Alpha stage commit

This commit is contained in:
2026-07-01 23:29:51 +02:00
parent b583bb1233
commit e23387738b
84 changed files with 40807 additions and 326 deletions

1
app/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Mobility Workbench prototype."""

1272
app/address_search.py Normal file

File diff suppressed because it is too large Load Diff

394
app/cli.py Normal file
View File

@@ -0,0 +1,394 @@
from __future__ import annotations
import json
import csv
from pathlib import Path
from typing import Optional
import typer
from sqlalchemy import func, select, text
from app.config import settings
from app.data_management import dataset_sidecar_paths, prune_inactive_datasets
from app.db import engine, init_db, reset_db, session_scope
from app.db_lock import database_write_lock
from app.feed_discovery import build_gtfs_discovery_manifests, default_generated_dir
from app.models import (
Dataset,
GtfsRoute,
GtfsShape,
GtfsStop,
RouteMatch,
RoutePattern,
Source,
SourceCatalogEntry,
)
from app.pipeline.matcher import run_route_matching
from app.pipeline.osm_labeling import relabel_osm_features
from app.pipeline.osm_pbf import run_osm_pbf_source_staged
from app.pipeline.run import run_source
from app.pipeline.gtfs import backfill_gtfs_shapes
from app.pipeline.route_layer import rebuild_route_layer
from app.pipeline.sample_data import load_sample_project
from app.osm_storage import osm_feature_count
from app.jobs import run_worker_loop
from app.jobs import create_route_layer_rebuild_job, create_route_matching_job, create_source_import_job
from app.source_catalog import (
default_ingestable_sources_path,
default_source_catalog_path,
import_ingestable_sources,
import_source_catalog,
source_catalog_summary,
)
cli = typer.Typer(help="Mobility Workbench pipeline CLI")
@cli.command("init-db")
def init_db_command() -> None:
with _write_lock("init-db"):
init_db()
typer.echo("Database initialized")
@cli.command("reset-db")
def reset_db_command() -> None:
with _write_lock("reset-db"):
reset_db()
typer.echo("Database reset")
@cli.command("load-sample")
def load_sample_command() -> None:
with _write_lock("load-sample"):
init_db()
with session_scope() as session:
result = load_sample_project(session)
typer.echo(json.dumps(result, indent=2))
@cli.command("add-source")
def add_source_command(
name: str = typer.Option(..., help="Source name"),
kind: str = typer.Option(..., help="gtfs, osm_geojson, osm_pbf, or osm_diff"),
url: str = typer.Option(..., help="HTTP URL or local path"),
country: Optional[str] = typer.Option(None),
license: Optional[str] = typer.Option(None),
priority: Optional[str] = typer.Option(None),
mode_scope: Optional[str] = typer.Option(None),
source_basis: Optional[str] = typer.Option(None),
notes: Optional[str] = typer.Option(None),
) -> None:
with _write_lock("add-source"):
init_db()
if kind not in {"gtfs", "osm_geojson", "osm_pbf", "osm_diff"}:
raise typer.BadParameter("kind must be gtfs, osm_geojson, osm_pbf, or osm_diff")
with session_scope() as session:
source = Source(
name=name,
kind=kind,
url=url,
country=country,
license=license,
priority=priority,
mode_scope=mode_scope,
source_basis=source_basis,
notes=notes,
)
session.add(source)
session.flush()
typer.echo(json.dumps({"id": source.id, "name": source.name}, indent=2))
@cli.command("run-source")
def run_source_command(source_id: int) -> None:
init_db()
with session_scope() as session:
source = session.get(Source, source_id)
if source is None:
raise typer.BadParameter(f"source not found: {source_id}")
source_kind = source.kind
if source_kind == "osm_pbf":
dataset = run_osm_pbf_source_staged(source_id)
typer.echo(json.dumps({"source_id": source_id, "dataset_id": dataset.id, "status": dataset.status, "import_mode": "staged_short_lock"}, indent=2))
return
with _write_lock("run-source"):
with session_scope() as session:
source = session.get(Source, source_id)
if source is None:
raise typer.BadParameter(f"source not found: {source_id}")
dataset = run_source(session, source)
typer.echo(json.dumps({"source_id": source.id, "dataset_id": dataset.id, "status": dataset.status}, indent=2))
@cli.command("run-match")
def run_match_command() -> None:
with _write_lock("run-match"):
init_db()
with session_scope() as session:
result = run_route_matching(session)
typer.echo(json.dumps(result, indent=2))
@cli.command("build-route-layer")
def build_route_layer_command() -> None:
with _write_lock("build-route-layer"):
init_db()
with session_scope() as session:
result = rebuild_route_layer(session)
typer.echo(json.dumps(result, indent=2))
@cli.command("relabel-osm-features")
def relabel_osm_features_command(
dataset_id: Optional[int] = typer.Option(None, help="Only relabel one OSM dataset"),
force: bool = typer.Option(False, help="Run even when the recorded dependency signature is current"),
chunk_size: int = typer.Option(5000, help="Rows per relabel batch"),
rebuild_indexes: bool = typer.Option(True, help="Drop/rebuild affected route-scope indexes around large relabel writes"),
build_route_layer: bool = typer.Option(True, help="Rebuild the route layer after relabeling"),
) -> None:
with _write_lock("relabel-osm-features"):
init_db()
with session_scope() as session:
result = relabel_osm_features(
session,
dataset_id=dataset_id,
force=force,
chunk_size=chunk_size,
rebuild_indexes=rebuild_indexes,
)
if build_route_layer and (result["changed"] or force):
result["route_layer_result"] = rebuild_route_layer(session)
typer.echo(json.dumps(result, indent=2))
@cli.command("backfill-gtfs-shapes")
def backfill_gtfs_shapes_command(dataset_id: Optional[int] = typer.Option(None, help="Only backfill one GTFS dataset")) -> None:
with _write_lock("backfill-gtfs-shapes"):
init_db()
with session_scope() as session:
result = backfill_gtfs_shapes(session, dataset_id=dataset_id)
typer.echo(json.dumps(result, indent=2))
@cli.command("stats")
def stats_command() -> None:
init_db()
with session_scope() as session:
active_dataset_ids = [row[0] for row in session.execute(select(Dataset.id).where(Dataset.is_active.is_(True))).all()]
stats = {
"sources": session.scalar(select(func.count()).select_from(Source)),
"source_catalog_entries": session.scalar(select(func.count()).select_from(SourceCatalogEntry)) or 0,
"active_datasets": len(active_dataset_ids),
"gtfs_routes": session.scalar(select(func.count()).select_from(GtfsRoute).where(GtfsRoute.dataset_id.in_(active_dataset_ids))) if active_dataset_ids else 0,
"gtfs_stops": session.scalar(select(func.count()).select_from(GtfsStop).where(GtfsStop.dataset_id.in_(active_dataset_ids))) if active_dataset_ids else 0,
"gtfs_shapes": session.scalar(select(func.count()).select_from(GtfsShape).where(GtfsShape.dataset_id.in_(active_dataset_ids))) if active_dataset_ids else 0,
"route_patterns": session.scalar(select(func.count()).select_from(RoutePattern)) or 0,
"osm_routes": sum(osm_feature_count(session, dataset_id, kind="route") for dataset_id in active_dataset_ids),
"matches": {status: count for status, count in session.execute(select(RouteMatch.status, func.count()).group_by(RouteMatch.status)).all()},
}
typer.echo(json.dumps(stats, indent=2))
@cli.command("import-source-catalog")
def import_source_catalog_command(
csv_path: Path = typer.Option(default_source_catalog_path(), "--csv", help="Source catalog CSV path"),
no_update: bool = typer.Option(False, help="Skip rows that already exist"),
) -> None:
with _write_lock("import-source-catalog"):
init_db()
with session_scope() as session:
result = import_source_catalog(session, csv_path, update_existing=not no_update)
result["summary"] = source_catalog_summary(session)
typer.echo(json.dumps(result, indent=2))
@cli.command("import-ingestable-sources")
def import_ingestable_sources_command(
csv_path: Path = typer.Option(default_ingestable_sources_path(), "--csv", help="Ingestable source seed CSV path"),
no_update: bool = typer.Option(False, help="Skip sources that already exist"),
) -> None:
with _write_lock("import-ingestable-sources"):
init_db()
with session_scope() as session:
result = import_ingestable_sources(session, csv_path, update_existing=not no_update)
result["summary"] = source_catalog_summary(session)
typer.echo(json.dumps(result, indent=2))
@cli.command("discover-gtfs-sources")
def discover_gtfs_sources_command(
output_dir: Path = typer.Option(default_generated_dir(), "--output-dir", help="Directory for generated discovery CSVs"),
countries: str = typer.Option(
",".join(["DE", "AT", "CH", "NL", "DK", "FR", "BE", "LU", "NO", "SE", "FI", "IE", "GB"]),
"--countries",
help="Comma-separated country codes, or ALL for every country exposed by the upstream catalogs",
),
no_mobility_database: bool = typer.Option(False, help="Skip Mobility Database feeds_v2.csv"),
no_acceptance_test_list: bool = typer.Option(False, help="Skip MobilityData validator acceptance-test feed list"),
no_ptna: bool = typer.Option(False, help="Skip PTNA GTFS analysis pages"),
max_ptna_details: int = typer.Option(80, help="Maximum PTNA detail pages to fetch for license/crosswalk metadata"),
test_limit: int = typer.Option(24, help="Rows to write to the focused test-run ingestable CSV"),
check_urls: bool = typer.Option(False, help="Run HEAD/range checks for ingestable feed URLs"),
) -> None:
result = build_gtfs_discovery_manifests(
output_dir=output_dir,
countries=[part.strip() for part in countries.split(",") if part.strip()],
include_mobility_database=not no_mobility_database,
include_acceptance_test_list=not no_acceptance_test_list,
include_ptna=not no_ptna,
max_ptna_details=max_ptna_details,
test_limit=test_limit,
check_urls=check_urls,
)
typer.echo(json.dumps(result, indent=2, ensure_ascii=False))
@cli.command("queue-source-imports-from-csv")
def queue_source_imports_from_csv_command(
csv_path: Path = typer.Option(default_ingestable_sources_path(), "--csv", help="Ingestable source CSV path"),
no_update: bool = typer.Option(False, help="Skip sources that already exist instead of updating them"),
run_match_at_end: bool = typer.Option(True, help="Queue one route-matching job after all source imports"),
build_route_layer_at_end: bool = typer.Option(True, help="Queue one route-layer rebuild after route matching"),
priority: int = typer.Option(0, help="Priority for queued source import jobs"),
) -> None:
with _write_lock("queue-source-imports-from-csv"):
init_db()
with session_scope() as session:
csv_path = csv_path if csv_path.is_absolute() else Path.cwd() / csv_path
imported = import_ingestable_sources(session, csv_path, update_existing=not no_update)
source_urls = _source_urls_from_ingestable_csv(csv_path)
sources = session.scalars(
select(Source)
.where(Source.kind == "gtfs", Source.url.in_(source_urls))
.order_by(Source.id)
).all()
jobs = [
create_source_import_job(
session,
source,
run_match=False,
build_route_layer=False,
priority=priority,
)
for source in sources
]
route_match_job = create_route_matching_job(session, priority=priority) if run_match_at_end else None
route_layer_job = create_route_layer_rebuild_job(session, priority=priority) if build_route_layer_at_end else None
typer.echo(
json.dumps(
{
"csv": str(csv_path),
"imported": imported,
"sources": [{"id": source.id, "name": source.name} for source in sources],
"source_import_jobs": [job.id for job in jobs],
"route_match_job": None if route_match_job is None else route_match_job.id,
"route_layer_job": None if route_layer_job is None else route_layer_job.id,
},
indent=2,
ensure_ascii=False,
)
)
@cli.command("prune-cache")
def prune_cache_command(dry_run: bool = typer.Option(False, help="Report files without deleting them")) -> None:
with _write_lock("prune-cache"):
init_db()
with session_scope() as session:
referenced = {
Path(path).resolve()
for path in session.scalars(select(Dataset.local_path)).all()
if path
}
for dataset in session.scalars(select(Dataset)).all():
referenced.update(path.resolve() for path in dataset_sidecar_paths(dataset))
roots = [settings.data_dir / "sources", settings.data_dir / "derived", settings.data_dir / "sidecars", settings.data_dir / "staging"]
candidates = [
path
for root in roots
if root.exists()
for path in root.rglob("*")
if path.is_file() and path.resolve() not in referenced
]
total_bytes = sum(path.stat().st_size for path in candidates)
if not dry_run:
for path in candidates:
path.unlink()
for root in roots:
_remove_empty_dirs(root)
typer.echo(
json.dumps(
{
"dry_run": dry_run,
"files": len(candidates),
"bytes": total_bytes,
"deleted": 0 if dry_run else len(candidates),
},
indent=2,
)
)
@cli.command("prune-inactive-datasets")
def prune_inactive_datasets_command(
dry_run: bool = typer.Option(False, help="Report inactive normalized datasets without deleting them"),
) -> None:
with _write_lock("prune-inactive-datasets"):
init_db()
with session_scope() as session:
result = prune_inactive_datasets(session, dry_run=dry_run)
typer.echo(json.dumps(result, indent=2))
@cli.command("vacuum-db")
def vacuum_db_command() -> None:
with _write_lock("vacuum-db"):
init_db()
with engine.connect().execution_options(isolation_level="AUTOCOMMIT") as connection:
connection.execute(text("VACUUM"))
connection.execute(text("PRAGMA wal_checkpoint(TRUNCATE)"))
typer.echo("Database vacuumed")
@cli.command("worker")
def worker_command(
once: bool = typer.Option(False, help="Process at most one queued job and exit"),
max_jobs: Optional[int] = typer.Option(None, help="Process at most this many jobs and exit"),
poll_interval: float = typer.Option(2.0, help="Seconds to wait between queue polls"),
worker_id: Optional[str] = typer.Option(None, help="Stable worker identifier"),
) -> None:
result = run_worker_loop(worker_id=worker_id, poll_interval=poll_interval, max_jobs=max_jobs, once=once)
typer.echo(json.dumps(result, indent=2))
def _remove_empty_dirs(root: Path) -> None:
if not root.exists():
return
for path in sorted((p for p in root.rglob("*") if p.is_dir()), key=lambda p: len(p.parts), reverse=True):
try:
path.rmdir()
except OSError:
pass
def _write_lock(operation: str):
return database_write_lock(f"cli:{operation}", timeout=settings.database_write_lock_cli_timeout_seconds)
def _source_urls_from_ingestable_csv(path: Path) -> list[str]:
urls: list[str] = []
with path.open("r", encoding="utf-8-sig", newline="") as handle:
for row in csv.DictReader(handle):
if (row.get("kind") or "").strip().lower() != "gtfs":
continue
url = (row.get("url") or "").strip()
if url and url not in urls:
urls.append(url)
return urls
if __name__ == "__main__":
cli()

74
app/config.py Normal file
View File

@@ -0,0 +1,74 @@
from __future__ import annotations
from pathlib import Path
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Runtime settings.
SQLite is the default because this prototype should run immediately.
The schema is deliberately plain enough to migrate to PostGIS later.
"""
database_url: str = "sqlite:///./data/workbench.sqlite"
data_dir: Path = Path("./data")
# 0 means import all stop_times. Use a positive value only for constrained
# demos where full timetable routing is not needed.
gtfs_stop_times_import_limit: int = 0
# "sidecar_stop_times" keeps the large timetable call table in a per-dataset
# SQLite file and stores compact GTFS tables in the main app database.
# Set to "main" for the old all-in-one SQLite layout.
gtfs_timetable_storage: str = "sidecar_stop_times"
gtfs_keep_activation_stage: bool = False
# "sidecar_features" keeps extracted OSM transport features in a per-dataset
# SQLite file. The main DB materializes only OSM rows that need stable
# foreign keys for matches or route-layer output.
osm_feature_storage: str = "sidecar_features"
osm_sidecar_create_visual_only_stops: bool = False
# Large OSM PBF extracts should be reduced to transport objects before the
# Python extractor scans them. XML fixtures stay unfiltered by default.
osm_pbf_prefilter_enabled: bool = True
osm_pbf_prefilter_formats: str = "osm_pbf"
osm_pbf_prefilter_script: Path = Path("scripts/osmium_transport_filter.sh")
osm_diff_max_sequence_gap: int = 14
osm_diff_apply_batch_size: int = 7
osm_diff_state_timeout_seconds: float = 30.0
sqlite_timeout_seconds: float = 120.0
sqlite_busy_timeout_ms: int = 120000
database_write_lock_timeout_seconds: float = 1.0
database_write_lock_cli_timeout_seconds: float = 3600.0
queue_worker_autostart: bool = True
queue_worker_count: int = 1
queue_worker_poll_interval_seconds: float = 2.0
queue_job_lease_seconds: int = 7200
route_matching_batch_size: int = 100
route_layer_osm_route_batch_size: int = 1000
route_layer_osm_stop_batch_size: int = 5000
# SQLite defaults to sidecar storage. PostgreSQL/PostGIS defaults to main
# table storage so indexes, joins, and spatial operators can work over the
# full imported datasets.
postgres_use_sidecars: bool = False
# Keep supervised workers alive across API server restarts. Stale workers are
# detected by PID files at the next startup; stale job leases are requeued.
queue_worker_stop_on_shutdown: bool = False
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
@property
def normalized_database_url(self) -> str:
if self.database_url.startswith("postgresql://"):
return "postgresql+psycopg://" + self.database_url.removeprefix("postgresql://")
return self.database_url
@property
def is_sqlite_database(self) -> bool:
return self.normalized_database_url.startswith("sqlite")
@property
def is_postgresql_database(self) -> bool:
return self.normalized_database_url.startswith("postgresql")
settings = Settings()
settings.data_dir.mkdir(parents=True, exist_ok=True)

327
app/data_management.py Normal file
View File

@@ -0,0 +1,327 @@
from __future__ import annotations
from pathlib import Path
from sqlalchemy import delete, func, or_, select
from sqlalchemy.orm import Session
from app.config import settings
from app.gtfs_storage import dataset_sidecar_paths as gtfs_dataset_sidecar_paths, missing_sidecar_paths as gtfs_missing_sidecar_paths, stop_time_count
from app.models import (
CanonicalStopLink,
Dataset,
GtfsAgency,
GtfsCalendar,
GtfsCalendarDate,
GtfsRoute,
GtfsRoutePatternLink,
GtfsShape,
GtfsStop,
GtfsStopTime,
GtfsTripRoutePatternLink,
GtfsTrip,
OsmDiffState,
OsmFeature,
RouteMatch,
RoutePattern,
RoutePatternStop,
Source,
SourceUpdateCheck,
)
from app.osm_storage import (
dataset_sidecar_paths as osm_dataset_sidecar_paths,
missing_sidecar_paths as osm_missing_sidecar_paths,
osm_feature_count,
)
def dataset_row_counts(session: Session, dataset_id: int, kind: str) -> dict[str, int]:
if kind == "gtfs":
route_ids = select(GtfsRoute.id).where(GtfsRoute.dataset_id == dataset_id)
match_counts = {
status: count
for status, count in session.execute(
select(RouteMatch.status, func.count())
.where(RouteMatch.gtfs_route_id.in_(route_ids))
.group_by(RouteMatch.status)
).all()
}
return {
"agencies": _count(session, GtfsAgency, dataset_id),
"stops": _count(session, GtfsStop, dataset_id),
"routes": _count(session, GtfsRoute, dataset_id),
"trips": _count(session, GtfsTrip, dataset_id),
"calendars": _count(session, GtfsCalendar, dataset_id),
"calendar_dates": _count(session, GtfsCalendarDate, dataset_id),
"shapes": _count(session, GtfsShape, dataset_id),
"stop_times": stop_time_count(session, dataset_id),
"missing_sidecar": _gtfs_sidecar_missing(session, dataset_id),
"matches": sum(match_counts.values()),
"match_counts": match_counts,
}
if kind == "osm_geojson":
return {
"features": _safe_osm_feature_count(session, dataset_id),
"routes": _safe_osm_feature_count(session, dataset_id, kind="route"),
"stops": _safe_osm_feature_count(session, dataset_id, kind=["stop", "station", "terminal"]),
"infra": _safe_osm_feature_count(session, dataset_id, kind="infra"),
"missing_sidecar": _osm_sidecar_missing(session, dataset_id),
}
return {}
def source_row_counts(session: Session, source: Source) -> dict[str, object]:
counts = {
"datasets": len(source.datasets),
"active_datasets": sum(1 for dataset in source.datasets if dataset.is_active),
"routes": 0,
"stops": 0,
"features": 0,
"trips": 0,
"shapes": 0,
"stop_times": 0,
"missing_sidecars": 0,
"match_counts": {},
"missing_gtfs_sidecars": 0,
"missing_osm_sidecars": 0,
}
match_counts: dict[str, int] = {}
for dataset in source.datasets:
stats = dataset_row_counts(session, dataset.id, dataset.kind)
counts["routes"] += int(stats.get("routes", 0))
counts["stops"] += int(stats.get("stops", 0))
counts["features"] += int(stats.get("features", 0))
counts["trips"] += int(stats.get("trips", 0))
counts["shapes"] += int(stats.get("shapes", 0))
counts["stop_times"] += int(stats.get("stop_times", 0))
if stats.get("missing_sidecar"):
counts["missing_sidecars"] += 1
if dataset.kind == "gtfs":
counts["missing_gtfs_sidecars"] += 1
elif dataset.kind == "osm_geojson":
counts["missing_osm_sidecars"] += 1
for status, count in stats.get("match_counts", {}).items():
match_counts[status] = match_counts.get(status, 0) + int(count)
counts["match_counts"] = match_counts
return counts
def delete_dataset(session: Session, dataset_id: int) -> dict[str, object]:
dataset = session.get(Dataset, dataset_id)
if dataset is None:
return {"deleted": False, "reason": "dataset not found", "dataset_id": dataset_id}
counts = dataset_row_counts(session, dataset.id, dataset.kind)
_detach_update_checks_for_dataset(session, dataset.id)
session.execute(delete(OsmDiffState).where(OsmDiffState.raw_dataset_id == dataset.id))
_delete_dataset_rows(session, dataset)
_delete_dataset_files(dataset)
session.delete(dataset)
session.flush()
return {"deleted": True, "dataset_id": dataset_id, "counts": counts}
def delete_source(session: Session, source_id: int) -> dict[str, object]:
source = session.get(Source, source_id)
if source is None:
return {"deleted": False, "reason": "source not found", "source_id": source_id}
datasets = list(source.datasets)
dataset_results = []
for dataset in datasets:
dataset_results.append({"dataset_id": dataset.id, "kind": dataset.kind, "counts": dataset_row_counts(session, dataset.id, dataset.kind)})
_detach_update_checks_for_dataset(session, dataset.id)
session.execute(delete(OsmDiffState).where(OsmDiffState.raw_dataset_id == dataset.id))
_delete_dataset_rows(session, dataset)
_delete_dataset_files(dataset)
session.delete(dataset)
session.execute(delete(OsmDiffState).where(OsmDiffState.source_id == source.id))
session.delete(source)
session.flush()
return {"deleted": True, "source_id": source_id, "datasets": dataset_results}
def unreferenced_cache_file_summary(session: Session) -> dict[str, int]:
candidates = _unreferenced_cache_files(session)
return {"files": len(candidates), "bytes": sum(path.stat().st_size for path in candidates)}
def prune_unreferenced_cache_files(session: Session) -> dict[str, int]:
candidates = _unreferenced_cache_files(session)
total_bytes = sum(path.stat().st_size for path in candidates)
for path in candidates:
path.unlink()
for root in _cache_roots():
_remove_empty_dirs(root)
return {"files": len(candidates), "bytes": total_bytes}
def _unreferenced_cache_files(session: Session) -> list[Path]:
referenced = {
Path(path).resolve()
for path in session.scalars(select(Dataset.local_path)).all()
if path
}
for dataset in session.scalars(select(Dataset)).all():
referenced.update(path.resolve() for path in dataset_sidecar_paths(dataset))
return [
path
for root in _cache_roots()
if root.exists()
for path in root.rglob("*")
if path.is_file() and path.resolve() not in referenced
]
def _cache_roots() -> list[Path]:
# Staging files are not referenced by datasets until activation. Automatic
# pruning must not remove a staging DB from a running import.
return [settings.data_dir / "sources", settings.data_dir / "derived", settings.data_dir / "sidecars"]
def prune_inactive_datasets(session: Session, dry_run: bool = True) -> dict[str, object]:
dataset_rows = session.execute(
select(Dataset.id, Dataset.kind).where(Dataset.is_active.is_(False), Dataset.kind.in_(["gtfs", "osm_geojson"]))
).all()
dataset_ids = [int(row[0]) for row in dataset_rows]
gtfs_ids = [int(dataset_id) for dataset_id, kind in dataset_rows if kind == "gtfs"]
osm_ids = [int(dataset_id) for dataset_id, kind in dataset_rows if kind == "osm_geojson"]
route_ids = select(GtfsRoute.id).where(GtfsRoute.dataset_id.in_(gtfs_ids)) if gtfs_ids else None
osm_feature_ids = select(OsmFeature.id).where(OsmFeature.dataset_id.in_(osm_ids)) if osm_ids else None
match_filters = []
if route_ids is not None:
match_filters.append(RouteMatch.gtfs_route_id.in_(route_ids))
if osm_feature_ids is not None:
match_filters.append(RouteMatch.osm_feature_id.in_(osm_feature_ids))
counts = {
"datasets": len(dataset_ids),
"gtfs_stop_times": sum(stop_time_count(session, dataset_id) for dataset_id in gtfs_ids),
"gtfs_shapes": _count_dataset_rows(session, GtfsShape, gtfs_ids),
"gtfs_trips": _count_dataset_rows(session, GtfsTrip, gtfs_ids),
"gtfs_calendar_dates": _count_dataset_rows(session, GtfsCalendarDate, gtfs_ids),
"gtfs_calendars": _count_dataset_rows(session, GtfsCalendar, gtfs_ids),
"gtfs_routes": _count_dataset_rows(session, GtfsRoute, gtfs_ids),
"gtfs_stops": _count_dataset_rows(session, GtfsStop, gtfs_ids),
"gtfs_agencies": _count_dataset_rows(session, GtfsAgency, gtfs_ids),
"osm_features": sum(_safe_osm_feature_count(session, dataset_id) for dataset_id in osm_ids),
"missing_osm_sidecars": sum(1 for dataset_id in osm_ids if _osm_sidecar_missing(session, dataset_id)),
"gtfs_route_pattern_links": session.scalar(select(func.count()).select_from(GtfsRoutePatternLink).where(GtfsRoutePatternLink.dataset_id.in_(gtfs_ids))) if gtfs_ids else 0,
"gtfs_trip_route_pattern_links": session.scalar(select(func.count()).select_from(GtfsTripRoutePatternLink).where(GtfsTripRoutePatternLink.dataset_id.in_(gtfs_ids))) if gtfs_ids else 0,
"canonical_stop_links": session.scalar(select(func.count()).select_from(CanonicalStopLink).where(CanonicalStopLink.dataset_id.in_(dataset_ids))) if dataset_ids else 0,
"route_matches": session.scalar(select(func.count()).select_from(RouteMatch).where(or_(*match_filters))) if match_filters else 0,
}
if dry_run or not dataset_ids:
return {"dry_run": dry_run, "dataset_ids": dataset_ids, "deleted": counts if not dry_run else {}, "would_delete": counts}
for dataset_id in dataset_ids:
_detach_update_checks_for_dataset(session, dataset_id)
if match_filters:
session.execute(delete(RouteMatch).where(or_(*match_filters)))
if gtfs_ids:
route_ids = select(GtfsRoute.id).where(GtfsRoute.dataset_id.in_(gtfs_ids))
pattern_ids = select(RoutePattern.id).where(RoutePattern.gtfs_route_id.in_(route_ids))
session.execute(delete(RoutePatternStop).where(RoutePatternStop.route_pattern_id.in_(pattern_ids)))
session.execute(delete(GtfsTripRoutePatternLink).where(GtfsTripRoutePatternLink.dataset_id.in_(gtfs_ids)))
session.execute(delete(GtfsRoutePatternLink).where(GtfsRoutePatternLink.dataset_id.in_(gtfs_ids)))
session.execute(delete(RoutePattern).where(RoutePattern.gtfs_route_id.in_(route_ids)))
session.execute(delete(CanonicalStopLink).where(CanonicalStopLink.dataset_id.in_(gtfs_ids), CanonicalStopLink.object_type == "gtfs_stop"))
for model in [GtfsStopTime, GtfsShape, GtfsTrip, GtfsCalendarDate, GtfsCalendar, GtfsRoute, GtfsStop, GtfsAgency]:
session.execute(delete(model).where(model.dataset_id.in_(gtfs_ids)))
if osm_ids:
osm_feature_ids = select(OsmFeature.id).where(OsmFeature.dataset_id.in_(osm_ids))
pattern_ids = select(RoutePattern.id).where(RoutePattern.osm_feature_id.in_(osm_feature_ids))
session.execute(delete(RoutePatternStop).where(RoutePatternStop.route_pattern_id.in_(pattern_ids)))
session.execute(delete(GtfsTripRoutePatternLink).where(GtfsTripRoutePatternLink.route_pattern_id.in_(pattern_ids)))
session.execute(delete(GtfsRoutePatternLink).where(GtfsRoutePatternLink.route_pattern_id.in_(pattern_ids)))
session.execute(delete(RoutePattern).where(RoutePattern.osm_feature_id.in_(osm_feature_ids)))
session.execute(delete(CanonicalStopLink).where(CanonicalStopLink.dataset_id.in_(osm_ids), CanonicalStopLink.object_type == "osm_feature"))
session.execute(delete(OsmFeature).where(OsmFeature.dataset_id.in_(osm_ids)))
for dataset in session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all():
_delete_dataset_files(dataset)
session.execute(delete(Dataset).where(Dataset.id.in_(dataset_ids)))
session.flush()
return {"dry_run": dry_run, "dataset_ids": dataset_ids, "deleted": counts, "would_delete": {}}
def _delete_dataset_rows(session: Session, dataset: Dataset) -> None:
if dataset.kind == "gtfs":
route_ids = select(GtfsRoute.id).where(GtfsRoute.dataset_id == dataset.id)
pattern_ids = select(RoutePattern.id).where(RoutePattern.gtfs_route_id.in_(route_ids))
session.execute(delete(RouteMatch).where(RouteMatch.gtfs_route_id.in_(route_ids)))
session.execute(delete(RoutePatternStop).where(RoutePatternStop.route_pattern_id.in_(pattern_ids)))
session.execute(delete(GtfsTripRoutePatternLink).where(GtfsTripRoutePatternLink.dataset_id == dataset.id))
session.execute(delete(GtfsRoutePatternLink).where(GtfsRoutePatternLink.dataset_id == dataset.id))
session.execute(delete(RoutePattern).where(RoutePattern.gtfs_route_id.in_(route_ids)))
session.execute(delete(CanonicalStopLink).where(CanonicalStopLink.dataset_id == dataset.id, CanonicalStopLink.object_type == "gtfs_stop"))
for model in [GtfsStopTime, GtfsShape, GtfsTrip, GtfsCalendarDate, GtfsCalendar, GtfsRoute, GtfsStop, GtfsAgency]:
session.execute(delete(model).where(model.dataset_id == dataset.id))
elif dataset.kind == "osm_geojson":
osm_feature_ids = select(OsmFeature.id).where(OsmFeature.dataset_id == dataset.id)
pattern_ids = select(RoutePattern.id).where(RoutePattern.osm_feature_id.in_(osm_feature_ids))
session.execute(delete(RouteMatch).where(RouteMatch.osm_feature_id.in_(osm_feature_ids)))
session.execute(delete(RoutePatternStop).where(RoutePatternStop.route_pattern_id.in_(pattern_ids)))
session.execute(delete(GtfsTripRoutePatternLink).where(GtfsTripRoutePatternLink.route_pattern_id.in_(pattern_ids)))
session.execute(delete(GtfsRoutePatternLink).where(GtfsRoutePatternLink.route_pattern_id.in_(pattern_ids)))
session.execute(delete(RoutePattern).where(RoutePattern.osm_feature_id.in_(osm_feature_ids)))
session.execute(delete(CanonicalStopLink).where(CanonicalStopLink.dataset_id == dataset.id, CanonicalStopLink.object_type == "osm_feature"))
session.execute(delete(OsmFeature).where(OsmFeature.dataset_id == dataset.id))
def _delete_dataset_files(dataset: Dataset) -> None:
for path in dataset_sidecar_paths(dataset):
try:
path.unlink()
except FileNotFoundError:
pass
def dataset_sidecar_paths(dataset: Dataset) -> list[Path]:
return [*gtfs_dataset_sidecar_paths(dataset), *osm_dataset_sidecar_paths(dataset)]
def _gtfs_sidecar_missing(session: Session, dataset_id: int) -> bool:
dataset = session.get(Dataset, dataset_id)
return bool(gtfs_missing_sidecar_paths(dataset))
def _safe_osm_feature_count(session: Session, dataset_id: int, *, kind=None) -> int:
try:
return osm_feature_count(session, dataset_id, kind=kind)
except FileNotFoundError:
return 0
def _osm_sidecar_missing(session: Session, dataset_id: int) -> bool:
dataset = session.get(Dataset, dataset_id)
return bool(osm_missing_sidecar_paths(dataset))
def _detach_update_checks_for_dataset(session: Session, dataset_id: int) -> None:
for check in session.scalars(select(SourceUpdateCheck).where(SourceUpdateCheck.active_dataset_id == dataset_id)).all():
check.active_dataset_id = None
def _count(session: Session, model, dataset_id: int) -> int:
return session.scalar(select(func.count()).select_from(model).where(model.dataset_id == dataset_id)) or 0
def _count_where(session: Session, model, dataset_id: int, *where) -> int:
return session.scalar(select(func.count()).select_from(model).where(model.dataset_id == dataset_id, *where)) or 0
def _count_dataset_rows(session: Session, model, dataset_ids: list[int]) -> int:
if not dataset_ids:
return 0
return session.scalar(select(func.count()).select_from(model).where(model.dataset_id.in_(dataset_ids))) or 0
def _remove_empty_dirs(root: Path) -> None:
if not root.exists():
return
for path in sorted((p for p in root.rglob("*") if p.is_dir()), key=lambda p: len(p.parts), reverse=True):
try:
path.rmdir()
except OSError:
pass

252
app/dataset_search.py Normal file
View File

@@ -0,0 +1,252 @@
from __future__ import annotations
from sqlalchemy import func, or_, select
from sqlalchemy.orm import Session
from app.gtfs_storage import execute_sidecar_query, uses_sidecar_stop_times
from app.models import Dataset, GtfsRoute, GtfsShape, GtfsStopTime, GtfsTrip, OsmFeature, RoutePattern, Source
from app.osm_storage import osm_feature_public_id, query_osm_features
from app.pipeline.utils import norm_ref
def search_datasets(session: Session, query: str, *, active_only: bool = False, limit: int = 80) -> dict:
q = (query or "").strip()
if len(q) < 1:
return {"query": q, "gtfs_routes": [], "osm_routes": [], "route_patterns": [], "totals": {}}
max_rows = max(1, min(limit, 250))
gtfs_routes = _gtfs_route_hits(session, q, active_only=active_only, limit=max_rows)
osm_routes = _osm_route_hits(session, q, active_only=active_only, limit=max_rows)
route_patterns = _route_pattern_hits(session, q, limit=max_rows)
return {
"query": q,
"gtfs_routes": gtfs_routes,
"osm_routes": osm_routes,
"route_patterns": route_patterns,
"totals": {
"gtfs_routes": len(gtfs_routes),
"osm_routes": len(osm_routes),
"route_patterns": len(route_patterns),
},
}
def _gtfs_route_hits(session: Session, query: str, *, active_only: bool, limit: int) -> list[dict]:
pattern = f"%{query}%"
ref = norm_ref(query)
stmt = (
select(GtfsRoute, Dataset, Source)
.join(Dataset, Dataset.id == GtfsRoute.dataset_id)
.join(Source, Source.id == Dataset.source_id)
.where(
or_(
GtfsRoute.short_name.ilike(pattern),
GtfsRoute.route_id.ilike(pattern),
GtfsRoute.long_name.ilike(pattern),
GtfsRoute.route_key == ref,
)
)
.order_by(Dataset.is_active.desc(), Source.name, GtfsRoute.short_name, GtfsRoute.route_id)
.limit(limit)
)
if active_only:
stmt = stmt.where(Dataset.is_active.is_(True))
rows = session.execute(stmt).all()
route_ids = [route.id for route, _, _ in rows]
trip_counts = _trip_counts(session, route_ids)
stop_time_counts = _stop_time_counts(session, route_ids)
shape_counts = _shape_counts(session, route_ids)
return [
{
"type": "gtfs_route",
"source": _source_payload(source),
"dataset": _dataset_payload(dataset),
"route": {
"id": route.id,
"route_id": route.route_id,
"ref": route.short_name,
"name": route.long_name,
"mode": route.mode,
"operator": route.operator_name,
},
"geometry": _geometry_payload(route),
"timetable": {
"trips": trip_counts.get(route.id, 0),
"stop_times": stop_time_counts.get(route.id, 0),
"shapes": shape_counts.get(route.id, 0),
},
}
for route, dataset, source in rows
]
def _osm_route_hits(session: Session, query: str, *, active_only: bool, limit: int) -> list[dict]:
ref = norm_ref(query)
dataset_stmt = select(Dataset).where(Dataset.kind == "osm_geojson")
if active_only:
dataset_stmt = dataset_stmt.where(Dataset.is_active.is_(True))
datasets = session.scalars(dataset_stmt.order_by(Dataset.is_active.desc(), Dataset.id)).all()
if not datasets:
return []
dataset_ids = [dataset.id for dataset in datasets]
sources = {source.id: source for source in session.scalars(select(Source).where(Source.id.in_([dataset.source_id for dataset in datasets]))).all()}
dataset_by_id = {dataset.id: dataset for dataset in datasets}
features_by_identity: dict[tuple[int, str, str], OsmFeature] = {}
for feature in query_osm_features(session, dataset_ids, kinds=["route"], search=query, limit=limit):
features_by_identity[(feature.dataset_id, feature.osm_type, feature.osm_id)] = feature
if ref:
for feature in query_osm_features(session, dataset_ids, kinds=["route"], route_key=ref, limit=limit):
features_by_identity[(feature.dataset_id, feature.osm_type, feature.osm_id)] = feature
features = sorted(
features_by_identity.values(),
key=lambda feature: (
0 if dataset_by_id.get(feature.dataset_id) and dataset_by_id[feature.dataset_id].is_active else 1,
sources.get(dataset_by_id[feature.dataset_id].source_id).name if dataset_by_id.get(feature.dataset_id) and sources.get(dataset_by_id[feature.dataset_id].source_id) else "",
feature.ref or "",
feature.name or "",
feature.id or 0,
),
)[:limit]
return [
{
"type": "osm_route",
"source": _source_payload(source),
"dataset": _dataset_payload(dataset),
"osm": {
"id": osm_feature_public_id(feature),
"osm_type": feature.osm_type,
"osm_id": feature.osm_id,
"ref": feature.ref,
"name": feature.name,
"mode": feature.mode,
"route_scope": feature.route_scope,
"operator": feature.operator,
"network": feature.network,
},
"geometry": _geometry_payload(feature),
}
for feature in features
if (dataset := dataset_by_id.get(feature.dataset_id)) is not None
if (source := sources.get(dataset.source_id)) is not None
]
def _route_pattern_hits(session: Session, query: str, *, limit: int) -> list[dict]:
pattern = f"%{query}%"
ref = norm_ref(query)
stmt = (
select(RoutePattern)
.where(
or_(
RoutePattern.route_ref.ilike(pattern),
RoutePattern.route_name.ilike(pattern),
RoutePattern.pattern_key.ilike(pattern),
)
)
.order_by(RoutePattern.source_kind, RoutePattern.route_ref, RoutePattern.id)
.limit(limit)
)
rows = session.scalars(stmt).all()
return [
{
"type": "route_pattern",
"id": pattern_row.id,
"ref": pattern_row.route_ref,
"name": pattern_row.route_name,
"mode": pattern_row.mode,
"route_scope": pattern_row.route_scope,
"source_kind": pattern_row.source_kind,
"status": pattern_row.status,
"confidence": pattern_row.confidence,
"gtfs_route_id": pattern_row.gtfs_route_id,
"osm_feature_id": pattern_row.osm_feature_id,
"geometry": _geometry_payload(pattern_row),
}
for pattern_row in rows
if not ref or norm_ref(pattern_row.route_ref or pattern_row.route_name or "") == ref or query.lower() in (pattern_row.route_name or "").lower()
]
def _trip_counts(session: Session, route_row_ids: list[int]) -> dict[int, int]:
if not route_row_ids:
return {}
rows = session.execute(
select(GtfsRoute.id, func.count(GtfsTrip.id))
.join(GtfsTrip, (GtfsTrip.dataset_id == GtfsRoute.dataset_id) & (GtfsTrip.route_id == GtfsRoute.route_id))
.where(GtfsRoute.id.in_(route_row_ids))
.group_by(GtfsRoute.id)
).all()
return {int(route_id): int(count) for route_id, count in rows}
def _stop_time_counts(session: Session, route_row_ids: list[int]) -> dict[int, int]:
if not route_row_ids:
return {}
routes = session.scalars(select(GtfsRoute).where(GtfsRoute.id.in_(route_row_ids))).all()
sidecar_routes = [route for route in routes if uses_sidecar_stop_times(session, route.dataset_id)]
sidecar_route_ids = {route.id for route in sidecar_routes}
main_route_ids = [route.id for route in routes if route.id not in sidecar_route_ids]
counts: dict[int, int] = {}
if main_route_ids:
rows = session.execute(
select(GtfsRoute.id, func.count(GtfsStopTime.id))
.join(GtfsTrip, (GtfsTrip.dataset_id == GtfsRoute.dataset_id) & (GtfsTrip.route_id == GtfsRoute.route_id))
.join(GtfsStopTime, (GtfsStopTime.dataset_id == GtfsTrip.dataset_id) & (GtfsStopTime.trip_id == GtfsTrip.trip_id))
.where(GtfsRoute.id.in_(main_route_ids))
.group_by(GtfsRoute.id)
).all()
counts.update({int(route_id): int(count) for route_id, count in rows})
for route in sidecar_routes:
rows = execute_sidecar_query(
session,
route.dataset_id,
"""
SELECT COUNT(*) AS count
FROM gtfs_stop_times AS stop_times
JOIN gtfs_trips AS trips
ON trips.trip_id = stop_times.trip_id
WHERE trips.route_id = ?
""",
[route.route_id],
)
counts[int(route.id)] = int(rows[0]["count"] or 0) if rows else 0
return counts
def _shape_counts(session: Session, route_row_ids: list[int]) -> dict[int, int]:
if not route_row_ids:
return {}
rows = session.execute(
select(GtfsRoute.id, func.count(func.distinct(GtfsShape.shape_id)))
.join(GtfsTrip, (GtfsTrip.dataset_id == GtfsRoute.dataset_id) & (GtfsTrip.route_id == GtfsRoute.route_id))
.join(GtfsShape, (GtfsShape.dataset_id == GtfsTrip.dataset_id) & (GtfsShape.shape_id == GtfsTrip.shape_id))
.where(GtfsRoute.id.in_(route_row_ids))
.group_by(GtfsRoute.id)
).all()
return {int(route_id): int(count) for route_id, count in rows}
def _source_payload(source: Source) -> dict:
return {"id": source.id, "name": source.name, "kind": source.kind, "country": source.country}
def _dataset_payload(dataset: Dataset) -> dict:
return {
"id": dataset.id,
"kind": dataset.kind,
"is_active": dataset.is_active,
"status": dataset.status,
"created_at": dataset.created_at.isoformat() if dataset.created_at else None,
"sha256": dataset.sha256,
}
def _geometry_payload(row) -> dict:
bbox = None
if all(getattr(row, attr, None) is not None for attr in ("min_lon", "min_lat", "max_lon", "max_lat")):
bbox = {
"min_lon": row.min_lon,
"min_lat": row.min_lat,
"max_lon": row.max_lon,
"max_lat": row.max_lat,
}
return {"present": bool(getattr(row, "geometry_geojson", None)), "bbox": bbox}

339
app/db.py Normal file
View File

@@ -0,0 +1,339 @@
from __future__ import annotations
from contextlib import contextmanager
from pathlib import Path
import re
from typing import Iterator
from sqlalchemy import create_engine
from sqlalchemy import event
from sqlalchemy import text
from sqlalchemy.engine import Connection
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
from app.config import settings
class Base(DeclarativeBase):
pass
def _connect_args() -> dict:
if settings.is_sqlite_database:
return {"check_same_thread": False, "timeout": settings.sqlite_timeout_seconds}
return {}
def _ensure_sqlite_parent() -> None:
if not settings.is_sqlite_database:
return
# sqlite:///./data/workbench.sqlite -> ./data/workbench.sqlite
path = settings.normalized_database_url.replace("sqlite:///", "", 1)
if path and path != ":memory:":
Path(path).parent.mkdir(parents=True, exist_ok=True)
_ensure_sqlite_parent()
engine = create_engine(settings.normalized_database_url, connect_args=_connect_args(), pool_pre_ping=True, future=True)
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False, expire_on_commit=False, future=True)
_CREATE_INDEX_NAME_RE = re.compile(
r"CREATE\s+(?:UNIQUE\s+)?INDEX\s+(?:CONCURRENTLY\s+)?(?:IF\s+NOT\s+EXISTS\s+)?([A-Za-z_][A-Za-z0-9_]*)",
re.IGNORECASE,
)
if settings.is_sqlite_database:
@event.listens_for(engine, "connect")
def _set_sqlite_pragmas(dbapi_connection, _connection_record) -> None:
cursor = dbapi_connection.cursor()
try:
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute(f"PRAGMA busy_timeout={int(settings.sqlite_busy_timeout_ms)}")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA foreign_keys=ON")
finally:
cursor.close()
def init_db() -> None:
# Import models so metadata is populated.
from app import models # noqa: F401
_ensure_database_extensions()
Base.metadata.create_all(bind=engine)
_ensure_runtime_columns()
_ensure_runtime_indexes()
def reset_db() -> None:
from app import models # noqa: F401
_ensure_database_extensions()
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
_ensure_runtime_columns()
_ensure_runtime_indexes()
def _ensure_database_extensions() -> None:
if not settings.is_postgresql_database:
return
with engine.begin() as conn:
conn.execute(text("CREATE EXTENSION IF NOT EXISTS postgis"))
conn.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm"))
has_pgrouting = conn.execute(text("SELECT EXISTS (SELECT 1 FROM pg_available_extensions WHERE name = 'pgrouting')")).scalar()
if has_pgrouting:
conn.execute(text("CREATE EXTENSION IF NOT EXISTS pgrouting"))
def _ensure_runtime_columns() -> None:
if settings.is_postgresql_database:
_ensure_postgresql_runtime_columns()
return
if not settings.is_sqlite_database:
return
with engine.begin() as conn:
columns = {row[1] for row in conn.execute(text("PRAGMA table_info(gtfs_stop_times)")).all()}
if "arrival_seconds" not in columns:
conn.execute(text("ALTER TABLE gtfs_stop_times ADD COLUMN arrival_seconds INTEGER"))
if "departure_seconds" not in columns:
conn.execute(text("ALTER TABLE gtfs_stop_times ADD COLUMN departure_seconds INTEGER"))
source_columns = {row[1] for row in conn.execute(text("PRAGMA table_info(sources)")).all()}
source_runtime_columns = {
"catalog_entry_id": "INTEGER",
"priority": "VARCHAR(16)",
"mode_scope": "TEXT",
"source_basis": "TEXT",
"notes": "TEXT",
}
for column_name, column_type in source_runtime_columns.items():
if column_name not in source_columns:
conn.execute(text(f"ALTER TABLE sources ADD COLUMN {column_name} {column_type}"))
job_columns = {row[1] for row in conn.execute(text("PRAGMA table_info(jobs)")).all()}
job_runtime_columns = {
"priority": "INTEGER NOT NULL DEFAULT 0",
"requested_action": "VARCHAR(32)",
"lease_owner": "VARCHAR(255)",
"lease_expires_at": "DATETIME",
"paused_at": "DATETIME",
"dismissed_at": "DATETIME",
}
for column_name, column_type in job_runtime_columns.items():
if column_name not in job_columns:
conn.execute(text(f"ALTER TABLE jobs ADD COLUMN {column_name} {column_type}"))
route_runtime_tables = {
"gtfs_routes": "VARCHAR(64)",
"route_patterns": "VARCHAR(64)",
"osm_features": "VARCHAR(64)",
}
for table_name, column_type in route_runtime_tables.items():
table_columns = {row[1] for row in conn.execute(text(f"PRAGMA table_info({table_name})")).all()}
if "route_scope" not in table_columns:
conn.execute(text(f"ALTER TABLE {table_name} ADD COLUMN route_scope {column_type}"))
address_columns = {row[1] for row in conn.execute(text("PRAGMA table_info(osm_addresses)")).all()}
if "geometry_geojson" not in address_columns:
conn.execute(text("ALTER TABLE osm_addresses ADD COLUMN geometry_geojson TEXT"))
def _ensure_postgresql_runtime_columns() -> None:
column_statements = [
("osm_features", "geom", "ALTER TABLE osm_features ADD COLUMN geom geometry(Geometry, 4326)"),
("gtfs_routes", "geom", "ALTER TABLE gtfs_routes ADD COLUMN geom geometry(Geometry, 4326)"),
("gtfs_shapes", "geom", "ALTER TABLE gtfs_shapes ADD COLUMN geom geometry(Geometry, 4326)"),
("route_patterns", "geom", "ALTER TABLE route_patterns ADD COLUMN geom geometry(Geometry, 4326)"),
("osm_addresses", "geometry_geojson", "ALTER TABLE osm_addresses ADD COLUMN geometry_geojson TEXT"),
("osm_addresses", "geom", "ALTER TABLE osm_addresses ADD COLUMN geom geometry(Point, 4326)"),
("osm_addresses", "area_geom", "ALTER TABLE osm_addresses ADD COLUMN area_geom geometry(Geometry, 4326)"),
("gtfs_stops", "geom", "ALTER TABLE gtfs_stops ADD COLUMN geom geometry(Point, 4326)"),
("canonical_stops", "geom", "ALTER TABLE canonical_stops ADD COLUMN geom geometry(Point, 4326)"),
("routing_nodes", "geom", "ALTER TABLE routing_nodes ADD COLUMN geom geometry(Point, 4326)"),
("routing_edges", "geom", "ALTER TABLE routing_edges ADD COLUMN geom geometry(LineString, 4326)"),
]
with engine.begin() as conn:
columns = _postgresql_columns(conn)
for table_name, column_name, statement in column_statements:
if (table_name, column_name) not in columns:
conn.execute(text(statement))
country_column = columns.get(("osm_addresses", "country"))
if country_column is not None and country_column["data_type"] != "text":
conn.execute(text("ALTER TABLE osm_addresses ALTER COLUMN country TYPE TEXT"))
def _ensure_runtime_indexes() -> None:
statements = [
"CREATE INDEX IF NOT EXISTS ix_osm_features_map_bbox ON osm_features (dataset_id, kind, mode, min_lon, max_lon, min_lat, max_lat)",
"CREATE INDEX IF NOT EXISTS ix_osm_features_scope_bbox ON osm_features (dataset_id, kind, mode, route_scope, min_lon, max_lon, min_lat, max_lat)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_routes_map_bbox ON gtfs_routes (dataset_id, mode, min_lon, max_lon, min_lat, max_lat)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_routes_scope_bbox ON gtfs_routes (dataset_id, mode, route_scope, min_lon, max_lon, min_lat, max_lat)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stops_map_point ON gtfs_stops (dataset_id, lon, lat)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop ON gtfs_stop_times (dataset_id, stop_id, departure_seconds, trip_id, stop_sequence)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop_depart_trip ON gtfs_stop_times (dataset_id, stop_id, departure_seconds, trip_id)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop_arrival ON gtfs_stop_times (dataset_id, stop_id, arrival_seconds, trip_id, stop_sequence)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop_arrive_trip ON gtfs_stop_times (dataset_id, stop_id, arrival_seconds, trip_id)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_trip_seq ON gtfs_stop_times (dataset_id, trip_id, stop_sequence)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_trip_stop_seq ON gtfs_stop_times (dataset_id, trip_id, stop_id, stop_sequence)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_trip ON gtfs_trips (dataset_id, trip_id)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_route ON gtfs_trips (dataset_id, route_id)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_service ON gtfs_trips (dataset_id, service_id, trip_id)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_route_service ON gtfs_trips (dataset_id, route_id, service_id)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_routes_dataset_route ON gtfs_routes (dataset_id, route_id)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_shapes_dataset_shape ON gtfs_shapes (dataset_id, shape_id)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_calendars_dataset_service_dates ON gtfs_calendars (dataset_id, service_id, start_date, end_date)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_calendar_dates_dataset_date ON gtfs_calendar_dates (dataset_id, date, service_id, exception_type)",
"CREATE INDEX IF NOT EXISTS ix_canonical_stop_links_object ON canonical_stop_links (object_type, dataset_id, object_id)",
"CREATE INDEX IF NOT EXISTS ix_canonical_stop_links_external ON canonical_stop_links (object_type, dataset_id, external_id)",
"CREATE INDEX IF NOT EXISTS ix_route_patterns_ref_mode ON route_patterns (route_ref, mode, source_kind)",
"CREATE INDEX IF NOT EXISTS ix_route_patterns_bbox ON route_patterns (mode, min_lon, max_lon, min_lat, max_lat)",
"CREATE INDEX IF NOT EXISTS ix_route_patterns_scope_bbox ON route_patterns (mode, route_scope, source_kind, min_lon, max_lon, min_lat, max_lat)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_route_pattern_links_trip_shape ON gtfs_route_pattern_links (dataset_id, route_id, shape_id)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_trip_route_pattern_links_trip ON gtfs_trip_route_pattern_links (dataset_id, trip_id)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_trip_route_pattern_links_pattern ON gtfs_trip_route_pattern_links (route_pattern_id, dataset_id, trip_id)",
"CREATE INDEX IF NOT EXISTS ix_sources_catalog_entry ON sources (catalog_entry_id)",
"CREATE INDEX IF NOT EXISTS ix_sources_priority_country_kind ON sources (priority, country, kind)",
"CREATE INDEX IF NOT EXISTS ix_source_catalog_country_priority ON source_catalog_entries (country_code, priority, status)",
"CREATE INDEX IF NOT EXISTS ix_source_catalog_name ON source_catalog_entries (source_name)",
"CREATE INDEX IF NOT EXISTS ix_source_update_checks_source_checked ON source_update_checks (source_id, checked_at)",
"CREATE INDEX IF NOT EXISTS ix_source_update_checks_available ON source_update_checks (source_id, update_available, checked_at)",
"CREATE INDEX IF NOT EXISTS ix_osm_diff_states_source_sequence ON osm_diff_states (source_id, sequence_number)",
"CREATE INDEX IF NOT EXISTS ix_osm_diff_states_source_status ON osm_diff_states (source_id, status, updated_at)",
"CREATE INDEX IF NOT EXISTS ix_jobs_status_created ON jobs (status, created_at)",
"CREATE INDEX IF NOT EXISTS ix_jobs_kind_status ON jobs (kind, status)",
"CREATE INDEX IF NOT EXISTS ix_jobs_queue_claim ON jobs (status, priority, created_at, id)",
"CREATE INDEX IF NOT EXISTS ix_jobs_lease ON jobs (status, lease_expires_at)",
"CREATE INDEX IF NOT EXISTS ix_jobs_dismissed_status ON jobs (dismissed_at, status, created_at)",
"CREATE INDEX IF NOT EXISTS ix_job_events_job_created ON job_events (job_id, created_at, id)",
"CREATE INDEX IF NOT EXISTS ix_pipeline_runs_stage_dataset_hash ON pipeline_runs (stage, dataset_id, dependency_hash, status, started_at)",
"CREATE INDEX IF NOT EXISTS ix_pipeline_runs_stage_source_hash ON pipeline_runs (stage, source_id, dependency_hash, status, started_at)",
"CREATE INDEX IF NOT EXISTS ix_pipeline_runs_job ON pipeline_runs (job_id, stage, status)",
"CREATE INDEX IF NOT EXISTS ix_match_rules_type_active ON match_rules (rule_type, active)",
"CREATE INDEX IF NOT EXISTS ix_journey_search_cache_type_expires ON journey_search_cache (cache_type, expires_at)",
"CREATE INDEX IF NOT EXISTS ix_travel_requests_created ON travel_requests (created_at)",
"CREATE INDEX IF NOT EXISTS ix_itineraries_request_saved ON itineraries (request_id, saved, created_at)",
"CREATE INDEX IF NOT EXISTS ix_itinerary_legs_itinerary_sequence ON itinerary_legs (itinerary_id, sequence)",
"CREATE INDEX IF NOT EXISTS ix_routing_nodes_dataset_osm ON routing_nodes (dataset_id, osm_node_id)",
"CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_source ON routing_edges (dataset_id, source_osm_node_id)",
"CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_target ON routing_edges (dataset_id, target_osm_node_id)",
"CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_drive ON routing_edges (dataset_id, source_osm_node_id) WHERE drive_cost_s IS NOT NULL",
"CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_walk ON routing_edges (dataset_id, source_osm_node_id) WHERE walk_cost_s IS NOT NULL",
"CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_reverse_drive ON routing_edges (dataset_id, target_osm_node_id) WHERE reverse_drive_cost_s IS NOT NULL",
"CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_reverse_walk ON routing_edges (dataset_id, target_osm_node_id) WHERE reverse_walk_cost_s IS NOT NULL",
"CREATE INDEX IF NOT EXISTS ix_routing_edges_bbox ON routing_edges (dataset_id, min_lon, max_lon, min_lat, max_lat)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_dataset_city_street ON osm_addresses (dataset_id, city, street, housenumber)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_dataset_postcode ON osm_addresses (dataset_id, postcode)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_bbox ON osm_addresses (dataset_id, min_lon, max_lon, min_lat, max_lat)",
]
with engine.begin() as conn:
if settings.is_sqlite_database:
conn.execute(text("PRAGMA journal_mode=WAL"))
conn.execute(text(f"PRAGMA busy_timeout={int(settings.sqlite_busy_timeout_ms)}"))
if settings.is_postgresql_database:
_execute_missing_postgresql_indexes(conn, statements + _postgresql_index_statements())
else:
for statement in statements:
conn.execute(text(statement))
def _postgresql_columns(conn: Connection) -> dict[tuple[str, str], dict[str, str]]:
rows = conn.execute(
text(
"""
SELECT table_name, column_name, data_type, udt_name
FROM information_schema.columns
WHERE table_schema = ANY (current_schemas(false))
"""
)
).mappings()
return {
(str(row["table_name"]), str(row["column_name"])): {
"data_type": str(row["data_type"]),
"udt_name": str(row["udt_name"]),
}
for row in rows
}
def _execute_missing_postgresql_indexes(conn: Connection, statements: list[str]) -> None:
existing = _postgresql_index_names(conn)
for statement in statements:
index_name = _index_name_from_create_statement(statement)
if index_name and index_name in existing:
continue
conn.execute(text(statement))
if index_name:
existing.add(index_name)
def _postgresql_index_names(conn: Connection) -> set[str]:
rows = conn.execute(
text(
"""
SELECT indexname
FROM pg_indexes
WHERE schemaname = ANY (current_schemas(false))
"""
)
)
return {str(row[0]) for row in rows}
def _index_name_from_create_statement(statement: str) -> str | None:
match = _CREATE_INDEX_NAME_RE.search(statement)
return match.group(1) if match else None
def _postgresql_index_statements() -> list[str]:
return [
"CREATE INDEX IF NOT EXISTS ix_osm_features_geom_gist ON osm_features USING GIST (geom)",
"CREATE INDEX IF NOT EXISTS ix_osm_features_stop_geom_gist ON osm_features USING GIST (geom) WHERE kind IN ('stop', 'station', 'terminal')",
"CREATE INDEX IF NOT EXISTS ix_osm_features_route_geom_gist ON osm_features USING GIST (geom) WHERE kind = 'route'",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stops_geom_gist ON gtfs_stops USING GIST (geom)",
"CREATE INDEX IF NOT EXISTS ix_canonical_stops_geom_gist ON canonical_stops USING GIST (geom)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_routes_geom_gist ON gtfs_routes USING GIST (geom)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_shapes_geom_gist ON gtfs_shapes USING GIST (geom)",
"CREATE INDEX IF NOT EXISTS ix_route_patterns_geom_gist ON route_patterns USING GIST (geom)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_geom_gist ON osm_addresses USING GIST (geom)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_area_geom_gist ON osm_addresses USING GIST (area_geom)",
"CREATE INDEX IF NOT EXISTS ix_routing_nodes_geom_gist ON routing_nodes USING GIST (geom)",
"CREATE INDEX IF NOT EXISTS ix_routing_edges_bbox_box_gist ON routing_edges USING GIST (box(point(max_lon, max_lat), point(min_lon, min_lat)))",
"CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_route_shape_expr ON gtfs_trips (dataset_id, route_id, (COALESCE(shape_id, '__route__')))",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_dataset_stop ON gtfs_stop_times (dataset_id, stop_id)",
"CREATE INDEX IF NOT EXISTS ix_canonical_stop_links_gtfs_external ON canonical_stop_links (dataset_id, external_id, canonical_stop_id) WHERE object_type = 'gtfs_stop'",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stops_dataset_parent ON gtfs_stops (dataset_id, parent_station)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stops_dataset_stop_prefix ON gtfs_stops (dataset_id, (split_part(stop_id, '::', 1)))",
"CREATE INDEX IF NOT EXISTS ix_osm_features_name_trgm ON osm_features USING GIN (LOWER(COALESCE(name, '')) gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_osm_features_ref_trgm ON osm_features USING GIN (LOWER(COALESCE(ref, '')) gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_osm_features_tags_trgm ON osm_features USING GIN (LOWER(COALESCE(tags_json, '')) gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_search_trgm ON osm_addresses USING GIN (LOWER(COALESCE(search_text, '')) gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_display_trgm ON osm_addresses USING GIN (LOWER(COALESCE(display_name, '')) gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_street_key_house ON osm_addresses (dataset_id, REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss'), housenumber)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_street_key_trgm ON osm_addresses USING GIN (REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss') gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stops_name_trgm ON gtfs_stops USING GIN (name gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stops_stop_id_trgm ON gtfs_stops USING GIN (stop_id gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_route_patterns_ref_trgm ON route_patterns USING GIN (LOWER(COALESCE(route_ref, '')) gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_route_patterns_name_trgm ON route_patterns USING GIN (LOWER(COALESCE(route_name, '')) gin_trgm_ops)",
]
def get_db() -> Iterator[Session]:
db = SessionLocal()
try:
yield db
finally:
db.close()
@contextmanager
def session_scope() -> Iterator[Session]:
db = SessionLocal()
try:
yield db
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()

211
app/db_lock.py Normal file
View File

@@ -0,0 +1,211 @@
from __future__ import annotations
from contextlib import contextmanager
from dataclasses import dataclass
import json
import os
from pathlib import Path
import threading
import time
from typing import Iterator
from app.config import settings
try:
import fcntl
except ImportError: # pragma: no cover - this app currently targets Linux/macOS dev hosts
fcntl = None # type: ignore[assignment]
class DatabaseWriteBusy(RuntimeError):
def __init__(self, operation: str, active: dict[str, object] | None = None) -> None:
self.operation = operation
self.active = active or {}
active_operation = self.active.get("operation")
detail = f"Database is busy with another write operation"
if active_operation:
detail += f": {active_operation}"
super().__init__(detail)
@dataclass(frozen=True)
class DatabaseWriteState:
locked: bool
operation: str | None = None
pid: int | None = None
started_at: float | None = None
@property
def elapsed_seconds(self) -> float | None:
if self.started_at is None:
return None
return max(0.0, time.time() - self.started_at)
_process_write_lock = threading.Lock()
_state_lock = threading.Lock()
_state = DatabaseWriteState(locked=False)
def is_sqlite_database() -> bool:
return settings.is_sqlite_database
@contextmanager
def database_write_lock(operation: str, timeout: float | None = None) -> Iterator[None]:
"""Serialize SQLite writes inside and across app processes.
SQLite allows only one writer. This lock prevents mutating endpoints from
competing until SQLite times out with a low-level "database is locked" error.
"""
if not is_sqlite_database():
yield
return
effective_timeout = settings.database_write_lock_timeout_seconds if timeout is None else timeout
deadline = None if effective_timeout is None else time.monotonic() + max(0.0, effective_timeout)
if not _acquire_process_lock(deadline):
raise DatabaseWriteBusy(operation, database_write_status().__dict__)
handle = None
file_locked = False
try:
lock_path = _lock_path()
lock_path.parent.mkdir(parents=True, exist_ok=True)
handle = _open_locked_handle(lock_path, deadline)
if handle is None:
raise DatabaseWriteBusy(operation, _read_lock_metadata(lock_path))
file_locked = True
_write_lock_metadata(handle, operation)
_set_state(DatabaseWriteState(locked=True, operation=operation, pid=os.getpid(), started_at=time.time()))
yield
finally:
_set_state(DatabaseWriteState(locked=False))
if handle is not None:
if file_locked and fcntl is not None:
try:
fcntl.flock(handle.fileno(), fcntl.LOCK_UN)
except OSError:
pass
handle.close()
if file_locked:
try:
_lock_path().unlink()
except FileNotFoundError:
pass
except OSError:
pass
_process_write_lock.release()
def database_write_status() -> DatabaseWriteState:
with _state_lock:
return _state
def _acquire_process_lock(deadline: float | None) -> bool:
while True:
if _process_write_lock.acquire(blocking=False):
return True
if deadline is not None and time.monotonic() >= deadline:
return False
time.sleep(0.05)
def _acquire_file_lock(handle, deadline: float | None) -> bool:
if fcntl is None:
return True
while True:
try:
fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
return True
except BlockingIOError:
if deadline is not None and time.monotonic() >= deadline:
return False
time.sleep(0.05)
def _open_locked_handle(lock_path: Path, deadline: float | None):
while True:
try:
lock_path.parent.mkdir(parents=True, exist_ok=True)
handle = lock_path.open("a+", encoding="utf-8")
except FileNotFoundError:
if deadline is not None and time.monotonic() >= deadline:
return None
time.sleep(0.05)
continue
if _try_file_lock(handle):
return handle
metadata = _read_lock_metadata(lock_path)
handle.close()
if not _lock_metadata_is_stale(metadata):
if deadline is not None and time.monotonic() >= deadline:
return None
time.sleep(0.05)
continue
try:
lock_path.unlink()
except FileNotFoundError:
pass
except OSError:
return None
if deadline is not None and time.monotonic() >= deadline:
return None
def _try_file_lock(handle) -> bool:
if fcntl is None:
return True
try:
fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
except BlockingIOError:
return False
return True
def _lock_metadata_is_stale(metadata: dict[str, object]) -> bool:
pid = metadata.get("pid")
try:
pid_int = int(pid) # type: ignore[arg-type]
except (TypeError, ValueError):
return False
if pid_int <= 0 or pid_int == os.getpid():
return False
return not _pid_exists(pid_int)
def _pid_exists(pid: int) -> bool:
try:
os.kill(pid, 0)
except ProcessLookupError:
return False
except PermissionError:
return True
return True
def _set_state(state: DatabaseWriteState) -> None:
global _state
with _state_lock:
_state = state
def _lock_path() -> Path:
return settings.data_dir / "workbench.write.lock"
def _write_lock_metadata(handle, operation: str) -> None:
handle.seek(0)
handle.truncate()
json.dump({"operation": operation, "pid": os.getpid(), "started_at": time.time()}, handle, separators=(",", ":"))
handle.flush()
os.fsync(handle.fileno())
def _read_lock_metadata(path: Path) -> dict[str, object]:
try:
text = path.read_text(encoding="utf-8").strip()
return json.loads(text) if text else {}
except (OSError, json.JSONDecodeError):
return {}

923
app/feed_discovery.py Normal file
View File

@@ -0,0 +1,923 @@
from __future__ import annotations
import csv
import hashlib
import json
import re
from dataclasses import dataclass, field
from datetime import datetime, timezone
from html import unescape
from html.parser import HTMLParser
from pathlib import Path
from typing import Iterable
from urllib.parse import parse_qs, urljoin, urlparse
import requests
MOBILITY_DATABASE_FEEDS_URL = "https://files.mobilitydatabase.org/feeds_v2.csv"
MOBILITY_DATABASE_ACCEPTANCE_TEST_URL = (
"https://raw.githubusercontent.com/MobilityData/gtfs-validator/master/"
"scripts/mobility-database-harvester/acceptance_test_feed_list.csv"
)
PTNA_GTFS_INDEX_URL = "https://ptna.openstreetmap.de/gtfs/index.html"
PTNA_COUNTRY_URL_TEMPLATE = "https://ptna.openstreetmap.de/gtfs/{country}/index.php"
DEFAULT_DISCOVERY_COUNTRIES = ["DE", "AT", "CH", "NL", "DK", "FR", "BE", "LU", "NO", "SE", "FI", "IE", "GB"]
CURATED_TEST_COUNTRIES = ["DE", "CH", "AT", "NL", "DK", "FI", "NO", "SE", "IE", "GB", "FR", "BE", "LU"]
DIRECT_INGEST_HEADERS = ["name", "kind", "url", "country", "license", "mode_scope", "source_basis", "priority", "notes"]
CANONICAL_HEADERS = [
"candidate_id",
"discovery_source",
"country",
"subdivision",
"provider",
"feed_name",
"stable_id",
"ptna_feed_id",
"data_type",
"status",
"is_official",
"selected_url",
"direct_download_url",
"latest_url",
"original_release_url",
"license_url",
"license_text",
"osm_license_text",
"details_url",
"routes_url",
"valid_from",
"valid_to",
"release_date",
"feed_version",
"bbox",
"features",
"priority",
"availability_status",
"http_status",
"content_type",
"content_length",
"final_url",
"source_basis",
"notes",
]
@dataclass
class FeedCandidate:
discovery_source: str
country: str = ""
subdivision: str = ""
provider: str = ""
feed_name: str = ""
stable_id: str = ""
ptna_feed_id: str = ""
data_type: str = "gtfs"
status: str = ""
is_official: str = ""
selected_url: str = ""
direct_download_url: str = ""
latest_url: str = ""
original_release_url: str = ""
license_url: str = ""
license_text: str = ""
osm_license_text: str = ""
details_url: str = ""
routes_url: str = ""
valid_from: str = ""
valid_to: str = ""
release_date: str = ""
feed_version: str = ""
bbox: str = ""
features: str = ""
priority: str = ""
availability_status: str = "unchecked"
http_status: str = ""
content_type: str = ""
content_length: str = ""
final_url: str = ""
source_basis: str = ""
notes: str = ""
evidence_sources: list[str] = field(default_factory=list)
def key(self) -> str:
if self.stable_id:
return f"stable:{self.stable_id}"
if self.selected_url:
return f"url:{_normalize_url_key(self.selected_url)}"
if self.ptna_feed_id:
return f"ptna:{self.ptna_feed_id}"
return "hash:" + hashlib.sha256(json.dumps(self.row(), sort_keys=True).encode("utf-8")).hexdigest()
def candidate_id(self) -> str:
seed = "|".join(
[
self.discovery_source,
self.country,
self.stable_id,
self.ptna_feed_id,
self.selected_url,
self.provider,
self.feed_name,
]
)
return hashlib.sha256(seed.encode("utf-8")).hexdigest()[:16]
def row(self) -> dict[str, str]:
payload = {header: _string(getattr(self, header, "")) for header in CANONICAL_HEADERS if header != "candidate_id"}
payload["candidate_id"] = self.candidate_id()
return payload
def ingestable_row(self) -> dict[str, str]:
name = _feed_source_name(self.country, self.provider or self.feed_name)
license_value = self.license_text or (f"see {self.license_url}" if self.license_url else "")
basis_parts = [self.source_basis or self.discovery_source]
if self.details_url:
basis_parts.append(f"details: {self.details_url}")
if self.original_release_url and self.original_release_url != self.selected_url:
basis_parts.append(f"release: {self.original_release_url}")
notes = self.notes or ""
if self.latest_url and self.latest_url != self.selected_url:
notes = _join_notes(notes, f"Mobility Database mirror: {self.latest_url}")
if self.osm_license_text:
notes = _join_notes(notes, f"OSM permission note: {_truncate(self.osm_license_text, 240)}")
return {
"name": _truncate(name, 240),
"kind": "gtfs",
"url": self.selected_url,
"country": self.country,
"license": _truncate(license_value, 240),
"mode_scope": _mode_scope_from_features(self.features),
"source_basis": _truncate("; ".join(part for part in basis_parts if part), 500),
"priority": self.priority or _candidate_priority(self),
"notes": _truncate(notes, 1200),
}
def default_generated_dir() -> Path:
return Path(__file__).resolve().parents[1] / "docs" / "generated"
def build_gtfs_discovery_manifests(
*,
output_dir: Path | str | None = None,
countries: Iterable[str] | None = None,
include_mobility_database: bool = True,
include_acceptance_test_list: bool = True,
include_ptna: bool = True,
max_ptna_details: int = 80,
test_limit: int = 24,
check_urls: bool = False,
timeout: float = 30.0,
) -> dict[str, object]:
selected_countries = _normalize_countries(countries)
out_dir = Path(output_dir) if output_dir is not None else default_generated_dir()
out_dir.mkdir(parents=True, exist_ok=True)
candidates: list[FeedCandidate] = []
candidates.extend(load_curated_ingestable_seed(countries=selected_countries))
if include_mobility_database:
candidates.extend(fetch_mobility_database_candidates(countries=selected_countries, timeout=timeout))
if include_acceptance_test_list:
candidates.extend(fetch_mobility_acceptance_candidates(countries=selected_countries, timeout=timeout))
if include_ptna:
candidates.extend(fetch_ptna_candidates(countries=selected_countries, max_details=max_ptna_details, timeout=timeout))
merged = merge_candidates(candidates)
ingestable = [candidate for candidate in merged if candidate.selected_url and candidate.data_type == "gtfs"]
if check_urls:
for candidate in ingestable:
annotate_url_availability(candidate, timeout=min(timeout, 12.0))
test_run = select_test_run_candidates(ingestable, limit=test_limit)
candidates_path = out_dir / "gtfs_feed_candidates.csv"
ingestable_path = out_dir / "gtfs_ingestable_sources.csv"
test_path = out_dir / "gtfs_test_run_sources.csv"
report_path = out_dir / "gtfs_discovery_report.json"
_write_csv(candidates_path, CANONICAL_HEADERS, [candidate.row() for candidate in merged])
_write_csv(ingestable_path, DIRECT_INGEST_HEADERS, [candidate.ingestable_row() for candidate in ingestable])
_write_csv(test_path, DIRECT_INGEST_HEADERS, [candidate.ingestable_row() for candidate in test_run])
by_source = _count_by(merged, lambda item: item.discovery_source)
by_country = _count_by(ingestable, lambda item: item.country or "unknown")
report = {
"generated_at": datetime.now(timezone.utc).isoformat(),
"countries": selected_countries or "all",
"sources": {
"mobility_database": MOBILITY_DATABASE_FEEDS_URL if include_mobility_database else None,
"mobility_acceptance_test_list": MOBILITY_DATABASE_ACCEPTANCE_TEST_URL if include_acceptance_test_list else None,
"ptna": PTNA_GTFS_INDEX_URL if include_ptna else None,
},
"counts": {
"candidates": len(merged),
"ingestable": len(ingestable),
"test_run": len(test_run),
"by_source": by_source,
"ingestable_by_country": by_country,
},
"files": {
"candidates": str(candidates_path),
"ingestable": str(ingestable_path),
"test_run": str(test_path),
},
}
report_path.write_text(json.dumps(report, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
return report
def fetch_mobility_database_candidates(
*,
countries: list[str] | None = None,
timeout: float = 30.0,
url: str = MOBILITY_DATABASE_FEEDS_URL,
) -> list[FeedCandidate]:
text = _fetch_text(url, timeout=timeout)
rows = csv.DictReader(text.splitlines())
candidates: list[FeedCandidate] = []
for row in rows:
if _value(row, "data_type").lower() != "gtfs":
continue
country = _value(row, "location.country_code").upper()
if countries and country not in countries:
continue
direct_url = _normalize_feed_url(_value(row, "urls.direct_download"))
latest_url = _normalize_feed_url(_value(row, "urls.latest"))
selected_url = _choose_feed_url(direct_url, latest_url)
candidate = FeedCandidate(
discovery_source="mobility_database",
country=country,
subdivision=_value(row, "location.subdivision_name"),
provider=_value(row, "provider"),
feed_name=_value(row, "name"),
stable_id=_value(row, "id"),
data_type="gtfs",
status=_value(row, "status"),
is_official=_value(row, "is_official"),
selected_url=selected_url,
direct_download_url=direct_url,
latest_url=latest_url,
license_url=_value(row, "urls.license"),
bbox=_bbox_from_mobility_row(row),
features=_value(row, "features"),
source_basis="Mobility Database feed catalog",
notes=_value(row, "note"),
)
normalize_candidate_geography(candidate)
apply_known_download_overrides(candidate)
candidate.priority = _candidate_priority(candidate)
candidates.append(candidate)
return candidates
def fetch_mobility_acceptance_candidates(
*,
countries: list[str] | None = None,
timeout: float = 30.0,
url: str = MOBILITY_DATABASE_ACCEPTANCE_TEST_URL,
) -> list[FeedCandidate]:
text = _fetch_text(url, timeout=timeout)
rows = csv.DictReader(text.splitlines())
candidates: list[FeedCandidate] = []
for row in rows:
country = _value(row, "country_code").upper()
if countries and country not in countries:
continue
latest_url = _normalize_feed_url(_value(row, "urls.latest"))
if not latest_url:
continue
candidate = FeedCandidate(
discovery_source="mobility_validator_acceptance",
country=country,
subdivision=_value(row, "subdivision_name"),
provider=_value(row, "provider"),
feed_name=_value(row, "provider"),
stable_id=_value(row, "stable_id"),
status="acceptance_test",
selected_url=latest_url,
latest_url=latest_url,
source_basis="MobilityData validator acceptance-test feed list",
notes="Useful smoke-test feed list; prefer Mobility Database feeds_v2 metadata for production source review.",
priority="P3",
)
normalize_candidate_geography(candidate)
apply_known_download_overrides(candidate)
candidates.append(candidate)
return candidates
def fetch_ptna_candidates(
*,
countries: list[str] | None = None,
max_details: int = 80,
timeout: float = 30.0,
) -> list[FeedCandidate]:
country_codes = countries or DEFAULT_DISCOVERY_COUNTRIES
if not country_codes:
country_codes = discover_ptna_country_codes(timeout=timeout)
candidates: list[FeedCandidate] = []
detail_fetches = 0
for country in country_codes:
country_url = PTNA_COUNTRY_URL_TEMPLATE.format(country=country)
try:
html = _fetch_text(country_url, timeout=timeout)
except requests.RequestException:
continue
for candidate in parse_ptna_country_page(html, country=country, page_url=country_url):
if candidate.details_url and detail_fetches < max_details:
try:
detail_html = _fetch_text(candidate.details_url, timeout=timeout)
enrich_ptna_candidate_from_details(candidate, detail_html, candidate.details_url)
detail_fetches += 1
except requests.RequestException:
candidate.notes = _join_notes(candidate.notes, "PTNA detail page could not be fetched during discovery.")
candidate.priority = _candidate_priority(candidate)
candidates.append(candidate)
return candidates
def discover_ptna_country_codes(*, timeout: float = 30.0) -> list[str]:
html = _fetch_text(PTNA_GTFS_INDEX_URL, timeout=timeout)
links = _all_links(html, PTNA_GTFS_INDEX_URL)
codes: list[str] = []
for link in links:
match = re.search(r"/gtfs/([A-Z]{2})/index\.php$", urlparse(link).path)
if match and match.group(1) not in codes:
codes.append(match.group(1))
return codes
def parse_ptna_country_page(html: str, *, country: str, page_url: str) -> list[FeedCandidate]:
rows = _parse_table_rows(html, page_url)
candidates: list[FeedCandidate] = []
for row in rows:
links = [link for cell in row.cells for link in cell.links]
routes_url = _first_link_matching(links, "routes.php?feed=")
details_url = _first_link_matching(links, "gtfs-details.php?feed=")
if not routes_url and not details_url:
continue
feed_id = _feed_id_from_url(routes_url or details_url)
if not feed_id:
continue
texts = [cell.text for cell in row.cells]
release_link = _normalize_feed_url(row.cells[6].first_external_link if len(row.cells) > 6 else "")
direct_url = release_link if _looks_like_download_url(release_link) else ""
candidate = FeedCandidate(
discovery_source="ptna",
country=country,
provider=texts[2] if len(texts) > 2 else "",
feed_name=texts[1] if len(texts) > 1 else feed_id,
ptna_feed_id=feed_id,
selected_url=direct_url,
direct_download_url=direct_url,
original_release_url=release_link,
details_url=details_url,
routes_url=routes_url,
valid_from=texts[3] if len(texts) > 3 else "",
valid_to=texts[4] if len(texts) > 4 else "",
feed_version=texts[5] if len(texts) > 5 else "",
release_date=texts[6] if len(texts) > 6 else "",
source_basis="PTNA GTFS analysis",
notes="PTNA candidate; use original publisher URL where available.",
)
normalize_candidate_geography(candidate)
apply_known_download_overrides(candidate)
candidates.append(candidate)
return candidates
def enrich_ptna_candidate_from_details(candidate: FeedCandidate, html: str, page_url: str) -> None:
fields = parse_ptna_detail_fields(html, page_url)
candidate.original_release_url = _normalize_feed_url(fields.get("release url href") or fields.get("release url") or candidate.original_release_url)
candidate.license_url = fields.get("publisher's license href") or candidate.license_url
candidate.license_text = fields.get("publisher's license") or candidate.license_text
candidate.osm_license_text = fields.get("license given for use in osm") or candidate.osm_license_text
candidate.valid_from = fields.get("feed start date") or candidate.valid_from
candidate.valid_to = fields.get("feed end date") or candidate.valid_to
candidate.feed_version = fields.get("feed version") or candidate.feed_version
candidate.release_date = fields.get("release date") or candidate.release_date
network_guid = fields.get('"network:guid"')
if network_guid:
candidate.notes = _join_notes(candidate.notes, f"PTNA network:guid={network_guid}")
if not candidate.selected_url and _looks_like_download_url(candidate.original_release_url):
candidate.selected_url = _normalize_feed_url(candidate.original_release_url)
candidate.direct_download_url = candidate.selected_url
normalize_candidate_geography(candidate)
def parse_ptna_detail_fields(html: str, page_url: str) -> dict[str, str]:
parsed: dict[str, str] = {}
for row in _parse_table_rows(html, page_url):
if len(row.cells) < 2:
continue
label = _clean_text(row.cells[0].text).lower()
if not label:
continue
detail = _clean_text(row.cells[1].text)
parsed[label] = detail
if row.cells[1].first_external_link:
parsed[f"{label} href"] = row.cells[1].first_external_link
return parsed
def load_curated_ingestable_seed(
*,
countries: list[str] | None = None,
path: Path | str | None = None,
) -> list[FeedCandidate]:
seed_path = Path(path) if path is not None else Path(__file__).resolve().parents[1] / "docs" / "ingestable_sources_seed.csv"
if not seed_path.exists():
return []
candidates: list[FeedCandidate] = []
with seed_path.open("r", encoding="utf-8-sig", newline="") as handle:
for row in csv.DictReader(handle):
if _value(row, "kind").lower() != "gtfs":
continue
country = _value(row, "country").upper()
if countries and country not in countries and country != "EU":
continue
candidate = FeedCandidate(
discovery_source="curated_seed",
country=country,
provider=_value(row, "name").removesuffix(" GTFS"),
feed_name=_value(row, "name"),
selected_url=_normalize_feed_url(_value(row, "url")),
direct_download_url=_normalize_feed_url(_value(row, "url")),
license_text=_value(row, "license"),
features=_value(row, "mode_scope"),
priority=_value(row, "priority"),
source_basis=_value(row, "source_basis") or "curated seed",
notes=_value(row, "notes"),
)
normalize_candidate_geography(candidate)
apply_known_download_overrides(candidate)
candidates.append(candidate)
return candidates
def merge_candidates(candidates: Iterable[FeedCandidate]) -> list[FeedCandidate]:
by_key: dict[str, FeedCandidate] = {}
alias_to_key: dict[str, str] = {}
for candidate in candidates:
keys = _candidate_alias_keys(candidate)
primary_key = keys[0]
existing_key = next((alias_to_key[key] for key in keys if key in alias_to_key), None)
existing = by_key.get(existing_key) if existing_key is not None else None
if existing is None:
by_key[primary_key] = candidate
for key in keys:
alias_to_key[key] = primary_key
continue
_merge_candidate(existing, candidate)
for key in keys:
alias_to_key[key] = existing_key or primary_key
return sorted(by_key.values(), key=lambda item: (_priority_sort_key(item.priority), item.country, item.provider.lower(), item.feed_name.lower()))
def select_test_run_candidates(candidates: Iterable[FeedCandidate], *, limit: int = 24) -> list[FeedCandidate]:
sorted_candidates = sorted(
[
candidate
for candidate in candidates
if candidate.discovery_source != "mobility_validator_acceptance" and _test_candidate_eligible(candidate)
],
key=_test_candidate_sort_key,
)
selected: list[FeedCandidate] = []
seen_urls: set[str] = set()
per_country: dict[str, int] = {}
def add(candidate: FeedCandidate, *, force: bool = False) -> None:
if len(selected) >= limit:
return
url_key = _normalize_url_key(candidate.selected_url)
if not candidate.selected_url or url_key in seen_urls:
return
country = candidate.country or "unknown"
country_limit = 7 if force and country == "DE" else 3
if per_country.get(country, 0) >= country_limit:
return
selected.append(candidate)
seen_urls.add(url_key)
per_country[country] = per_country.get(country, 0) + 1
preferred_tokens = [
"opendata-oepnv.de",
"download.gtfs.de/germany/",
"vbb.de/vbbgtfs",
"rnv-online.de",
"vrn.de",
"gtfs.geops.ch",
"wienerlinien.at",
"gtfs.openov.nl",
"gtfs.ovapi.nl",
"rejseplanen.info",
"dev.hsl.fi/gtfs",
"hsldev.com/gtfs",
"rb_norway-aggregated-gtfs",
"data.bus-data.dft.gov.uk",
"transportforireland",
"gtfs.irail.be/de-lijn",
]
for candidate in sorted_candidates:
text = " ".join([candidate.provider, candidate.feed_name, candidate.source_basis, candidate.selected_url]).lower()
if any(token in text for token in preferred_tokens):
add(candidate, force=True)
for country in CURATED_TEST_COUNTRIES:
for candidate in sorted_candidates:
if candidate.country == country:
add(candidate)
if len(selected) >= limit:
break
if len(selected) >= limit:
break
for candidate in sorted_candidates:
add(candidate)
if len(selected) >= limit:
break
return selected
def _test_candidate_eligible(candidate: FeedCandidate) -> bool:
if not candidate.selected_url:
return False
if _priority_sort_key(candidate.priority) > 2:
return False
text = " ".join([candidate.status, candidate.selected_url, candidate.provider, candidate.feed_name, candidate.notes]).lower()
if "deprecated" in text or "inactive" in text or "{apikey}" in text:
return False
if "registration required" in text or "authentication" in text:
return False
return True
def annotate_url_availability(candidate: FeedCandidate, *, timeout: float = 10.0) -> FeedCandidate:
if not candidate.selected_url:
candidate.availability_status = "missing_url"
return candidate
headers = {"User-Agent": "meubility-workbench-feed-discovery/0.1"}
try:
response = requests.head(candidate.selected_url, allow_redirects=True, timeout=timeout, headers=headers)
if response.status_code in {405, 403} or response.status_code >= 500:
response = requests.get(
candidate.selected_url,
allow_redirects=True,
timeout=timeout,
headers={**headers, "Range": "bytes=0-0"},
stream=True,
)
candidate.http_status = str(response.status_code)
candidate.content_type = response.headers.get("content-type", "")
candidate.content_length = response.headers.get("content-length", "")
candidate.final_url = response.url
candidate.availability_status = "ok" if response.status_code < 400 else "error"
response.close()
except requests.RequestException as exc:
candidate.availability_status = "error"
candidate.notes = _join_notes(candidate.notes, f"Availability check failed: {exc}")
return candidate
def normalize_candidate_geography(candidate: FeedCandidate) -> None:
text = " ".join(
[
candidate.selected_url,
candidate.direct_download_url,
candidate.latest_url,
candidate.original_release_url,
candidate.provider,
candidate.feed_name,
candidate.source_basis,
]
).lower()
if "download.gtfs.de/germany/" in text or "gtfs for germany" in text:
candidate.country = "DE"
elif "storage.googleapis.com/marduk-production/outbound/gtfs/rb_norway" in text:
candidate.country = "NO"
elif "gtfs.ovapi.nl" in text or "openov.nl" in text:
candidate.country = "NL"
elif "www.nvbw.de/fileadmin/user_upload/service/open_data/" in text:
candidate.country = "DE"
def apply_known_download_overrides(candidate: FeedCandidate) -> None:
stale_direct_ids = {"mdb-684", "mdb-777"}
if candidate.stable_id in stale_direct_ids and candidate.latest_url:
candidate.selected_url = candidate.latest_url
candidate.notes = _join_notes(
candidate.notes,
"Selected Mobility Database latest.zip mirror because the catalog direct URL is known to be stale.",
)
@dataclass
class _HtmlCell:
text: str = ""
links: list[str] = field(default_factory=list)
@property
def first_external_link(self) -> str:
for link in self.links:
parsed = urlparse(link)
if parsed.scheme in {"http", "https"} and "ptna.openstreetmap.de" not in parsed.netloc:
return link
return ""
@dataclass
class _HtmlRow:
cells: list[_HtmlCell] = field(default_factory=list)
class _TableParser(HTMLParser):
def __init__(self, base_url: str):
super().__init__(convert_charrefs=True)
self.base_url = base_url
self.rows: list[_HtmlRow] = []
self._row: _HtmlRow | None = None
self._cell: _HtmlCell | None = None
self._active_link: str = ""
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
attrs_dict = {key: value or "" for key, value in attrs}
if tag == "tr":
self._row = _HtmlRow()
elif tag in {"td", "th"} and self._row is not None:
self._cell = _HtmlCell()
elif tag == "a" and self._cell is not None:
href = attrs_dict.get("href", "")
if href:
self._active_link = urljoin(self.base_url, href)
self._cell.links.append(self._active_link)
def handle_endtag(self, tag: str) -> None:
if tag in {"td", "th"} and self._row is not None and self._cell is not None:
self._cell.text = _clean_text(self._cell.text)
self._row.cells.append(self._cell)
self._cell = None
self._active_link = ""
elif tag == "a":
self._active_link = ""
elif tag == "tr":
if self._row is not None and self._row.cells:
self.rows.append(self._row)
self._row = None
self._cell = None
self._active_link = ""
def handle_data(self, data: str) -> None:
if self._cell is not None:
self._cell.text += data
class _LinkParser(HTMLParser):
def __init__(self, base_url: str):
super().__init__(convert_charrefs=True)
self.base_url = base_url
self.links: list[str] = []
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
if tag != "a":
return
for key, value in attrs:
if key == "href" and value:
self.links.append(urljoin(self.base_url, value))
def _parse_table_rows(html: str, base_url: str) -> list[_HtmlRow]:
parser = _TableParser(base_url)
parser.feed(html)
return parser.rows
def _all_links(html: str, base_url: str) -> list[str]:
parser = _LinkParser(base_url)
parser.feed(html)
return parser.links
def _fetch_text(url: str, *, timeout: float) -> str:
response = requests.get(url, timeout=timeout, headers={"User-Agent": "meubility-workbench-feed-discovery/0.1"})
response.raise_for_status()
return response.text
def _first_link_matching(links: Iterable[str], needle: str) -> str:
for link in links:
if needle in link:
return link
return ""
def _feed_id_from_url(url: str) -> str:
query = parse_qs(urlparse(url).query)
return (query.get("feed") or [""])[0]
def _looks_like_download_url(url: str) -> bool:
if not url:
return False
parsed = urlparse(url)
lower_path = parsed.path.lower()
lower_url = url.lower()
if lower_path.endswith(".zip"):
return True
if "exportformat=gtfs" in lower_url or "google_transit" in lower_url:
return True
if lower_path.rstrip("/").endswith(("current_gtfs", "gtfs")):
return True
if "gtfs.ovapi.nl" in parsed.netloc.lower() and "gtfs" in lower_path:
return True
return False
def _normalize_feed_url(url: str) -> str:
cleaned = _clean_text(url)
if not cleaned:
return ""
parsed = urlparse(cleaned)
if parsed.scheme:
return cleaned
first = cleaned.split("/", 1)[0]
if "." in first:
return f"https://{cleaned}"
return cleaned
def _choose_feed_url(direct_url: str, latest_url: str) -> str:
if direct_url:
return direct_url
return latest_url
def _candidate_priority(candidate: FeedCandidate) -> str:
status = candidate.status.lower()
official = candidate.is_official.lower() == "true"
if candidate.discovery_source == "curated_seed":
return candidate.priority or "P1"
if status == "active" and official and candidate.direct_download_url:
return "P0"
if status == "active" and candidate.direct_download_url:
return "P1"
if status == "active" and candidate.latest_url:
return "P2"
if candidate.discovery_source == "ptna":
return "P2" if candidate.selected_url else "P4"
return "P3"
def _test_candidate_sort_key(candidate: FeedCandidate) -> tuple[int, int, str, str]:
source_bonus = 0 if candidate.discovery_source == "curated_seed" else 1
country_bonus = CURATED_TEST_COUNTRIES.index(candidate.country) if candidate.country in CURATED_TEST_COUNTRIES else 99
return (_priority_sort_key(candidate.priority), source_bonus + country_bonus, candidate.country, candidate.provider.lower())
def _priority_sort_key(priority: str) -> int:
match = re.match(r"P(\d+)", priority or "")
return int(match.group(1)) if match else 9
def _candidate_alias_keys(candidate: FeedCandidate) -> list[str]:
keys = [candidate.key()]
if candidate.stable_id:
keys.append(f"stable:{candidate.stable_id}")
for url in [candidate.selected_url, candidate.direct_download_url, candidate.latest_url]:
if url:
keys.append(f"url:{_normalize_url_key(url)}")
if candidate.ptna_feed_id:
keys.append(f"ptna:{candidate.ptna_feed_id}")
deduped: list[str] = []
for key in keys:
if key not in deduped:
deduped.append(key)
return deduped
def _merge_candidate(existing: FeedCandidate, incoming: FeedCandidate) -> None:
if incoming.discovery_source == "curated_seed":
for field_name in ["country", "provider", "feed_name", "license_text", "features", "source_basis", "notes"]:
new_value = getattr(incoming, field_name, "")
if new_value:
setattr(existing, field_name, new_value)
existing.discovery_source = _join_unique(existing.discovery_source, incoming.discovery_source)
for field_name in CANONICAL_HEADERS:
if field_name == "candidate_id":
continue
current = getattr(existing, field_name, "")
new_value = getattr(incoming, field_name, "")
if not current and new_value:
setattr(existing, field_name, new_value)
existing.priority = _better_priority(existing.priority, incoming.priority)
existing.source_basis = _join_unique(existing.source_basis, incoming.source_basis)
existing.notes = _join_notes(existing.notes, incoming.notes)
def _better_priority(left: str, right: str) -> str:
return left if _priority_sort_key(left) <= _priority_sort_key(right) else right
def _join_unique(left: str, right: str) -> str:
parts: list[str] = []
for value in [left, right]:
for part in value.split(";"):
cleaned = part.strip()
if cleaned and cleaned not in parts:
parts.append(cleaned)
return "; ".join(parts)
def _join_notes(left: str, right: str) -> str:
return _join_unique(left, right)
def _compact_name(value: str) -> str:
return re.sub(r"\s+", " ", _clean_text(value)).strip()
def _feed_source_name(country: str, value: str) -> str:
base = _compact_name(value) or "GTFS feed"
prefix = country.upper()
display = base
if prefix and not base.upper().startswith(f"{prefix} "):
display = f"{prefix} {base}"
if "gtfs" not in display.lower():
display = f"{display} GTFS"
return display
def _clean_text(value: str) -> str:
cleaned = unescape(value or "").replace("\xa0", " ")
cleaned = re.sub(r"\s+", " ", cleaned)
return cleaned.strip()
def _mode_scope_from_features(features: str) -> str:
lower = features.lower()
modes = []
if "rail" in lower or "train" in lower:
modes.append("rail")
if "tram" in lower or "light_rail" in lower:
modes.append("tram")
if "subway" in lower or "metro" in lower:
modes.append("metro")
if "bus" in lower or not modes:
modes.append("bus")
if "ferry" in lower:
modes.append("ferry")
return ",".join(dict.fromkeys(modes))
def _bbox_from_mobility_row(row: dict[str, str]) -> str:
min_lat = _value(row, "location.bounding_box.minimum_latitude")
max_lat = _value(row, "location.bounding_box.maximum_latitude")
min_lon = _value(row, "location.bounding_box.minimum_longitude")
max_lon = _value(row, "location.bounding_box.maximum_longitude")
if not all([min_lat, max_lat, min_lon, max_lon]):
return ""
return f"{min_lon},{min_lat},{max_lon},{max_lat}"
def _normalize_countries(countries: Iterable[str] | None) -> list[str] | None:
if countries is None:
return DEFAULT_DISCOVERY_COUNTRIES
normalized = [country.strip().upper() for country in countries if country and country.strip()]
if any(country == "ALL" for country in normalized):
return None
return normalized
def _normalize_url_key(url: str) -> str:
parsed = urlparse(url.strip())
scheme = parsed.scheme.lower()
netloc = parsed.netloc.lower()
path = parsed.path.rstrip("/")
query = parsed.query
return f"{scheme}://{netloc}{path}" + (f"?{query}" if query else "")
def _write_csv(path: Path, headers: list[str], rows: list[dict[str, str]]) -> None:
with path.open("w", encoding="utf-8", newline="") as handle:
writer = csv.DictWriter(handle, fieldnames=headers, extrasaction="ignore")
writer.writeheader()
writer.writerows(rows)
def _count_by(items: Iterable[FeedCandidate], key_fn) -> dict[str, int]:
counts: dict[str, int] = {}
for item in items:
key = key_fn(item)
counts[key] = counts.get(key, 0) + 1
return dict(sorted(counts.items()))
def _value(row: dict[str, str], key: str) -> str:
return _clean_text(row.get(key, ""))
def _string(value: object) -> str:
return "" if value is None else str(value)
def _truncate(value: str, length: int) -> str:
return value[:length] if value else ""

120
app/geofabrik.py Normal file
View File

@@ -0,0 +1,120 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone
from typing import Any
import requests
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.models import Source
GEOFABRIK_INDEX_URL = "https://download.geofabrik.de/index-v1-nogeom.json"
_CACHE: dict[str, Any] = {"expires_at": None, "rows": None}
def geofabrik_catalog(q: str | None = None, limit: int = 80) -> list[dict[str, Any]]:
rows = _geofabrik_rows()
query = (q or "").strip().casefold()
if query:
rows = [
row
for row in rows
if query in row["id"].casefold()
or query in row["name"].casefold()
or query in (row.get("parent") or "").casefold()
or query in " ".join(row.get("country_codes") or []).casefold()
]
rows.sort(key=lambda row: (row.get("parent") or "", row["name"]))
return rows[: max(1, min(limit, 500))]
def geofabrik_entry(geofabrik_id: str) -> dict[str, Any] | None:
target = geofabrik_id.strip().casefold()
for row in _geofabrik_rows():
if row["id"].casefold() == target:
return row
return None
def create_geofabrik_source(session: Session, geofabrik_id: str, *, import_updates: bool = False) -> Source:
entry = geofabrik_entry(geofabrik_id)
if entry is None:
raise ValueError(f"Geofabrik extract not found: {geofabrik_id}")
if not entry.get("pbf_url"):
raise ValueError(f"Geofabrik extract has no PBF URL: {geofabrik_id}")
existing = session.scalar(select(Source).where(Source.kind == "osm_pbf", Source.url == entry["pbf_url"]))
if existing is not None:
return existing
source = Source(
name=f"Geofabrik {entry['name']}",
kind="osm_pbf",
url=entry["pbf_url"],
country=",".join(entry.get("country_codes") or [])[:8] or None,
license="ODbL / Geofabrik extract terms",
priority="P0 fallback",
mode_scope="public transport OSM routes, stops, and infrastructure",
source_basis="OpenStreetMap / Geofabrik extracts",
notes=_geofabrik_notes(entry, import_updates=import_updates),
)
session.add(source)
session.flush()
if import_updates and entry.get("updates_url"):
update_source = Source(
name=f"Geofabrik {entry['name']} updates",
kind="osm_diff",
url=entry["updates_url"],
country=source.country,
license=source.license,
priority=source.priority,
mode_scope=source.mode_scope,
source_basis="OpenStreetMap / Geofabrik replication diffs",
notes=f"Diff base for Geofabrik extract {entry['id']}; applying diffs to a local base extract is not implemented yet.",
)
session.add(update_source)
return source
def _geofabrik_rows() -> list[dict[str, Any]]:
now = datetime.now(timezone.utc)
expires_at = _CACHE.get("expires_at")
if _CACHE.get("rows") is not None and isinstance(expires_at, datetime) and expires_at > now:
return list(_CACHE["rows"])
response = requests.get(GEOFABRIK_INDEX_URL, timeout=45)
response.raise_for_status()
payload = response.json()
rows = [_normalize_feature(feature) for feature in payload.get("features", [])]
rows = [row for row in rows if row.get("id") and row.get("pbf_url")]
_CACHE["rows"] = rows
_CACHE["expires_at"] = now + timedelta(hours=12)
return list(rows)
def _normalize_feature(feature: dict[str, Any]) -> dict[str, Any]:
props = feature.get("properties") or {}
urls = props.get("urls") or {}
country_codes = props.get("iso3166-1:alpha2") or []
if isinstance(country_codes, str):
country_codes = [country_codes]
return {
"id": str(props.get("id") or ""),
"name": str(props.get("name") or props.get("id") or ""),
"parent": props.get("parent"),
"country_codes": country_codes,
"pbf_url": urls.get("pbf"),
"updates_url": urls.get("updates"),
"taginfo_url": urls.get("taginfo"),
"urls": urls,
}
def _geofabrik_notes(entry: dict[str, Any], *, import_updates: bool) -> str:
parts = [
f"geofabrik_id={entry['id']}",
f"parent={entry.get('parent') or 'root'}",
f"updates_url={entry.get('updates_url') or ''}",
"diff_source_requested=true" if import_updates else "diff_source_requested=false",
"Overlap dedupe is handled by OSM object identity in the route layer; source-specific map layers may still show both extracts.",
]
return "; ".join(parts)

308
app/gtfs_storage.py Normal file
View File

@@ -0,0 +1,308 @@
from __future__ import annotations
import json
import sqlite3
from contextlib import contextmanager
from pathlib import Path
from typing import Iterator, Sequence
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from app.config import settings
from app.models import Dataset, GtfsStopTime
GTFS_STORAGE_METADATA_KEY = "gtfs_storage"
GTFS_STORAGE_MAIN = "main"
GTFS_STORAGE_SIDECAR_STOP_TIMES = "sidecar_stop_times"
GTFS_STOP_TIME_COLUMNS = [
"trip_id",
"stop_id",
"stop_sequence",
"arrival_time",
"departure_time",
"arrival_seconds",
"departure_seconds",
]
SQLITE_IN_CHUNK_SIZE = 800
def effective_gtfs_timetable_storage(value: str | None = None) -> str:
configured = str(value or settings.gtfs_timetable_storage or GTFS_STORAGE_SIDECAR_STOP_TIMES).strip().lower()
if configured in {GTFS_STORAGE_MAIN, "main_db", "main_sqlite", "postgres", "postgresql"}:
return GTFS_STORAGE_MAIN
if settings.is_postgresql_database and not settings.postgres_use_sidecars:
return GTFS_STORAGE_MAIN
return GTFS_STORAGE_SIDECAR_STOP_TIMES
class MissingGtfsSidecar(FileNotFoundError):
def __init__(self, dataset_id: int | None, path: Path | None) -> None:
self.dataset_id = dataset_id
self.path = path
if path is None:
message = f"dataset #{dataset_id} does not reference a GTFS sidecar"
else:
message = f"GTFS sidecar does not exist: {path}"
super().__init__(message)
def dataset_metadata(dataset: Dataset) -> dict:
try:
metadata = json.loads(dataset.metadata_json or "{}")
except json.JSONDecodeError:
return {}
return metadata if isinstance(metadata, dict) else {}
def stop_times_are_sidecar(dataset: Dataset | None) -> bool:
if dataset is None:
return False
storage = dataset_metadata(dataset).get(GTFS_STORAGE_METADATA_KEY)
if not isinstance(storage, dict):
return False
tables = storage.get("tables")
if isinstance(tables, dict):
return tables.get("gtfs_stop_times") == "sidecar"
return storage.get("mode") == GTFS_STORAGE_SIDECAR_STOP_TIMES
def sidecar_path(dataset: Dataset | None) -> Path | None:
if dataset is None:
return None
storage = dataset_metadata(dataset).get(GTFS_STORAGE_METADATA_KEY)
if not isinstance(storage, dict):
return None
value = storage.get("sidecar_path")
if not value:
return None
return Path(str(value))
def dataset_sidecar_paths(dataset: Dataset) -> list[Path]:
path = sidecar_path(dataset)
return [] if path is None else [path]
def missing_sidecar_paths(dataset: Dataset | None) -> list[str]:
if not stop_times_are_sidecar(dataset):
return []
path = sidecar_path(dataset)
if path is None:
dataset_id = "unknown" if dataset is None else str(dataset.id)
return [f"dataset #{dataset_id} has no configured GTFS sidecar path"]
return [] if path.exists() else [str(path)]
def uses_sidecar_stop_times(session: Session, dataset_id: int) -> bool:
return stop_times_are_sidecar(session.get(Dataset, dataset_id))
@contextmanager
def sidecar_connection(dataset: Dataset) -> Iterator[sqlite3.Connection]:
path = sidecar_path(dataset)
if path is None:
raise MissingGtfsSidecar(dataset.id, None)
if not path.exists():
raise MissingGtfsSidecar(dataset.id, path)
connection = sqlite3.connect(f"file:{path}?mode=ro", uri=True)
connection.row_factory = sqlite3.Row
try:
yield connection
finally:
connection.close()
def stop_time_count(session: Session, dataset_id: int) -> int:
dataset = session.get(Dataset, dataset_id)
if stop_times_are_sidecar(dataset):
try:
with sidecar_connection(dataset) as connection:
return int(connection.execute("SELECT COUNT(*) FROM gtfs_stop_times").fetchone()[0] or 0)
except MissingGtfsSidecar:
return 0
return session.scalar(select(func.count()).select_from(GtfsStopTime).where(GtfsStopTime.dataset_id == dataset_id)) or 0
def stop_time_counts_by_dataset(session: Session, dataset_ids: Sequence[int]) -> dict[int, int]:
counts: dict[int, int] = {}
for dataset_id in dataset_ids:
counts[int(dataset_id)] = stop_time_count(session, int(dataset_id))
return counts
def scheduled_stop_ids(session: Session, dataset_id: int, stop_ids: Sequence[str]) -> tuple[str, ...]:
if not stop_ids:
return ()
dataset = session.get(Dataset, dataset_id)
requested = [str(stop_id) for stop_id in stop_ids]
found: set[str] = set()
if stop_times_are_sidecar(dataset):
try:
with sidecar_connection(dataset) as connection:
for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
placeholders = ", ".join(["?"] * len(chunk))
rows = connection.execute(
f"""
SELECT stop_id
FROM gtfs_stop_times
WHERE stop_id IN ({placeholders})
GROUP BY stop_id
""",
list(chunk),
).fetchall()
found.update(str(row["stop_id"]) for row in rows)
except MissingGtfsSidecar:
return ()
else:
for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
rows = session.scalars(
select(GtfsStopTime.stop_id)
.where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.stop_id.in_(chunk))
.group_by(GtfsStopTime.stop_id)
).all()
found.update(str(row) for row in rows)
return tuple(sorted(found))
def all_scheduled_stop_ids(session: Session, dataset_id: int) -> set[str]:
dataset = session.get(Dataset, dataset_id)
if stop_times_are_sidecar(dataset):
try:
with sidecar_connection(dataset) as connection:
return {
str(row["stop_id"])
for row in connection.execute("SELECT stop_id FROM gtfs_stop_times GROUP BY stop_id").fetchall()
}
except MissingGtfsSidecar:
return set()
return {
str(row)
for row in session.scalars(
select(GtfsStopTime.stop_id)
.where(GtfsStopTime.dataset_id == dataset_id)
.group_by(GtfsStopTime.stop_id)
).all()
}
def scheduled_stop_ids_by_dataset(session: Session, dataset_ids: Sequence[int]) -> dict[int, set[str]]:
return {int(dataset_id): all_scheduled_stop_ids(session, int(dataset_id)) for dataset_id in dataset_ids}
def has_scheduled_stop(session: Session, dataset_id: int, stop_id: str) -> bool:
return bool(scheduled_stop_ids(session, dataset_id, [stop_id]))
def stop_times_by_trip(
session: Session,
dataset_id: int,
trip_ids: Sequence[str],
) -> dict[str, list[GtfsStopTime]]:
if not trip_ids:
return {}
grouped: dict[str, list[GtfsStopTime]] = {}
dataset = session.get(Dataset, dataset_id)
requested = [str(trip_id) for trip_id in trip_ids]
if stop_times_are_sidecar(dataset):
column_sql = ", ".join(GTFS_STOP_TIME_COLUMNS)
try:
with sidecar_connection(dataset) as connection:
for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
placeholders = ", ".join(["?"] * len(chunk))
rows = connection.execute(
f"""
SELECT {column_sql}
FROM gtfs_stop_times
WHERE trip_id IN ({placeholders})
ORDER BY trip_id, stop_sequence
""",
list(chunk),
).fetchall()
for row in rows:
stop_time = stop_time_from_row(dataset_id, row)
grouped.setdefault(stop_time.trip_id, []).append(stop_time)
except MissingGtfsSidecar:
return {}
return grouped
for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
rows = session.scalars(
select(GtfsStopTime)
.where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.trip_id.in_(chunk))
.order_by(GtfsStopTime.trip_id, GtfsStopTime.stop_sequence)
).all()
for row in rows:
grouped.setdefault(row.trip_id, []).append(row)
return grouped
def stop_times_for_trip_range(
session: Session,
dataset_id: int,
trip_id: str,
start_sequence: int,
end_sequence: int,
) -> list[GtfsStopTime]:
dataset = session.get(Dataset, dataset_id)
if stop_times_are_sidecar(dataset):
column_sql = ", ".join(GTFS_STOP_TIME_COLUMNS)
try:
with sidecar_connection(dataset) as connection:
rows = connection.execute(
f"""
SELECT {column_sql}
FROM gtfs_stop_times
WHERE trip_id = ?
AND stop_sequence >= ?
AND stop_sequence <= ?
ORDER BY stop_sequence
""",
(trip_id, int(start_sequence), int(end_sequence)),
).fetchall()
return [stop_time_from_row(dataset_id, row) for row in rows]
except MissingGtfsSidecar:
return []
return list(
session.scalars(
select(GtfsStopTime)
.where(
GtfsStopTime.dataset_id == dataset_id,
GtfsStopTime.trip_id == trip_id,
GtfsStopTime.stop_sequence >= start_sequence,
GtfsStopTime.stop_sequence <= end_sequence,
)
.order_by(GtfsStopTime.stop_sequence)
).all()
)
def stop_time_from_row(dataset_id: int, row) -> GtfsStopTime:
return GtfsStopTime(
dataset_id=dataset_id,
trip_id=str(row["trip_id"]),
stop_id=str(row["stop_id"]),
stop_sequence=int(row["stop_sequence"]),
arrival_time=row["arrival_time"],
departure_time=row["departure_time"],
arrival_seconds=row["arrival_seconds"],
departure_seconds=row["departure_seconds"],
)
def execute_sidecar_query(session: Session, dataset_id: int, sql: str, params: Sequence[object]) -> list[sqlite3.Row]:
dataset = session.get(Dataset, dataset_id)
if not stop_times_are_sidecar(dataset):
raise ValueError(f"dataset #{dataset_id} does not use sidecar stop_times")
try:
with sidecar_connection(dataset) as connection:
return list(connection.execute(sql, list(params)).fetchall())
except MissingGtfsSidecar:
return []
def _chunks[T](items: Sequence[T], size: int) -> Iterator[Sequence[T]]:
for index in range(0, len(items), size):
yield items[index : index + size]

394
app/harmonization.py Normal file
View File

@@ -0,0 +1,394 @@
from __future__ import annotations
from datetime import date, datetime, timezone
from typing import Any
from sqlalchemy import and_, func, select
from sqlalchemy.orm import Session, aliased
from app.data_management import dataset_row_counts
from app.models import (
CanonicalStopLink,
Dataset,
GtfsCalendar,
GtfsCalendarDate,
GtfsRoute,
GtfsStop,
GtfsStopTime,
GtfsTrip,
RouteMatch,
Source,
)
GTFS_QA_NOTE_PREFIX = "[GTFS QA]"
def gtfs_harmonization_inventory(session: Session) -> dict[str, Any]:
feeds = [_feed_inventory_item(session, source) for source in _gtfs_sources(session)]
summary = {
"sources": len(feeds),
"active_sources": sum(1 for feed in feeds if feed["active_dataset"] is not None),
"datasets": sum(len(feed["datasets"]) for feed in feeds),
"ready": sum(1 for feed in feeds if feed["qa_status"] == "ready"),
"needs_review": sum(1 for feed in feeds if feed["qa_status"] == "needs_review"),
"blocked": sum(1 for feed in feeds if feed["qa_status"] == "blocked"),
}
return {
"summary": summary,
"feeds": feeds,
}
def gtfs_harmonization_feed_detail(session: Session, source_id: int) -> dict[str, Any] | None:
source = session.get(Source, source_id)
if source is None or source.kind != "gtfs":
return None
feed = _feed_inventory_item(session, source)
return {
**feed,
"sections": _feed_sections(feed),
}
def _gtfs_sources(session: Session) -> list[Source]:
return session.scalars(select(Source).where(Source.kind == "gtfs").order_by(Source.country, Source.priority, Source.name, Source.id)).all()
def _feed_inventory_item(session: Session, source: Source) -> dict[str, Any]:
datasets = sorted([dataset for dataset in source.datasets if dataset.kind == "gtfs"], key=lambda item: (not item.is_active, item.created_at, item.id))
active_dataset = next((dataset for dataset in datasets if dataset.is_active), None)
counts = dataset_row_counts(session, active_dataset.id, active_dataset.kind) if active_dataset is not None else {}
validation = _validate_gtfs_dataset(session, source, active_dataset, counts)
overlap = _overlap_summary(session, active_dataset)
service = _service_horizon(session, active_dataset)
issues = [*validation["issues"], *service["issues"], *overlap["issues"], *_license_issues(source)]
qa_status = _qa_status(issues, active_dataset)
return {
"source": _source_payload(source),
"active_dataset": None if active_dataset is None else _dataset_payload(active_dataset, counts),
"datasets": [_dataset_payload(dataset, dataset_row_counts(session, dataset.id, dataset.kind)) for dataset in datasets],
"counts": counts,
"validation": validation,
"service": service,
"overlap": overlap,
"license": _license_payload(source),
"issues": issues,
"qa_status": qa_status,
}
def _source_payload(source: Source) -> dict[str, Any]:
return {
"id": source.id,
"name": source.name,
"country": source.country,
"license": source.license,
"priority": source.priority,
"mode_scope": source.mode_scope,
"source_basis": source.source_basis,
"status": source.status,
"enabled": source.enabled,
"last_error": source.last_error,
"last_run_at": _iso(source.last_run_at),
"url": source.url,
"catalog_entry_id": source.catalog_entry_id,
"notes": source.notes,
"qa_review": _qa_review_payload(source.notes),
}
def _dataset_payload(dataset: Dataset, counts: dict[str, Any]) -> dict[str, Any]:
return {
"id": dataset.id,
"kind": dataset.kind,
"is_active": dataset.is_active,
"status": dataset.status,
"sha256": dataset.sha256,
"local_path": dataset.local_path,
"created_at": _iso(dataset.created_at),
"counts": counts,
}
def _validate_gtfs_dataset(session: Session, source: Source, dataset: Dataset | None, counts: dict[str, Any]) -> dict[str, Any]:
if dataset is None:
return {
"status": "blocked",
"items": [],
"issues": [_issue("missing_active_dataset", "bad", "No active GTFS dataset", "Import this source before harmonization.")],
}
items = [
_metric("Agencies", counts.get("agencies", 0), "bad" if not counts.get("agencies", 0) else "good"),
_metric("Stops", counts.get("stops", 0), "bad" if not counts.get("stops", 0) else "good"),
_metric("Routes", counts.get("routes", 0), "bad" if not counts.get("routes", 0) else "good"),
_metric("Trips", counts.get("trips", 0), "bad" if not counts.get("trips", 0) else "good"),
_metric("Stop times", counts.get("stop_times", 0), "bad" if not counts.get("stop_times", 0) else "good"),
_metric("Shapes", counts.get("shapes", 0), "warn" if not counts.get("shapes", 0) else "good"),
]
missing_coords = _count(session, GtfsStop, dataset.id, (GtfsStop.lat.is_(None) | GtfsStop.lon.is_(None)))
invalid_coords = _count(
session,
GtfsStop,
dataset.id,
(GtfsStop.lat < -90) | (GtfsStop.lat > 90) | (GtfsStop.lon < -180) | (GtfsStop.lon > 180),
)
routes_without_trips = _routes_without_trips(session, dataset.id)
trips_without_stop_times = _trips_without_stop_times(session, dataset.id)
stop_times_without_seconds = _stop_times_without_seconds(session, dataset.id)
route_geometry_missing = _count(session, GtfsRoute, dataset.id, GtfsRoute.geometry_geojson.is_(None))
canonical_links = _count(session, CanonicalStopLink, dataset.id, CanonicalStopLink.object_type == "gtfs_stop")
match_counts = counts.get("match_counts", {}) if isinstance(counts.get("match_counts"), dict) else {}
items.extend(
[
_metric("Stops missing coordinates", missing_coords, "bad" if missing_coords else "good"),
_metric("Stops with invalid coordinates", invalid_coords, "bad" if invalid_coords else "good"),
_metric("Routes without trips", routes_without_trips, "bad" if routes_without_trips else "good"),
_metric("Trips without stop_times", trips_without_stop_times, "bad" if trips_without_stop_times else "good"),
_metric("Stop times without parsed seconds", stop_times_without_seconds, "warn" if stop_times_without_seconds else "good"),
_metric("Routes without geometry", route_geometry_missing, "warn" if route_geometry_missing else "good"),
_metric("Canonical stop links", canonical_links, "warn" if counts.get("stops", 0) and canonical_links == 0 else "good"),
_metric("Route matches", counts.get("matches", 0), "warn" if counts.get("routes", 0) and not counts.get("matches", 0) else "good"),
]
)
issues: list[dict[str, str]] = []
if counts.get("missing_sidecar"):
issues.append(_issue("missing_sidecar", "bad", "GTFS sidecar is missing", "Queue a recovery import for this dataset."))
for key, label in [
("agencies", "No agencies imported"),
("stops", "No stops imported"),
("routes", "No routes imported"),
("trips", "No trips imported"),
("stop_times", "No stop_times imported"),
]:
if not counts.get(key, 0):
issues.append(_issue(f"missing_{key}", "bad", label, "Required GTFS content is absent or failed to import."))
if missing_coords:
issues.append(_issue("missing_stop_coordinates", "bad", f"{missing_coords:,} stops have no coordinates", "Stop coordinates are required for deduplication and routing access."))
if invalid_coords:
issues.append(_issue("invalid_stop_coordinates", "bad", f"{invalid_coords:,} stops have invalid coordinates", "Fix or exclude invalid stop coordinates before publication."))
if routes_without_trips:
issues.append(_issue("routes_without_trips", "warn", f"{routes_without_trips:,} routes have no trips", "These routes cannot contribute timetable service."))
if trips_without_stop_times:
issues.append(_issue("trips_without_stop_times", "bad", f"{trips_without_stop_times:,} trips have no stop_times", "These trips cannot be routed."))
if route_geometry_missing:
issues.append(_issue("route_geometry_missing", "warn", f"{route_geometry_missing:,} routes have no geometry", "Use GTFS shapes, route-layer matching, or stop-by-stop fallback."))
if counts.get("routes", 0) and not counts.get("shapes", 0):
issues.append(_issue("missing_shapes", "warn", "No GTFS shapes imported", "OSM route matching or generated geometry will be needed."))
if counts.get("routes", 0) and not match_counts:
issues.append(_issue("no_route_matching", "warn", "No route-match rows", "Run route matching before route-layer publication QA."))
return {
"status": _qa_status(issues, dataset),
"items": items,
"issues": issues,
}
def _service_horizon(session: Session, dataset: Dataset | None) -> dict[str, Any]:
if dataset is None:
return {"start_date": None, "end_date": None, "days_until_end": None, "items": [], "issues": []}
cal_min, cal_max = session.execute(
select(func.min(GtfsCalendar.start_date), func.max(GtfsCalendar.end_date)).where(GtfsCalendar.dataset_id == dataset.id)
).one()
date_min, date_max = session.execute(
select(func.min(GtfsCalendarDate.date), func.max(GtfsCalendarDate.date)).where(GtfsCalendarDate.dataset_id == dataset.id)
).one()
start_int = _min_int(cal_min, date_min)
end_int = _max_int(cal_max, date_max)
start_date = _gtfs_date(start_int)
end_date = _gtfs_date(end_int)
today = datetime.now(timezone.utc).date()
days_until_end = None if end_date is None else (end_date - today).days
issues: list[dict[str, str]] = []
if end_date is None:
issues.append(_issue("service_horizon_missing", "bad", "No service calendar horizon", "calendar.txt or calendar_dates.txt is required for reliable routing."))
elif days_until_end is not None and days_until_end < 0:
issues.append(_issue("service_horizon_expired", "bad", f"Service expired {abs(days_until_end):,} days ago", "Update or exclude this feed."))
elif days_until_end is not None and days_until_end < 30:
issues.append(_issue("service_horizon_short", "warn", f"Service ends in {days_until_end:,} days", "Update cadence is too close for publication confidence."))
return {
"start_date": None if start_date is None else start_date.isoformat(),
"end_date": None if end_date is None else end_date.isoformat(),
"days_until_end": days_until_end,
"items": [
_metric("Service starts", start_date.isoformat() if start_date else "n/a", "info"),
_metric("Service ends", end_date.isoformat() if end_date else "n/a", "bad" if end_date is None or (days_until_end is not None and days_until_end < 0) else "warn" if days_until_end is not None and days_until_end < 30 else "good"),
],
"issues": issues,
}
def _overlap_summary(session: Session, dataset: Dataset | None) -> dict[str, Any]:
if dataset is None:
return {"items": [], "issues": []}
route_key_overlaps = _shared_route_keys(session, dataset.id)
canonical_stop_overlaps = _shared_canonical_stops(session, dataset.id)
issues: list[dict[str, str]] = []
if route_key_overlaps:
issues.append(_issue("shared_route_keys", "warn", f"{route_key_overlaps:,} route keys also exist in another active feed", "Deduplicate or rank source authority for overlapping routes."))
if canonical_stop_overlaps:
issues.append(_issue("shared_canonical_stops", "warn", f"{canonical_stop_overlaps:,} canonical stops are shared with another active feed", "This is useful linking evidence, but conflicts need review."))
return {
"items": [
_metric("Shared route keys", route_key_overlaps, "warn" if route_key_overlaps else "good"),
_metric("Shared canonical stops", canonical_stop_overlaps, "warn" if canonical_stop_overlaps else "good"),
],
"issues": issues,
}
def _license_payload(source: Source) -> dict[str, Any]:
text = (source.license or "").strip()
unknown = not text or "unknown" in text.lower()
return {
"label": text or "unknown",
"redistribution_status": "unknown" if unknown else "review_required",
"tone": "warn" if unknown else "info",
}
def _license_issues(source: Source) -> list[dict[str, str]]:
if _license_payload(source)["redistribution_status"] == "unknown":
return [_issue("license_unknown", "warn", "License/redistribution status is unknown", "Publication needs explicit import, derivation, redistribution, and attribution flags.")]
return []
def _qa_review_payload(notes: str | None) -> dict[str, Any]:
if not notes:
return {"status": "unreviewed", "note": "", "updated_at": None}
for line in str(notes).splitlines():
if not line.startswith(GTFS_QA_NOTE_PREFIX):
continue
payload: dict[str, str] = {}
for part in line[len(GTFS_QA_NOTE_PREFIX) :].strip().split(";"):
if "=" not in part:
continue
key, value = part.split("=", 1)
payload[key.strip()] = value.strip()
return {
"status": payload.get("status") or "unreviewed",
"note": payload.get("note") or "",
"updated_at": payload.get("updated_at"),
}
return {"status": "unreviewed", "note": "", "updated_at": None}
def _routes_without_trips(session: Session, dataset_id: int) -> int:
trip_exists = select(GtfsTrip.id).where(GtfsTrip.dataset_id == dataset_id, GtfsTrip.route_id == GtfsRoute.route_id).exists()
return int(session.scalar(select(func.count()).select_from(GtfsRoute).where(GtfsRoute.dataset_id == dataset_id, ~trip_exists)) or 0)
def _trips_without_stop_times(session: Session, dataset_id: int) -> int:
stop_time_exists = select(GtfsStopTime.id).where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.trip_id == GtfsTrip.trip_id).exists()
return int(session.scalar(select(func.count()).select_from(GtfsTrip).where(GtfsTrip.dataset_id == dataset_id, ~stop_time_exists)) or 0)
def _stop_times_without_seconds(session: Session, dataset_id: int) -> int:
return int(
session.scalar(
select(func.count())
.select_from(GtfsStopTime)
.where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.arrival_seconds.is_(None), GtfsStopTime.departure_seconds.is_(None))
)
or 0
)
def _shared_route_keys(session: Session, dataset_id: int) -> int:
current = aliased(GtfsRoute)
other = aliased(GtfsRoute)
other_dataset = aliased(Dataset)
return int(
session.scalar(
select(func.count(func.distinct(current.route_key)))
.select_from(current)
.join(other, and_(other.route_key == current.route_key, other.dataset_id != current.dataset_id))
.join(other_dataset, other_dataset.id == other.dataset_id)
.where(
current.dataset_id == dataset_id,
current.route_key.is_not(None),
current.route_key != "",
other_dataset.kind == "gtfs",
other_dataset.is_active.is_(True),
)
)
or 0
)
def _shared_canonical_stops(session: Session, dataset_id: int) -> int:
current = aliased(CanonicalStopLink)
other = aliased(CanonicalStopLink)
other_dataset = aliased(Dataset)
return int(
session.scalar(
select(func.count(func.distinct(current.canonical_stop_id)))
.select_from(current)
.join(other, and_(other.canonical_stop_id == current.canonical_stop_id, other.dataset_id != current.dataset_id))
.join(other_dataset, other_dataset.id == other.dataset_id)
.where(
current.dataset_id == dataset_id,
current.object_type == "gtfs_stop",
other.object_type == "gtfs_stop",
other_dataset.kind == "gtfs",
other_dataset.is_active.is_(True),
)
)
or 0
)
def _count(session: Session, model: Any, dataset_id: int, *criteria: Any) -> int:
stmt = select(func.count()).select_from(model).where(model.dataset_id == dataset_id)
if criteria:
stmt = stmt.where(*criteria)
return int(session.scalar(stmt) or 0)
def _metric(label: str, value: Any, tone: str = "info", description: str = "") -> dict[str, Any]:
return {"label": label, "value": value, "tone": tone, "description": description}
def _issue(issue_id: str, severity: str, title: str, detail: str) -> dict[str, str]:
return {"id": issue_id, "severity": severity, "title": title, "detail": detail}
def _qa_status(issues: list[dict[str, str]], dataset: Dataset | None) -> str:
if dataset is None or any(issue.get("severity") == "bad" for issue in issues):
return "blocked"
if any(issue.get("severity") == "warn" for issue in issues):
return "needs_review"
return "ready"
def _feed_sections(feed: dict[str, Any]) -> list[dict[str, Any]]:
return [
{"id": "validation", "title": "GTFS Validation", "items": feed["validation"]["items"]},
{"id": "service", "title": "Service Horizon", "items": feed["service"]["items"]},
{"id": "overlap", "title": "Overlap and Deduplication", "items": feed["overlap"]["items"]},
{"id": "license", "title": "License", "items": [_metric("Redistribution", feed["license"]["redistribution_status"], feed["license"]["tone"]), _metric("License", feed["license"]["label"], feed["license"]["tone"])]},
]
def _gtfs_date(value: int | None) -> date | None:
if value is None:
return None
try:
return datetime.strptime(str(int(value)), "%Y%m%d").date()
except ValueError:
return None
def _min_int(*values: int | None) -> int | None:
clean = [int(value) for value in values if value is not None]
return min(clean) if clean else None
def _max_int(*values: int | None) -> int | None:
clean = [int(value) for value in values if value is not None]
return max(clean) if clean else None
def _iso(value: datetime | None) -> str | None:
return None if value is None else value.isoformat()

360
app/itineraries.py Normal file
View File

@@ -0,0 +1,360 @@
from __future__ import annotations
import json
from datetime import datetime, timezone
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.journey import duration_minutes_ceil, find_journeys, format_duration_label
from app.models import Itinerary, ItineraryLeg, TravelRequest
from app.routing import route_between_points
def generate_itineraries(
db: Session,
*,
from_stop_id: str,
to_stop_id: str,
via_stop_id: str | None,
departure: str,
service_date: str | None,
max_transfers: int,
transfer_seconds: int,
limit: int,
source_ids: list[int] | None,
preferences: dict[str, Any] | None = None,
) -> dict:
request = TravelRequest(
origin_stop_id=from_stop_id,
destination_stop_id=to_stop_id,
via_stop_id=via_stop_id or None,
departure_time=departure,
service_date=service_date or None,
max_transfers=max(0, max_transfers),
transfer_seconds=max(0, transfer_seconds),
source_filter=",".join(str(source_id) for source_id in source_ids or []) or None,
preferences_json=json.dumps(preferences or {}, separators=(",", ":")),
)
db.add(request)
db.flush()
journey_result = find_journeys(
db=db,
from_stop_id=from_stop_id,
to_stop_id=to_stop_id,
via_stop_id=via_stop_id,
departure=departure,
service_date=service_date,
max_transfers=max(0, max_transfers),
transfer_seconds=max(0, transfer_seconds),
limit=limit,
source_ids=source_ids,
)
itineraries: list[Itinerary] = []
for index, journey in enumerate(journey_result.get("journeys", []), start=1):
itinerary = _journey_itinerary(request.id, journey, index)
db.add(itinerary)
db.flush()
_add_journey_legs(db, itinerary.id, journey)
itineraries.append(itinerary)
car_itinerary = _car_itinerary(db, request.id, journey_result.get("from"), journey_result.get("to"))
if car_itinerary is not None:
db.add(car_itinerary)
db.flush()
_add_routing_leg(db, car_itinerary.id, car_itinerary)
itineraries.append(car_itinerary)
placeholders = _placeholder_itineraries(
request.id,
journey_result.get("from"),
journey_result.get("to"),
service_date=service_date,
include_car=car_itinerary is None,
)
for itinerary in placeholders:
db.add(itinerary)
db.flush()
itineraries.append(itinerary)
db.flush()
return {
"request": travel_request_payload(request),
"journey_context": {
"from": journey_result.get("from"),
"to": journey_result.get("to"),
"via": journey_result.get("via"),
"sources": journey_result.get("sources", []),
},
"itineraries": [itinerary_payload(db, itinerary) for itinerary in itineraries],
}
def travel_request_payload(request: TravelRequest) -> dict[str, Any]:
return {
"id": request.id,
"origin_stop_id": request.origin_stop_id,
"destination_stop_id": request.destination_stop_id,
"via_stop_id": request.via_stop_id,
"departure_time": request.departure_time,
"service_date": request.service_date,
"max_transfers": request.max_transfers,
"transfer_seconds": request.transfer_seconds,
"source_filter": request.source_filter,
"preferences": _json_dict(request.preferences_json),
"created_at": request.created_at.isoformat() if request.created_at else None,
}
def itinerary_payload(db: Session, itinerary: Itinerary) -> dict[str, Any]:
legs = db.scalars(
select(ItineraryLeg)
.where(ItineraryLeg.itinerary_id == itinerary.id)
.order_by(ItineraryLeg.sequence)
).all()
return {
"id": itinerary.id,
"request_id": itinerary.request_id,
"title": itinerary.title,
"family": itinerary.family,
"status": itinerary.status,
"saved": itinerary.saved,
"summary": _json_dict(itinerary.summary_json),
"score": _json_dict(itinerary.score_json),
"payload": _json_dict(itinerary.payload_json),
"legs": [itinerary_leg_payload(leg) for leg in legs],
"created_at": itinerary.created_at.isoformat() if itinerary.created_at else None,
"updated_at": itinerary.updated_at.isoformat() if itinerary.updated_at else None,
}
def itinerary_leg_payload(leg: ItineraryLeg) -> dict[str, Any]:
return {
"id": leg.id,
"itinerary_id": leg.itinerary_id,
"sequence": leg.sequence,
"mode": leg.mode,
"route_ref": leg.route_ref,
"route_name": leg.route_name,
"from_name": leg.from_name,
"to_name": leg.to_name,
"departure_time": leg.departure_time,
"arrival_time": leg.arrival_time,
"locked": leg.locked,
"payload": _json_dict(leg.payload_json),
}
def set_itinerary_saved(db: Session, itinerary: Itinerary, saved: bool) -> dict[str, Any]:
itinerary.saved = saved
itinerary.status = "saved" if saved else "candidate"
itinerary.updated_at = datetime.now(timezone.utc)
db.flush()
return itinerary_payload(db, itinerary)
def set_leg_locked(db: Session, leg: ItineraryLeg, locked: bool) -> dict[str, Any]:
leg.locked = locked
itinerary = db.get(Itinerary, leg.itinerary_id)
if itinerary is not None:
itinerary.updated_at = datetime.now(timezone.utc)
db.flush()
return itinerary_leg_payload(leg)
def recent_itineraries(db: Session, *, saved_only: bool = False, limit: int = 30) -> list[dict[str, Any]]:
stmt = select(Itinerary).order_by(Itinerary.updated_at.desc(), Itinerary.id.desc())
if saved_only:
stmt = stmt.where(Itinerary.saved.is_(True))
rows = db.scalars(stmt.limit(max(1, min(limit, 100)))).all()
return [itinerary_payload(db, itinerary) for itinerary in rows]
def _journey_itinerary(request_id: int, journey: dict, index: int) -> Itinerary:
score = _journey_score(journey)
summary = {
"departure_time": journey.get("departure_time"),
"arrival_time": journey.get("arrival_time"),
"duration_minutes": journey.get("duration_minutes"),
"duration_label": journey.get("duration_label"),
"transfers": journey.get("transfers"),
"leg_count": len(journey.get("legs", [])),
"route_refs": [leg.get("route_ref") or leg.get("route_id") for leg in journey.get("legs", [])],
}
return Itinerary(
request_id=request_id,
title=f"Public transport option {index}",
family="public_transport",
status="candidate",
saved=False,
summary_json=json.dumps(summary, separators=(",", ":")),
score_json=json.dumps(score, separators=(",", ":")),
payload_json=json.dumps({"journey": journey}, separators=(",", ":")),
)
def _add_journey_legs(db: Session, itinerary_id: int, journey: dict) -> None:
for index, leg in enumerate(journey.get("legs", []), start=1):
db.add(
ItineraryLeg(
itinerary_id=itinerary_id,
sequence=index,
mode=leg.get("mode"),
route_ref=leg.get("route_ref"),
route_name=leg.get("route_name"),
from_name=(leg.get("from") or {}).get("name") or (leg.get("from") or {}).get("stop_id"),
to_name=(leg.get("to") or {}).get("name") or (leg.get("to") or {}).get("stop_id"),
departure_time=leg.get("departure_time"),
arrival_time=leg.get("arrival_time"),
locked=False,
payload_json=json.dumps({"journey_leg": leg}, separators=(",", ":")),
)
)
def _car_itinerary(db: Session, request_id: int, from_stop: dict | None, to_stop: dict | None) -> Itinerary | None:
from_lon = _float_or_none((from_stop or {}).get("lon"))
from_lat = _float_or_none((from_stop or {}).get("lat"))
to_lon = _float_or_none((to_stop or {}).get("lon"))
to_lat = _float_or_none((to_stop or {}).get("lat"))
if None in {from_lon, from_lat, to_lon, to_lat}:
return None
try:
route = route_between_points(
db,
from_lon=from_lon,
from_lat=from_lat,
to_lon=to_lon,
to_lat=to_lat,
mode="drive",
max_visited=300_000,
)
except Exception: # noqa: BLE001 - car comparison is optional
return None
duration_seconds = _float_or_none(route.get("duration_seconds"))
duration_minutes = duration_minutes_ceil(duration_seconds)
distance_m = _float_or_none(route.get("distance_m"))
summary = {
"from": (from_stop or {}).get("name") or (from_stop or {}).get("stop_id") or "origin",
"to": (to_stop or {}).get("name") or (to_stop or {}).get("stop_id") or "destination",
"duration_minutes": duration_minutes,
"duration_label": format_duration_label(duration_seconds),
"distance_km": None if distance_m is None else round(distance_m / 1000, 1),
"transfers": 0,
"engine": route.get("engine"),
}
score = {
"duration_minutes": duration_minutes,
"transfers": 0,
"complexity": 1,
"emissions": "high",
"estimated_cost": None,
}
return Itinerary(
request_id=request_id,
title="Car only",
family="car",
status="candidate",
saved=False,
summary_json=json.dumps(summary, separators=(",", ":")),
score_json=json.dumps(score, separators=(",", ":")),
payload_json=json.dumps({"routing": route}, separators=(",", ":")),
)
def _add_routing_leg(db: Session, itinerary_id: int, itinerary: Itinerary) -> None:
payload = _json_dict(itinerary.payload_json)
route = payload.get("routing") if isinstance(payload, dict) else None
if not isinstance(route, dict):
return
db.add(
ItineraryLeg(
itinerary_id=itinerary_id,
sequence=1,
mode=str(route.get("mode") or "drive"),
route_ref=None,
route_name="Road route",
from_name=str((route.get("start_node") or {}).get("osm_node_id") or "origin"),
to_name=str((route.get("target_node") or {}).get("osm_node_id") or "destination"),
departure_time=None,
arrival_time=None,
locked=False,
payload_json=json.dumps({"routing_leg": route}, separators=(",", ":")),
)
)
def _placeholder_itineraries(
request_id: int,
from_stop: dict | None,
to_stop: dict | None,
*,
service_date: str | None,
include_car: bool = True,
) -> list[Itinerary]:
from_name = (from_stop or {}).get("name") or (from_stop or {}).get("stop_id") or "origin"
to_name = (to_stop or {}).get("name") or (to_stop or {}).get("stop_id") or "destination"
placeholders = [
("car_ferry", "Car + ferry", "Needs ferry-port candidate graph", {"complexity": 3, "emissions": "medium_high"}),
("flight_access", "Flight + airport access", "Needs airport/flight schedule connector", {"complexity": 4, "emissions": "high"}),
("rail_long_stay", "Rail with adjustable city stop", "Use via stop and leg locking to refine", {"complexity": 3, "emissions": "low"}),
]
if include_car:
placeholders.insert(0, ("car", "Car only", "Needs road-routing connector", {"complexity": 1, "emissions": "high"}))
rows = []
for family, title, note, score in placeholders:
summary = {
"from": from_name,
"to": to_name,
"service_date": service_date,
"note": note,
"duration_minutes": None,
"transfers": None,
}
rows.append(
Itinerary(
request_id=request_id,
title=title,
family=family,
status="placeholder",
saved=False,
summary_json=json.dumps(summary, separators=(",", ":")),
score_json=json.dumps(score, separators=(",", ":")),
payload_json=json.dumps({"placeholder": True, "note": note}, separators=(",", ":")),
)
)
return rows
def _float_or_none(value: object) -> float | None:
try:
return None if value is None else float(value)
except (TypeError, ValueError):
return None
def _journey_score(journey: dict) -> dict[str, Any]:
modes = [leg.get("mode") for leg in journey.get("legs", [])]
duration = journey.get("duration_minutes")
transfers = int(journey.get("transfers") or 0)
railish = sum(1 for mode in modes if mode in {"train", "subway", "tram", "light_rail"})
busish = sum(1 for mode in modes if mode in {"bus", "coach", "trolleybus"})
emissions_hint = "low" if railish >= busish else "medium"
return {
"duration_minutes": duration,
"transfers": transfers,
"complexity": transfers + len(modes),
"emissions": emissions_hint,
"overnight": False,
"estimated_cost": None,
}
def _json_dict(value: str | None) -> dict[str, Any]:
try:
data = json.loads(value or "{}")
except json.JSONDecodeError:
return {}
return data if isinstance(data, dict) else {}

1932
app/jobs.py Normal file

File diff suppressed because it is too large Load Diff

5385
app/journey.py Normal file

File diff suppressed because it is too large Load Diff

717
app/journey_search.py Normal file
View File

@@ -0,0 +1,717 @@
from __future__ import annotations
import json
import hashlib
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from datetime import date, datetime, timedelta, timezone
from typing import Any
from sqlalchemy import select
from app.address_search import is_location_token
from app.db import SessionLocal
from app.journey import find_journeys, parse_service_date, resolve_location_summary
from app.models import JourneySearchCache
from app.routing import direct_route_between_points, route_between_points
MAX_PROGRESSIVE_TRANSFERS = 5
TRANSIT_STAGE_CACHE_TTL_SECONDS = 5 * 60
TRANSIT_STAGE_CACHE_MAX_ENTRIES = 256
PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS = 10 * 60
PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES = 128
JOURNEY_SEARCH_CACHE_VERSION = "journey-search-v7"
_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="journey-search")
_lock = threading.RLock()
_searches: dict[str, "_SearchState"] = {}
_progressive_search_inflight: dict[tuple[object, ...], str] = {}
_transit_stage_cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]] = {}
_progressive_search_cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]] = {}
@dataclass
class _SearchState:
id: str
request: dict[str, Any]
cache_key: tuple[object, ...] | None = None
status: str = "queued"
message: str = "Queued."
stage: str = "queued"
journeys: list[dict] = field(default_factory=list)
routing: dict[str, Any] | None = None
context: dict[str, Any] = field(default_factory=dict)
error: str | None = None
created_at: float = field(default_factory=time.time)
updated_at: float = field(default_factory=time.time)
complete: bool = False
cancelled: bool = False
def start_journey_search(request: dict[str, Any]) -> dict[str, Any]:
key = _progressive_cache_key(request)
cached = _progressive_cache_get(key)
search_id = uuid.uuid4().hex
state = _SearchState(id=search_id, request=dict(request), cache_key=key)
if cached is not None:
_apply_cached_payload(state, cached)
with _lock:
_prune_old_searches()
if cached is None:
existing_search_id = _progressive_search_inflight.get(key)
existing_state = None if existing_search_id is None else _searches.get(existing_search_id)
if existing_state is not None and not existing_state.complete and not existing_state.cancelled:
return _payload(existing_state)
_progressive_search_inflight[key] = search_id
_searches[search_id] = state
if cached is None:
_executor.submit(_run_search, search_id)
return journey_search_payload(search_id)
def journey_search_payload(search_id: str) -> dict[str, Any]:
with _lock:
state = _searches.get(search_id)
if state is None:
raise KeyError(search_id)
return _payload(state)
def cancel_journey_search(search_id: str) -> dict[str, Any]:
with _lock:
state = _searches.get(search_id)
if state is None:
raise KeyError(search_id)
state.cancelled = True
if not state.complete:
state.status = "cancelled"
state.message = "Search cancelled."
state.complete = True
state.updated_at = time.time()
_clear_inflight_search_locked(state)
return _payload(state)
def _run_search(search_id: str) -> None:
with _lock:
state = _searches.get(search_id)
if state is None or state.cancelled:
return
state.status = "running"
state.stage = "starting"
state.message = "Starting search..."
state.updated_at = time.time()
request = dict(state.request)
try:
mode = str(request.get("mode") or "transit")
if mode in {"walk", "drive", "car"}:
_run_point_route_search(search_id, "drive" if mode == "car" else mode, request)
else:
_run_transit_search(search_id, request)
except Exception as exc: # noqa: BLE001 - report progressive-search failure to client
_publish_error(search_id, str(exc))
def _run_transit_search(search_id: str, request: dict[str, Any]) -> None:
direct_only = bool(request.get("direct_only"))
limit = max(3, min(int(request.get("limit") or 5), 10))
transfer_seconds = max(0, min(int(request.get("transfer_seconds") or 120), 3600))
source_ids = _csv_ints(request.get("source_id"))
service_date = request.get("service_date") or None
stages = [0] if direct_only else list(range(0, MAX_PROGRESSIVE_TRANSFERS + 1))
address_search = is_location_token(request.get("from_stop_id")) or is_location_token(request.get("to_stop_id"))
stage_limit = limit if address_search else max(limit, 10)
merged: dict[str, dict] = {}
context: dict[str, Any] = {}
diagnostics: dict[str, Any] = {"stages": []}
best_count = 0
stale_stages = 0
for transfers in stages:
if _is_cancelled(search_id):
return
label = "direct" if transfers == 0 else f"up to {transfers} transfer{'s' if transfers != 1 else ''}"
_publish_status(search_id, "running", f"Searching {label}...", f"transfers_{transfers}")
stage_started_at = time.monotonic()
with SessionLocal() as db:
result = _cached_find_journeys(
db,
from_stop_id=str(request.get("from_stop_id") or ""),
to_stop_id=str(request.get("to_stop_id") or ""),
via_stop_id=request.get("via_stop_id") or None,
source_ids=source_ids,
departure=str(request.get("departure") or "08:00"),
service_date=service_date,
max_transfers=transfers,
transfer_seconds=transfer_seconds,
limit=stage_limit,
)
cache_status = str(result.pop("_cache_status", "miss"))
elapsed_ms = int((time.monotonic() - stage_started_at) * 1000)
stage_diagnostics = {
"transfers": transfers,
"cache": cache_status,
"elapsed_ms": elapsed_ms,
"journeys": len(result.get("journeys") or []),
}
result_diagnostics = result.get("diagnostics")
if isinstance(result_diagnostics, dict):
stage_diagnostics["details"] = result_diagnostics
diagnostics["stages"].append(stage_diagnostics)
context = _context_from_result(result)
context["diagnostics"] = diagnostics
before = len(merged)
for journey in result.get("journeys") or []:
merged.setdefault(_journey_key(journey), journey)
ranked = _select_diverse_journeys(_rank_journeys(merged.values(), str(request.get("ranking") or "recommended")), limit=limit)
_publish_results(
search_id,
journeys=ranked,
context=context,
status="running",
stage=f"transfers_{transfers}",
message=f"Found {len(ranked)} option{'s' if len(ranked) != 1 else ''}; still searching..." if not direct_only else "Direct search complete.",
)
if len(merged) <= before and ranked:
stale_stages += 1
else:
stale_stages = 0
best_count = max(best_count, len(ranked))
if ranked and stale_stages >= 2 and transfers >= 2:
break
if _major_hub_address_stage_is_complete(result_diagnostics, ranked, transfers=transfers, limit=limit):
break
complete_message = (
f"Search complete. Found {best_count} option{'s' if best_count != 1 else ''}."
if best_count
else "Search complete. No route found in the imported timetable."
)
_publish_complete(search_id, message=complete_message)
payload = journey_search_payload(search_id)
if payload.get("status") == "complete" and not payload.get("error"):
_progressive_cache_put(_progressive_cache_key(request), payload)
def _major_hub_address_stage_is_complete(
diagnostics: dict[str, Any] | None,
ranked: list[dict],
*,
transfers: int,
limit: int,
) -> bool:
if transfers < 1 or not ranked or not isinstance(diagnostics, dict):
return False
address_access = diagnostics.get("address_access")
if not isinstance(address_access, dict) or not address_access.get("major_hubs"):
return False
return len(ranked) >= min(3, limit)
def _run_point_route_search(search_id: str, mode: str, request: dict[str, Any]) -> None:
_publish_status(search_id, "running", f"Searching {mode} route...", mode)
with SessionLocal() as db:
from_location = resolve_location_summary(db, str(request.get("from_stop_id") or ""), source_ids=_csv_ints(request.get("source_id")))
to_location = resolve_location_summary(db, str(request.get("to_stop_id") or ""), source_ids=_csv_ints(request.get("source_id")))
if from_location.lon is None or from_location.lat is None:
raise ValueError("Selected start has no coordinates.")
if to_location.lon is None or to_location.lat is None:
raise ValueError("Selected destination has no coordinates.")
try:
route = route_between_points(
db,
from_lon=float(from_location.lon),
from_lat=float(from_location.lat),
to_lon=float(to_location.lon),
to_lat=float(to_location.lat),
mode=mode,
max_visited=300_000,
)
message = f"{mode.title()} route found."
except Exception as exc: # noqa: BLE001 - point routing should still return an approximate connector
route = direct_route_between_points(
db,
from_lon=float(from_location.lon),
from_lat=float(from_location.lat),
to_lon=float(to_location.lon),
to_lat=float(to_location.lat),
mode=mode,
reason=str(exc),
)
message = f"{mode.title()} route approximated."
context = {
"from": _stop_payload(from_location),
"to": _stop_payload(to_location),
"mode": mode,
}
_publish_routing(search_id, route, context=context, message=message)
_publish_complete(search_id, message=f"{mode.title()} route complete.")
payload = journey_search_payload(search_id)
if payload.get("status") == "complete" and not payload.get("error"):
_progressive_cache_put(_progressive_cache_key(request), payload)
def _cached_find_journeys(
db,
*,
from_stop_id: str,
to_stop_id: str,
via_stop_id: object,
source_ids: list[int] | None,
departure: str,
service_date: object,
max_transfers: int,
transfer_seconds: int,
limit: int,
) -> dict[str, Any]:
key = (
from_stop_id,
to_stop_id,
str(via_stop_id or ""),
tuple(sorted(int(source_id) for source_id in source_ids or [])),
departure,
str(service_date or ""),
int(max_transfers),
int(transfer_seconds),
int(limit),
)
now = time.monotonic()
with _lock:
cached = _transit_stage_cache.get(key)
if cached is not None:
expires_at, payload = cached
if expires_at > now:
return _with_cache_status(payload, "memory")
_transit_stage_cache.pop(key, None)
durable = _durable_cache_get("transit_stage", key)
if durable is not None:
with _lock:
_transit_stage_cache[key] = (now + TRANSIT_STAGE_CACHE_TTL_SECONDS, json.loads(json.dumps(durable)))
_prune_timed_cache(_transit_stage_cache, TRANSIT_STAGE_CACHE_MAX_ENTRIES)
return _with_cache_status(durable, "persistent")
result = find_journeys(
db=db,
from_stop_id=from_stop_id,
to_stop_id=to_stop_id,
via_stop_id=via_stop_id,
source_ids=source_ids,
departure=departure,
service_date=service_date,
max_transfers=max_transfers,
transfer_seconds=transfer_seconds,
limit=limit,
)
stored_result = json.loads(json.dumps(result))
with _lock:
_transit_stage_cache[key] = (now + TRANSIT_STAGE_CACHE_TTL_SECONDS, stored_result)
_prune_timed_cache(_transit_stage_cache, TRANSIT_STAGE_CACHE_MAX_ENTRIES)
_durable_cache_put("transit_stage", key, stored_result, ttl_seconds=TRANSIT_STAGE_CACHE_TTL_SECONDS)
return _with_cache_status(result, "miss")
def _with_cache_status(payload: dict[str, Any], cache_status: str) -> dict[str, Any]:
copied = json.loads(json.dumps(payload))
copied["_cache_status"] = cache_status
return copied
def _prune_timed_cache(cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]], max_entries: int) -> None:
if len(cache) <= max_entries:
return
oldest = sorted(cache.items(), key=lambda item: item[1][0])[: len(cache) - max_entries]
for old_key, _ in oldest:
cache.pop(old_key, None)
def _durable_cache_get(cache_type: str, key: tuple[object, ...]) -> dict[str, Any] | None:
storage_key = _durable_cache_key(cache_type, key)
now = datetime.now(timezone.utc)
try:
with SessionLocal() as session:
row = session.scalar(select(JourneySearchCache).where(JourneySearchCache.cache_key == storage_key))
if row is None:
return None
expires_at = _as_utc(row.expires_at)
if expires_at is None or expires_at <= now:
session.delete(row)
session.commit()
return None
return json.loads(row.payload_json)
except Exception: # noqa: BLE001 - cache misses must not break journey search
return None
def _durable_cache_put(cache_type: str, key: tuple[object, ...], payload: dict[str, Any], *, ttl_seconds: int) -> None:
storage_key = _durable_cache_key(cache_type, key)
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=max(1, int(ttl_seconds)))
try:
with SessionLocal() as session:
row = session.scalar(select(JourneySearchCache).where(JourneySearchCache.cache_key == storage_key))
if row is None:
row = JourneySearchCache(
cache_key=storage_key,
cache_type=cache_type,
payload_json=json.dumps(payload, separators=(",", ":")),
created_at=now,
updated_at=now,
expires_at=expires_at,
)
session.add(row)
else:
row.cache_type = cache_type
row.payload_json = json.dumps(payload, separators=(",", ":"))
row.updated_at = now
row.expires_at = expires_at
session.commit()
except Exception: # noqa: BLE001 - cache writes are best-effort
return
def _durable_cache_key(cache_type: str, key: tuple[object, ...]) -> str:
raw = json.dumps(
{
"version": JOURNEY_SEARCH_CACHE_VERSION,
"cache_type": cache_type,
"key": _json_safe(key),
},
sort_keys=True,
separators=(",", ":"),
)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
def _json_safe(value: object) -> object:
if isinstance(value, tuple):
return [_json_safe(item) for item in value]
if isinstance(value, list):
return [_json_safe(item) for item in value]
if isinstance(value, dict):
return {str(key): _json_safe(item) for key, item in sorted(value.items(), key=lambda item: str(item[0]))}
if isinstance(value, (str, int, float, bool)) or value is None:
return value
if isinstance(value, (date, datetime)):
return value.isoformat()
return str(value)
def _as_utc(value: datetime | None) -> datetime | None:
if value is None:
return None
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)
def _progressive_cache_key(request: dict[str, Any]) -> tuple[object, ...]:
source_ids = _csv_ints(request.get("source_id"))
return (
str(request.get("mode") or "transit"),
str(request.get("from_stop_id") or ""),
str(request.get("to_stop_id") or ""),
str(request.get("via_stop_id") or ""),
tuple(sorted(int(source_id) for source_id in source_ids or [])),
str(request.get("departure") or "08:00"),
str(request.get("service_date") or ""),
bool(request.get("direct_only")),
str(request.get("ranking") or "recommended"),
int(request.get("transfer_seconds") or 120),
max(3, min(int(request.get("limit") or 5), 10)),
)
def _progressive_cache_get(key: tuple[object, ...]) -> dict[str, Any] | None:
now = time.monotonic()
with _lock:
cached = _progressive_search_cache.get(key)
if cached is not None:
expires_at, payload = cached
if expires_at > now:
copied = json.loads(json.dumps(payload))
copied["cache_status"] = "memory"
return copied
_progressive_search_cache.pop(key, None)
durable = _durable_cache_get("progressive", key)
if durable is None:
return None
with _lock:
_progressive_search_cache[key] = (now + PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS, json.loads(json.dumps(durable)))
_prune_timed_cache(_progressive_search_cache, PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES)
copied = json.loads(json.dumps(durable))
copied["cache_status"] = "persistent"
return copied
def _progressive_cache_put(key: tuple[object, ...], payload: dict[str, Any]) -> None:
stored_payload = json.loads(json.dumps(payload))
stored_payload.pop("cache_status", None)
with _lock:
_progressive_search_cache[key] = (time.monotonic() + PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS, stored_payload)
_prune_timed_cache(_progressive_search_cache, PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES)
_durable_cache_put("progressive", key, stored_payload, ttl_seconds=PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS)
def _apply_cached_payload(state: _SearchState, payload: dict[str, Any]) -> None:
state.status = str(payload.get("status") or "complete")
state.message = "Cached result."
state.stage = str(payload.get("stage") or "cached")
state.journeys = json.loads(json.dumps(payload.get("journeys") or []))
state.routing = json.loads(json.dumps(payload.get("routing"))) if payload.get("routing") is not None else None
state.context = {
key: value
for key, value in payload.items()
if key not in {"search_id", "status", "stage", "message", "complete", "error", "journeys", "routing", "created_at", "updated_at"}
}
state.error = None
state.complete = True
state.updated_at = time.time()
def _publish_status(search_id: str, status: str, message: str, stage: str) -> None:
with _lock:
state = _searches.get(search_id)
if state is None or state.cancelled:
return
state.status = status
state.message = message
state.stage = stage
state.updated_at = time.time()
def _publish_results(search_id: str, *, journeys: list[dict], context: dict[str, Any], status: str, stage: str, message: str) -> None:
with _lock:
state = _searches.get(search_id)
if state is None or state.cancelled:
return
state.status = status
state.stage = stage
state.message = message
state.journeys = list(journeys)
state.context = dict(context)
state.updated_at = time.time()
def _publish_routing(search_id: str, routing: dict[str, Any], *, context: dict[str, Any], message: str) -> None:
with _lock:
state = _searches.get(search_id)
if state is None or state.cancelled:
return
state.status = "running"
state.stage = str(routing.get("mode") or "route")
state.message = message
state.routing = routing
state.context = dict(context)
state.updated_at = time.time()
def _publish_complete(search_id: str, *, message: str) -> None:
with _lock:
state = _searches.get(search_id)
if state is None or state.cancelled:
return
state.status = "complete"
state.message = message
state.complete = True
state.updated_at = time.time()
_clear_inflight_search_locked(state)
def _publish_error(search_id: str, message: str) -> None:
with _lock:
state = _searches.get(search_id)
if state is None:
return
state.status = "error"
state.stage = "error"
state.message = message
state.error = message
state.complete = True
state.updated_at = time.time()
_clear_inflight_search_locked(state)
def _clear_inflight_search_locked(state: _SearchState) -> None:
if state.cache_key is not None and _progressive_search_inflight.get(state.cache_key) == state.id:
_progressive_search_inflight.pop(state.cache_key, None)
def _is_cancelled(search_id: str) -> bool:
with _lock:
state = _searches.get(search_id)
return state is None or state.cancelled
def _payload(state: _SearchState) -> dict[str, Any]:
return {
"search_id": state.id,
"status": state.status,
"stage": state.stage,
"message": state.message,
"complete": state.complete,
"error": state.error,
"journeys": json.loads(json.dumps(state.journeys)),
"routing": json.loads(json.dumps(state.routing)) if state.routing is not None else None,
"created_at": state.created_at,
"updated_at": state.updated_at,
**json.loads(json.dumps(state.context)),
}
def _context_from_result(result: dict[str, Any]) -> dict[str, Any]:
return {
key: value
for key, value in result.items()
if key not in {"journeys"} and key not in {"error"}
}
def _journey_key(journey: dict[str, Any]) -> str:
parts = []
for leg in journey.get("legs") or []:
parts.append(
"|".join(
str(part or "")
for part in [
leg.get("dataset_id"),
leg.get("mode"),
leg.get("route_id"),
leg.get("trip_id"),
(leg.get("from") or {}).get("stop_id") or (leg.get("from") or {}).get("name"),
(leg.get("to") or {}).get("stop_id") or (leg.get("to") or {}).get("name"),
leg.get("departure_time"),
leg.get("arrival_time"),
]
)
)
return "||".join(parts)
def _rank_journeys(journeys, ranking: str) -> list[dict]:
def key(journey: dict[str, Any]) -> tuple[float, float, int, float]:
departure = journey.get("departure_seconds")
arrival = journey.get("arrival_seconds")
duration = journey.get("duration_minutes")
transfers = int(journey.get("transfers") or 0)
walking = sum(float(leg.get("distance_m") or 0) for leg in journey.get("legs") or [] if leg.get("mode") == "walk")
walking_seconds = walking / 1.35
if ranking == "duration":
return (
float("inf") if duration is None else float(duration),
float("inf") if arrival is None else float(arrival),
transfers,
walking,
)
if ranking == "fewest_transfers":
return (
transfers,
float("inf") if arrival is None else float(arrival),
float("inf") if duration is None else float(duration),
walking,
)
if ranking == "earliest_arrival":
return (
float("inf") if arrival is None else float(arrival),
float("inf") if duration is None else float(duration),
transfers,
walking,
)
return (
float("inf") if arrival is None else float(arrival) + transfers * 600 + walking_seconds,
float("inf") if arrival is None else float(arrival),
transfers,
walking,
)
return sorted((dict(journey) for journey in journeys), key=key)
def _select_diverse_journeys(journeys: list[dict], *, limit: int) -> list[dict]:
selected: list[dict] = []
selected_exact: set[str] = set()
selected_diversity: set[tuple[object, ...]] = set()
for journey in journeys:
exact_key = _journey_key(journey)
if exact_key in selected_exact:
continue
diversity_key = _journey_diversity_key(journey)
if diversity_key in selected_diversity and len(selected) >= 3:
continue
selected.append(journey)
selected_exact.add(exact_key)
selected_diversity.add(diversity_key)
if len(selected) >= limit:
return selected
if len(selected) >= min(3, limit):
return selected
for journey in journeys:
exact_key = _journey_key(journey)
if exact_key in selected_exact:
continue
selected.append(journey)
selected_exact.add(exact_key)
if len(selected) >= min(3, limit):
break
return _ensure_walk_only_option(selected, journeys, limit=limit)
def _ensure_walk_only_option(selected: list[dict], ranked: list[dict], *, limit: int) -> list[dict]:
if any(_journey_is_walk_only(journey) for journey in selected):
return selected
walk = next((journey for journey in ranked if _journey_is_walk_only(journey)), None)
if walk is None:
return selected
if len(selected) < limit:
return [*selected, walk]
if selected:
selected[-1] = walk
return selected
def _journey_is_walk_only(journey: dict) -> bool:
legs = journey.get("legs") or []
return bool(legs) and all(leg.get("mode") == "walk" for leg in legs)
def _journey_diversity_key(journey: dict[str, Any]) -> tuple[object, ...]:
route_signature = tuple(
str(leg.get("route_ref") or leg.get("route_id") or leg.get("mode") or "")
for leg in journey.get("legs") or []
if leg.get("mode") != "walk"
)
departure = journey.get("departure_seconds")
time_band = None if departure is None else int(departure) // (30 * 60)
return (int(journey.get("transfers") or 0), route_signature, time_band)
def _csv_ints(value: object) -> list[int] | None:
if value is None:
return None
items = [item.strip() for item in str(value).split(",") if item.strip()]
if not items:
return None
return [int(item) for item in items]
def _stop_payload(stop) -> dict[str, Any]:
return {
"id": stop.id,
"dataset_id": stop.dataset_id,
"stop_id": stop.stop_id,
"name": stop.name,
"lat": stop.lat,
"lon": stop.lon,
}
def _prune_old_searches() -> None:
now = time.time()
stale = [
search_id
for search_id, state in _searches.items()
if now - state.updated_at > 15 * 60 or (state.complete and now - state.updated_at > 3 * 60)
]
for search_id in stale:
state = _searches.pop(search_id, None)
if state is not None:
_clear_inflight_search_locked(state)

2653
app/main.py Normal file

File diff suppressed because it is too large Load Diff

612
app/models.py Normal file
View File

@@ -0,0 +1,612 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy import BigInteger, Boolean, DateTime, Float, ForeignKey, Integer, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db import Base
def now_utc() -> datetime:
return datetime.now(timezone.utc)
class Source(Base):
__tablename__ = "sources"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
catalog_entry_id: Mapped[Optional[int]] = mapped_column(ForeignKey("source_catalog_entries.id"), nullable=True, index=True)
name: Mapped[str] = mapped_column(String(255), nullable=False)
kind: Mapped[str] = mapped_column(String(64), nullable=False) # gtfs, osm_geojson, osm_pbf, osm_diff
url: Mapped[str] = mapped_column(Text, nullable=False)
country: Mapped[Optional[str]] = mapped_column(String(8), nullable=True)
license: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
priority: Mapped[Optional[str]] = mapped_column(String(16), nullable=True, index=True)
mode_scope: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
source_basis: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
status: Mapped[str] = mapped_column(String(64), default="new", nullable=False)
last_error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
last_run_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
catalog_entry: Mapped[Optional["SourceCatalogEntry"]] = relationship()
datasets: Mapped[list["Dataset"]] = relationship(back_populates="source", cascade="all, delete-orphan")
update_checks: Mapped[list["SourceUpdateCheck"]] = relationship(back_populates="source", cascade="all, delete-orphan")
class SourceCatalogEntry(Base):
__tablename__ = "source_catalog_entries"
__table_args__ = (UniqueConstraint("catalog_key", name="uq_source_catalog_entry_key"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
catalog_key: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
geography: Mapped[Optional[str]] = mapped_column(String(128), nullable=True, index=True)
country_code: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
mode_scope: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
source_name: Mapped[str] = mapped_column(Text, nullable=False)
source_category: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
formats_apis: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
availability: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
coverage_notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
geometry_notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
disruptions_closures: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
operator_list_use: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
access_license_notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
priority: Mapped[Optional[str]] = mapped_column(String(32), nullable=True, index=True)
source_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
evidence_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
next_pipeline_action: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
status: Mapped[str] = mapped_column(String(64), default="backlog", nullable=False, index=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
class Dataset(Base):
__tablename__ = "datasets"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
source_id: Mapped[int] = mapped_column(ForeignKey("sources.id"), nullable=False, index=True)
kind: Mapped[str] = mapped_column(String(64), nullable=False)
local_path: Mapped[str] = mapped_column(Text, nullable=False)
sha256: Mapped[str] = mapped_column(String(64), nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
status: Mapped[str] = mapped_column(String(64), default="imported", nullable=False)
metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
source: Mapped[Source] = relationship(back_populates="datasets")
class SourceUpdateCheck(Base):
__tablename__ = "source_update_checks"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
source_id: Mapped[int] = mapped_column(ForeignKey("sources.id"), nullable=False, index=True)
checked_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
status: Mapped[str] = mapped_column(String(64), nullable=False, default="checked", index=True)
update_available: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
reason: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
remote_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
etag: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
last_modified: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
content_length: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
content_type: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
local_mtime: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
local_size: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
local_sha256: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
active_dataset_id: Mapped[Optional[int]] = mapped_column(ForeignKey("datasets.id"), nullable=True, index=True)
active_dataset_sha256: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
source: Mapped[Source] = relationship(back_populates="update_checks")
active_dataset: Mapped[Optional[Dataset]] = relationship()
class OsmDiffState(Base):
__tablename__ = "osm_diff_states"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
source_id: Mapped[int] = mapped_column(ForeignKey("sources.id"), nullable=False, index=True)
raw_dataset_id: Mapped[Optional[int]] = mapped_column(ForeignKey("datasets.id"), nullable=True, index=True)
updates_url: Mapped[str] = mapped_column(Text, nullable=False)
sequence_number: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
timestamp: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
status: Mapped[str] = mapped_column(String(64), nullable=False, default="active", index=True)
metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
source: Mapped[Source] = relationship()
raw_dataset: Mapped[Optional[Dataset]] = relationship()
class Job(Base):
__tablename__ = "jobs"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
kind: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
status: Mapped[str] = mapped_column(String(64), nullable=False, default="queued", index=True)
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
progress_current: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
progress_total: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
priority: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
requested_action: Mapped[Optional[str]] = mapped_column(String(32), nullable=True, index=True)
lease_owner: Mapped[Optional[str]] = mapped_column(String(255), nullable=True, index=True)
lease_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True, index=True)
paused_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
result_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
dismissed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True, index=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
events: Mapped[list["JobEvent"]] = relationship(back_populates="job", cascade="all, delete-orphan")
class JobEvent(Base):
__tablename__ = "job_events"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
job_id: Mapped[int] = mapped_column(ForeignKey("jobs.id"), nullable=False, index=True)
level: Mapped[str] = mapped_column(String(32), nullable=False, default="info", index=True)
event_type: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
message: Mapped[str] = mapped_column(Text, nullable=False)
progress_current: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
progress_total: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
job: Mapped[Job] = relationship(back_populates="events")
class PipelineRun(Base):
__tablename__ = "pipeline_runs"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
stage: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
version: Mapped[str] = mapped_column(String(128), nullable=False, index=True)
dependency_hash: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
status: Mapped[str] = mapped_column(String(64), nullable=False, default="running", index=True)
source_id: Mapped[Optional[int]] = mapped_column(ForeignKey("sources.id"), nullable=True, index=True)
dataset_id: Mapped[Optional[int]] = mapped_column(ForeignKey("datasets.id"), nullable=True, index=True)
job_id: Mapped[Optional[int]] = mapped_column(ForeignKey("jobs.id"), nullable=True, index=True)
input_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
output_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
started_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
source: Mapped[Optional[Source]] = relationship()
dataset: Mapped[Optional[Dataset]] = relationship()
job: Mapped[Optional[Job]] = relationship()
class GtfsAgency(Base):
__tablename__ = "gtfs_agencies"
__table_args__ = (UniqueConstraint("dataset_id", "agency_id", name="uq_gtfs_agency_dataset_id"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
agency_id: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(Text, nullable=False)
url: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
timezone: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
class GtfsStop(Base):
__tablename__ = "gtfs_stops"
__table_args__ = (UniqueConstraint("dataset_id", "stop_id", name="uq_gtfs_stop_dataset_id"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
stop_id: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
parent_station: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
class GtfsRoute(Base):
__tablename__ = "gtfs_routes"
__table_args__ = (UniqueConstraint("dataset_id", "route_id", name="uq_gtfs_route_dataset_id"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
route_id: Mapped[str] = mapped_column(String(255), nullable=False)
agency_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
short_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
long_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
route_type: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
mode: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
route_scope: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
operator_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
geometry_geojson: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
min_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
min_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
max_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
max_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
route_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True, index=True)
operator_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True, index=True)
class GtfsTrip(Base):
__tablename__ = "gtfs_trips"
__table_args__ = (UniqueConstraint("dataset_id", "trip_id", name="uq_gtfs_trip_dataset_id"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
route_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
trip_id: Mapped[str] = mapped_column(String(255), nullable=False)
service_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
shape_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
class GtfsCalendar(Base):
__tablename__ = "gtfs_calendars"
__table_args__ = (UniqueConstraint("dataset_id", "service_id", name="uq_gtfs_calendar_dataset_service"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
service_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
monday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
tuesday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
wednesday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
thursday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
friday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
saturday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
sunday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
start_date: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
end_date: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
class GtfsCalendarDate(Base):
__tablename__ = "gtfs_calendar_dates"
__table_args__ = (UniqueConstraint("dataset_id", "service_id", "date", name="uq_gtfs_calendar_date_dataset_service_date"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
service_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
date: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
exception_type: Mapped[int] = mapped_column(Integer, nullable=False)
class GtfsShape(Base):
__tablename__ = "gtfs_shapes"
__table_args__ = (UniqueConstraint("dataset_id", "shape_id", name="uq_gtfs_shape_dataset_id"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
shape_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
geometry_geojson: Mapped[str] = mapped_column(Text, nullable=False)
min_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
min_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
max_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
max_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
class GtfsStopTime(Base):
__tablename__ = "gtfs_stop_times"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
trip_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
stop_id: Mapped[str] = mapped_column(String(255), nullable=False)
stop_sequence: Mapped[int] = mapped_column(Integer, nullable=False)
arrival_time: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
departure_time: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
arrival_seconds: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, index=True)
departure_seconds: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, index=True)
class CanonicalStop(Base):
__tablename__ = "canonical_stops"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
stop_key: Mapped[str] = mapped_column(String(255), nullable=False, unique=True, index=True)
name: Mapped[str] = mapped_column(Text, nullable=False)
normalized_name: Mapped[str] = mapped_column(Text, nullable=False, index=True)
lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
mode: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
class CanonicalStopLink(Base):
__tablename__ = "canonical_stop_links"
__table_args__ = (
UniqueConstraint("object_type", "dataset_id", "object_id", name="uq_canonical_stop_link_object"),
)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
canonical_stop_id: Mapped[int] = mapped_column(ForeignKey("canonical_stops.id"), nullable=False, index=True)
layer: Mapped[str] = mapped_column(String(64), nullable=False, index=True) # timetable, visual
object_type: Mapped[str] = mapped_column(String(64), nullable=False, index=True) # gtfs_stop, osm_feature
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
object_id: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
external_id: Mapped[str] = mapped_column(Text, nullable=False)
role: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=1.0)
distance_m: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
canonical_stop: Mapped[CanonicalStop] = relationship()
class RoutePattern(Base):
__tablename__ = "route_patterns"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
pattern_key: Mapped[str] = mapped_column(String(255), nullable=False, unique=True, index=True)
route_ref: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
route_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
mode: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
route_scope: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
operator_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
source_kind: Mapped[str] = mapped_column(String(64), nullable=False, index=True) # osm, gtfs_proposed
status: Mapped[str] = mapped_column(String(64), nullable=False, default="active", index=True)
osm_feature_id: Mapped[Optional[int]] = mapped_column(ForeignKey("osm_features.id"), nullable=True, index=True)
gtfs_route_id: Mapped[Optional[int]] = mapped_column(ForeignKey("gtfs_routes.id"), nullable=True, index=True)
gtfs_shape_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True, index=True)
geometry_geojson: Mapped[str] = mapped_column(Text, nullable=False)
min_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
min_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
max_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
max_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=1.0)
metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
osm_feature: Mapped[Optional["OsmFeature"]] = relationship()
gtfs_route: Mapped[Optional[GtfsRoute]] = relationship()
class RoutePatternStop(Base):
__tablename__ = "route_pattern_stops"
__table_args__ = (UniqueConstraint("route_pattern_id", "sequence", name="uq_route_pattern_stop_sequence"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
route_pattern_id: Mapped[int] = mapped_column(ForeignKey("route_patterns.id"), nullable=False, index=True)
canonical_stop_id: Mapped[int] = mapped_column(ForeignKey("canonical_stops.id"), nullable=False, index=True)
sequence: Mapped[int] = mapped_column(Integer, nullable=False)
distance_along: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
source_kind: Mapped[str] = mapped_column(String(64), nullable=False, default="timetable")
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=1.0)
route_pattern: Mapped[RoutePattern] = relationship()
canonical_stop: Mapped[CanonicalStop] = relationship()
class GtfsRoutePatternLink(Base):
__tablename__ = "gtfs_route_pattern_links"
__table_args__ = (UniqueConstraint("dataset_id", "route_id", "shape_id", name="uq_gtfs_route_pattern_shape"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
gtfs_route_id: Mapped[int] = mapped_column(ForeignKey("gtfs_routes.id"), nullable=False, index=True)
route_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
shape_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
route_pattern_id: Mapped[int] = mapped_column(ForeignKey("route_patterns.id"), nullable=False, index=True)
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0)
status: Mapped[str] = mapped_column(String(64), nullable=False, default="linked", index=True)
source_kind: Mapped[str] = mapped_column(String(64), nullable=False)
reasons_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
gtfs_route: Mapped[GtfsRoute] = relationship()
route_pattern: Mapped[RoutePattern] = relationship()
class GtfsTripRoutePatternLink(Base):
__tablename__ = "gtfs_trip_route_pattern_links"
__table_args__ = (UniqueConstraint("dataset_id", "trip_id", name="uq_gtfs_trip_route_pattern"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
trip_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
route_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
shape_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
route_pattern_id: Mapped[int] = mapped_column(ForeignKey("route_patterns.id"), nullable=False, index=True)
source_kind: Mapped[str] = mapped_column(String(64), nullable=False)
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0)
status: Mapped[str] = mapped_column(String(64), nullable=False, default="linked", index=True)
route_pattern: Mapped[RoutePattern] = relationship()
class OsmFeature(Base):
__tablename__ = "osm_features"
__table_args__ = (UniqueConstraint("dataset_id", "osm_type", "osm_id", name="uq_osm_feature_dataset_type_id"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
osm_type: Mapped[str] = mapped_column(String(32), nullable=False)
osm_id: Mapped[str] = mapped_column(String(64), nullable=False)
kind: Mapped[str] = mapped_column(String(64), nullable=False, index=True) # route, stop, terminal, station, infra
mode: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
route_scope: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
ref: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
operator: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
network: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
geometry_geojson: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
min_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
min_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
max_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
max_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
tags_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
route_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True, index=True)
operator_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True, index=True)
class OsmAddress(Base):
__tablename__ = "osm_addresses"
__table_args__ = (UniqueConstraint("dataset_id", "osm_type", "osm_id", name="uq_osm_address_dataset_type_id"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
osm_type: Mapped[str] = mapped_column(String(32), nullable=False)
osm_id: Mapped[str] = mapped_column(String(64), nullable=False)
housenumber: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
street: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
place: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
postcode: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
city: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
country: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
unit: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
display_name: Mapped[str] = mapped_column(Text, nullable=False)
search_text: Mapped[str] = mapped_column(Text, nullable=False)
lat: Mapped[float] = mapped_column(Float, nullable=False)
lon: Mapped[float] = mapped_column(Float, nullable=False)
min_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
min_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
max_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
max_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
geometry_geojson: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
tags_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
class RoutingNode(Base):
__tablename__ = "routing_nodes"
__table_args__ = (UniqueConstraint("dataset_id", "osm_node_id", name="uq_routing_node_dataset_osm"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
osm_node_id: Mapped[int] = mapped_column(BigInteger, nullable=False, index=True)
lat: Mapped[float] = mapped_column(Float, nullable=False)
lon: Mapped[float] = mapped_column(Float, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
class RoutingEdge(Base):
__tablename__ = "routing_edges"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
osm_way_id: Mapped[int] = mapped_column(BigInteger, nullable=False, index=True)
source_osm_node_id: Mapped[int] = mapped_column(BigInteger, nullable=False, index=True)
target_osm_node_id: Mapped[int] = mapped_column(BigInteger, nullable=False, index=True)
source_lat: Mapped[float] = mapped_column(Float, nullable=False)
source_lon: Mapped[float] = mapped_column(Float, nullable=False)
target_lat: Mapped[float] = mapped_column(Float, nullable=False)
target_lon: Mapped[float] = mapped_column(Float, nullable=False)
highway: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
length_m: Mapped[float] = mapped_column(Float, nullable=False)
walk_cost_s: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
reverse_walk_cost_s: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
drive_cost_s: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
reverse_drive_cost_s: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
geometry_geojson: Mapped[str] = mapped_column(Text, nullable=False)
min_lon: Mapped[float] = mapped_column(Float, nullable=False)
min_lat: Mapped[float] = mapped_column(Float, nullable=False)
max_lon: Mapped[float] = mapped_column(Float, nullable=False)
max_lat: Mapped[float] = mapped_column(Float, nullable=False)
tags_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
class RouteMatch(Base):
__tablename__ = "route_matches"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
gtfs_route_id: Mapped[int] = mapped_column(ForeignKey("gtfs_routes.id"), nullable=False, index=True)
osm_feature_id: Mapped[Optional[int]] = mapped_column(ForeignKey("osm_features.id"), nullable=True, index=True)
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0)
status: Mapped[str] = mapped_column(String(64), nullable=False) # matched, probable, weak, missing, accepted, rejected
rule_source: Mapped[str] = mapped_column(String(64), default="auto", nullable=False) # auto, manual
reasons_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
gtfs_route: Mapped[GtfsRoute] = relationship()
osm_feature: Mapped[Optional[OsmFeature]] = relationship()
class MatchRule(Base):
__tablename__ = "match_rules"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
rule_type: Mapped[str] = mapped_column(String(64), nullable=False) # accept_match, reject_match, alias, force_operator
selector_json: Mapped[str] = mapped_column(Text, nullable=False)
action_json: Mapped[str] = mapped_column(Text, nullable=False)
note: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
class JourneySearchCache(Base):
__tablename__ = "journey_search_cache"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
cache_key: Mapped[str] = mapped_column(String(128), nullable=False, unique=True, index=True)
cache_type: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
payload_json: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
class TravelRequest(Base):
__tablename__ = "travel_requests"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
origin_stop_id: Mapped[str] = mapped_column(Text, nullable=False)
destination_stop_id: Mapped[str] = mapped_column(Text, nullable=False)
via_stop_id: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
departure_time: Mapped[str] = mapped_column(String(32), nullable=False)
service_date: Mapped[Optional[str]] = mapped_column(String(10), nullable=True, index=True)
max_transfers: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
transfer_seconds: Mapped[int] = mapped_column(Integer, nullable=False, default=120)
source_filter: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
preferences_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
itineraries: Mapped[list["Itinerary"]] = relationship(back_populates="request", cascade="all, delete-orphan")
class Itinerary(Base):
__tablename__ = "itineraries"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
request_id: Mapped[int] = mapped_column(ForeignKey("travel_requests.id"), nullable=False, index=True)
title: Mapped[str] = mapped_column(Text, nullable=False)
family: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
status: Mapped[str] = mapped_column(String(64), nullable=False, default="candidate", index=True)
saved: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, index=True)
summary_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
score_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
payload_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
request: Mapped[TravelRequest] = relationship(back_populates="itineraries")
legs: Mapped[list["ItineraryLeg"]] = relationship(back_populates="itinerary", cascade="all, delete-orphan")
class ItineraryLeg(Base):
__tablename__ = "itinerary_legs"
__table_args__ = (UniqueConstraint("itinerary_id", "sequence", name="uq_itinerary_leg_sequence"),)
id: Mapped[int] = mapped_column(Integer, primary_key=True)
itinerary_id: Mapped[int] = mapped_column(ForeignKey("itineraries.id"), nullable=False, index=True)
sequence: Mapped[int] = mapped_column(Integer, nullable=False)
mode: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
route_ref: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
route_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
from_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
to_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
departure_time: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
arrival_time: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
locked: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, index=True)
payload_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
itinerary: Mapped[Itinerary] = relationship(back_populates="legs")

111
app/osm_classification.py Normal file
View File

@@ -0,0 +1,111 @@
from __future__ import annotations
import json
import re
from typing import Mapping
LOCAL_SCOPE = "local"
REGIONAL_SCOPE = "regional"
LONG_DISTANCE_SCOPE = "long_distance"
UNKNOWN_SCOPE = "unknown"
OSM_ROUTE_SCOPE_CLASSIFIER_VERSION = "route_scope_v2"
BUS_MODES = {"bus", "trolleybus"}
LOCAL_MODES = {"tram", "light_rail", "subway", "ferry", "funicular", "aerialway", "monorail"}
LONG_DISTANCE_MODES = {"coach"}
LONG_DISTANCE_SERVICE_VALUES = {
"high_speed",
"long_distance",
"intercity",
"international",
"night",
"sleeper",
}
REGIONAL_SERVICE_VALUES = {"regional", "interurban", "commuter", "branch", "suburban"}
LOCAL_SERVICE_VALUES = {"local", "urban", "city", "subway", "tram", "light_rail", "s-bahn", "sbahn"}
LONG_DISTANCE_PREFIX_RE = re.compile(r"^(ICE|IC|EC|ECE|EN|NJ|RJ|RJX|TGV|THA|EST|FLX|WB)\b|^(ICE|IC|EC|ECE|EN|NJ|RJ|RJX|TGV|THA|EST|FLX|WB)\d")
REGIONAL_PREFIX_RE = re.compile(r"^(IRE|RE|RB|RER|TER|REX|MEX|ALX|WFB|R)\b|^(IRE|RE|RB|RER|TER|REX|MEX|ALX|WFB|R)\d")
LOCAL_TRAIN_PREFIX_RE = re.compile(r"^(S|S-BAHN)\b|^S\d")
def infer_osm_route_scope(
*,
mode: str | None,
ref: str | None = None,
name: str | None = None,
network: str | None = None,
tags: Mapping[str, object] | str | None = None,
) -> str | None:
"""Classify a public-transport route into a display scope.
OSM tagging varies by country and operator, so this intentionally combines
explicit service tags with conservative reference-prefix heuristics.
"""
normalized_mode = (mode or "").strip().lower()
tags_dict = _tags_dict(tags)
values = {
str(tags_dict.get(key) or "").strip().lower()
for key in ("service", "train", "bus", "passenger", "network:type", "route_scope")
if tags_dict.get(key)
}
if values & LONG_DISTANCE_SERVICE_VALUES:
return LONG_DISTANCE_SCOPE
if values & LOCAL_SERVICE_VALUES:
return LOCAL_SCOPE
if values & REGIONAL_SERVICE_VALUES:
return REGIONAL_SCOPE
if normalized_mode in LOCAL_MODES:
return LOCAL_SCOPE
if normalized_mode in LONG_DISTANCE_MODES:
return LONG_DISTANCE_SCOPE
text = _classification_text(ref, name, network, tags_dict)
if normalized_mode in BUS_MODES:
if any(marker in text for marker in ("FLIXBUS", "EUROLINES", "INTERCITYBUS", "IC BUS", "LONG DISTANCE", "FERNBUS")):
return LONG_DISTANCE_SCOPE
if any(marker in text for marker in ("REGIONALBUS", "REGIOBUS", "REGIONAL BUS", "REGIONALVERKEHR", "REGIONAL VERKEHR")):
return REGIONAL_SCOPE
return LOCAL_SCOPE
if normalized_mode == "train":
if LONG_DISTANCE_PREFIX_RE.search(text) or any(marker in text for marker in ("INTERCITY", "EUROCITY", "NIGHTJET", "FLIXTRAIN")):
return LONG_DISTANCE_SCOPE
if LOCAL_TRAIN_PREFIX_RE.search(text) or "S-BAHN" in text or "SBahn".upper() in text:
return LOCAL_SCOPE
if REGIONAL_PREFIX_RE.search(text) or any(marker in text for marker in ("REGIONAL", "REGIO", "REGIONALBAHN", "REGIONALEXPRESS")):
return REGIONAL_SCOPE
return UNKNOWN_SCOPE
return None
def infer_osm_route_scope_from_tags(mode: str | None, ref: str | None, name: str | None, network: str | None, tags_json: str | None) -> str | None:
return infer_osm_route_scope(mode=mode, ref=ref, name=name, network=network, tags=tags_json)
def _tags_dict(tags: Mapping[str, object] | str | None) -> dict[str, object]:
if isinstance(tags, str):
try:
data = json.loads(tags or "{}")
except json.JSONDecodeError:
return {}
return data if isinstance(data, dict) else {}
if isinstance(tags, Mapping):
return dict(tags)
return {}
def _classification_text(ref: str | None, name: str | None, network: str | None, tags: Mapping[str, object]) -> str:
parts = [
ref or "",
name or "",
network or "",
str(tags.get("ref") or ""),
str(tags.get("name") or ""),
str(tags.get("network") or ""),
str(tags.get("network:short") or ""),
]
return " ".join(parts).strip().upper().replace("_", " ")

981
app/osm_storage.py Normal file
View File

@@ -0,0 +1,981 @@
from __future__ import annotations
import json
import sqlite3
from contextlib import contextmanager
from pathlib import Path
from typing import Iterator, Sequence
from sqlalchemy import and_, func, insert, not_, or_, select, text
from sqlalchemy.dialects.postgresql import insert as postgresql_insert
from sqlalchemy.orm import Session
from app.config import settings
from app.models import Dataset, OsmFeature
from app.spatial import refresh_postgis_geometries
OSM_STORAGE_METADATA_KEY = "osm_storage"
OSM_STORAGE_MAIN = "main"
OSM_STORAGE_SIDECAR_FEATURES = "sidecar_features"
SQLITE_IN_CHUNK_SIZE = 800
OSM_SIDECAR_ROUTE_SCOPE_INDEXES = ["ix_osm_sidecar_scope_bbox"]
OSM_FEATURE_COLUMNS = [
"dataset_id",
"osm_type",
"osm_id",
"kind",
"mode",
"route_scope",
"name",
"ref",
"operator",
"network",
"geometry_geojson",
"min_lon",
"min_lat",
"max_lon",
"max_lat",
"tags_json",
"route_key",
"operator_key",
]
def effective_osm_feature_storage(value: str | None = None) -> str:
configured = str(value or settings.osm_feature_storage or OSM_STORAGE_SIDECAR_FEATURES).strip().lower()
if configured in {OSM_STORAGE_MAIN, "main", "main_db", "postgres", "postgresql"}:
return OSM_STORAGE_MAIN
if settings.is_postgresql_database and not settings.postgres_use_sidecars:
return OSM_STORAGE_MAIN
return OSM_STORAGE_SIDECAR_FEATURES
class MissingOsmSidecar(FileNotFoundError):
pass
def dataset_metadata(dataset: Dataset) -> dict:
try:
metadata = json.loads(dataset.metadata_json or "{}")
except json.JSONDecodeError:
return {}
return metadata if isinstance(metadata, dict) else {}
def features_are_sidecar(dataset: Dataset | None) -> bool:
if dataset is None:
return False
storage = dataset_metadata(dataset).get(OSM_STORAGE_METADATA_KEY)
if not isinstance(storage, dict):
return False
tables = storage.get("tables")
if isinstance(tables, dict):
return tables.get("osm_features") == "sidecar"
return storage.get("mode") == OSM_STORAGE_SIDECAR_FEATURES
def sidecar_path(dataset: Dataset | None) -> Path | None:
if dataset is None:
return None
storage = dataset_metadata(dataset).get(OSM_STORAGE_METADATA_KEY)
if not isinstance(storage, dict):
return None
value = storage.get("sidecar_path")
if not value:
return None
return Path(str(value))
def dataset_sidecar_paths(dataset: Dataset) -> list[Path]:
path = sidecar_path(dataset)
return [] if path is None else [path]
def missing_sidecar_paths(dataset: Dataset | None) -> list[str]:
if not features_are_sidecar(dataset):
return []
path = sidecar_path(dataset)
if path is None or path.exists():
return []
return [str(path)]
@contextmanager
def sidecar_connection(dataset: Dataset) -> Iterator[sqlite3.Connection]:
path = sidecar_path(dataset)
if path is None:
raise MissingOsmSidecar(f"dataset #{dataset.id} does not reference an OSM sidecar")
if not path.exists():
raise MissingOsmSidecar(f"OSM sidecar does not exist: {path}")
connection = sqlite3.connect(f"file:{path}?mode=ro", uri=True)
connection.row_factory = sqlite3.Row
try:
yield connection
finally:
connection.close()
@contextmanager
def writable_sidecar_connection(dataset: Dataset) -> Iterator[sqlite3.Connection]:
path = sidecar_path(dataset)
if path is None:
raise MissingOsmSidecar(f"dataset #{dataset.id} does not reference an OSM sidecar")
if not path.exists():
raise MissingOsmSidecar(f"OSM sidecar does not exist: {path}")
connection = sqlite3.connect(path)
connection.row_factory = sqlite3.Row
try:
connection.execute(f"PRAGMA busy_timeout={int(settings.sqlite_busy_timeout_ms)}")
connection.execute("PRAGMA synchronous=NORMAL")
yield connection
finally:
connection.close()
def create_osm_sidecar(dataset: Dataset, rows: Sequence[dict[str, object]], *, source_hash: str | None = None) -> dict:
path = _new_sidecar_path(dataset, source_hash or dataset.sha256)
path.parent.mkdir(parents=True, exist_ok=True)
if path.exists():
path.unlink()
connection = sqlite3.connect(path)
try:
connection.execute("PRAGMA journal_mode=OFF")
connection.execute("PRAGMA synchronous=OFF")
_create_schema(connection)
deduped_rows, duplicate_count = dedupe_osm_feature_rows(rows)
inserted = 0
counts = {"route": 0, "stop": 0, "station": 0, "terminal": 0, "infra": 0, "feature": 0}
insert_sql = f"""
INSERT INTO osm_features
({", ".join(["id", *OSM_FEATURE_COLUMNS])})
VALUES
({", ".join(["?"] * (len(OSM_FEATURE_COLUMNS) + 1))})
"""
batch = []
for index, row in enumerate(deduped_rows, start=1):
kind = str(row.get("kind") or "feature")
counts[kind] = counts.get(kind, 0) + 1
batch.append((index, *[row.get(column) for column in OSM_FEATURE_COLUMNS]))
if len(batch) >= 5000:
connection.executemany(insert_sql, batch)
inserted += len(batch)
batch.clear()
if batch:
connection.executemany(insert_sql, batch)
inserted += len(batch)
connection.commit()
_create_indexes(connection)
connection.commit()
finally:
connection.close()
return {
"mode": OSM_STORAGE_SIDECAR_FEATURES,
"tables": {"osm_features": "sidecar"},
"sidecar_path": str(path),
"features": inserted,
"duplicate_features_skipped": duplicate_count,
"counts": counts,
}
def ensure_osm_sidecar_schema(connection: sqlite3.Connection) -> None:
columns = _sidecar_columns(connection)
if "route_scope" not in columns:
connection.execute("ALTER TABLE osm_features ADD COLUMN route_scope TEXT")
connection.commit()
def drop_osm_sidecar_route_scope_indexes(connection: sqlite3.Connection) -> None:
for index_name in OSM_SIDECAR_ROUTE_SCOPE_INDEXES:
connection.execute(f"DROP INDEX IF EXISTS {index_name}")
def rebuild_osm_sidecar_indexes(connection: sqlite3.Connection) -> None:
_create_indexes(connection)
def osm_feature_count(session: Session, dataset_id: int, *, kind: str | Sequence[str] | None = None) -> int:
dataset = session.get(Dataset, dataset_id)
if features_are_sidecar(dataset):
kinds = _as_list(kind)
sql = "SELECT COUNT(*) FROM osm_features"
params: list[object] = []
if kinds:
placeholders = ", ".join(["?"] * len(kinds))
sql += f" WHERE kind IN ({placeholders})"
params.extend(kinds)
try:
with sidecar_connection(dataset) as connection:
return int(connection.execute(sql, params).fetchone()[0] or 0)
except MissingOsmSidecar:
return 0
stmt = select(func.count()).select_from(OsmFeature).where(OsmFeature.dataset_id == dataset_id)
kinds = _as_list(kind)
if kinds:
stmt = stmt.where(OsmFeature.kind.in_(kinds))
return int(session.scalar(stmt) or 0)
def osm_feature_bbox(
session: Session,
dataset_ids: Sequence[int],
*,
kinds: Sequence[str] | None = None,
) -> tuple[float | None, float | None, float | None, float | None]:
if not dataset_ids:
return (None, None, None, None)
datasets = {
dataset.id: dataset
for dataset in session.scalars(select(Dataset).where(Dataset.id.in_([int(value) for value in dataset_ids]))).all()
}
boxes: list[tuple[float, float, float, float]] = []
main_dataset_ids = [dataset_id for dataset_id, dataset in datasets.items() if not features_are_sidecar(dataset)]
if main_dataset_ids:
stmt = select(func.min(OsmFeature.min_lon), func.min(OsmFeature.min_lat), func.max(OsmFeature.max_lon), func.max(OsmFeature.max_lat)).where(
OsmFeature.dataset_id.in_(main_dataset_ids)
)
if kinds:
stmt = stmt.where(OsmFeature.kind.in_(list(kinds)))
row = session.execute(stmt).one()
if None not in row:
boxes.append((float(row[0]), float(row[1]), float(row[2]), float(row[3])))
for dataset in datasets.values():
if not features_are_sidecar(dataset):
continue
where = []
params: list[object] = []
if kinds:
placeholders = ", ".join(["?"] * len(kinds))
where.append(f"kind IN ({placeholders})")
params.extend(list(kinds))
sql = "SELECT MIN(min_lon), MIN(min_lat), MAX(max_lon), MAX(max_lat) FROM osm_features"
if where:
sql += " WHERE " + " AND ".join(where)
try:
with sidecar_connection(dataset) as connection:
row = connection.execute(sql, params).fetchone()
if row is not None and None not in row:
boxes.append((float(row[0]), float(row[1]), float(row[2]), float(row[3])))
except MissingOsmSidecar:
continue
if not boxes:
return (None, None, None, None)
return (
min(box[0] for box in boxes),
min(box[1] for box in boxes),
max(box[2] for box in boxes),
max(box[3] for box in boxes),
)
def query_osm_features(
session: Session,
dataset_ids: Sequence[int],
*,
kinds: Sequence[str] | None = None,
modes: Sequence[str] | None = None,
bbox: tuple[float, float, float, float] | None = None,
geometry_required: bool | None = None,
search: str | None = None,
route_key: str | None = None,
route_scopes: Sequence[str] | None = None,
ref: str | None = None,
osm_type: str | None = None,
osm_id: str | None = None,
limit: int | None = None,
offset: int | None = None,
prefer_materialized_ids: bool = True,
) -> list[OsmFeature]:
if not dataset_ids:
return []
datasets = {
dataset.id: dataset
for dataset in session.scalars(select(Dataset).where(Dataset.id.in_([int(value) for value in dataset_ids]))).all()
}
materialized_ids = _materialized_ids_by_identity(session, list(datasets)) if prefer_materialized_ids else {}
rows: list[OsmFeature] = []
main_dataset_ids = [dataset_id for dataset_id, dataset in datasets.items() if not features_are_sidecar(dataset)]
if main_dataset_ids:
stmt = select(OsmFeature).where(OsmFeature.dataset_id.in_(main_dataset_ids))
stmt = _apply_main_filters(
stmt,
kinds=kinds,
modes=modes,
bbox=bbox,
geometry_required=geometry_required,
search=search,
route_key=route_key,
route_scopes=route_scopes,
ref=ref,
osm_type=osm_type,
osm_id=osm_id,
)
if offset:
stmt = stmt.offset(max(0, int(offset)))
rows.extend(
session.scalars(
stmt.order_by(OsmFeature.kind, OsmFeature.mode, OsmFeature.ref, OsmFeature.name, OsmFeature.id).limit(limit)
).all()
)
for dataset_id, dataset in datasets.items():
if not features_are_sidecar(dataset):
continue
rows.extend(
_query_sidecar_features(
dataset,
kinds=kinds,
modes=modes,
bbox=bbox,
geometry_required=geometry_required,
search=search,
route_key=route_key,
route_scopes=route_scopes,
ref=ref,
osm_type=osm_type,
osm_id=osm_id,
limit=limit,
offset=offset,
materialized_ids=materialized_ids,
)
)
rows.sort(key=lambda row: (row.kind or "", row.mode or "", row.ref or "", row.name or "", int(row.id or 0)))
if limit is not None:
return rows[: max(1, int(limit))]
return rows
def get_osm_feature(session: Session, feature_id: int) -> OsmFeature | None:
return session.get(OsmFeature, feature_id)
def osm_feature_identity_key(feature: OsmFeature) -> str:
return f"{feature.dataset_id}|{feature.osm_type}|{feature.osm_id}"
def osm_feature_public_id(feature: OsmFeature) -> int | str | None:
if getattr(feature, "_osm_sidecar_source", False):
return osm_feature_identity_key(feature)
return feature.id
def resolve_osm_feature(session: Session, value: int | str) -> OsmFeature | None:
int_value = _safe_int(value)
if int_value is not None:
feature = session.get(OsmFeature, int_value)
if feature is not None:
return feature
parsed = parse_osm_feature_identity_key(str(value))
if parsed is None:
return None
dataset_id, osm_type, osm_id = parsed
existing = session.scalar(
select(OsmFeature).where(
OsmFeature.dataset_id == dataset_id,
OsmFeature.osm_type == osm_type,
OsmFeature.osm_id == osm_id,
)
)
if existing is not None:
return existing
dataset = session.get(Dataset, dataset_id)
if not features_are_sidecar(dataset):
return None
try:
with sidecar_connection(dataset) as connection:
select_columns = ", ".join(_sidecar_select_columns(_sidecar_columns(connection)))
row = connection.execute(
f"""
SELECT id, {select_columns}
FROM osm_features
WHERE dataset_id = ?
AND osm_type = ?
AND osm_id = ?
""",
(dataset_id, osm_type, osm_id),
).fetchone()
except MissingOsmSidecar:
return None
if row is None:
return None
return _feature_from_row(row, {})
def parse_osm_feature_identity_key(value: str) -> tuple[int, str, str] | None:
parts = value.split("|", 2)
if len(parts) != 3:
return None
dataset_id = _safe_int(parts[0])
if dataset_id is None:
return None
osm_type = parts[1].strip()
osm_id = parts[2].strip()
if not osm_type or not osm_id:
return None
return dataset_id, osm_type, osm_id
def ensure_main_osm_feature(session: Session, feature: OsmFeature) -> OsmFeature:
existing = session.scalar(
select(OsmFeature).where(
OsmFeature.dataset_id == feature.dataset_id,
OsmFeature.osm_type == feature.osm_type,
OsmFeature.osm_id == feature.osm_id,
)
)
if existing is not None:
return existing
values = dict(
dataset_id=feature.dataset_id,
osm_type=feature.osm_type,
osm_id=feature.osm_id,
kind=feature.kind,
mode=feature.mode,
route_scope=feature.route_scope,
name=feature.name,
ref=feature.ref,
operator=feature.operator,
network=feature.network,
geometry_geojson=feature.geometry_geojson,
min_lon=feature.min_lon,
min_lat=feature.min_lat,
max_lon=feature.max_lon,
max_lat=feature.max_lat,
tags_json=feature.tags_json,
route_key=feature.route_key,
operator_key=feature.operator_key,
)
if settings.is_postgresql_database:
session.execute(
postgresql_insert(OsmFeature)
.values(**values)
.on_conflict_do_nothing(index_elements=["dataset_id", "osm_type", "osm_id"])
)
else:
session.execute(insert(OsmFeature).values(**values).prefix_with("OR IGNORE"))
session.flush()
refresh_postgis_geometries(session, dataset_id=feature.dataset_id, tables=["osm_features"])
existing = session.scalar(
select(OsmFeature).where(
OsmFeature.dataset_id == feature.dataset_id,
OsmFeature.osm_type == feature.osm_type,
OsmFeature.osm_id == feature.osm_id,
)
)
if existing is None:
raise RuntimeError(f"Could not materialize OSM feature {feature.dataset_id}:{feature.osm_type}:{feature.osm_id}")
return existing
def materialize_osm_features(session: Session, features: Sequence[OsmFeature]) -> list[OsmFeature]:
return [ensure_main_osm_feature(session, feature) for feature in features]
def _new_sidecar_path(dataset: Dataset, source_hash: str | None) -> Path:
suffix = (source_hash or dataset.sha256 or str(dataset.id))[:12]
return settings.data_dir / "sidecars" / f"source_{dataset.source_id}" / f"osm_dataset_{dataset.id}_{suffix}.sqlite"
def dedupe_osm_feature_rows(rows: Sequence[dict[str, object]]) -> tuple[list[dict[str, object]], int]:
selected: dict[tuple[int, str, str], dict[str, object]] = {}
for row in rows:
key = (int(row["dataset_id"]), str(row["osm_type"]), str(row["osm_id"]))
current = selected.get(key)
if current is None or _feature_row_preference(row) < _feature_row_preference(current):
selected[key] = dict(row)
return list(selected.values()), max(0, len(rows) - len(selected))
def _feature_row_preference(row: dict[str, object]) -> tuple[int, int, int]:
kind_rank = {
"route": 0,
"station": 1,
"terminal": 2,
"stop": 3,
"infra": 4,
"feature": 5,
}.get(str(row.get("kind") or "feature"), 6)
has_geometry = 0 if row.get("geometry_geojson") else 1
geometry_size = -len(str(row.get("geometry_geojson") or ""))
return (kind_rank, has_geometry, geometry_size)
def _create_schema(connection: sqlite3.Connection) -> None:
connection.execute(
"""
CREATE TABLE osm_features (
id INTEGER PRIMARY KEY,
dataset_id INTEGER NOT NULL,
osm_type TEXT NOT NULL,
osm_id TEXT NOT NULL,
kind TEXT NOT NULL,
mode TEXT,
route_scope TEXT,
name TEXT,
ref TEXT,
operator TEXT,
network TEXT,
geometry_geojson TEXT,
min_lon REAL,
min_lat REAL,
max_lon REAL,
max_lat REAL,
tags_json TEXT,
route_key TEXT,
operator_key TEXT,
UNIQUE(dataset_id, osm_type, osm_id)
)
"""
)
def _create_indexes(connection: sqlite3.Connection) -> None:
statements = [
"CREATE INDEX IF NOT EXISTS ix_osm_sidecar_kind_mode_bbox ON osm_features (kind, mode, min_lon, max_lon, min_lat, max_lat)",
"CREATE INDEX IF NOT EXISTS ix_osm_sidecar_scope_bbox ON osm_features (kind, mode, route_scope, min_lon, max_lon, min_lat, max_lat)",
"CREATE INDEX IF NOT EXISTS ix_osm_sidecar_route_key ON osm_features (route_key)",
"CREATE INDEX IF NOT EXISTS ix_osm_sidecar_ref ON osm_features (ref)",
"CREATE INDEX IF NOT EXISTS ix_osm_sidecar_identity ON osm_features (dataset_id, osm_type, osm_id)",
"CREATE INDEX IF NOT EXISTS ix_osm_sidecar_kind_ref_mode ON osm_features (kind, ref, mode)",
]
for statement in statements:
connection.execute(statement)
def _apply_main_filters(stmt, *, kinds, modes, bbox, geometry_required, search, route_key, route_scopes, ref, osm_type, osm_id):
if kinds:
stmt = stmt.where(OsmFeature.kind.in_(list(kinds)))
if modes:
stmt = stmt.where(OsmFeature.mode.in_(list(modes)))
if route_scopes:
stmt = stmt.where(_main_route_scope_condition([str(scope) for scope in route_scopes]))
if bbox:
min_lon, min_lat, max_lon, max_lat = bbox
if settings.is_postgresql_database:
stmt = stmt.where(
text(
"""
(
osm_features.geom && ST_MakeEnvelope(:bbox_min_lon, :bbox_min_lat, :bbox_max_lon, :bbox_max_lat, 4326)
OR (
osm_features.geom IS NULL
AND osm_features.min_lon <= :bbox_max_lon
AND osm_features.max_lon >= :bbox_min_lon
AND osm_features.min_lat <= :bbox_max_lat
AND osm_features.max_lat >= :bbox_min_lat
)
)
"""
)
).params(
bbox_min_lon=min_lon,
bbox_min_lat=min_lat,
bbox_max_lon=max_lon,
bbox_max_lat=max_lat,
)
else:
stmt = stmt.where(OsmFeature.min_lon <= max_lon, OsmFeature.max_lon >= min_lon, OsmFeature.min_lat <= max_lat, OsmFeature.max_lat >= min_lat)
if geometry_required is True:
stmt = stmt.where(OsmFeature.geometry_geojson.is_not(None))
elif geometry_required is False:
stmt = stmt.where(OsmFeature.geometry_geojson.is_(None))
if search:
if settings.is_postgresql_database:
stmt = stmt.where(
text(
"""
(
LOWER(COALESCE(osm_features.ref, '')) LIKE :search_pattern
OR LOWER(COALESCE(osm_features.name, '')) LIKE :search_pattern
OR LOWER(COALESCE(osm_features.tags_json, '')) LIKE :search_pattern
)
"""
)
).params(search_pattern=f"%{search.lower()}%")
else:
pattern = f"%{search}%"
stmt = stmt.where(
(OsmFeature.ref.ilike(pattern))
| (OsmFeature.name.ilike(pattern))
| (OsmFeature.tags_json.ilike(pattern))
)
if route_key:
stmt = stmt.where(OsmFeature.route_key == route_key)
if ref:
stmt = stmt.where(OsmFeature.ref == ref)
if osm_type:
stmt = stmt.where(OsmFeature.osm_type == osm_type)
if osm_id:
stmt = stmt.where(OsmFeature.osm_id == osm_id)
return stmt
def _main_route_scope_condition(route_scopes: list[str]):
fallback = _main_route_scope_fallback_condition(route_scopes)
stored = OsmFeature.route_scope.in_(route_scopes)
if "local" in route_scopes:
non_local_bus_fallback = _main_route_scope_fallback_condition(["long_distance", "regional"])
stored = and_(stored, not_(and_(OsmFeature.mode.in_(["bus", "trolleybus"]), non_local_bus_fallback)))
return or_(stored, fallback)
def _main_route_scope_fallback_condition(route_scopes: list[str]):
ref = func.upper(func.coalesce(OsmFeature.ref, ""))
name = func.upper(func.coalesce(OsmFeature.name, ""))
network = func.upper(func.coalesce(OsmFeature.network, ""))
tags = func.lower(func.coalesce(OsmFeature.tags_json, ""))
train_long_distance = and_(
OsmFeature.mode == "train",
or_(
ref.like("ICE%"),
ref.like("IC%"),
ref.like("EC%"),
ref.like("ECE%"),
ref.like("EN%"),
ref.like("NJ%"),
ref.like("RJ%"),
ref.like("RJX%"),
ref.like("TGV%"),
ref.like("THA%"),
ref.like("FLX%"),
name.like("%INTERCITY%"),
name.like("%EUROCITY%"),
name.like("%NIGHTJET%"),
name.like("%FLIXTRAIN%"),
tags.like('%"service":"long_distance"%'),
tags.like('%"train":"long_distance"%'),
tags.like('%"train":"high_speed"%'),
tags.like('%"train":"intercity"%'),
),
)
bus_long_distance = and_(
OsmFeature.mode.in_(["bus", "trolleybus"]),
or_(
name.like("%FLIXBUS%"),
network.like("%FLIXBUS%"),
name.like("%EUROLINES%"),
network.like("%EUROLINES%"),
name.like("%INTERCITYBUS%"),
name.like("%IC BUS%"),
name.like("%FERNBUS%"),
tags.like('%"service":"long_distance"%'),
tags.like('%"bus":"long_distance"%'),
tags.like('%"bus":"intercity"%'),
tags.like('%"network:type":"long_distance"%'),
),
)
long_distance = or_(OsmFeature.mode == "coach", train_long_distance, bus_long_distance)
bus_regional = and_(
OsmFeature.mode.in_(["bus", "trolleybus"]),
not_(bus_long_distance),
or_(
name.like("%REGIONALBUS%"),
name.like("%REGIOBUS%"),
name.like("%REGIONAL BUS%"),
name.like("%REGIONALVERKEHR%"),
network.like("%REGIONALBUS%"),
network.like("%REGIOBUS%"),
network.like("%REGIONALVERKEHR%"),
tags.like('%"service":"regional"%'),
tags.like('%"bus":"regional"%'),
tags.like('%"bus":"interurban"%'),
tags.like('%"network:type":"regional"%'),
),
)
local = or_(
OsmFeature.mode.in_(["tram", "light_rail", "subway", "ferry", "funicular", "aerialway", "monorail"]),
and_(OsmFeature.mode.in_(["bus", "trolleybus"]), not_(or_(bus_long_distance, bus_regional))),
and_(
OsmFeature.mode == "train",
or_(ref.like("S%"), name.like("%S-BAHN%"), network.like("%S-BAHN%"), tags.like('%"train":"commuter"%')),
),
)
train_regional = and_(
OsmFeature.mode == "train",
not_(train_long_distance),
or_(
ref.like("IRE%"),
ref.like("RE%"),
ref.like("RB%"),
ref.like("RER%"),
ref.like("TER%"),
ref.like("REX%"),
ref.like("MEX%"),
ref.like("ALX%"),
ref.like("WFB%"),
ref.like("R%"),
name.like("%REGIONAL%"),
name.like("%REGIO%"),
tags.like('%"service":"regional"%'),
tags.like('%"train":"regional"%'),
),
)
regional = or_(train_regional, bus_regional)
conditions = []
if "long_distance" in route_scopes:
conditions.append(long_distance)
if "regional" in route_scopes:
conditions.append(regional)
if "local" in route_scopes:
conditions.append(local)
if "unknown" in route_scopes:
conditions.append(and_(OsmFeature.mode == "train", not_(or_(long_distance, regional, local))))
return or_(*conditions) if conditions else OsmFeature.route_scope.is_(None)
def _query_sidecar_features(
dataset: Dataset,
*,
kinds: Sequence[str] | None,
modes: Sequence[str] | None,
bbox: tuple[float, float, float, float] | None,
geometry_required: bool | None,
search: str | None,
route_key: str | None,
route_scopes: Sequence[str] | None,
ref: str | None,
osm_type: str | None,
osm_id: str | None,
limit: int | None,
offset: int | None,
materialized_ids: dict[tuple[int, str, str], int],
) -> list[OsmFeature]:
where = []
params: list[object] = []
try:
with sidecar_connection(dataset) as connection:
available_columns = _sidecar_columns(connection)
if kinds:
placeholders = ", ".join(["?"] * len(kinds))
where.append(f"kind IN ({placeholders})")
params.extend(list(kinds))
if modes:
placeholders = ", ".join(["?"] * len(modes))
where.append(f"mode IN ({placeholders})")
params.extend(list(modes))
if bbox:
min_lon, min_lat, max_lon, max_lat = bbox
where.extend(["min_lon <= ?", "max_lon >= ?", "min_lat <= ?", "max_lat >= ?"])
params.extend([max_lon, min_lon, max_lat, min_lat])
if geometry_required is True:
where.append("geometry_geojson IS NOT NULL")
elif geometry_required is False:
where.append("geometry_geojson IS NULL")
if search:
where.append("(LOWER(COALESCE(ref, '')) LIKE ? OR LOWER(COALESCE(name, '')) LIKE ? OR LOWER(COALESCE(tags_json, '')) LIKE ?)")
pattern = f"%{search.lower()}%"
params.extend([pattern, pattern, pattern])
if route_key:
where.append("route_key = ?")
params.append(route_key)
if route_scopes:
condition, condition_params = _sidecar_route_scope_condition([str(scope) for scope in route_scopes], has_route_scope="route_scope" in available_columns)
where.append(condition)
params.extend(condition_params)
if ref:
where.append("ref = ?")
params.append(ref)
if osm_type:
where.append("osm_type = ?")
params.append(osm_type)
if osm_id:
where.append("osm_id = ?")
params.append(osm_id)
select_columns = ", ".join(_sidecar_select_columns(available_columns))
sql = f"SELECT id, {select_columns} FROM osm_features"
if where:
sql += " WHERE " + " AND ".join(where)
sql += " ORDER BY kind, mode, ref, name, id"
if limit is not None:
sql += " LIMIT ?"
params.append(max(1, int(limit)))
if offset:
if limit is None:
sql += " LIMIT -1"
sql += " OFFSET ?"
params.append(max(0, int(offset)))
return [_feature_from_row(row, materialized_ids) for row in connection.execute(sql, params).fetchall()]
except MissingOsmSidecar:
return []
def _sidecar_columns(connection: sqlite3.Connection) -> set[str]:
return {str(row["name"]) for row in connection.execute("PRAGMA table_info(osm_features)").fetchall()}
def _sidecar_select_columns(available_columns: set[str]) -> list[str]:
return [column if column in available_columns else f"NULL AS {column}" for column in OSM_FEATURE_COLUMNS]
def _sidecar_route_scope_condition(route_scopes: list[str], *, has_route_scope: bool) -> tuple[str, list[object]]:
fallback_sql, fallback_params = _sidecar_route_scope_fallback_condition(route_scopes)
if has_route_scope:
placeholders = ", ".join(["?"] * len(route_scopes))
stored_sql = f"route_scope IN ({placeholders})"
params: list[object] = [*route_scopes]
if "local" in route_scopes:
non_local_sql, non_local_params = _sidecar_route_scope_fallback_condition(["long_distance", "regional"])
stored_sql = f"({stored_sql} AND NOT (mode IN ('bus', 'trolleybus') AND {non_local_sql}))"
params.extend(non_local_params)
return f"({stored_sql} OR {fallback_sql})", [*params, *fallback_params]
return fallback_sql, fallback_params
def _sidecar_route_scope_fallback_condition(route_scopes: list[str]) -> tuple[str, list[object]]:
train_long_distance = """(
mode = 'train'
AND (
UPPER(COALESCE(ref, '')) LIKE 'ICE%'
OR UPPER(COALESCE(ref, '')) LIKE 'IC%'
OR UPPER(COALESCE(ref, '')) LIKE 'EC%'
OR UPPER(COALESCE(ref, '')) LIKE 'ECE%'
OR UPPER(COALESCE(ref, '')) LIKE 'EN%'
OR UPPER(COALESCE(ref, '')) LIKE 'NJ%'
OR UPPER(COALESCE(ref, '')) LIKE 'RJ%'
OR UPPER(COALESCE(ref, '')) LIKE 'RJX%'
OR UPPER(COALESCE(ref, '')) LIKE 'TGV%'
OR UPPER(COALESCE(ref, '')) LIKE 'THA%'
OR UPPER(COALESCE(ref, '')) LIKE 'FLX%'
OR UPPER(COALESCE(name, '')) LIKE '%INTERCITY%'
OR UPPER(COALESCE(name, '')) LIKE '%EUROCITY%'
OR UPPER(COALESCE(name, '')) LIKE '%NIGHTJET%'
OR UPPER(COALESCE(name, '')) LIKE '%FLIXTRAIN%'
OR LOWER(COALESCE(tags_json, '')) LIKE '%"service":"long_distance"%'
OR LOWER(COALESCE(tags_json, '')) LIKE '%"train":"long_distance"%'
OR LOWER(COALESCE(tags_json, '')) LIKE '%"train":"high_speed"%'
OR LOWER(COALESCE(tags_json, '')) LIKE '%"train":"intercity"%'
)
)"""
bus_long_distance = """(
mode IN ('bus', 'trolleybus')
AND (
UPPER(COALESCE(name, '')) LIKE '%FLIXBUS%'
OR UPPER(COALESCE(network, '')) LIKE '%FLIXBUS%'
OR UPPER(COALESCE(name, '')) LIKE '%EUROLINES%'
OR UPPER(COALESCE(network, '')) LIKE '%EUROLINES%'
OR UPPER(COALESCE(name, '')) LIKE '%INTERCITYBUS%'
OR UPPER(COALESCE(name, '')) LIKE '%IC BUS%'
OR UPPER(COALESCE(name, '')) LIKE '%FERNBUS%'
OR LOWER(COALESCE(tags_json, '')) LIKE '%"service":"long_distance"%'
OR LOWER(COALESCE(tags_json, '')) LIKE '%"bus":"long_distance"%'
OR LOWER(COALESCE(tags_json, '')) LIKE '%"bus":"intercity"%'
OR LOWER(COALESCE(tags_json, '')) LIKE '%"network:type":"long_distance"%'
)
)"""
long_distance = f"(mode = 'coach' OR {train_long_distance} OR {bus_long_distance})"
bus_regional = f"""(
mode IN ('bus', 'trolleybus')
AND NOT {bus_long_distance}
AND (
UPPER(COALESCE(name, '')) LIKE '%REGIONALBUS%'
OR UPPER(COALESCE(name, '')) LIKE '%REGIOBUS%'
OR UPPER(COALESCE(name, '')) LIKE '%REGIONAL BUS%'
OR UPPER(COALESCE(name, '')) LIKE '%REGIONALVERKEHR%'
OR UPPER(COALESCE(network, '')) LIKE '%REGIONALBUS%'
OR UPPER(COALESCE(network, '')) LIKE '%REGIOBUS%'
OR UPPER(COALESCE(network, '')) LIKE '%REGIONALVERKEHR%'
OR LOWER(COALESCE(tags_json, '')) LIKE '%"service":"regional"%'
OR LOWER(COALESCE(tags_json, '')) LIKE '%"bus":"regional"%'
OR LOWER(COALESCE(tags_json, '')) LIKE '%"bus":"interurban"%'
OR LOWER(COALESCE(tags_json, '')) LIKE '%"network:type":"regional"%'
)
)"""
train_regional = f"""(
mode = 'train'
AND NOT {train_long_distance}
AND (
UPPER(COALESCE(ref, '')) LIKE 'IRE%'
OR UPPER(COALESCE(ref, '')) LIKE 'RE%'
OR UPPER(COALESCE(ref, '')) LIKE 'RB%'
OR UPPER(COALESCE(ref, '')) LIKE 'RER%'
OR UPPER(COALESCE(ref, '')) LIKE 'TER%'
OR UPPER(COALESCE(ref, '')) LIKE 'REX%'
OR UPPER(COALESCE(ref, '')) LIKE 'MEX%'
OR UPPER(COALESCE(ref, '')) LIKE 'ALX%'
OR UPPER(COALESCE(ref, '')) LIKE 'WFB%'
OR UPPER(COALESCE(ref, '')) LIKE 'R%'
OR UPPER(COALESCE(name, '')) LIKE '%REGIONAL%'
OR UPPER(COALESCE(name, '')) LIKE '%REGIO%'
OR LOWER(COALESCE(tags_json, '')) LIKE '%"service":"regional"%'
OR LOWER(COALESCE(tags_json, '')) LIKE '%"train":"regional"%'
)
)"""
regional = f"({train_regional} OR {bus_regional})"
local = f"""(
mode IN ('tram', 'light_rail', 'subway', 'ferry', 'funicular', 'aerialway', 'monorail')
OR (mode IN ('bus', 'trolleybus') AND NOT ({bus_long_distance} OR {bus_regional}))
OR (
mode = 'train'
AND (
UPPER(COALESCE(ref, '')) LIKE 'S%'
OR UPPER(COALESCE(name, '')) LIKE '%S-BAHN%'
OR UPPER(COALESCE(network, '')) LIKE '%S-BAHN%'
OR LOWER(COALESCE(tags_json, '')) LIKE '%"train":"commuter"%'
)
)
)"""
parts = []
if "long_distance" in route_scopes:
parts.append(long_distance)
if "regional" in route_scopes:
parts.append(regional)
if "local" in route_scopes:
parts.append(local)
if "unknown" in route_scopes:
parts.append(f"(mode = 'train' AND NOT ({long_distance} OR {regional} OR {local}))")
return "(" + " OR ".join(parts or ["0"]) + ")", []
def _feature_from_row(row: sqlite3.Row, materialized_ids: dict[tuple[int, str, str], int]) -> OsmFeature:
dataset_id = int(row["dataset_id"])
osm_type = str(row["osm_type"])
osm_id = str(row["osm_id"])
feature_id = materialized_ids.get((dataset_id, osm_type, osm_id), int(row["id"]))
feature = OsmFeature(
id=feature_id,
dataset_id=dataset_id,
osm_type=osm_type,
osm_id=osm_id,
kind=str(row["kind"]),
mode=row["mode"],
route_scope=row["route_scope"],
name=row["name"],
ref=row["ref"],
operator=row["operator"],
network=row["network"],
geometry_geojson=row["geometry_geojson"],
min_lon=row["min_lon"],
min_lat=row["min_lat"],
max_lon=row["max_lon"],
max_lat=row["max_lat"],
tags_json=row["tags_json"],
route_key=row["route_key"],
operator_key=row["operator_key"],
)
setattr(feature, "_osm_sidecar_source", True)
setattr(feature, "_osm_sidecar_row_id", int(row["id"]))
return feature
def _materialized_ids_by_identity(session: Session, dataset_ids: Sequence[int]) -> dict[tuple[int, str, str], int]:
if not dataset_ids:
return {}
rows = session.execute(
select(OsmFeature.dataset_id, OsmFeature.osm_type, OsmFeature.osm_id, OsmFeature.id).where(OsmFeature.dataset_id.in_(dataset_ids))
).all()
return {(int(dataset_id), str(osm_type), str(osm_id)): int(feature_id) for dataset_id, osm_type, osm_id, feature_id in rows}
def _as_list(value: str | Sequence[str] | None) -> list[str]:
if value is None:
return []
if isinstance(value, str):
return [value]
return [str(item) for item in value]
def _safe_int(value: object) -> int | None:
try:
return int(value) # type: ignore[arg-type]
except (TypeError, ValueError):
return None

61
app/performance.py Normal file
View File

@@ -0,0 +1,61 @@
from __future__ import annotations
import json
import time
from contextlib import contextmanager
from datetime import datetime, timezone
from pathlib import Path
from typing import Iterator
from app.config import settings
@contextmanager
def measure_pipeline_phase(
phase: str,
*,
source_id: int | None = None,
dataset_id: int | None = None,
metadata: dict[str, object] | None = None,
) -> Iterator[dict[str, object]]:
start = time.perf_counter()
payload: dict[str, object] = dict(metadata or {})
try:
yield payload
finally:
duration = round(time.perf_counter() - start, 3)
payload["duration_seconds"] = duration
record_pipeline_metric(
phase,
source_id=source_id,
dataset_id=dataset_id,
duration_seconds=duration,
metadata=payload,
)
def record_pipeline_metric(
phase: str,
*,
source_id: int | None = None,
dataset_id: int | None = None,
duration_seconds: float | None = None,
metadata: dict[str, object] | None = None,
) -> None:
path = _metric_path()
path.parent.mkdir(parents=True, exist_ok=True)
row = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"phase": phase,
"source_id": source_id,
"dataset_id": dataset_id,
"duration_seconds": duration_seconds,
"metadata": metadata or {},
}
with path.open("a", encoding="utf-8") as handle:
handle.write(json.dumps(row, separators=(",", ":"), default=str))
handle.write("\n")
def _metric_path() -> Path:
return settings.data_dir / "metrics" / "pipeline_metrics.jsonl"

111
app/pipeline/download.py Normal file
View File

@@ -0,0 +1,111 @@
from __future__ import annotations
import shutil
import time
from pathlib import Path
from urllib.parse import urlparse
import requests
from app.config import settings
from app.models import Source
from app.pipeline.utils import sha256_file
def materialize_source(source: Source) -> Path:
"""Download/copy a source into the local cache and return the file path.
Files are stored by content hash per source. Re-running an unchanged source
reuses the existing cached file instead of creating another timestamped copy.
"""
source_dir = settings.data_dir / "sources" / f"source_{source.id}"
source_dir.mkdir(parents=True, exist_ok=True)
suffix = _guess_suffix(source.url, source.kind)
parsed = urlparse(source.url)
if parsed.scheme in {"http", "https"}:
temp_path = _download_temp_path(source_dir, suffix)
existing_size = temp_path.stat().st_size if temp_path.exists() else 0
headers = {"Range": f"bytes={existing_size}-"} if existing_size > 0 else None
with requests.get(source.url, stream=True, timeout=120, headers=headers) as r:
r.raise_for_status()
mode = "ab" if existing_size > 0 and r.status_code == 206 else "wb"
with temp_path.open(mode) as f:
for chunk in r.iter_content(chunk_size=1024 * 1024):
if chunk:
f.write(chunk)
return _store_or_reuse_cached_file(source_dir=source_dir, source_path=temp_path, suffix=suffix, move=True)
if parsed.scheme == "file":
source_path = Path(parsed.path)
else:
source_path = Path(source.url)
if not source_path.exists():
raise FileNotFoundError(f"Source file does not exist: {source.url}")
if _is_relative_to(source_path.resolve(), source_dir.resolve()):
return source_path
return _store_or_reuse_cached_file(source_dir=source_dir, source_path=source_path, suffix=suffix, move=False)
def _download_temp_path(source_dir: Path, suffix: str) -> Path:
candidates = sorted(
source_dir.glob(f"*.download{suffix}"),
key=lambda path: path.stat().st_mtime if path.exists() else 0,
reverse=True,
)
if candidates:
return candidates[0]
return source_dir / f"{int(time.time())}.download{suffix}"
def _guess_suffix(url: str, kind: str) -> str:
path = urlparse(url).path or url
lower = path.lower()
for suffix in (".zip", ".geojson", ".json", ".osm.pbf", ".pbf", ".osm", ".osm.xml", ".osc.gz", ".osc", ".csv"):
if lower.endswith(suffix):
return suffix
if kind == "gtfs":
return ".zip"
if kind == "osm_geojson":
return ".geojson"
return ".dat"
def _store_or_reuse_cached_file(source_dir: Path, source_path: Path, suffix: str, move: bool) -> Path:
source_hash = sha256_file(source_path)
target = source_dir / f"{source_hash[:16]}{suffix}"
if target.exists() and sha256_file(target) == source_hash:
if move and source_path != target:
source_path.unlink(missing_ok=True)
return target
existing = _find_existing_cached_file(source_dir, source_hash, suffix, exclude=source_path)
if existing is not None:
if move and source_path != existing:
source_path.unlink(missing_ok=True)
return existing
if move:
source_path.replace(target)
else:
shutil.copyfile(source_path, target)
return target
def _find_existing_cached_file(source_dir: Path, source_hash: str, suffix: str, exclude: Path | None = None) -> Path | None:
for candidate in sorted(source_dir.glob(f"*{suffix}")):
if exclude is not None and candidate.resolve() == exclude.resolve():
continue
if candidate.is_file() and sha256_file(candidate) == source_hash:
return candidate
return None
def _is_relative_to(path: Path, parent: Path) -> bool:
try:
path.relative_to(parent)
return True
except ValueError:
return False

1327
app/pipeline/gtfs.py Normal file

File diff suppressed because it is too large Load Diff

995
app/pipeline/matcher.py Normal file
View File

@@ -0,0 +1,995 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
import json
from typing import Callable, Optional
from shapely.geometry import LineString, MultiLineString, Point, shape
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from app.config import settings
from app.models import Dataset, GtfsRoute, MatchRule, OsmFeature, RouteMatch
from app.osm_storage import ensure_main_osm_feature, osm_feature_bbox, query_osm_features
from app.pipeline.state import STAGE_MATCH_ROUTES, dependency_hash, finish_pipeline_run, start_pipeline_run
from app.pipeline.utils import approx_bbox_center_distance_deg, bbox_overlap, norm_ref, norm_text
MODE_GROUPS = {
"train": {"train", "rail", "railway"},
"subway": {"subway", "metro"},
"tram": {"tram", "light_rail"},
"light_rail": {"light_rail", "tram"},
"bus": {"bus", "coach", "trolleybus"},
"coach": {"coach", "bus"},
"trolleybus": {"trolleybus", "bus"},
"ferry": {"ferry"},
"funicular": {"funicular"},
"aerialway": {"aerialway", "cable_car"},
"monorail": {"monorail"},
}
MAX_FALLBACK_CANDIDATES_WITH_REF = 40
MAX_FALLBACK_CANDIDATES_WITHOUT_REF = 80
MAX_EXACT_REF_CANDIDATES = 120
OSM_SCOPE_NEAR_DISTANCE_DEG = 0.15
GEOMETRY_PROXIMITY_DEG = 0.0035
GEOMETRY_SAMPLE_POINTS = 24
MATCHER_VERSION = "matcher_v4_scope_spatial_manual_rules"
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
@dataclass(frozen=True)
class _ManualMatchRule:
id: int
rule_type: str
route_selector: dict[str, object]
osm_selector: dict[str, object] | None
status: str
@dataclass(frozen=True)
class _OsmRouteIndex:
all_routes: list[OsmFeature]
by_ref: dict[str, list[OsmFeature]]
by_route_key: dict[str, list[OsmFeature]]
by_mode: dict[str, list[OsmFeature]]
@dataclass(frozen=True)
class _GeometryProfile:
geom: object
lines: list[LineString]
length: float
sample_points: list[Point]
@dataclass(frozen=True)
class _RouteMatchPayload:
gtfs_route_id: int
osm_feature_id: int | None
confidence: float
status: str
rule_source: str
reasons_json: str | None
def run_route_matching(
session: Session,
*,
progress_callback: ProgressCallback | None = None,
batch_size: int | None = None,
) -> dict[str, object]:
"""Match active GTFS routes against active OSM route features."""
active_datasets = session.execute(
select(Dataset.id, Dataset.kind, Dataset.source_id).where(Dataset.is_active.is_(True))
).all()
if not active_datasets:
return {"routes": 0, "matches": 0, "missing": 0}
dataset_source_ids = {int(dataset_id): int(source_id) for dataset_id, _, source_id in active_datasets}
gtfs_dataset_ids = [int(dataset_id) for dataset_id, kind, _ in active_datasets if kind == "gtfs"]
osm_dataset_ids = [int(dataset_id) for dataset_id, kind, _ in active_datasets if kind == "osm_geojson"]
if not gtfs_dataset_ids:
return {"routes": 0, "matches": 0, "missing": 0}
route_row_ids = session.scalars(
select(GtfsRoute.id)
.where(GtfsRoute.dataset_id.in_(gtfs_dataset_ids))
.order_by(GtfsRoute.dataset_id, GtfsRoute.route_id, GtfsRoute.id)
).all()
# Reconcile current match rows from auto scoring plus durable manual rules.
total_routes = len(route_row_ids)
if total_routes == 0:
return {"routes": 0, "matches": 0, "missing": 0}
dependency = _route_matching_dependency(session, active_datasets)
run = start_pipeline_run(
session,
stage=STAGE_MATCH_ROUTES,
version=MATCHER_VERSION,
dependency_hash_value=dependency_hash(dependency),
inputs=dependency,
)
session.commit()
effective_batch_size = max(1, int(batch_size or settings.route_matching_batch_size))
_emit_progress(
progress_callback,
"route_matching_started",
f"Matching {total_routes} GTFS routes in batches of {effective_batch_size}.",
0,
total_routes,
{"gtfs_datasets": gtfs_dataset_ids, "osm_datasets": osm_dataset_ids, "batch_size": effective_batch_size},
)
manual_rules = _manual_match_rules(session)
osm_scope_bbox = osm_feature_bbox(session, osm_dataset_ids, kinds=["route"])
counts = {"routes": total_routes, "matches": 0, "missing": 0, "manual": 0, "created": 0, "updated": 0, "unchanged": 0}
scoped_counts = {"in_osm_scope": 0, "near_osm_scope": 0, "outside_osm_scope": 0, "unknown_scope": 0}
processed = 0
for chunk in _chunks_int(route_row_ids, effective_batch_size):
routes = session.scalars(
select(GtfsRoute)
.where(GtfsRoute.id.in_(chunk))
.order_by(GtfsRoute.dataset_id, GtfsRoute.route_id, GtfsRoute.id)
).all()
batch_counts = _match_route_batch(
session=session,
routes=routes,
osm_dataset_ids=osm_dataset_ids,
dataset_source_ids=dataset_source_ids,
manual_rules=manual_rules,
osm_scope_bbox=osm_scope_bbox,
scoped_counts=scoped_counts,
)
counts["matches"] += batch_counts["matches"]
counts["missing"] += batch_counts["missing"]
counts["manual"] += batch_counts["manual"]
counts["created"] += batch_counts["created"]
counts["updated"] += batch_counts["updated"]
counts["unchanged"] += batch_counts["unchanged"]
processed += len(routes)
session.commit()
_emit_progress(
progress_callback,
"route_matching_batch",
f"Matched {processed}/{total_routes} GTFS routes.",
processed,
total_routes,
{
"processed": processed,
"matches": counts["matches"],
"missing": counts["missing"],
"manual": counts["manual"],
"created": counts["created"],
"updated": counts["updated"],
"unchanged": counts["unchanged"],
"scope": dict(scoped_counts),
},
)
result = {**counts, "scope": scoped_counts}
finish_pipeline_run(session, run, outputs=result)
session.commit()
_emit_progress(
progress_callback,
"route_matching_completed",
"Route matching completed.",
total_routes,
total_routes,
result,
)
return result
def _route_matching_dependency(session: Session, active_datasets) -> dict[str, object]:
datasets = [
{"id": int(dataset_id), "kind": str(kind), "source_id": int(source_id), "sha256": _dataset_sha(session, int(dataset_id))}
for dataset_id, kind, source_id in active_datasets
]
rules = [
{
"id": int(rule.id),
"type": rule.rule_type,
"active": bool(rule.active),
"selector": rule.selector_json,
"action": rule.action_json,
}
for rule in session.scalars(select(MatchRule).order_by(MatchRule.id)).all()
]
return {"version": MATCHER_VERSION, "active_datasets": datasets, "manual_rules": rules}
def _dataset_sha(session: Session, dataset_id: int) -> str | None:
dataset = session.get(Dataset, dataset_id)
return None if dataset is None else dataset.sha256
def _match_route_batch(
*,
session: Session,
routes: list[GtfsRoute],
osm_dataset_ids: list[int],
dataset_source_ids: dict[int, int],
manual_rules: list[_ManualMatchRule],
osm_scope_bbox: tuple[float | None, float | None, float | None, float | None],
scoped_counts: dict[str, int],
) -> dict[str, int]:
matches = 0
missing = 0
manual = 0
payloads: list[_RouteMatchPayload] = []
for route in routes:
scope = route_match_scope(route, osm_scope_bbox)
scoped_counts[scope] = scoped_counts.get(scope, 0) + 1
route_source_id = dataset_source_ids.get(route.dataset_id)
accepted_rule = _accepted_rule_for_route(manual_rules, route, route_source_id)
if accepted_rule is not None:
accepted_feature = _feature_for_rule_from_storage(session, osm_dataset_ids, dataset_source_ids, accepted_rule)
if accepted_feature is not None:
accepted_feature = ensure_main_osm_feature(session, accepted_feature)
payloads.append(
_RouteMatchPayload(
gtfs_route_id=route.id,
osm_feature_id=accepted_feature.id,
confidence=100.0,
status="accepted",
rule_source="manual",
reasons_json=json.dumps(
{"manual_rule_id": accepted_rule.id, "manual": "accepted_match", "scope": scope},
separators=(",", ":"),
),
)
)
matches += 1
manual += 1
continue
if scope == "outside_osm_scope":
missing += 1
payloads.append(
_RouteMatchPayload(
gtfs_route_id=route.id,
osm_feature_id=None,
confidence=0.0,
status="missing",
rule_source="auto",
reasons_json=json.dumps(
{
"reason": "outside loaded OSM route scope",
"scope": scope,
},
separators=(",", ":"),
),
)
)
continue
best_feature: Optional[OsmFeature] = None
best_score = 0.0
best_reasons: dict[str, object] = {}
route_geometry_profile = _geometry_profile(route.geometry_geojson)
for feature in candidate_osm_routes_for_route(session, route, osm_dataset_ids):
if _is_rejected_pair(manual_rules, route, route_source_id, feature, dataset_source_ids.get(feature.dataset_id)):
continue
feature_geometry_profile = _geometry_profile(feature.geometry_geojson)
score, reasons = score_route_pair(
route,
feature,
route_geometry_profile=route_geometry_profile,
feature_geometry_profile=feature_geometry_profile,
)
if score > best_score:
best_score = score
best_feature = feature
best_reasons = reasons
status = _status_from_score(best_score)
if best_feature is None or status == "missing":
missing += 1
best_feature_id = None
best_reasons = {
"reason": "no OSM candidate above threshold",
"scope": scope,
"best_score_below_threshold": round(float(best_score), 2) if best_score else 0,
"best_reasons": best_reasons,
}
best_score = 0
else:
matches += 1
best_feature = ensure_main_osm_feature(session, best_feature)
best_feature_id = best_feature.id
best_reasons["scope"] = scope
payloads.append(
_RouteMatchPayload(
gtfs_route_id=route.id,
osm_feature_id=best_feature_id,
confidence=round(float(best_score), 2),
status=status,
rule_source="auto",
reasons_json=json.dumps(best_reasons, separators=(",", ":")),
)
)
changes = _apply_route_match_payloads(session, payloads)
session.flush()
return {"matches": matches, "missing": missing, "manual": manual, **changes}
def _apply_route_match_payloads(session: Session, payloads: list[_RouteMatchPayload]) -> dict[str, int]:
if not payloads:
return {"created": 0, "updated": 0, "unchanged": 0}
route_ids = [payload.gtfs_route_id for payload in payloads]
existing_rows = session.scalars(
select(RouteMatch).where(RouteMatch.gtfs_route_id.in_(route_ids)).order_by(RouteMatch.gtfs_route_id, RouteMatch.id)
).all()
existing_by_route: dict[int, list[RouteMatch]] = {}
for row in existing_rows:
existing_by_route.setdefault(row.gtfs_route_id, []).append(row)
created = 0
updated = 0
unchanged = 0
duplicate_ids: list[int] = []
now = datetime.now(timezone.utc)
for payload in payloads:
existing = existing_by_route.get(payload.gtfs_route_id, [])
current = _preferred_existing_match(existing)
if current is None:
session.add(
RouteMatch(
gtfs_route_id=payload.gtfs_route_id,
osm_feature_id=payload.osm_feature_id,
confidence=payload.confidence,
status=payload.status,
rule_source=payload.rule_source,
reasons_json=payload.reasons_json,
)
)
created += 1
continue
duplicate_ids.extend(row.id for row in existing if row.id != current.id)
if _route_match_payload_equal(current, payload):
unchanged += 1
continue
current.osm_feature_id = payload.osm_feature_id
current.confidence = payload.confidence
current.status = payload.status
current.rule_source = payload.rule_source
current.reasons_json = payload.reasons_json
current.updated_at = now
updated += 1
for chunk in _chunks_int(duplicate_ids, 1000):
session.execute(delete(RouteMatch).where(RouteMatch.id.in_(chunk)))
return {"created": created, "updated": updated, "unchanged": unchanged}
def _preferred_existing_match(rows: list[RouteMatch]) -> RouteMatch | None:
if not rows:
return None
return next((row for row in rows if row.rule_source == "manual"), rows[0])
def _route_match_payload_equal(row: RouteMatch, payload: _RouteMatchPayload) -> bool:
return (
row.osm_feature_id == payload.osm_feature_id
and round(float(row.confidence or 0), 2) == round(float(payload.confidence or 0), 2)
and row.status == payload.status
and row.rule_source == payload.rule_source
and (row.reasons_json or None) == (payload.reasons_json or None)
)
def _build_osm_route_index(osm_routes: list[OsmFeature]) -> _OsmRouteIndex:
by_ref: dict[str, list[OsmFeature]] = {}
by_route_key: dict[str, list[OsmFeature]] = {}
by_mode: dict[str, list[OsmFeature]] = {}
for feature in osm_routes:
ref = norm_ref(feature.ref or "")
if ref:
by_ref.setdefault(ref, []).append(feature)
if feature.route_key:
by_route_key.setdefault(feature.route_key, []).append(feature)
if feature.mode:
by_mode.setdefault(feature.mode, []).append(feature)
return _OsmRouteIndex(all_routes=osm_routes, by_ref=by_ref, by_route_key=by_route_key, by_mode=by_mode)
def _candidate_osm_routes(route: GtfsRoute, index: _OsmRouteIndex) -> list[OsmFeature]:
selected: list[OsmFeature] = []
seen: set[int] = set()
def add(features: list[OsmFeature], *, require_compatible_mode: bool = True) -> None:
for feature in features:
if feature.id in seen:
continue
if require_compatible_mode and not _mode_compatible(route.mode or "", feature.mode or ""):
continue
seen.add(feature.id)
selected.append(feature)
route_ref = norm_ref(route.short_name or route.route_id)
if route_ref:
add(index.by_ref.get(route_ref, []))
if route.route_key:
add(index.by_route_key.get(route.route_key, []))
if selected:
return _spatially_ranked_candidates(route, selected, MAX_EXACT_REF_CANDIDATES)
compatible_modes = MODE_GROUPS.get(route.mode or "", {route.mode or ""})
mode_candidates: list[OsmFeature] = []
for mode in compatible_modes:
if mode:
mode_candidates.extend(index.by_mode.get(mode, []))
if not mode_candidates:
mode_candidates = index.all_routes
gtfs_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
near_candidates: list[tuple[float, OsmFeature]] = []
for feature in mode_candidates:
osm_bbox = (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
distance = approx_bbox_center_distance_deg(gtfs_bbox, osm_bbox)
if bbox_overlap(gtfs_bbox, osm_bbox):
near_candidates.append((0.0, feature))
elif distance is not None and distance < 0.12:
near_candidates.append((distance, feature))
fallback_limit = MAX_FALLBACK_CANDIDATES_WITH_REF if route_ref else MAX_FALLBACK_CANDIDATES_WITHOUT_REF
fallback = [feature for _, feature in sorted(near_candidates, key=lambda item: item[0])[:fallback_limit]]
if not fallback:
fallback = mode_candidates[:fallback_limit]
add(fallback)
return _spatially_ranked_candidates(route, selected, fallback_limit)
def candidate_osm_routes_for_route(session: Session, route: GtfsRoute, osm_dataset_ids: list[int]) -> list[OsmFeature]:
if not osm_dataset_ids:
return []
selected: list[OsmFeature] = []
seen: set[tuple[int, str, str]] = set()
def add(features: list[OsmFeature], *, require_compatible_mode: bool = True) -> None:
for feature in features:
key = (feature.dataset_id, feature.osm_type, feature.osm_id)
if key in seen:
continue
if require_compatible_mode and not _mode_compatible(route.mode or "", feature.mode or ""):
continue
seen.add(key)
selected.append(feature)
route_ref = norm_ref(route.short_name or route.route_id)
route_keys = [key for key in [route.route_key, route_ref] if key]
for route_key in dict.fromkeys(route_keys):
add(
query_osm_features(
session,
osm_dataset_ids,
kinds=["route"],
route_key=route_key,
)
)
if selected:
return _spatially_ranked_candidates(route, selected, MAX_EXACT_REF_CANDIDATES)
gtfs_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
compatible_modes = sorted(MODE_GROUPS.get(route.mode or "", {route.mode or ""}) - {""})
if not any(value is None for value in gtfs_bbox):
bbox = _expanded_bbox(gtfs_bbox, 0.10)
add(
query_osm_features(
session,
osm_dataset_ids,
kinds=["route"],
modes=compatible_modes or None,
bbox=bbox,
limit=MAX_FALLBACK_CANDIDATES_WITHOUT_REF * 4,
),
require_compatible_mode=False,
)
if not selected:
add(
query_osm_features(
session,
osm_dataset_ids,
kinds=["route"],
modes=compatible_modes or None,
limit=MAX_FALLBACK_CANDIDATES_WITHOUT_REF,
),
require_compatible_mode=False,
)
fallback_limit = MAX_FALLBACK_CANDIDATES_WITH_REF if route_ref else MAX_FALLBACK_CANDIDATES_WITHOUT_REF
return _spatially_ranked_candidates(route, selected, fallback_limit)
def score_route_pair(
route: GtfsRoute,
feature: OsmFeature,
route_geometry_profile: _GeometryProfile | None = None,
feature_geometry_profile: _GeometryProfile | None = None,
) -> tuple[float, dict[str, object]]:
score = 0.0
reasons: dict[str, object] = {}
gtfs_mode = route.mode or ""
osm_mode = feature.mode or ""
if _mode_compatible(gtfs_mode, osm_mode):
score += 25
reasons["mode"] = "compatible"
elif gtfs_mode and osm_mode:
reasons["mode"] = f"mismatch: {gtfs_mode} != {osm_mode}"
return 0.0, reasons
gtfs_ref = norm_ref(route.short_name or route.route_id)
osm_ref = norm_ref(feature.ref or "")
if gtfs_ref and osm_ref:
if gtfs_ref == osm_ref:
score += 25
reasons["ref"] = "exact"
elif gtfs_ref in osm_ref or osm_ref in gtfs_ref:
score += 15
reasons["ref"] = "partial"
gtfs_name = norm_text(" ".join(v for v in [route.long_name, route.short_name, route.route_id] if v))
osm_name = norm_text(" ".join(v for v in [feature.name, feature.ref] if v))
name_similarity = _ratio(gtfs_name, osm_name)
score += 20 * name_similarity
reasons["name_similarity"] = round(name_similarity, 3)
gtfs_operator = norm_text(route.operator_name or "")
osm_operator = norm_text(" ".join(v for v in [feature.operator, feature.network] if v))
operator_similarity = _ratio(gtfs_operator, osm_operator) if gtfs_operator and osm_operator else 0
score += 15 * operator_similarity
reasons["operator_similarity"] = round(operator_similarity, 3)
gtfs_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
osm_bbox = (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
center_distance = None
if bbox_overlap(gtfs_bbox, osm_bbox):
score += 14
reasons["bbox"] = "overlap"
if gtfs_ref and osm_ref and gtfs_ref == osm_ref and _mode_compatible(gtfs_mode, osm_mode):
score += 8
reasons["line_identity"] = "exact_ref_mode_bbox_overlap"
else:
center_distance = approx_bbox_center_distance_deg(gtfs_bbox, osm_bbox)
if center_distance is not None:
if center_distance < 0.01:
score += 12
elif center_distance < 0.03:
score += 8
elif center_distance < 0.08:
score += 4
elif gtfs_ref and osm_ref and gtfs_ref == osm_ref and center_distance > OSM_SCOPE_NEAR_DISTANCE_DEG:
score -= 8
reasons["spatial_penalty"] = "exact_ref_far_bbox_center"
reasons["bbox_center_distance_deg"] = round(center_distance, 5)
geometry_metrics = (
_geometry_match_metrics_from_profiles(route_geometry_profile, feature_geometry_profile)
if route_geometry_profile is not None and feature_geometry_profile is not None
else _geometry_match_metrics(route.geometry_geojson, feature.geometry_geojson)
)
if geometry_metrics is not None:
reasons["geometry"] = geometry_metrics
geometry_score = 34 * float(geometry_metrics["gtfs_on_osm_ratio"]) + 8 * float(geometry_metrics["osm_on_gtfs_ratio"])
if float(geometry_metrics["endpoint_distance_deg"]) < GEOMETRY_PROXIMITY_DEG * 2:
geometry_score += 6
if float(geometry_metrics["length_ratio"]) < 0.35 or float(geometry_metrics["length_ratio"]) > 2.8:
geometry_score -= 8
reasons["geometry_length"] = "implausible_ratio"
score += max(0.0, min(42.0, geometry_score))
# Extra small boost for same normalized route key.
if route.route_key and feature.route_key and route.route_key == feature.route_key:
score += 5
reasons["route_key"] = "same"
if gtfs_ref and osm_ref and gtfs_ref == osm_ref and _mode_compatible(gtfs_mode, osm_mode):
if bbox_overlap(gtfs_bbox, osm_bbox):
score = max(score, 88.0)
reasons["strong_identity"] = "exact_ref_mode_bbox_overlap"
elif center_distance is not None and center_distance < 0.02:
score = max(score, 82.0)
reasons["strong_identity"] = "exact_ref_mode_near_bbox_center"
if route.route_key and feature.route_key and route.route_key == feature.route_key and _mode_compatible(gtfs_mode, osm_mode):
if bbox_overlap(gtfs_bbox, osm_bbox):
score = max(score, 86.0)
reasons.setdefault("strong_identity", "same_route_key_mode_bbox_overlap")
if geometry_metrics is not None:
gtfs_on_osm = float(geometry_metrics["gtfs_on_osm_ratio"])
endpoint_distance = float(geometry_metrics["endpoint_distance_deg"])
if gtfs_on_osm >= 0.82 and endpoint_distance < GEOMETRY_PROXIMITY_DEG * 3 and _mode_compatible(gtfs_mode, osm_mode):
if gtfs_ref and osm_ref and gtfs_ref == osm_ref:
score = max(score, 90.0)
reasons["strong_identity"] = "exact_ref_mode_geometry_overlap"
elif gtfs_ref and osm_ref and (gtfs_ref in osm_ref or osm_ref in gtfs_ref):
score = max(score, 82.0)
reasons["strong_identity"] = "partial_ref_mode_geometry_overlap"
if (
gtfs_ref
and osm_ref
and gtfs_ref == osm_ref
and center_distance is not None
and center_distance > OSM_SCOPE_NEAR_DISTANCE_DEG
and not bbox_overlap(gtfs_bbox, osm_bbox)
and (
geometry_metrics is None
or float(geometry_metrics.get("gtfs_on_osm_ratio", 0.0)) < 0.25
)
):
score = min(score, 58.0)
reasons["spatial_cap"] = "exact_ref_far_without_geometry_overlap"
return min(score, 100.0), reasons
def route_match_scope(route: GtfsRoute, osm_scope_bbox: tuple[float | None, float | None, float | None, float | None]) -> str:
route_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
if any(value is None for value in route_bbox) or any(value is None for value in osm_scope_bbox):
return "unknown_scope"
if bbox_overlap(route_bbox, osm_scope_bbox):
return "in_osm_scope"
distance = approx_bbox_center_distance_deg(route_bbox, osm_scope_bbox)
if distance is not None and distance < OSM_SCOPE_NEAR_DISTANCE_DEG:
return "near_osm_scope"
return "outside_osm_scope"
def _combined_bbox(features: list[OsmFeature]) -> tuple[float | None, float | None, float | None, float | None]:
boxes = [
(feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
for feature in features
if None not in (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
]
if not boxes:
return (None, None, None, None)
return (
min(float(box[0]) for box in boxes if box[0] is not None),
min(float(box[1]) for box in boxes if box[1] is not None),
max(float(box[2]) for box in boxes if box[2] is not None),
max(float(box[3]) for box in boxes if box[3] is not None),
)
def _spatially_ranked_candidates(route: GtfsRoute, candidates: list[OsmFeature], limit: int) -> list[OsmFeature]:
return [
feature
for _, feature in sorted(
((_spatial_rank(route, feature), feature) for feature in candidates),
key=lambda item: item[0],
)[: max(1, limit)]
]
def _spatial_rank(route: GtfsRoute, feature: OsmFeature) -> tuple[int, float, str]:
route_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
feature_bbox = (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
distance = approx_bbox_center_distance_deg(route_bbox, feature_bbox)
if bbox_overlap(route_bbox, feature_bbox):
bucket = 0
elif distance is not None and distance < OSM_SCOPE_NEAR_DISTANCE_DEG:
bucket = 1
elif distance is not None:
bucket = 2
else:
bucket = 3
return (bucket, distance if distance is not None else 999.0, feature.osm_id)
def _expanded_bbox(
bbox: tuple[float | None, float | None, float | None, float | None],
padding: float,
) -> tuple[float, float, float, float] | None:
min_lon, min_lat, max_lon, max_lat = bbox
if None in (min_lon, min_lat, max_lon, max_lat):
return None
return (float(min_lon) - padding, float(min_lat) - padding, float(max_lon) + padding, float(max_lat) + padding)
def _chunks_int(values: list[int], size: int) -> list[list[int]]:
return [values[start : start + size] for start in range(0, len(values), max(1, size))]
def _emit_progress(
progress_callback: ProgressCallback | None,
event_type: str,
message: str,
progress_current: int | None,
progress_total: int | None,
metadata: dict[str, object] | None = None,
) -> None:
if progress_callback is not None:
progress_callback(event_type, message, progress_current, progress_total, metadata)
def _geometry_match_metrics(route_geometry: str | None, feature_geometry: str | None) -> dict[str, float] | None:
route_profile = _geometry_profile(route_geometry)
feature_profile = _geometry_profile(feature_geometry)
return _geometry_match_metrics_from_profiles(route_profile, feature_profile)
def _geometry_profile(geometry_text: str | None) -> _GeometryProfile | None:
if not geometry_text:
return None
try:
geom = shape(json.loads(geometry_text))
except Exception: # noqa: BLE001 - malformed geometry should not break matching
return None
lines = _iter_lines(geom)
if not lines:
return None
length = sum(line.length for line in lines)
if length == 0:
return None
sample_points = _sample_line_points(lines, GEOMETRY_SAMPLE_POINTS)
if not sample_points:
return None
return _GeometryProfile(geom=geom, lines=lines, length=length, sample_points=sample_points)
def _geometry_match_metrics_from_profiles(
route_profile: _GeometryProfile | None, feature_profile: _GeometryProfile | None
) -> dict[str, float] | None:
if route_profile is None or feature_profile is None:
return None
gtfs_on_osm = _near_point_ratio(route_profile.sample_points, feature_profile.geom, GEOMETRY_PROXIMITY_DEG)
osm_on_gtfs = _near_point_ratio(feature_profile.sample_points, route_profile.geom, GEOMETRY_PROXIMITY_DEG)
endpoint_distance = _endpoint_distance(route_profile.lines, feature_profile.geom)
length_ratio = route_profile.length / feature_profile.length if feature_profile.length else 0.0
return {
"gtfs_on_osm_ratio": round(gtfs_on_osm, 3),
"osm_on_gtfs_ratio": round(osm_on_gtfs, 3),
"endpoint_distance_deg": round(endpoint_distance, 6),
"length_ratio": round(length_ratio, 3),
}
def _iter_lines(geom) -> list[LineString]:
if isinstance(geom, LineString):
return [geom]
if isinstance(geom, MultiLineString):
return [line for line in geom.geoms if isinstance(line, LineString) and line.length > 0]
return []
def _sample_line_points(lines: list[LineString], count: int) -> list[Point]:
total_length = sum(line.length for line in lines)
if total_length == 0:
return []
points = []
for index in range(count):
target = total_length * (index / max(1, count - 1))
traversed = 0.0
for line in lines:
next_traversed = traversed + line.length
if target <= next_traversed or line is lines[-1]:
points.append(line.interpolate(max(0.0, min(line.length, target - traversed))))
break
traversed = next_traversed
return points
def _near_point_ratio(points: list[Point], geom, max_distance: float) -> float:
if not points:
return 0.0
near = sum(1 for point in points if geom.distance(point) <= max_distance)
return near / len(points)
def _endpoint_distance(gtfs_lines: list[LineString], osm_geom) -> float:
longest = max(gtfs_lines, key=lambda line: line.length)
coords = list(longest.coords)
if len(coords) < 2:
return 999.0
return osm_geom.distance(Point(coords[0])) + osm_geom.distance(Point(coords[-1]))
def _manual_match_rules(session: Session) -> list[_ManualMatchRule]:
rules = session.scalars(
select(MatchRule)
.where(MatchRule.active.is_(True), MatchRule.rule_type.in_(["accept_match", "reject_match"]))
.order_by(MatchRule.id.desc())
).all()
parsed: list[_ManualMatchRule] = []
for rule in rules:
try:
selector = json.loads(rule.selector_json or "{}")
action = json.loads(rule.action_json or "{}")
except json.JSONDecodeError:
continue
route_selector = selector.get("gtfs") if isinstance(selector.get("gtfs"), dict) else selector
osm_selector = action.get("osm") if isinstance(action.get("osm"), dict) else selector.get("osm")
if not isinstance(osm_selector, dict) and selector.get("osm_feature_id") is not None:
osm_selector = {"osm_feature_id": selector.get("osm_feature_id")}
status = str(action.get("status") or ("accepted" if rule.rule_type == "accept_match" else "rejected"))
parsed.append(
_ManualMatchRule(
id=rule.id,
rule_type=rule.rule_type,
route_selector=route_selector,
osm_selector=osm_selector if isinstance(osm_selector, dict) else None,
status=status,
)
)
return parsed
def _accepted_rule_for_route(
rules: list[_ManualMatchRule], route: GtfsRoute, route_source_id: int | None
) -> _ManualMatchRule | None:
for rule in rules:
if rule.rule_type != "accept_match":
continue
if rule.status != "accepted":
continue
if _route_matches_selector(route, route_source_id, rule.route_selector):
return rule
return None
def _feature_for_rule(
features: list[OsmFeature], dataset_source_ids: dict[int, int], rule: _ManualMatchRule
) -> OsmFeature | None:
if not rule.osm_selector:
return None
for feature in features:
if _feature_matches_selector(feature, dataset_source_ids.get(feature.dataset_id), rule.osm_selector):
return feature
return None
def _feature_for_rule_from_storage(
session: Session,
osm_dataset_ids: list[int],
dataset_source_ids: dict[int, int],
rule: _ManualMatchRule,
) -> OsmFeature | None:
if not rule.osm_selector:
return None
selector = rule.osm_selector
legacy_id = _safe_int(selector.get("osm_feature_id"))
if legacy_id is not None:
feature = session.get(OsmFeature, legacy_id)
if feature is not None and _feature_matches_selector(feature, dataset_source_ids.get(feature.dataset_id), selector):
return feature
scoped_dataset_ids = list(osm_dataset_ids)
expected_source = selector.get("source_id")
if expected_source is not None:
expected_source_id = _safe_int(expected_source)
if expected_source_id is not None:
scoped_dataset_ids = [
dataset_id
for dataset_id in scoped_dataset_ids
if dataset_source_ids.get(dataset_id) == expected_source_id
]
dataset_id = _safe_int(selector.get("dataset_id"))
if dataset_id is not None:
scoped_dataset_ids = [value for value in scoped_dataset_ids if value == dataset_id]
if not scoped_dataset_ids:
return None
features: list[OsmFeature] = []
osm_type = selector.get("osm_type")
osm_id = selector.get("osm_id")
if osm_type and osm_id:
features = query_osm_features(
session,
scoped_dataset_ids,
kinds=["route"],
osm_type=str(osm_type),
osm_id=str(osm_id),
limit=10,
)
if not features:
route_key = selector.get("route_key")
if route_key:
features = query_osm_features(session, scoped_dataset_ids, kinds=["route"], route_key=str(route_key))
if not features:
ref = norm_ref(selector.get("ref"))
if ref:
features = query_osm_features(session, scoped_dataset_ids, kinds=["route"], route_key=ref)
for feature in features:
if _feature_matches_selector(feature, dataset_source_ids.get(feature.dataset_id), selector):
return feature
return None
def _is_rejected_pair(
rules: list[_ManualMatchRule],
route: GtfsRoute,
route_source_id: int | None,
feature: OsmFeature,
feature_source_id: int | None,
) -> bool:
for rule in rules:
if rule.rule_type != "reject_match":
continue
if not _route_matches_selector(route, route_source_id, rule.route_selector):
continue
if rule.osm_selector and _feature_matches_selector(feature, feature_source_id, rule.osm_selector):
return True
return False
def _route_matches_selector(route: GtfsRoute, source_id: int | None, selector: dict[str, object]) -> bool:
legacy_id = selector.get("gtfs_route_id")
if legacy_id is not None and _safe_int(legacy_id) == route.id:
return True
expected_source = selector.get("source_id")
if expected_source is not None and source_id is not None and _safe_int(expected_source) != source_id:
return False
route_id = selector.get("route_id")
if route_id and str(route_id) == route.route_id:
return True
route_key = selector.get("route_key")
if route_key and route.route_key and str(route_key) == route.route_key:
return True
ref = norm_ref(selector.get("ref"))
mode = selector.get("mode")
if ref and ref == norm_ref(route.short_name or route.route_id):
return not mode or _mode_compatible(str(mode), route.mode or "")
return False
def _feature_matches_selector(feature: OsmFeature, source_id: int | None, selector: dict[str, object]) -> bool:
legacy_id = selector.get("osm_feature_id")
if legacy_id is not None and _safe_int(legacy_id) == feature.id:
return True
expected_source = selector.get("source_id")
if expected_source is not None and source_id is not None and _safe_int(expected_source) != source_id:
return False
osm_type = selector.get("osm_type")
osm_id = selector.get("osm_id")
if osm_type and osm_id and str(osm_type) == feature.osm_type and str(osm_id) == feature.osm_id:
return True
route_key = selector.get("route_key")
if route_key and feature.route_key and str(route_key) == feature.route_key:
return True
ref = norm_ref(selector.get("ref"))
mode = selector.get("mode")
if ref and ref == norm_ref(feature.ref or ""):
return not mode or _mode_compatible(str(mode), feature.mode or "")
return False
def _safe_int(value: object) -> int | None:
try:
return int(value) # type: ignore[arg-type]
except (TypeError, ValueError):
return None
def _mode_compatible(gtfs_mode: str, osm_mode: str) -> bool:
if not gtfs_mode or not osm_mode:
return True
if gtfs_mode == osm_mode:
return True
return osm_mode in MODE_GROUPS.get(gtfs_mode, {gtfs_mode}) or gtfs_mode in MODE_GROUPS.get(osm_mode, {osm_mode})
def _ratio(a: str, b: str) -> float:
if not a or not b:
return 0.0
if a == b:
return 1.0
token_ratio = _token_similarity(a, b)
if a in b or b in a:
token_ratio = max(token_ratio, 0.82)
return token_ratio
def _token_similarity(a: str, b: str) -> float:
left = set(a.split())
right = set(b.split())
if not left or not right:
return 0.0
return len(left & right) / len(left | right)
def _status_from_score(score: float) -> str:
if score >= 85:
return "matched"
if score >= 65:
return "probable"
if score >= 40:
return "weak"
return "missing"

View File

@@ -0,0 +1,508 @@
from __future__ import annotations
import json
import math
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Callable
import osmium
from sqlalchemy import delete, func, select, text
from sqlalchemy.orm import Session
from app.config import settings
from app.models import Dataset, OsmAddress
from app.pipeline.routing_layer import active_routing_dataset
from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
ADDRESS_INDEX_VERSION = "osm_addresses_v2_nodes_ways_area_geometry"
ADDRESS_TAGS = {
"addr:housenumber",
"addr:housename",
"addr:street",
"addr:place",
"addr:postcode",
"addr:city",
"addr:country",
"addr:unit",
"addr:suburb",
"addr:district",
"addr:municipality",
"entrance",
"name",
}
@dataclass
class AddressIndexResult:
dataset_id: int
input_path: str
addresses: int
node_addresses: int
way_addresses: int
skipped: int
version: str = ADDRESS_INDEX_VERSION
def as_dict(self) -> dict[str, object]:
return {
"version": self.version,
"dataset_id": self.dataset_id,
"input_path": self.input_path,
"addresses": self.addresses,
"node_addresses": self.node_addresses,
"way_addresses": self.way_addresses,
"skipped": self.skipped,
}
def rebuild_address_index(
session: Session,
*,
dataset_id: int | None = None,
input_path: str | Path | None = None,
reset: bool = True,
batch_size: int = 20_000,
progress_callback: ProgressCallback | None = None,
) -> dict[str, object]:
dataset = session.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(session)
if dataset is None:
raise ValueError("No OSM PBF dataset is available for address indexing.")
path = Path(input_path or dataset.local_path)
if not path.exists():
raise FileNotFoundError(f"Address index PBF does not exist: {path}")
if reset:
_emit(progress_callback, "address_index_clear_started", "Clearing existing OSM address index.", None, None, {"dataset_id": dataset.id})
_clear_address_rows(session, dataset_id=int(dataset.id))
session.commit()
if settings.is_postgresql_database:
_emit(progress_callback, "address_index_indexes_dropped", "Dropping address lookup indexes before bulk import.", None, None, {"dataset_id": dataset.id})
_drop_address_indexes(session)
session.commit()
_emit(progress_callback, "address_index_import_started", "Importing OSM address nodes and ways.", None, None, {"dataset_id": dataset.id, "path": str(path)})
handler = _AddressHandler(
session=session,
dataset_id=dataset.id,
batch_size=batch_size,
progress_callback=progress_callback,
)
if hasattr(osmium, "FileProcessor"):
_apply_address_file_processor(handler, path)
else:
handler.apply_file(str(path), locations=True)
handler.flush()
return finalize_address_index(
session,
dataset_id=dataset.id,
input_path=path,
node_addresses=handler.node_address_count,
way_addresses=handler.way_address_count,
skipped=handler.skipped_count,
progress_callback=progress_callback,
)
def finalize_address_index(
session: Session,
*,
dataset_id: int,
input_path: str | Path,
node_addresses: int = 0,
way_addresses: int = 0,
skipped: int = 0,
progress_callback: ProgressCallback | None = None,
) -> dict[str, object]:
dataset = session.get(Dataset, dataset_id)
if dataset is None:
raise ValueError("Address index dataset does not exist.")
if settings.is_postgresql_database:
_emit(progress_callback, "address_index_geometry_started", "Refreshing address point geometries.", None, None, {"dataset_id": dataset.id})
refresh_postgis_geometries(session, dataset_id=dataset.id, tables=["osm_addresses"], only_missing=False)
session.commit()
_emit(progress_callback, "address_index_indexes_started", "Rebuilding address lookup indexes.", None, None, {"dataset_id": dataset.id})
_create_address_indexes(session)
session.commit()
analyze_postgresql_tables(session, ["osm_addresses"])
address_count = int(session.scalar(select(func.count()).select_from(OsmAddress).where(OsmAddress.dataset_id == dataset.id)) or 0)
metadata = _metadata(dataset)
metadata["address_index"] = {
"version": ADDRESS_INDEX_VERSION,
"addresses": address_count,
"node_addresses": int(node_addresses),
"way_addresses": int(way_addresses),
"skipped": int(skipped),
"input_path": str(input_path),
}
dataset.metadata_json = json.dumps(metadata, indent=2)
session.commit()
result = AddressIndexResult(
dataset_id=dataset.id,
input_path=str(input_path),
addresses=address_count,
node_addresses=node_addresses,
way_addresses=way_addresses,
skipped=skipped,
).as_dict()
_emit(progress_callback, "address_index_import_completed", "OSM address index import completed.", address_count, address_count, result)
return result
def _clear_address_rows(session: Session, *, dataset_id: int) -> None:
if settings.is_postgresql_database:
other_dataset_count = int(
session.scalar(
select(func.count(func.distinct(OsmAddress.dataset_id))).where(OsmAddress.dataset_id != int(dataset_id))
)
or 0
)
if other_dataset_count == 0:
session.execute(text("TRUNCATE TABLE osm_addresses RESTART IDENTITY"))
return
session.execute(delete(OsmAddress).where(OsmAddress.dataset_id == int(dataset_id)))
def address_index_status(session: Session) -> dict[str, object]:
dataset = active_routing_dataset(session)
dataset_id = None if dataset is None else int(dataset.id)
address_count = 0
metadata: dict[str, object] = {}
if dataset is not None:
metadata = _metadata(dataset).get("address_index") or {}
if isinstance(metadata, dict):
try:
address_count = int(metadata.get("addresses") or 0)
except (TypeError, ValueError):
address_count = 0
if not address_count:
address_count = int(session.scalar(select(func.count()).select_from(OsmAddress).where(OsmAddress.dataset_id == dataset.id)) or 0)
installed_version = metadata.get("version") if isinstance(metadata, dict) else None
return {
"dataset_id": dataset_id,
"addresses": address_count,
"available": address_count > 0,
"version": installed_version,
"current_version": ADDRESS_INDEX_VERSION,
"stale": bool(address_count and installed_version != ADDRESS_INDEX_VERSION),
"input_path": metadata.get("input_path") if isinstance(metadata, dict) else None,
}
class _AddressHandler(osmium.SimpleHandler):
def __init__(
self,
*,
session: Session,
dataset_id: int,
batch_size: int,
progress_callback: ProgressCallback | None,
) -> None:
super().__init__()
self.session = session
self.dataset_id = int(dataset_id)
self.batch_size = max(1_000, int(batch_size))
self.progress_callback = progress_callback
self.rows: list[dict[str, object]] = []
self.address_count = 0
self.node_address_count = 0
self.way_address_count = 0
self.skipped_count = 0
self.processed_count = 0
def node(self, node) -> None:
self.process_node(node)
def way(self, way) -> None:
self.process_way(way)
def process_object(self, obj) -> None:
if hasattr(obj, "nodes"):
self.process_way(obj)
elif hasattr(obj, "location"):
self.process_node(obj)
def process_node(self, node) -> None:
self.processed_count += 1
tags = {tag.k: tag.v for tag in node.tags}
if not _has_address(tags):
return
if not node.location.valid():
self.skipped_count += 1
return
row = _address_row(
dataset_id=self.dataset_id,
osm_type="node",
osm_id=str(node.id),
tags=tags,
lon=float(node.location.lon),
lat=float(node.location.lat),
bounds=(float(node.location.lon), float(node.location.lat), float(node.location.lon), float(node.location.lat)),
geometry_geojson=None,
)
if row is None:
self.skipped_count += 1
return
self.rows.append(row)
self.node_address_count += 1
self._after_address()
def process_way(self, way) -> None:
self.processed_count += 1
tags = {tag.k: tag.v for tag in way.tags}
if not _has_address(tags):
return
coords = [
(float(node.location.lon), float(node.location.lat))
for node in way.nodes
if node.location.valid()
]
if not coords:
self.skipped_count += 1
return
lon, lat = _centroid(coords)
min_lon = min(coord[0] for coord in coords)
max_lon = max(coord[0] for coord in coords)
min_lat = min(coord[1] for coord in coords)
max_lat = max(coord[1] for coord in coords)
row = _address_row(
dataset_id=self.dataset_id,
osm_type="way",
osm_id=str(way.id),
tags=tags,
lon=lon,
lat=lat,
bounds=(min_lon, min_lat, max_lon, max_lat),
geometry_geojson=_address_area_geometry_geojson(coords, closed=_way_is_closed(way)),
)
if row is None:
self.skipped_count += 1
return
self.rows.append(row)
self.way_address_count += 1
self._after_address()
def _after_address(self) -> None:
self.address_count += 1
if len(self.rows) >= self.batch_size:
self.flush()
if self.address_count % 50_000 == 0:
_emit(
self.progress_callback,
"address_index_import_batch",
f"Imported {self.address_count:,} OSM addresses.",
self.address_count,
None,
{"processed": self.processed_count, "skipped": self.skipped_count},
)
def flush(self) -> None:
if not self.rows:
return
self.session.bulk_insert_mappings(OsmAddress, self.rows)
self.session.commit()
self.rows = []
def _apply_address_file_processor(handler: _AddressHandler, path: Path) -> None:
processor = (
osmium.FileProcessor(str(path), osmium.osm.NODE | osmium.osm.WAY)
.with_locations()
.with_filter(osmium.filter.KeyFilter("addr:housenumber", "addr:housename"))
)
for obj in processor:
handler.process_object(obj)
def _has_address(tags: dict[str, str]) -> bool:
housenumber = _clean(tags.get("addr:housenumber") or tags.get("addr:housename"))
if not housenumber:
return False
return any(_clean(tags.get(key)) for key in ("addr:street", "addr:place", "addr:city", "addr:postcode"))
def _address_row(
*,
dataset_id: int,
osm_type: str,
osm_id: str,
tags: dict[str, str],
lon: float,
lat: float,
bounds: tuple[float, float, float, float],
geometry_geojson: str | None = None,
) -> dict[str, object] | None:
housenumber = _clean(tags.get("addr:housenumber") or tags.get("addr:housename"))
street = _clean(tags.get("addr:street"))
place = _clean(tags.get("addr:place"))
postcode = _clean(tags.get("addr:postcode"))
city = _clean(tags.get("addr:city") or tags.get("addr:municipality"))
country = _clean(tags.get("addr:country"))
unit = _clean(tags.get("addr:unit"))
name = _clean(tags.get("name"))
display_name = _display_name(housenumber=housenumber, street=street, place=place, postcode=postcode, city=city, name=name)
if not display_name:
return None
search_text = _search_text(display_name, housenumber, street, place, postcode, city, country, unit, name)
selected_tags = {key: tags[key] for key in sorted(ADDRESS_TAGS) if key in tags}
min_lon, min_lat, max_lon, max_lat = bounds
return {
"dataset_id": dataset_id,
"osm_type": osm_type,
"osm_id": osm_id,
"housenumber": housenumber,
"street": street,
"place": place,
"postcode": postcode,
"city": city,
"country": country,
"unit": unit,
"name": name,
"display_name": display_name,
"search_text": search_text,
"lon": lon,
"lat": lat,
"min_lon": min_lon,
"min_lat": min_lat,
"max_lon": max_lon,
"max_lat": max_lat,
"geometry_geojson": geometry_geojson,
"tags_json": json.dumps(selected_tags, separators=(",", ":")) if selected_tags else None,
}
def _address_area_geometry_geojson(coords: list[tuple[float, float]], *, closed: bool | None = None) -> str | None:
if closed is False:
return None
if len(coords) < 3:
return None
ring_coords = list(coords)
first = ring_coords[0]
last = ring_coords[-1]
already_closed = abs(first[0] - last[0]) <= 1e-12 and abs(first[1] - last[1]) <= 1e-12
if not already_closed:
if closed is not True:
return None
ring_coords.append(first)
if len(ring_coords) < 4:
return None
ring = [[float(lon), float(lat)] for lon, lat in ring_coords]
if len({(round(lon, 12), round(lat, 12)) for lon, lat in ring_coords[:-1]}) < 3:
return None
return json.dumps({"type": "Polygon", "coordinates": [ring]}, separators=(",", ":"))
def _way_is_closed(way) -> bool:
try:
nodes = way.nodes
return len(nodes) >= 3 and nodes[0].ref == nodes[-1].ref
except (AttributeError, IndexError, TypeError):
return False
def _display_name(
*,
housenumber: str | None,
street: str | None,
place: str | None,
postcode: str | None,
city: str | None,
name: str | None,
) -> str | None:
road = street or place or name
if road and housenumber:
first = f"{road} {housenumber}"
else:
first = road or housenumber
locality = " ".join(part for part in [postcode, city] if part)
if first and locality:
return f"{first}, {locality}"
return first or locality
def _search_text(*parts: str | None) -> str:
return re.sub(r"\s+", " ", " ".join(part.casefold() for part in parts if part)).strip()
def _clean(value: object) -> str | None:
cleaned = re.sub(r"\s+", " ", str(value or "")).strip()
return cleaned or None
def _centroid(coords: list[tuple[float, float]]) -> tuple[float, float]:
if len(coords) >= 4 and coords[0] == coords[-1]:
area = 0.0
cx = 0.0
cy = 0.0
for (x1, y1), (x2, y2) in zip(coords, coords[1:]):
cross = x1 * y2 - x2 * y1
area += cross
cx += (x1 + x2) * cross
cy += (y1 + y2) * cross
if abs(area) > 1e-18:
factor = 1 / (3 * area)
return cx * factor, cy * factor
return (
math.fsum(coord[0] for coord in coords) / len(coords),
math.fsum(coord[1] for coord in coords) / len(coords),
)
def _drop_address_indexes(session: Session) -> None:
for name in [
"ix_osm_addresses_dataset_city_street",
"ix_osm_addresses_dataset_postcode",
"ix_osm_addresses_bbox",
"ix_osm_addresses_geom_gist",
"ix_osm_addresses_area_geom_gist",
"ix_osm_addresses_search_trgm",
"ix_osm_addresses_display_trgm",
"ix_osm_addresses_street_key_house",
"ix_osm_addresses_street_key_trgm",
]:
session.execute(text(f"DROP INDEX IF EXISTS {name}"))
def _create_address_indexes(session: Session) -> None:
statements = [
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_dataset_city_street ON osm_addresses (dataset_id, city, street, housenumber)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_dataset_postcode ON osm_addresses (dataset_id, postcode)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_bbox ON osm_addresses (dataset_id, min_lon, max_lon, min_lat, max_lat)",
]
if settings.is_postgresql_database:
statements.extend(
[
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_geom_gist ON osm_addresses USING GIST (geom)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_area_geom_gist ON osm_addresses USING GIST (area_geom)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_search_trgm ON osm_addresses USING GIN (LOWER(COALESCE(search_text, '')) gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_display_trgm ON osm_addresses USING GIN (LOWER(COALESCE(display_name, '')) gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_street_key_house ON osm_addresses (dataset_id, REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss'), housenumber)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_street_key_trgm ON osm_addresses USING GIN (REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss') gin_trgm_ops)",
]
)
for statement in statements:
session.execute(text(statement))
def _metadata(dataset: Dataset) -> dict[str, object]:
try:
value = json.loads(dataset.metadata_json or "{}")
except json.JSONDecodeError:
return {}
return value if isinstance(value, dict) else {}
def _emit(
progress_callback: ProgressCallback | None,
event_type: str,
message: str,
progress_current: int | None,
progress_total: int | None,
metadata: dict[str, object] | None = None,
) -> None:
if progress_callback is not None:
progress_callback(event_type, message, progress_current, progress_total, metadata)

100
app/pipeline/osm_diff.py Normal file
View File

@@ -0,0 +1,100 @@
from __future__ import annotations
import json
from urllib.parse import urlparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.config import settings
from app.models import Dataset, Source
from app.pipeline.download import materialize_source
from app.pipeline.osm_pbf import _raw_format
from app.pipeline.osm_replication import fetch_replication_state
from app.pipeline.utils import sha256_file
def run_osm_diff_source(session: Session, source: Source) -> Dataset:
"""Commit an OSM change file as a raw update artifact.
Applying the diff to an authoritative OSM base extract is a separate step;
this importer deliberately records the file without treating it as a
complete visual route layer.
"""
if _looks_like_update_directory(source.url):
return _commit_update_directory_state(session, source)
raw_path = materialize_source(source)
raw_hash = sha256_file(raw_path)
existing = session.scalar(
select(Dataset)
.where(Dataset.source_id == source.id, Dataset.kind == "osm_diff_raw", Dataset.sha256 == raw_hash)
.order_by(Dataset.id.desc())
)
if existing is not None:
return existing
dataset = Dataset(
source_id=source.id,
kind="osm_diff_raw",
local_path=str(raw_path),
sha256=raw_hash,
is_active=False,
status="committed",
metadata_json=json.dumps(
{
"stage": "raw_osm_diff",
"raw_format": _raw_format(raw_path),
"source_url": source.url,
},
indent=2,
),
)
session.add(dataset)
session.flush()
return dataset
def _commit_update_directory_state(session: Session, source: Source) -> Dataset:
state = fetch_replication_state(source.url, timeout=settings.osm_diff_state_timeout_seconds)
source_dir = settings.data_dir / "sources" / f"source_{source.id}"
source_dir.mkdir(parents=True, exist_ok=True)
state_path = source_dir / f"state_{state.sequence_number}.txt"
state_path.write_text(
"\n".join(f"{key}={value}" for key, value in sorted(state.raw.items())) + "\n",
encoding="utf-8",
)
state_hash = sha256_file(state_path)
existing = session.scalar(
select(Dataset)
.where(Dataset.source_id == source.id, Dataset.kind == "osm_diff_state", Dataset.sha256 == state_hash)
.order_by(Dataset.id.desc())
)
if existing is not None:
return existing
dataset = Dataset(
source_id=source.id,
kind="osm_diff_state",
local_path=str(state_path),
sha256=state_hash,
is_active=False,
status="committed",
metadata_json=json.dumps(
{
"stage": "osm_diff_state",
"updates_url": source.url,
"sequence_number": state.sequence_number,
"timestamp": state.timestamp,
"state": state.raw,
},
indent=2,
),
)
session.add(dataset)
session.flush()
return dataset
def _looks_like_update_directory(url: str) -> bool:
lower_path = urlparse(url).path.lower()
return lower_path.endswith("-updates") or lower_path.endswith("-updates/")

248
app/pipeline/osm_geojson.py Normal file
View File

@@ -0,0 +1,248 @@
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.config import settings
from app.models import Dataset, OsmFeature, Source
from app.osm_classification import infer_osm_route_scope
from app.osm_storage import (
OSM_STORAGE_METADATA_KEY,
OSM_STORAGE_MAIN,
OSM_STORAGE_SIDECAR_FEATURES,
create_osm_sidecar,
dedupe_osm_feature_rows,
effective_osm_feature_storage,
)
from app.pipeline.download import materialize_source
from app.pipeline.utils import first_nonempty, geometry_json_and_bbox, norm_ref, norm_text, sha256_file
from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries
ROUTE_MODES = {
"train",
"railway",
"light_rail",
"subway",
"tram",
"bus",
"trolleybus",
"coach",
"ferry",
"monorail",
"funicular",
"aerialway",
}
def run_osm_geojson_source(session: Session, source: Source) -> Dataset:
local_path = materialize_source(source)
source_hash = sha256_file(local_path)
existing = session.scalar(
select(Dataset)
.where(
Dataset.source_id == source.id,
Dataset.kind == "osm_geojson",
Dataset.sha256 == source_hash,
Dataset.is_active.is_(True),
Dataset.status == "imported",
)
.order_by(Dataset.id.desc())
)
if existing is not None:
return existing
return import_osm_geojson(session=session, source=source, path=local_path, source_hash=source_hash)
def import_osm_geojson(
session: Session,
source: Source,
path: Path,
source_hash: str | None = None,
*,
storage_mode: str | None = None,
) -> Dataset:
for dataset in source.datasets:
dataset.is_active = False
dataset = Dataset(
source_id=source.id,
kind="osm_geojson",
local_path=str(path),
sha256=source_hash or sha256_file(path),
is_active=True,
status="importing",
)
session.add(dataset)
session.flush()
source_hash = source_hash or sha256_file(path)
dataset.metadata_json = json.dumps(
prepare_osm_geojson_storage(
session=session,
dataset=dataset,
path=path,
source_hash=source_hash,
storage_mode=storage_mode,
),
indent=2,
)
dataset.status = "imported"
source.status = "ok"
source.last_error = None
session.flush()
return dataset
def prepare_osm_geojson_storage(
*,
session: Session,
dataset: Dataset,
path: Path,
source_hash: str | None = None,
storage_mode: str | None = None,
) -> dict[str, object]:
data = json.loads(path.read_text(encoding="utf-8"))
features = _as_features(data)
feature_rows = [_feature_row(dataset.id, idx, feature) for idx, feature in enumerate(features)]
storage = effective_osm_feature_storage(storage_mode)
if storage not in {OSM_STORAGE_MAIN, OSM_STORAGE_SIDECAR_FEATURES}:
raise ValueError(f"Unsupported OSM feature storage mode: {storage}")
if storage == OSM_STORAGE_SIDECAR_FEATURES:
return {
"features": len(feature_rows),
OSM_STORAGE_METADATA_KEY: create_osm_sidecar(dataset, feature_rows, source_hash=source_hash or dataset.sha256),
}
_insert_main_features(session, feature_rows)
session.flush()
refresh_postgis_geometries(session, dataset_id=dataset.id, tables=["osm_features"])
analyze_postgresql_tables(session, ["osm_features"])
return {"features": len(feature_rows), OSM_STORAGE_METADATA_KEY: {"mode": OSM_STORAGE_MAIN}}
def _insert_main_features(session: Session, feature_rows: list[dict[str, object]]) -> None:
objects: list[OsmFeature] = []
deduped_rows, _duplicate_count = dedupe_osm_feature_rows(feature_rows)
for row in deduped_rows:
objects.append(
OsmFeature(
dataset_id=row["dataset_id"],
osm_type=row["osm_type"],
osm_id=row["osm_id"],
kind=row["kind"],
mode=row["mode"],
route_scope=row["route_scope"],
name=row["name"],
ref=row["ref"],
operator=row["operator"],
network=row["network"],
geometry_geojson=row["geometry_geojson"],
min_lon=row["min_lon"],
min_lat=row["min_lat"],
max_lon=row["max_lon"],
max_lat=row["max_lat"],
tags_json=row["tags_json"],
route_key=row["route_key"],
operator_key=row["operator_key"],
)
)
if len(objects) >= 5000:
session.bulk_save_objects(objects)
objects.clear()
if objects:
session.bulk_save_objects(objects)
def _feature_row(dataset_id: int, idx: int, feature: dict[str, Any]) -> dict[str, object]:
props = feature.get("properties") or {}
geometry = feature.get("geometry")
geometry_text, bbox = geometry_json_and_bbox(geometry)
osm_type = str(first_nonempty(props.get("osm_type"), props.get("@type"), props.get("type"), "feature"))
osm_id = str(first_nonempty(props.get("osm_id"), props.get("@id"), props.get("id"), f"feature_{idx}"))
mode = _infer_mode(props)
kind = _infer_kind(props, mode)
name = first_nonempty(props.get("name"), props.get("official_name")) or None
ref = first_nonempty(props.get("ref"), props.get("route_ref"), props.get("line")) or None
operator = first_nonempty(props.get("operator"), props.get("agency"), props.get("brand")) or None
network = first_nonempty(props.get("network"), props.get("network:short")) or None
route_scope = infer_osm_route_scope(mode=mode, ref=ref, name=name, network=network, tags=props)
route_key = norm_ref(ref) or norm_text(name) or norm_ref(osm_id)
operator_key = norm_text(operator or network or "")
return {
"dataset_id": dataset_id,
"osm_type": osm_type,
"osm_id": osm_id,
"kind": kind,
"mode": mode,
"route_scope": route_scope,
"name": name,
"ref": ref,
"operator": operator,
"network": network,
"geometry_geojson": geometry_text,
"min_lon": bbox[0],
"min_lat": bbox[1],
"max_lon": bbox[2],
"max_lat": bbox[3],
"tags_json": json.dumps(props, separators=(",", ":")),
"route_key": route_key,
"operator_key": operator_key,
}
def _as_features(data: Any) -> list[dict[str, Any]]:
if isinstance(data, dict) and data.get("type") == "FeatureCollection":
return [f for f in data.get("features", []) if isinstance(f, dict)]
if isinstance(data, dict) and data.get("type") == "Feature":
return [data]
if isinstance(data, list):
return [f for f in data if isinstance(f, dict)]
raise ValueError("OSM source must be GeoJSON FeatureCollection, Feature, or list of Features")
def _infer_mode(props: dict[str, Any]) -> str | None:
for key in ("mode", "route", "route_master"):
value = str(props.get(key) or "").strip()
if value in ROUTE_MODES:
return "train" if value == "railway" else value
railway = str(props.get("railway") or "").strip()
if railway in {"station", "halt"}:
return "train"
if railway == "tram_stop":
return "tram"
if railway == "subway_entrance":
return "subway"
if str(props.get("highway") or "") == "bus_stop" or str(props.get("amenity") or "") == "bus_station":
return "bus"
if str(props.get("amenity") or "") == "ferry_terminal":
return "ferry"
if str(props.get("aerialway") or "") == "station":
return "aerialway"
return None
def _infer_kind(props: dict[str, Any], mode: str | None) -> str:
explicit_kind = str(props.get("kind") or "").strip()
if explicit_kind in {"route", "stop", "station", "terminal", "infra", "feature"}:
return explicit_kind
if str(props.get("type") or "") in {"route", "route_master"} or str(props.get("route") or "") in ROUTE_MODES:
return "route"
if str(props.get("amenity") or "") == "ferry_terminal":
return "terminal"
if str(props.get("amenity") or "") == "bus_station":
return "terminal"
if str(props.get("railway") or "") in {"station", "halt"}:
return "station"
if str(props.get("aerialway") or "") == "station":
return "station"
if str(props.get("public_transport") or "") in {"platform", "stop_position", "station"}:
return "stop"
if str(props.get("highway") or "") == "bus_stop":
return "stop"
if mode:
return "infra"
return "feature"

View File

@@ -0,0 +1,456 @@
from __future__ import annotations
from datetime import datetime, timezone
import json
from pathlib import Path
import sqlite3
from typing import Callable
from sqlalchemy import func, select, text
from sqlalchemy.orm import Session
from app.models import Dataset, OsmFeature
from app.osm_classification import OSM_ROUTE_SCOPE_CLASSIFIER_VERSION, infer_osm_route_scope_from_tags
from app.osm_storage import (
dataset_metadata,
drop_osm_sidecar_route_scope_indexes,
ensure_osm_sidecar_schema,
features_are_sidecar,
rebuild_osm_sidecar_indexes,
sidecar_path,
writable_sidecar_connection,
)
from app.pipeline.state import (
STAGE_BUILD_INDEXES,
STAGE_LABEL_FEATURES,
dependency_hash,
finish_pipeline_run,
latest_completed_run,
start_pipeline_run,
)
OSM_LABEL_FEATURES_VERSION = OSM_ROUTE_SCOPE_CLASSIFIER_VERSION
MAIN_ROUTE_SCOPE_INDEX = "ix_osm_features_scope_bbox"
MAIN_INDEX_REBUILD_THRESHOLD = 10_000
SIDECAR_INDEX_REBUILD_THRESHOLD = 10_000
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
def relabel_osm_features(
session: Session,
*,
dataset_id: int | None = None,
chunk_size: int = 5000,
force: bool = False,
rebuild_indexes: bool = True,
progress_callback: ProgressCallback | None = None,
job_id: int | None = None,
) -> dict[str, object]:
datasets = _target_datasets(session, dataset_id)
result: dict[str, object] = {
"version": OSM_LABEL_FEATURES_VERSION,
"datasets": len(datasets),
"processed": 0,
"changed": 0,
"skipped": 0,
"missing": 0,
"index_rebuilds": 0,
"dataset_results": [],
}
_emit_progress(
progress_callback,
"osm_labeling_started",
f"Relabeling {len(datasets)} OSM dataset(s).",
0,
len(datasets),
{"dataset_id": dataset_id, "force": force, "version": OSM_LABEL_FEATURES_VERSION},
)
for index, dataset in enumerate(datasets, start=1):
dataset_result = relabel_osm_dataset(
session,
dataset,
chunk_size=chunk_size,
force=force,
rebuild_indexes=rebuild_indexes,
progress_callback=progress_callback,
job_id=job_id,
)
result["processed"] = int(result["processed"]) + int(dataset_result.get("processed", 0) or 0)
result["changed"] = int(result["changed"]) + int(dataset_result.get("changed", 0) or 0)
result["skipped"] = int(result["skipped"]) + (1 if dataset_result.get("status") == "skipped" else 0)
result["missing"] = int(result["missing"]) + (1 if dataset_result.get("status") == "missing_sidecar" else 0)
result["index_rebuilds"] = int(result["index_rebuilds"]) + int(dataset_result.get("index_rebuilds", 0) or 0)
result["dataset_results"].append(dataset_result) # type: ignore[union-attr]
_emit_progress(
progress_callback,
"osm_labeling_dataset_completed",
f"Relabeled {index}/{len(datasets)} OSM dataset(s).",
index,
len(datasets),
dataset_result,
)
_emit_progress(progress_callback, "osm_labeling_completed", "OSM feature relabeling completed.", len(datasets), len(datasets), result)
return result
def relabel_osm_dataset(
session: Session,
dataset: Dataset,
*,
chunk_size: int = 5000,
force: bool = False,
rebuild_indexes: bool = True,
progress_callback: ProgressCallback | None = None,
job_id: int | None = None,
) -> dict[str, object]:
dependency = _label_dependency(dataset)
dependency_hash_value = dependency_hash(dependency)
if not force and _dataset_label_is_current(session, dataset, dependency_hash_value):
return {
"dataset_id": dataset.id,
"source_id": dataset.source_id,
"status": "skipped",
"reason": "label_features dependency is current",
"dependency_hash": dependency_hash_value,
"version": OSM_LABEL_FEATURES_VERSION,
"processed": 0,
"changed": 0,
"index_rebuilds": 0,
}
run = start_pipeline_run(
session,
stage=STAGE_LABEL_FEATURES,
version=OSM_LABEL_FEATURES_VERSION,
dependency_hash_value=dependency_hash_value,
source_id=dataset.source_id,
dataset_id=dataset.id,
job_id=job_id,
inputs=dependency,
)
session.commit()
try:
if features_are_sidecar(dataset):
counts = _relabel_sidecar_dataset(dataset, chunk_size=chunk_size, rebuild_indexes=rebuild_indexes, progress_callback=progress_callback)
else:
counts = _relabel_main_dataset(session, dataset, chunk_size=chunk_size, rebuild_indexes=rebuild_indexes, progress_callback=progress_callback)
output = {
"dataset_id": dataset.id,
"source_id": dataset.source_id,
"status": "completed",
"dependency_hash": dependency_hash_value,
"version": OSM_LABEL_FEATURES_VERSION,
**counts,
}
_stamp_dataset_metadata(session, dataset, dependency_hash_value, output)
finish_pipeline_run(session, run, outputs=output)
session.commit()
return output
except FileNotFoundError as exc:
output = {
"dataset_id": dataset.id,
"source_id": dataset.source_id,
"status": "missing_sidecar",
"dependency_hash": dependency_hash_value,
"version": OSM_LABEL_FEATURES_VERSION,
"processed": 0,
"changed": 0,
"index_rebuilds": 0,
"error": str(exc),
}
finish_pipeline_run(session, run, status="failed", outputs=output, error=str(exc))
session.commit()
return output
except Exception as exc:
finish_pipeline_run(session, run, status="failed", error=str(exc))
session.commit()
raise
def _target_datasets(session: Session, dataset_id: int | None) -> list[Dataset]:
stmt = select(Dataset).where(Dataset.kind == "osm_geojson", Dataset.status == "imported")
if dataset_id is None:
stmt = stmt.where(Dataset.is_active.is_(True))
else:
stmt = stmt.where(Dataset.id == dataset_id)
return session.scalars(stmt.order_by(Dataset.source_id, Dataset.id)).all()
def _dataset_label_is_current(session: Session, dataset: Dataset, dependency_hash_value: str) -> bool:
metadata = dataset_metadata(dataset)
label_info = metadata.get("label_features")
metadata_current = (
isinstance(label_info, dict)
and label_info.get("version") == OSM_LABEL_FEATURES_VERSION
and label_info.get("dependency_hash") == dependency_hash_value
)
if not metadata_current:
return False
return (
latest_completed_run(
session,
stage=STAGE_LABEL_FEATURES,
version=OSM_LABEL_FEATURES_VERSION,
dependency_hash_value=dependency_hash_value,
source_id=dataset.source_id,
dataset_id=dataset.id,
)
is not None
)
def _relabel_sidecar_dataset(
dataset: Dataset,
*,
chunk_size: int,
rebuild_indexes: bool,
progress_callback: ProgressCallback | None,
) -> dict[str, int | str]:
path = sidecar_path(dataset)
if path is None or not path.exists():
raise FileNotFoundError(f"OSM sidecar does not exist: {path}")
with writable_sidecar_connection(dataset) as connection:
ensure_osm_sidecar_schema(connection)
total = int(connection.execute("SELECT COUNT(*) FROM osm_features").fetchone()[0] or 0)
should_rebuild_index = rebuild_indexes and total >= SIDECAR_INDEX_REBUILD_THRESHOLD
if should_rebuild_index:
drop_osm_sidecar_route_scope_indexes(connection)
connection.commit()
processed = 0
changed = 0
last_id = 0
try:
while True:
rows = connection.execute(
"""
SELECT id, mode, ref, name, network, tags_json, route_scope
FROM osm_features
WHERE id > ?
ORDER BY id
LIMIT ?
""",
(last_id, max(1, int(chunk_size))),
).fetchall()
if not rows:
break
updates: list[tuple[str | None, int]] = []
for row in rows:
last_id = int(row["id"])
new_scope = _classified_scope(row["mode"], row["ref"], row["name"], row["network"], row["tags_json"])
if _normalize_scope(row["route_scope"]) != new_scope:
updates.append((new_scope, last_id))
if updates:
connection.executemany("UPDATE osm_features SET route_scope = ? WHERE id = ?", updates)
processed += len(rows)
changed += len(updates)
connection.commit()
_emit_progress(
progress_callback,
"osm_labeling_batch",
f"Relabeled {processed}/{total} OSM sidecar features.",
processed,
total,
{"dataset_id": dataset.id, "changed": changed, "storage": "sidecar"},
)
finally:
index_rebuilds = 0
if should_rebuild_index:
rebuild_osm_sidecar_indexes(connection)
connection.commit()
index_rebuilds = 1
_record_sidecar_index_build(connection, dataset, path)
_record_sidecar_label(connection, dataset, processed=processed, changed=changed)
connection.commit()
return {"storage": "sidecar", "processed": processed, "changed": changed, "index_rebuilds": index_rebuilds}
def _relabel_main_dataset(
session: Session,
dataset: Dataset,
*,
chunk_size: int,
rebuild_indexes: bool,
progress_callback: ProgressCallback | None,
) -> dict[str, int | str]:
total = int(session.scalar(select(func.count()).select_from(OsmFeature).where(OsmFeature.dataset_id == dataset.id)) or 0)
should_rebuild_index = rebuild_indexes and total >= MAIN_INDEX_REBUILD_THRESHOLD
index_rebuilds = 0
if should_rebuild_index:
session.execute(text(f"DROP INDEX IF EXISTS {MAIN_ROUTE_SCOPE_INDEX}"))
session.commit()
processed = 0
changed = 0
last_id = 0
try:
while True:
rows = session.scalars(
select(OsmFeature)
.where(OsmFeature.dataset_id == dataset.id, OsmFeature.id > last_id)
.order_by(OsmFeature.id)
.limit(max(1, int(chunk_size)))
).all()
if not rows:
break
updates: list[dict[str, object]] = []
for feature in rows:
last_id = int(feature.id)
new_scope = _classified_scope(feature.mode, feature.ref, feature.name, feature.network, feature.tags_json)
if _normalize_scope(feature.route_scope) != new_scope:
updates.append({"id": feature.id, "route_scope": new_scope})
if updates:
session.bulk_update_mappings(OsmFeature, updates)
processed += len(rows)
changed += len(updates)
session.commit()
_emit_progress(
progress_callback,
"osm_labeling_batch",
f"Relabeled {processed}/{total} main-table OSM features.",
processed,
total,
{"dataset_id": dataset.id, "changed": changed, "storage": "main"},
)
finally:
if should_rebuild_index:
session.execute(
text(
"CREATE INDEX IF NOT EXISTS ix_osm_features_scope_bbox "
"ON osm_features (dataset_id, kind, mode, route_scope, min_lon, max_lon, min_lat, max_lat)"
)
)
session.commit()
index_rebuilds = 1
_record_main_index_build(session, dataset)
return {"storage": "main", "processed": processed, "changed": changed, "index_rebuilds": index_rebuilds}
def _classified_scope(mode: object, ref: object, name: object, network: object, tags_json: object) -> str | None:
return _normalize_scope(
infer_osm_route_scope_from_tags(
None if mode is None else str(mode),
None if ref is None else str(ref),
None if name is None else str(name),
None if network is None else str(network),
None if tags_json is None else str(tags_json),
)
)
def _normalize_scope(value: object) -> str | None:
text_value = str(value or "").strip()
return text_value or None
def _label_dependency(dataset: Dataset) -> dict[str, object]:
metadata = dataset_metadata(dataset)
storage = metadata.get("osm_storage") if isinstance(metadata, dict) else None
path = sidecar_path(dataset)
path_fingerprint: dict[str, object] | None = None
if path is not None:
resolved = Path(path)
if resolved.exists():
path_fingerprint = {"path": str(resolved), "exists": True}
else:
path_fingerprint = {"path": str(resolved), "missing": True}
return {
"dataset_id": dataset.id,
"source_id": dataset.source_id,
"kind": dataset.kind,
"dataset_sha256": dataset.sha256,
"storage": storage,
"sidecar": path_fingerprint,
"classifier_version": OSM_LABEL_FEATURES_VERSION,
}
def _stamp_dataset_metadata(session: Session, dataset: Dataset, dependency_hash_value: str, output: dict[str, object]) -> None:
refreshed = session.get(Dataset, dataset.id)
if refreshed is None:
return
metadata = dataset_metadata(refreshed)
metadata["label_features"] = {
"stage": STAGE_LABEL_FEATURES,
"version": OSM_LABEL_FEATURES_VERSION,
"dependency_hash": dependency_hash_value,
"labeled_at": datetime.now(timezone.utc).isoformat(),
"processed": output.get("processed", 0),
"changed": output.get("changed", 0),
"storage": output.get("storage"),
}
refreshed.metadata_json = json.dumps(metadata, indent=2)
session.flush()
def _record_sidecar_label(connection: sqlite3.Connection, dataset: Dataset, *, processed: int, changed: int) -> None:
connection.execute("CREATE TABLE IF NOT EXISTS pipeline_metadata (key TEXT PRIMARY KEY, value TEXT NOT NULL)")
connection.execute(
"INSERT OR REPLACE INTO pipeline_metadata (key, value) VALUES (?, ?)",
(
"label_features",
json.dumps(
{
"stage": STAGE_LABEL_FEATURES,
"version": OSM_LABEL_FEATURES_VERSION,
"dataset_id": dataset.id,
"processed": processed,
"changed": changed,
"updated_at": datetime.now(timezone.utc).isoformat(),
},
sort_keys=True,
separators=(",", ":"),
),
),
)
def _record_sidecar_index_build(connection: sqlite3.Connection, dataset: Dataset, path: Path) -> None:
connection.execute("CREATE TABLE IF NOT EXISTS pipeline_metadata (key TEXT PRIMARY KEY, value TEXT NOT NULL)")
connection.execute(
"INSERT OR REPLACE INTO pipeline_metadata (key, value) VALUES (?, ?)",
(
"build_indexes:route_scope",
json.dumps(
{
"stage": STAGE_BUILD_INDEXES,
"version": "osm_sidecar_indexes_v1",
"dataset_id": dataset.id,
"path": str(path),
"updated_at": datetime.now(timezone.utc).isoformat(),
},
sort_keys=True,
separators=(",", ":"),
),
),
)
def _record_main_index_build(session: Session, dataset: Dataset) -> None:
dependency = {
"dataset_id": dataset.id,
"index": MAIN_ROUTE_SCOPE_INDEX,
"version": "osm_main_indexes_v1",
}
run = start_pipeline_run(
session,
stage=STAGE_BUILD_INDEXES,
version="osm_main_indexes_v1",
dependency_hash_value=dependency_hash(dependency),
source_id=dataset.source_id,
dataset_id=dataset.id,
inputs=dependency,
)
finish_pipeline_run(session, run, outputs={"index": MAIN_ROUTE_SCOPE_INDEX})
session.commit()
def _emit_progress(
callback: ProgressCallback | None,
event_type: str,
message: str,
current: int | None,
total: int | None,
metadata: dict[str, object] | None,
) -> None:
if callback is not None:
callback(event_type, message, current, total, metadata)

1581
app/pipeline/osm_pbf.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,105 @@
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import subprocess
from urllib.parse import urljoin, urlparse
import requests
@dataclass(frozen=True)
class ReplicationState:
sequence_number: int
timestamp: str | None
raw: dict[str, str]
def fetch_replication_state(updates_url: str, *, timeout: float = 30) -> ReplicationState:
state_url = _state_url(updates_url)
response = requests.get(state_url, timeout=timeout)
response.raise_for_status()
return parse_replication_state_text(response.text)
def parse_replication_state_text(text: str) -> ReplicationState:
values: dict[str, str] = {}
for line in text.splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, value = line.split("=", 1)
values[key.strip()] = _unescape_state_value(value.strip())
sequence = values.get("sequenceNumber")
if sequence is None:
raise ValueError("replication state is missing sequenceNumber")
try:
sequence_number = int(sequence)
except ValueError as exc:
raise ValueError(f"invalid replication sequenceNumber: {sequence}") from exc
return ReplicationState(
sequence_number=sequence_number,
timestamp=values.get("timestamp"),
raw=values,
)
def diff_url_for_sequence(updates_url: str, sequence_number: int) -> str:
padded = str(sequence_number).zfill(max(9, ((len(str(sequence_number)) + 2) // 3) * 3))
parts = [padded[index : index + 3] for index in range(0, len(padded), 3)]
return urljoin(_directory_url(updates_url), "/".join(parts) + ".osc.gz")
def download_diff(updates_url: str, sequence_number: int, output_dir: Path, *, timeout: float = 120) -> Path:
url = diff_url_for_sequence(updates_url, sequence_number)
parsed_path = Path(urlparse(url).path)
output_path = output_dir / parsed_path.name
nested = output_dir / parsed_path.parent.name / output_path.name
if output_path.exists():
return output_path
if nested.exists():
return nested
output_dir.mkdir(parents=True, exist_ok=True)
temp_path = output_dir / f"{sequence_number}.download"
with requests.get(url, stream=True, timeout=timeout) as response:
response.raise_for_status()
with temp_path.open("wb") as handle:
for chunk in response.iter_content(chunk_size=1024 * 1024):
if chunk:
handle.write(chunk)
temp_path.replace(output_path)
return output_path
def apply_osm_changes(base_path: Path, diff_paths: list[Path], output_path: Path, host_tool_path: Path) -> subprocess.CompletedProcess[str]:
if not diff_paths:
raise ValueError("no OSM change files supplied")
output_path.parent.mkdir(parents=True, exist_ok=True)
command = [
str(host_tool_path),
"osmium",
"apply-changes",
"--output",
str(output_path),
"--overwrite",
str(base_path),
*[str(path) for path in diff_paths],
]
return subprocess.run(command, check=True, capture_output=True, text=True)
def _state_url(updates_url: str) -> str:
return urljoin(_directory_url(updates_url), "state.txt")
def _directory_url(url: str) -> str:
return url if url.endswith("/") else f"{url}/"
def _unescape_state_value(value: str) -> str:
return (
value.replace("\\:", ":")
.replace("\\=", "=")
.replace("\\ ", " ")
.replace("\\\\", "\\")
)

1903
app/pipeline/route_layer.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,473 @@
from __future__ import annotations
import json
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Callable
import osmium
from sqlalchemy import delete, func, select, text
from sqlalchemy.dialects.postgresql import insert as postgresql_insert
from sqlalchemy.orm import Session
from app.config import settings
from app.models import Dataset, RoutingEdge, RoutingNode
from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
ROUTING_LAYER_VERSION = "routing_layer_v2_osm_highway_segments_service_tags"
DRIVE_HIGHWAYS = {
"motorway",
"motorway_link",
"trunk",
"trunk_link",
"primary",
"primary_link",
"secondary",
"secondary_link",
"tertiary",
"tertiary_link",
"unclassified",
"residential",
"living_street",
"service",
"road",
"track",
}
WALK_HIGHWAYS = {
"pedestrian",
"footway",
"path",
"steps",
"cycleway",
"bridleway",
"living_street",
"residential",
"service",
"track",
"unclassified",
"tertiary",
"tertiary_link",
"secondary",
"secondary_link",
"primary",
"primary_link",
"road",
}
EXCLUDED_HIGHWAYS = {"construction", "proposed", "abandoned", "platform", "raceway"}
NO_VALUES = {"no", "private", "agricultural", "forestry", "delivery", "customers"}
YES_VALUES = {"yes", "designated", "permissive", "destination"}
ONEWAY_FORWARD = {"yes", "true", "1"}
ONEWAY_REVERSE = {"-1", "reverse"}
DEFAULT_DRIVE_SPEED_KMH = {
"motorway": 110,
"motorway_link": 50,
"trunk": 90,
"trunk_link": 45,
"primary": 70,
"primary_link": 40,
"secondary": 60,
"secondary_link": 35,
"tertiary": 50,
"tertiary_link": 30,
"unclassified": 40,
"residential": 30,
"living_street": 10,
"service": 15,
"road": 30,
"track": 15,
}
DEFAULT_WALK_SPEED_MPS = 1.35
STEP_WALK_SPEED_MPS = 0.65
@dataclass
class RoutingImportResult:
dataset_id: int
input_path: str
nodes: int
edges: int
walk_edges: int
drive_edges: int
skipped_ways: int
version: str = ROUTING_LAYER_VERSION
def as_dict(self) -> dict[str, object]:
return {
"version": self.version,
"dataset_id": self.dataset_id,
"input_path": self.input_path,
"nodes": self.nodes,
"edges": self.edges,
"walk_edges": self.walk_edges,
"drive_edges": self.drive_edges,
"skipped_ways": self.skipped_ways,
}
def active_routing_dataset(session: Session) -> Dataset | None:
active_osm = session.scalar(
select(Dataset).where(Dataset.kind == "osm_geojson", Dataset.is_active.is_(True)).order_by(Dataset.id.desc())
)
if active_osm is not None:
metadata = _metadata(active_osm)
raw_dataset_id = metadata.get("raw_dataset_id")
if raw_dataset_id is not None:
raw = session.get(Dataset, int(raw_dataset_id))
if raw is not None and Path(raw.local_path).exists():
return raw
return session.scalar(
select(Dataset)
.where(Dataset.kind == "osm_pbf_raw")
.order_by(Dataset.is_active.desc(), Dataset.id.desc())
)
def rebuild_routing_layer(
session: Session,
*,
dataset_id: int | None = None,
input_path: str | Path | None = None,
reset: bool = True,
batch_size: int = 5000,
progress_callback: ProgressCallback | None = None,
) -> dict[str, object]:
if not settings.is_postgresql_database:
raise RuntimeError("The routing layer importer requires PostgreSQL/PostGIS.")
dataset = session.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(session)
if dataset is None:
raise ValueError("No OSM PBF dataset is available for routing import.")
path = Path(input_path or dataset.local_path)
if not path.exists():
raise FileNotFoundError(f"Routing import PBF does not exist: {path}")
if reset:
_emit(progress_callback, "routing_layer_clear_started", "Clearing existing routing graph.", None, None, {"dataset_id": dataset.id})
session.execute(delete(RoutingEdge).where(RoutingEdge.dataset_id == dataset.id))
session.execute(delete(RoutingNode).where(RoutingNode.dataset_id == dataset.id))
session.commit()
_emit(progress_callback, "routing_layer_import_started", "Importing routable OSM highway graph.", None, None, {"dataset_id": dataset.id, "path": str(path)})
handler = _RoutingGraphHandler(session=session, dataset_id=dataset.id, batch_size=batch_size, progress_callback=progress_callback)
handler.apply_file(str(path), locations=True)
handler.flush()
return finalize_routing_layer(
session,
dataset_id=dataset.id,
input_path=str(path),
skipped_way_count=handler.skipped_way_count,
progress_callback=progress_callback,
)
def finalize_routing_layer(
session: Session,
*,
dataset_id: int | None = None,
input_path: str | Path | None = None,
skipped_way_count: int = 0,
progress_callback: ProgressCallback | None = None,
) -> dict[str, object]:
if not settings.is_postgresql_database:
raise RuntimeError("The routing layer finalizer requires PostgreSQL/PostGIS.")
dataset = session.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(session)
if dataset is None:
raise ValueError("No routing dataset is available to finalize.")
path = Path(input_path or dataset.local_path)
_emit(progress_callback, "routing_layer_geometry_indexes_dropped", "Dropping routing geometry indexes before bulk refresh.", None, None, {"dataset_id": dataset.id})
_drop_routing_geometry_indexes(session)
session.commit()
_emit(progress_callback, "routing_layer_geometry_started", "Refreshing routing node PostGIS geometries.", None, None, {"dataset_id": dataset.id})
refresh_postgis_geometries(session, dataset_id=dataset.id, tables=["routing_nodes"], only_missing=False)
session.commit()
_emit(progress_callback, "routing_layer_geometry_indexes_started", "Rebuilding routing geometry indexes.", None, None, {"dataset_id": dataset.id})
_create_routing_geometry_indexes(session)
session.commit()
analyze_postgresql_tables(session, ["routing_nodes", "routing_edges"])
node_count = int(session.scalar(select(func.count()).select_from(RoutingNode).where(RoutingNode.dataset_id == dataset.id)) or 0)
edge_count = int(session.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset.id)) or 0)
walk_edge_count = int(session.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset.id, RoutingEdge.walk_cost_s.is_not(None))) or 0)
drive_edge_count = int(session.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset.id, RoutingEdge.drive_cost_s.is_not(None))) or 0)
dataset_metadata = _metadata(dataset)
dataset_metadata["routing_layer"] = {
"version": ROUTING_LAYER_VERSION,
"nodes": node_count,
"edges": edge_count,
"walk_edges": walk_edge_count,
"drive_edges": drive_edge_count,
"input_path": str(path),
}
dataset.metadata_json = json.dumps(dataset_metadata, indent=2)
session.commit()
result = RoutingImportResult(
dataset_id=dataset.id,
input_path=str(path),
nodes=node_count,
edges=edge_count,
walk_edges=walk_edge_count,
drive_edges=drive_edge_count,
skipped_ways=skipped_way_count,
).as_dict()
_emit(progress_callback, "routing_layer_import_completed", "Routing graph import completed.", edge_count, edge_count, result)
return result
def _drop_routing_geometry_indexes(session: Session) -> None:
session.execute(text("DROP INDEX IF EXISTS ix_routing_nodes_geom_gist"))
session.execute(text("DROP INDEX IF EXISTS ix_routing_edges_geom_gist"))
session.execute(text("DROP INDEX IF EXISTS ix_routing_edges_bbox_box_gist"))
def _create_routing_geometry_indexes(session: Session) -> None:
session.execute(text("CREATE INDEX IF NOT EXISTS ix_routing_nodes_geom_gist ON routing_nodes USING GIST (geom)"))
session.execute(text("CREATE INDEX IF NOT EXISTS ix_routing_edges_bbox_box_gist ON routing_edges USING GIST (box(point(max_lon, max_lat), point(min_lon, min_lat)))"))
class _RoutingGraphHandler(osmium.SimpleHandler):
def __init__(
self,
*,
session: Session,
dataset_id: int,
batch_size: int,
progress_callback: ProgressCallback | None,
) -> None:
super().__init__()
self.session = session
self.dataset_id = dataset_id
self.batch_size = max(500, int(batch_size))
self.progress_callback = progress_callback
self.nodes: dict[int, dict[str, object]] = {}
self.edges: list[dict[str, object]] = []
self.node_count = int(
session.scalar(select(func.count()).select_from(RoutingNode).where(RoutingNode.dataset_id == dataset_id)) or 0
)
self.edge_count = int(
session.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset_id)) or 0
)
self.walk_edge_count = 0
self.drive_edge_count = 0
self.skipped_way_count = 0
self.processed_way_count = 0
def way(self, way) -> None:
tags = {tag.k: tag.v for tag in way.tags}
highway = tags.get("highway")
if not highway or highway in EXCLUDED_HIGHWAYS:
self.skipped_way_count += 1
return
walkable = _walkable(tags, highway)
drivable = _drivable(tags, highway)
if not walkable and not drivable:
self.skipped_way_count += 1
return
nodes = []
for node in way.nodes:
if not node.location.valid():
continue
nodes.append((int(node.ref), float(node.location.lon), float(node.location.lat)))
if len(nodes) < 2:
self.skipped_way_count += 1
return
oneway = _oneway_direction(tags, highway)
drive_speed_mps = _drive_speed_mps(tags, highway)
walk_speed_mps = STEP_WALK_SPEED_MPS if highway == "steps" else DEFAULT_WALK_SPEED_MPS
for left, right in zip(nodes, nodes[1:]):
source_id, source_lon, source_lat = left
target_id, target_lon, target_lat = right
if source_id == target_id:
continue
length_m = _distance_m(source_lat, source_lon, target_lat, target_lon)
if length_m <= 0:
continue
if oneway == "reverse":
source_id, target_id = target_id, source_id
source_lon, target_lon = target_lon, source_lon
source_lat, target_lat = target_lat, source_lat
walk_cost = length_m / walk_speed_mps if walkable else None
drive_cost = length_m / drive_speed_mps if drivable and drive_speed_mps > 0 else None
reverse_walk_cost = walk_cost
reverse_drive_cost = None if oneway in {"forward", "reverse"} else drive_cost
self.nodes[source_id] = {"dataset_id": self.dataset_id, "osm_node_id": source_id, "lon": source_lon, "lat": source_lat}
self.nodes[target_id] = {"dataset_id": self.dataset_id, "osm_node_id": target_id, "lon": target_lon, "lat": target_lat}
self.edges.append(
{
"dataset_id": self.dataset_id,
"osm_way_id": int(way.id),
"source_osm_node_id": source_id,
"target_osm_node_id": target_id,
"source_lon": source_lon,
"source_lat": source_lat,
"target_lon": target_lon,
"target_lat": target_lat,
"highway": highway,
"name": tags.get("name"),
"length_m": length_m,
"walk_cost_s": walk_cost,
"reverse_walk_cost_s": reverse_walk_cost,
"drive_cost_s": drive_cost,
"reverse_drive_cost_s": reverse_drive_cost,
"geometry_geojson": json.dumps({"type": "LineString", "coordinates": [[source_lon, source_lat], [target_lon, target_lat]]}, separators=(",", ":")),
"min_lon": min(source_lon, target_lon),
"min_lat": min(source_lat, target_lat),
"max_lon": max(source_lon, target_lon),
"max_lat": max(source_lat, target_lat),
"tags_json": _routing_tags_json(tags),
}
)
self.edge_count += 1
if walk_cost is not None:
self.walk_edge_count += 1
if drive_cost is not None:
self.drive_edge_count += 1
self.processed_way_count += 1
if len(self.edges) >= self.batch_size:
self.flush()
if self.processed_way_count % 100_000 == 0:
_emit(
self.progress_callback,
"routing_layer_import_batch",
f"Imported {self.edge_count:,} routing edges.",
self.edge_count,
None,
{"processed_ways": self.processed_way_count, "nodes_pending": len(self.nodes), "edges": self.edge_count},
)
def flush(self) -> None:
if not self.nodes and not self.edges:
return
node_rows = list(self.nodes.values())
edge_rows = self.edges
if node_rows:
stmt = postgresql_insert(RoutingNode).values(node_rows)
stmt = stmt.on_conflict_do_nothing(index_elements=["dataset_id", "osm_node_id"])
self.session.execute(stmt)
self.node_count += len(node_rows)
self.nodes.clear()
if edge_rows:
self.session.bulk_insert_mappings(RoutingEdge, edge_rows)
self.edges = []
self.session.commit()
def _walkable(tags: dict[str, str], highway: str) -> bool:
if highway not in WALK_HIGHWAYS:
return False
access = _tag_value(tags, "access")
foot = _tag_value(tags, "foot")
if foot in NO_VALUES:
return False
if access in NO_VALUES and foot not in YES_VALUES:
return False
if highway in {"motorway", "motorway_link", "trunk", "trunk_link"} and foot not in YES_VALUES:
return False
return True
def _drivable(tags: dict[str, str], highway: str) -> bool:
if highway not in DRIVE_HIGHWAYS:
return False
access = _tag_value(tags, "access")
motor_vehicle = _tag_value(tags, "motor_vehicle")
motorcar = _tag_value(tags, "motorcar")
vehicle = _tag_value(tags, "vehicle")
if motorcar in NO_VALUES or motor_vehicle in NO_VALUES or vehicle in NO_VALUES:
return False
if access in NO_VALUES and motorcar not in YES_VALUES and motor_vehicle not in YES_VALUES:
return False
if highway in {"footway", "path", "pedestrian", "steps", "cycleway", "bridleway"}:
return motorcar in YES_VALUES or motor_vehicle in YES_VALUES
return True
def _oneway_direction(tags: dict[str, str], highway: str) -> str:
oneway = _tag_value(tags, "oneway")
if oneway in ONEWAY_REVERSE:
return "reverse"
if oneway in ONEWAY_FORWARD or tags.get("junction") == "roundabout" or highway == "motorway":
return "forward"
return "both"
def _drive_speed_mps(tags: dict[str, str], highway: str) -> float:
maxspeed = _parse_maxspeed(tags.get("maxspeed"))
kmh = maxspeed or DEFAULT_DRIVE_SPEED_KMH.get(highway, 30)
return max(5.0, float(kmh) / 3.6)
def _parse_maxspeed(value: str | None) -> float | None:
if not value:
return None
text = value.strip().lower()
if text in {"signals", "none", "walk", "variable"}:
return None
if text.endswith("mph"):
number = _leading_float(text[:-3])
return None if number is None else number * 1.60934
return _leading_float(text)
def _leading_float(value: str) -> float | None:
digits = []
for char in value.strip():
if char.isdigit() or char == ".":
digits.append(char)
elif digits:
break
if not digits:
return None
try:
return float("".join(digits))
except ValueError:
return None
def _routing_tags_json(tags: dict[str, str]) -> str:
selected = {
key: value
for key, value in tags.items()
if key in {"access", "bicycle", "bridge", "foot", "highway", "junction", "maxspeed", "motor_vehicle", "motorcar", "name", "oneway", "service", "surface", "tunnel", "vehicle"}
}
return json.dumps(selected, separators=(",", ":"))
def _tag_value(tags: dict[str, str], key: str) -> str:
return str(tags.get(key) or "").strip().lower()
def _distance_m(lat_a: float, lon_a: float, lat_b: float, lon_b: float) -> float:
radius = 6_371_000.0
phi_a = math.radians(lat_a)
phi_b = math.radians(lat_b)
delta_phi = math.radians(lat_b - lat_a)
delta_lambda = math.radians(lon_b - lon_a)
hav = math.sin(delta_phi / 2) ** 2 + math.cos(phi_a) * math.cos(phi_b) * math.sin(delta_lambda / 2) ** 2
return radius * 2 * math.atan2(math.sqrt(hav), math.sqrt(1 - hav))
def _metadata(dataset: Dataset) -> dict[str, object]:
try:
value = json.loads(dataset.metadata_json or "{}")
except json.JSONDecodeError:
return {}
return value if isinstance(value, dict) else {}
def _emit(
progress_callback: ProgressCallback | None,
event_type: str,
message: str,
progress_current: int | None,
progress_total: int | None,
metadata: dict[str, object] | None = None,
) -> None:
if progress_callback is not None:
progress_callback(event_type, message, progress_current, progress_total, metadata)

40
app/pipeline/run.py Normal file
View File

@@ -0,0 +1,40 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Callable, Any
from sqlalchemy.orm import Session
from app.models import Source
from app.pipeline.gtfs import run_gtfs_source
from app.pipeline.osm_diff import run_osm_diff_source
from app.pipeline.osm_geojson import run_osm_geojson_source
from app.pipeline.osm_pbf import run_osm_pbf_source
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, Any] | None], None]
def run_source(session: Session, source: Source, progress_callback: ProgressCallback | None = None):
source.status = "running"
source.last_run_at = datetime.now(timezone.utc)
source.last_error = None
session.flush()
try:
if source.kind == "gtfs":
dataset = run_gtfs_source(session, source, progress_callback=progress_callback)
elif source.kind == "osm_geojson":
dataset = run_osm_geojson_source(session, source)
elif source.kind == "osm_pbf":
dataset = run_osm_pbf_source(session, source, progress_callback=progress_callback)
elif source.kind == "osm_diff":
dataset = run_osm_diff_source(session, source)
else:
raise ValueError(f"Unsupported source kind: {source.kind}")
source.status = "ok"
source.last_error = None
return dataset
except Exception as exc: # noqa: BLE001 - persist pipeline error for UI
source.status = "error"
source.last_error = str(exc)
raise

294
app/pipeline/sample_data.py Normal file
View File

@@ -0,0 +1,294 @@
from __future__ import annotations
import csv
import io
import json
import zipfile
from pathlib import Path
from datetime import datetime, timezone
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from app.config import settings
from app.db import init_db
from app.models import (
Dataset,
CanonicalStop,
CanonicalStopLink,
GtfsAgency,
GtfsCalendar,
GtfsCalendarDate,
GtfsRoute,
GtfsRoutePatternLink,
GtfsShape,
GtfsStop,
GtfsStopTime,
GtfsTripRoutePatternLink,
GtfsTrip,
Itinerary,
ItineraryLeg,
Job,
JobEvent,
MatchRule,
OsmDiffState,
OsmFeature,
PipelineRun,
RouteMatch,
RoutePattern,
RoutePatternStop,
RoutingEdge,
RoutingNode,
Source,
SourceCatalogEntry,
SourceUpdateCheck,
TravelRequest,
)
from app.pipeline.matcher import run_route_matching
from app.pipeline.route_layer import rebuild_route_layer
from app.pipeline.run import run_source
def load_sample_project(session: Session, *, preserve_job_id: int | None = None) -> dict:
"""Clear the DB, create a small Berlin-like GTFS + OSM sample, import, and match."""
init_db()
clear_project_data(session, preserve_job_id=preserve_job_id, preserve_catalog=True)
sample_dir = settings.data_dir / "sample"
sample_dir.mkdir(parents=True, exist_ok=True)
gtfs_path = sample_dir / "sample_berlin.gtfs.zip"
osm_path = sample_dir / "sample_berlin_osm.geojson"
create_sample_gtfs(gtfs_path)
create_sample_osm_geojson(osm_path)
gtfs_source = Source(name="Sample Berlin GTFS", kind="gtfs", url=str(gtfs_path), country="DE", license="sample")
osm_source = Source(name="Sample Berlin OSM transport", kind="osm_geojson", url=str(osm_path), country="DE", license="sample")
session.add_all([gtfs_source, osm_source])
session.flush()
gtfs_dataset = run_source(session, gtfs_source)
osm_dataset = run_source(session, osm_source)
match_result = run_route_matching(session)
route_layer_result = rebuild_route_layer(session)
return {
"status": "ok",
"gtfs_dataset_id": gtfs_dataset.id,
"osm_dataset_id": osm_dataset.id,
"match_result": match_result,
"route_layer_result": route_layer_result,
}
def clear_project_data(
session: Session,
*,
preserve_job_id: int | None = None,
preserve_catalog: bool = True,
) -> None:
"""Clear user/project data while optionally preserving the current queue job."""
session.execute(delete(PipelineRun))
if preserve_job_id is None:
session.execute(delete(JobEvent))
session.execute(delete(Job))
else:
_cancel_other_jobs_for_reset(session, preserve_job_id)
for model in [
ItineraryLeg,
Itinerary,
TravelRequest,
SourceUpdateCheck,
OsmDiffState,
MatchRule,
RouteMatch,
GtfsTripRoutePatternLink,
GtfsRoutePatternLink,
RoutePatternStop,
RoutePattern,
CanonicalStopLink,
CanonicalStop,
RoutingEdge,
RoutingNode,
GtfsStopTime,
GtfsCalendarDate,
GtfsCalendar,
GtfsShape,
GtfsTrip,
GtfsRoute,
GtfsStop,
GtfsAgency,
OsmFeature,
Dataset,
Source,
]:
session.execute(delete(model))
if not preserve_catalog:
session.execute(delete(SourceCatalogEntry))
session.flush()
def _cancel_other_jobs_for_reset(session: Session, preserve_job_id: int) -> None:
now = datetime.now(timezone.utc)
active_statuses = {"queued", "running", "paused"}
jobs = session.scalars(
select(Job).where(Job.id != preserve_job_id, Job.status.in_(active_statuses))
).all()
for job in jobs:
job.status = "cancelled"
job.requested_action = None
job.lease_owner = None
job.lease_expires_at = None
job.paused_at = None
job.error = None
job.updated_at = now
job.finished_at = now
session.add(
JobEvent(
job_id=job.id,
event_type="cancelled_by_reset",
message=f"Job cancelled by reset job #{preserve_job_id}.",
progress_current=job.progress_current,
progress_total=job.progress_total,
)
)
def create_sample_gtfs(path: Path) -> None:
agencies = [
{"agency_id": "BVG", "agency_name": "BVG", "agency_url": "https://example.invalid/bvg", "agency_timezone": "Europe/Berlin"},
{"agency_id": "DB", "agency_name": "DB Regio", "agency_url": "https://example.invalid/db", "agency_timezone": "Europe/Berlin"},
{"agency_id": "XAIR", "agency_name": "Example Airport Shuttle", "agency_url": "https://example.invalid/xair", "agency_timezone": "Europe/Berlin"},
]
stops = [
{"stop_id": "hbf", "stop_name": "Berlin Hauptbahnhof", "stop_lat": "52.5251", "stop_lon": "13.3696"},
{"stop_id": "friedrich", "stop_name": "Friedrichstraße", "stop_lat": "52.5201", "stop_lon": "13.3862"},
{"stop_id": "alex", "stop_name": "Alexanderplatz", "stop_lat": "52.5219", "stop_lon": "13.4132"},
{"stop_id": "ost", "stop_name": "Ostbahnhof", "stop_lat": "52.5100", "stop_lon": "13.4344"},
{"stop_id": "zoo", "stop_name": "Zoologischer Garten", "stop_lat": "52.5069", "stop_lon": "13.3320"},
{"stop_id": "wittenberg", "stop_name": "Wittenbergplatz", "stop_lat": "52.5020", "stop_lon": "13.3430"},
{"stop_id": "potsdamer", "stop_name": "Potsdamer Platz", "stop_lat": "52.5096", "stop_lon": "13.3760"},
{"stop_id": "stadtmitte", "stop_name": "Stadtmitte", "stop_lat": "52.5113", "stop_lon": "13.3907"},
{"stop_id": "reichstag", "stop_name": "Reichstag", "stop_lat": "52.5186", "stop_lon": "13.3763"},
{"stop_id": "hackescher", "stop_name": "Hackescher Markt", "stop_lat": "52.5220", "stop_lon": "13.4023"},
{"stop_id": "naturkunde", "stop_name": "Naturkundemuseum", "stop_lat": "52.5300", "stop_lon": "13.3790"},
{"stop_id": "wannsee", "stop_name": "Wannsee", "stop_lat": "52.4210", "stop_lon": "13.1797"},
{"stop_id": "kladow", "stop_name": "Kladow", "stop_lat": "52.4547", "stop_lon": "13.1439"},
{"stop_id": "airport", "stop_name": "Example Airport Terminal", "stop_lat": "52.3650", "stop_lon": "13.5100"},
]
routes = [
{"route_id": "u2", "agency_id": "BVG", "route_short_name": "U2", "route_long_name": "Pankow - Ruhleben", "route_type": "1"},
{"route_id": "re1", "agency_id": "DB", "route_short_name": "RE1", "route_long_name": "Magdeburg - Frankfurt Oder", "route_type": "2"},
{"route_id": "m5", "agency_id": "BVG", "route_short_name": "M5", "route_long_name": "Hauptbahnhof - Hohenschönhausen", "route_type": "0"},
{"route_id": "bus100", "agency_id": "BVG", "route_short_name": "100", "route_long_name": "Zoologischer Garten - Alexanderplatz", "route_type": "3"},
{"route_id": "f10", "agency_id": "BVG", "route_short_name": "F10", "route_long_name": "Wannsee - Kladow", "route_type": "4"},
{"route_id": "x99", "agency_id": "XAIR", "route_short_name": "X99", "route_long_name": "Airport Express Sample", "route_type": "3"},
]
trips = [
{"route_id": r["route_id"], "service_id": "daily", "trip_id": f"{r['route_id']}_trip", "shape_id": f"{r['route_id']}_shape"}
for r in routes
]
stop_sequences = {
"u2_trip": ["zoo", "wittenberg", "potsdamer", "stadtmitte", "alex"],
"re1_trip": ["hbf", "friedrich", "alex", "ost"],
"m5_trip": ["hbf", "naturkunde", "hackescher", "alex"],
"bus100_trip": ["zoo", "reichstag", "alex"],
"f10_trip": ["wannsee", "kladow"],
"x99_trip": ["alex", "airport"],
}
coords = {row["stop_id"]: (row["stop_lon"], row["stop_lat"]) for row in stops}
stop_times = []
shapes = []
for trip in trips:
trip_id = trip["trip_id"]
for idx, stop_id in enumerate(stop_sequences[trip_id], start=1):
stop_times.append(
{
"trip_id": trip_id,
"arrival_time": f"08:{idx * 5:02d}:00",
"departure_time": f"08:{idx * 5 + 1:02d}:00",
"stop_id": stop_id,
"stop_sequence": str(idx),
}
)
lon, lat = coords[stop_id]
shapes.append(
{
"shape_id": trip["shape_id"],
"shape_pt_lat": lat,
"shape_pt_lon": lon,
"shape_pt_sequence": str(idx),
}
)
calendar = [
{
"service_id": "daily",
"monday": "1",
"tuesday": "1",
"wednesday": "1",
"thursday": "1",
"friday": "1",
"saturday": "1",
"sunday": "1",
"start_date": "20260101",
"end_date": "20261231",
}
]
with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
_write_csv(zf, "agency.txt", ["agency_id", "agency_name", "agency_url", "agency_timezone"], agencies)
_write_csv(zf, "stops.txt", ["stop_id", "stop_name", "stop_lat", "stop_lon"], stops)
_write_csv(zf, "routes.txt", ["route_id", "agency_id", "route_short_name", "route_long_name", "route_type"], routes)
_write_csv(zf, "trips.txt", ["route_id", "service_id", "trip_id", "shape_id"], trips)
_write_csv(zf, "stop_times.txt", ["trip_id", "arrival_time", "departure_time", "stop_id", "stop_sequence"], stop_times)
_write_csv(
zf,
"calendar.txt",
["service_id", "monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday", "start_date", "end_date"],
calendar,
)
_write_csv(zf, "shapes.txt", ["shape_id", "shape_pt_lat", "shape_pt_lon", "shape_pt_sequence"], shapes)
def _write_csv(zf: zipfile.ZipFile, name: str, fields: list[str], rows: list[dict[str, str]]) -> None:
buffer = io.StringIO(newline="")
writer = csv.DictWriter(buffer, fieldnames=fields)
writer.writeheader()
writer.writerows(rows)
zf.writestr(name, buffer.getvalue())
def create_sample_osm_geojson(path: Path) -> None:
def line(fid, mode, ref, name, operator, coords):
return {
"type": "Feature",
"geometry": {"type": "LineString", "coordinates": coords},
"properties": {
"osm_type": "relation",
"osm_id": str(fid),
"type": "route",
"route": mode,
"ref": ref,
"name": name,
"operator": operator,
"network": "VBB" if operator == "BVG" else "DB",
},
}
def point(fid, kind, name, coords, props=None):
props = props or {}
props.update({"osm_type": "node", "osm_id": str(fid), "name": name})
return {"type": "Feature", "geometry": {"type": "Point", "coordinates": coords}, "properties": props}
features = [
line(1002, "subway", "U2", "U2 Ruhleben - Pankow", "BVG", [[13.3320, 52.5069], [13.3430, 52.5020], [13.3760, 52.5096], [13.3907, 52.5113], [13.4132, 52.5219]]),
line(2001, "train", "RE1", "RE1 Magdeburg - Frankfurt Oder", "DB Regio", [[13.3696, 52.5251], [13.3862, 52.5201], [13.4132, 52.5219], [13.4344, 52.5100]]),
line(5005, "tram", "M5", "M5 Hauptbahnhof - Hohenschönhausen", "BVG", [[13.3696, 52.5251], [13.3790, 52.5300], [13.4023, 52.5220], [13.4132, 52.5219]]),
line(6100, "bus", "100", "Bus 100 Zoologischer Garten - Alexanderplatz", "BVG", [[13.3320, 52.5069], [13.3763, 52.5186], [13.4132, 52.5219]]),
line(7010, "ferry", "F10", "F10 Wannsee - Kladow", "BVG", [[13.1797, 52.4210], [13.1439, 52.4547]]),
line(5010, "tram", "M10", "M10 Warschauer Straße - Hauptbahnhof", "BVG", [[13.4500, 52.5050], [13.4020, 52.5300], [13.3696, 52.5251]]),
point(9001, "station", "Berlin Hauptbahnhof", [13.3696, 52.5251], {"railway": "station"}),
point(9002, "station", "Alexanderplatz", [13.4132, 52.5219], {"railway": "station"}),
point(9003, "stop", "Zoologischer Garten", [13.3320, 52.5069], {"public_transport": "station", "railway": "station"}),
point(9004, "terminal", "Wannsee Ferry Terminal", [13.1797, 52.4210], {"amenity": "ferry_terminal"}),
point(9005, "terminal", "Kladow Ferry Terminal", [13.1439, 52.4547], {"amenity": "ferry_terminal"}),
]
path.write_text(json.dumps({"type": "FeatureCollection", "features": features}, indent=2), encoding="utf-8")

135
app/pipeline/state.py Normal file
View File

@@ -0,0 +1,135 @@
from __future__ import annotations
from datetime import datetime, timezone
import hashlib
import json
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.models import PipelineRun
STAGE_ACQUIRE_RAW = "acquire_raw"
STAGE_FILTER_TRANSPORT = "filter_transport"
STAGE_EXTRACT_GEOMETRY = "extract_geometry"
STAGE_LABEL_FEATURES = "label_features"
STAGE_BUILD_INDEXES = "build_indexes"
STAGE_MATCH_ROUTES = "match_routes"
STAGE_BUILD_ROUTE_LAYER = "build_route_layer"
def stable_json(value: Any) -> str:
return json.dumps(value, sort_keys=True, separators=(",", ":"), default=str)
def dependency_hash(value: Any) -> str:
return hashlib.sha256(stable_json(value).encode("utf-8")).hexdigest()
def latest_completed_run(
session: Session,
*,
stage: str,
version: str,
dependency_hash_value: str,
source_id: int | None = None,
dataset_id: int | None = None,
) -> PipelineRun | None:
stmt = (
select(PipelineRun)
.where(
PipelineRun.stage == stage,
PipelineRun.version == version,
PipelineRun.dependency_hash == dependency_hash_value,
PipelineRun.status == "completed",
)
.order_by(PipelineRun.finished_at.desc(), PipelineRun.id.desc())
.limit(1)
)
if source_id is None:
stmt = stmt.where(PipelineRun.source_id.is_(None))
else:
stmt = stmt.where(PipelineRun.source_id == source_id)
if dataset_id is None:
stmt = stmt.where(PipelineRun.dataset_id.is_(None))
else:
stmt = stmt.where(PipelineRun.dataset_id == dataset_id)
return session.scalar(stmt)
def start_pipeline_run(
session: Session,
*,
stage: str,
version: str,
dependency_hash_value: str,
source_id: int | None = None,
dataset_id: int | None = None,
job_id: int | None = None,
inputs: dict[str, Any] | None = None,
) -> PipelineRun:
now = datetime.now(timezone.utc)
run = PipelineRun(
stage=stage,
version=version,
dependency_hash=dependency_hash_value,
status="running",
source_id=source_id,
dataset_id=dataset_id,
job_id=job_id,
input_json=None if inputs is None else stable_json(inputs),
started_at=now,
updated_at=now,
)
session.add(run)
session.flush()
return run
def finish_pipeline_run(
session: Session,
run: PipelineRun,
*,
status: str = "completed",
outputs: dict[str, Any] | None = None,
error: str | None = None,
) -> PipelineRun:
now = datetime.now(timezone.utc)
run.status = status
run.output_json = None if outputs is None else stable_json(outputs)
run.error = error
run.updated_at = now
run.finished_at = now
session.flush()
return run
def pipeline_run_payload(run: PipelineRun) -> dict[str, Any]:
return {
"id": run.id,
"stage": run.stage,
"version": run.version,
"dependency_hash": run.dependency_hash,
"status": run.status,
"source_id": run.source_id,
"dataset_id": run.dataset_id,
"job_id": run.job_id,
"input": _json_object(run.input_json),
"output": _json_object(run.output_json),
"error": run.error,
"started_at": run.started_at.isoformat() if run.started_at else None,
"updated_at": run.updated_at.isoformat() if run.updated_at else None,
"finished_at": run.finished_at.isoformat() if run.finished_at else None,
}
def _json_object(text: str | None) -> dict[str, Any]:
if not text:
return {}
try:
value = json.loads(text)
except json.JSONDecodeError:
return {}
return value if isinstance(value, dict) else {}

89
app/pipeline/utils.py Normal file
View File

@@ -0,0 +1,89 @@
from __future__ import annotations
import hashlib
import json
import re
from pathlib import Path
from typing import Iterable, Optional
from shapely.geometry import shape
def sha256_file(path: Path) -> str:
h = hashlib.sha256()
with path.open("rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
h.update(chunk)
return h.hexdigest()
def norm_text(value: object) -> str:
if value is None:
return ""
value = str(value).lower().strip()
value = value.replace("ß", "ss")
value = re.sub(r"[^a-z0-9]+", " ", value)
return re.sub(r"\s+", " ", value).strip()
def norm_ref(value: object) -> str:
if value is None:
return ""
return re.sub(r"[^a-z0-9]+", "", str(value).lower())
def first_nonempty(*values: object) -> str:
for value in values:
if value is None:
continue
text = str(value).strip()
if text:
return text
return ""
def geometry_json_and_bbox(geometry: object) -> tuple[Optional[str], tuple[Optional[float], Optional[float], Optional[float], Optional[float]]]:
if geometry is None:
return None, (None, None, None, None)
try:
geom = shape(geometry) if isinstance(geometry, dict) else geometry
if geom.is_empty:
return None, (None, None, None, None)
min_lon, min_lat, max_lon, max_lat = geom.bounds
return json.dumps(geom.__geo_interface__, separators=(",", ":")), (min_lon, min_lat, max_lon, max_lat)
except Exception:
return None, (None, None, None, None)
def bbox_overlap(a: tuple[float | None, float | None, float | None, float | None], b: tuple[float | None, float | None, float | None, float | None]) -> bool:
if any(v is None for v in (*a, *b)):
return False
aminx, aminy, amaxx, amaxy = a # type: ignore[misc]
bminx, bminy, bmaxx, bmaxy = b # type: ignore[misc]
return not (amaxx < bminx or bmaxx < aminx or amaxy < bminy or bmaxy < aminy)
def bbox_center(b: tuple[float | None, float | None, float | None, float | None]) -> Optional[tuple[float, float]]:
if any(v is None for v in b):
return None
minx, miny, maxx, maxy = b # type: ignore[misc]
return ((minx + maxx) / 2, (miny + maxy) / 2)
def approx_bbox_center_distance_deg(a: tuple[float | None, float | None, float | None, float | None], b: tuple[float | None, float | None, float | None, float | None]) -> Optional[float]:
ca = bbox_center(a)
cb = bbox_center(b)
if ca is None or cb is None:
return None
return ((ca[0] - cb[0]) ** 2 + (ca[1] - cb[1]) ** 2) ** 0.5
def batched(iterable: Iterable[dict], batch_size: int = 1000) -> Iterable[list[dict]]:
batch: list[dict] = []
for item in iterable:
batch.append(item)
if len(batch) >= batch_size:
yield batch
batch = []
if batch:
yield batch

393
app/qa.py Normal file
View File

@@ -0,0 +1,393 @@
from __future__ import annotations
import json
from datetime import datetime, timezone
from typing import Any
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from app.gtfs_storage import missing_sidecar_paths as missing_gtfs_sidecar_paths
from app.models import (
CanonicalStop,
CanonicalStopLink,
Dataset,
GtfsAgency,
GtfsCalendar,
GtfsCalendarDate,
GtfsRoute,
GtfsShape,
GtfsStop,
GtfsTrip,
Job,
OsmFeature,
RouteMatch,
RoutePattern,
RoutePatternStop,
Source,
SourceCatalogEntry,
)
from app.osm_storage import missing_sidecar_paths as missing_osm_sidecar_paths
from app.pipeline.osm_addresses import ADDRESS_INDEX_VERSION
from app.pipeline.routing_layer import active_routing_dataset
def qa_summary(session: Session) -> dict[str, Any]:
active_gtfs_datasets = session.scalars(
select(Dataset).where(Dataset.kind == "gtfs", Dataset.is_active.is_(True)).order_by(Dataset.id)
).all()
active_osm_datasets = session.scalars(
select(Dataset).where(Dataset.kind == "osm_geojson", Dataset.is_active.is_(True)).order_by(Dataset.id)
).all()
active_gtfs_ids = [int(dataset.id) for dataset in active_gtfs_datasets]
active_osm_ids = [int(dataset.id) for dataset in active_osm_datasets]
source_catalog_total = _count(session, SourceCatalogEntry)
registered_sources = _count(session, Source)
linked_catalog_entries = int(
session.scalar(
select(func.count(func.distinct(Source.catalog_entry_id))).where(Source.catalog_entry_id.is_not(None))
)
or 0
)
priority_backlog = _priority_catalog_backlog(session)
failed_sources = int(
session.scalar(
select(func.count())
.select_from(Source)
.where((Source.last_error.is_not(None)) | Source.status.in_(["failed", "error"]))
)
or 0
)
active_jobs = _job_status_counts(session)
missing_gtfs_sidecars = sum(1 for dataset in active_gtfs_datasets if missing_gtfs_sidecar_paths(dataset))
missing_osm_sidecars = sum(1 for dataset in active_osm_datasets if missing_osm_sidecar_paths(dataset))
gtfs_counts = _gtfs_validation_counts(session, active_gtfs_ids)
link_counts = _link_quality_counts(session, active_gtfs_ids, active_osm_ids)
route_counts = _route_quality_counts(session, active_gtfs_ids)
address_status = _lightweight_address_index_status(session)
license_unknown = int(
session.scalar(
select(func.count())
.select_from(Source)
.where(Source.kind == "gtfs", (Source.license.is_(None)) | (func.lower(Source.license).in_(["", "unknown"])))
)
or 0
)
return {
"generated_at": datetime.now(timezone.utc).isoformat(),
"decision": {
"deployment": "same_workbench_for_now",
"database": "same_postgresql_database_for_now",
"split_trigger": "Split when third-party API, accounts/billing, heavy export jobs, or independent scaling are needed.",
"api_contract": "/api/qa/summary is intentionally display-ready but stable enough to become a harmonization-service summary endpoint.",
},
"sections": [
{
"id": "source_discovery",
"title": "Source Discovery",
"items": [
_item("Identified sources", source_catalog_total, "info", "Rows in the source catalog."),
_item("Registered sources", registered_sources, "info", "Sources known to the importer."),
_item("Catalog entries linked", linked_catalog_entries, "good" if linked_catalog_entries else "warn", "Catalog rows connected to importer sources."),
_item("Priority catalog backlog", priority_backlog, "warn" if priority_backlog else "good", "P0/P1 catalog rows without a registered source."),
],
},
{
"id": "import_health",
"title": "Import Health",
"items": [
_item("Active GTFS datasets", len(active_gtfs_ids), "good" if active_gtfs_ids else "warn", "Feeds currently participating in harmonization."),
_item("Active OSM datasets", len(active_osm_ids), "good" if active_osm_ids else "warn", "Visual/spatial datasets currently active."),
_item("Running jobs", active_jobs.get("running", 0), "warn" if active_jobs.get("running", 0) else "info", "Currently running queued work."),
_item("Queued jobs", active_jobs.get("queued", 0), "info", "Outstanding queued work."),
_item("Failed sources", failed_sources, "bad" if failed_sources else "good", "Sources with failed status or last_error."),
_item("Missing GTFS sidecars", missing_gtfs_sidecars, "bad" if missing_gtfs_sidecars else "good", "Active GTFS datasets whose sidecar is unavailable."),
_item("Missing OSM sidecars", missing_osm_sidecars, "bad" if missing_osm_sidecars else "good", "Active OSM datasets whose sidecar is unavailable."),
],
},
{
"id": "gtfs_validation",
"title": "GTFS Validation",
"items": [
_item("Agencies", gtfs_counts["agencies"], "info", "Imported agency.txt rows."),
_item("Stops", gtfs_counts["stops"], "info", "Imported stops."),
_item("Routes", gtfs_counts["routes"], "info", "Imported routes."),
_item("Trips", gtfs_counts["trips"], "info", "Imported trips."),
_item("Shapes", gtfs_counts["shapes"], "info", "Imported shape records."),
_item("Stops without coordinates", gtfs_counts["stops_without_coordinates"], "bad" if gtfs_counts["stops_without_coordinates"] else "good", "Stops that cannot be spatially linked or routed."),
_item("Routes without geometry", gtfs_counts["routes_without_geometry"], "warn" if gtfs_counts["routes_without_geometry"] else "good", "Routes with no stored GTFS shape geometry."),
_item("Routes without agency", gtfs_counts["routes_without_agency"], "warn" if gtfs_counts["routes_without_agency"] else "good", "Routes missing agency/operator references."),
_item("Calendar range", gtfs_counts["calendar_range"], "info", "Min/max imported service dates from calendars and exceptions."),
],
},
{
"id": "deduplication",
"title": "Deduplication and Stop Links",
"items": [
_item("Canonical stops", link_counts["canonical_stops"], "info", "Current normalized stop/station records."),
_item("GTFS stop links", link_counts["gtfs_stop_links"], "good" if link_counts["gtfs_stop_links"] else "warn", "Timetable stops linked into canonical stops."),
_item("GTFS stops without canonical link", link_counts["gtfs_stops_without_canonical"], "bad" if link_counts["gtfs_stops_without_canonical"] else "good", "Imported active stops that still need deduplication/linking."),
_item("OSM visual stop links", link_counts["osm_stop_links"], "good" if link_counts["osm_stop_links"] else "warn", "OSM stop/station features linked to canonical stops."),
_item("OSM stops without canonical link", link_counts["osm_stops_without_canonical"], "warn" if link_counts["osm_stops_without_canonical"] else "good", "Visual stops that are not yet linked to GTFS/canonical stops."),
_item("Multi-source stop groups", link_counts["multi_source_stop_groups"], "info", "Canonical stops that merge GTFS stops from multiple datasets."),
_item("Long-distance OSM links", link_counts["long_distance_osm_links"], "warn" if link_counts["long_distance_osm_links"] else "good", "OSM stop links over 150m from the canonical stop."),
],
},
{
"id": "route_quality",
"title": "Route Matching and Geometry",
"items": [
_item("Matched/accepted routes", route_counts["matched_or_accepted"], "good" if route_counts["matched_or_accepted"] else "warn", "GTFS routes with accepted or automatic OSM matches."),
_item("Probable matches", route_counts["probable"], "warn" if route_counts["probable"] else "info", "Potential conflicts needing review."),
_item("Weak matches", route_counts["weak"], "warn" if route_counts["weak"] else "good", "Low-confidence route links."),
_item("Missing route matches", route_counts["missing"], "bad" if route_counts["missing"] else "good", "Routes with no visual match."),
_item("Unreviewed GTFS routes", route_counts["routes_without_match"], "warn" if route_counts["routes_without_match"] else "good", "Active GTFS routes without a RouteMatch row."),
_item("Route patterns", route_counts["route_patterns"], "info", "Published visual route-layer patterns."),
_item("Route patterns without stops", route_counts["route_patterns_without_stops"], "warn" if route_counts["route_patterns_without_stops"] else "good", "Visual patterns missing canonical stop sequence evidence."),
],
},
{
"id": "publication_readiness",
"title": "Publication Readiness",
"items": [
_item("Address index stale", "yes" if address_status.get("stale") else "no", "warn" if address_status.get("stale") else "good", "Address polygons/search index version status."),
_item("GTFS licenses unknown", license_unknown, "warn" if license_unknown else "good", "GTFS sources without explicit redistribution/license status."),
_item("Canonical export", "draft", "warn", "Canonical Europe dataset export tables/API are not versioned yet."),
_item("Third-party API", "later", "info", "Accounts, billing, quotas, and API backend are intentionally out of scope for this step."),
],
},
],
"next_actions": [
"Add review queues for each non-zero bad/warn metric.",
"Persist source authority and redistribution policy before publishing third-party exports.",
"Create versioned canonical snapshots and export manifests.",
],
}
def _item(label: str, value: object, tone: str, description: str) -> dict[str, object]:
return {"label": label, "value": value, "tone": tone, "description": description}
def _lightweight_address_index_status(session: Session) -> dict[str, object]:
dataset = active_routing_dataset(session)
if dataset is None or not dataset.metadata_json:
return {"stale": False, "version": None, "current_version": ADDRESS_INDEX_VERSION}
try:
metadata = json.loads(dataset.metadata_json or "{}")
except json.JSONDecodeError:
metadata = {}
address_index = metadata.get("address_index") if isinstance(metadata, dict) else {}
if not isinstance(address_index, dict):
address_index = {}
version = address_index.get("version")
return {
"stale": bool(address_index and version != ADDRESS_INDEX_VERSION),
"version": version,
"current_version": ADDRESS_INDEX_VERSION,
}
def _count(session: Session, model, *where) -> int:
stmt = select(func.count()).select_from(model)
if where:
stmt = stmt.where(*where)
return int(session.scalar(stmt) or 0)
def _priority_catalog_backlog(session: Session) -> int:
linked = select(Source.id).where(Source.catalog_entry_id == SourceCatalogEntry.id).exists()
return int(
session.scalar(
select(func.count())
.select_from(SourceCatalogEntry)
.where(SourceCatalogEntry.priority.in_(["P0", "P0 fallback", "P1"]), ~linked)
)
or 0
)
def _job_status_counts(session: Session) -> dict[str, int]:
return {
str(status): int(count)
for status, count in session.execute(
select(Job.status, func.count())
.where(Job.dismissed_at.is_(None), Job.status.in_(["queued", "running", "paused", "failed"]))
.group_by(Job.status)
).all()
}
def _gtfs_validation_counts(session: Session, dataset_ids: list[int]) -> dict[str, object]:
if not dataset_ids:
return {
"agencies": 0,
"stops": 0,
"routes": 0,
"trips": 0,
"shapes": 0,
"stops_without_coordinates": 0,
"routes_without_geometry": 0,
"routes_without_agency": 0,
"calendar_range": "none",
}
calendar_min, calendar_max = session.execute(
select(func.min(GtfsCalendar.start_date), func.max(GtfsCalendar.end_date)).where(GtfsCalendar.dataset_id.in_(dataset_ids))
).one()
exception_min, exception_max = session.execute(
select(func.min(GtfsCalendarDate.date), func.max(GtfsCalendarDate.date)).where(GtfsCalendarDate.dataset_id.in_(dataset_ids))
).one()
min_date = min(value for value in [calendar_min, exception_min] if value is not None) if (calendar_min or exception_min) else None
max_date = max(value for value in [calendar_max, exception_max] if value is not None) if (calendar_max or exception_max) else None
return {
"agencies": _count(session, GtfsAgency, GtfsAgency.dataset_id.in_(dataset_ids)),
"stops": _count(session, GtfsStop, GtfsStop.dataset_id.in_(dataset_ids)),
"routes": _count(session, GtfsRoute, GtfsRoute.dataset_id.in_(dataset_ids)),
"trips": _count(session, GtfsTrip, GtfsTrip.dataset_id.in_(dataset_ids)),
"shapes": _count(session, GtfsShape, GtfsShape.dataset_id.in_(dataset_ids)),
"stops_without_coordinates": _count(
session,
GtfsStop,
GtfsStop.dataset_id.in_(dataset_ids),
(GtfsStop.lat.is_(None)) | (GtfsStop.lon.is_(None)),
),
"routes_without_geometry": _count(
session,
GtfsRoute,
GtfsRoute.dataset_id.in_(dataset_ids),
(GtfsRoute.geometry_geojson.is_(None)) | (GtfsRoute.geometry_geojson == ""),
),
"routes_without_agency": _count(
session,
GtfsRoute,
GtfsRoute.dataset_id.in_(dataset_ids),
(GtfsRoute.agency_id.is_(None)) | (GtfsRoute.agency_id == ""),
),
"calendar_range": f"{min_date or 'unknown'} -> {max_date or 'unknown'}",
}
def _link_quality_counts(session: Session, gtfs_dataset_ids: list[int], osm_dataset_ids: list[int]) -> dict[str, int]:
if gtfs_dataset_ids:
gtfs_link_exists = (
select(CanonicalStopLink.id)
.where(
CanonicalStopLink.object_type == "gtfs_stop",
CanonicalStopLink.dataset_id == GtfsStop.dataset_id,
CanonicalStopLink.object_id == GtfsStop.id,
)
.exists()
)
gtfs_stops_without_canonical = _count(
session,
GtfsStop,
GtfsStop.dataset_id.in_(gtfs_dataset_ids),
~gtfs_link_exists,
)
gtfs_stop_links = _count(
session,
CanonicalStopLink,
CanonicalStopLink.object_type == "gtfs_stop",
CanonicalStopLink.dataset_id.in_(gtfs_dataset_ids),
)
multi_source_subquery = (
select(CanonicalStopLink.canonical_stop_id)
.where(CanonicalStopLink.object_type == "gtfs_stop", CanonicalStopLink.dataset_id.in_(gtfs_dataset_ids))
.group_by(CanonicalStopLink.canonical_stop_id)
.having(func.count(func.distinct(CanonicalStopLink.dataset_id)) > 1)
.subquery()
)
multi_source_stop_groups = int(session.scalar(select(func.count()).select_from(multi_source_subquery)) or 0)
else:
gtfs_stops_without_canonical = 0
gtfs_stop_links = 0
multi_source_stop_groups = 0
if osm_dataset_ids:
osm_link_exists = (
select(CanonicalStopLink.id)
.where(
CanonicalStopLink.object_type == "osm_feature",
CanonicalStopLink.dataset_id == OsmFeature.dataset_id,
CanonicalStopLink.object_id == OsmFeature.id,
)
.exists()
)
osm_stops_without_canonical = _count(
session,
OsmFeature,
OsmFeature.dataset_id.in_(osm_dataset_ids),
OsmFeature.kind.in_(["stop", "station", "terminal"]),
~osm_link_exists,
)
osm_stop_links = _count(
session,
CanonicalStopLink,
CanonicalStopLink.object_type == "osm_feature",
CanonicalStopLink.dataset_id.in_(osm_dataset_ids),
)
long_distance_osm_links = _count(
session,
CanonicalStopLink,
CanonicalStopLink.object_type == "osm_feature",
CanonicalStopLink.dataset_id.in_(osm_dataset_ids),
CanonicalStopLink.distance_m > 150,
)
else:
osm_stops_without_canonical = 0
osm_stop_links = 0
long_distance_osm_links = 0
return {
"canonical_stops": _count(session, CanonicalStop),
"gtfs_stop_links": gtfs_stop_links,
"gtfs_stops_without_canonical": gtfs_stops_without_canonical,
"osm_stop_links": osm_stop_links,
"osm_stops_without_canonical": osm_stops_without_canonical,
"multi_source_stop_groups": multi_source_stop_groups,
"long_distance_osm_links": long_distance_osm_links,
}
def _route_quality_counts(session: Session, gtfs_dataset_ids: list[int]) -> dict[str, int]:
route_patterns = _count(session, RoutePattern)
route_pattern_stop_exists = (
select(RoutePatternStop.id)
.where(RoutePatternStop.route_pattern_id == RoutePattern.id)
.exists()
)
route_patterns_without_stops = _count(session, RoutePattern, ~route_pattern_stop_exists)
if not gtfs_dataset_ids:
return {
"matched_or_accepted": 0,
"probable": 0,
"weak": 0,
"missing": 0,
"routes_without_match": 0,
"route_patterns": route_patterns,
"route_patterns_without_stops": route_patterns_without_stops,
}
match_rows = {
str(status): int(count)
for status, count in session.execute(
select(RouteMatch.status, func.count())
.join(GtfsRoute, GtfsRoute.id == RouteMatch.gtfs_route_id)
.where(GtfsRoute.dataset_id.in_(gtfs_dataset_ids))
.group_by(RouteMatch.status)
).all()
}
match_exists = select(RouteMatch.id).where(RouteMatch.gtfs_route_id == GtfsRoute.id).exists()
routes_without_match = _count(session, GtfsRoute, GtfsRoute.dataset_id.in_(gtfs_dataset_ids), ~match_exists)
return {
"matched_or_accepted": match_rows.get("matched", 0) + match_rows.get("accepted", 0),
"probable": match_rows.get("probable", 0),
"weak": match_rows.get("weak", 0),
"missing": match_rows.get("missing", 0),
"routes_without_match": routes_without_match,
"route_patterns": route_patterns,
"route_patterns_without_stops": route_patterns_without_stops,
}

911
app/routing.py Normal file
View File

@@ -0,0 +1,911 @@
from __future__ import annotations
import copy
import heapq
import json
import math
import threading
import time
from collections import OrderedDict
from dataclasses import dataclass
from sqlalchemy import func, select, text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from app.config import settings
from app.models import Dataset, RoutingEdge, RoutingNode
from app.pipeline.routing_layer import active_routing_dataset
from app.serializers import feature_collection
WALK_HEURISTIC_MPS = 1.6
DRIVE_HEURISTIC_MPS = 36.0
DEFAULT_MAX_VISITED = 160_000
PGR_WALK_BBOX_PADDING_KM = [0.5, 1.5, 4, 10, 25]
PGR_DRIVE_BBOX_PADDING_KM = [2, 8, 25, 75, 200]
PGR_WALK_STATEMENT_TIMEOUT_MS = 2_500
PGR_DRIVE_STATEMENT_TIMEOUT_MS = 7_500
ROUTE_CACHE_TTL_SECONDS = 15 * 60
ROUTE_CACHE_MAX_ENTRIES = 512
_route_cache_lock = threading.RLock()
_route_cache: OrderedDict[tuple[object, ...], tuple[float, dict[str, object]]] = OrderedDict()
@dataclass(frozen=True)
class _GraphNode:
osm_node_id: int
lon: float
lat: float
distance_m: float
@dataclass(frozen=True)
class _Traversal:
edge_id: int
from_node: int
to_node: int
from_lon: float
from_lat: float
to_lon: float
to_lat: float
cost_s: float
length_m: float
highway: str | None
name: str | None
geometry_geojson: str
reversed: bool
def routing_status(db: Session) -> dict[str, object]:
dataset = active_routing_dataset(db)
dataset_id = None if dataset is None else int(dataset.id)
node_count = 0
edge_count = 0
if dataset_id is not None:
node_count, edge_count = _routing_status_counts(db, dataset, dataset_id)
pgrouting_available = False
pgrouting_installed = False
if settings.is_postgresql_database:
pgrouting_available = bool(
db.execute(text("SELECT EXISTS (SELECT 1 FROM pg_available_extensions WHERE name = 'pgrouting')")).scalar()
)
pgrouting_installed = bool(
db.execute(text("SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgrouting')")).scalar()
)
return {
"dataset_id": dataset_id,
"nodes": node_count,
"edges": edge_count,
"available": edge_count > 0,
"engine": "pgrouting" if pgrouting_installed else "python_astar",
"pgrouting_available": pgrouting_available,
"pgrouting_installed": pgrouting_installed,
}
def _routing_status_counts(db: Session, dataset: Dataset, dataset_id: int) -> tuple[int, int]:
metadata = _metadata(dataset)
routing_layer = metadata.get("routing_layer")
if isinstance(routing_layer, dict):
try:
nodes = int(routing_layer.get("nodes") or 0)
edges = int(routing_layer.get("edges") or 0)
except (TypeError, ValueError):
nodes = 0
edges = 0
if nodes or edges:
return nodes, edges
if settings.is_postgresql_database:
rows = db.execute(
text(
"""
SELECT relname, COALESCE(reltuples, 0)::bigint AS estimate
FROM pg_class
WHERE oid IN ('routing_nodes'::regclass, 'routing_edges'::regclass)
"""
)
).mappings()
estimates = {str(row["relname"]): int(row["estimate"] or 0) for row in rows}
return estimates.get("routing_nodes", 0), estimates.get("routing_edges", 0)
node_count = int(db.scalar(select(func.count()).select_from(RoutingNode).where(RoutingNode.dataset_id == dataset_id)) or 0)
edge_count = int(db.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset_id)) or 0)
return node_count, edge_count
def _metadata(dataset: Dataset) -> dict[str, object]:
if not dataset.metadata_json:
return {}
try:
value = json.loads(dataset.metadata_json)
except json.JSONDecodeError:
return {}
return value if isinstance(value, dict) else {}
def route_between_points(
db: Session,
*,
from_lon: float,
from_lat: float,
to_lon: float,
to_lat: float,
mode: str = "walk",
dataset_id: int | None = None,
max_visited: int = DEFAULT_MAX_VISITED,
) -> dict[str, object]:
if mode not in {"walk", "drive"}:
raise ValueError("mode must be walk or drive")
dataset = db.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(db)
if dataset is None:
raise ValueError("No routing dataset is available.")
dataset_id = int(dataset.id)
cache_key = _route_cache_key(dataset_id, mode, from_lon, from_lat, to_lon, to_lat)
cached = _route_cache_get(cache_key)
if cached is not None:
return cached
start = _nearest_node(db, dataset_id, from_lon, from_lat, mode)
target = _nearest_node(db, dataset_id, to_lon, to_lat, mode)
if start is None or target is None:
raise ValueError("Routing graph has no nearby nodes for the requested mode.")
if start.osm_node_id == target.osm_node_id:
payload = _single_point_route(start, from_lon, from_lat, to_lon, to_lat, mode, dataset_id)
_route_cache_put(cache_key, payload)
return payload
if settings.is_postgresql_database and _pgrouting_installed(db):
try:
payload = _route_with_pgrouting(
db,
dataset_id=dataset_id,
mode=mode,
start=start,
target=target,
from_lon=from_lon,
from_lat=from_lat,
to_lon=to_lon,
to_lat=to_lat,
)
_route_cache_put(cache_key, payload)
return payload
except ValueError:
pass
except SQLAlchemyError:
db.rollback()
heuristic_mps = WALK_HEURISTIC_MPS if mode == "walk" else DRIVE_HEURISTIC_MPS
queue: list[tuple[float, float, int]] = []
heapq.heappush(queue, (0.0, 0.0, start.osm_node_id))
costs: dict[int, float] = {start.osm_node_id: 0.0}
coords: dict[int, tuple[float, float]] = {start.osm_node_id: (start.lon, start.lat), target.osm_node_id: (target.lon, target.lat)}
previous: dict[int, tuple[int, _Traversal]] = {}
adjacency_cache: dict[int, list[_Traversal]] = {}
visited: set[int] = set()
while queue and len(visited) < max(1, max_visited):
_, cost, node_id = heapq.heappop(queue)
if node_id in visited:
continue
visited.add(node_id)
if node_id == target.osm_node_id:
payload = _route_payload(
dataset_id=dataset_id,
mode=mode,
start=start,
target=target,
from_lon=from_lon,
from_lat=from_lat,
to_lon=to_lon,
to_lat=to_lat,
previous=previous,
total_cost_s=cost,
visited=len(visited),
)
_route_cache_put(cache_key, payload)
return payload
for edge in adjacency_cache.setdefault(node_id, _outgoing_edges(db, dataset_id, node_id, mode)):
coords[edge.to_node] = (edge.to_lon, edge.to_lat)
next_cost = cost + edge.cost_s
if next_cost >= costs.get(edge.to_node, float("inf")):
continue
costs[edge.to_node] = next_cost
previous[edge.to_node] = (node_id, edge)
heuristic = _distance_m(edge.to_lat, edge.to_lon, target.lat, target.lon) / heuristic_mps
heapq.heappush(queue, (next_cost + heuristic, next_cost, edge.to_node))
raise ValueError(f"No {mode} route found within {max_visited:,} visited graph nodes.")
def direct_route_between_points(
db: Session,
*,
from_lon: float,
from_lat: float,
to_lon: float,
to_lat: float,
mode: str = "walk",
dataset_id: int | None = None,
reason: str | None = None,
) -> dict[str, object]:
if mode not in {"walk", "drive"}:
raise ValueError("mode must be walk or drive")
dataset = db.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(db)
payload = _direct_route_payload(
dataset_id=0 if dataset is None else int(dataset.id),
mode=mode,
from_lon=float(from_lon),
from_lat=float(from_lat),
to_lon=float(to_lon),
to_lat=float(to_lat),
)
if reason:
payload["warning"] = reason
return payload
def snap_point_to_routing_graph(
db: Session,
*,
lon: float,
lat: float,
mode: str = "walk",
dataset_id: int | None = None,
max_distance_m: float = 250,
) -> dict[str, object] | None:
if mode not in {"walk", "drive"}:
raise ValueError("mode must be walk or drive")
dataset = db.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(db)
if dataset is None:
return None
dataset_id = int(dataset.id)
if settings.is_postgresql_database:
return _snap_point_to_routing_edge_postgresql(
db,
dataset_id=dataset_id,
lon=float(lon),
lat=float(lat),
mode=mode,
max_distance_m=float(max_distance_m),
)
node = _nearest_node(db, dataset_id, float(lon), float(lat), mode)
if node is None or node.distance_m > max_distance_m:
return None
return {
"dataset_id": dataset_id,
"lon": node.lon,
"lat": node.lat,
"distance_m": round(node.distance_m, 1),
"source": "routing_node",
"osm_node_id": node.osm_node_id,
}
def _snap_point_to_routing_edge_postgresql(
db: Session,
*,
dataset_id: int,
lon: float,
lat: float,
mode: str,
max_distance_m: float,
) -> dict[str, object] | None:
cost_column = "walk_cost_s" if mode == "walk" else "drive_cost_s"
reverse_cost_column = "reverse_walk_cost_s" if mode == "walk" else "reverse_drive_cost_s"
radius_deg = max_distance_m / 111_320
row = db.execute(
text(
f"""
WITH point AS (
SELECT ST_SetSRID(ST_MakePoint(:lon, :lat), 4326) AS geom
),
edges AS MATERIALIZED (
SELECT
edge.id,
edge.highway,
edge.name,
CASE
WHEN edge.tags_json IS NULL OR edge.tags_json = '' THEN NULL
ELSE edge.tags_json::jsonb ->> 'service'
END AS service,
edge.source_osm_node_id,
edge.target_osm_node_id,
ST_SetSRID(
ST_MakeLine(
ST_MakePoint(edge.source_lon, edge.source_lat),
ST_MakePoint(edge.target_lon, edge.target_lat)
),
4326
) AS edge_geom
FROM routing_edges AS edge
CROSS JOIN point
WHERE edge.dataset_id = :dataset_id
AND (edge.{cost_column} IS NOT NULL OR edge.{reverse_cost_column} IS NOT NULL)
AND box(point(edge.max_lon, edge.max_lat), point(edge.min_lon, edge.min_lat))
&& box(
point(:lon + :radius_deg, :lat + :radius_deg),
point(:lon - :radius_deg, :lat - :radius_deg)
)
),
candidate AS (
SELECT
edges.id,
edges.highway,
edges.name,
edges.service,
edges.source_osm_node_id,
edges.target_osm_node_id,
ST_ClosestPoint(edges.edge_geom, point.geom) AS snapped_geom,
ST_DistanceSphere(edges.edge_geom, point.geom) AS distance_m,
CASE
WHEN edges.highway IN ('footway', 'pedestrian', 'steps') THEN 0
WHEN edges.highway IN ('path', 'cycleway', 'bridleway') THEN 1
WHEN edges.highway IN ('living_street', 'residential') THEN 2
WHEN edges.highway = 'service' THEN 3
ELSE 4
END AS highway_rank,
CASE
WHEN :mode != 'walk' THEN 0
WHEN edges.highway = 'service' THEN 20
WHEN edges.highway IN ('primary', 'primary_link', 'secondary', 'secondary_link') THEN 10
WHEN edges.highway IN ('tertiary', 'tertiary_link', 'unclassified', 'road') THEN 5
ELSE 0
END AS snap_penalty_m
FROM edges
CROSS JOIN point
WHERE ST_DWithin(edges.edge_geom::geography, point.geom::geography, :max_distance_m)
AND NOT (
:mode = 'walk'
AND edges.highway = 'service'
AND COALESCE(edges.service, '') IN ('driveway', 'parking_aisle', 'drive-through')
)
ORDER BY
ST_DistanceSphere(edges.edge_geom, point.geom) + CASE
WHEN :mode != 'walk' THEN 0
WHEN edges.highway = 'service' THEN 20
WHEN edges.highway IN ('primary', 'primary_link', 'secondary', 'secondary_link') THEN 10
WHEN edges.highway IN ('tertiary', 'tertiary_link', 'unclassified', 'road') THEN 5
ELSE 0
END,
ST_DistanceSphere(edges.edge_geom, point.geom),
highway_rank,
edges.id
LIMIT 1
)
SELECT
id,
highway,
name,
source_osm_node_id,
target_osm_node_id,
ST_X(snapped_geom) AS lon,
ST_Y(snapped_geom) AS lat,
distance_m
FROM candidate
"""
),
{
"dataset_id": dataset_id,
"lon": lon,
"lat": lat,
"radius_deg": radius_deg,
"max_distance_m": max_distance_m,
"mode": mode,
},
).mappings().first()
if row is None:
return None
return {
"dataset_id": dataset_id,
"lon": float(row["lon"]),
"lat": float(row["lat"]),
"distance_m": round(float(row["distance_m"] or 0), 1),
"source": "routing_edge",
"edge_id": int(row["id"]),
"highway": row["highway"],
"name": row["name"],
"source_osm_node_id": int(row["source_osm_node_id"]),
"target_osm_node_id": int(row["target_osm_node_id"]),
}
def _route_cache_key(dataset_id: int, mode: str, from_lon: float, from_lat: float, to_lon: float, to_lat: float) -> tuple[object, ...]:
return (
int(dataset_id),
mode,
round(float(from_lon), 6),
round(float(from_lat), 6),
round(float(to_lon), 6),
round(float(to_lat), 6),
)
def _route_cache_get(key: tuple[object, ...]) -> dict[str, object] | None:
now = time.monotonic()
with _route_cache_lock:
cached = _route_cache.get(key)
if cached is None:
return None
expires_at, payload = cached
if expires_at <= now:
_route_cache.pop(key, None)
return None
_route_cache.move_to_end(key)
return copy.deepcopy(payload)
def _route_cache_put(key: tuple[object, ...], payload: dict[str, object]) -> None:
with _route_cache_lock:
_route_cache[key] = (time.monotonic() + ROUTE_CACHE_TTL_SECONDS, copy.deepcopy(payload))
_route_cache.move_to_end(key)
while len(_route_cache) > ROUTE_CACHE_MAX_ENTRIES:
_route_cache.popitem(last=False)
def _pgrouting_installed(db: Session) -> bool:
return bool(db.execute(text("SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgrouting')")).scalar())
def _route_with_pgrouting(
db: Session,
*,
dataset_id: int,
mode: str,
start: _GraphNode,
target: _GraphNode,
from_lon: float,
from_lat: float,
to_lon: float,
to_lat: float,
) -> dict[str, object]:
cost_column = "walk_cost_s" if mode == "walk" else "drive_cost_s"
reverse_cost_column = "reverse_walk_cost_s" if mode == "walk" else "reverse_drive_cost_s"
routing_cost = _routing_cost_expression(cost_column, mode)
reverse_routing_cost = _routing_cost_expression(reverse_cost_column, mode)
for padding_km in PGR_WALK_BBOX_PADDING_KM if mode == "walk" else PGR_DRIVE_BBOX_PADDING_KM:
_set_local_statement_timeout(
db,
PGR_WALK_STATEMENT_TIMEOUT_MS if mode == "walk" else PGR_DRIVE_STATEMENT_TIMEOUT_MS,
)
bbox = _expanded_bbox(
min(from_lon, to_lon, start.lon, target.lon),
min(from_lat, to_lat, start.lat, target.lat),
max(from_lon, to_lon, start.lon, target.lon),
max(from_lat, to_lat, start.lat, target.lat),
padding_km,
)
edge_sql = f"""
SELECT
id,
source_osm_node_id AS source,
target_osm_node_id AS target,
COALESCE({routing_cost}, -1)::float8 AS cost,
COALESCE({reverse_routing_cost}, -1)::float8 AS reverse_cost
FROM routing_edges
WHERE dataset_id = {int(dataset_id)}
AND ({cost_column} IS NOT NULL OR {reverse_cost_column} IS NOT NULL)
AND box(point(max_lon, max_lat), point(min_lon, min_lat))
&& box(point({bbox[2]:.8f}, {bbox[3]:.8f}), point({bbox[0]:.8f}, {bbox[1]:.8f}))
"""
rows = db.execute(
text(
f"""
WITH route AS (
SELECT *
FROM pgr_dijkstra(:edge_sql, :start_node, :target_node, directed := true)
),
steps AS (
SELECT
route.path_seq,
route.node AS from_node,
LEAD(route.node) OVER (ORDER BY route.path_seq) AS to_node,
route.edge,
route.cost
FROM route
)
SELECT
steps.path_seq,
steps.from_node,
steps.to_node,
steps.cost,
edge.id,
edge.source_osm_node_id,
edge.target_osm_node_id,
edge.source_lon,
edge.source_lat,
edge.target_lon,
edge.target_lat,
edge.length_m,
edge.highway,
edge.name,
edge.geometry_geojson,
CASE
WHEN steps.from_node = edge.source_osm_node_id THEN edge.{cost_column}
ELSE edge.{reverse_cost_column}
END AS actual_cost_s
FROM steps
JOIN routing_edges AS edge ON edge.id = steps.edge
WHERE steps.edge <> -1
ORDER BY steps.path_seq
"""
),
{"edge_sql": edge_sql, "start_node": start.osm_node_id, "target_node": target.osm_node_id},
).all()
if rows:
return _pgrouting_payload(
dataset_id=dataset_id,
mode=mode,
start=start,
target=target,
from_lon=from_lon,
from_lat=from_lat,
to_lon=to_lon,
to_lat=to_lat,
rows=rows,
padding_km=padding_km,
)
raise ValueError("pgRouting did not find a route in the bounded search area.")
def _set_local_statement_timeout(db: Session, timeout_ms: int) -> None:
db.execute(text("SELECT set_config('statement_timeout', :timeout, true)"), {"timeout": f"{int(timeout_ms)}ms"})
def _pgrouting_payload(
*,
dataset_id: int,
mode: str,
start: _GraphNode,
target: _GraphNode,
from_lon: float,
from_lat: float,
to_lon: float,
to_lat: float,
rows,
padding_km: float,
) -> dict[str, object]:
previous: dict[int, tuple[int, _Traversal]] = {}
total_cost = 0.0
for row in rows:
if row.to_node is None:
continue
from_node = int(row.from_node)
to_node = int(row.to_node)
source_node = int(row.source_osm_node_id)
target_node = int(row.target_osm_node_id)
actual_cost = float(row.actual_cost_s if row.actual_cost_s is not None else row.cost or 0)
reversed_edge = from_node == target_node and to_node == source_node
if reversed_edge:
from_lon_edge, from_lat_edge = float(row.target_lon), float(row.target_lat)
to_lon_edge, to_lat_edge = float(row.source_lon), float(row.source_lat)
else:
from_lon_edge, from_lat_edge = float(row.source_lon), float(row.source_lat)
to_lon_edge, to_lat_edge = float(row.target_lon), float(row.target_lat)
total_cost += actual_cost
previous[to_node] = (
from_node,
_Traversal(
edge_id=int(row.id),
from_node=from_node,
to_node=to_node,
from_lon=from_lon_edge,
from_lat=from_lat_edge,
to_lon=to_lon_edge,
to_lat=to_lat_edge,
cost_s=actual_cost,
length_m=float(row.length_m),
highway=row.highway,
name=row.name,
geometry_geojson=str(row.geometry_geojson),
reversed=reversed_edge,
),
)
payload = _route_payload(
dataset_id=dataset_id,
mode=mode,
start=start,
target=target,
from_lon=from_lon,
from_lat=from_lat,
to_lon=to_lon,
to_lat=to_lat,
previous=previous,
total_cost_s=total_cost,
visited=len(rows),
)
payload["engine"] = "pgrouting"
payload["bbox_padding_km"] = padding_km
return payload
def _routing_cost_expression(column: str, mode: str) -> str:
if mode != "walk":
return column
return f"""
CASE
WHEN {column} IS NULL THEN NULL
ELSE {column} * CASE
WHEN highway IN ('footway', 'pedestrian') THEN 0.70
WHEN highway = 'path' THEN 0.78
WHEN highway = 'steps' THEN 0.95
WHEN highway = 'cycleway' THEN 1.05
WHEN highway = 'bridleway' THEN 1.10
WHEN highway IN ('living_street', 'track') THEN 1.15
WHEN highway IN ('residential', 'service') THEN 1.35
WHEN highway IN ('unclassified', 'road') THEN 1.55
WHEN highway IN ('tertiary', 'tertiary_link') THEN 1.80
WHEN highway IN ('secondary', 'secondary_link') THEN 2.15
WHEN highway IN ('primary', 'primary_link') THEN 2.50
ELSE 1.30
END
END
"""
def _nearest_node(db: Session, dataset_id: int, lon: float, lat: float, mode: str) -> _GraphNode | None:
cost_column = "walk_cost_s" if mode == "walk" else "drive_cost_s"
reverse_cost_column = "reverse_walk_cost_s" if mode == "walk" else "reverse_drive_cost_s"
row = None
for candidate_limit in (64, 512, 4096):
row = db.execute(
text(
f"""
WITH nearest AS MATERIALIZED (
SELECT node.osm_node_id, node.lon, node.lat, node.geom
FROM routing_nodes AS node
WHERE node.dataset_id = :dataset_id
AND node.geom IS NOT NULL
ORDER BY node.geom <-> ST_SetSRID(ST_MakePoint(:lon, :lat), 4326)
LIMIT :candidate_limit
),
candidate AS (
SELECT nearest.osm_node_id, nearest.lon, nearest.lat, nearest.geom
FROM nearest
WHERE EXISTS (
SELECT 1
FROM routing_edges AS edge
WHERE edge.dataset_id = :dataset_id
AND (
(edge.source_osm_node_id = nearest.osm_node_id AND edge.{cost_column} IS NOT NULL)
OR (edge.target_osm_node_id = nearest.osm_node_id AND edge.{reverse_cost_column} IS NOT NULL)
)
LIMIT 1
)
ORDER BY nearest.geom <-> ST_SetSRID(ST_MakePoint(:lon, :lat), 4326)
LIMIT 1
)
SELECT osm_node_id, lon, lat, ST_DistanceSphere(geom, ST_SetSRID(ST_MakePoint(:lon, :lat), 4326)) AS distance_m
FROM candidate
"""
),
{"dataset_id": dataset_id, "lon": lon, "lat": lat, "candidate_limit": candidate_limit},
).first()
if row is not None:
break
if row is None:
return None
return _GraphNode(osm_node_id=int(row.osm_node_id), lon=float(row.lon), lat=float(row.lat), distance_m=float(row.distance_m or 0))
def _outgoing_edges(db: Session, dataset_id: int, node_id: int, mode: str) -> list[_Traversal]:
cost_column = "walk_cost_s" if mode == "walk" else "drive_cost_s"
reverse_cost_column = "reverse_walk_cost_s" if mode == "walk" else "reverse_drive_cost_s"
rows = db.execute(
text(
f"""
SELECT
id, source_osm_node_id, target_osm_node_id,
source_lon, source_lat, target_lon, target_lat,
length_m, highway, name, geometry_geojson,
CASE
WHEN source_osm_node_id = :node_id THEN {cost_column}
ELSE {reverse_cost_column}
END AS cost_s,
target_osm_node_id != :node_id AS forward
FROM routing_edges
WHERE dataset_id = :dataset_id
AND (
(source_osm_node_id = :node_id AND {cost_column} IS NOT NULL)
OR (target_osm_node_id = :node_id AND {reverse_cost_column} IS NOT NULL)
)
"""
),
{"dataset_id": dataset_id, "node_id": node_id},
).all()
edges = []
for row in rows:
forward = bool(row.forward)
if forward:
to_node = int(row.target_osm_node_id)
from_lon, from_lat = float(row.source_lon), float(row.source_lat)
to_lon, to_lat = float(row.target_lon), float(row.target_lat)
else:
to_node = int(row.source_osm_node_id)
from_lon, from_lat = float(row.target_lon), float(row.target_lat)
to_lon, to_lat = float(row.source_lon), float(row.source_lat)
edges.append(
_Traversal(
edge_id=int(row.id),
from_node=node_id,
to_node=to_node,
from_lon=from_lon,
from_lat=from_lat,
to_lon=to_lon,
to_lat=to_lat,
cost_s=float(row.cost_s),
length_m=float(row.length_m),
highway=row.highway,
name=row.name,
geometry_geojson=str(row.geometry_geojson),
reversed=not forward,
)
)
return edges
def _route_payload(
*,
dataset_id: int,
mode: str,
start: _GraphNode,
target: _GraphNode,
from_lon: float,
from_lat: float,
to_lon: float,
to_lat: float,
previous: dict[int, tuple[int, _Traversal]],
total_cost_s: float,
visited: int,
) -> dict[str, object]:
edges: list[_Traversal] = []
current = target.osm_node_id
while current != start.osm_node_id:
prior, edge = previous[current]
edges.append(edge)
current = prior
edges.reverse()
network_distance = sum(edge.length_m for edge in edges)
access_distance = start.distance_m + target.distance_m
features = []
if start.distance_m:
features.append(_connector_feature("access", mode, [[from_lon, from_lat], [start.lon, start.lat]], start.distance_m))
for index, edge in enumerate(edges, start=1):
geometry = json.loads(edge.geometry_geojson)
if edge.reversed:
geometry["coordinates"] = list(reversed(geometry.get("coordinates", [])))
features.append(
{
"type": "Feature",
"geometry": geometry,
"properties": {
"feature_type": "routing_edge",
"sequence": index,
"mode": mode,
"edge_id": edge.edge_id,
"highway": edge.highway,
"name": edge.name,
"length_m": edge.length_m,
"cost_s": edge.cost_s,
},
}
)
if target.distance_m:
features.append(_connector_feature("egress", mode, [[target.lon, target.lat], [to_lon, to_lat]], target.distance_m))
duration_seconds = total_cost_s + _connector_seconds(access_distance, mode)
return {
"dataset_id": dataset_id,
"mode": mode,
"engine": "python_astar",
"distance_m": round(network_distance + access_distance, 1),
"network_distance_m": round(network_distance, 1),
"access_distance_m": round(access_distance, 1),
"duration_seconds": round(duration_seconds, 1),
"duration_minutes": _duration_minutes_ceil(duration_seconds),
"duration_label": _duration_label(duration_seconds),
"visited_nodes": visited,
"start_node": {"osm_node_id": start.osm_node_id, "distance_m": round(start.distance_m, 1)},
"target_node": {"osm_node_id": target.osm_node_id, "distance_m": round(target.distance_m, 1)},
"features": feature_collection(features),
}
def _single_point_route(start: _GraphNode, from_lon: float, from_lat: float, to_lon: float, to_lat: float, mode: str, dataset_id: int) -> dict[str, object]:
return _direct_route_payload(
dataset_id=dataset_id,
mode=mode,
from_lon=from_lon,
from_lat=from_lat,
to_lon=to_lon,
to_lat=to_lat,
engine="python_astar",
start_node={"osm_node_id": start.osm_node_id, "distance_m": round(start.distance_m, 1)},
target_node={"osm_node_id": start.osm_node_id, "distance_m": round(start.distance_m, 1)},
visited_nodes=1,
)
def _direct_route_payload(
*,
dataset_id: int,
mode: str,
from_lon: float,
from_lat: float,
to_lon: float,
to_lat: float,
engine: str = "direct_fallback",
start_node: dict[str, object] | None = None,
target_node: dict[str, object] | None = None,
visited_nodes: int = 0,
) -> dict[str, object]:
distance = _distance_m(from_lat, from_lon, to_lat, to_lon)
duration_seconds = _connector_seconds(distance, mode)
return {
"dataset_id": dataset_id,
"mode": mode,
"engine": engine,
"distance_m": round(distance, 1),
"network_distance_m": 0,
"access_distance_m": round(distance, 1),
"duration_seconds": round(duration_seconds, 1),
"duration_minutes": _duration_minutes_ceil(duration_seconds),
"duration_label": _duration_label(duration_seconds),
"visited_nodes": visited_nodes,
"start_node": start_node,
"target_node": target_node,
"features": feature_collection([_connector_feature("direct", mode, [[from_lon, from_lat], [to_lon, to_lat]], distance)]),
}
def _connector_feature(kind: str, mode: str, coordinates: list[list[float]], distance_m: float) -> dict:
return {
"type": "Feature",
"geometry": {"type": "LineString", "coordinates": coordinates},
"properties": {
"feature_type": "routing_connector",
"connector": kind,
"mode": mode,
"length_m": distance_m,
"cost_s": _connector_seconds(distance_m, mode),
},
}
def _connector_seconds(distance_m: float, mode: str) -> float:
speed = 1.35 if mode == "walk" else 8.0
return float(distance_m) / speed
def _duration_minutes_ceil(seconds: int | float | None) -> int | None:
if seconds is None:
return None
return max(0, int(math.ceil(float(seconds) / 60)))
def _duration_label(seconds: int | float | None) -> str | None:
minutes_total = _duration_minutes_ceil(seconds)
if minutes_total is None:
return None
days = minutes_total // (24 * 60)
remaining = minutes_total % (24 * 60)
hours = remaining // 60
minutes = remaining % 60
if days:
return f"{days}d {hours:02d}:{minutes:02d}"
if hours:
return f"{hours}:{minutes:02d}"
return f"{minutes} min"
def _expanded_bbox(min_lon: float, min_lat: float, max_lon: float, max_lat: float, padding_km: float) -> tuple[float, float, float, float]:
mid_lat = (min_lat + max_lat) / 2
lat_delta = padding_km / 111.0
lon_delta = padding_km / max(1.0, 111.0 * math.cos(math.radians(mid_lat)))
return (min_lon - lon_delta, min_lat - lat_delta, max_lon + lon_delta, max_lat + lat_delta)
def _distance_m(lat_a: float, lon_a: float, lat_b: float, lon_b: float) -> float:
radius = 6_371_000.0
phi_a = math.radians(lat_a)
phi_b = math.radians(lat_b)
delta_phi = math.radians(lat_b - lat_a)
delta_lambda = math.radians(lon_b - lon_a)
hav = math.sin(delta_phi / 2) ** 2 + math.cos(phi_a) * math.cos(phi_b) * math.sin(delta_lambda / 2) ** 2
return radius * 2 * math.atan2(math.sqrt(hav), math.sqrt(1 - hav))

130
app/serializers.py Normal file
View File

@@ -0,0 +1,130 @@
from __future__ import annotations
import json
from typing import Any, Iterable
from app.models import GtfsRoute, GtfsStop, OsmFeature, RouteMatch, RoutePattern
from app.osm_storage import osm_feature_public_id
def feature_collection(features: Iterable[dict[str, Any]]) -> dict[str, Any]:
return {"type": "FeatureCollection", "features": list(features)}
def gtfs_route_feature(route: GtfsRoute, extra: dict[str, Any] | None = None) -> dict[str, Any] | None:
if not route.geometry_geojson:
return None
props = {
"id": route.id,
"dataset_id": route.dataset_id,
"route_id": route.route_id,
"mode": route.mode,
"route_scope": route.route_scope,
"ref": route.short_name,
"name": route.long_name,
"operator": route.operator_name,
"source": "gtfs",
}
if extra:
props.update(extra)
return {"type": "Feature", "geometry": json.loads(route.geometry_geojson), "properties": props}
def osm_feature_feature(feature: OsmFeature, extra: dict[str, Any] | None = None) -> dict[str, Any] | None:
if not feature.geometry_geojson:
return None
props = {
"id": osm_feature_public_id(feature),
"row_id": feature.id,
"dataset_id": feature.dataset_id,
"osm_type": feature.osm_type,
"osm_id": feature.osm_id,
"kind": feature.kind,
"mode": feature.mode,
"route_scope": feature.route_scope,
"ref": feature.ref,
"name": feature.name,
"operator": feature.operator,
"network": feature.network,
"source": "osm",
}
if extra:
props.update(extra)
return {"type": "Feature", "geometry": json.loads(feature.geometry_geojson), "properties": props}
def route_pattern_feature(pattern: RoutePattern, extra: dict[str, Any] | None = None) -> dict[str, Any] | None:
if not pattern.geometry_geojson:
return None
props = {
"id": pattern.id,
"route_pattern_id": pattern.id,
"route_ref": pattern.route_ref,
"ref": pattern.route_ref,
"name": pattern.route_name,
"mode": pattern.mode,
"route_scope": pattern.route_scope,
"operator": pattern.operator_name,
"source": "route_layer",
"source_kind": pattern.source_kind,
"status": pattern.status,
"confidence": pattern.confidence,
"osm_feature_id": pattern.osm_feature_id,
"gtfs_route_id": pattern.gtfs_route_id,
"gtfs_shape_id": pattern.gtfs_shape_id,
}
if extra:
props.update(extra)
return {"type": "Feature", "geometry": json.loads(pattern.geometry_geojson), "properties": props}
def gtfs_stop_feature(stop: GtfsStop) -> dict[str, Any] | None:
if stop.lon is None or stop.lat is None:
return None
return {
"type": "Feature",
"geometry": {"type": "Point", "coordinates": [stop.lon, stop.lat]},
"properties": {
"id": stop.id,
"dataset_id": stop.dataset_id,
"stop_id": stop.stop_id,
"name": stop.name,
"source": "gtfs",
},
}
def match_row(match: RouteMatch) -> dict[str, Any]:
route = match.gtfs_route
feature = match.osm_feature
return {
"id": match.id,
"status": match.status,
"confidence": match.confidence,
"rule_source": match.rule_source,
"gtfs": {
"id": route.id,
"dataset_id": route.dataset_id,
"route_id": route.route_id,
"mode": route.mode,
"route_scope": route.route_scope,
"ref": route.short_name,
"name": route.long_name,
"operator": route.operator_name,
},
"osm": None
if feature is None
else {
"id": feature.id,
"dataset_id": feature.dataset_id,
"osm_type": feature.osm_type,
"osm_id": feature.osm_id,
"mode": feature.mode,
"route_scope": feature.route_scope,
"ref": feature.ref,
"name": feature.name,
"operator": feature.operator,
"network": feature.network,
},
"reasons": json.loads(match.reasons_json or "{}"),
}

309
app/source_catalog.py Normal file
View File

@@ -0,0 +1,309 @@
from __future__ import annotations
import csv
import hashlib
from datetime import datetime, timezone
from pathlib import Path
from typing import Iterable
from sqlalchemy import func, or_, select
from sqlalchemy.orm import Session
from app.models import Source, SourceCatalogEntry
DIRECT_INGEST_KINDS = {"gtfs", "osm_geojson", "osm_pbf"}
def default_source_catalog_path() -> Path:
return Path(__file__).resolve().parents[1] / "docs" / "source_catalog_seed.csv"
def default_ingestable_sources_path() -> Path:
return Path(__file__).resolve().parents[1] / "docs" / "ingestable_sources_seed.csv"
def import_source_catalog(session: Session, path: Path | str | None = None, *, update_existing: bool = True) -> dict[str, int]:
csv_path = _resolve_path(path, default_source_catalog_path())
rows = _read_csv(csv_path)
created = 0
updated = 0
skipped = 0
for row in rows:
source_name = _value(row, "Source name")
if not source_name:
skipped += 1
continue
payload = {
"catalog_key": _catalog_key(row),
"geography": _value(row, "Geography"),
"country_code": _value(row, "Country code"),
"mode_scope": _value(row, "Mode scope"),
"source_name": source_name,
"source_category": _value(row, "Source category"),
"formats_apis": _value(row, "Formats / APIs"),
"availability": _value(row, "Availability"),
"coverage_notes": _value(row, "Coverage notes"),
"geometry_notes": _value(row, "Supersedes OSM for"),
"disruptions_closures": _value(row, "Disruptions / closures"),
"operator_list_use": _value(row, "Operator-list use"),
"access_license_notes": _value(row, "Access / licence notes"),
"priority": _value(row, "Priority"),
"source_url": _value(row, "Source URL"),
"evidence_url": _value(row, "Evidence URL"),
"next_pipeline_action": _value(row, "Next pipeline action"),
}
existing = session.scalar(select(SourceCatalogEntry).where(SourceCatalogEntry.catalog_key == payload["catalog_key"]))
if existing is None:
session.add(SourceCatalogEntry(**payload))
created += 1
continue
if not update_existing:
skipped += 1
continue
for key, value in payload.items():
setattr(existing, key, value)
existing.updated_at = datetime.now(timezone.utc)
updated += 1
session.flush()
return {"created": created, "updated": updated, "skipped": skipped}
def import_ingestable_sources(
session: Session,
path: Path | str | None = None,
*,
update_existing: bool = True,
) -> dict[str, int]:
csv_path = _resolve_path(path, default_ingestable_sources_path())
rows = _read_csv(csv_path)
created = 0
updated = 0
skipped = 0
linked_catalog = 0
for row in rows:
name = _value(row, "name")
kind = (_value(row, "kind") or "").lower()
url = _value(row, "url")
if not name or not url or kind not in DIRECT_INGEST_KINDS:
skipped += 1
continue
catalog_entry = _catalog_entry_for_ingestable_row(session, row)
payload = {
"name": name,
"kind": kind,
"url": url,
"country": _value(row, "country"),
"license": _value(row, "license"),
"priority": _value(row, "priority"),
"mode_scope": _value(row, "mode_scope"),
"source_basis": _value(row, "source_basis"),
"notes": _value(row, "notes"),
"catalog_entry_id": None if catalog_entry is None else catalog_entry.id,
}
existing = session.scalar(
select(Source)
.where(Source.kind == kind, Source.url == url)
.order_by(Source.id)
.limit(1)
)
if existing is None:
existing = session.scalar(select(Source).where(Source.name == name, Source.url == url).order_by(Source.id).limit(1))
if existing is None:
session.add(Source(**payload))
created += 1
if catalog_entry is not None:
linked_catalog += 1
continue
if not update_existing:
skipped += 1
continue
for key, value in payload.items():
setattr(existing, key, value)
existing.enabled = True
updated += 1
if catalog_entry is not None:
linked_catalog += 1
session.flush()
return {"created": created, "updated": updated, "skipped": skipped, "linked_catalog": linked_catalog}
def source_catalog_summary(session: Session) -> dict[str, object]:
priority_counts = {
priority or "unknown": count
for priority, count in session.execute(
select(SourceCatalogEntry.priority, func.count()).group_by(SourceCatalogEntry.priority)
).all()
}
status_counts = {
status or "unknown": count
for status, count in session.execute(select(SourceCatalogEntry.status, func.count()).group_by(SourceCatalogEntry.status)).all()
}
ingestable_sources = session.scalar(
select(func.count()).select_from(Source).where(Source.source_basis.is_not(None) | Source.priority.is_not(None))
) or 0
return {
"catalog_entries": session.scalar(select(func.count()).select_from(SourceCatalogEntry)) or 0,
"catalog_by_priority": priority_counts,
"catalog_by_status": status_counts,
"seeded_ingestable_sources": ingestable_sources,
}
def source_catalog_rows(
session: Session,
*,
q: str | None = None,
country: str | None = None,
priority: str | None = None,
status: str | None = None,
limit: int = 100,
) -> list[SourceCatalogEntry]:
stmt = select(SourceCatalogEntry).order_by(
SourceCatalogEntry.priority,
SourceCatalogEntry.country_code,
SourceCatalogEntry.source_name,
SourceCatalogEntry.id,
)
if q:
pattern = f"%{q.strip()}%"
stmt = stmt.where(
or_(
SourceCatalogEntry.source_name.ilike(pattern),
SourceCatalogEntry.source_category.ilike(pattern),
SourceCatalogEntry.formats_apis.ilike(pattern),
SourceCatalogEntry.coverage_notes.ilike(pattern),
SourceCatalogEntry.next_pipeline_action.ilike(pattern),
)
)
if country:
stmt = stmt.where(SourceCatalogEntry.country_code.ilike(f"%{country.strip()}%"))
if priority:
stmt = stmt.where(SourceCatalogEntry.priority == priority.strip())
if status:
stmt = stmt.where(SourceCatalogEntry.status == status.strip())
return session.scalars(stmt.limit(max(1, min(limit, 500)))).all()
def catalog_entry_payload(entry: SourceCatalogEntry, *, linked_source_count: int = 0) -> dict[str, object]:
return {
"id": entry.id,
"geography": entry.geography,
"country_code": entry.country_code,
"mode_scope": entry.mode_scope,
"source_name": entry.source_name,
"source_category": entry.source_category,
"formats_apis": entry.formats_apis,
"availability": entry.availability,
"coverage_notes": entry.coverage_notes,
"geometry_notes": entry.geometry_notes,
"disruptions_closures": entry.disruptions_closures,
"operator_list_use": entry.operator_list_use,
"access_license_notes": entry.access_license_notes,
"priority": entry.priority,
"source_url": entry.source_url,
"evidence_url": entry.evidence_url,
"next_pipeline_action": entry.next_pipeline_action,
"status": entry.status,
"linked_source_count": linked_source_count,
"created_at": entry.created_at.isoformat() if entry.created_at else None,
"updated_at": entry.updated_at.isoformat() if entry.updated_at else None,
}
def linked_source_counts(session: Session, entries: Iterable[SourceCatalogEntry]) -> dict[int, int]:
entry_ids = [entry.id for entry in entries]
if not entry_ids:
return {}
return {
entry_id: count
for entry_id, count in session.execute(
select(Source.catalog_entry_id, func.count())
.where(Source.catalog_entry_id.in_(entry_ids))
.group_by(Source.catalog_entry_id)
).all()
if entry_id is not None
}
def _catalog_entry_for_ingestable_row(session: Session, row: dict[str, str]) -> SourceCatalogEntry | None:
country = _value(row, "country")
source_basis = _value(row, "source_basis")
name = _value(row, "name")
if not country and not source_basis and not name:
return None
if name:
exact = session.scalar(
select(SourceCatalogEntry)
.where(func.lower(SourceCatalogEntry.source_name) == name.lower())
.order_by(SourceCatalogEntry.id)
.limit(1)
)
if exact is not None:
return exact
clauses = []
if country:
clauses.append(SourceCatalogEntry.country_code.ilike(f"%{country}%"))
if source_basis:
for token in _basis_tokens(source_basis):
clauses.append(SourceCatalogEntry.source_name.ilike(f"%{token}%"))
clauses.append(SourceCatalogEntry.coverage_notes.ilike(f"%{token}%"))
if name:
first_word = name.split()[0]
if len(first_word) > 2:
clauses.append(SourceCatalogEntry.source_name.ilike(f"%{first_word}%"))
if not clauses:
return None
return session.scalar(
select(SourceCatalogEntry)
.where(or_(*clauses))
.order_by(SourceCatalogEntry.priority, SourceCatalogEntry.id)
.limit(1)
)
def _basis_tokens(value: str) -> list[str]:
tokens = []
for raw in value.replace("/", " ").replace("-", " ").split():
token = raw.strip(" ,.;()")
if len(token) >= 5 and token.lower() not in {"official", "mirror", "feeds", "transport"}:
tokens.append(token)
return tokens[:4]
def _catalog_key(row: dict[str, str]) -> str:
parts = [
_value(row, "Country code"),
_value(row, "Source name"),
_value(row, "Source URL"),
_value(row, "Formats / APIs"),
]
text = "|".join(part.lower() for part in parts if part)
if not text:
text = repr(sorted(row.items()))
return hashlib.sha256(text.encode("utf-8")).hexdigest()
def _read_csv(path: Path) -> list[dict[str, str]]:
if not path.exists():
raise FileNotFoundError(path)
with path.open("r", encoding="utf-8-sig", newline="") as handle:
reader = csv.DictReader(handle)
return [dict(row) for row in reader]
def _resolve_path(path: Path | str | None, default_path: Path) -> Path:
if path is None:
return default_path
candidate = Path(path)
if candidate.is_absolute():
return candidate
return Path.cwd() / candidate
def _value(row: dict[str, str], key: str) -> str | None:
value = row.get(key)
if value is None:
return None
stripped = value.strip()
return stripped or None

256
app/source_updates.py Normal file
View File

@@ -0,0 +1,256 @@
from __future__ import annotations
import json
from datetime import datetime, timezone
from pathlib import Path
from urllib.parse import urlparse
import requests
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.config import settings
from app.models import Dataset, Source, SourceUpdateCheck
from app.pipeline.utils import norm_text, sha256_file
def check_source_for_update(session: Session, source: Source) -> SourceUpdateCheck:
active_dataset = session.scalar(
select(Dataset)
.where(Dataset.source_id == source.id, Dataset.is_active.is_(True))
.order_by(Dataset.created_at.desc(), Dataset.id.desc())
)
recovery = _recover_missing_managed_cache_url(source)
remote = _source_remote_metadata(source)
if recovery is not None:
remote["recovered_source_url"] = recovery["url"]
remote["previous_source_url"] = recovery["previous_url"]
update_available, reason = _update_decision(active_dataset, remote)
check = SourceUpdateCheck(
source_id=source.id,
status=remote["status"],
update_available=update_available,
reason=reason,
remote_url=source.url,
etag=remote.get("etag"),
last_modified=remote.get("last_modified"),
content_length=remote.get("content_length"),
content_type=remote.get("content_type"),
local_mtime=remote.get("local_mtime"),
local_size=remote.get("local_size"),
local_sha256=remote.get("local_sha256"),
active_dataset_id=None if active_dataset is None else active_dataset.id,
active_dataset_sha256=None if active_dataset is None else active_dataset.sha256,
metadata_json=json.dumps(remote, separators=(",", ":"), default=_json_default),
)
session.add(check)
source.status = "update_check_error" if remote["status"] != "checked" else "update_available" if update_available else "up_to_date"
source.last_error = None if remote["status"] == "checked" else reason
session.flush()
return check
def latest_source_update_check(session: Session, source_id: int) -> SourceUpdateCheck | None:
return session.scalar(
select(SourceUpdateCheck)
.where(SourceUpdateCheck.source_id == source_id)
.order_by(SourceUpdateCheck.checked_at.desc(), SourceUpdateCheck.id.desc())
)
def update_check_payload(check: SourceUpdateCheck | None) -> dict | None:
if check is None:
return None
try:
metadata = json.loads(check.metadata_json or "{}")
except json.JSONDecodeError:
metadata = {}
return {
"id": check.id,
"source_id": check.source_id,
"checked_at": check.checked_at.isoformat() if check.checked_at else None,
"status": check.status,
"update_available": check.update_available,
"reason": check.reason,
"etag": check.etag,
"last_modified": check.last_modified,
"content_length": check.content_length,
"content_type": check.content_type,
"local_mtime": check.local_mtime.isoformat() if check.local_mtime else None,
"local_size": check.local_size,
"local_sha256": check.local_sha256,
"active_dataset_id": check.active_dataset_id,
"active_dataset_sha256": check.active_dataset_sha256,
"metadata": metadata,
}
def record_dataset_update_metadata(dataset: Dataset, check: SourceUpdateCheck | None) -> None:
if check is None:
return
try:
metadata = json.loads(dataset.metadata_json or "{}")
except json.JSONDecodeError:
metadata = {}
metadata["source_update_check"] = {
"id": check.id,
"checked_at": check.checked_at.isoformat() if check.checked_at else None,
"etag": check.etag,
"last_modified": check.last_modified,
"content_length": check.content_length,
"content_type": check.content_type,
"local_mtime": check.local_mtime.isoformat() if check.local_mtime else None,
"local_size": check.local_size,
"local_sha256": check.local_sha256,
"metadata": update_check_payload(check).get("metadata", {}),
}
dataset.metadata_json = json.dumps(metadata, indent=2, default=_json_default)
def _source_remote_metadata(source: Source) -> dict:
parsed = urlparse(source.url)
if parsed.scheme in {"http", "https"}:
return _http_metadata(source.url)
path = Path(parsed.path) if parsed.scheme == "file" else Path(source.url)
return _local_metadata(path)
def _recover_missing_managed_cache_url(source: Source) -> dict | None:
parsed = urlparse(source.url)
if parsed.scheme in {"http", "https"}:
return None
path = Path(parsed.path) if parsed.scheme == "file" else Path(source.url)
if path.exists() or not _is_managed_source_cache_path(path, source.id):
return None
replacement = _seed_source_url_for(source)
if replacement is None:
return None
previous_url = source.url
source.url = replacement
return {"previous_url": previous_url, "url": replacement}
def _is_managed_source_cache_path(path: Path, source_id: int) -> bool:
source_dir = f"source_{source_id}"
try:
resolved = path.resolve()
managed_dir = (settings.data_dir / "sources" / source_dir).resolve()
resolved.relative_to(managed_dir)
return True
except ValueError:
pass
parts = path.parts
return any(part == "sources" and index + 1 < len(parts) and parts[index + 1] == source_dir for index, part in enumerate(parts))
def _seed_source_url_for(source: Source) -> str | None:
seed_path = Path(__file__).resolve().parents[1] / "scripts" / "example_sources.json"
if not seed_path.exists():
return None
try:
rows = json.loads(seed_path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
return None
source_tokens = set(norm_text(source.name).split())
for row in rows if isinstance(rows, list) else []:
if not isinstance(row, dict):
continue
url = str(row.get("url") or "")
if urlparse(url).scheme not in {"http", "https"}:
continue
if row.get("kind") != source.kind:
continue
if source.country and row.get("country") and str(row.get("country")) != source.country:
continue
row_tokens = set(norm_text(row.get("name")).split())
if row_tokens and (row_tokens <= source_tokens or source_tokens <= row_tokens):
return url
return None
def _http_metadata(url: str) -> dict:
response = None
try:
response = requests.head(url, allow_redirects=True, timeout=30)
if response.status_code in {405, 501}:
response.close()
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
except Exception as exc: # noqa: BLE001 - persisted as update-check status
return {"status": "error", "error": str(exc)}
finally:
if response is not None:
response.close()
headers = response.headers
content_length = headers.get("Content-Length")
return {
"status": "checked",
"etag": headers.get("ETag"),
"last_modified": headers.get("Last-Modified"),
"content_length": int(content_length) if content_length and content_length.isdigit() else None,
"content_type": headers.get("Content-Type"),
"final_url": response.url,
"update_artifact": _update_artifact(url, headers.get("Content-Type")),
}
def _local_metadata(path: Path) -> dict:
if not path.exists():
return {"status": "error", "error": f"Source file does not exist: {path}"}
stat = path.stat()
return {
"status": "checked",
"local_mtime": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc),
"local_size": stat.st_size,
"local_sha256": sha256_file(path),
"update_artifact": _update_artifact(str(path), None),
}
def _update_decision(active_dataset: Dataset | None, remote: dict) -> tuple[bool, str]:
if remote["status"] != "checked":
return False, remote.get("error") or "update check failed"
if active_dataset is None:
return True, "no active dataset imported"
if remote.get("local_sha256"):
if remote["local_sha256"] == active_dataset.sha256:
return False, "local file hash matches active dataset"
return True, "local file hash differs from active dataset"
previous = _dataset_update_metadata(active_dataset)
comparable = []
for key in ("etag", "last_modified", "content_length"):
current = remote.get(key)
old = previous.get(key)
if current is not None and old is not None:
comparable.append(key)
if str(current) != str(old):
return True, f"remote {key} changed"
if comparable:
return False, "remote metadata matches active dataset"
return True, "no previous remote metadata recorded"
def _dataset_update_metadata(dataset: Dataset) -> dict:
try:
metadata = json.loads(dataset.metadata_json or "{}")
except json.JSONDecodeError:
return {}
return metadata.get("source_update_check") or {}
def _json_default(value):
if isinstance(value, datetime):
return value.isoformat()
raise TypeError(f"Object of type {type(value).__name__} is not JSON serializable")
def _update_artifact(url_or_path: str, content_type: str | None) -> dict:
lower = url_or_path.lower()
is_osm_diff = lower.endswith(".osc") or lower.endswith(".osc.gz")
is_gtfs_zip = lower.endswith(".zip") or (content_type or "").lower() in {"application/zip", "application/x-zip-compressed"}
return {
"kind": "osm_diff" if is_osm_diff else "gtfs_or_archive" if is_gtfs_zip else "full_snapshot",
"is_diff": is_osm_diff,
"content_type": content_type,
}

158
app/spatial.py Normal file
View File

@@ -0,0 +1,158 @@
from __future__ import annotations
from collections.abc import Iterable
from sqlalchemy import text
from sqlalchemy.orm import Session
from app.config import settings
POSTGIS_GEOMETRY_TABLES = {
"osm_features",
"gtfs_routes",
"gtfs_shapes",
"gtfs_stops",
"canonical_stops",
"route_patterns",
"osm_addresses",
"routing_nodes",
"routing_edges",
}
def using_postgresql() -> bool:
return settings.is_postgresql_database
def refresh_postgis_geometries(
session: Session,
*,
dataset_id: int | None = None,
tables: Iterable[str] | None = None,
only_missing: bool = True,
) -> None:
if not using_postgresql():
return
selected = set(tables or POSTGIS_GEOMETRY_TABLES)
unknown = selected - POSTGIS_GEOMETRY_TABLES
if unknown:
raise ValueError(f"Unsupported PostGIS geometry table(s): {', '.join(sorted(unknown))}")
if "osm_features" in selected:
_refresh_geojson_geometry(session, "osm_features", dataset_id=dataset_id, only_missing=only_missing)
if "gtfs_routes" in selected:
_refresh_geojson_geometry(session, "gtfs_routes", dataset_id=dataset_id, only_missing=only_missing)
if "gtfs_shapes" in selected:
_refresh_geojson_geometry(session, "gtfs_shapes", dataset_id=dataset_id, only_missing=only_missing)
if "route_patterns" in selected:
_refresh_geojson_geometry(session, "route_patterns", dataset_id=None, only_missing=only_missing)
if "osm_addresses" in selected:
_refresh_address_geometry(session, dataset_id=dataset_id, only_missing=only_missing)
if "gtfs_stops" in selected:
_refresh_point_geometry(session, "gtfs_stops", dataset_id=dataset_id, only_missing=only_missing)
if "canonical_stops" in selected:
_refresh_point_geometry(session, "canonical_stops", dataset_id=None, only_missing=only_missing)
if "routing_nodes" in selected:
_refresh_point_geometry(session, "routing_nodes", dataset_id=dataset_id, only_missing=only_missing)
if "routing_edges" in selected:
_refresh_routing_edge_geometry(session, dataset_id=dataset_id, only_missing=only_missing)
def analyze_postgresql_tables(session: Session, tables: Iterable[str]) -> None:
if not using_postgresql():
return
for table in tables:
session.execute(text(f"ANALYZE {table}"))
def _refresh_geojson_geometry(session: Session, table: str, *, dataset_id: int | None, only_missing: bool) -> None:
where = ["geometry_geojson IS NOT NULL", "geometry_geojson <> ''"]
params: dict[str, object] = {}
if dataset_id is not None:
where.append("dataset_id = :dataset_id")
params["dataset_id"] = int(dataset_id)
if only_missing:
where.append("geom IS NULL")
session.execute(
text(
f"""
UPDATE {table}
SET geom = ST_SetSRID(ST_GeomFromGeoJSON(geometry_geojson), 4326)
WHERE {" AND ".join(where)}
"""
),
params,
)
def _refresh_point_geometry(session: Session, table: str, *, dataset_id: int | None, only_missing: bool) -> None:
where = ["lon IS NOT NULL", "lat IS NOT NULL"]
params: dict[str, object] = {}
if dataset_id is not None:
where.append("dataset_id = :dataset_id")
params["dataset_id"] = int(dataset_id)
if only_missing:
where.append("geom IS NULL")
session.execute(
text(
f"""
UPDATE {table}
SET geom = ST_SetSRID(ST_MakePoint(lon, lat), 4326)
WHERE {" AND ".join(where)}
"""
),
params,
)
def _refresh_address_geometry(session: Session, *, dataset_id: int | None, only_missing: bool) -> None:
_refresh_point_geometry(session, "osm_addresses", dataset_id=dataset_id, only_missing=only_missing)
where = ["geometry_geojson IS NOT NULL", "geometry_geojson <> ''"]
params: dict[str, object] = {}
if dataset_id is not None:
where.append("dataset_id = :dataset_id")
params["dataset_id"] = int(dataset_id)
if only_missing:
where.append("area_geom IS NULL")
session.execute(
text(
f"""
UPDATE osm_addresses
SET area_geom = ST_SetSRID(ST_GeomFromGeoJSON(geometry_geojson), 4326)
WHERE {" AND ".join(where)}
"""
),
params,
)
def _refresh_routing_edge_geometry(session: Session, *, dataset_id: int | None, only_missing: bool) -> None:
where = [
"source_lon IS NOT NULL",
"source_lat IS NOT NULL",
"target_lon IS NOT NULL",
"target_lat IS NOT NULL",
]
params: dict[str, object] = {}
if dataset_id is not None:
where.append("dataset_id = :dataset_id")
params["dataset_id"] = int(dataset_id)
if only_missing:
where.append("geom IS NULL")
session.execute(
text(
f"""
UPDATE routing_edges
SET geom = ST_SetSRID(
ST_MakeLine(
ST_MakePoint(source_lon, source_lat),
ST_MakePoint(target_lon, target_lat)
),
4326
)
WHERE {" AND ".join(where)}
"""
),
params,
)

4090
app/static/app.js Normal file

File diff suppressed because it is too large Load Diff

1498
app/static/style.css Normal file

File diff suppressed because it is too large Load Diff

329
app/templates/index.html Normal file
View File

@@ -0,0 +1,329 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Mobility Workbench</title>
<link rel="stylesheet" href="https://unpkg.com/leaflet@1.9.4/dist/leaflet.css" crossorigin="" />
<link rel="stylesheet" href="/static/style.css?v=20260701-harmonizer-module" />
</head>
<body>
<header>
<div>
<h1>Mobility Workbench</h1>
<p>Harmonized transit, mapping data, route layer, map review, and journey tests.</p>
</div>
<div class="actions">
<button id="refreshBtn">Refresh</button>
</div>
</header>
<main>
<aside>
<div class="sidebar-content">
<details class="card sidebar-section" data-sidebar-section="stats" open>
<summary><h2>Stats</h2></summary>
<div class="sidebar-section-body">
<div id="stats" class="stats"></div>
</div>
</details>
<details class="card sidebar-section" data-sidebar-section="qa" open>
<summary><h2>QA</h2></summary>
<div class="sidebar-section-body">
<div class="qa-toolbar">
<button type="button" id="refreshQaBtn">Refresh QA</button>
</div>
<div id="qaDashboard" class="qa-dashboard muted">No QA loaded.</div>
</div>
</details>
<details class="card sidebar-section" data-sidebar-section="jobs" open>
<summary><h2>Jobs</h2></summary>
<div class="sidebar-section-body">
<div id="jobs" class="jobs muted">No jobs loaded.</div>
</div>
</details>
<details class="card sidebar-section" data-sidebar-section="harmonization" open>
<summary><h2>GTFS Harmonization</h2></summary>
<div class="sidebar-section-body">
<details class="nested-section" data-sidebar-section="add-gtfs-source">
<summary><h3>Add GTFS source</h3></summary>
<div class="nested-section-body">
<form id="sourceForm">
<input name="catalog_entry_id" type="hidden" />
<input name="kind" type="hidden" value="gtfs" />
<label>Name <input name="name" required placeholder="DELFI / national GTFS" /></label>
<label>URL or path <input name="url" required placeholder="https://.../feed.zip or ./data/feed.zip" /></label>
<label>Country <input name="country" placeholder="DE" maxlength="8" /></label>
<label>License <input name="license" placeholder="ODbL / CC-BY / unknown" /></label>
<button type="submit">Add GTFS source</button>
</form>
</div>
</details>
<details class="nested-section source-catalog-card" data-sidebar-section="source-catalog">
<summary><h3>Transit source catalog</h3></summary>
<div class="nested-section-body">
<div id="sourceCatalogSummary" class="muted"></div>
<div class="filter-row source-catalog-filter">
<input id="sourceCatalogSearch" placeholder="Search catalog" />
<input id="sourceCatalogCountry" placeholder="Country" />
<select id="sourceCatalogPriority">
<option value="">all priorities</option>
<option value="P0">P0</option>
<option value="P0 fallback">P0 fallback</option>
<option value="P1">P1</option>
<option value="P2">P2</option>
<option value="P3">P3</option>
<option value="P4">P4</option>
<option value="P5">P5</option>
</select>
</div>
<div class="source-catalog-actions">
<button type="button" id="importSourceCatalogBtn">Import catalog</button>
<button type="button" id="importIngestableSourcesBtn">Import ingestable seeds</button>
</div>
<div id="sourceCatalog"></div>
</div>
</details>
<details class="nested-section" data-sidebar-section="gtfs-feed-qa" open>
<summary><h3>Feed QA</h3></summary>
<div class="nested-section-body">
<div class="qa-toolbar">
<button type="button" id="refreshGtfsHarmonizationBtn">Refresh feeds</button>
</div>
<div id="gtfsHarmonizationInventory" class="harmonization-inventory muted">No GTFS feed QA loaded.</div>
</div>
</details>
<details class="nested-section" data-sidebar-section="gtfs-source-management" open>
<summary><h3>GTFS source library</h3></summary>
<div class="nested-section-body">
<div class="filter-row">
<input id="sourceSearch" placeholder="Filter GTFS sources" />
</div>
<div id="sources"></div>
</div>
</details>
</div>
</details>
<details class="card sidebar-section" data-sidebar-section="mapping" open>
<summary><h2>Mapping Data</h2></summary>
<div class="sidebar-section-body">
<details class="nested-section" data-sidebar-section="add-map-source">
<summary><h3>Add map source</h3></summary>
<div class="nested-section-body">
<form id="mappingSourceForm">
<input name="catalog_entry_id" type="hidden" />
<label>Name <input name="name" required placeholder="Germany OSM PBF" /></label>
<label>Kind
<select name="kind">
<option value="osm_pbf">OSM PBF extract</option>
<option value="osm_geojson">OSM transport GeoJSON</option>
<option value="osm_diff">OSM change diff</option>
</select>
</label>
<label>URL or path <input name="url" required placeholder="https://.../latest.osm.pbf or ./data/routes.geojson" /></label>
<label>Country <input name="country" placeholder="DE" maxlength="8" /></label>
<label>License <input name="license" placeholder="ODbL / CC-BY / unknown" /></label>
<button type="submit">Add map source</button>
</form>
</div>
</details>
<details class="nested-section source-catalog-card" data-sidebar-section="geofabrik">
<summary><h3>Geofabrik OSM</h3></summary>
<div class="nested-section-body">
<div class="filter-row geofabrik-filter">
<input id="geofabrikSearch" placeholder="Berlin, Germany, Hamburg" />
<button type="button" id="geofabrikSearchBtn">Search</button>
</div>
<label class="inline-check"><input id="geofabrikDiffSource" type="checkbox" checked /> add diff source metadata</label>
<div id="geofabrikResults" class="dataset-search-results muted">Search Geofabrik extracts, then add or import one as an OSM PBF source.</div>
</div>
</details>
<details class="nested-section" data-sidebar-section="mapping-source-management" open>
<summary><h3>Map source library</h3></summary>
<div class="nested-section-body">
<div class="filter-row">
<input id="mappingSourceSearch" placeholder="Filter map sources" />
<select id="mappingSourceKindFilter">
<option value="">all map kinds</option>
<option value="osm_geojson">OSM GeoJSON</option>
<option value="osm_pbf">OSM PBF</option>
<option value="osm_diff">OSM diff</option>
</select>
</div>
<div id="mappingSources"></div>
</div>
</details>
</div>
</details>
<details class="card sidebar-section" data-sidebar-section="datasets" open>
<summary><h2>Datasets</h2></summary>
<div class="sidebar-section-body">
<details class="nested-section" data-sidebar-section="dataset-pipeline" open>
<summary><h3>Derivation pipeline</h3></summary>
<div class="nested-section-body">
<div class="workflow-actions">
<button id="runMatchBtn" type="button">Run matcher</button>
<button id="buildRouteLayerBtn" type="button">Build route layer</button>
<button id="loadSampleBtn" type="button">Reset sample</button>
</div>
</div>
</details>
<details class="nested-section" data-sidebar-section="dataset-search">
<summary><h3>Dataset search</h3></summary>
<div class="nested-section-body">
<form id="datasetSearchForm" class="dataset-search-form">
<input id="datasetSearchQuery" placeholder="Route, line, stop, shape ID" autocomplete="off" />
<div class="filter-row">
<label class="inline-check"><input id="datasetSearchActiveOnly" type="checkbox" checked /> active only</label>
<button type="submit">Search</button>
</div>
</form>
<div id="datasetSearchResults" class="dataset-search-results muted">Search all imported datasets by label, route ID, and route-layer reference.</div>
</div>
</details>
<details class="nested-section matches-card" data-sidebar-section="route-matches">
<summary><h3>Route matches</h3></summary>
<div class="nested-section-body">
<div class="filter-row">
<select id="matchStatusFilter">
<option value="">all</option>
<option value="matched">matched</option>
<option value="probable">probable</option>
<option value="weak">weak</option>
<option value="missing">missing</option>
<option value="accepted">accepted</option>
<option value="rejected">rejected</option>
</select>
<button id="reloadMatchesBtn">Reload</button>
</div>
<div id="matches"></div>
</div>
</details>
<details class="nested-section" data-sidebar-section="maintenance">
<summary><h3>Maintenance</h3></summary>
<div class="nested-section-body">
<div class="maintenance-grid">
<button type="button" data-admin-action="init-db">Init DB</button>
<button type="button" data-admin-action="backfill-gtfs-shapes">Backfill GTFS shapes</button>
<button type="button" data-admin-action="prune-cache-dry">Check cache</button>
<button type="button" data-admin-action="prune-cache">Prune cache</button>
<button type="button" data-admin-action="prune-inactive-dry">Check inactive</button>
<button type="button" data-admin-action="prune-inactive">Prune inactive</button>
<button type="button" data-admin-action="vacuum-db">Vacuum DB</button>
<button type="button" class="danger" data-admin-action="reset-db">Reset DB</button>
</div>
<div id="adminStatus" class="admin-status muted"></div>
</div>
</details>
</div>
</details>
<details class="card sidebar-section" data-sidebar-section="layers" open>
<summary><h2>Layers</h2></summary>
<div class="sidebar-section-body">
<div class="preset-row">
<button type="button" data-layer-preset="network">Network</button>
<button type="button" data-layer-preset="review">Matched/unmatched</button>
<button type="button" data-layer-preset="unmatched">Unmatched</button>
<button type="button" data-layer-preset="all">All</button>
</div>
<div id="layerControls" class="layer-controls"></div>
<div id="mapStatus" class="map-status muted"></div>
</div>
</details>
</div>
<button id="sidebarCollapseBtn" class="sidebar-collapse-handle" type="button" aria-label="Collapse left panel" title="Collapse left panel" aria-expanded="true"></button>
</aside>
<section class="map-panel">
<div id="map"></div>
<div id="mapLoading" class="map-loading" hidden>
<span class="spinner" aria-hidden="true"></span>
<span id="mapLoadingText">Loading map layers...</span>
</div>
<section class="map-floating journey-card">
<h2>Journey</h2>
<form id="journeyForm">
<div id="journeyTransitSnapshot" class="journey-snapshot muted">Transit snapshot loading...</div>
<label>From <input id="journeyFromQuery" placeholder="Hauptbahnhof" autocomplete="off" /></label>
<input id="journeyFromStop" type="hidden" />
<div id="journeyFromSuggestions" class="stop-suggestions"></div>
<button type="button" id="journeySwapBtn" class="journey-swap" title="Switch start and destination">Swap</button>
<label>To <input id="journeyToQuery" placeholder="Alexanderplatz" autocomplete="off" /></label>
<input id="journeyToStop" type="hidden" />
<div id="journeyToSuggestions" class="stop-suggestions"></div>
<label>Via <input id="journeyViaQuery" placeholder="optional stop" autocomplete="off" /></label>
<input id="journeyViaStop" type="hidden" />
<div id="journeyViaSuggestions" class="stop-suggestions"></div>
<div class="journey-mode" role="radiogroup" aria-label="Route mode">
<label><input type="radio" name="journeyMode" value="transit" checked /> Public transport</label>
<label><input type="radio" name="journeyMode" value="walk" /> Walk</label>
<label><input type="radio" name="journeyMode" value="drive" /> Car</label>
</div>
<div class="journey-options">
<label>Date <input id="journeyServiceDate" type="date" /></label>
<label>Departure <input id="journeyDeparture" type="time" value="08:00" /></label>
<label>Transfer buffer <input id="journeyTransferMinutes" type="number" min="0" max="60" step="1" value="2" /></label>
<label>Rank by
<select id="journeyRanking">
<option value="recommended">Recommended</option>
<option value="earliest_arrival">Earliest arrival</option>
<option value="duration">Shortest duration</option>
<option value="fewest_transfers">Fewest transfers</option>
</select>
</label>
</div>
<label class="journey-direct"><input id="journeyDirectOnly" type="checkbox" /> Direct public transport only</label>
<div class="journey-actions">
<button type="button" id="journeyEarlierBtn">Earlier</button>
<button type="submit" class="primary">Search</button>
<button type="button" id="journeyLaterBtn">Later</button>
</div>
<button type="button" id="generateItinerariesBtn">Generate travel options</button>
</form>
<div id="journeyResults" class="journey-results"></div>
<section class="itinerary-panel">
<div class="journey-title">
<span>Comparison</span>
<button type="button" id="reloadItinerariesBtn">Reload</button>
</div>
<div id="itineraryResults" class="itinerary-results muted">Generate travel options to compare route families.</div>
</section>
</section>
<div class="legend">
<span><b class="line osm"></b>OSM existing routes</span>
<span><b class="line gtfs"></b>GTFS covered routes</span>
<span><b class="line missing"></b>GTFS missing OSM match</span>
<span><b class="dot stops"></b>Stops / stations / terminals</span>
</div>
</section>
</main>
<div id="overlay" class="overlay" hidden>
<section class="overlay-panel">
<div class="overlay-title">
<h2 id="overlayTitle">Candidates</h2>
<button id="overlayCloseBtn">Close</button>
</div>
<div id="overlayContent"></div>
</section>
</div>
<script src="https://unpkg.com/leaflet@1.9.4/dist/leaflet.js" crossorigin=""></script>
<script src="/static/app.js?v=20260701-harmonizer-module"></script>
</body>
</html>

155
app/worker_supervisor.py Normal file
View File

@@ -0,0 +1,155 @@
from __future__ import annotations
import os
import signal
import subprocess
import sys
from dataclasses import dataclass
from pathlib import Path
from app.config import settings
@dataclass
class WorkerHandle:
index: int
worker_id: str
pid: int | None
status: str
pid_file: Path
log_file: Path
started_by_server: bool = False
_handles: list[WorkerHandle] = []
def start_queue_workers() -> list[WorkerHandle]:
if not settings.queue_worker_autostart:
return []
worker_count = max(0, int(settings.queue_worker_count))
handles: list[WorkerHandle] = []
worker_dir = settings.data_dir / "workers"
worker_dir.mkdir(parents=True, exist_ok=True)
for index in range(worker_count):
worker_id = f"server-worker-{index + 1}"
pid_file = worker_dir / f"{worker_id}.pid"
log_file = worker_dir / f"{worker_id}.log"
existing_pid = _read_pid(pid_file)
if existing_pid is not None and _pid_running(existing_pid):
handles.append(
WorkerHandle(
index=index,
worker_id=worker_id,
pid=existing_pid,
status="already_running",
pid_file=pid_file,
log_file=log_file,
)
)
continue
pid_file.unlink(missing_ok=True)
process = _spawn_worker(worker_id, log_file)
pid_file.write_text(str(process.pid), encoding="utf-8")
handles.append(
WorkerHandle(
index=index,
worker_id=worker_id,
pid=process.pid,
status="started",
pid_file=pid_file,
log_file=log_file,
started_by_server=True,
)
)
_handles[:] = handles
return list(_handles)
def stop_queue_workers() -> None:
if not settings.queue_worker_stop_on_shutdown:
return
for handle in list(_handles):
if not handle.started_by_server or handle.pid is None:
continue
_terminate_pid(handle.pid)
handle.pid_file.unlink(missing_ok=True)
def queue_worker_status() -> list[dict[str, object]]:
if not settings.queue_worker_autostart:
return []
worker_dir = settings.data_dir / "workers"
statuses: list[dict[str, object]] = []
configured_count = max(0, int(settings.queue_worker_count))
for index in range(configured_count):
worker_id = f"server-worker-{index + 1}"
pid_file = worker_dir / f"{worker_id}.pid"
log_file = worker_dir / f"{worker_id}.log"
pid = _read_pid(pid_file)
running = pid is not None and _pid_running(pid)
statuses.append(
{
"index": index,
"worker_id": worker_id,
"pid": pid,
"running": running,
"pid_file": str(pid_file),
"log_file": str(log_file),
}
)
return statuses
def _spawn_worker(worker_id: str, log_file: Path) -> subprocess.Popen:
root = Path(__file__).resolve().parents[1]
command = [
sys.executable,
"-m",
"app.cli",
"worker",
"--worker-id",
worker_id,
"--poll-interval",
str(settings.queue_worker_poll_interval_seconds),
]
env = os.environ.copy()
env["MOBILITY_SUPERVISED_WORKER"] = "1"
log_file.parent.mkdir(parents=True, exist_ok=True)
log_handle = log_file.open("ab", buffering=0)
try:
return subprocess.Popen(
command,
cwd=str(root),
env=env,
stdin=subprocess.DEVNULL,
stdout=log_handle,
stderr=subprocess.STDOUT,
start_new_session=True,
)
finally:
log_handle.close()
def _read_pid(path: Path) -> int | None:
try:
return int(path.read_text(encoding="utf-8").strip())
except (FileNotFoundError, ValueError, OSError):
return None
def _pid_running(pid: int) -> bool:
try:
os.kill(pid, 0)
except ProcessLookupError:
return False
except PermissionError:
return True
return True
def _terminate_pid(pid: int) -> None:
try:
os.kill(pid, signal.SIGTERM)
except ProcessLookupError:
return