Alpha stage commit
This commit is contained in:
456
app/pipeline/osm_labeling.py
Normal file
456
app/pipeline/osm_labeling.py
Normal file
@@ -0,0 +1,456 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
from pathlib import Path
|
||||
import sqlite3
|
||||
from typing import Callable
|
||||
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import Dataset, OsmFeature
|
||||
from app.osm_classification import OSM_ROUTE_SCOPE_CLASSIFIER_VERSION, infer_osm_route_scope_from_tags
|
||||
from app.osm_storage import (
|
||||
dataset_metadata,
|
||||
drop_osm_sidecar_route_scope_indexes,
|
||||
ensure_osm_sidecar_schema,
|
||||
features_are_sidecar,
|
||||
rebuild_osm_sidecar_indexes,
|
||||
sidecar_path,
|
||||
writable_sidecar_connection,
|
||||
)
|
||||
from app.pipeline.state import (
|
||||
STAGE_BUILD_INDEXES,
|
||||
STAGE_LABEL_FEATURES,
|
||||
dependency_hash,
|
||||
finish_pipeline_run,
|
||||
latest_completed_run,
|
||||
start_pipeline_run,
|
||||
)
|
||||
|
||||
|
||||
OSM_LABEL_FEATURES_VERSION = OSM_ROUTE_SCOPE_CLASSIFIER_VERSION
|
||||
MAIN_ROUTE_SCOPE_INDEX = "ix_osm_features_scope_bbox"
|
||||
MAIN_INDEX_REBUILD_THRESHOLD = 10_000
|
||||
SIDECAR_INDEX_REBUILD_THRESHOLD = 10_000
|
||||
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
|
||||
|
||||
|
||||
def relabel_osm_features(
|
||||
session: Session,
|
||||
*,
|
||||
dataset_id: int | None = None,
|
||||
chunk_size: int = 5000,
|
||||
force: bool = False,
|
||||
rebuild_indexes: bool = True,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
job_id: int | None = None,
|
||||
) -> dict[str, object]:
|
||||
datasets = _target_datasets(session, dataset_id)
|
||||
result: dict[str, object] = {
|
||||
"version": OSM_LABEL_FEATURES_VERSION,
|
||||
"datasets": len(datasets),
|
||||
"processed": 0,
|
||||
"changed": 0,
|
||||
"skipped": 0,
|
||||
"missing": 0,
|
||||
"index_rebuilds": 0,
|
||||
"dataset_results": [],
|
||||
}
|
||||
_emit_progress(
|
||||
progress_callback,
|
||||
"osm_labeling_started",
|
||||
f"Relabeling {len(datasets)} OSM dataset(s).",
|
||||
0,
|
||||
len(datasets),
|
||||
{"dataset_id": dataset_id, "force": force, "version": OSM_LABEL_FEATURES_VERSION},
|
||||
)
|
||||
for index, dataset in enumerate(datasets, start=1):
|
||||
dataset_result = relabel_osm_dataset(
|
||||
session,
|
||||
dataset,
|
||||
chunk_size=chunk_size,
|
||||
force=force,
|
||||
rebuild_indexes=rebuild_indexes,
|
||||
progress_callback=progress_callback,
|
||||
job_id=job_id,
|
||||
)
|
||||
result["processed"] = int(result["processed"]) + int(dataset_result.get("processed", 0) or 0)
|
||||
result["changed"] = int(result["changed"]) + int(dataset_result.get("changed", 0) or 0)
|
||||
result["skipped"] = int(result["skipped"]) + (1 if dataset_result.get("status") == "skipped" else 0)
|
||||
result["missing"] = int(result["missing"]) + (1 if dataset_result.get("status") == "missing_sidecar" else 0)
|
||||
result["index_rebuilds"] = int(result["index_rebuilds"]) + int(dataset_result.get("index_rebuilds", 0) or 0)
|
||||
result["dataset_results"].append(dataset_result) # type: ignore[union-attr]
|
||||
_emit_progress(
|
||||
progress_callback,
|
||||
"osm_labeling_dataset_completed",
|
||||
f"Relabeled {index}/{len(datasets)} OSM dataset(s).",
|
||||
index,
|
||||
len(datasets),
|
||||
dataset_result,
|
||||
)
|
||||
_emit_progress(progress_callback, "osm_labeling_completed", "OSM feature relabeling completed.", len(datasets), len(datasets), result)
|
||||
return result
|
||||
|
||||
|
||||
def relabel_osm_dataset(
|
||||
session: Session,
|
||||
dataset: Dataset,
|
||||
*,
|
||||
chunk_size: int = 5000,
|
||||
force: bool = False,
|
||||
rebuild_indexes: bool = True,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
job_id: int | None = None,
|
||||
) -> dict[str, object]:
|
||||
dependency = _label_dependency(dataset)
|
||||
dependency_hash_value = dependency_hash(dependency)
|
||||
if not force and _dataset_label_is_current(session, dataset, dependency_hash_value):
|
||||
return {
|
||||
"dataset_id": dataset.id,
|
||||
"source_id": dataset.source_id,
|
||||
"status": "skipped",
|
||||
"reason": "label_features dependency is current",
|
||||
"dependency_hash": dependency_hash_value,
|
||||
"version": OSM_LABEL_FEATURES_VERSION,
|
||||
"processed": 0,
|
||||
"changed": 0,
|
||||
"index_rebuilds": 0,
|
||||
}
|
||||
|
||||
run = start_pipeline_run(
|
||||
session,
|
||||
stage=STAGE_LABEL_FEATURES,
|
||||
version=OSM_LABEL_FEATURES_VERSION,
|
||||
dependency_hash_value=dependency_hash_value,
|
||||
source_id=dataset.source_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=job_id,
|
||||
inputs=dependency,
|
||||
)
|
||||
session.commit()
|
||||
try:
|
||||
if features_are_sidecar(dataset):
|
||||
counts = _relabel_sidecar_dataset(dataset, chunk_size=chunk_size, rebuild_indexes=rebuild_indexes, progress_callback=progress_callback)
|
||||
else:
|
||||
counts = _relabel_main_dataset(session, dataset, chunk_size=chunk_size, rebuild_indexes=rebuild_indexes, progress_callback=progress_callback)
|
||||
output = {
|
||||
"dataset_id": dataset.id,
|
||||
"source_id": dataset.source_id,
|
||||
"status": "completed",
|
||||
"dependency_hash": dependency_hash_value,
|
||||
"version": OSM_LABEL_FEATURES_VERSION,
|
||||
**counts,
|
||||
}
|
||||
_stamp_dataset_metadata(session, dataset, dependency_hash_value, output)
|
||||
finish_pipeline_run(session, run, outputs=output)
|
||||
session.commit()
|
||||
return output
|
||||
except FileNotFoundError as exc:
|
||||
output = {
|
||||
"dataset_id": dataset.id,
|
||||
"source_id": dataset.source_id,
|
||||
"status": "missing_sidecar",
|
||||
"dependency_hash": dependency_hash_value,
|
||||
"version": OSM_LABEL_FEATURES_VERSION,
|
||||
"processed": 0,
|
||||
"changed": 0,
|
||||
"index_rebuilds": 0,
|
||||
"error": str(exc),
|
||||
}
|
||||
finish_pipeline_run(session, run, status="failed", outputs=output, error=str(exc))
|
||||
session.commit()
|
||||
return output
|
||||
except Exception as exc:
|
||||
finish_pipeline_run(session, run, status="failed", error=str(exc))
|
||||
session.commit()
|
||||
raise
|
||||
|
||||
|
||||
def _target_datasets(session: Session, dataset_id: int | None) -> list[Dataset]:
|
||||
stmt = select(Dataset).where(Dataset.kind == "osm_geojson", Dataset.status == "imported")
|
||||
if dataset_id is None:
|
||||
stmt = stmt.where(Dataset.is_active.is_(True))
|
||||
else:
|
||||
stmt = stmt.where(Dataset.id == dataset_id)
|
||||
return session.scalars(stmt.order_by(Dataset.source_id, Dataset.id)).all()
|
||||
|
||||
|
||||
def _dataset_label_is_current(session: Session, dataset: Dataset, dependency_hash_value: str) -> bool:
|
||||
metadata = dataset_metadata(dataset)
|
||||
label_info = metadata.get("label_features")
|
||||
metadata_current = (
|
||||
isinstance(label_info, dict)
|
||||
and label_info.get("version") == OSM_LABEL_FEATURES_VERSION
|
||||
and label_info.get("dependency_hash") == dependency_hash_value
|
||||
)
|
||||
if not metadata_current:
|
||||
return False
|
||||
return (
|
||||
latest_completed_run(
|
||||
session,
|
||||
stage=STAGE_LABEL_FEATURES,
|
||||
version=OSM_LABEL_FEATURES_VERSION,
|
||||
dependency_hash_value=dependency_hash_value,
|
||||
source_id=dataset.source_id,
|
||||
dataset_id=dataset.id,
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def _relabel_sidecar_dataset(
|
||||
dataset: Dataset,
|
||||
*,
|
||||
chunk_size: int,
|
||||
rebuild_indexes: bool,
|
||||
progress_callback: ProgressCallback | None,
|
||||
) -> dict[str, int | str]:
|
||||
path = sidecar_path(dataset)
|
||||
if path is None or not path.exists():
|
||||
raise FileNotFoundError(f"OSM sidecar does not exist: {path}")
|
||||
with writable_sidecar_connection(dataset) as connection:
|
||||
ensure_osm_sidecar_schema(connection)
|
||||
total = int(connection.execute("SELECT COUNT(*) FROM osm_features").fetchone()[0] or 0)
|
||||
should_rebuild_index = rebuild_indexes and total >= SIDECAR_INDEX_REBUILD_THRESHOLD
|
||||
if should_rebuild_index:
|
||||
drop_osm_sidecar_route_scope_indexes(connection)
|
||||
connection.commit()
|
||||
processed = 0
|
||||
changed = 0
|
||||
last_id = 0
|
||||
try:
|
||||
while True:
|
||||
rows = connection.execute(
|
||||
"""
|
||||
SELECT id, mode, ref, name, network, tags_json, route_scope
|
||||
FROM osm_features
|
||||
WHERE id > ?
|
||||
ORDER BY id
|
||||
LIMIT ?
|
||||
""",
|
||||
(last_id, max(1, int(chunk_size))),
|
||||
).fetchall()
|
||||
if not rows:
|
||||
break
|
||||
updates: list[tuple[str | None, int]] = []
|
||||
for row in rows:
|
||||
last_id = int(row["id"])
|
||||
new_scope = _classified_scope(row["mode"], row["ref"], row["name"], row["network"], row["tags_json"])
|
||||
if _normalize_scope(row["route_scope"]) != new_scope:
|
||||
updates.append((new_scope, last_id))
|
||||
if updates:
|
||||
connection.executemany("UPDATE osm_features SET route_scope = ? WHERE id = ?", updates)
|
||||
processed += len(rows)
|
||||
changed += len(updates)
|
||||
connection.commit()
|
||||
_emit_progress(
|
||||
progress_callback,
|
||||
"osm_labeling_batch",
|
||||
f"Relabeled {processed}/{total} OSM sidecar features.",
|
||||
processed,
|
||||
total,
|
||||
{"dataset_id": dataset.id, "changed": changed, "storage": "sidecar"},
|
||||
)
|
||||
finally:
|
||||
index_rebuilds = 0
|
||||
if should_rebuild_index:
|
||||
rebuild_osm_sidecar_indexes(connection)
|
||||
connection.commit()
|
||||
index_rebuilds = 1
|
||||
_record_sidecar_index_build(connection, dataset, path)
|
||||
_record_sidecar_label(connection, dataset, processed=processed, changed=changed)
|
||||
connection.commit()
|
||||
return {"storage": "sidecar", "processed": processed, "changed": changed, "index_rebuilds": index_rebuilds}
|
||||
|
||||
|
||||
def _relabel_main_dataset(
|
||||
session: Session,
|
||||
dataset: Dataset,
|
||||
*,
|
||||
chunk_size: int,
|
||||
rebuild_indexes: bool,
|
||||
progress_callback: ProgressCallback | None,
|
||||
) -> dict[str, int | str]:
|
||||
total = int(session.scalar(select(func.count()).select_from(OsmFeature).where(OsmFeature.dataset_id == dataset.id)) or 0)
|
||||
should_rebuild_index = rebuild_indexes and total >= MAIN_INDEX_REBUILD_THRESHOLD
|
||||
index_rebuilds = 0
|
||||
if should_rebuild_index:
|
||||
session.execute(text(f"DROP INDEX IF EXISTS {MAIN_ROUTE_SCOPE_INDEX}"))
|
||||
session.commit()
|
||||
processed = 0
|
||||
changed = 0
|
||||
last_id = 0
|
||||
try:
|
||||
while True:
|
||||
rows = session.scalars(
|
||||
select(OsmFeature)
|
||||
.where(OsmFeature.dataset_id == dataset.id, OsmFeature.id > last_id)
|
||||
.order_by(OsmFeature.id)
|
||||
.limit(max(1, int(chunk_size)))
|
||||
).all()
|
||||
if not rows:
|
||||
break
|
||||
updates: list[dict[str, object]] = []
|
||||
for feature in rows:
|
||||
last_id = int(feature.id)
|
||||
new_scope = _classified_scope(feature.mode, feature.ref, feature.name, feature.network, feature.tags_json)
|
||||
if _normalize_scope(feature.route_scope) != new_scope:
|
||||
updates.append({"id": feature.id, "route_scope": new_scope})
|
||||
if updates:
|
||||
session.bulk_update_mappings(OsmFeature, updates)
|
||||
processed += len(rows)
|
||||
changed += len(updates)
|
||||
session.commit()
|
||||
_emit_progress(
|
||||
progress_callback,
|
||||
"osm_labeling_batch",
|
||||
f"Relabeled {processed}/{total} main-table OSM features.",
|
||||
processed,
|
||||
total,
|
||||
{"dataset_id": dataset.id, "changed": changed, "storage": "main"},
|
||||
)
|
||||
finally:
|
||||
if should_rebuild_index:
|
||||
session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_features_scope_bbox "
|
||||
"ON osm_features (dataset_id, kind, mode, route_scope, min_lon, max_lon, min_lat, max_lat)"
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
index_rebuilds = 1
|
||||
_record_main_index_build(session, dataset)
|
||||
return {"storage": "main", "processed": processed, "changed": changed, "index_rebuilds": index_rebuilds}
|
||||
|
||||
|
||||
def _classified_scope(mode: object, ref: object, name: object, network: object, tags_json: object) -> str | None:
|
||||
return _normalize_scope(
|
||||
infer_osm_route_scope_from_tags(
|
||||
None if mode is None else str(mode),
|
||||
None if ref is None else str(ref),
|
||||
None if name is None else str(name),
|
||||
None if network is None else str(network),
|
||||
None if tags_json is None else str(tags_json),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _normalize_scope(value: object) -> str | None:
|
||||
text_value = str(value or "").strip()
|
||||
return text_value or None
|
||||
|
||||
|
||||
def _label_dependency(dataset: Dataset) -> dict[str, object]:
|
||||
metadata = dataset_metadata(dataset)
|
||||
storage = metadata.get("osm_storage") if isinstance(metadata, dict) else None
|
||||
path = sidecar_path(dataset)
|
||||
path_fingerprint: dict[str, object] | None = None
|
||||
if path is not None:
|
||||
resolved = Path(path)
|
||||
if resolved.exists():
|
||||
path_fingerprint = {"path": str(resolved), "exists": True}
|
||||
else:
|
||||
path_fingerprint = {"path": str(resolved), "missing": True}
|
||||
return {
|
||||
"dataset_id": dataset.id,
|
||||
"source_id": dataset.source_id,
|
||||
"kind": dataset.kind,
|
||||
"dataset_sha256": dataset.sha256,
|
||||
"storage": storage,
|
||||
"sidecar": path_fingerprint,
|
||||
"classifier_version": OSM_LABEL_FEATURES_VERSION,
|
||||
}
|
||||
|
||||
|
||||
def _stamp_dataset_metadata(session: Session, dataset: Dataset, dependency_hash_value: str, output: dict[str, object]) -> None:
|
||||
refreshed = session.get(Dataset, dataset.id)
|
||||
if refreshed is None:
|
||||
return
|
||||
metadata = dataset_metadata(refreshed)
|
||||
metadata["label_features"] = {
|
||||
"stage": STAGE_LABEL_FEATURES,
|
||||
"version": OSM_LABEL_FEATURES_VERSION,
|
||||
"dependency_hash": dependency_hash_value,
|
||||
"labeled_at": datetime.now(timezone.utc).isoformat(),
|
||||
"processed": output.get("processed", 0),
|
||||
"changed": output.get("changed", 0),
|
||||
"storage": output.get("storage"),
|
||||
}
|
||||
refreshed.metadata_json = json.dumps(metadata, indent=2)
|
||||
session.flush()
|
||||
|
||||
|
||||
def _record_sidecar_label(connection: sqlite3.Connection, dataset: Dataset, *, processed: int, changed: int) -> None:
|
||||
connection.execute("CREATE TABLE IF NOT EXISTS pipeline_metadata (key TEXT PRIMARY KEY, value TEXT NOT NULL)")
|
||||
connection.execute(
|
||||
"INSERT OR REPLACE INTO pipeline_metadata (key, value) VALUES (?, ?)",
|
||||
(
|
||||
"label_features",
|
||||
json.dumps(
|
||||
{
|
||||
"stage": STAGE_LABEL_FEATURES,
|
||||
"version": OSM_LABEL_FEATURES_VERSION,
|
||||
"dataset_id": dataset.id,
|
||||
"processed": processed,
|
||||
"changed": changed,
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _record_sidecar_index_build(connection: sqlite3.Connection, dataset: Dataset, path: Path) -> None:
|
||||
connection.execute("CREATE TABLE IF NOT EXISTS pipeline_metadata (key TEXT PRIMARY KEY, value TEXT NOT NULL)")
|
||||
connection.execute(
|
||||
"INSERT OR REPLACE INTO pipeline_metadata (key, value) VALUES (?, ?)",
|
||||
(
|
||||
"build_indexes:route_scope",
|
||||
json.dumps(
|
||||
{
|
||||
"stage": STAGE_BUILD_INDEXES,
|
||||
"version": "osm_sidecar_indexes_v1",
|
||||
"dataset_id": dataset.id,
|
||||
"path": str(path),
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _record_main_index_build(session: Session, dataset: Dataset) -> None:
|
||||
dependency = {
|
||||
"dataset_id": dataset.id,
|
||||
"index": MAIN_ROUTE_SCOPE_INDEX,
|
||||
"version": "osm_main_indexes_v1",
|
||||
}
|
||||
run = start_pipeline_run(
|
||||
session,
|
||||
stage=STAGE_BUILD_INDEXES,
|
||||
version="osm_main_indexes_v1",
|
||||
dependency_hash_value=dependency_hash(dependency),
|
||||
source_id=dataset.source_id,
|
||||
dataset_id=dataset.id,
|
||||
inputs=dependency,
|
||||
)
|
||||
finish_pipeline_run(session, run, outputs={"index": MAIN_ROUTE_SCOPE_INDEX})
|
||||
session.commit()
|
||||
|
||||
|
||||
def _emit_progress(
|
||||
callback: ProgressCallback | None,
|
||||
event_type: str,
|
||||
message: str,
|
||||
current: int | None,
|
||||
total: int | None,
|
||||
metadata: dict[str, object] | None,
|
||||
) -> None:
|
||||
if callback is not None:
|
||||
callback(event_type, message, current, total, metadata)
|
||||
Reference in New Issue
Block a user