from __future__ import annotations from datetime import datetime, timezone import json from pathlib import Path import sqlite3 from typing import Callable from sqlalchemy import func, select, text from sqlalchemy.orm import Session from app.models import Dataset, OsmFeature from app.osm_classification import OSM_ROUTE_SCOPE_CLASSIFIER_VERSION, infer_osm_route_scope_from_tags from app.osm_storage import ( dataset_metadata, drop_osm_sidecar_route_scope_indexes, ensure_osm_sidecar_schema, features_are_sidecar, rebuild_osm_sidecar_indexes, sidecar_path, writable_sidecar_connection, ) from app.pipeline.state import ( STAGE_BUILD_INDEXES, STAGE_LABEL_FEATURES, dependency_hash, finish_pipeline_run, latest_completed_run, start_pipeline_run, ) OSM_LABEL_FEATURES_VERSION = OSM_ROUTE_SCOPE_CLASSIFIER_VERSION MAIN_ROUTE_SCOPE_INDEX = "ix_osm_features_scope_bbox" MAIN_INDEX_REBUILD_THRESHOLD = 10_000 SIDECAR_INDEX_REBUILD_THRESHOLD = 10_000 ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None] def relabel_osm_features( session: Session, *, dataset_id: int | None = None, chunk_size: int = 5000, force: bool = False, rebuild_indexes: bool = True, progress_callback: ProgressCallback | None = None, job_id: int | None = None, ) -> dict[str, object]: datasets = _target_datasets(session, dataset_id) result: dict[str, object] = { "version": OSM_LABEL_FEATURES_VERSION, "datasets": len(datasets), "processed": 0, "changed": 0, "skipped": 0, "missing": 0, "index_rebuilds": 0, "dataset_results": [], } _emit_progress( progress_callback, "osm_labeling_started", f"Relabeling {len(datasets)} OSM dataset(s).", 0, len(datasets), {"dataset_id": dataset_id, "force": force, "version": OSM_LABEL_FEATURES_VERSION}, ) for index, dataset in enumerate(datasets, start=1): dataset_result = relabel_osm_dataset( session, dataset, chunk_size=chunk_size, force=force, rebuild_indexes=rebuild_indexes, progress_callback=progress_callback, job_id=job_id, ) result["processed"] = int(result["processed"]) + int(dataset_result.get("processed", 0) or 0) result["changed"] = int(result["changed"]) + int(dataset_result.get("changed", 0) or 0) result["skipped"] = int(result["skipped"]) + (1 if dataset_result.get("status") == "skipped" else 0) result["missing"] = int(result["missing"]) + (1 if dataset_result.get("status") == "missing_sidecar" else 0) result["index_rebuilds"] = int(result["index_rebuilds"]) + int(dataset_result.get("index_rebuilds", 0) or 0) result["dataset_results"].append(dataset_result) # type: ignore[union-attr] _emit_progress( progress_callback, "osm_labeling_dataset_completed", f"Relabeled {index}/{len(datasets)} OSM dataset(s).", index, len(datasets), dataset_result, ) _emit_progress(progress_callback, "osm_labeling_completed", "OSM feature relabeling completed.", len(datasets), len(datasets), result) return result def relabel_osm_dataset( session: Session, dataset: Dataset, *, chunk_size: int = 5000, force: bool = False, rebuild_indexes: bool = True, progress_callback: ProgressCallback | None = None, job_id: int | None = None, ) -> dict[str, object]: dependency = _label_dependency(dataset) dependency_hash_value = dependency_hash(dependency) if not force and _dataset_label_is_current(session, dataset, dependency_hash_value): return { "dataset_id": dataset.id, "source_id": dataset.source_id, "status": "skipped", "reason": "label_features dependency is current", "dependency_hash": dependency_hash_value, "version": OSM_LABEL_FEATURES_VERSION, "processed": 0, "changed": 0, "index_rebuilds": 0, } run = start_pipeline_run( session, stage=STAGE_LABEL_FEATURES, version=OSM_LABEL_FEATURES_VERSION, dependency_hash_value=dependency_hash_value, source_id=dataset.source_id, dataset_id=dataset.id, job_id=job_id, inputs=dependency, ) session.commit() try: if features_are_sidecar(dataset): counts = _relabel_sidecar_dataset(dataset, chunk_size=chunk_size, rebuild_indexes=rebuild_indexes, progress_callback=progress_callback) else: counts = _relabel_main_dataset(session, dataset, chunk_size=chunk_size, rebuild_indexes=rebuild_indexes, progress_callback=progress_callback) output = { "dataset_id": dataset.id, "source_id": dataset.source_id, "status": "completed", "dependency_hash": dependency_hash_value, "version": OSM_LABEL_FEATURES_VERSION, **counts, } _stamp_dataset_metadata(session, dataset, dependency_hash_value, output) finish_pipeline_run(session, run, outputs=output) session.commit() return output except FileNotFoundError as exc: output = { "dataset_id": dataset.id, "source_id": dataset.source_id, "status": "missing_sidecar", "dependency_hash": dependency_hash_value, "version": OSM_LABEL_FEATURES_VERSION, "processed": 0, "changed": 0, "index_rebuilds": 0, "error": str(exc), } finish_pipeline_run(session, run, status="failed", outputs=output, error=str(exc)) session.commit() return output except Exception as exc: finish_pipeline_run(session, run, status="failed", error=str(exc)) session.commit() raise def _target_datasets(session: Session, dataset_id: int | None) -> list[Dataset]: stmt = select(Dataset).where(Dataset.kind == "osm_geojson", Dataset.status == "imported") if dataset_id is None: stmt = stmt.where(Dataset.is_active.is_(True)) else: stmt = stmt.where(Dataset.id == dataset_id) return session.scalars(stmt.order_by(Dataset.source_id, Dataset.id)).all() def _dataset_label_is_current(session: Session, dataset: Dataset, dependency_hash_value: str) -> bool: metadata = dataset_metadata(dataset) label_info = metadata.get("label_features") metadata_current = ( isinstance(label_info, dict) and label_info.get("version") == OSM_LABEL_FEATURES_VERSION and label_info.get("dependency_hash") == dependency_hash_value ) if not metadata_current: return False return ( latest_completed_run( session, stage=STAGE_LABEL_FEATURES, version=OSM_LABEL_FEATURES_VERSION, dependency_hash_value=dependency_hash_value, source_id=dataset.source_id, dataset_id=dataset.id, ) is not None ) def _relabel_sidecar_dataset( dataset: Dataset, *, chunk_size: int, rebuild_indexes: bool, progress_callback: ProgressCallback | None, ) -> dict[str, int | str]: path = sidecar_path(dataset) if path is None or not path.exists(): raise FileNotFoundError(f"OSM sidecar does not exist: {path}") with writable_sidecar_connection(dataset) as connection: ensure_osm_sidecar_schema(connection) total = int(connection.execute("SELECT COUNT(*) FROM osm_features").fetchone()[0] or 0) should_rebuild_index = rebuild_indexes and total >= SIDECAR_INDEX_REBUILD_THRESHOLD if should_rebuild_index: drop_osm_sidecar_route_scope_indexes(connection) connection.commit() processed = 0 changed = 0 last_id = 0 try: while True: rows = connection.execute( """ SELECT id, mode, ref, name, network, tags_json, route_scope FROM osm_features WHERE id > ? ORDER BY id LIMIT ? """, (last_id, max(1, int(chunk_size))), ).fetchall() if not rows: break updates: list[tuple[str | None, int]] = [] for row in rows: last_id = int(row["id"]) new_scope = _classified_scope(row["mode"], row["ref"], row["name"], row["network"], row["tags_json"]) if _normalize_scope(row["route_scope"]) != new_scope: updates.append((new_scope, last_id)) if updates: connection.executemany("UPDATE osm_features SET route_scope = ? WHERE id = ?", updates) processed += len(rows) changed += len(updates) connection.commit() _emit_progress( progress_callback, "osm_labeling_batch", f"Relabeled {processed}/{total} OSM sidecar features.", processed, total, {"dataset_id": dataset.id, "changed": changed, "storage": "sidecar"}, ) finally: index_rebuilds = 0 if should_rebuild_index: rebuild_osm_sidecar_indexes(connection) connection.commit() index_rebuilds = 1 _record_sidecar_index_build(connection, dataset, path) _record_sidecar_label(connection, dataset, processed=processed, changed=changed) connection.commit() return {"storage": "sidecar", "processed": processed, "changed": changed, "index_rebuilds": index_rebuilds} def _relabel_main_dataset( session: Session, dataset: Dataset, *, chunk_size: int, rebuild_indexes: bool, progress_callback: ProgressCallback | None, ) -> dict[str, int | str]: total = int(session.scalar(select(func.count()).select_from(OsmFeature).where(OsmFeature.dataset_id == dataset.id)) or 0) should_rebuild_index = rebuild_indexes and total >= MAIN_INDEX_REBUILD_THRESHOLD index_rebuilds = 0 if should_rebuild_index: session.execute(text(f"DROP INDEX IF EXISTS {MAIN_ROUTE_SCOPE_INDEX}")) session.commit() processed = 0 changed = 0 last_id = 0 try: while True: rows = session.scalars( select(OsmFeature) .where(OsmFeature.dataset_id == dataset.id, OsmFeature.id > last_id) .order_by(OsmFeature.id) .limit(max(1, int(chunk_size))) ).all() if not rows: break updates: list[dict[str, object]] = [] for feature in rows: last_id = int(feature.id) new_scope = _classified_scope(feature.mode, feature.ref, feature.name, feature.network, feature.tags_json) if _normalize_scope(feature.route_scope) != new_scope: updates.append({"id": feature.id, "route_scope": new_scope}) if updates: session.bulk_update_mappings(OsmFeature, updates) processed += len(rows) changed += len(updates) session.commit() _emit_progress( progress_callback, "osm_labeling_batch", f"Relabeled {processed}/{total} main-table OSM features.", processed, total, {"dataset_id": dataset.id, "changed": changed, "storage": "main"}, ) finally: if should_rebuild_index: session.execute( text( "CREATE INDEX IF NOT EXISTS ix_osm_features_scope_bbox " "ON osm_features (dataset_id, kind, mode, route_scope, min_lon, max_lon, min_lat, max_lat)" ) ) session.commit() index_rebuilds = 1 _record_main_index_build(session, dataset) return {"storage": "main", "processed": processed, "changed": changed, "index_rebuilds": index_rebuilds} def _classified_scope(mode: object, ref: object, name: object, network: object, tags_json: object) -> str | None: return _normalize_scope( infer_osm_route_scope_from_tags( None if mode is None else str(mode), None if ref is None else str(ref), None if name is None else str(name), None if network is None else str(network), None if tags_json is None else str(tags_json), ) ) def _normalize_scope(value: object) -> str | None: text_value = str(value or "").strip() return text_value or None def _label_dependency(dataset: Dataset) -> dict[str, object]: metadata = dataset_metadata(dataset) storage = metadata.get("osm_storage") if isinstance(metadata, dict) else None path = sidecar_path(dataset) path_fingerprint: dict[str, object] | None = None if path is not None: resolved = Path(path) if resolved.exists(): path_fingerprint = {"path": str(resolved), "exists": True} else: path_fingerprint = {"path": str(resolved), "missing": True} return { "dataset_id": dataset.id, "source_id": dataset.source_id, "kind": dataset.kind, "dataset_sha256": dataset.sha256, "storage": storage, "sidecar": path_fingerprint, "classifier_version": OSM_LABEL_FEATURES_VERSION, } def _stamp_dataset_metadata(session: Session, dataset: Dataset, dependency_hash_value: str, output: dict[str, object]) -> None: refreshed = session.get(Dataset, dataset.id) if refreshed is None: return metadata = dataset_metadata(refreshed) metadata["label_features"] = { "stage": STAGE_LABEL_FEATURES, "version": OSM_LABEL_FEATURES_VERSION, "dependency_hash": dependency_hash_value, "labeled_at": datetime.now(timezone.utc).isoformat(), "processed": output.get("processed", 0), "changed": output.get("changed", 0), "storage": output.get("storage"), } refreshed.metadata_json = json.dumps(metadata, indent=2) session.flush() def _record_sidecar_label(connection: sqlite3.Connection, dataset: Dataset, *, processed: int, changed: int) -> None: connection.execute("CREATE TABLE IF NOT EXISTS pipeline_metadata (key TEXT PRIMARY KEY, value TEXT NOT NULL)") connection.execute( "INSERT OR REPLACE INTO pipeline_metadata (key, value) VALUES (?, ?)", ( "label_features", json.dumps( { "stage": STAGE_LABEL_FEATURES, "version": OSM_LABEL_FEATURES_VERSION, "dataset_id": dataset.id, "processed": processed, "changed": changed, "updated_at": datetime.now(timezone.utc).isoformat(), }, sort_keys=True, separators=(",", ":"), ), ), ) def _record_sidecar_index_build(connection: sqlite3.Connection, dataset: Dataset, path: Path) -> None: connection.execute("CREATE TABLE IF NOT EXISTS pipeline_metadata (key TEXT PRIMARY KEY, value TEXT NOT NULL)") connection.execute( "INSERT OR REPLACE INTO pipeline_metadata (key, value) VALUES (?, ?)", ( "build_indexes:route_scope", json.dumps( { "stage": STAGE_BUILD_INDEXES, "version": "osm_sidecar_indexes_v1", "dataset_id": dataset.id, "path": str(path), "updated_at": datetime.now(timezone.utc).isoformat(), }, sort_keys=True, separators=(",", ":"), ), ), ) def _record_main_index_build(session: Session, dataset: Dataset) -> None: dependency = { "dataset_id": dataset.id, "index": MAIN_ROUTE_SCOPE_INDEX, "version": "osm_main_indexes_v1", } run = start_pipeline_run( session, stage=STAGE_BUILD_INDEXES, version="osm_main_indexes_v1", dependency_hash_value=dependency_hash(dependency), source_id=dataset.source_id, dataset_id=dataset.id, inputs=dependency, ) finish_pipeline_run(session, run, outputs={"index": MAIN_ROUTE_SCOPE_INDEX}) session.commit() def _emit_progress( callback: ProgressCallback | None, event_type: str, message: str, current: int | None, total: int | None, metadata: dict[str, object] | None, ) -> None: if callback is not None: callback(event_type, message, current, total, metadata)