from __future__ import annotations import json from pathlib import Path from typing import Any from sqlalchemy import select from sqlalchemy.orm import Session from app.config import settings from app.models import Dataset, OsmFeature, Source from app.osm_classification import infer_osm_route_scope from app.osm_storage import ( OSM_STORAGE_METADATA_KEY, OSM_STORAGE_MAIN, OSM_STORAGE_SIDECAR_FEATURES, create_osm_sidecar, dedupe_osm_feature_rows, effective_osm_feature_storage, ) from app.pipeline.download import materialize_source from app.pipeline.utils import first_nonempty, geometry_json_and_bbox, norm_ref, norm_text, sha256_file from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries ROUTE_MODES = { "train", "railway", "light_rail", "subway", "tram", "bus", "trolleybus", "coach", "ferry", "monorail", "funicular", "aerialway", } def run_osm_geojson_source(session: Session, source: Source) -> Dataset: local_path = materialize_source(source) source_hash = sha256_file(local_path) existing = session.scalar( select(Dataset) .where( Dataset.source_id == source.id, Dataset.kind == "osm_geojson", Dataset.sha256 == source_hash, Dataset.is_active.is_(True), Dataset.status == "imported", ) .order_by(Dataset.id.desc()) ) if existing is not None: return existing return import_osm_geojson(session=session, source=source, path=local_path, source_hash=source_hash) def import_osm_geojson( session: Session, source: Source, path: Path, source_hash: str | None = None, *, storage_mode: str | None = None, ) -> Dataset: for dataset in source.datasets: dataset.is_active = False dataset = Dataset( source_id=source.id, kind="osm_geojson", local_path=str(path), sha256=source_hash or sha256_file(path), is_active=True, status="importing", ) session.add(dataset) session.flush() source_hash = source_hash or sha256_file(path) dataset.metadata_json = json.dumps( prepare_osm_geojson_storage( session=session, dataset=dataset, path=path, source_hash=source_hash, storage_mode=storage_mode, ), indent=2, ) dataset.status = "imported" source.status = "ok" source.last_error = None session.flush() return dataset def prepare_osm_geojson_storage( *, session: Session, dataset: Dataset, path: Path, source_hash: str | None = None, storage_mode: str | None = None, ) -> dict[str, object]: data = json.loads(path.read_text(encoding="utf-8")) features = _as_features(data) feature_rows = [_feature_row(dataset.id, idx, feature) for idx, feature in enumerate(features)] storage = effective_osm_feature_storage(storage_mode) if storage not in {OSM_STORAGE_MAIN, OSM_STORAGE_SIDECAR_FEATURES}: raise ValueError(f"Unsupported OSM feature storage mode: {storage}") if storage == OSM_STORAGE_SIDECAR_FEATURES: return { "features": len(feature_rows), OSM_STORAGE_METADATA_KEY: create_osm_sidecar(dataset, feature_rows, source_hash=source_hash or dataset.sha256), } _insert_main_features(session, feature_rows) session.flush() refresh_postgis_geometries(session, dataset_id=dataset.id, tables=["osm_features"]) analyze_postgresql_tables(session, ["osm_features"]) return {"features": len(feature_rows), OSM_STORAGE_METADATA_KEY: {"mode": OSM_STORAGE_MAIN}} def _insert_main_features(session: Session, feature_rows: list[dict[str, object]]) -> None: objects: list[OsmFeature] = [] deduped_rows, _duplicate_count = dedupe_osm_feature_rows(feature_rows) for row in deduped_rows: objects.append( OsmFeature( dataset_id=row["dataset_id"], osm_type=row["osm_type"], osm_id=row["osm_id"], kind=row["kind"], mode=row["mode"], route_scope=row["route_scope"], name=row["name"], ref=row["ref"], operator=row["operator"], network=row["network"], geometry_geojson=row["geometry_geojson"], min_lon=row["min_lon"], min_lat=row["min_lat"], max_lon=row["max_lon"], max_lat=row["max_lat"], tags_json=row["tags_json"], route_key=row["route_key"], operator_key=row["operator_key"], ) ) if len(objects) >= 5000: session.bulk_save_objects(objects) objects.clear() if objects: session.bulk_save_objects(objects) def _feature_row(dataset_id: int, idx: int, feature: dict[str, Any]) -> dict[str, object]: props = feature.get("properties") or {} geometry = feature.get("geometry") geometry_text, bbox = geometry_json_and_bbox(geometry) osm_type = str(first_nonempty(props.get("osm_type"), props.get("@type"), props.get("type"), "feature")) osm_id = str(first_nonempty(props.get("osm_id"), props.get("@id"), props.get("id"), f"feature_{idx}")) mode = _infer_mode(props) kind = _infer_kind(props, mode) name = first_nonempty(props.get("name"), props.get("official_name")) or None ref = first_nonempty(props.get("ref"), props.get("route_ref"), props.get("line")) or None operator = first_nonempty(props.get("operator"), props.get("agency"), props.get("brand")) or None network = first_nonempty(props.get("network"), props.get("network:short")) or None route_scope = infer_osm_route_scope(mode=mode, ref=ref, name=name, network=network, tags=props) route_key = norm_ref(ref) or norm_text(name) or norm_ref(osm_id) operator_key = norm_text(operator or network or "") return { "dataset_id": dataset_id, "osm_type": osm_type, "osm_id": osm_id, "kind": kind, "mode": mode, "route_scope": route_scope, "name": name, "ref": ref, "operator": operator, "network": network, "geometry_geojson": geometry_text, "min_lon": bbox[0], "min_lat": bbox[1], "max_lon": bbox[2], "max_lat": bbox[3], "tags_json": json.dumps(props, separators=(",", ":")), "route_key": route_key, "operator_key": operator_key, } def _as_features(data: Any) -> list[dict[str, Any]]: if isinstance(data, dict) and data.get("type") == "FeatureCollection": return [f for f in data.get("features", []) if isinstance(f, dict)] if isinstance(data, dict) and data.get("type") == "Feature": return [data] if isinstance(data, list): return [f for f in data if isinstance(f, dict)] raise ValueError("OSM source must be GeoJSON FeatureCollection, Feature, or list of Features") def _infer_mode(props: dict[str, Any]) -> str | None: for key in ("mode", "route", "route_master"): value = str(props.get(key) or "").strip() if value in ROUTE_MODES: return "train" if value == "railway" else value railway = str(props.get("railway") or "").strip() if railway in {"station", "halt"}: return "train" if railway == "tram_stop": return "tram" if railway == "subway_entrance": return "subway" if str(props.get("highway") or "") == "bus_stop" or str(props.get("amenity") or "") == "bus_station": return "bus" if str(props.get("amenity") or "") == "ferry_terminal": return "ferry" if str(props.get("aerialway") or "") == "station": return "aerialway" return None def _infer_kind(props: dict[str, Any], mode: str | None) -> str: explicit_kind = str(props.get("kind") or "").strip() if explicit_kind in {"route", "stop", "station", "terminal", "infra", "feature"}: return explicit_kind if str(props.get("type") or "") in {"route", "route_master"} or str(props.get("route") or "") in ROUTE_MODES: return "route" if str(props.get("amenity") or "") == "ferry_terminal": return "terminal" if str(props.get("amenity") or "") == "bus_station": return "terminal" if str(props.get("railway") or "") in {"station", "halt"}: return "station" if str(props.get("aerialway") or "") == "station": return "station" if str(props.get("public_transport") or "") in {"platform", "stop_position", "station"}: return "stop" if str(props.get("highway") or "") == "bus_stop": return "stop" if mode: return "infra" return "feature"