457 lines
17 KiB
Python
457 lines
17 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime, timezone
|
|
import json
|
|
from pathlib import Path
|
|
import sqlite3
|
|
from typing import Callable
|
|
|
|
from sqlalchemy import func, select, text
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.models import Dataset, OsmFeature
|
|
from app.osm_classification import OSM_ROUTE_SCOPE_CLASSIFIER_VERSION, infer_osm_route_scope_from_tags
|
|
from app.osm_storage import (
|
|
dataset_metadata,
|
|
drop_osm_sidecar_route_scope_indexes,
|
|
ensure_osm_sidecar_schema,
|
|
features_are_sidecar,
|
|
rebuild_osm_sidecar_indexes,
|
|
sidecar_path,
|
|
writable_sidecar_connection,
|
|
)
|
|
from app.pipeline.state import (
|
|
STAGE_BUILD_INDEXES,
|
|
STAGE_LABEL_FEATURES,
|
|
dependency_hash,
|
|
finish_pipeline_run,
|
|
latest_completed_run,
|
|
start_pipeline_run,
|
|
)
|
|
|
|
|
|
OSM_LABEL_FEATURES_VERSION = OSM_ROUTE_SCOPE_CLASSIFIER_VERSION
|
|
MAIN_ROUTE_SCOPE_INDEX = "ix_osm_features_scope_bbox"
|
|
MAIN_INDEX_REBUILD_THRESHOLD = 10_000
|
|
SIDECAR_INDEX_REBUILD_THRESHOLD = 10_000
|
|
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
|
|
|
|
|
|
def relabel_osm_features(
|
|
session: Session,
|
|
*,
|
|
dataset_id: int | None = None,
|
|
chunk_size: int = 5000,
|
|
force: bool = False,
|
|
rebuild_indexes: bool = True,
|
|
progress_callback: ProgressCallback | None = None,
|
|
job_id: int | None = None,
|
|
) -> dict[str, object]:
|
|
datasets = _target_datasets(session, dataset_id)
|
|
result: dict[str, object] = {
|
|
"version": OSM_LABEL_FEATURES_VERSION,
|
|
"datasets": len(datasets),
|
|
"processed": 0,
|
|
"changed": 0,
|
|
"skipped": 0,
|
|
"missing": 0,
|
|
"index_rebuilds": 0,
|
|
"dataset_results": [],
|
|
}
|
|
_emit_progress(
|
|
progress_callback,
|
|
"osm_labeling_started",
|
|
f"Relabeling {len(datasets)} OSM dataset(s).",
|
|
0,
|
|
len(datasets),
|
|
{"dataset_id": dataset_id, "force": force, "version": OSM_LABEL_FEATURES_VERSION},
|
|
)
|
|
for index, dataset in enumerate(datasets, start=1):
|
|
dataset_result = relabel_osm_dataset(
|
|
session,
|
|
dataset,
|
|
chunk_size=chunk_size,
|
|
force=force,
|
|
rebuild_indexes=rebuild_indexes,
|
|
progress_callback=progress_callback,
|
|
job_id=job_id,
|
|
)
|
|
result["processed"] = int(result["processed"]) + int(dataset_result.get("processed", 0) or 0)
|
|
result["changed"] = int(result["changed"]) + int(dataset_result.get("changed", 0) or 0)
|
|
result["skipped"] = int(result["skipped"]) + (1 if dataset_result.get("status") == "skipped" else 0)
|
|
result["missing"] = int(result["missing"]) + (1 if dataset_result.get("status") == "missing_sidecar" else 0)
|
|
result["index_rebuilds"] = int(result["index_rebuilds"]) + int(dataset_result.get("index_rebuilds", 0) or 0)
|
|
result["dataset_results"].append(dataset_result) # type: ignore[union-attr]
|
|
_emit_progress(
|
|
progress_callback,
|
|
"osm_labeling_dataset_completed",
|
|
f"Relabeled {index}/{len(datasets)} OSM dataset(s).",
|
|
index,
|
|
len(datasets),
|
|
dataset_result,
|
|
)
|
|
_emit_progress(progress_callback, "osm_labeling_completed", "OSM feature relabeling completed.", len(datasets), len(datasets), result)
|
|
return result
|
|
|
|
|
|
def relabel_osm_dataset(
|
|
session: Session,
|
|
dataset: Dataset,
|
|
*,
|
|
chunk_size: int = 5000,
|
|
force: bool = False,
|
|
rebuild_indexes: bool = True,
|
|
progress_callback: ProgressCallback | None = None,
|
|
job_id: int | None = None,
|
|
) -> dict[str, object]:
|
|
dependency = _label_dependency(dataset)
|
|
dependency_hash_value = dependency_hash(dependency)
|
|
if not force and _dataset_label_is_current(session, dataset, dependency_hash_value):
|
|
return {
|
|
"dataset_id": dataset.id,
|
|
"source_id": dataset.source_id,
|
|
"status": "skipped",
|
|
"reason": "label_features dependency is current",
|
|
"dependency_hash": dependency_hash_value,
|
|
"version": OSM_LABEL_FEATURES_VERSION,
|
|
"processed": 0,
|
|
"changed": 0,
|
|
"index_rebuilds": 0,
|
|
}
|
|
|
|
run = start_pipeline_run(
|
|
session,
|
|
stage=STAGE_LABEL_FEATURES,
|
|
version=OSM_LABEL_FEATURES_VERSION,
|
|
dependency_hash_value=dependency_hash_value,
|
|
source_id=dataset.source_id,
|
|
dataset_id=dataset.id,
|
|
job_id=job_id,
|
|
inputs=dependency,
|
|
)
|
|
session.commit()
|
|
try:
|
|
if features_are_sidecar(dataset):
|
|
counts = _relabel_sidecar_dataset(dataset, chunk_size=chunk_size, rebuild_indexes=rebuild_indexes, progress_callback=progress_callback)
|
|
else:
|
|
counts = _relabel_main_dataset(session, dataset, chunk_size=chunk_size, rebuild_indexes=rebuild_indexes, progress_callback=progress_callback)
|
|
output = {
|
|
"dataset_id": dataset.id,
|
|
"source_id": dataset.source_id,
|
|
"status": "completed",
|
|
"dependency_hash": dependency_hash_value,
|
|
"version": OSM_LABEL_FEATURES_VERSION,
|
|
**counts,
|
|
}
|
|
_stamp_dataset_metadata(session, dataset, dependency_hash_value, output)
|
|
finish_pipeline_run(session, run, outputs=output)
|
|
session.commit()
|
|
return output
|
|
except FileNotFoundError as exc:
|
|
output = {
|
|
"dataset_id": dataset.id,
|
|
"source_id": dataset.source_id,
|
|
"status": "missing_sidecar",
|
|
"dependency_hash": dependency_hash_value,
|
|
"version": OSM_LABEL_FEATURES_VERSION,
|
|
"processed": 0,
|
|
"changed": 0,
|
|
"index_rebuilds": 0,
|
|
"error": str(exc),
|
|
}
|
|
finish_pipeline_run(session, run, status="failed", outputs=output, error=str(exc))
|
|
session.commit()
|
|
return output
|
|
except Exception as exc:
|
|
finish_pipeline_run(session, run, status="failed", error=str(exc))
|
|
session.commit()
|
|
raise
|
|
|
|
|
|
def _target_datasets(session: Session, dataset_id: int | None) -> list[Dataset]:
|
|
stmt = select(Dataset).where(Dataset.kind == "osm_geojson", Dataset.status == "imported")
|
|
if dataset_id is None:
|
|
stmt = stmt.where(Dataset.is_active.is_(True))
|
|
else:
|
|
stmt = stmt.where(Dataset.id == dataset_id)
|
|
return session.scalars(stmt.order_by(Dataset.source_id, Dataset.id)).all()
|
|
|
|
|
|
def _dataset_label_is_current(session: Session, dataset: Dataset, dependency_hash_value: str) -> bool:
|
|
metadata = dataset_metadata(dataset)
|
|
label_info = metadata.get("label_features")
|
|
metadata_current = (
|
|
isinstance(label_info, dict)
|
|
and label_info.get("version") == OSM_LABEL_FEATURES_VERSION
|
|
and label_info.get("dependency_hash") == dependency_hash_value
|
|
)
|
|
if not metadata_current:
|
|
return False
|
|
return (
|
|
latest_completed_run(
|
|
session,
|
|
stage=STAGE_LABEL_FEATURES,
|
|
version=OSM_LABEL_FEATURES_VERSION,
|
|
dependency_hash_value=dependency_hash_value,
|
|
source_id=dataset.source_id,
|
|
dataset_id=dataset.id,
|
|
)
|
|
is not None
|
|
)
|
|
|
|
|
|
def _relabel_sidecar_dataset(
|
|
dataset: Dataset,
|
|
*,
|
|
chunk_size: int,
|
|
rebuild_indexes: bool,
|
|
progress_callback: ProgressCallback | None,
|
|
) -> dict[str, int | str]:
|
|
path = sidecar_path(dataset)
|
|
if path is None or not path.exists():
|
|
raise FileNotFoundError(f"OSM sidecar does not exist: {path}")
|
|
with writable_sidecar_connection(dataset) as connection:
|
|
ensure_osm_sidecar_schema(connection)
|
|
total = int(connection.execute("SELECT COUNT(*) FROM osm_features").fetchone()[0] or 0)
|
|
should_rebuild_index = rebuild_indexes and total >= SIDECAR_INDEX_REBUILD_THRESHOLD
|
|
if should_rebuild_index:
|
|
drop_osm_sidecar_route_scope_indexes(connection)
|
|
connection.commit()
|
|
processed = 0
|
|
changed = 0
|
|
last_id = 0
|
|
try:
|
|
while True:
|
|
rows = connection.execute(
|
|
"""
|
|
SELECT id, mode, ref, name, network, tags_json, route_scope
|
|
FROM osm_features
|
|
WHERE id > ?
|
|
ORDER BY id
|
|
LIMIT ?
|
|
""",
|
|
(last_id, max(1, int(chunk_size))),
|
|
).fetchall()
|
|
if not rows:
|
|
break
|
|
updates: list[tuple[str | None, int]] = []
|
|
for row in rows:
|
|
last_id = int(row["id"])
|
|
new_scope = _classified_scope(row["mode"], row["ref"], row["name"], row["network"], row["tags_json"])
|
|
if _normalize_scope(row["route_scope"]) != new_scope:
|
|
updates.append((new_scope, last_id))
|
|
if updates:
|
|
connection.executemany("UPDATE osm_features SET route_scope = ? WHERE id = ?", updates)
|
|
processed += len(rows)
|
|
changed += len(updates)
|
|
connection.commit()
|
|
_emit_progress(
|
|
progress_callback,
|
|
"osm_labeling_batch",
|
|
f"Relabeled {processed}/{total} OSM sidecar features.",
|
|
processed,
|
|
total,
|
|
{"dataset_id": dataset.id, "changed": changed, "storage": "sidecar"},
|
|
)
|
|
finally:
|
|
index_rebuilds = 0
|
|
if should_rebuild_index:
|
|
rebuild_osm_sidecar_indexes(connection)
|
|
connection.commit()
|
|
index_rebuilds = 1
|
|
_record_sidecar_index_build(connection, dataset, path)
|
|
_record_sidecar_label(connection, dataset, processed=processed, changed=changed)
|
|
connection.commit()
|
|
return {"storage": "sidecar", "processed": processed, "changed": changed, "index_rebuilds": index_rebuilds}
|
|
|
|
|
|
def _relabel_main_dataset(
|
|
session: Session,
|
|
dataset: Dataset,
|
|
*,
|
|
chunk_size: int,
|
|
rebuild_indexes: bool,
|
|
progress_callback: ProgressCallback | None,
|
|
) -> dict[str, int | str]:
|
|
total = int(session.scalar(select(func.count()).select_from(OsmFeature).where(OsmFeature.dataset_id == dataset.id)) or 0)
|
|
should_rebuild_index = rebuild_indexes and total >= MAIN_INDEX_REBUILD_THRESHOLD
|
|
index_rebuilds = 0
|
|
if should_rebuild_index:
|
|
session.execute(text(f"DROP INDEX IF EXISTS {MAIN_ROUTE_SCOPE_INDEX}"))
|
|
session.commit()
|
|
processed = 0
|
|
changed = 0
|
|
last_id = 0
|
|
try:
|
|
while True:
|
|
rows = session.scalars(
|
|
select(OsmFeature)
|
|
.where(OsmFeature.dataset_id == dataset.id, OsmFeature.id > last_id)
|
|
.order_by(OsmFeature.id)
|
|
.limit(max(1, int(chunk_size)))
|
|
).all()
|
|
if not rows:
|
|
break
|
|
updates: list[dict[str, object]] = []
|
|
for feature in rows:
|
|
last_id = int(feature.id)
|
|
new_scope = _classified_scope(feature.mode, feature.ref, feature.name, feature.network, feature.tags_json)
|
|
if _normalize_scope(feature.route_scope) != new_scope:
|
|
updates.append({"id": feature.id, "route_scope": new_scope})
|
|
if updates:
|
|
session.bulk_update_mappings(OsmFeature, updates)
|
|
processed += len(rows)
|
|
changed += len(updates)
|
|
session.commit()
|
|
_emit_progress(
|
|
progress_callback,
|
|
"osm_labeling_batch",
|
|
f"Relabeled {processed}/{total} main-table OSM features.",
|
|
processed,
|
|
total,
|
|
{"dataset_id": dataset.id, "changed": changed, "storage": "main"},
|
|
)
|
|
finally:
|
|
if should_rebuild_index:
|
|
session.execute(
|
|
text(
|
|
"CREATE INDEX IF NOT EXISTS ix_osm_features_scope_bbox "
|
|
"ON osm_features (dataset_id, kind, mode, route_scope, min_lon, max_lon, min_lat, max_lat)"
|
|
)
|
|
)
|
|
session.commit()
|
|
index_rebuilds = 1
|
|
_record_main_index_build(session, dataset)
|
|
return {"storage": "main", "processed": processed, "changed": changed, "index_rebuilds": index_rebuilds}
|
|
|
|
|
|
def _classified_scope(mode: object, ref: object, name: object, network: object, tags_json: object) -> str | None:
|
|
return _normalize_scope(
|
|
infer_osm_route_scope_from_tags(
|
|
None if mode is None else str(mode),
|
|
None if ref is None else str(ref),
|
|
None if name is None else str(name),
|
|
None if network is None else str(network),
|
|
None if tags_json is None else str(tags_json),
|
|
)
|
|
)
|
|
|
|
|
|
def _normalize_scope(value: object) -> str | None:
|
|
text_value = str(value or "").strip()
|
|
return text_value or None
|
|
|
|
|
|
def _label_dependency(dataset: Dataset) -> dict[str, object]:
|
|
metadata = dataset_metadata(dataset)
|
|
storage = metadata.get("osm_storage") if isinstance(metadata, dict) else None
|
|
path = sidecar_path(dataset)
|
|
path_fingerprint: dict[str, object] | None = None
|
|
if path is not None:
|
|
resolved = Path(path)
|
|
if resolved.exists():
|
|
path_fingerprint = {"path": str(resolved), "exists": True}
|
|
else:
|
|
path_fingerprint = {"path": str(resolved), "missing": True}
|
|
return {
|
|
"dataset_id": dataset.id,
|
|
"source_id": dataset.source_id,
|
|
"kind": dataset.kind,
|
|
"dataset_sha256": dataset.sha256,
|
|
"storage": storage,
|
|
"sidecar": path_fingerprint,
|
|
"classifier_version": OSM_LABEL_FEATURES_VERSION,
|
|
}
|
|
|
|
|
|
def _stamp_dataset_metadata(session: Session, dataset: Dataset, dependency_hash_value: str, output: dict[str, object]) -> None:
|
|
refreshed = session.get(Dataset, dataset.id)
|
|
if refreshed is None:
|
|
return
|
|
metadata = dataset_metadata(refreshed)
|
|
metadata["label_features"] = {
|
|
"stage": STAGE_LABEL_FEATURES,
|
|
"version": OSM_LABEL_FEATURES_VERSION,
|
|
"dependency_hash": dependency_hash_value,
|
|
"labeled_at": datetime.now(timezone.utc).isoformat(),
|
|
"processed": output.get("processed", 0),
|
|
"changed": output.get("changed", 0),
|
|
"storage": output.get("storage"),
|
|
}
|
|
refreshed.metadata_json = json.dumps(metadata, indent=2)
|
|
session.flush()
|
|
|
|
|
|
def _record_sidecar_label(connection: sqlite3.Connection, dataset: Dataset, *, processed: int, changed: int) -> None:
|
|
connection.execute("CREATE TABLE IF NOT EXISTS pipeline_metadata (key TEXT PRIMARY KEY, value TEXT NOT NULL)")
|
|
connection.execute(
|
|
"INSERT OR REPLACE INTO pipeline_metadata (key, value) VALUES (?, ?)",
|
|
(
|
|
"label_features",
|
|
json.dumps(
|
|
{
|
|
"stage": STAGE_LABEL_FEATURES,
|
|
"version": OSM_LABEL_FEATURES_VERSION,
|
|
"dataset_id": dataset.id,
|
|
"processed": processed,
|
|
"changed": changed,
|
|
"updated_at": datetime.now(timezone.utc).isoformat(),
|
|
},
|
|
sort_keys=True,
|
|
separators=(",", ":"),
|
|
),
|
|
),
|
|
)
|
|
|
|
|
|
def _record_sidecar_index_build(connection: sqlite3.Connection, dataset: Dataset, path: Path) -> None:
|
|
connection.execute("CREATE TABLE IF NOT EXISTS pipeline_metadata (key TEXT PRIMARY KEY, value TEXT NOT NULL)")
|
|
connection.execute(
|
|
"INSERT OR REPLACE INTO pipeline_metadata (key, value) VALUES (?, ?)",
|
|
(
|
|
"build_indexes:route_scope",
|
|
json.dumps(
|
|
{
|
|
"stage": STAGE_BUILD_INDEXES,
|
|
"version": "osm_sidecar_indexes_v1",
|
|
"dataset_id": dataset.id,
|
|
"path": str(path),
|
|
"updated_at": datetime.now(timezone.utc).isoformat(),
|
|
},
|
|
sort_keys=True,
|
|
separators=(",", ":"),
|
|
),
|
|
),
|
|
)
|
|
|
|
|
|
def _record_main_index_build(session: Session, dataset: Dataset) -> None:
|
|
dependency = {
|
|
"dataset_id": dataset.id,
|
|
"index": MAIN_ROUTE_SCOPE_INDEX,
|
|
"version": "osm_main_indexes_v1",
|
|
}
|
|
run = start_pipeline_run(
|
|
session,
|
|
stage=STAGE_BUILD_INDEXES,
|
|
version="osm_main_indexes_v1",
|
|
dependency_hash_value=dependency_hash(dependency),
|
|
source_id=dataset.source_id,
|
|
dataset_id=dataset.id,
|
|
inputs=dependency,
|
|
)
|
|
finish_pipeline_run(session, run, outputs={"index": MAIN_ROUTE_SCOPE_INDEX})
|
|
session.commit()
|
|
|
|
|
|
def _emit_progress(
|
|
callback: ProgressCallback | None,
|
|
event_type: str,
|
|
message: str,
|
|
current: int | None,
|
|
total: int | None,
|
|
metadata: dict[str, object] | None,
|
|
) -> None:
|
|
if callback is not None:
|
|
callback(event_type, message, current, total, metadata)
|