Files
meubility-workbench/app/pipeline/osm_geojson.py
2026-07-01 23:29:51 +02:00

249 lines
8.6 KiB
Python

from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.config import settings
from app.models import Dataset, OsmFeature, Source
from app.osm_classification import infer_osm_route_scope
from app.osm_storage import (
OSM_STORAGE_METADATA_KEY,
OSM_STORAGE_MAIN,
OSM_STORAGE_SIDECAR_FEATURES,
create_osm_sidecar,
dedupe_osm_feature_rows,
effective_osm_feature_storage,
)
from app.pipeline.download import materialize_source
from app.pipeline.utils import first_nonempty, geometry_json_and_bbox, norm_ref, norm_text, sha256_file
from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries
ROUTE_MODES = {
"train",
"railway",
"light_rail",
"subway",
"tram",
"bus",
"trolleybus",
"coach",
"ferry",
"monorail",
"funicular",
"aerialway",
}
def run_osm_geojson_source(session: Session, source: Source) -> Dataset:
local_path = materialize_source(source)
source_hash = sha256_file(local_path)
existing = session.scalar(
select(Dataset)
.where(
Dataset.source_id == source.id,
Dataset.kind == "osm_geojson",
Dataset.sha256 == source_hash,
Dataset.is_active.is_(True),
Dataset.status == "imported",
)
.order_by(Dataset.id.desc())
)
if existing is not None:
return existing
return import_osm_geojson(session=session, source=source, path=local_path, source_hash=source_hash)
def import_osm_geojson(
session: Session,
source: Source,
path: Path,
source_hash: str | None = None,
*,
storage_mode: str | None = None,
) -> Dataset:
for dataset in source.datasets:
dataset.is_active = False
dataset = Dataset(
source_id=source.id,
kind="osm_geojson",
local_path=str(path),
sha256=source_hash or sha256_file(path),
is_active=True,
status="importing",
)
session.add(dataset)
session.flush()
source_hash = source_hash or sha256_file(path)
dataset.metadata_json = json.dumps(
prepare_osm_geojson_storage(
session=session,
dataset=dataset,
path=path,
source_hash=source_hash,
storage_mode=storage_mode,
),
indent=2,
)
dataset.status = "imported"
source.status = "ok"
source.last_error = None
session.flush()
return dataset
def prepare_osm_geojson_storage(
*,
session: Session,
dataset: Dataset,
path: Path,
source_hash: str | None = None,
storage_mode: str | None = None,
) -> dict[str, object]:
data = json.loads(path.read_text(encoding="utf-8"))
features = _as_features(data)
feature_rows = [_feature_row(dataset.id, idx, feature) for idx, feature in enumerate(features)]
storage = effective_osm_feature_storage(storage_mode)
if storage not in {OSM_STORAGE_MAIN, OSM_STORAGE_SIDECAR_FEATURES}:
raise ValueError(f"Unsupported OSM feature storage mode: {storage}")
if storage == OSM_STORAGE_SIDECAR_FEATURES:
return {
"features": len(feature_rows),
OSM_STORAGE_METADATA_KEY: create_osm_sidecar(dataset, feature_rows, source_hash=source_hash or dataset.sha256),
}
_insert_main_features(session, feature_rows)
session.flush()
refresh_postgis_geometries(session, dataset_id=dataset.id, tables=["osm_features"])
analyze_postgresql_tables(session, ["osm_features"])
return {"features": len(feature_rows), OSM_STORAGE_METADATA_KEY: {"mode": OSM_STORAGE_MAIN}}
def _insert_main_features(session: Session, feature_rows: list[dict[str, object]]) -> None:
objects: list[OsmFeature] = []
deduped_rows, _duplicate_count = dedupe_osm_feature_rows(feature_rows)
for row in deduped_rows:
objects.append(
OsmFeature(
dataset_id=row["dataset_id"],
osm_type=row["osm_type"],
osm_id=row["osm_id"],
kind=row["kind"],
mode=row["mode"],
route_scope=row["route_scope"],
name=row["name"],
ref=row["ref"],
operator=row["operator"],
network=row["network"],
geometry_geojson=row["geometry_geojson"],
min_lon=row["min_lon"],
min_lat=row["min_lat"],
max_lon=row["max_lon"],
max_lat=row["max_lat"],
tags_json=row["tags_json"],
route_key=row["route_key"],
operator_key=row["operator_key"],
)
)
if len(objects) >= 5000:
session.bulk_save_objects(objects)
objects.clear()
if objects:
session.bulk_save_objects(objects)
def _feature_row(dataset_id: int, idx: int, feature: dict[str, Any]) -> dict[str, object]:
props = feature.get("properties") or {}
geometry = feature.get("geometry")
geometry_text, bbox = geometry_json_and_bbox(geometry)
osm_type = str(first_nonempty(props.get("osm_type"), props.get("@type"), props.get("type"), "feature"))
osm_id = str(first_nonempty(props.get("osm_id"), props.get("@id"), props.get("id"), f"feature_{idx}"))
mode = _infer_mode(props)
kind = _infer_kind(props, mode)
name = first_nonempty(props.get("name"), props.get("official_name")) or None
ref = first_nonempty(props.get("ref"), props.get("route_ref"), props.get("line")) or None
operator = first_nonempty(props.get("operator"), props.get("agency"), props.get("brand")) or None
network = first_nonempty(props.get("network"), props.get("network:short")) or None
route_scope = infer_osm_route_scope(mode=mode, ref=ref, name=name, network=network, tags=props)
route_key = norm_ref(ref) or norm_text(name) or norm_ref(osm_id)
operator_key = norm_text(operator or network or "")
return {
"dataset_id": dataset_id,
"osm_type": osm_type,
"osm_id": osm_id,
"kind": kind,
"mode": mode,
"route_scope": route_scope,
"name": name,
"ref": ref,
"operator": operator,
"network": network,
"geometry_geojson": geometry_text,
"min_lon": bbox[0],
"min_lat": bbox[1],
"max_lon": bbox[2],
"max_lat": bbox[3],
"tags_json": json.dumps(props, separators=(",", ":")),
"route_key": route_key,
"operator_key": operator_key,
}
def _as_features(data: Any) -> list[dict[str, Any]]:
if isinstance(data, dict) and data.get("type") == "FeatureCollection":
return [f for f in data.get("features", []) if isinstance(f, dict)]
if isinstance(data, dict) and data.get("type") == "Feature":
return [data]
if isinstance(data, list):
return [f for f in data if isinstance(f, dict)]
raise ValueError("OSM source must be GeoJSON FeatureCollection, Feature, or list of Features")
def _infer_mode(props: dict[str, Any]) -> str | None:
for key in ("mode", "route", "route_master"):
value = str(props.get(key) or "").strip()
if value in ROUTE_MODES:
return "train" if value == "railway" else value
railway = str(props.get("railway") or "").strip()
if railway in {"station", "halt"}:
return "train"
if railway == "tram_stop":
return "tram"
if railway == "subway_entrance":
return "subway"
if str(props.get("highway") or "") == "bus_stop" or str(props.get("amenity") or "") == "bus_station":
return "bus"
if str(props.get("amenity") or "") == "ferry_terminal":
return "ferry"
if str(props.get("aerialway") or "") == "station":
return "aerialway"
return None
def _infer_kind(props: dict[str, Any], mode: str | None) -> str:
explicit_kind = str(props.get("kind") or "").strip()
if explicit_kind in {"route", "stop", "station", "terminal", "infra", "feature"}:
return explicit_kind
if str(props.get("type") or "") in {"route", "route_master"} or str(props.get("route") or "") in ROUTE_MODES:
return "route"
if str(props.get("amenity") or "") == "ferry_terminal":
return "terminal"
if str(props.get("amenity") or "") == "bus_station":
return "terminal"
if str(props.get("railway") or "") in {"station", "halt"}:
return "station"
if str(props.get("aerialway") or "") == "station":
return "station"
if str(props.get("public_transport") or "") in {"platform", "stop_position", "station"}:
return "stop"
if str(props.get("highway") or "") == "bus_stop":
return "stop"
if mode:
return "infra"
return "feature"