Files
meubility-workbench/app/pipeline/osm_labeling.py
2026-07-01 23:29:51 +02:00

457 lines
17 KiB
Python

from __future__ import annotations
from datetime import datetime, timezone
import json
from pathlib import Path
import sqlite3
from typing import Callable
from sqlalchemy import func, select, text
from sqlalchemy.orm import Session
from app.models import Dataset, OsmFeature
from app.osm_classification import OSM_ROUTE_SCOPE_CLASSIFIER_VERSION, infer_osm_route_scope_from_tags
from app.osm_storage import (
dataset_metadata,
drop_osm_sidecar_route_scope_indexes,
ensure_osm_sidecar_schema,
features_are_sidecar,
rebuild_osm_sidecar_indexes,
sidecar_path,
writable_sidecar_connection,
)
from app.pipeline.state import (
STAGE_BUILD_INDEXES,
STAGE_LABEL_FEATURES,
dependency_hash,
finish_pipeline_run,
latest_completed_run,
start_pipeline_run,
)
OSM_LABEL_FEATURES_VERSION = OSM_ROUTE_SCOPE_CLASSIFIER_VERSION
MAIN_ROUTE_SCOPE_INDEX = "ix_osm_features_scope_bbox"
MAIN_INDEX_REBUILD_THRESHOLD = 10_000
SIDECAR_INDEX_REBUILD_THRESHOLD = 10_000
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
def relabel_osm_features(
session: Session,
*,
dataset_id: int | None = None,
chunk_size: int = 5000,
force: bool = False,
rebuild_indexes: bool = True,
progress_callback: ProgressCallback | None = None,
job_id: int | None = None,
) -> dict[str, object]:
datasets = _target_datasets(session, dataset_id)
result: dict[str, object] = {
"version": OSM_LABEL_FEATURES_VERSION,
"datasets": len(datasets),
"processed": 0,
"changed": 0,
"skipped": 0,
"missing": 0,
"index_rebuilds": 0,
"dataset_results": [],
}
_emit_progress(
progress_callback,
"osm_labeling_started",
f"Relabeling {len(datasets)} OSM dataset(s).",
0,
len(datasets),
{"dataset_id": dataset_id, "force": force, "version": OSM_LABEL_FEATURES_VERSION},
)
for index, dataset in enumerate(datasets, start=1):
dataset_result = relabel_osm_dataset(
session,
dataset,
chunk_size=chunk_size,
force=force,
rebuild_indexes=rebuild_indexes,
progress_callback=progress_callback,
job_id=job_id,
)
result["processed"] = int(result["processed"]) + int(dataset_result.get("processed", 0) or 0)
result["changed"] = int(result["changed"]) + int(dataset_result.get("changed", 0) or 0)
result["skipped"] = int(result["skipped"]) + (1 if dataset_result.get("status") == "skipped" else 0)
result["missing"] = int(result["missing"]) + (1 if dataset_result.get("status") == "missing_sidecar" else 0)
result["index_rebuilds"] = int(result["index_rebuilds"]) + int(dataset_result.get("index_rebuilds", 0) or 0)
result["dataset_results"].append(dataset_result) # type: ignore[union-attr]
_emit_progress(
progress_callback,
"osm_labeling_dataset_completed",
f"Relabeled {index}/{len(datasets)} OSM dataset(s).",
index,
len(datasets),
dataset_result,
)
_emit_progress(progress_callback, "osm_labeling_completed", "OSM feature relabeling completed.", len(datasets), len(datasets), result)
return result
def relabel_osm_dataset(
session: Session,
dataset: Dataset,
*,
chunk_size: int = 5000,
force: bool = False,
rebuild_indexes: bool = True,
progress_callback: ProgressCallback | None = None,
job_id: int | None = None,
) -> dict[str, object]:
dependency = _label_dependency(dataset)
dependency_hash_value = dependency_hash(dependency)
if not force and _dataset_label_is_current(session, dataset, dependency_hash_value):
return {
"dataset_id": dataset.id,
"source_id": dataset.source_id,
"status": "skipped",
"reason": "label_features dependency is current",
"dependency_hash": dependency_hash_value,
"version": OSM_LABEL_FEATURES_VERSION,
"processed": 0,
"changed": 0,
"index_rebuilds": 0,
}
run = start_pipeline_run(
session,
stage=STAGE_LABEL_FEATURES,
version=OSM_LABEL_FEATURES_VERSION,
dependency_hash_value=dependency_hash_value,
source_id=dataset.source_id,
dataset_id=dataset.id,
job_id=job_id,
inputs=dependency,
)
session.commit()
try:
if features_are_sidecar(dataset):
counts = _relabel_sidecar_dataset(dataset, chunk_size=chunk_size, rebuild_indexes=rebuild_indexes, progress_callback=progress_callback)
else:
counts = _relabel_main_dataset(session, dataset, chunk_size=chunk_size, rebuild_indexes=rebuild_indexes, progress_callback=progress_callback)
output = {
"dataset_id": dataset.id,
"source_id": dataset.source_id,
"status": "completed",
"dependency_hash": dependency_hash_value,
"version": OSM_LABEL_FEATURES_VERSION,
**counts,
}
_stamp_dataset_metadata(session, dataset, dependency_hash_value, output)
finish_pipeline_run(session, run, outputs=output)
session.commit()
return output
except FileNotFoundError as exc:
output = {
"dataset_id": dataset.id,
"source_id": dataset.source_id,
"status": "missing_sidecar",
"dependency_hash": dependency_hash_value,
"version": OSM_LABEL_FEATURES_VERSION,
"processed": 0,
"changed": 0,
"index_rebuilds": 0,
"error": str(exc),
}
finish_pipeline_run(session, run, status="failed", outputs=output, error=str(exc))
session.commit()
return output
except Exception as exc:
finish_pipeline_run(session, run, status="failed", error=str(exc))
session.commit()
raise
def _target_datasets(session: Session, dataset_id: int | None) -> list[Dataset]:
stmt = select(Dataset).where(Dataset.kind == "osm_geojson", Dataset.status == "imported")
if dataset_id is None:
stmt = stmt.where(Dataset.is_active.is_(True))
else:
stmt = stmt.where(Dataset.id == dataset_id)
return session.scalars(stmt.order_by(Dataset.source_id, Dataset.id)).all()
def _dataset_label_is_current(session: Session, dataset: Dataset, dependency_hash_value: str) -> bool:
metadata = dataset_metadata(dataset)
label_info = metadata.get("label_features")
metadata_current = (
isinstance(label_info, dict)
and label_info.get("version") == OSM_LABEL_FEATURES_VERSION
and label_info.get("dependency_hash") == dependency_hash_value
)
if not metadata_current:
return False
return (
latest_completed_run(
session,
stage=STAGE_LABEL_FEATURES,
version=OSM_LABEL_FEATURES_VERSION,
dependency_hash_value=dependency_hash_value,
source_id=dataset.source_id,
dataset_id=dataset.id,
)
is not None
)
def _relabel_sidecar_dataset(
dataset: Dataset,
*,
chunk_size: int,
rebuild_indexes: bool,
progress_callback: ProgressCallback | None,
) -> dict[str, int | str]:
path = sidecar_path(dataset)
if path is None or not path.exists():
raise FileNotFoundError(f"OSM sidecar does not exist: {path}")
with writable_sidecar_connection(dataset) as connection:
ensure_osm_sidecar_schema(connection)
total = int(connection.execute("SELECT COUNT(*) FROM osm_features").fetchone()[0] or 0)
should_rebuild_index = rebuild_indexes and total >= SIDECAR_INDEX_REBUILD_THRESHOLD
if should_rebuild_index:
drop_osm_sidecar_route_scope_indexes(connection)
connection.commit()
processed = 0
changed = 0
last_id = 0
try:
while True:
rows = connection.execute(
"""
SELECT id, mode, ref, name, network, tags_json, route_scope
FROM osm_features
WHERE id > ?
ORDER BY id
LIMIT ?
""",
(last_id, max(1, int(chunk_size))),
).fetchall()
if not rows:
break
updates: list[tuple[str | None, int]] = []
for row in rows:
last_id = int(row["id"])
new_scope = _classified_scope(row["mode"], row["ref"], row["name"], row["network"], row["tags_json"])
if _normalize_scope(row["route_scope"]) != new_scope:
updates.append((new_scope, last_id))
if updates:
connection.executemany("UPDATE osm_features SET route_scope = ? WHERE id = ?", updates)
processed += len(rows)
changed += len(updates)
connection.commit()
_emit_progress(
progress_callback,
"osm_labeling_batch",
f"Relabeled {processed}/{total} OSM sidecar features.",
processed,
total,
{"dataset_id": dataset.id, "changed": changed, "storage": "sidecar"},
)
finally:
index_rebuilds = 0
if should_rebuild_index:
rebuild_osm_sidecar_indexes(connection)
connection.commit()
index_rebuilds = 1
_record_sidecar_index_build(connection, dataset, path)
_record_sidecar_label(connection, dataset, processed=processed, changed=changed)
connection.commit()
return {"storage": "sidecar", "processed": processed, "changed": changed, "index_rebuilds": index_rebuilds}
def _relabel_main_dataset(
session: Session,
dataset: Dataset,
*,
chunk_size: int,
rebuild_indexes: bool,
progress_callback: ProgressCallback | None,
) -> dict[str, int | str]:
total = int(session.scalar(select(func.count()).select_from(OsmFeature).where(OsmFeature.dataset_id == dataset.id)) or 0)
should_rebuild_index = rebuild_indexes and total >= MAIN_INDEX_REBUILD_THRESHOLD
index_rebuilds = 0
if should_rebuild_index:
session.execute(text(f"DROP INDEX IF EXISTS {MAIN_ROUTE_SCOPE_INDEX}"))
session.commit()
processed = 0
changed = 0
last_id = 0
try:
while True:
rows = session.scalars(
select(OsmFeature)
.where(OsmFeature.dataset_id == dataset.id, OsmFeature.id > last_id)
.order_by(OsmFeature.id)
.limit(max(1, int(chunk_size)))
).all()
if not rows:
break
updates: list[dict[str, object]] = []
for feature in rows:
last_id = int(feature.id)
new_scope = _classified_scope(feature.mode, feature.ref, feature.name, feature.network, feature.tags_json)
if _normalize_scope(feature.route_scope) != new_scope:
updates.append({"id": feature.id, "route_scope": new_scope})
if updates:
session.bulk_update_mappings(OsmFeature, updates)
processed += len(rows)
changed += len(updates)
session.commit()
_emit_progress(
progress_callback,
"osm_labeling_batch",
f"Relabeled {processed}/{total} main-table OSM features.",
processed,
total,
{"dataset_id": dataset.id, "changed": changed, "storage": "main"},
)
finally:
if should_rebuild_index:
session.execute(
text(
"CREATE INDEX IF NOT EXISTS ix_osm_features_scope_bbox "
"ON osm_features (dataset_id, kind, mode, route_scope, min_lon, max_lon, min_lat, max_lat)"
)
)
session.commit()
index_rebuilds = 1
_record_main_index_build(session, dataset)
return {"storage": "main", "processed": processed, "changed": changed, "index_rebuilds": index_rebuilds}
def _classified_scope(mode: object, ref: object, name: object, network: object, tags_json: object) -> str | None:
return _normalize_scope(
infer_osm_route_scope_from_tags(
None if mode is None else str(mode),
None if ref is None else str(ref),
None if name is None else str(name),
None if network is None else str(network),
None if tags_json is None else str(tags_json),
)
)
def _normalize_scope(value: object) -> str | None:
text_value = str(value or "").strip()
return text_value or None
def _label_dependency(dataset: Dataset) -> dict[str, object]:
metadata = dataset_metadata(dataset)
storage = metadata.get("osm_storage") if isinstance(metadata, dict) else None
path = sidecar_path(dataset)
path_fingerprint: dict[str, object] | None = None
if path is not None:
resolved = Path(path)
if resolved.exists():
path_fingerprint = {"path": str(resolved), "exists": True}
else:
path_fingerprint = {"path": str(resolved), "missing": True}
return {
"dataset_id": dataset.id,
"source_id": dataset.source_id,
"kind": dataset.kind,
"dataset_sha256": dataset.sha256,
"storage": storage,
"sidecar": path_fingerprint,
"classifier_version": OSM_LABEL_FEATURES_VERSION,
}
def _stamp_dataset_metadata(session: Session, dataset: Dataset, dependency_hash_value: str, output: dict[str, object]) -> None:
refreshed = session.get(Dataset, dataset.id)
if refreshed is None:
return
metadata = dataset_metadata(refreshed)
metadata["label_features"] = {
"stage": STAGE_LABEL_FEATURES,
"version": OSM_LABEL_FEATURES_VERSION,
"dependency_hash": dependency_hash_value,
"labeled_at": datetime.now(timezone.utc).isoformat(),
"processed": output.get("processed", 0),
"changed": output.get("changed", 0),
"storage": output.get("storage"),
}
refreshed.metadata_json = json.dumps(metadata, indent=2)
session.flush()
def _record_sidecar_label(connection: sqlite3.Connection, dataset: Dataset, *, processed: int, changed: int) -> None:
connection.execute("CREATE TABLE IF NOT EXISTS pipeline_metadata (key TEXT PRIMARY KEY, value TEXT NOT NULL)")
connection.execute(
"INSERT OR REPLACE INTO pipeline_metadata (key, value) VALUES (?, ?)",
(
"label_features",
json.dumps(
{
"stage": STAGE_LABEL_FEATURES,
"version": OSM_LABEL_FEATURES_VERSION,
"dataset_id": dataset.id,
"processed": processed,
"changed": changed,
"updated_at": datetime.now(timezone.utc).isoformat(),
},
sort_keys=True,
separators=(",", ":"),
),
),
)
def _record_sidecar_index_build(connection: sqlite3.Connection, dataset: Dataset, path: Path) -> None:
connection.execute("CREATE TABLE IF NOT EXISTS pipeline_metadata (key TEXT PRIMARY KEY, value TEXT NOT NULL)")
connection.execute(
"INSERT OR REPLACE INTO pipeline_metadata (key, value) VALUES (?, ?)",
(
"build_indexes:route_scope",
json.dumps(
{
"stage": STAGE_BUILD_INDEXES,
"version": "osm_sidecar_indexes_v1",
"dataset_id": dataset.id,
"path": str(path),
"updated_at": datetime.now(timezone.utc).isoformat(),
},
sort_keys=True,
separators=(",", ":"),
),
),
)
def _record_main_index_build(session: Session, dataset: Dataset) -> None:
dependency = {
"dataset_id": dataset.id,
"index": MAIN_ROUTE_SCOPE_INDEX,
"version": "osm_main_indexes_v1",
}
run = start_pipeline_run(
session,
stage=STAGE_BUILD_INDEXES,
version="osm_main_indexes_v1",
dependency_hash_value=dependency_hash(dependency),
source_id=dataset.source_id,
dataset_id=dataset.id,
inputs=dependency,
)
finish_pipeline_run(session, run, outputs={"index": MAIN_ROUTE_SCOPE_INDEX})
session.commit()
def _emit_progress(
callback: ProgressCallback | None,
event_type: str,
message: str,
current: int | None,
total: int | None,
metadata: dict[str, object] | None,
) -> None:
if callback is not None:
callback(event_type, message, current, total, metadata)