Alpha stage commit
This commit is contained in:
111
app/pipeline/download.py
Normal file
111
app/pipeline/download.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from app.config import settings
|
||||
from app.models import Source
|
||||
from app.pipeline.utils import sha256_file
|
||||
|
||||
|
||||
def materialize_source(source: Source) -> Path:
|
||||
"""Download/copy a source into the local cache and return the file path.
|
||||
|
||||
Files are stored by content hash per source. Re-running an unchanged source
|
||||
reuses the existing cached file instead of creating another timestamped copy.
|
||||
"""
|
||||
source_dir = settings.data_dir / "sources" / f"source_{source.id}"
|
||||
source_dir.mkdir(parents=True, exist_ok=True)
|
||||
suffix = _guess_suffix(source.url, source.kind)
|
||||
|
||||
parsed = urlparse(source.url)
|
||||
if parsed.scheme in {"http", "https"}:
|
||||
temp_path = _download_temp_path(source_dir, suffix)
|
||||
existing_size = temp_path.stat().st_size if temp_path.exists() else 0
|
||||
headers = {"Range": f"bytes={existing_size}-"} if existing_size > 0 else None
|
||||
with requests.get(source.url, stream=True, timeout=120, headers=headers) as r:
|
||||
r.raise_for_status()
|
||||
mode = "ab" if existing_size > 0 and r.status_code == 206 else "wb"
|
||||
with temp_path.open(mode) as f:
|
||||
for chunk in r.iter_content(chunk_size=1024 * 1024):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
return _store_or_reuse_cached_file(source_dir=source_dir, source_path=temp_path, suffix=suffix, move=True)
|
||||
|
||||
if parsed.scheme == "file":
|
||||
source_path = Path(parsed.path)
|
||||
else:
|
||||
source_path = Path(source.url)
|
||||
|
||||
if not source_path.exists():
|
||||
raise FileNotFoundError(f"Source file does not exist: {source.url}")
|
||||
if _is_relative_to(source_path.resolve(), source_dir.resolve()):
|
||||
return source_path
|
||||
return _store_or_reuse_cached_file(source_dir=source_dir, source_path=source_path, suffix=suffix, move=False)
|
||||
|
||||
|
||||
def _download_temp_path(source_dir: Path, suffix: str) -> Path:
|
||||
candidates = sorted(
|
||||
source_dir.glob(f"*.download{suffix}"),
|
||||
key=lambda path: path.stat().st_mtime if path.exists() else 0,
|
||||
reverse=True,
|
||||
)
|
||||
if candidates:
|
||||
return candidates[0]
|
||||
return source_dir / f"{int(time.time())}.download{suffix}"
|
||||
|
||||
|
||||
def _guess_suffix(url: str, kind: str) -> str:
|
||||
path = urlparse(url).path or url
|
||||
lower = path.lower()
|
||||
for suffix in (".zip", ".geojson", ".json", ".osm.pbf", ".pbf", ".osm", ".osm.xml", ".osc.gz", ".osc", ".csv"):
|
||||
if lower.endswith(suffix):
|
||||
return suffix
|
||||
if kind == "gtfs":
|
||||
return ".zip"
|
||||
if kind == "osm_geojson":
|
||||
return ".geojson"
|
||||
return ".dat"
|
||||
|
||||
|
||||
def _store_or_reuse_cached_file(source_dir: Path, source_path: Path, suffix: str, move: bool) -> Path:
|
||||
source_hash = sha256_file(source_path)
|
||||
target = source_dir / f"{source_hash[:16]}{suffix}"
|
||||
|
||||
if target.exists() and sha256_file(target) == source_hash:
|
||||
if move and source_path != target:
|
||||
source_path.unlink(missing_ok=True)
|
||||
return target
|
||||
|
||||
existing = _find_existing_cached_file(source_dir, source_hash, suffix, exclude=source_path)
|
||||
if existing is not None:
|
||||
if move and source_path != existing:
|
||||
source_path.unlink(missing_ok=True)
|
||||
return existing
|
||||
|
||||
if move:
|
||||
source_path.replace(target)
|
||||
else:
|
||||
shutil.copyfile(source_path, target)
|
||||
return target
|
||||
|
||||
|
||||
def _find_existing_cached_file(source_dir: Path, source_hash: str, suffix: str, exclude: Path | None = None) -> Path | None:
|
||||
for candidate in sorted(source_dir.glob(f"*{suffix}")):
|
||||
if exclude is not None and candidate.resolve() == exclude.resolve():
|
||||
continue
|
||||
if candidate.is_file() and sha256_file(candidate) == source_hash:
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
def _is_relative_to(path: Path, parent: Path) -> bool:
|
||||
try:
|
||||
path.relative_to(parent)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
1327
app/pipeline/gtfs.py
Normal file
1327
app/pipeline/gtfs.py
Normal file
File diff suppressed because it is too large
Load Diff
995
app/pipeline/matcher.py
Normal file
995
app/pipeline/matcher.py
Normal file
@@ -0,0 +1,995 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
from typing import Callable, Optional
|
||||
|
||||
from shapely.geometry import LineString, MultiLineString, Point, shape
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models import Dataset, GtfsRoute, MatchRule, OsmFeature, RouteMatch
|
||||
from app.osm_storage import ensure_main_osm_feature, osm_feature_bbox, query_osm_features
|
||||
from app.pipeline.state import STAGE_MATCH_ROUTES, dependency_hash, finish_pipeline_run, start_pipeline_run
|
||||
from app.pipeline.utils import approx_bbox_center_distance_deg, bbox_overlap, norm_ref, norm_text
|
||||
|
||||
MODE_GROUPS = {
|
||||
"train": {"train", "rail", "railway"},
|
||||
"subway": {"subway", "metro"},
|
||||
"tram": {"tram", "light_rail"},
|
||||
"light_rail": {"light_rail", "tram"},
|
||||
"bus": {"bus", "coach", "trolleybus"},
|
||||
"coach": {"coach", "bus"},
|
||||
"trolleybus": {"trolleybus", "bus"},
|
||||
"ferry": {"ferry"},
|
||||
"funicular": {"funicular"},
|
||||
"aerialway": {"aerialway", "cable_car"},
|
||||
"monorail": {"monorail"},
|
||||
}
|
||||
MAX_FALLBACK_CANDIDATES_WITH_REF = 40
|
||||
MAX_FALLBACK_CANDIDATES_WITHOUT_REF = 80
|
||||
MAX_EXACT_REF_CANDIDATES = 120
|
||||
OSM_SCOPE_NEAR_DISTANCE_DEG = 0.15
|
||||
GEOMETRY_PROXIMITY_DEG = 0.0035
|
||||
GEOMETRY_SAMPLE_POINTS = 24
|
||||
MATCHER_VERSION = "matcher_v4_scope_spatial_manual_rules"
|
||||
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ManualMatchRule:
|
||||
id: int
|
||||
rule_type: str
|
||||
route_selector: dict[str, object]
|
||||
osm_selector: dict[str, object] | None
|
||||
status: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _OsmRouteIndex:
|
||||
all_routes: list[OsmFeature]
|
||||
by_ref: dict[str, list[OsmFeature]]
|
||||
by_route_key: dict[str, list[OsmFeature]]
|
||||
by_mode: dict[str, list[OsmFeature]]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _GeometryProfile:
|
||||
geom: object
|
||||
lines: list[LineString]
|
||||
length: float
|
||||
sample_points: list[Point]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _RouteMatchPayload:
|
||||
gtfs_route_id: int
|
||||
osm_feature_id: int | None
|
||||
confidence: float
|
||||
status: str
|
||||
rule_source: str
|
||||
reasons_json: str | None
|
||||
|
||||
|
||||
def run_route_matching(
|
||||
session: Session,
|
||||
*,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
batch_size: int | None = None,
|
||||
) -> dict[str, object]:
|
||||
"""Match active GTFS routes against active OSM route features."""
|
||||
active_datasets = session.execute(
|
||||
select(Dataset.id, Dataset.kind, Dataset.source_id).where(Dataset.is_active.is_(True))
|
||||
).all()
|
||||
if not active_datasets:
|
||||
return {"routes": 0, "matches": 0, "missing": 0}
|
||||
dataset_source_ids = {int(dataset_id): int(source_id) for dataset_id, _, source_id in active_datasets}
|
||||
gtfs_dataset_ids = [int(dataset_id) for dataset_id, kind, _ in active_datasets if kind == "gtfs"]
|
||||
osm_dataset_ids = [int(dataset_id) for dataset_id, kind, _ in active_datasets if kind == "osm_geojson"]
|
||||
if not gtfs_dataset_ids:
|
||||
return {"routes": 0, "matches": 0, "missing": 0}
|
||||
|
||||
route_row_ids = session.scalars(
|
||||
select(GtfsRoute.id)
|
||||
.where(GtfsRoute.dataset_id.in_(gtfs_dataset_ids))
|
||||
.order_by(GtfsRoute.dataset_id, GtfsRoute.route_id, GtfsRoute.id)
|
||||
).all()
|
||||
# Reconcile current match rows from auto scoring plus durable manual rules.
|
||||
total_routes = len(route_row_ids)
|
||||
if total_routes == 0:
|
||||
return {"routes": 0, "matches": 0, "missing": 0}
|
||||
|
||||
dependency = _route_matching_dependency(session, active_datasets)
|
||||
run = start_pipeline_run(
|
||||
session,
|
||||
stage=STAGE_MATCH_ROUTES,
|
||||
version=MATCHER_VERSION,
|
||||
dependency_hash_value=dependency_hash(dependency),
|
||||
inputs=dependency,
|
||||
)
|
||||
session.commit()
|
||||
effective_batch_size = max(1, int(batch_size or settings.route_matching_batch_size))
|
||||
_emit_progress(
|
||||
progress_callback,
|
||||
"route_matching_started",
|
||||
f"Matching {total_routes} GTFS routes in batches of {effective_batch_size}.",
|
||||
0,
|
||||
total_routes,
|
||||
{"gtfs_datasets": gtfs_dataset_ids, "osm_datasets": osm_dataset_ids, "batch_size": effective_batch_size},
|
||||
)
|
||||
manual_rules = _manual_match_rules(session)
|
||||
osm_scope_bbox = osm_feature_bbox(session, osm_dataset_ids, kinds=["route"])
|
||||
counts = {"routes": total_routes, "matches": 0, "missing": 0, "manual": 0, "created": 0, "updated": 0, "unchanged": 0}
|
||||
scoped_counts = {"in_osm_scope": 0, "near_osm_scope": 0, "outside_osm_scope": 0, "unknown_scope": 0}
|
||||
processed = 0
|
||||
for chunk in _chunks_int(route_row_ids, effective_batch_size):
|
||||
routes = session.scalars(
|
||||
select(GtfsRoute)
|
||||
.where(GtfsRoute.id.in_(chunk))
|
||||
.order_by(GtfsRoute.dataset_id, GtfsRoute.route_id, GtfsRoute.id)
|
||||
).all()
|
||||
batch_counts = _match_route_batch(
|
||||
session=session,
|
||||
routes=routes,
|
||||
osm_dataset_ids=osm_dataset_ids,
|
||||
dataset_source_ids=dataset_source_ids,
|
||||
manual_rules=manual_rules,
|
||||
osm_scope_bbox=osm_scope_bbox,
|
||||
scoped_counts=scoped_counts,
|
||||
)
|
||||
counts["matches"] += batch_counts["matches"]
|
||||
counts["missing"] += batch_counts["missing"]
|
||||
counts["manual"] += batch_counts["manual"]
|
||||
counts["created"] += batch_counts["created"]
|
||||
counts["updated"] += batch_counts["updated"]
|
||||
counts["unchanged"] += batch_counts["unchanged"]
|
||||
processed += len(routes)
|
||||
session.commit()
|
||||
_emit_progress(
|
||||
progress_callback,
|
||||
"route_matching_batch",
|
||||
f"Matched {processed}/{total_routes} GTFS routes.",
|
||||
processed,
|
||||
total_routes,
|
||||
{
|
||||
"processed": processed,
|
||||
"matches": counts["matches"],
|
||||
"missing": counts["missing"],
|
||||
"manual": counts["manual"],
|
||||
"created": counts["created"],
|
||||
"updated": counts["updated"],
|
||||
"unchanged": counts["unchanged"],
|
||||
"scope": dict(scoped_counts),
|
||||
},
|
||||
)
|
||||
result = {**counts, "scope": scoped_counts}
|
||||
finish_pipeline_run(session, run, outputs=result)
|
||||
session.commit()
|
||||
_emit_progress(
|
||||
progress_callback,
|
||||
"route_matching_completed",
|
||||
"Route matching completed.",
|
||||
total_routes,
|
||||
total_routes,
|
||||
result,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _route_matching_dependency(session: Session, active_datasets) -> dict[str, object]:
|
||||
datasets = [
|
||||
{"id": int(dataset_id), "kind": str(kind), "source_id": int(source_id), "sha256": _dataset_sha(session, int(dataset_id))}
|
||||
for dataset_id, kind, source_id in active_datasets
|
||||
]
|
||||
rules = [
|
||||
{
|
||||
"id": int(rule.id),
|
||||
"type": rule.rule_type,
|
||||
"active": bool(rule.active),
|
||||
"selector": rule.selector_json,
|
||||
"action": rule.action_json,
|
||||
}
|
||||
for rule in session.scalars(select(MatchRule).order_by(MatchRule.id)).all()
|
||||
]
|
||||
return {"version": MATCHER_VERSION, "active_datasets": datasets, "manual_rules": rules}
|
||||
|
||||
|
||||
def _dataset_sha(session: Session, dataset_id: int) -> str | None:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
return None if dataset is None else dataset.sha256
|
||||
|
||||
|
||||
def _match_route_batch(
|
||||
*,
|
||||
session: Session,
|
||||
routes: list[GtfsRoute],
|
||||
osm_dataset_ids: list[int],
|
||||
dataset_source_ids: dict[int, int],
|
||||
manual_rules: list[_ManualMatchRule],
|
||||
osm_scope_bbox: tuple[float | None, float | None, float | None, float | None],
|
||||
scoped_counts: dict[str, int],
|
||||
) -> dict[str, int]:
|
||||
matches = 0
|
||||
missing = 0
|
||||
manual = 0
|
||||
payloads: list[_RouteMatchPayload] = []
|
||||
for route in routes:
|
||||
scope = route_match_scope(route, osm_scope_bbox)
|
||||
scoped_counts[scope] = scoped_counts.get(scope, 0) + 1
|
||||
route_source_id = dataset_source_ids.get(route.dataset_id)
|
||||
accepted_rule = _accepted_rule_for_route(manual_rules, route, route_source_id)
|
||||
if accepted_rule is not None:
|
||||
accepted_feature = _feature_for_rule_from_storage(session, osm_dataset_ids, dataset_source_ids, accepted_rule)
|
||||
if accepted_feature is not None:
|
||||
accepted_feature = ensure_main_osm_feature(session, accepted_feature)
|
||||
payloads.append(
|
||||
_RouteMatchPayload(
|
||||
gtfs_route_id=route.id,
|
||||
osm_feature_id=accepted_feature.id,
|
||||
confidence=100.0,
|
||||
status="accepted",
|
||||
rule_source="manual",
|
||||
reasons_json=json.dumps(
|
||||
{"manual_rule_id": accepted_rule.id, "manual": "accepted_match", "scope": scope},
|
||||
separators=(",", ":"),
|
||||
),
|
||||
)
|
||||
)
|
||||
matches += 1
|
||||
manual += 1
|
||||
continue
|
||||
|
||||
if scope == "outside_osm_scope":
|
||||
missing += 1
|
||||
payloads.append(
|
||||
_RouteMatchPayload(
|
||||
gtfs_route_id=route.id,
|
||||
osm_feature_id=None,
|
||||
confidence=0.0,
|
||||
status="missing",
|
||||
rule_source="auto",
|
||||
reasons_json=json.dumps(
|
||||
{
|
||||
"reason": "outside loaded OSM route scope",
|
||||
"scope": scope,
|
||||
},
|
||||
separators=(",", ":"),
|
||||
),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
best_feature: Optional[OsmFeature] = None
|
||||
best_score = 0.0
|
||||
best_reasons: dict[str, object] = {}
|
||||
route_geometry_profile = _geometry_profile(route.geometry_geojson)
|
||||
for feature in candidate_osm_routes_for_route(session, route, osm_dataset_ids):
|
||||
if _is_rejected_pair(manual_rules, route, route_source_id, feature, dataset_source_ids.get(feature.dataset_id)):
|
||||
continue
|
||||
feature_geometry_profile = _geometry_profile(feature.geometry_geojson)
|
||||
score, reasons = score_route_pair(
|
||||
route,
|
||||
feature,
|
||||
route_geometry_profile=route_geometry_profile,
|
||||
feature_geometry_profile=feature_geometry_profile,
|
||||
)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_feature = feature
|
||||
best_reasons = reasons
|
||||
status = _status_from_score(best_score)
|
||||
if best_feature is None or status == "missing":
|
||||
missing += 1
|
||||
best_feature_id = None
|
||||
best_reasons = {
|
||||
"reason": "no OSM candidate above threshold",
|
||||
"scope": scope,
|
||||
"best_score_below_threshold": round(float(best_score), 2) if best_score else 0,
|
||||
"best_reasons": best_reasons,
|
||||
}
|
||||
best_score = 0
|
||||
else:
|
||||
matches += 1
|
||||
best_feature = ensure_main_osm_feature(session, best_feature)
|
||||
best_feature_id = best_feature.id
|
||||
best_reasons["scope"] = scope
|
||||
payloads.append(
|
||||
_RouteMatchPayload(
|
||||
gtfs_route_id=route.id,
|
||||
osm_feature_id=best_feature_id,
|
||||
confidence=round(float(best_score), 2),
|
||||
status=status,
|
||||
rule_source="auto",
|
||||
reasons_json=json.dumps(best_reasons, separators=(",", ":")),
|
||||
)
|
||||
)
|
||||
changes = _apply_route_match_payloads(session, payloads)
|
||||
session.flush()
|
||||
return {"matches": matches, "missing": missing, "manual": manual, **changes}
|
||||
|
||||
|
||||
def _apply_route_match_payloads(session: Session, payloads: list[_RouteMatchPayload]) -> dict[str, int]:
|
||||
if not payloads:
|
||||
return {"created": 0, "updated": 0, "unchanged": 0}
|
||||
route_ids = [payload.gtfs_route_id for payload in payloads]
|
||||
existing_rows = session.scalars(
|
||||
select(RouteMatch).where(RouteMatch.gtfs_route_id.in_(route_ids)).order_by(RouteMatch.gtfs_route_id, RouteMatch.id)
|
||||
).all()
|
||||
existing_by_route: dict[int, list[RouteMatch]] = {}
|
||||
for row in existing_rows:
|
||||
existing_by_route.setdefault(row.gtfs_route_id, []).append(row)
|
||||
|
||||
created = 0
|
||||
updated = 0
|
||||
unchanged = 0
|
||||
duplicate_ids: list[int] = []
|
||||
now = datetime.now(timezone.utc)
|
||||
for payload in payloads:
|
||||
existing = existing_by_route.get(payload.gtfs_route_id, [])
|
||||
current = _preferred_existing_match(existing)
|
||||
if current is None:
|
||||
session.add(
|
||||
RouteMatch(
|
||||
gtfs_route_id=payload.gtfs_route_id,
|
||||
osm_feature_id=payload.osm_feature_id,
|
||||
confidence=payload.confidence,
|
||||
status=payload.status,
|
||||
rule_source=payload.rule_source,
|
||||
reasons_json=payload.reasons_json,
|
||||
)
|
||||
)
|
||||
created += 1
|
||||
continue
|
||||
|
||||
duplicate_ids.extend(row.id for row in existing if row.id != current.id)
|
||||
if _route_match_payload_equal(current, payload):
|
||||
unchanged += 1
|
||||
continue
|
||||
current.osm_feature_id = payload.osm_feature_id
|
||||
current.confidence = payload.confidence
|
||||
current.status = payload.status
|
||||
current.rule_source = payload.rule_source
|
||||
current.reasons_json = payload.reasons_json
|
||||
current.updated_at = now
|
||||
updated += 1
|
||||
|
||||
for chunk in _chunks_int(duplicate_ids, 1000):
|
||||
session.execute(delete(RouteMatch).where(RouteMatch.id.in_(chunk)))
|
||||
return {"created": created, "updated": updated, "unchanged": unchanged}
|
||||
|
||||
|
||||
def _preferred_existing_match(rows: list[RouteMatch]) -> RouteMatch | None:
|
||||
if not rows:
|
||||
return None
|
||||
return next((row for row in rows if row.rule_source == "manual"), rows[0])
|
||||
|
||||
|
||||
def _route_match_payload_equal(row: RouteMatch, payload: _RouteMatchPayload) -> bool:
|
||||
return (
|
||||
row.osm_feature_id == payload.osm_feature_id
|
||||
and round(float(row.confidence or 0), 2) == round(float(payload.confidence or 0), 2)
|
||||
and row.status == payload.status
|
||||
and row.rule_source == payload.rule_source
|
||||
and (row.reasons_json or None) == (payload.reasons_json or None)
|
||||
)
|
||||
|
||||
|
||||
def _build_osm_route_index(osm_routes: list[OsmFeature]) -> _OsmRouteIndex:
|
||||
by_ref: dict[str, list[OsmFeature]] = {}
|
||||
by_route_key: dict[str, list[OsmFeature]] = {}
|
||||
by_mode: dict[str, list[OsmFeature]] = {}
|
||||
for feature in osm_routes:
|
||||
ref = norm_ref(feature.ref or "")
|
||||
if ref:
|
||||
by_ref.setdefault(ref, []).append(feature)
|
||||
if feature.route_key:
|
||||
by_route_key.setdefault(feature.route_key, []).append(feature)
|
||||
if feature.mode:
|
||||
by_mode.setdefault(feature.mode, []).append(feature)
|
||||
return _OsmRouteIndex(all_routes=osm_routes, by_ref=by_ref, by_route_key=by_route_key, by_mode=by_mode)
|
||||
|
||||
|
||||
def _candidate_osm_routes(route: GtfsRoute, index: _OsmRouteIndex) -> list[OsmFeature]:
|
||||
selected: list[OsmFeature] = []
|
||||
seen: set[int] = set()
|
||||
|
||||
def add(features: list[OsmFeature], *, require_compatible_mode: bool = True) -> None:
|
||||
for feature in features:
|
||||
if feature.id in seen:
|
||||
continue
|
||||
if require_compatible_mode and not _mode_compatible(route.mode or "", feature.mode or ""):
|
||||
continue
|
||||
seen.add(feature.id)
|
||||
selected.append(feature)
|
||||
|
||||
route_ref = norm_ref(route.short_name or route.route_id)
|
||||
if route_ref:
|
||||
add(index.by_ref.get(route_ref, []))
|
||||
if route.route_key:
|
||||
add(index.by_route_key.get(route.route_key, []))
|
||||
if selected:
|
||||
return _spatially_ranked_candidates(route, selected, MAX_EXACT_REF_CANDIDATES)
|
||||
|
||||
compatible_modes = MODE_GROUPS.get(route.mode or "", {route.mode or ""})
|
||||
mode_candidates: list[OsmFeature] = []
|
||||
for mode in compatible_modes:
|
||||
if mode:
|
||||
mode_candidates.extend(index.by_mode.get(mode, []))
|
||||
if not mode_candidates:
|
||||
mode_candidates = index.all_routes
|
||||
|
||||
gtfs_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
|
||||
near_candidates: list[tuple[float, OsmFeature]] = []
|
||||
for feature in mode_candidates:
|
||||
osm_bbox = (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
|
||||
distance = approx_bbox_center_distance_deg(gtfs_bbox, osm_bbox)
|
||||
if bbox_overlap(gtfs_bbox, osm_bbox):
|
||||
near_candidates.append((0.0, feature))
|
||||
elif distance is not None and distance < 0.12:
|
||||
near_candidates.append((distance, feature))
|
||||
fallback_limit = MAX_FALLBACK_CANDIDATES_WITH_REF if route_ref else MAX_FALLBACK_CANDIDATES_WITHOUT_REF
|
||||
fallback = [feature for _, feature in sorted(near_candidates, key=lambda item: item[0])[:fallback_limit]]
|
||||
if not fallback:
|
||||
fallback = mode_candidates[:fallback_limit]
|
||||
add(fallback)
|
||||
return _spatially_ranked_candidates(route, selected, fallback_limit)
|
||||
|
||||
|
||||
def candidate_osm_routes_for_route(session: Session, route: GtfsRoute, osm_dataset_ids: list[int]) -> list[OsmFeature]:
|
||||
if not osm_dataset_ids:
|
||||
return []
|
||||
selected: list[OsmFeature] = []
|
||||
seen: set[tuple[int, str, str]] = set()
|
||||
|
||||
def add(features: list[OsmFeature], *, require_compatible_mode: bool = True) -> None:
|
||||
for feature in features:
|
||||
key = (feature.dataset_id, feature.osm_type, feature.osm_id)
|
||||
if key in seen:
|
||||
continue
|
||||
if require_compatible_mode and not _mode_compatible(route.mode or "", feature.mode or ""):
|
||||
continue
|
||||
seen.add(key)
|
||||
selected.append(feature)
|
||||
|
||||
route_ref = norm_ref(route.short_name or route.route_id)
|
||||
route_keys = [key for key in [route.route_key, route_ref] if key]
|
||||
for route_key in dict.fromkeys(route_keys):
|
||||
add(
|
||||
query_osm_features(
|
||||
session,
|
||||
osm_dataset_ids,
|
||||
kinds=["route"],
|
||||
route_key=route_key,
|
||||
)
|
||||
)
|
||||
if selected:
|
||||
return _spatially_ranked_candidates(route, selected, MAX_EXACT_REF_CANDIDATES)
|
||||
|
||||
gtfs_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
|
||||
compatible_modes = sorted(MODE_GROUPS.get(route.mode or "", {route.mode or ""}) - {""})
|
||||
if not any(value is None for value in gtfs_bbox):
|
||||
bbox = _expanded_bbox(gtfs_bbox, 0.10)
|
||||
add(
|
||||
query_osm_features(
|
||||
session,
|
||||
osm_dataset_ids,
|
||||
kinds=["route"],
|
||||
modes=compatible_modes or None,
|
||||
bbox=bbox,
|
||||
limit=MAX_FALLBACK_CANDIDATES_WITHOUT_REF * 4,
|
||||
),
|
||||
require_compatible_mode=False,
|
||||
)
|
||||
if not selected:
|
||||
add(
|
||||
query_osm_features(
|
||||
session,
|
||||
osm_dataset_ids,
|
||||
kinds=["route"],
|
||||
modes=compatible_modes or None,
|
||||
limit=MAX_FALLBACK_CANDIDATES_WITHOUT_REF,
|
||||
),
|
||||
require_compatible_mode=False,
|
||||
)
|
||||
fallback_limit = MAX_FALLBACK_CANDIDATES_WITH_REF if route_ref else MAX_FALLBACK_CANDIDATES_WITHOUT_REF
|
||||
return _spatially_ranked_candidates(route, selected, fallback_limit)
|
||||
|
||||
|
||||
def score_route_pair(
|
||||
route: GtfsRoute,
|
||||
feature: OsmFeature,
|
||||
route_geometry_profile: _GeometryProfile | None = None,
|
||||
feature_geometry_profile: _GeometryProfile | None = None,
|
||||
) -> tuple[float, dict[str, object]]:
|
||||
score = 0.0
|
||||
reasons: dict[str, object] = {}
|
||||
|
||||
gtfs_mode = route.mode or ""
|
||||
osm_mode = feature.mode or ""
|
||||
if _mode_compatible(gtfs_mode, osm_mode):
|
||||
score += 25
|
||||
reasons["mode"] = "compatible"
|
||||
elif gtfs_mode and osm_mode:
|
||||
reasons["mode"] = f"mismatch: {gtfs_mode} != {osm_mode}"
|
||||
return 0.0, reasons
|
||||
|
||||
gtfs_ref = norm_ref(route.short_name or route.route_id)
|
||||
osm_ref = norm_ref(feature.ref or "")
|
||||
if gtfs_ref and osm_ref:
|
||||
if gtfs_ref == osm_ref:
|
||||
score += 25
|
||||
reasons["ref"] = "exact"
|
||||
elif gtfs_ref in osm_ref or osm_ref in gtfs_ref:
|
||||
score += 15
|
||||
reasons["ref"] = "partial"
|
||||
|
||||
gtfs_name = norm_text(" ".join(v for v in [route.long_name, route.short_name, route.route_id] if v))
|
||||
osm_name = norm_text(" ".join(v for v in [feature.name, feature.ref] if v))
|
||||
name_similarity = _ratio(gtfs_name, osm_name)
|
||||
score += 20 * name_similarity
|
||||
reasons["name_similarity"] = round(name_similarity, 3)
|
||||
|
||||
gtfs_operator = norm_text(route.operator_name or "")
|
||||
osm_operator = norm_text(" ".join(v for v in [feature.operator, feature.network] if v))
|
||||
operator_similarity = _ratio(gtfs_operator, osm_operator) if gtfs_operator and osm_operator else 0
|
||||
score += 15 * operator_similarity
|
||||
reasons["operator_similarity"] = round(operator_similarity, 3)
|
||||
|
||||
gtfs_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
|
||||
osm_bbox = (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
|
||||
center_distance = None
|
||||
if bbox_overlap(gtfs_bbox, osm_bbox):
|
||||
score += 14
|
||||
reasons["bbox"] = "overlap"
|
||||
if gtfs_ref and osm_ref and gtfs_ref == osm_ref and _mode_compatible(gtfs_mode, osm_mode):
|
||||
score += 8
|
||||
reasons["line_identity"] = "exact_ref_mode_bbox_overlap"
|
||||
else:
|
||||
center_distance = approx_bbox_center_distance_deg(gtfs_bbox, osm_bbox)
|
||||
if center_distance is not None:
|
||||
if center_distance < 0.01:
|
||||
score += 12
|
||||
elif center_distance < 0.03:
|
||||
score += 8
|
||||
elif center_distance < 0.08:
|
||||
score += 4
|
||||
elif gtfs_ref and osm_ref and gtfs_ref == osm_ref and center_distance > OSM_SCOPE_NEAR_DISTANCE_DEG:
|
||||
score -= 8
|
||||
reasons["spatial_penalty"] = "exact_ref_far_bbox_center"
|
||||
reasons["bbox_center_distance_deg"] = round(center_distance, 5)
|
||||
|
||||
geometry_metrics = (
|
||||
_geometry_match_metrics_from_profiles(route_geometry_profile, feature_geometry_profile)
|
||||
if route_geometry_profile is not None and feature_geometry_profile is not None
|
||||
else _geometry_match_metrics(route.geometry_geojson, feature.geometry_geojson)
|
||||
)
|
||||
if geometry_metrics is not None:
|
||||
reasons["geometry"] = geometry_metrics
|
||||
geometry_score = 34 * float(geometry_metrics["gtfs_on_osm_ratio"]) + 8 * float(geometry_metrics["osm_on_gtfs_ratio"])
|
||||
if float(geometry_metrics["endpoint_distance_deg"]) < GEOMETRY_PROXIMITY_DEG * 2:
|
||||
geometry_score += 6
|
||||
if float(geometry_metrics["length_ratio"]) < 0.35 or float(geometry_metrics["length_ratio"]) > 2.8:
|
||||
geometry_score -= 8
|
||||
reasons["geometry_length"] = "implausible_ratio"
|
||||
score += max(0.0, min(42.0, geometry_score))
|
||||
|
||||
# Extra small boost for same normalized route key.
|
||||
if route.route_key and feature.route_key and route.route_key == feature.route_key:
|
||||
score += 5
|
||||
reasons["route_key"] = "same"
|
||||
|
||||
if gtfs_ref and osm_ref and gtfs_ref == osm_ref and _mode_compatible(gtfs_mode, osm_mode):
|
||||
if bbox_overlap(gtfs_bbox, osm_bbox):
|
||||
score = max(score, 88.0)
|
||||
reasons["strong_identity"] = "exact_ref_mode_bbox_overlap"
|
||||
elif center_distance is not None and center_distance < 0.02:
|
||||
score = max(score, 82.0)
|
||||
reasons["strong_identity"] = "exact_ref_mode_near_bbox_center"
|
||||
|
||||
if route.route_key and feature.route_key and route.route_key == feature.route_key and _mode_compatible(gtfs_mode, osm_mode):
|
||||
if bbox_overlap(gtfs_bbox, osm_bbox):
|
||||
score = max(score, 86.0)
|
||||
reasons.setdefault("strong_identity", "same_route_key_mode_bbox_overlap")
|
||||
|
||||
if geometry_metrics is not None:
|
||||
gtfs_on_osm = float(geometry_metrics["gtfs_on_osm_ratio"])
|
||||
endpoint_distance = float(geometry_metrics["endpoint_distance_deg"])
|
||||
if gtfs_on_osm >= 0.82 and endpoint_distance < GEOMETRY_PROXIMITY_DEG * 3 and _mode_compatible(gtfs_mode, osm_mode):
|
||||
if gtfs_ref and osm_ref and gtfs_ref == osm_ref:
|
||||
score = max(score, 90.0)
|
||||
reasons["strong_identity"] = "exact_ref_mode_geometry_overlap"
|
||||
elif gtfs_ref and osm_ref and (gtfs_ref in osm_ref or osm_ref in gtfs_ref):
|
||||
score = max(score, 82.0)
|
||||
reasons["strong_identity"] = "partial_ref_mode_geometry_overlap"
|
||||
|
||||
if (
|
||||
gtfs_ref
|
||||
and osm_ref
|
||||
and gtfs_ref == osm_ref
|
||||
and center_distance is not None
|
||||
and center_distance > OSM_SCOPE_NEAR_DISTANCE_DEG
|
||||
and not bbox_overlap(gtfs_bbox, osm_bbox)
|
||||
and (
|
||||
geometry_metrics is None
|
||||
or float(geometry_metrics.get("gtfs_on_osm_ratio", 0.0)) < 0.25
|
||||
)
|
||||
):
|
||||
score = min(score, 58.0)
|
||||
reasons["spatial_cap"] = "exact_ref_far_without_geometry_overlap"
|
||||
|
||||
return min(score, 100.0), reasons
|
||||
|
||||
|
||||
def route_match_scope(route: GtfsRoute, osm_scope_bbox: tuple[float | None, float | None, float | None, float | None]) -> str:
|
||||
route_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
|
||||
if any(value is None for value in route_bbox) or any(value is None for value in osm_scope_bbox):
|
||||
return "unknown_scope"
|
||||
if bbox_overlap(route_bbox, osm_scope_bbox):
|
||||
return "in_osm_scope"
|
||||
distance = approx_bbox_center_distance_deg(route_bbox, osm_scope_bbox)
|
||||
if distance is not None and distance < OSM_SCOPE_NEAR_DISTANCE_DEG:
|
||||
return "near_osm_scope"
|
||||
return "outside_osm_scope"
|
||||
|
||||
|
||||
def _combined_bbox(features: list[OsmFeature]) -> tuple[float | None, float | None, float | None, float | None]:
|
||||
boxes = [
|
||||
(feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
|
||||
for feature in features
|
||||
if None not in (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
|
||||
]
|
||||
if not boxes:
|
||||
return (None, None, None, None)
|
||||
return (
|
||||
min(float(box[0]) for box in boxes if box[0] is not None),
|
||||
min(float(box[1]) for box in boxes if box[1] is not None),
|
||||
max(float(box[2]) for box in boxes if box[2] is not None),
|
||||
max(float(box[3]) for box in boxes if box[3] is not None),
|
||||
)
|
||||
|
||||
|
||||
def _spatially_ranked_candidates(route: GtfsRoute, candidates: list[OsmFeature], limit: int) -> list[OsmFeature]:
|
||||
return [
|
||||
feature
|
||||
for _, feature in sorted(
|
||||
((_spatial_rank(route, feature), feature) for feature in candidates),
|
||||
key=lambda item: item[0],
|
||||
)[: max(1, limit)]
|
||||
]
|
||||
|
||||
|
||||
def _spatial_rank(route: GtfsRoute, feature: OsmFeature) -> tuple[int, float, str]:
|
||||
route_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
|
||||
feature_bbox = (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
|
||||
distance = approx_bbox_center_distance_deg(route_bbox, feature_bbox)
|
||||
if bbox_overlap(route_bbox, feature_bbox):
|
||||
bucket = 0
|
||||
elif distance is not None and distance < OSM_SCOPE_NEAR_DISTANCE_DEG:
|
||||
bucket = 1
|
||||
elif distance is not None:
|
||||
bucket = 2
|
||||
else:
|
||||
bucket = 3
|
||||
return (bucket, distance if distance is not None else 999.0, feature.osm_id)
|
||||
|
||||
|
||||
def _expanded_bbox(
|
||||
bbox: tuple[float | None, float | None, float | None, float | None],
|
||||
padding: float,
|
||||
) -> tuple[float, float, float, float] | None:
|
||||
min_lon, min_lat, max_lon, max_lat = bbox
|
||||
if None in (min_lon, min_lat, max_lon, max_lat):
|
||||
return None
|
||||
return (float(min_lon) - padding, float(min_lat) - padding, float(max_lon) + padding, float(max_lat) + padding)
|
||||
|
||||
|
||||
def _chunks_int(values: list[int], size: int) -> list[list[int]]:
|
||||
return [values[start : start + size] for start in range(0, len(values), max(1, size))]
|
||||
|
||||
|
||||
def _emit_progress(
|
||||
progress_callback: ProgressCallback | None,
|
||||
event_type: str,
|
||||
message: str,
|
||||
progress_current: int | None,
|
||||
progress_total: int | None,
|
||||
metadata: dict[str, object] | None = None,
|
||||
) -> None:
|
||||
if progress_callback is not None:
|
||||
progress_callback(event_type, message, progress_current, progress_total, metadata)
|
||||
|
||||
|
||||
def _geometry_match_metrics(route_geometry: str | None, feature_geometry: str | None) -> dict[str, float] | None:
|
||||
route_profile = _geometry_profile(route_geometry)
|
||||
feature_profile = _geometry_profile(feature_geometry)
|
||||
return _geometry_match_metrics_from_profiles(route_profile, feature_profile)
|
||||
|
||||
|
||||
def _geometry_profile(geometry_text: str | None) -> _GeometryProfile | None:
|
||||
if not geometry_text:
|
||||
return None
|
||||
try:
|
||||
geom = shape(json.loads(geometry_text))
|
||||
except Exception: # noqa: BLE001 - malformed geometry should not break matching
|
||||
return None
|
||||
lines = _iter_lines(geom)
|
||||
if not lines:
|
||||
return None
|
||||
length = sum(line.length for line in lines)
|
||||
if length == 0:
|
||||
return None
|
||||
sample_points = _sample_line_points(lines, GEOMETRY_SAMPLE_POINTS)
|
||||
if not sample_points:
|
||||
return None
|
||||
return _GeometryProfile(geom=geom, lines=lines, length=length, sample_points=sample_points)
|
||||
|
||||
|
||||
def _geometry_match_metrics_from_profiles(
|
||||
route_profile: _GeometryProfile | None, feature_profile: _GeometryProfile | None
|
||||
) -> dict[str, float] | None:
|
||||
if route_profile is None or feature_profile is None:
|
||||
return None
|
||||
gtfs_on_osm = _near_point_ratio(route_profile.sample_points, feature_profile.geom, GEOMETRY_PROXIMITY_DEG)
|
||||
osm_on_gtfs = _near_point_ratio(feature_profile.sample_points, route_profile.geom, GEOMETRY_PROXIMITY_DEG)
|
||||
endpoint_distance = _endpoint_distance(route_profile.lines, feature_profile.geom)
|
||||
length_ratio = route_profile.length / feature_profile.length if feature_profile.length else 0.0
|
||||
return {
|
||||
"gtfs_on_osm_ratio": round(gtfs_on_osm, 3),
|
||||
"osm_on_gtfs_ratio": round(osm_on_gtfs, 3),
|
||||
"endpoint_distance_deg": round(endpoint_distance, 6),
|
||||
"length_ratio": round(length_ratio, 3),
|
||||
}
|
||||
|
||||
|
||||
def _iter_lines(geom) -> list[LineString]:
|
||||
if isinstance(geom, LineString):
|
||||
return [geom]
|
||||
if isinstance(geom, MultiLineString):
|
||||
return [line for line in geom.geoms if isinstance(line, LineString) and line.length > 0]
|
||||
return []
|
||||
|
||||
|
||||
def _sample_line_points(lines: list[LineString], count: int) -> list[Point]:
|
||||
total_length = sum(line.length for line in lines)
|
||||
if total_length == 0:
|
||||
return []
|
||||
points = []
|
||||
for index in range(count):
|
||||
target = total_length * (index / max(1, count - 1))
|
||||
traversed = 0.0
|
||||
for line in lines:
|
||||
next_traversed = traversed + line.length
|
||||
if target <= next_traversed or line is lines[-1]:
|
||||
points.append(line.interpolate(max(0.0, min(line.length, target - traversed))))
|
||||
break
|
||||
traversed = next_traversed
|
||||
return points
|
||||
|
||||
|
||||
def _near_point_ratio(points: list[Point], geom, max_distance: float) -> float:
|
||||
if not points:
|
||||
return 0.0
|
||||
near = sum(1 for point in points if geom.distance(point) <= max_distance)
|
||||
return near / len(points)
|
||||
|
||||
|
||||
def _endpoint_distance(gtfs_lines: list[LineString], osm_geom) -> float:
|
||||
longest = max(gtfs_lines, key=lambda line: line.length)
|
||||
coords = list(longest.coords)
|
||||
if len(coords) < 2:
|
||||
return 999.0
|
||||
return osm_geom.distance(Point(coords[0])) + osm_geom.distance(Point(coords[-1]))
|
||||
|
||||
|
||||
def _manual_match_rules(session: Session) -> list[_ManualMatchRule]:
|
||||
rules = session.scalars(
|
||||
select(MatchRule)
|
||||
.where(MatchRule.active.is_(True), MatchRule.rule_type.in_(["accept_match", "reject_match"]))
|
||||
.order_by(MatchRule.id.desc())
|
||||
).all()
|
||||
parsed: list[_ManualMatchRule] = []
|
||||
for rule in rules:
|
||||
try:
|
||||
selector = json.loads(rule.selector_json or "{}")
|
||||
action = json.loads(rule.action_json or "{}")
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
route_selector = selector.get("gtfs") if isinstance(selector.get("gtfs"), dict) else selector
|
||||
osm_selector = action.get("osm") if isinstance(action.get("osm"), dict) else selector.get("osm")
|
||||
if not isinstance(osm_selector, dict) and selector.get("osm_feature_id") is not None:
|
||||
osm_selector = {"osm_feature_id": selector.get("osm_feature_id")}
|
||||
status = str(action.get("status") or ("accepted" if rule.rule_type == "accept_match" else "rejected"))
|
||||
parsed.append(
|
||||
_ManualMatchRule(
|
||||
id=rule.id,
|
||||
rule_type=rule.rule_type,
|
||||
route_selector=route_selector,
|
||||
osm_selector=osm_selector if isinstance(osm_selector, dict) else None,
|
||||
status=status,
|
||||
)
|
||||
)
|
||||
return parsed
|
||||
|
||||
|
||||
def _accepted_rule_for_route(
|
||||
rules: list[_ManualMatchRule], route: GtfsRoute, route_source_id: int | None
|
||||
) -> _ManualMatchRule | None:
|
||||
for rule in rules:
|
||||
if rule.rule_type != "accept_match":
|
||||
continue
|
||||
if rule.status != "accepted":
|
||||
continue
|
||||
if _route_matches_selector(route, route_source_id, rule.route_selector):
|
||||
return rule
|
||||
return None
|
||||
|
||||
|
||||
def _feature_for_rule(
|
||||
features: list[OsmFeature], dataset_source_ids: dict[int, int], rule: _ManualMatchRule
|
||||
) -> OsmFeature | None:
|
||||
if not rule.osm_selector:
|
||||
return None
|
||||
for feature in features:
|
||||
if _feature_matches_selector(feature, dataset_source_ids.get(feature.dataset_id), rule.osm_selector):
|
||||
return feature
|
||||
return None
|
||||
|
||||
|
||||
def _feature_for_rule_from_storage(
|
||||
session: Session,
|
||||
osm_dataset_ids: list[int],
|
||||
dataset_source_ids: dict[int, int],
|
||||
rule: _ManualMatchRule,
|
||||
) -> OsmFeature | None:
|
||||
if not rule.osm_selector:
|
||||
return None
|
||||
selector = rule.osm_selector
|
||||
legacy_id = _safe_int(selector.get("osm_feature_id"))
|
||||
if legacy_id is not None:
|
||||
feature = session.get(OsmFeature, legacy_id)
|
||||
if feature is not None and _feature_matches_selector(feature, dataset_source_ids.get(feature.dataset_id), selector):
|
||||
return feature
|
||||
scoped_dataset_ids = list(osm_dataset_ids)
|
||||
expected_source = selector.get("source_id")
|
||||
if expected_source is not None:
|
||||
expected_source_id = _safe_int(expected_source)
|
||||
if expected_source_id is not None:
|
||||
scoped_dataset_ids = [
|
||||
dataset_id
|
||||
for dataset_id in scoped_dataset_ids
|
||||
if dataset_source_ids.get(dataset_id) == expected_source_id
|
||||
]
|
||||
dataset_id = _safe_int(selector.get("dataset_id"))
|
||||
if dataset_id is not None:
|
||||
scoped_dataset_ids = [value for value in scoped_dataset_ids if value == dataset_id]
|
||||
if not scoped_dataset_ids:
|
||||
return None
|
||||
|
||||
features: list[OsmFeature] = []
|
||||
osm_type = selector.get("osm_type")
|
||||
osm_id = selector.get("osm_id")
|
||||
if osm_type and osm_id:
|
||||
features = query_osm_features(
|
||||
session,
|
||||
scoped_dataset_ids,
|
||||
kinds=["route"],
|
||||
osm_type=str(osm_type),
|
||||
osm_id=str(osm_id),
|
||||
limit=10,
|
||||
)
|
||||
if not features:
|
||||
route_key = selector.get("route_key")
|
||||
if route_key:
|
||||
features = query_osm_features(session, scoped_dataset_ids, kinds=["route"], route_key=str(route_key))
|
||||
if not features:
|
||||
ref = norm_ref(selector.get("ref"))
|
||||
if ref:
|
||||
features = query_osm_features(session, scoped_dataset_ids, kinds=["route"], route_key=ref)
|
||||
for feature in features:
|
||||
if _feature_matches_selector(feature, dataset_source_ids.get(feature.dataset_id), selector):
|
||||
return feature
|
||||
return None
|
||||
|
||||
|
||||
def _is_rejected_pair(
|
||||
rules: list[_ManualMatchRule],
|
||||
route: GtfsRoute,
|
||||
route_source_id: int | None,
|
||||
feature: OsmFeature,
|
||||
feature_source_id: int | None,
|
||||
) -> bool:
|
||||
for rule in rules:
|
||||
if rule.rule_type != "reject_match":
|
||||
continue
|
||||
if not _route_matches_selector(route, route_source_id, rule.route_selector):
|
||||
continue
|
||||
if rule.osm_selector and _feature_matches_selector(feature, feature_source_id, rule.osm_selector):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _route_matches_selector(route: GtfsRoute, source_id: int | None, selector: dict[str, object]) -> bool:
|
||||
legacy_id = selector.get("gtfs_route_id")
|
||||
if legacy_id is not None and _safe_int(legacy_id) == route.id:
|
||||
return True
|
||||
expected_source = selector.get("source_id")
|
||||
if expected_source is not None and source_id is not None and _safe_int(expected_source) != source_id:
|
||||
return False
|
||||
route_id = selector.get("route_id")
|
||||
if route_id and str(route_id) == route.route_id:
|
||||
return True
|
||||
route_key = selector.get("route_key")
|
||||
if route_key and route.route_key and str(route_key) == route.route_key:
|
||||
return True
|
||||
ref = norm_ref(selector.get("ref"))
|
||||
mode = selector.get("mode")
|
||||
if ref and ref == norm_ref(route.short_name or route.route_id):
|
||||
return not mode or _mode_compatible(str(mode), route.mode or "")
|
||||
return False
|
||||
|
||||
|
||||
def _feature_matches_selector(feature: OsmFeature, source_id: int | None, selector: dict[str, object]) -> bool:
|
||||
legacy_id = selector.get("osm_feature_id")
|
||||
if legacy_id is not None and _safe_int(legacy_id) == feature.id:
|
||||
return True
|
||||
expected_source = selector.get("source_id")
|
||||
if expected_source is not None and source_id is not None and _safe_int(expected_source) != source_id:
|
||||
return False
|
||||
osm_type = selector.get("osm_type")
|
||||
osm_id = selector.get("osm_id")
|
||||
if osm_type and osm_id and str(osm_type) == feature.osm_type and str(osm_id) == feature.osm_id:
|
||||
return True
|
||||
route_key = selector.get("route_key")
|
||||
if route_key and feature.route_key and str(route_key) == feature.route_key:
|
||||
return True
|
||||
ref = norm_ref(selector.get("ref"))
|
||||
mode = selector.get("mode")
|
||||
if ref and ref == norm_ref(feature.ref or ""):
|
||||
return not mode or _mode_compatible(str(mode), feature.mode or "")
|
||||
return False
|
||||
|
||||
|
||||
def _safe_int(value: object) -> int | None:
|
||||
try:
|
||||
return int(value) # type: ignore[arg-type]
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _mode_compatible(gtfs_mode: str, osm_mode: str) -> bool:
|
||||
if not gtfs_mode or not osm_mode:
|
||||
return True
|
||||
if gtfs_mode == osm_mode:
|
||||
return True
|
||||
return osm_mode in MODE_GROUPS.get(gtfs_mode, {gtfs_mode}) or gtfs_mode in MODE_GROUPS.get(osm_mode, {osm_mode})
|
||||
|
||||
|
||||
def _ratio(a: str, b: str) -> float:
|
||||
if not a or not b:
|
||||
return 0.0
|
||||
if a == b:
|
||||
return 1.0
|
||||
token_ratio = _token_similarity(a, b)
|
||||
if a in b or b in a:
|
||||
token_ratio = max(token_ratio, 0.82)
|
||||
return token_ratio
|
||||
|
||||
|
||||
def _token_similarity(a: str, b: str) -> float:
|
||||
left = set(a.split())
|
||||
right = set(b.split())
|
||||
if not left or not right:
|
||||
return 0.0
|
||||
return len(left & right) / len(left | right)
|
||||
|
||||
|
||||
def _status_from_score(score: float) -> str:
|
||||
if score >= 85:
|
||||
return "matched"
|
||||
if score >= 65:
|
||||
return "probable"
|
||||
if score >= 40:
|
||||
return "weak"
|
||||
return "missing"
|
||||
508
app/pipeline/osm_addresses.py
Normal file
508
app/pipeline/osm_addresses.py
Normal file
@@ -0,0 +1,508 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import osmium
|
||||
from sqlalchemy import delete, func, select, text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models import Dataset, OsmAddress
|
||||
from app.pipeline.routing_layer import active_routing_dataset
|
||||
from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries
|
||||
|
||||
|
||||
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
|
||||
ADDRESS_INDEX_VERSION = "osm_addresses_v2_nodes_ways_area_geometry"
|
||||
ADDRESS_TAGS = {
|
||||
"addr:housenumber",
|
||||
"addr:housename",
|
||||
"addr:street",
|
||||
"addr:place",
|
||||
"addr:postcode",
|
||||
"addr:city",
|
||||
"addr:country",
|
||||
"addr:unit",
|
||||
"addr:suburb",
|
||||
"addr:district",
|
||||
"addr:municipality",
|
||||
"entrance",
|
||||
"name",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddressIndexResult:
|
||||
dataset_id: int
|
||||
input_path: str
|
||||
addresses: int
|
||||
node_addresses: int
|
||||
way_addresses: int
|
||||
skipped: int
|
||||
version: str = ADDRESS_INDEX_VERSION
|
||||
|
||||
def as_dict(self) -> dict[str, object]:
|
||||
return {
|
||||
"version": self.version,
|
||||
"dataset_id": self.dataset_id,
|
||||
"input_path": self.input_path,
|
||||
"addresses": self.addresses,
|
||||
"node_addresses": self.node_addresses,
|
||||
"way_addresses": self.way_addresses,
|
||||
"skipped": self.skipped,
|
||||
}
|
||||
|
||||
|
||||
def rebuild_address_index(
|
||||
session: Session,
|
||||
*,
|
||||
dataset_id: int | None = None,
|
||||
input_path: str | Path | None = None,
|
||||
reset: bool = True,
|
||||
batch_size: int = 20_000,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
) -> dict[str, object]:
|
||||
dataset = session.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(session)
|
||||
if dataset is None:
|
||||
raise ValueError("No OSM PBF dataset is available for address indexing.")
|
||||
path = Path(input_path or dataset.local_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Address index PBF does not exist: {path}")
|
||||
|
||||
if reset:
|
||||
_emit(progress_callback, "address_index_clear_started", "Clearing existing OSM address index.", None, None, {"dataset_id": dataset.id})
|
||||
_clear_address_rows(session, dataset_id=int(dataset.id))
|
||||
session.commit()
|
||||
|
||||
if settings.is_postgresql_database:
|
||||
_emit(progress_callback, "address_index_indexes_dropped", "Dropping address lookup indexes before bulk import.", None, None, {"dataset_id": dataset.id})
|
||||
_drop_address_indexes(session)
|
||||
session.commit()
|
||||
|
||||
_emit(progress_callback, "address_index_import_started", "Importing OSM address nodes and ways.", None, None, {"dataset_id": dataset.id, "path": str(path)})
|
||||
handler = _AddressHandler(
|
||||
session=session,
|
||||
dataset_id=dataset.id,
|
||||
batch_size=batch_size,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
if hasattr(osmium, "FileProcessor"):
|
||||
_apply_address_file_processor(handler, path)
|
||||
else:
|
||||
handler.apply_file(str(path), locations=True)
|
||||
handler.flush()
|
||||
|
||||
return finalize_address_index(
|
||||
session,
|
||||
dataset_id=dataset.id,
|
||||
input_path=path,
|
||||
node_addresses=handler.node_address_count,
|
||||
way_addresses=handler.way_address_count,
|
||||
skipped=handler.skipped_count,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
|
||||
def finalize_address_index(
|
||||
session: Session,
|
||||
*,
|
||||
dataset_id: int,
|
||||
input_path: str | Path,
|
||||
node_addresses: int = 0,
|
||||
way_addresses: int = 0,
|
||||
skipped: int = 0,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
) -> dict[str, object]:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if dataset is None:
|
||||
raise ValueError("Address index dataset does not exist.")
|
||||
if settings.is_postgresql_database:
|
||||
_emit(progress_callback, "address_index_geometry_started", "Refreshing address point geometries.", None, None, {"dataset_id": dataset.id})
|
||||
refresh_postgis_geometries(session, dataset_id=dataset.id, tables=["osm_addresses"], only_missing=False)
|
||||
session.commit()
|
||||
_emit(progress_callback, "address_index_indexes_started", "Rebuilding address lookup indexes.", None, None, {"dataset_id": dataset.id})
|
||||
_create_address_indexes(session)
|
||||
session.commit()
|
||||
analyze_postgresql_tables(session, ["osm_addresses"])
|
||||
address_count = int(session.scalar(select(func.count()).select_from(OsmAddress).where(OsmAddress.dataset_id == dataset.id)) or 0)
|
||||
metadata = _metadata(dataset)
|
||||
metadata["address_index"] = {
|
||||
"version": ADDRESS_INDEX_VERSION,
|
||||
"addresses": address_count,
|
||||
"node_addresses": int(node_addresses),
|
||||
"way_addresses": int(way_addresses),
|
||||
"skipped": int(skipped),
|
||||
"input_path": str(input_path),
|
||||
}
|
||||
dataset.metadata_json = json.dumps(metadata, indent=2)
|
||||
session.commit()
|
||||
result = AddressIndexResult(
|
||||
dataset_id=dataset.id,
|
||||
input_path=str(input_path),
|
||||
addresses=address_count,
|
||||
node_addresses=node_addresses,
|
||||
way_addresses=way_addresses,
|
||||
skipped=skipped,
|
||||
).as_dict()
|
||||
_emit(progress_callback, "address_index_import_completed", "OSM address index import completed.", address_count, address_count, result)
|
||||
return result
|
||||
|
||||
|
||||
def _clear_address_rows(session: Session, *, dataset_id: int) -> None:
|
||||
if settings.is_postgresql_database:
|
||||
other_dataset_count = int(
|
||||
session.scalar(
|
||||
select(func.count(func.distinct(OsmAddress.dataset_id))).where(OsmAddress.dataset_id != int(dataset_id))
|
||||
)
|
||||
or 0
|
||||
)
|
||||
if other_dataset_count == 0:
|
||||
session.execute(text("TRUNCATE TABLE osm_addresses RESTART IDENTITY"))
|
||||
return
|
||||
session.execute(delete(OsmAddress).where(OsmAddress.dataset_id == int(dataset_id)))
|
||||
|
||||
|
||||
def address_index_status(session: Session) -> dict[str, object]:
|
||||
dataset = active_routing_dataset(session)
|
||||
dataset_id = None if dataset is None else int(dataset.id)
|
||||
address_count = 0
|
||||
metadata: dict[str, object] = {}
|
||||
if dataset is not None:
|
||||
metadata = _metadata(dataset).get("address_index") or {}
|
||||
if isinstance(metadata, dict):
|
||||
try:
|
||||
address_count = int(metadata.get("addresses") or 0)
|
||||
except (TypeError, ValueError):
|
||||
address_count = 0
|
||||
if not address_count:
|
||||
address_count = int(session.scalar(select(func.count()).select_from(OsmAddress).where(OsmAddress.dataset_id == dataset.id)) or 0)
|
||||
installed_version = metadata.get("version") if isinstance(metadata, dict) else None
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"addresses": address_count,
|
||||
"available": address_count > 0,
|
||||
"version": installed_version,
|
||||
"current_version": ADDRESS_INDEX_VERSION,
|
||||
"stale": bool(address_count and installed_version != ADDRESS_INDEX_VERSION),
|
||||
"input_path": metadata.get("input_path") if isinstance(metadata, dict) else None,
|
||||
}
|
||||
|
||||
|
||||
class _AddressHandler(osmium.SimpleHandler):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
dataset_id: int,
|
||||
batch_size: int,
|
||||
progress_callback: ProgressCallback | None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.session = session
|
||||
self.dataset_id = int(dataset_id)
|
||||
self.batch_size = max(1_000, int(batch_size))
|
||||
self.progress_callback = progress_callback
|
||||
self.rows: list[dict[str, object]] = []
|
||||
self.address_count = 0
|
||||
self.node_address_count = 0
|
||||
self.way_address_count = 0
|
||||
self.skipped_count = 0
|
||||
self.processed_count = 0
|
||||
|
||||
def node(self, node) -> None:
|
||||
self.process_node(node)
|
||||
|
||||
def way(self, way) -> None:
|
||||
self.process_way(way)
|
||||
|
||||
def process_object(self, obj) -> None:
|
||||
if hasattr(obj, "nodes"):
|
||||
self.process_way(obj)
|
||||
elif hasattr(obj, "location"):
|
||||
self.process_node(obj)
|
||||
|
||||
def process_node(self, node) -> None:
|
||||
self.processed_count += 1
|
||||
tags = {tag.k: tag.v for tag in node.tags}
|
||||
if not _has_address(tags):
|
||||
return
|
||||
if not node.location.valid():
|
||||
self.skipped_count += 1
|
||||
return
|
||||
row = _address_row(
|
||||
dataset_id=self.dataset_id,
|
||||
osm_type="node",
|
||||
osm_id=str(node.id),
|
||||
tags=tags,
|
||||
lon=float(node.location.lon),
|
||||
lat=float(node.location.lat),
|
||||
bounds=(float(node.location.lon), float(node.location.lat), float(node.location.lon), float(node.location.lat)),
|
||||
geometry_geojson=None,
|
||||
)
|
||||
if row is None:
|
||||
self.skipped_count += 1
|
||||
return
|
||||
self.rows.append(row)
|
||||
self.node_address_count += 1
|
||||
self._after_address()
|
||||
|
||||
def process_way(self, way) -> None:
|
||||
self.processed_count += 1
|
||||
tags = {tag.k: tag.v for tag in way.tags}
|
||||
if not _has_address(tags):
|
||||
return
|
||||
coords = [
|
||||
(float(node.location.lon), float(node.location.lat))
|
||||
for node in way.nodes
|
||||
if node.location.valid()
|
||||
]
|
||||
if not coords:
|
||||
self.skipped_count += 1
|
||||
return
|
||||
lon, lat = _centroid(coords)
|
||||
min_lon = min(coord[0] for coord in coords)
|
||||
max_lon = max(coord[0] for coord in coords)
|
||||
min_lat = min(coord[1] for coord in coords)
|
||||
max_lat = max(coord[1] for coord in coords)
|
||||
row = _address_row(
|
||||
dataset_id=self.dataset_id,
|
||||
osm_type="way",
|
||||
osm_id=str(way.id),
|
||||
tags=tags,
|
||||
lon=lon,
|
||||
lat=lat,
|
||||
bounds=(min_lon, min_lat, max_lon, max_lat),
|
||||
geometry_geojson=_address_area_geometry_geojson(coords, closed=_way_is_closed(way)),
|
||||
)
|
||||
if row is None:
|
||||
self.skipped_count += 1
|
||||
return
|
||||
self.rows.append(row)
|
||||
self.way_address_count += 1
|
||||
self._after_address()
|
||||
|
||||
def _after_address(self) -> None:
|
||||
self.address_count += 1
|
||||
if len(self.rows) >= self.batch_size:
|
||||
self.flush()
|
||||
if self.address_count % 50_000 == 0:
|
||||
_emit(
|
||||
self.progress_callback,
|
||||
"address_index_import_batch",
|
||||
f"Imported {self.address_count:,} OSM addresses.",
|
||||
self.address_count,
|
||||
None,
|
||||
{"processed": self.processed_count, "skipped": self.skipped_count},
|
||||
)
|
||||
|
||||
def flush(self) -> None:
|
||||
if not self.rows:
|
||||
return
|
||||
self.session.bulk_insert_mappings(OsmAddress, self.rows)
|
||||
self.session.commit()
|
||||
self.rows = []
|
||||
|
||||
|
||||
def _apply_address_file_processor(handler: _AddressHandler, path: Path) -> None:
|
||||
processor = (
|
||||
osmium.FileProcessor(str(path), osmium.osm.NODE | osmium.osm.WAY)
|
||||
.with_locations()
|
||||
.with_filter(osmium.filter.KeyFilter("addr:housenumber", "addr:housename"))
|
||||
)
|
||||
for obj in processor:
|
||||
handler.process_object(obj)
|
||||
|
||||
|
||||
def _has_address(tags: dict[str, str]) -> bool:
|
||||
housenumber = _clean(tags.get("addr:housenumber") or tags.get("addr:housename"))
|
||||
if not housenumber:
|
||||
return False
|
||||
return any(_clean(tags.get(key)) for key in ("addr:street", "addr:place", "addr:city", "addr:postcode"))
|
||||
|
||||
|
||||
def _address_row(
|
||||
*,
|
||||
dataset_id: int,
|
||||
osm_type: str,
|
||||
osm_id: str,
|
||||
tags: dict[str, str],
|
||||
lon: float,
|
||||
lat: float,
|
||||
bounds: tuple[float, float, float, float],
|
||||
geometry_geojson: str | None = None,
|
||||
) -> dict[str, object] | None:
|
||||
housenumber = _clean(tags.get("addr:housenumber") or tags.get("addr:housename"))
|
||||
street = _clean(tags.get("addr:street"))
|
||||
place = _clean(tags.get("addr:place"))
|
||||
postcode = _clean(tags.get("addr:postcode"))
|
||||
city = _clean(tags.get("addr:city") or tags.get("addr:municipality"))
|
||||
country = _clean(tags.get("addr:country"))
|
||||
unit = _clean(tags.get("addr:unit"))
|
||||
name = _clean(tags.get("name"))
|
||||
display_name = _display_name(housenumber=housenumber, street=street, place=place, postcode=postcode, city=city, name=name)
|
||||
if not display_name:
|
||||
return None
|
||||
search_text = _search_text(display_name, housenumber, street, place, postcode, city, country, unit, name)
|
||||
selected_tags = {key: tags[key] for key in sorted(ADDRESS_TAGS) if key in tags}
|
||||
min_lon, min_lat, max_lon, max_lat = bounds
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"osm_type": osm_type,
|
||||
"osm_id": osm_id,
|
||||
"housenumber": housenumber,
|
||||
"street": street,
|
||||
"place": place,
|
||||
"postcode": postcode,
|
||||
"city": city,
|
||||
"country": country,
|
||||
"unit": unit,
|
||||
"name": name,
|
||||
"display_name": display_name,
|
||||
"search_text": search_text,
|
||||
"lon": lon,
|
||||
"lat": lat,
|
||||
"min_lon": min_lon,
|
||||
"min_lat": min_lat,
|
||||
"max_lon": max_lon,
|
||||
"max_lat": max_lat,
|
||||
"geometry_geojson": geometry_geojson,
|
||||
"tags_json": json.dumps(selected_tags, separators=(",", ":")) if selected_tags else None,
|
||||
}
|
||||
|
||||
|
||||
def _address_area_geometry_geojson(coords: list[tuple[float, float]], *, closed: bool | None = None) -> str | None:
|
||||
if closed is False:
|
||||
return None
|
||||
if len(coords) < 3:
|
||||
return None
|
||||
ring_coords = list(coords)
|
||||
first = ring_coords[0]
|
||||
last = ring_coords[-1]
|
||||
already_closed = abs(first[0] - last[0]) <= 1e-12 and abs(first[1] - last[1]) <= 1e-12
|
||||
if not already_closed:
|
||||
if closed is not True:
|
||||
return None
|
||||
ring_coords.append(first)
|
||||
if len(ring_coords) < 4:
|
||||
return None
|
||||
ring = [[float(lon), float(lat)] for lon, lat in ring_coords]
|
||||
if len({(round(lon, 12), round(lat, 12)) for lon, lat in ring_coords[:-1]}) < 3:
|
||||
return None
|
||||
return json.dumps({"type": "Polygon", "coordinates": [ring]}, separators=(",", ":"))
|
||||
|
||||
|
||||
def _way_is_closed(way) -> bool:
|
||||
try:
|
||||
nodes = way.nodes
|
||||
return len(nodes) >= 3 and nodes[0].ref == nodes[-1].ref
|
||||
except (AttributeError, IndexError, TypeError):
|
||||
return False
|
||||
|
||||
|
||||
def _display_name(
|
||||
*,
|
||||
housenumber: str | None,
|
||||
street: str | None,
|
||||
place: str | None,
|
||||
postcode: str | None,
|
||||
city: str | None,
|
||||
name: str | None,
|
||||
) -> str | None:
|
||||
road = street or place or name
|
||||
if road and housenumber:
|
||||
first = f"{road} {housenumber}"
|
||||
else:
|
||||
first = road or housenumber
|
||||
locality = " ".join(part for part in [postcode, city] if part)
|
||||
if first and locality:
|
||||
return f"{first}, {locality}"
|
||||
return first or locality
|
||||
|
||||
|
||||
def _search_text(*parts: str | None) -> str:
|
||||
return re.sub(r"\s+", " ", " ".join(part.casefold() for part in parts if part)).strip()
|
||||
|
||||
|
||||
def _clean(value: object) -> str | None:
|
||||
cleaned = re.sub(r"\s+", " ", str(value or "")).strip()
|
||||
return cleaned or None
|
||||
|
||||
|
||||
def _centroid(coords: list[tuple[float, float]]) -> tuple[float, float]:
|
||||
if len(coords) >= 4 and coords[0] == coords[-1]:
|
||||
area = 0.0
|
||||
cx = 0.0
|
||||
cy = 0.0
|
||||
for (x1, y1), (x2, y2) in zip(coords, coords[1:]):
|
||||
cross = x1 * y2 - x2 * y1
|
||||
area += cross
|
||||
cx += (x1 + x2) * cross
|
||||
cy += (y1 + y2) * cross
|
||||
if abs(area) > 1e-18:
|
||||
factor = 1 / (3 * area)
|
||||
return cx * factor, cy * factor
|
||||
return (
|
||||
math.fsum(coord[0] for coord in coords) / len(coords),
|
||||
math.fsum(coord[1] for coord in coords) / len(coords),
|
||||
)
|
||||
|
||||
|
||||
def _drop_address_indexes(session: Session) -> None:
|
||||
for name in [
|
||||
"ix_osm_addresses_dataset_city_street",
|
||||
"ix_osm_addresses_dataset_postcode",
|
||||
"ix_osm_addresses_bbox",
|
||||
"ix_osm_addresses_geom_gist",
|
||||
"ix_osm_addresses_area_geom_gist",
|
||||
"ix_osm_addresses_search_trgm",
|
||||
"ix_osm_addresses_display_trgm",
|
||||
"ix_osm_addresses_street_key_house",
|
||||
"ix_osm_addresses_street_key_trgm",
|
||||
]:
|
||||
session.execute(text(f"DROP INDEX IF EXISTS {name}"))
|
||||
|
||||
|
||||
def _create_address_indexes(session: Session) -> None:
|
||||
statements = [
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_dataset_city_street ON osm_addresses (dataset_id, city, street, housenumber)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_dataset_postcode ON osm_addresses (dataset_id, postcode)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_bbox ON osm_addresses (dataset_id, min_lon, max_lon, min_lat, max_lat)",
|
||||
]
|
||||
if settings.is_postgresql_database:
|
||||
statements.extend(
|
||||
[
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_geom_gist ON osm_addresses USING GIST (geom)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_area_geom_gist ON osm_addresses USING GIST (area_geom)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_search_trgm ON osm_addresses USING GIN (LOWER(COALESCE(search_text, '')) gin_trgm_ops)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_display_trgm ON osm_addresses USING GIN (LOWER(COALESCE(display_name, '')) gin_trgm_ops)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_street_key_house ON osm_addresses (dataset_id, REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss'), housenumber)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_street_key_trgm ON osm_addresses USING GIN (REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss') gin_trgm_ops)",
|
||||
]
|
||||
)
|
||||
for statement in statements:
|
||||
session.execute(text(statement))
|
||||
|
||||
|
||||
def _metadata(dataset: Dataset) -> dict[str, object]:
|
||||
try:
|
||||
value = json.loads(dataset.metadata_json or "{}")
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return value if isinstance(value, dict) else {}
|
||||
|
||||
|
||||
def _emit(
|
||||
progress_callback: ProgressCallback | None,
|
||||
event_type: str,
|
||||
message: str,
|
||||
progress_current: int | None,
|
||||
progress_total: int | None,
|
||||
metadata: dict[str, object] | None = None,
|
||||
) -> None:
|
||||
if progress_callback is not None:
|
||||
progress_callback(event_type, message, progress_current, progress_total, metadata)
|
||||
100
app/pipeline/osm_diff.py
Normal file
100
app/pipeline/osm_diff.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models import Dataset, Source
|
||||
from app.pipeline.download import materialize_source
|
||||
from app.pipeline.osm_pbf import _raw_format
|
||||
from app.pipeline.osm_replication import fetch_replication_state
|
||||
from app.pipeline.utils import sha256_file
|
||||
|
||||
|
||||
def run_osm_diff_source(session: Session, source: Source) -> Dataset:
|
||||
"""Commit an OSM change file as a raw update artifact.
|
||||
|
||||
Applying the diff to an authoritative OSM base extract is a separate step;
|
||||
this importer deliberately records the file without treating it as a
|
||||
complete visual route layer.
|
||||
"""
|
||||
if _looks_like_update_directory(source.url):
|
||||
return _commit_update_directory_state(session, source)
|
||||
|
||||
raw_path = materialize_source(source)
|
||||
raw_hash = sha256_file(raw_path)
|
||||
existing = session.scalar(
|
||||
select(Dataset)
|
||||
.where(Dataset.source_id == source.id, Dataset.kind == "osm_diff_raw", Dataset.sha256 == raw_hash)
|
||||
.order_by(Dataset.id.desc())
|
||||
)
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
dataset = Dataset(
|
||||
source_id=source.id,
|
||||
kind="osm_diff_raw",
|
||||
local_path=str(raw_path),
|
||||
sha256=raw_hash,
|
||||
is_active=False,
|
||||
status="committed",
|
||||
metadata_json=json.dumps(
|
||||
{
|
||||
"stage": "raw_osm_diff",
|
||||
"raw_format": _raw_format(raw_path),
|
||||
"source_url": source.url,
|
||||
},
|
||||
indent=2,
|
||||
),
|
||||
)
|
||||
session.add(dataset)
|
||||
session.flush()
|
||||
return dataset
|
||||
|
||||
|
||||
def _commit_update_directory_state(session: Session, source: Source) -> Dataset:
|
||||
state = fetch_replication_state(source.url, timeout=settings.osm_diff_state_timeout_seconds)
|
||||
source_dir = settings.data_dir / "sources" / f"source_{source.id}"
|
||||
source_dir.mkdir(parents=True, exist_ok=True)
|
||||
state_path = source_dir / f"state_{state.sequence_number}.txt"
|
||||
state_path.write_text(
|
||||
"\n".join(f"{key}={value}" for key, value in sorted(state.raw.items())) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
state_hash = sha256_file(state_path)
|
||||
existing = session.scalar(
|
||||
select(Dataset)
|
||||
.where(Dataset.source_id == source.id, Dataset.kind == "osm_diff_state", Dataset.sha256 == state_hash)
|
||||
.order_by(Dataset.id.desc())
|
||||
)
|
||||
if existing is not None:
|
||||
return existing
|
||||
dataset = Dataset(
|
||||
source_id=source.id,
|
||||
kind="osm_diff_state",
|
||||
local_path=str(state_path),
|
||||
sha256=state_hash,
|
||||
is_active=False,
|
||||
status="committed",
|
||||
metadata_json=json.dumps(
|
||||
{
|
||||
"stage": "osm_diff_state",
|
||||
"updates_url": source.url,
|
||||
"sequence_number": state.sequence_number,
|
||||
"timestamp": state.timestamp,
|
||||
"state": state.raw,
|
||||
},
|
||||
indent=2,
|
||||
),
|
||||
)
|
||||
session.add(dataset)
|
||||
session.flush()
|
||||
return dataset
|
||||
|
||||
|
||||
def _looks_like_update_directory(url: str) -> bool:
|
||||
lower_path = urlparse(url).path.lower()
|
||||
return lower_path.endswith("-updates") or lower_path.endswith("-updates/")
|
||||
248
app/pipeline/osm_geojson.py
Normal file
248
app/pipeline/osm_geojson.py
Normal file
@@ -0,0 +1,248 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models import Dataset, OsmFeature, Source
|
||||
from app.osm_classification import infer_osm_route_scope
|
||||
from app.osm_storage import (
|
||||
OSM_STORAGE_METADATA_KEY,
|
||||
OSM_STORAGE_MAIN,
|
||||
OSM_STORAGE_SIDECAR_FEATURES,
|
||||
create_osm_sidecar,
|
||||
dedupe_osm_feature_rows,
|
||||
effective_osm_feature_storage,
|
||||
)
|
||||
from app.pipeline.download import materialize_source
|
||||
from app.pipeline.utils import first_nonempty, geometry_json_and_bbox, norm_ref, norm_text, sha256_file
|
||||
from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries
|
||||
|
||||
ROUTE_MODES = {
|
||||
"train",
|
||||
"railway",
|
||||
"light_rail",
|
||||
"subway",
|
||||
"tram",
|
||||
"bus",
|
||||
"trolleybus",
|
||||
"coach",
|
||||
"ferry",
|
||||
"monorail",
|
||||
"funicular",
|
||||
"aerialway",
|
||||
}
|
||||
|
||||
|
||||
def run_osm_geojson_source(session: Session, source: Source) -> Dataset:
|
||||
local_path = materialize_source(source)
|
||||
source_hash = sha256_file(local_path)
|
||||
existing = session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.source_id == source.id,
|
||||
Dataset.kind == "osm_geojson",
|
||||
Dataset.sha256 == source_hash,
|
||||
Dataset.is_active.is_(True),
|
||||
Dataset.status == "imported",
|
||||
)
|
||||
.order_by(Dataset.id.desc())
|
||||
)
|
||||
if existing is not None:
|
||||
return existing
|
||||
return import_osm_geojson(session=session, source=source, path=local_path, source_hash=source_hash)
|
||||
|
||||
|
||||
def import_osm_geojson(
|
||||
session: Session,
|
||||
source: Source,
|
||||
path: Path,
|
||||
source_hash: str | None = None,
|
||||
*,
|
||||
storage_mode: str | None = None,
|
||||
) -> Dataset:
|
||||
for dataset in source.datasets:
|
||||
dataset.is_active = False
|
||||
|
||||
dataset = Dataset(
|
||||
source_id=source.id,
|
||||
kind="osm_geojson",
|
||||
local_path=str(path),
|
||||
sha256=source_hash or sha256_file(path),
|
||||
is_active=True,
|
||||
status="importing",
|
||||
)
|
||||
session.add(dataset)
|
||||
session.flush()
|
||||
|
||||
source_hash = source_hash or sha256_file(path)
|
||||
dataset.metadata_json = json.dumps(
|
||||
prepare_osm_geojson_storage(
|
||||
session=session,
|
||||
dataset=dataset,
|
||||
path=path,
|
||||
source_hash=source_hash,
|
||||
storage_mode=storage_mode,
|
||||
),
|
||||
indent=2,
|
||||
)
|
||||
|
||||
dataset.status = "imported"
|
||||
source.status = "ok"
|
||||
source.last_error = None
|
||||
session.flush()
|
||||
return dataset
|
||||
|
||||
|
||||
def prepare_osm_geojson_storage(
|
||||
*,
|
||||
session: Session,
|
||||
dataset: Dataset,
|
||||
path: Path,
|
||||
source_hash: str | None = None,
|
||||
storage_mode: str | None = None,
|
||||
) -> dict[str, object]:
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
features = _as_features(data)
|
||||
feature_rows = [_feature_row(dataset.id, idx, feature) for idx, feature in enumerate(features)]
|
||||
storage = effective_osm_feature_storage(storage_mode)
|
||||
if storage not in {OSM_STORAGE_MAIN, OSM_STORAGE_SIDECAR_FEATURES}:
|
||||
raise ValueError(f"Unsupported OSM feature storage mode: {storage}")
|
||||
if storage == OSM_STORAGE_SIDECAR_FEATURES:
|
||||
return {
|
||||
"features": len(feature_rows),
|
||||
OSM_STORAGE_METADATA_KEY: create_osm_sidecar(dataset, feature_rows, source_hash=source_hash or dataset.sha256),
|
||||
}
|
||||
_insert_main_features(session, feature_rows)
|
||||
session.flush()
|
||||
refresh_postgis_geometries(session, dataset_id=dataset.id, tables=["osm_features"])
|
||||
analyze_postgresql_tables(session, ["osm_features"])
|
||||
return {"features": len(feature_rows), OSM_STORAGE_METADATA_KEY: {"mode": OSM_STORAGE_MAIN}}
|
||||
|
||||
|
||||
def _insert_main_features(session: Session, feature_rows: list[dict[str, object]]) -> None:
|
||||
objects: list[OsmFeature] = []
|
||||
deduped_rows, _duplicate_count = dedupe_osm_feature_rows(feature_rows)
|
||||
for row in deduped_rows:
|
||||
objects.append(
|
||||
OsmFeature(
|
||||
dataset_id=row["dataset_id"],
|
||||
osm_type=row["osm_type"],
|
||||
osm_id=row["osm_id"],
|
||||
kind=row["kind"],
|
||||
mode=row["mode"],
|
||||
route_scope=row["route_scope"],
|
||||
name=row["name"],
|
||||
ref=row["ref"],
|
||||
operator=row["operator"],
|
||||
network=row["network"],
|
||||
geometry_geojson=row["geometry_geojson"],
|
||||
min_lon=row["min_lon"],
|
||||
min_lat=row["min_lat"],
|
||||
max_lon=row["max_lon"],
|
||||
max_lat=row["max_lat"],
|
||||
tags_json=row["tags_json"],
|
||||
route_key=row["route_key"],
|
||||
operator_key=row["operator_key"],
|
||||
)
|
||||
)
|
||||
if len(objects) >= 5000:
|
||||
session.bulk_save_objects(objects)
|
||||
objects.clear()
|
||||
if objects:
|
||||
session.bulk_save_objects(objects)
|
||||
|
||||
|
||||
def _feature_row(dataset_id: int, idx: int, feature: dict[str, Any]) -> dict[str, object]:
|
||||
props = feature.get("properties") or {}
|
||||
geometry = feature.get("geometry")
|
||||
geometry_text, bbox = geometry_json_and_bbox(geometry)
|
||||
osm_type = str(first_nonempty(props.get("osm_type"), props.get("@type"), props.get("type"), "feature"))
|
||||
osm_id = str(first_nonempty(props.get("osm_id"), props.get("@id"), props.get("id"), f"feature_{idx}"))
|
||||
mode = _infer_mode(props)
|
||||
kind = _infer_kind(props, mode)
|
||||
name = first_nonempty(props.get("name"), props.get("official_name")) or None
|
||||
ref = first_nonempty(props.get("ref"), props.get("route_ref"), props.get("line")) or None
|
||||
operator = first_nonempty(props.get("operator"), props.get("agency"), props.get("brand")) or None
|
||||
network = first_nonempty(props.get("network"), props.get("network:short")) or None
|
||||
route_scope = infer_osm_route_scope(mode=mode, ref=ref, name=name, network=network, tags=props)
|
||||
route_key = norm_ref(ref) or norm_text(name) or norm_ref(osm_id)
|
||||
operator_key = norm_text(operator or network or "")
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"osm_type": osm_type,
|
||||
"osm_id": osm_id,
|
||||
"kind": kind,
|
||||
"mode": mode,
|
||||
"route_scope": route_scope,
|
||||
"name": name,
|
||||
"ref": ref,
|
||||
"operator": operator,
|
||||
"network": network,
|
||||
"geometry_geojson": geometry_text,
|
||||
"min_lon": bbox[0],
|
||||
"min_lat": bbox[1],
|
||||
"max_lon": bbox[2],
|
||||
"max_lat": bbox[3],
|
||||
"tags_json": json.dumps(props, separators=(",", ":")),
|
||||
"route_key": route_key,
|
||||
"operator_key": operator_key,
|
||||
}
|
||||
|
||||
|
||||
def _as_features(data: Any) -> list[dict[str, Any]]:
|
||||
if isinstance(data, dict) and data.get("type") == "FeatureCollection":
|
||||
return [f for f in data.get("features", []) if isinstance(f, dict)]
|
||||
if isinstance(data, dict) and data.get("type") == "Feature":
|
||||
return [data]
|
||||
if isinstance(data, list):
|
||||
return [f for f in data if isinstance(f, dict)]
|
||||
raise ValueError("OSM source must be GeoJSON FeatureCollection, Feature, or list of Features")
|
||||
|
||||
|
||||
def _infer_mode(props: dict[str, Any]) -> str | None:
|
||||
for key in ("mode", "route", "route_master"):
|
||||
value = str(props.get(key) or "").strip()
|
||||
if value in ROUTE_MODES:
|
||||
return "train" if value == "railway" else value
|
||||
railway = str(props.get("railway") or "").strip()
|
||||
if railway in {"station", "halt"}:
|
||||
return "train"
|
||||
if railway == "tram_stop":
|
||||
return "tram"
|
||||
if railway == "subway_entrance":
|
||||
return "subway"
|
||||
if str(props.get("highway") or "") == "bus_stop" or str(props.get("amenity") or "") == "bus_station":
|
||||
return "bus"
|
||||
if str(props.get("amenity") or "") == "ferry_terminal":
|
||||
return "ferry"
|
||||
if str(props.get("aerialway") or "") == "station":
|
||||
return "aerialway"
|
||||
return None
|
||||
|
||||
|
||||
def _infer_kind(props: dict[str, Any], mode: str | None) -> str:
|
||||
explicit_kind = str(props.get("kind") or "").strip()
|
||||
if explicit_kind in {"route", "stop", "station", "terminal", "infra", "feature"}:
|
||||
return explicit_kind
|
||||
if str(props.get("type") or "") in {"route", "route_master"} or str(props.get("route") or "") in ROUTE_MODES:
|
||||
return "route"
|
||||
if str(props.get("amenity") or "") == "ferry_terminal":
|
||||
return "terminal"
|
||||
if str(props.get("amenity") or "") == "bus_station":
|
||||
return "terminal"
|
||||
if str(props.get("railway") or "") in {"station", "halt"}:
|
||||
return "station"
|
||||
if str(props.get("aerialway") or "") == "station":
|
||||
return "station"
|
||||
if str(props.get("public_transport") or "") in {"platform", "stop_position", "station"}:
|
||||
return "stop"
|
||||
if str(props.get("highway") or "") == "bus_stop":
|
||||
return "stop"
|
||||
if mode:
|
||||
return "infra"
|
||||
return "feature"
|
||||
456
app/pipeline/osm_labeling.py
Normal file
456
app/pipeline/osm_labeling.py
Normal file
@@ -0,0 +1,456 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
from pathlib import Path
|
||||
import sqlite3
|
||||
from typing import Callable
|
||||
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import Dataset, OsmFeature
|
||||
from app.osm_classification import OSM_ROUTE_SCOPE_CLASSIFIER_VERSION, infer_osm_route_scope_from_tags
|
||||
from app.osm_storage import (
|
||||
dataset_metadata,
|
||||
drop_osm_sidecar_route_scope_indexes,
|
||||
ensure_osm_sidecar_schema,
|
||||
features_are_sidecar,
|
||||
rebuild_osm_sidecar_indexes,
|
||||
sidecar_path,
|
||||
writable_sidecar_connection,
|
||||
)
|
||||
from app.pipeline.state import (
|
||||
STAGE_BUILD_INDEXES,
|
||||
STAGE_LABEL_FEATURES,
|
||||
dependency_hash,
|
||||
finish_pipeline_run,
|
||||
latest_completed_run,
|
||||
start_pipeline_run,
|
||||
)
|
||||
|
||||
|
||||
OSM_LABEL_FEATURES_VERSION = OSM_ROUTE_SCOPE_CLASSIFIER_VERSION
|
||||
MAIN_ROUTE_SCOPE_INDEX = "ix_osm_features_scope_bbox"
|
||||
MAIN_INDEX_REBUILD_THRESHOLD = 10_000
|
||||
SIDECAR_INDEX_REBUILD_THRESHOLD = 10_000
|
||||
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
|
||||
|
||||
|
||||
def relabel_osm_features(
|
||||
session: Session,
|
||||
*,
|
||||
dataset_id: int | None = None,
|
||||
chunk_size: int = 5000,
|
||||
force: bool = False,
|
||||
rebuild_indexes: bool = True,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
job_id: int | None = None,
|
||||
) -> dict[str, object]:
|
||||
datasets = _target_datasets(session, dataset_id)
|
||||
result: dict[str, object] = {
|
||||
"version": OSM_LABEL_FEATURES_VERSION,
|
||||
"datasets": len(datasets),
|
||||
"processed": 0,
|
||||
"changed": 0,
|
||||
"skipped": 0,
|
||||
"missing": 0,
|
||||
"index_rebuilds": 0,
|
||||
"dataset_results": [],
|
||||
}
|
||||
_emit_progress(
|
||||
progress_callback,
|
||||
"osm_labeling_started",
|
||||
f"Relabeling {len(datasets)} OSM dataset(s).",
|
||||
0,
|
||||
len(datasets),
|
||||
{"dataset_id": dataset_id, "force": force, "version": OSM_LABEL_FEATURES_VERSION},
|
||||
)
|
||||
for index, dataset in enumerate(datasets, start=1):
|
||||
dataset_result = relabel_osm_dataset(
|
||||
session,
|
||||
dataset,
|
||||
chunk_size=chunk_size,
|
||||
force=force,
|
||||
rebuild_indexes=rebuild_indexes,
|
||||
progress_callback=progress_callback,
|
||||
job_id=job_id,
|
||||
)
|
||||
result["processed"] = int(result["processed"]) + int(dataset_result.get("processed", 0) or 0)
|
||||
result["changed"] = int(result["changed"]) + int(dataset_result.get("changed", 0) or 0)
|
||||
result["skipped"] = int(result["skipped"]) + (1 if dataset_result.get("status") == "skipped" else 0)
|
||||
result["missing"] = int(result["missing"]) + (1 if dataset_result.get("status") == "missing_sidecar" else 0)
|
||||
result["index_rebuilds"] = int(result["index_rebuilds"]) + int(dataset_result.get("index_rebuilds", 0) or 0)
|
||||
result["dataset_results"].append(dataset_result) # type: ignore[union-attr]
|
||||
_emit_progress(
|
||||
progress_callback,
|
||||
"osm_labeling_dataset_completed",
|
||||
f"Relabeled {index}/{len(datasets)} OSM dataset(s).",
|
||||
index,
|
||||
len(datasets),
|
||||
dataset_result,
|
||||
)
|
||||
_emit_progress(progress_callback, "osm_labeling_completed", "OSM feature relabeling completed.", len(datasets), len(datasets), result)
|
||||
return result
|
||||
|
||||
|
||||
def relabel_osm_dataset(
|
||||
session: Session,
|
||||
dataset: Dataset,
|
||||
*,
|
||||
chunk_size: int = 5000,
|
||||
force: bool = False,
|
||||
rebuild_indexes: bool = True,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
job_id: int | None = None,
|
||||
) -> dict[str, object]:
|
||||
dependency = _label_dependency(dataset)
|
||||
dependency_hash_value = dependency_hash(dependency)
|
||||
if not force and _dataset_label_is_current(session, dataset, dependency_hash_value):
|
||||
return {
|
||||
"dataset_id": dataset.id,
|
||||
"source_id": dataset.source_id,
|
||||
"status": "skipped",
|
||||
"reason": "label_features dependency is current",
|
||||
"dependency_hash": dependency_hash_value,
|
||||
"version": OSM_LABEL_FEATURES_VERSION,
|
||||
"processed": 0,
|
||||
"changed": 0,
|
||||
"index_rebuilds": 0,
|
||||
}
|
||||
|
||||
run = start_pipeline_run(
|
||||
session,
|
||||
stage=STAGE_LABEL_FEATURES,
|
||||
version=OSM_LABEL_FEATURES_VERSION,
|
||||
dependency_hash_value=dependency_hash_value,
|
||||
source_id=dataset.source_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=job_id,
|
||||
inputs=dependency,
|
||||
)
|
||||
session.commit()
|
||||
try:
|
||||
if features_are_sidecar(dataset):
|
||||
counts = _relabel_sidecar_dataset(dataset, chunk_size=chunk_size, rebuild_indexes=rebuild_indexes, progress_callback=progress_callback)
|
||||
else:
|
||||
counts = _relabel_main_dataset(session, dataset, chunk_size=chunk_size, rebuild_indexes=rebuild_indexes, progress_callback=progress_callback)
|
||||
output = {
|
||||
"dataset_id": dataset.id,
|
||||
"source_id": dataset.source_id,
|
||||
"status": "completed",
|
||||
"dependency_hash": dependency_hash_value,
|
||||
"version": OSM_LABEL_FEATURES_VERSION,
|
||||
**counts,
|
||||
}
|
||||
_stamp_dataset_metadata(session, dataset, dependency_hash_value, output)
|
||||
finish_pipeline_run(session, run, outputs=output)
|
||||
session.commit()
|
||||
return output
|
||||
except FileNotFoundError as exc:
|
||||
output = {
|
||||
"dataset_id": dataset.id,
|
||||
"source_id": dataset.source_id,
|
||||
"status": "missing_sidecar",
|
||||
"dependency_hash": dependency_hash_value,
|
||||
"version": OSM_LABEL_FEATURES_VERSION,
|
||||
"processed": 0,
|
||||
"changed": 0,
|
||||
"index_rebuilds": 0,
|
||||
"error": str(exc),
|
||||
}
|
||||
finish_pipeline_run(session, run, status="failed", outputs=output, error=str(exc))
|
||||
session.commit()
|
||||
return output
|
||||
except Exception as exc:
|
||||
finish_pipeline_run(session, run, status="failed", error=str(exc))
|
||||
session.commit()
|
||||
raise
|
||||
|
||||
|
||||
def _target_datasets(session: Session, dataset_id: int | None) -> list[Dataset]:
|
||||
stmt = select(Dataset).where(Dataset.kind == "osm_geojson", Dataset.status == "imported")
|
||||
if dataset_id is None:
|
||||
stmt = stmt.where(Dataset.is_active.is_(True))
|
||||
else:
|
||||
stmt = stmt.where(Dataset.id == dataset_id)
|
||||
return session.scalars(stmt.order_by(Dataset.source_id, Dataset.id)).all()
|
||||
|
||||
|
||||
def _dataset_label_is_current(session: Session, dataset: Dataset, dependency_hash_value: str) -> bool:
|
||||
metadata = dataset_metadata(dataset)
|
||||
label_info = metadata.get("label_features")
|
||||
metadata_current = (
|
||||
isinstance(label_info, dict)
|
||||
and label_info.get("version") == OSM_LABEL_FEATURES_VERSION
|
||||
and label_info.get("dependency_hash") == dependency_hash_value
|
||||
)
|
||||
if not metadata_current:
|
||||
return False
|
||||
return (
|
||||
latest_completed_run(
|
||||
session,
|
||||
stage=STAGE_LABEL_FEATURES,
|
||||
version=OSM_LABEL_FEATURES_VERSION,
|
||||
dependency_hash_value=dependency_hash_value,
|
||||
source_id=dataset.source_id,
|
||||
dataset_id=dataset.id,
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def _relabel_sidecar_dataset(
|
||||
dataset: Dataset,
|
||||
*,
|
||||
chunk_size: int,
|
||||
rebuild_indexes: bool,
|
||||
progress_callback: ProgressCallback | None,
|
||||
) -> dict[str, int | str]:
|
||||
path = sidecar_path(dataset)
|
||||
if path is None or not path.exists():
|
||||
raise FileNotFoundError(f"OSM sidecar does not exist: {path}")
|
||||
with writable_sidecar_connection(dataset) as connection:
|
||||
ensure_osm_sidecar_schema(connection)
|
||||
total = int(connection.execute("SELECT COUNT(*) FROM osm_features").fetchone()[0] or 0)
|
||||
should_rebuild_index = rebuild_indexes and total >= SIDECAR_INDEX_REBUILD_THRESHOLD
|
||||
if should_rebuild_index:
|
||||
drop_osm_sidecar_route_scope_indexes(connection)
|
||||
connection.commit()
|
||||
processed = 0
|
||||
changed = 0
|
||||
last_id = 0
|
||||
try:
|
||||
while True:
|
||||
rows = connection.execute(
|
||||
"""
|
||||
SELECT id, mode, ref, name, network, tags_json, route_scope
|
||||
FROM osm_features
|
||||
WHERE id > ?
|
||||
ORDER BY id
|
||||
LIMIT ?
|
||||
""",
|
||||
(last_id, max(1, int(chunk_size))),
|
||||
).fetchall()
|
||||
if not rows:
|
||||
break
|
||||
updates: list[tuple[str | None, int]] = []
|
||||
for row in rows:
|
||||
last_id = int(row["id"])
|
||||
new_scope = _classified_scope(row["mode"], row["ref"], row["name"], row["network"], row["tags_json"])
|
||||
if _normalize_scope(row["route_scope"]) != new_scope:
|
||||
updates.append((new_scope, last_id))
|
||||
if updates:
|
||||
connection.executemany("UPDATE osm_features SET route_scope = ? WHERE id = ?", updates)
|
||||
processed += len(rows)
|
||||
changed += len(updates)
|
||||
connection.commit()
|
||||
_emit_progress(
|
||||
progress_callback,
|
||||
"osm_labeling_batch",
|
||||
f"Relabeled {processed}/{total} OSM sidecar features.",
|
||||
processed,
|
||||
total,
|
||||
{"dataset_id": dataset.id, "changed": changed, "storage": "sidecar"},
|
||||
)
|
||||
finally:
|
||||
index_rebuilds = 0
|
||||
if should_rebuild_index:
|
||||
rebuild_osm_sidecar_indexes(connection)
|
||||
connection.commit()
|
||||
index_rebuilds = 1
|
||||
_record_sidecar_index_build(connection, dataset, path)
|
||||
_record_sidecar_label(connection, dataset, processed=processed, changed=changed)
|
||||
connection.commit()
|
||||
return {"storage": "sidecar", "processed": processed, "changed": changed, "index_rebuilds": index_rebuilds}
|
||||
|
||||
|
||||
def _relabel_main_dataset(
|
||||
session: Session,
|
||||
dataset: Dataset,
|
||||
*,
|
||||
chunk_size: int,
|
||||
rebuild_indexes: bool,
|
||||
progress_callback: ProgressCallback | None,
|
||||
) -> dict[str, int | str]:
|
||||
total = int(session.scalar(select(func.count()).select_from(OsmFeature).where(OsmFeature.dataset_id == dataset.id)) or 0)
|
||||
should_rebuild_index = rebuild_indexes and total >= MAIN_INDEX_REBUILD_THRESHOLD
|
||||
index_rebuilds = 0
|
||||
if should_rebuild_index:
|
||||
session.execute(text(f"DROP INDEX IF EXISTS {MAIN_ROUTE_SCOPE_INDEX}"))
|
||||
session.commit()
|
||||
processed = 0
|
||||
changed = 0
|
||||
last_id = 0
|
||||
try:
|
||||
while True:
|
||||
rows = session.scalars(
|
||||
select(OsmFeature)
|
||||
.where(OsmFeature.dataset_id == dataset.id, OsmFeature.id > last_id)
|
||||
.order_by(OsmFeature.id)
|
||||
.limit(max(1, int(chunk_size)))
|
||||
).all()
|
||||
if not rows:
|
||||
break
|
||||
updates: list[dict[str, object]] = []
|
||||
for feature in rows:
|
||||
last_id = int(feature.id)
|
||||
new_scope = _classified_scope(feature.mode, feature.ref, feature.name, feature.network, feature.tags_json)
|
||||
if _normalize_scope(feature.route_scope) != new_scope:
|
||||
updates.append({"id": feature.id, "route_scope": new_scope})
|
||||
if updates:
|
||||
session.bulk_update_mappings(OsmFeature, updates)
|
||||
processed += len(rows)
|
||||
changed += len(updates)
|
||||
session.commit()
|
||||
_emit_progress(
|
||||
progress_callback,
|
||||
"osm_labeling_batch",
|
||||
f"Relabeled {processed}/{total} main-table OSM features.",
|
||||
processed,
|
||||
total,
|
||||
{"dataset_id": dataset.id, "changed": changed, "storage": "main"},
|
||||
)
|
||||
finally:
|
||||
if should_rebuild_index:
|
||||
session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_features_scope_bbox "
|
||||
"ON osm_features (dataset_id, kind, mode, route_scope, min_lon, max_lon, min_lat, max_lat)"
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
index_rebuilds = 1
|
||||
_record_main_index_build(session, dataset)
|
||||
return {"storage": "main", "processed": processed, "changed": changed, "index_rebuilds": index_rebuilds}
|
||||
|
||||
|
||||
def _classified_scope(mode: object, ref: object, name: object, network: object, tags_json: object) -> str | None:
|
||||
return _normalize_scope(
|
||||
infer_osm_route_scope_from_tags(
|
||||
None if mode is None else str(mode),
|
||||
None if ref is None else str(ref),
|
||||
None if name is None else str(name),
|
||||
None if network is None else str(network),
|
||||
None if tags_json is None else str(tags_json),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _normalize_scope(value: object) -> str | None:
|
||||
text_value = str(value or "").strip()
|
||||
return text_value or None
|
||||
|
||||
|
||||
def _label_dependency(dataset: Dataset) -> dict[str, object]:
|
||||
metadata = dataset_metadata(dataset)
|
||||
storage = metadata.get("osm_storage") if isinstance(metadata, dict) else None
|
||||
path = sidecar_path(dataset)
|
||||
path_fingerprint: dict[str, object] | None = None
|
||||
if path is not None:
|
||||
resolved = Path(path)
|
||||
if resolved.exists():
|
||||
path_fingerprint = {"path": str(resolved), "exists": True}
|
||||
else:
|
||||
path_fingerprint = {"path": str(resolved), "missing": True}
|
||||
return {
|
||||
"dataset_id": dataset.id,
|
||||
"source_id": dataset.source_id,
|
||||
"kind": dataset.kind,
|
||||
"dataset_sha256": dataset.sha256,
|
||||
"storage": storage,
|
||||
"sidecar": path_fingerprint,
|
||||
"classifier_version": OSM_LABEL_FEATURES_VERSION,
|
||||
}
|
||||
|
||||
|
||||
def _stamp_dataset_metadata(session: Session, dataset: Dataset, dependency_hash_value: str, output: dict[str, object]) -> None:
|
||||
refreshed = session.get(Dataset, dataset.id)
|
||||
if refreshed is None:
|
||||
return
|
||||
metadata = dataset_metadata(refreshed)
|
||||
metadata["label_features"] = {
|
||||
"stage": STAGE_LABEL_FEATURES,
|
||||
"version": OSM_LABEL_FEATURES_VERSION,
|
||||
"dependency_hash": dependency_hash_value,
|
||||
"labeled_at": datetime.now(timezone.utc).isoformat(),
|
||||
"processed": output.get("processed", 0),
|
||||
"changed": output.get("changed", 0),
|
||||
"storage": output.get("storage"),
|
||||
}
|
||||
refreshed.metadata_json = json.dumps(metadata, indent=2)
|
||||
session.flush()
|
||||
|
||||
|
||||
def _record_sidecar_label(connection: sqlite3.Connection, dataset: Dataset, *, processed: int, changed: int) -> None:
|
||||
connection.execute("CREATE TABLE IF NOT EXISTS pipeline_metadata (key TEXT PRIMARY KEY, value TEXT NOT NULL)")
|
||||
connection.execute(
|
||||
"INSERT OR REPLACE INTO pipeline_metadata (key, value) VALUES (?, ?)",
|
||||
(
|
||||
"label_features",
|
||||
json.dumps(
|
||||
{
|
||||
"stage": STAGE_LABEL_FEATURES,
|
||||
"version": OSM_LABEL_FEATURES_VERSION,
|
||||
"dataset_id": dataset.id,
|
||||
"processed": processed,
|
||||
"changed": changed,
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _record_sidecar_index_build(connection: sqlite3.Connection, dataset: Dataset, path: Path) -> None:
|
||||
connection.execute("CREATE TABLE IF NOT EXISTS pipeline_metadata (key TEXT PRIMARY KEY, value TEXT NOT NULL)")
|
||||
connection.execute(
|
||||
"INSERT OR REPLACE INTO pipeline_metadata (key, value) VALUES (?, ?)",
|
||||
(
|
||||
"build_indexes:route_scope",
|
||||
json.dumps(
|
||||
{
|
||||
"stage": STAGE_BUILD_INDEXES,
|
||||
"version": "osm_sidecar_indexes_v1",
|
||||
"dataset_id": dataset.id,
|
||||
"path": str(path),
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _record_main_index_build(session: Session, dataset: Dataset) -> None:
|
||||
dependency = {
|
||||
"dataset_id": dataset.id,
|
||||
"index": MAIN_ROUTE_SCOPE_INDEX,
|
||||
"version": "osm_main_indexes_v1",
|
||||
}
|
||||
run = start_pipeline_run(
|
||||
session,
|
||||
stage=STAGE_BUILD_INDEXES,
|
||||
version="osm_main_indexes_v1",
|
||||
dependency_hash_value=dependency_hash(dependency),
|
||||
source_id=dataset.source_id,
|
||||
dataset_id=dataset.id,
|
||||
inputs=dependency,
|
||||
)
|
||||
finish_pipeline_run(session, run, outputs={"index": MAIN_ROUTE_SCOPE_INDEX})
|
||||
session.commit()
|
||||
|
||||
|
||||
def _emit_progress(
|
||||
callback: ProgressCallback | None,
|
||||
event_type: str,
|
||||
message: str,
|
||||
current: int | None,
|
||||
total: int | None,
|
||||
metadata: dict[str, object] | None,
|
||||
) -> None:
|
||||
if callback is not None:
|
||||
callback(event_type, message, current, total, metadata)
|
||||
1581
app/pipeline/osm_pbf.py
Normal file
1581
app/pipeline/osm_pbf.py
Normal file
File diff suppressed because it is too large
Load Diff
105
app/pipeline/osm_replication.py
Normal file
105
app/pipeline/osm_replication.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReplicationState:
|
||||
sequence_number: int
|
||||
timestamp: str | None
|
||||
raw: dict[str, str]
|
||||
|
||||
|
||||
def fetch_replication_state(updates_url: str, *, timeout: float = 30) -> ReplicationState:
|
||||
state_url = _state_url(updates_url)
|
||||
response = requests.get(state_url, timeout=timeout)
|
||||
response.raise_for_status()
|
||||
return parse_replication_state_text(response.text)
|
||||
|
||||
|
||||
def parse_replication_state_text(text: str) -> ReplicationState:
|
||||
values: dict[str, str] = {}
|
||||
for line in text.splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
key, value = line.split("=", 1)
|
||||
values[key.strip()] = _unescape_state_value(value.strip())
|
||||
sequence = values.get("sequenceNumber")
|
||||
if sequence is None:
|
||||
raise ValueError("replication state is missing sequenceNumber")
|
||||
try:
|
||||
sequence_number = int(sequence)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"invalid replication sequenceNumber: {sequence}") from exc
|
||||
return ReplicationState(
|
||||
sequence_number=sequence_number,
|
||||
timestamp=values.get("timestamp"),
|
||||
raw=values,
|
||||
)
|
||||
|
||||
|
||||
def diff_url_for_sequence(updates_url: str, sequence_number: int) -> str:
|
||||
padded = str(sequence_number).zfill(max(9, ((len(str(sequence_number)) + 2) // 3) * 3))
|
||||
parts = [padded[index : index + 3] for index in range(0, len(padded), 3)]
|
||||
return urljoin(_directory_url(updates_url), "/".join(parts) + ".osc.gz")
|
||||
|
||||
|
||||
def download_diff(updates_url: str, sequence_number: int, output_dir: Path, *, timeout: float = 120) -> Path:
|
||||
url = diff_url_for_sequence(updates_url, sequence_number)
|
||||
parsed_path = Path(urlparse(url).path)
|
||||
output_path = output_dir / parsed_path.name
|
||||
nested = output_dir / parsed_path.parent.name / output_path.name
|
||||
if output_path.exists():
|
||||
return output_path
|
||||
if nested.exists():
|
||||
return nested
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_path = output_dir / f"{sequence_number}.download"
|
||||
with requests.get(url, stream=True, timeout=timeout) as response:
|
||||
response.raise_for_status()
|
||||
with temp_path.open("wb") as handle:
|
||||
for chunk in response.iter_content(chunk_size=1024 * 1024):
|
||||
if chunk:
|
||||
handle.write(chunk)
|
||||
temp_path.replace(output_path)
|
||||
return output_path
|
||||
|
||||
|
||||
def apply_osm_changes(base_path: Path, diff_paths: list[Path], output_path: Path, host_tool_path: Path) -> subprocess.CompletedProcess[str]:
|
||||
if not diff_paths:
|
||||
raise ValueError("no OSM change files supplied")
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
command = [
|
||||
str(host_tool_path),
|
||||
"osmium",
|
||||
"apply-changes",
|
||||
"--output",
|
||||
str(output_path),
|
||||
"--overwrite",
|
||||
str(base_path),
|
||||
*[str(path) for path in diff_paths],
|
||||
]
|
||||
return subprocess.run(command, check=True, capture_output=True, text=True)
|
||||
|
||||
|
||||
def _state_url(updates_url: str) -> str:
|
||||
return urljoin(_directory_url(updates_url), "state.txt")
|
||||
|
||||
|
||||
def _directory_url(url: str) -> str:
|
||||
return url if url.endswith("/") else f"{url}/"
|
||||
|
||||
|
||||
def _unescape_state_value(value: str) -> str:
|
||||
return (
|
||||
value.replace("\\:", ":")
|
||||
.replace("\\=", "=")
|
||||
.replace("\\ ", " ")
|
||||
.replace("\\\\", "\\")
|
||||
)
|
||||
1903
app/pipeline/route_layer.py
Normal file
1903
app/pipeline/route_layer.py
Normal file
File diff suppressed because it is too large
Load Diff
473
app/pipeline/routing_layer.py
Normal file
473
app/pipeline/routing_layer.py
Normal file
@@ -0,0 +1,473 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import osmium
|
||||
from sqlalchemy import delete, func, select, text
|
||||
from sqlalchemy.dialects.postgresql import insert as postgresql_insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models import Dataset, RoutingEdge, RoutingNode
|
||||
from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries
|
||||
|
||||
|
||||
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
|
||||
ROUTING_LAYER_VERSION = "routing_layer_v2_osm_highway_segments_service_tags"
|
||||
|
||||
DRIVE_HIGHWAYS = {
|
||||
"motorway",
|
||||
"motorway_link",
|
||||
"trunk",
|
||||
"trunk_link",
|
||||
"primary",
|
||||
"primary_link",
|
||||
"secondary",
|
||||
"secondary_link",
|
||||
"tertiary",
|
||||
"tertiary_link",
|
||||
"unclassified",
|
||||
"residential",
|
||||
"living_street",
|
||||
"service",
|
||||
"road",
|
||||
"track",
|
||||
}
|
||||
WALK_HIGHWAYS = {
|
||||
"pedestrian",
|
||||
"footway",
|
||||
"path",
|
||||
"steps",
|
||||
"cycleway",
|
||||
"bridleway",
|
||||
"living_street",
|
||||
"residential",
|
||||
"service",
|
||||
"track",
|
||||
"unclassified",
|
||||
"tertiary",
|
||||
"tertiary_link",
|
||||
"secondary",
|
||||
"secondary_link",
|
||||
"primary",
|
||||
"primary_link",
|
||||
"road",
|
||||
}
|
||||
EXCLUDED_HIGHWAYS = {"construction", "proposed", "abandoned", "platform", "raceway"}
|
||||
NO_VALUES = {"no", "private", "agricultural", "forestry", "delivery", "customers"}
|
||||
YES_VALUES = {"yes", "designated", "permissive", "destination"}
|
||||
ONEWAY_FORWARD = {"yes", "true", "1"}
|
||||
ONEWAY_REVERSE = {"-1", "reverse"}
|
||||
DEFAULT_DRIVE_SPEED_KMH = {
|
||||
"motorway": 110,
|
||||
"motorway_link": 50,
|
||||
"trunk": 90,
|
||||
"trunk_link": 45,
|
||||
"primary": 70,
|
||||
"primary_link": 40,
|
||||
"secondary": 60,
|
||||
"secondary_link": 35,
|
||||
"tertiary": 50,
|
||||
"tertiary_link": 30,
|
||||
"unclassified": 40,
|
||||
"residential": 30,
|
||||
"living_street": 10,
|
||||
"service": 15,
|
||||
"road": 30,
|
||||
"track": 15,
|
||||
}
|
||||
DEFAULT_WALK_SPEED_MPS = 1.35
|
||||
STEP_WALK_SPEED_MPS = 0.65
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoutingImportResult:
|
||||
dataset_id: int
|
||||
input_path: str
|
||||
nodes: int
|
||||
edges: int
|
||||
walk_edges: int
|
||||
drive_edges: int
|
||||
skipped_ways: int
|
||||
version: str = ROUTING_LAYER_VERSION
|
||||
|
||||
def as_dict(self) -> dict[str, object]:
|
||||
return {
|
||||
"version": self.version,
|
||||
"dataset_id": self.dataset_id,
|
||||
"input_path": self.input_path,
|
||||
"nodes": self.nodes,
|
||||
"edges": self.edges,
|
||||
"walk_edges": self.walk_edges,
|
||||
"drive_edges": self.drive_edges,
|
||||
"skipped_ways": self.skipped_ways,
|
||||
}
|
||||
|
||||
|
||||
def active_routing_dataset(session: Session) -> Dataset | None:
|
||||
active_osm = session.scalar(
|
||||
select(Dataset).where(Dataset.kind == "osm_geojson", Dataset.is_active.is_(True)).order_by(Dataset.id.desc())
|
||||
)
|
||||
if active_osm is not None:
|
||||
metadata = _metadata(active_osm)
|
||||
raw_dataset_id = metadata.get("raw_dataset_id")
|
||||
if raw_dataset_id is not None:
|
||||
raw = session.get(Dataset, int(raw_dataset_id))
|
||||
if raw is not None and Path(raw.local_path).exists():
|
||||
return raw
|
||||
return session.scalar(
|
||||
select(Dataset)
|
||||
.where(Dataset.kind == "osm_pbf_raw")
|
||||
.order_by(Dataset.is_active.desc(), Dataset.id.desc())
|
||||
)
|
||||
|
||||
|
||||
def rebuild_routing_layer(
|
||||
session: Session,
|
||||
*,
|
||||
dataset_id: int | None = None,
|
||||
input_path: str | Path | None = None,
|
||||
reset: bool = True,
|
||||
batch_size: int = 5000,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
) -> dict[str, object]:
|
||||
if not settings.is_postgresql_database:
|
||||
raise RuntimeError("The routing layer importer requires PostgreSQL/PostGIS.")
|
||||
dataset = session.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(session)
|
||||
if dataset is None:
|
||||
raise ValueError("No OSM PBF dataset is available for routing import.")
|
||||
path = Path(input_path or dataset.local_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Routing import PBF does not exist: {path}")
|
||||
|
||||
if reset:
|
||||
_emit(progress_callback, "routing_layer_clear_started", "Clearing existing routing graph.", None, None, {"dataset_id": dataset.id})
|
||||
session.execute(delete(RoutingEdge).where(RoutingEdge.dataset_id == dataset.id))
|
||||
session.execute(delete(RoutingNode).where(RoutingNode.dataset_id == dataset.id))
|
||||
session.commit()
|
||||
|
||||
_emit(progress_callback, "routing_layer_import_started", "Importing routable OSM highway graph.", None, None, {"dataset_id": dataset.id, "path": str(path)})
|
||||
handler = _RoutingGraphHandler(session=session, dataset_id=dataset.id, batch_size=batch_size, progress_callback=progress_callback)
|
||||
handler.apply_file(str(path), locations=True)
|
||||
handler.flush()
|
||||
|
||||
return finalize_routing_layer(
|
||||
session,
|
||||
dataset_id=dataset.id,
|
||||
input_path=str(path),
|
||||
skipped_way_count=handler.skipped_way_count,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
|
||||
def finalize_routing_layer(
|
||||
session: Session,
|
||||
*,
|
||||
dataset_id: int | None = None,
|
||||
input_path: str | Path | None = None,
|
||||
skipped_way_count: int = 0,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
) -> dict[str, object]:
|
||||
if not settings.is_postgresql_database:
|
||||
raise RuntimeError("The routing layer finalizer requires PostgreSQL/PostGIS.")
|
||||
dataset = session.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(session)
|
||||
if dataset is None:
|
||||
raise ValueError("No routing dataset is available to finalize.")
|
||||
path = Path(input_path or dataset.local_path)
|
||||
_emit(progress_callback, "routing_layer_geometry_indexes_dropped", "Dropping routing geometry indexes before bulk refresh.", None, None, {"dataset_id": dataset.id})
|
||||
_drop_routing_geometry_indexes(session)
|
||||
session.commit()
|
||||
_emit(progress_callback, "routing_layer_geometry_started", "Refreshing routing node PostGIS geometries.", None, None, {"dataset_id": dataset.id})
|
||||
refresh_postgis_geometries(session, dataset_id=dataset.id, tables=["routing_nodes"], only_missing=False)
|
||||
session.commit()
|
||||
_emit(progress_callback, "routing_layer_geometry_indexes_started", "Rebuilding routing geometry indexes.", None, None, {"dataset_id": dataset.id})
|
||||
_create_routing_geometry_indexes(session)
|
||||
session.commit()
|
||||
analyze_postgresql_tables(session, ["routing_nodes", "routing_edges"])
|
||||
node_count = int(session.scalar(select(func.count()).select_from(RoutingNode).where(RoutingNode.dataset_id == dataset.id)) or 0)
|
||||
edge_count = int(session.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset.id)) or 0)
|
||||
walk_edge_count = int(session.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset.id, RoutingEdge.walk_cost_s.is_not(None))) or 0)
|
||||
drive_edge_count = int(session.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset.id, RoutingEdge.drive_cost_s.is_not(None))) or 0)
|
||||
dataset_metadata = _metadata(dataset)
|
||||
dataset_metadata["routing_layer"] = {
|
||||
"version": ROUTING_LAYER_VERSION,
|
||||
"nodes": node_count,
|
||||
"edges": edge_count,
|
||||
"walk_edges": walk_edge_count,
|
||||
"drive_edges": drive_edge_count,
|
||||
"input_path": str(path),
|
||||
}
|
||||
dataset.metadata_json = json.dumps(dataset_metadata, indent=2)
|
||||
session.commit()
|
||||
result = RoutingImportResult(
|
||||
dataset_id=dataset.id,
|
||||
input_path=str(path),
|
||||
nodes=node_count,
|
||||
edges=edge_count,
|
||||
walk_edges=walk_edge_count,
|
||||
drive_edges=drive_edge_count,
|
||||
skipped_ways=skipped_way_count,
|
||||
).as_dict()
|
||||
_emit(progress_callback, "routing_layer_import_completed", "Routing graph import completed.", edge_count, edge_count, result)
|
||||
return result
|
||||
|
||||
|
||||
def _drop_routing_geometry_indexes(session: Session) -> None:
|
||||
session.execute(text("DROP INDEX IF EXISTS ix_routing_nodes_geom_gist"))
|
||||
session.execute(text("DROP INDEX IF EXISTS ix_routing_edges_geom_gist"))
|
||||
session.execute(text("DROP INDEX IF EXISTS ix_routing_edges_bbox_box_gist"))
|
||||
|
||||
|
||||
def _create_routing_geometry_indexes(session: Session) -> None:
|
||||
session.execute(text("CREATE INDEX IF NOT EXISTS ix_routing_nodes_geom_gist ON routing_nodes USING GIST (geom)"))
|
||||
session.execute(text("CREATE INDEX IF NOT EXISTS ix_routing_edges_bbox_box_gist ON routing_edges USING GIST (box(point(max_lon, max_lat), point(min_lon, min_lat)))"))
|
||||
|
||||
|
||||
class _RoutingGraphHandler(osmium.SimpleHandler):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
dataset_id: int,
|
||||
batch_size: int,
|
||||
progress_callback: ProgressCallback | None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.session = session
|
||||
self.dataset_id = dataset_id
|
||||
self.batch_size = max(500, int(batch_size))
|
||||
self.progress_callback = progress_callback
|
||||
self.nodes: dict[int, dict[str, object]] = {}
|
||||
self.edges: list[dict[str, object]] = []
|
||||
self.node_count = int(
|
||||
session.scalar(select(func.count()).select_from(RoutingNode).where(RoutingNode.dataset_id == dataset_id)) or 0
|
||||
)
|
||||
self.edge_count = int(
|
||||
session.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset_id)) or 0
|
||||
)
|
||||
self.walk_edge_count = 0
|
||||
self.drive_edge_count = 0
|
||||
self.skipped_way_count = 0
|
||||
self.processed_way_count = 0
|
||||
|
||||
def way(self, way) -> None:
|
||||
tags = {tag.k: tag.v for tag in way.tags}
|
||||
highway = tags.get("highway")
|
||||
if not highway or highway in EXCLUDED_HIGHWAYS:
|
||||
self.skipped_way_count += 1
|
||||
return
|
||||
walkable = _walkable(tags, highway)
|
||||
drivable = _drivable(tags, highway)
|
||||
if not walkable and not drivable:
|
||||
self.skipped_way_count += 1
|
||||
return
|
||||
|
||||
nodes = []
|
||||
for node in way.nodes:
|
||||
if not node.location.valid():
|
||||
continue
|
||||
nodes.append((int(node.ref), float(node.location.lon), float(node.location.lat)))
|
||||
if len(nodes) < 2:
|
||||
self.skipped_way_count += 1
|
||||
return
|
||||
|
||||
oneway = _oneway_direction(tags, highway)
|
||||
drive_speed_mps = _drive_speed_mps(tags, highway)
|
||||
walk_speed_mps = STEP_WALK_SPEED_MPS if highway == "steps" else DEFAULT_WALK_SPEED_MPS
|
||||
for left, right in zip(nodes, nodes[1:]):
|
||||
source_id, source_lon, source_lat = left
|
||||
target_id, target_lon, target_lat = right
|
||||
if source_id == target_id:
|
||||
continue
|
||||
length_m = _distance_m(source_lat, source_lon, target_lat, target_lon)
|
||||
if length_m <= 0:
|
||||
continue
|
||||
if oneway == "reverse":
|
||||
source_id, target_id = target_id, source_id
|
||||
source_lon, target_lon = target_lon, source_lon
|
||||
source_lat, target_lat = target_lat, source_lat
|
||||
|
||||
walk_cost = length_m / walk_speed_mps if walkable else None
|
||||
drive_cost = length_m / drive_speed_mps if drivable and drive_speed_mps > 0 else None
|
||||
reverse_walk_cost = walk_cost
|
||||
reverse_drive_cost = None if oneway in {"forward", "reverse"} else drive_cost
|
||||
self.nodes[source_id] = {"dataset_id": self.dataset_id, "osm_node_id": source_id, "lon": source_lon, "lat": source_lat}
|
||||
self.nodes[target_id] = {"dataset_id": self.dataset_id, "osm_node_id": target_id, "lon": target_lon, "lat": target_lat}
|
||||
self.edges.append(
|
||||
{
|
||||
"dataset_id": self.dataset_id,
|
||||
"osm_way_id": int(way.id),
|
||||
"source_osm_node_id": source_id,
|
||||
"target_osm_node_id": target_id,
|
||||
"source_lon": source_lon,
|
||||
"source_lat": source_lat,
|
||||
"target_lon": target_lon,
|
||||
"target_lat": target_lat,
|
||||
"highway": highway,
|
||||
"name": tags.get("name"),
|
||||
"length_m": length_m,
|
||||
"walk_cost_s": walk_cost,
|
||||
"reverse_walk_cost_s": reverse_walk_cost,
|
||||
"drive_cost_s": drive_cost,
|
||||
"reverse_drive_cost_s": reverse_drive_cost,
|
||||
"geometry_geojson": json.dumps({"type": "LineString", "coordinates": [[source_lon, source_lat], [target_lon, target_lat]]}, separators=(",", ":")),
|
||||
"min_lon": min(source_lon, target_lon),
|
||||
"min_lat": min(source_lat, target_lat),
|
||||
"max_lon": max(source_lon, target_lon),
|
||||
"max_lat": max(source_lat, target_lat),
|
||||
"tags_json": _routing_tags_json(tags),
|
||||
}
|
||||
)
|
||||
self.edge_count += 1
|
||||
if walk_cost is not None:
|
||||
self.walk_edge_count += 1
|
||||
if drive_cost is not None:
|
||||
self.drive_edge_count += 1
|
||||
|
||||
self.processed_way_count += 1
|
||||
if len(self.edges) >= self.batch_size:
|
||||
self.flush()
|
||||
if self.processed_way_count % 100_000 == 0:
|
||||
_emit(
|
||||
self.progress_callback,
|
||||
"routing_layer_import_batch",
|
||||
f"Imported {self.edge_count:,} routing edges.",
|
||||
self.edge_count,
|
||||
None,
|
||||
{"processed_ways": self.processed_way_count, "nodes_pending": len(self.nodes), "edges": self.edge_count},
|
||||
)
|
||||
|
||||
def flush(self) -> None:
|
||||
if not self.nodes and not self.edges:
|
||||
return
|
||||
node_rows = list(self.nodes.values())
|
||||
edge_rows = self.edges
|
||||
if node_rows:
|
||||
stmt = postgresql_insert(RoutingNode).values(node_rows)
|
||||
stmt = stmt.on_conflict_do_nothing(index_elements=["dataset_id", "osm_node_id"])
|
||||
self.session.execute(stmt)
|
||||
self.node_count += len(node_rows)
|
||||
self.nodes.clear()
|
||||
if edge_rows:
|
||||
self.session.bulk_insert_mappings(RoutingEdge, edge_rows)
|
||||
self.edges = []
|
||||
self.session.commit()
|
||||
|
||||
|
||||
def _walkable(tags: dict[str, str], highway: str) -> bool:
|
||||
if highway not in WALK_HIGHWAYS:
|
||||
return False
|
||||
access = _tag_value(tags, "access")
|
||||
foot = _tag_value(tags, "foot")
|
||||
if foot in NO_VALUES:
|
||||
return False
|
||||
if access in NO_VALUES and foot not in YES_VALUES:
|
||||
return False
|
||||
if highway in {"motorway", "motorway_link", "trunk", "trunk_link"} and foot not in YES_VALUES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _drivable(tags: dict[str, str], highway: str) -> bool:
|
||||
if highway not in DRIVE_HIGHWAYS:
|
||||
return False
|
||||
access = _tag_value(tags, "access")
|
||||
motor_vehicle = _tag_value(tags, "motor_vehicle")
|
||||
motorcar = _tag_value(tags, "motorcar")
|
||||
vehicle = _tag_value(tags, "vehicle")
|
||||
if motorcar in NO_VALUES or motor_vehicle in NO_VALUES or vehicle in NO_VALUES:
|
||||
return False
|
||||
if access in NO_VALUES and motorcar not in YES_VALUES and motor_vehicle not in YES_VALUES:
|
||||
return False
|
||||
if highway in {"footway", "path", "pedestrian", "steps", "cycleway", "bridleway"}:
|
||||
return motorcar in YES_VALUES or motor_vehicle in YES_VALUES
|
||||
return True
|
||||
|
||||
|
||||
def _oneway_direction(tags: dict[str, str], highway: str) -> str:
|
||||
oneway = _tag_value(tags, "oneway")
|
||||
if oneway in ONEWAY_REVERSE:
|
||||
return "reverse"
|
||||
if oneway in ONEWAY_FORWARD or tags.get("junction") == "roundabout" or highway == "motorway":
|
||||
return "forward"
|
||||
return "both"
|
||||
|
||||
|
||||
def _drive_speed_mps(tags: dict[str, str], highway: str) -> float:
|
||||
maxspeed = _parse_maxspeed(tags.get("maxspeed"))
|
||||
kmh = maxspeed or DEFAULT_DRIVE_SPEED_KMH.get(highway, 30)
|
||||
return max(5.0, float(kmh) / 3.6)
|
||||
|
||||
|
||||
def _parse_maxspeed(value: str | None) -> float | None:
|
||||
if not value:
|
||||
return None
|
||||
text = value.strip().lower()
|
||||
if text in {"signals", "none", "walk", "variable"}:
|
||||
return None
|
||||
if text.endswith("mph"):
|
||||
number = _leading_float(text[:-3])
|
||||
return None if number is None else number * 1.60934
|
||||
return _leading_float(text)
|
||||
|
||||
|
||||
def _leading_float(value: str) -> float | None:
|
||||
digits = []
|
||||
for char in value.strip():
|
||||
if char.isdigit() or char == ".":
|
||||
digits.append(char)
|
||||
elif digits:
|
||||
break
|
||||
if not digits:
|
||||
return None
|
||||
try:
|
||||
return float("".join(digits))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _routing_tags_json(tags: dict[str, str]) -> str:
|
||||
selected = {
|
||||
key: value
|
||||
for key, value in tags.items()
|
||||
if key in {"access", "bicycle", "bridge", "foot", "highway", "junction", "maxspeed", "motor_vehicle", "motorcar", "name", "oneway", "service", "surface", "tunnel", "vehicle"}
|
||||
}
|
||||
return json.dumps(selected, separators=(",", ":"))
|
||||
|
||||
|
||||
def _tag_value(tags: dict[str, str], key: str) -> str:
|
||||
return str(tags.get(key) or "").strip().lower()
|
||||
|
||||
|
||||
def _distance_m(lat_a: float, lon_a: float, lat_b: float, lon_b: float) -> float:
|
||||
radius = 6_371_000.0
|
||||
phi_a = math.radians(lat_a)
|
||||
phi_b = math.radians(lat_b)
|
||||
delta_phi = math.radians(lat_b - lat_a)
|
||||
delta_lambda = math.radians(lon_b - lon_a)
|
||||
hav = math.sin(delta_phi / 2) ** 2 + math.cos(phi_a) * math.cos(phi_b) * math.sin(delta_lambda / 2) ** 2
|
||||
return radius * 2 * math.atan2(math.sqrt(hav), math.sqrt(1 - hav))
|
||||
|
||||
|
||||
def _metadata(dataset: Dataset) -> dict[str, object]:
|
||||
try:
|
||||
value = json.loads(dataset.metadata_json or "{}")
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return value if isinstance(value, dict) else {}
|
||||
|
||||
|
||||
def _emit(
|
||||
progress_callback: ProgressCallback | None,
|
||||
event_type: str,
|
||||
message: str,
|
||||
progress_current: int | None,
|
||||
progress_total: int | None,
|
||||
metadata: dict[str, object] | None = None,
|
||||
) -> None:
|
||||
if progress_callback is not None:
|
||||
progress_callback(event_type, message, progress_current, progress_total, metadata)
|
||||
40
app/pipeline/run.py
Normal file
40
app/pipeline/run.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Callable, Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import Source
|
||||
from app.pipeline.gtfs import run_gtfs_source
|
||||
from app.pipeline.osm_diff import run_osm_diff_source
|
||||
from app.pipeline.osm_geojson import run_osm_geojson_source
|
||||
from app.pipeline.osm_pbf import run_osm_pbf_source
|
||||
|
||||
|
||||
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, Any] | None], None]
|
||||
|
||||
|
||||
def run_source(session: Session, source: Source, progress_callback: ProgressCallback | None = None):
|
||||
source.status = "running"
|
||||
source.last_run_at = datetime.now(timezone.utc)
|
||||
source.last_error = None
|
||||
session.flush()
|
||||
try:
|
||||
if source.kind == "gtfs":
|
||||
dataset = run_gtfs_source(session, source, progress_callback=progress_callback)
|
||||
elif source.kind == "osm_geojson":
|
||||
dataset = run_osm_geojson_source(session, source)
|
||||
elif source.kind == "osm_pbf":
|
||||
dataset = run_osm_pbf_source(session, source, progress_callback=progress_callback)
|
||||
elif source.kind == "osm_diff":
|
||||
dataset = run_osm_diff_source(session, source)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source kind: {source.kind}")
|
||||
source.status = "ok"
|
||||
source.last_error = None
|
||||
return dataset
|
||||
except Exception as exc: # noqa: BLE001 - persist pipeline error for UI
|
||||
source.status = "error"
|
||||
source.last_error = str(exc)
|
||||
raise
|
||||
294
app/pipeline/sample_data.py
Normal file
294
app/pipeline/sample_data.py
Normal file
@@ -0,0 +1,294 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.db import init_db
|
||||
from app.models import (
|
||||
Dataset,
|
||||
CanonicalStop,
|
||||
CanonicalStopLink,
|
||||
GtfsAgency,
|
||||
GtfsCalendar,
|
||||
GtfsCalendarDate,
|
||||
GtfsRoute,
|
||||
GtfsRoutePatternLink,
|
||||
GtfsShape,
|
||||
GtfsStop,
|
||||
GtfsStopTime,
|
||||
GtfsTripRoutePatternLink,
|
||||
GtfsTrip,
|
||||
Itinerary,
|
||||
ItineraryLeg,
|
||||
Job,
|
||||
JobEvent,
|
||||
MatchRule,
|
||||
OsmDiffState,
|
||||
OsmFeature,
|
||||
PipelineRun,
|
||||
RouteMatch,
|
||||
RoutePattern,
|
||||
RoutePatternStop,
|
||||
RoutingEdge,
|
||||
RoutingNode,
|
||||
Source,
|
||||
SourceCatalogEntry,
|
||||
SourceUpdateCheck,
|
||||
TravelRequest,
|
||||
)
|
||||
from app.pipeline.matcher import run_route_matching
|
||||
from app.pipeline.route_layer import rebuild_route_layer
|
||||
from app.pipeline.run import run_source
|
||||
|
||||
|
||||
def load_sample_project(session: Session, *, preserve_job_id: int | None = None) -> dict:
|
||||
"""Clear the DB, create a small Berlin-like GTFS + OSM sample, import, and match."""
|
||||
init_db()
|
||||
clear_project_data(session, preserve_job_id=preserve_job_id, preserve_catalog=True)
|
||||
sample_dir = settings.data_dir / "sample"
|
||||
sample_dir.mkdir(parents=True, exist_ok=True)
|
||||
gtfs_path = sample_dir / "sample_berlin.gtfs.zip"
|
||||
osm_path = sample_dir / "sample_berlin_osm.geojson"
|
||||
create_sample_gtfs(gtfs_path)
|
||||
create_sample_osm_geojson(osm_path)
|
||||
|
||||
gtfs_source = Source(name="Sample Berlin GTFS", kind="gtfs", url=str(gtfs_path), country="DE", license="sample")
|
||||
osm_source = Source(name="Sample Berlin OSM transport", kind="osm_geojson", url=str(osm_path), country="DE", license="sample")
|
||||
session.add_all([gtfs_source, osm_source])
|
||||
session.flush()
|
||||
|
||||
gtfs_dataset = run_source(session, gtfs_source)
|
||||
osm_dataset = run_source(session, osm_source)
|
||||
match_result = run_route_matching(session)
|
||||
route_layer_result = rebuild_route_layer(session)
|
||||
return {
|
||||
"status": "ok",
|
||||
"gtfs_dataset_id": gtfs_dataset.id,
|
||||
"osm_dataset_id": osm_dataset.id,
|
||||
"match_result": match_result,
|
||||
"route_layer_result": route_layer_result,
|
||||
}
|
||||
|
||||
|
||||
def clear_project_data(
|
||||
session: Session,
|
||||
*,
|
||||
preserve_job_id: int | None = None,
|
||||
preserve_catalog: bool = True,
|
||||
) -> None:
|
||||
"""Clear user/project data while optionally preserving the current queue job."""
|
||||
session.execute(delete(PipelineRun))
|
||||
if preserve_job_id is None:
|
||||
session.execute(delete(JobEvent))
|
||||
session.execute(delete(Job))
|
||||
else:
|
||||
_cancel_other_jobs_for_reset(session, preserve_job_id)
|
||||
|
||||
for model in [
|
||||
ItineraryLeg,
|
||||
Itinerary,
|
||||
TravelRequest,
|
||||
SourceUpdateCheck,
|
||||
OsmDiffState,
|
||||
MatchRule,
|
||||
RouteMatch,
|
||||
GtfsTripRoutePatternLink,
|
||||
GtfsRoutePatternLink,
|
||||
RoutePatternStop,
|
||||
RoutePattern,
|
||||
CanonicalStopLink,
|
||||
CanonicalStop,
|
||||
RoutingEdge,
|
||||
RoutingNode,
|
||||
GtfsStopTime,
|
||||
GtfsCalendarDate,
|
||||
GtfsCalendar,
|
||||
GtfsShape,
|
||||
GtfsTrip,
|
||||
GtfsRoute,
|
||||
GtfsStop,
|
||||
GtfsAgency,
|
||||
OsmFeature,
|
||||
Dataset,
|
||||
Source,
|
||||
]:
|
||||
session.execute(delete(model))
|
||||
if not preserve_catalog:
|
||||
session.execute(delete(SourceCatalogEntry))
|
||||
session.flush()
|
||||
|
||||
|
||||
def _cancel_other_jobs_for_reset(session: Session, preserve_job_id: int) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
active_statuses = {"queued", "running", "paused"}
|
||||
jobs = session.scalars(
|
||||
select(Job).where(Job.id != preserve_job_id, Job.status.in_(active_statuses))
|
||||
).all()
|
||||
for job in jobs:
|
||||
job.status = "cancelled"
|
||||
job.requested_action = None
|
||||
job.lease_owner = None
|
||||
job.lease_expires_at = None
|
||||
job.paused_at = None
|
||||
job.error = None
|
||||
job.updated_at = now
|
||||
job.finished_at = now
|
||||
session.add(
|
||||
JobEvent(
|
||||
job_id=job.id,
|
||||
event_type="cancelled_by_reset",
|
||||
message=f"Job cancelled by reset job #{preserve_job_id}.",
|
||||
progress_current=job.progress_current,
|
||||
progress_total=job.progress_total,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def create_sample_gtfs(path: Path) -> None:
|
||||
agencies = [
|
||||
{"agency_id": "BVG", "agency_name": "BVG", "agency_url": "https://example.invalid/bvg", "agency_timezone": "Europe/Berlin"},
|
||||
{"agency_id": "DB", "agency_name": "DB Regio", "agency_url": "https://example.invalid/db", "agency_timezone": "Europe/Berlin"},
|
||||
{"agency_id": "XAIR", "agency_name": "Example Airport Shuttle", "agency_url": "https://example.invalid/xair", "agency_timezone": "Europe/Berlin"},
|
||||
]
|
||||
stops = [
|
||||
{"stop_id": "hbf", "stop_name": "Berlin Hauptbahnhof", "stop_lat": "52.5251", "stop_lon": "13.3696"},
|
||||
{"stop_id": "friedrich", "stop_name": "Friedrichstraße", "stop_lat": "52.5201", "stop_lon": "13.3862"},
|
||||
{"stop_id": "alex", "stop_name": "Alexanderplatz", "stop_lat": "52.5219", "stop_lon": "13.4132"},
|
||||
{"stop_id": "ost", "stop_name": "Ostbahnhof", "stop_lat": "52.5100", "stop_lon": "13.4344"},
|
||||
{"stop_id": "zoo", "stop_name": "Zoologischer Garten", "stop_lat": "52.5069", "stop_lon": "13.3320"},
|
||||
{"stop_id": "wittenberg", "stop_name": "Wittenbergplatz", "stop_lat": "52.5020", "stop_lon": "13.3430"},
|
||||
{"stop_id": "potsdamer", "stop_name": "Potsdamer Platz", "stop_lat": "52.5096", "stop_lon": "13.3760"},
|
||||
{"stop_id": "stadtmitte", "stop_name": "Stadtmitte", "stop_lat": "52.5113", "stop_lon": "13.3907"},
|
||||
{"stop_id": "reichstag", "stop_name": "Reichstag", "stop_lat": "52.5186", "stop_lon": "13.3763"},
|
||||
{"stop_id": "hackescher", "stop_name": "Hackescher Markt", "stop_lat": "52.5220", "stop_lon": "13.4023"},
|
||||
{"stop_id": "naturkunde", "stop_name": "Naturkundemuseum", "stop_lat": "52.5300", "stop_lon": "13.3790"},
|
||||
{"stop_id": "wannsee", "stop_name": "Wannsee", "stop_lat": "52.4210", "stop_lon": "13.1797"},
|
||||
{"stop_id": "kladow", "stop_name": "Kladow", "stop_lat": "52.4547", "stop_lon": "13.1439"},
|
||||
{"stop_id": "airport", "stop_name": "Example Airport Terminal", "stop_lat": "52.3650", "stop_lon": "13.5100"},
|
||||
]
|
||||
routes = [
|
||||
{"route_id": "u2", "agency_id": "BVG", "route_short_name": "U2", "route_long_name": "Pankow - Ruhleben", "route_type": "1"},
|
||||
{"route_id": "re1", "agency_id": "DB", "route_short_name": "RE1", "route_long_name": "Magdeburg - Frankfurt Oder", "route_type": "2"},
|
||||
{"route_id": "m5", "agency_id": "BVG", "route_short_name": "M5", "route_long_name": "Hauptbahnhof - Hohenschönhausen", "route_type": "0"},
|
||||
{"route_id": "bus100", "agency_id": "BVG", "route_short_name": "100", "route_long_name": "Zoologischer Garten - Alexanderplatz", "route_type": "3"},
|
||||
{"route_id": "f10", "agency_id": "BVG", "route_short_name": "F10", "route_long_name": "Wannsee - Kladow", "route_type": "4"},
|
||||
{"route_id": "x99", "agency_id": "XAIR", "route_short_name": "X99", "route_long_name": "Airport Express Sample", "route_type": "3"},
|
||||
]
|
||||
trips = [
|
||||
{"route_id": r["route_id"], "service_id": "daily", "trip_id": f"{r['route_id']}_trip", "shape_id": f"{r['route_id']}_shape"}
|
||||
for r in routes
|
||||
]
|
||||
stop_sequences = {
|
||||
"u2_trip": ["zoo", "wittenberg", "potsdamer", "stadtmitte", "alex"],
|
||||
"re1_trip": ["hbf", "friedrich", "alex", "ost"],
|
||||
"m5_trip": ["hbf", "naturkunde", "hackescher", "alex"],
|
||||
"bus100_trip": ["zoo", "reichstag", "alex"],
|
||||
"f10_trip": ["wannsee", "kladow"],
|
||||
"x99_trip": ["alex", "airport"],
|
||||
}
|
||||
coords = {row["stop_id"]: (row["stop_lon"], row["stop_lat"]) for row in stops}
|
||||
stop_times = []
|
||||
shapes = []
|
||||
for trip in trips:
|
||||
trip_id = trip["trip_id"]
|
||||
for idx, stop_id in enumerate(stop_sequences[trip_id], start=1):
|
||||
stop_times.append(
|
||||
{
|
||||
"trip_id": trip_id,
|
||||
"arrival_time": f"08:{idx * 5:02d}:00",
|
||||
"departure_time": f"08:{idx * 5 + 1:02d}:00",
|
||||
"stop_id": stop_id,
|
||||
"stop_sequence": str(idx),
|
||||
}
|
||||
)
|
||||
lon, lat = coords[stop_id]
|
||||
shapes.append(
|
||||
{
|
||||
"shape_id": trip["shape_id"],
|
||||
"shape_pt_lat": lat,
|
||||
"shape_pt_lon": lon,
|
||||
"shape_pt_sequence": str(idx),
|
||||
}
|
||||
)
|
||||
calendar = [
|
||||
{
|
||||
"service_id": "daily",
|
||||
"monday": "1",
|
||||
"tuesday": "1",
|
||||
"wednesday": "1",
|
||||
"thursday": "1",
|
||||
"friday": "1",
|
||||
"saturday": "1",
|
||||
"sunday": "1",
|
||||
"start_date": "20260101",
|
||||
"end_date": "20261231",
|
||||
}
|
||||
]
|
||||
|
||||
with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
|
||||
_write_csv(zf, "agency.txt", ["agency_id", "agency_name", "agency_url", "agency_timezone"], agencies)
|
||||
_write_csv(zf, "stops.txt", ["stop_id", "stop_name", "stop_lat", "stop_lon"], stops)
|
||||
_write_csv(zf, "routes.txt", ["route_id", "agency_id", "route_short_name", "route_long_name", "route_type"], routes)
|
||||
_write_csv(zf, "trips.txt", ["route_id", "service_id", "trip_id", "shape_id"], trips)
|
||||
_write_csv(zf, "stop_times.txt", ["trip_id", "arrival_time", "departure_time", "stop_id", "stop_sequence"], stop_times)
|
||||
_write_csv(
|
||||
zf,
|
||||
"calendar.txt",
|
||||
["service_id", "monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday", "start_date", "end_date"],
|
||||
calendar,
|
||||
)
|
||||
_write_csv(zf, "shapes.txt", ["shape_id", "shape_pt_lat", "shape_pt_lon", "shape_pt_sequence"], shapes)
|
||||
|
||||
|
||||
def _write_csv(zf: zipfile.ZipFile, name: str, fields: list[str], rows: list[dict[str, str]]) -> None:
|
||||
buffer = io.StringIO(newline="")
|
||||
writer = csv.DictWriter(buffer, fieldnames=fields)
|
||||
writer.writeheader()
|
||||
writer.writerows(rows)
|
||||
zf.writestr(name, buffer.getvalue())
|
||||
|
||||
|
||||
def create_sample_osm_geojson(path: Path) -> None:
|
||||
def line(fid, mode, ref, name, operator, coords):
|
||||
return {
|
||||
"type": "Feature",
|
||||
"geometry": {"type": "LineString", "coordinates": coords},
|
||||
"properties": {
|
||||
"osm_type": "relation",
|
||||
"osm_id": str(fid),
|
||||
"type": "route",
|
||||
"route": mode,
|
||||
"ref": ref,
|
||||
"name": name,
|
||||
"operator": operator,
|
||||
"network": "VBB" if operator == "BVG" else "DB",
|
||||
},
|
||||
}
|
||||
|
||||
def point(fid, kind, name, coords, props=None):
|
||||
props = props or {}
|
||||
props.update({"osm_type": "node", "osm_id": str(fid), "name": name})
|
||||
return {"type": "Feature", "geometry": {"type": "Point", "coordinates": coords}, "properties": props}
|
||||
|
||||
features = [
|
||||
line(1002, "subway", "U2", "U2 Ruhleben - Pankow", "BVG", [[13.3320, 52.5069], [13.3430, 52.5020], [13.3760, 52.5096], [13.3907, 52.5113], [13.4132, 52.5219]]),
|
||||
line(2001, "train", "RE1", "RE1 Magdeburg - Frankfurt Oder", "DB Regio", [[13.3696, 52.5251], [13.3862, 52.5201], [13.4132, 52.5219], [13.4344, 52.5100]]),
|
||||
line(5005, "tram", "M5", "M5 Hauptbahnhof - Hohenschönhausen", "BVG", [[13.3696, 52.5251], [13.3790, 52.5300], [13.4023, 52.5220], [13.4132, 52.5219]]),
|
||||
line(6100, "bus", "100", "Bus 100 Zoologischer Garten - Alexanderplatz", "BVG", [[13.3320, 52.5069], [13.3763, 52.5186], [13.4132, 52.5219]]),
|
||||
line(7010, "ferry", "F10", "F10 Wannsee - Kladow", "BVG", [[13.1797, 52.4210], [13.1439, 52.4547]]),
|
||||
line(5010, "tram", "M10", "M10 Warschauer Straße - Hauptbahnhof", "BVG", [[13.4500, 52.5050], [13.4020, 52.5300], [13.3696, 52.5251]]),
|
||||
point(9001, "station", "Berlin Hauptbahnhof", [13.3696, 52.5251], {"railway": "station"}),
|
||||
point(9002, "station", "Alexanderplatz", [13.4132, 52.5219], {"railway": "station"}),
|
||||
point(9003, "stop", "Zoologischer Garten", [13.3320, 52.5069], {"public_transport": "station", "railway": "station"}),
|
||||
point(9004, "terminal", "Wannsee Ferry Terminal", [13.1797, 52.4210], {"amenity": "ferry_terminal"}),
|
||||
point(9005, "terminal", "Kladow Ferry Terminal", [13.1439, 52.4547], {"amenity": "ferry_terminal"}),
|
||||
]
|
||||
path.write_text(json.dumps({"type": "FeatureCollection", "features": features}, indent=2), encoding="utf-8")
|
||||
135
app/pipeline/state.py
Normal file
135
app/pipeline/state.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import PipelineRun
|
||||
|
||||
|
||||
STAGE_ACQUIRE_RAW = "acquire_raw"
|
||||
STAGE_FILTER_TRANSPORT = "filter_transport"
|
||||
STAGE_EXTRACT_GEOMETRY = "extract_geometry"
|
||||
STAGE_LABEL_FEATURES = "label_features"
|
||||
STAGE_BUILD_INDEXES = "build_indexes"
|
||||
STAGE_MATCH_ROUTES = "match_routes"
|
||||
STAGE_BUILD_ROUTE_LAYER = "build_route_layer"
|
||||
|
||||
|
||||
def stable_json(value: Any) -> str:
|
||||
return json.dumps(value, sort_keys=True, separators=(",", ":"), default=str)
|
||||
|
||||
|
||||
def dependency_hash(value: Any) -> str:
|
||||
return hashlib.sha256(stable_json(value).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def latest_completed_run(
|
||||
session: Session,
|
||||
*,
|
||||
stage: str,
|
||||
version: str,
|
||||
dependency_hash_value: str,
|
||||
source_id: int | None = None,
|
||||
dataset_id: int | None = None,
|
||||
) -> PipelineRun | None:
|
||||
stmt = (
|
||||
select(PipelineRun)
|
||||
.where(
|
||||
PipelineRun.stage == stage,
|
||||
PipelineRun.version == version,
|
||||
PipelineRun.dependency_hash == dependency_hash_value,
|
||||
PipelineRun.status == "completed",
|
||||
)
|
||||
.order_by(PipelineRun.finished_at.desc(), PipelineRun.id.desc())
|
||||
.limit(1)
|
||||
)
|
||||
if source_id is None:
|
||||
stmt = stmt.where(PipelineRun.source_id.is_(None))
|
||||
else:
|
||||
stmt = stmt.where(PipelineRun.source_id == source_id)
|
||||
if dataset_id is None:
|
||||
stmt = stmt.where(PipelineRun.dataset_id.is_(None))
|
||||
else:
|
||||
stmt = stmt.where(PipelineRun.dataset_id == dataset_id)
|
||||
return session.scalar(stmt)
|
||||
|
||||
|
||||
def start_pipeline_run(
|
||||
session: Session,
|
||||
*,
|
||||
stage: str,
|
||||
version: str,
|
||||
dependency_hash_value: str,
|
||||
source_id: int | None = None,
|
||||
dataset_id: int | None = None,
|
||||
job_id: int | None = None,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> PipelineRun:
|
||||
now = datetime.now(timezone.utc)
|
||||
run = PipelineRun(
|
||||
stage=stage,
|
||||
version=version,
|
||||
dependency_hash=dependency_hash_value,
|
||||
status="running",
|
||||
source_id=source_id,
|
||||
dataset_id=dataset_id,
|
||||
job_id=job_id,
|
||||
input_json=None if inputs is None else stable_json(inputs),
|
||||
started_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
session.add(run)
|
||||
session.flush()
|
||||
return run
|
||||
|
||||
|
||||
def finish_pipeline_run(
|
||||
session: Session,
|
||||
run: PipelineRun,
|
||||
*,
|
||||
status: str = "completed",
|
||||
outputs: dict[str, Any] | None = None,
|
||||
error: str | None = None,
|
||||
) -> PipelineRun:
|
||||
now = datetime.now(timezone.utc)
|
||||
run.status = status
|
||||
run.output_json = None if outputs is None else stable_json(outputs)
|
||||
run.error = error
|
||||
run.updated_at = now
|
||||
run.finished_at = now
|
||||
session.flush()
|
||||
return run
|
||||
|
||||
|
||||
def pipeline_run_payload(run: PipelineRun) -> dict[str, Any]:
|
||||
return {
|
||||
"id": run.id,
|
||||
"stage": run.stage,
|
||||
"version": run.version,
|
||||
"dependency_hash": run.dependency_hash,
|
||||
"status": run.status,
|
||||
"source_id": run.source_id,
|
||||
"dataset_id": run.dataset_id,
|
||||
"job_id": run.job_id,
|
||||
"input": _json_object(run.input_json),
|
||||
"output": _json_object(run.output_json),
|
||||
"error": run.error,
|
||||
"started_at": run.started_at.isoformat() if run.started_at else None,
|
||||
"updated_at": run.updated_at.isoformat() if run.updated_at else None,
|
||||
"finished_at": run.finished_at.isoformat() if run.finished_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _json_object(text: str | None) -> dict[str, Any]:
|
||||
if not text:
|
||||
return {}
|
||||
try:
|
||||
value = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return value if isinstance(value, dict) else {}
|
||||
89
app/pipeline/utils.py
Normal file
89
app/pipeline/utils.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from shapely.geometry import shape
|
||||
|
||||
|
||||
def sha256_file(path: Path) -> str:
|
||||
h = hashlib.sha256()
|
||||
with path.open("rb") as f:
|
||||
for chunk in iter(lambda: f.read(1024 * 1024), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def norm_text(value: object) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
value = str(value).lower().strip()
|
||||
value = value.replace("ß", "ss")
|
||||
value = re.sub(r"[^a-z0-9]+", " ", value)
|
||||
return re.sub(r"\s+", " ", value).strip()
|
||||
|
||||
|
||||
def norm_ref(value: object) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
return re.sub(r"[^a-z0-9]+", "", str(value).lower())
|
||||
|
||||
|
||||
def first_nonempty(*values: object) -> str:
|
||||
for value in values:
|
||||
if value is None:
|
||||
continue
|
||||
text = str(value).strip()
|
||||
if text:
|
||||
return text
|
||||
return ""
|
||||
|
||||
|
||||
def geometry_json_and_bbox(geometry: object) -> tuple[Optional[str], tuple[Optional[float], Optional[float], Optional[float], Optional[float]]]:
|
||||
if geometry is None:
|
||||
return None, (None, None, None, None)
|
||||
try:
|
||||
geom = shape(geometry) if isinstance(geometry, dict) else geometry
|
||||
if geom.is_empty:
|
||||
return None, (None, None, None, None)
|
||||
min_lon, min_lat, max_lon, max_lat = geom.bounds
|
||||
return json.dumps(geom.__geo_interface__, separators=(",", ":")), (min_lon, min_lat, max_lon, max_lat)
|
||||
except Exception:
|
||||
return None, (None, None, None, None)
|
||||
|
||||
|
||||
def bbox_overlap(a: tuple[float | None, float | None, float | None, float | None], b: tuple[float | None, float | None, float | None, float | None]) -> bool:
|
||||
if any(v is None for v in (*a, *b)):
|
||||
return False
|
||||
aminx, aminy, amaxx, amaxy = a # type: ignore[misc]
|
||||
bminx, bminy, bmaxx, bmaxy = b # type: ignore[misc]
|
||||
return not (amaxx < bminx or bmaxx < aminx or amaxy < bminy or bmaxy < aminy)
|
||||
|
||||
|
||||
def bbox_center(b: tuple[float | None, float | None, float | None, float | None]) -> Optional[tuple[float, float]]:
|
||||
if any(v is None for v in b):
|
||||
return None
|
||||
minx, miny, maxx, maxy = b # type: ignore[misc]
|
||||
return ((minx + maxx) / 2, (miny + maxy) / 2)
|
||||
|
||||
|
||||
def approx_bbox_center_distance_deg(a: tuple[float | None, float | None, float | None, float | None], b: tuple[float | None, float | None, float | None, float | None]) -> Optional[float]:
|
||||
ca = bbox_center(a)
|
||||
cb = bbox_center(b)
|
||||
if ca is None or cb is None:
|
||||
return None
|
||||
return ((ca[0] - cb[0]) ** 2 + (ca[1] - cb[1]) ** 2) ** 0.5
|
||||
|
||||
|
||||
def batched(iterable: Iterable[dict], batch_size: int = 1000) -> Iterable[list[dict]]:
|
||||
batch: list[dict] = []
|
||||
for item in iterable:
|
||||
batch.append(item)
|
||||
if len(batch) >= batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
if batch:
|
||||
yield batch
|
||||
Reference in New Issue
Block a user