Alpha stage commit
This commit is contained in:
1
app/__init__.py
Normal file
1
app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Mobility Workbench prototype."""
|
||||
1272
app/address_search.py
Normal file
1272
app/address_search.py
Normal file
File diff suppressed because it is too large
Load Diff
394
app/cli.py
Normal file
394
app/cli.py
Normal file
@@ -0,0 +1,394 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import csv
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
from sqlalchemy import func, select, text
|
||||
|
||||
from app.config import settings
|
||||
from app.data_management import dataset_sidecar_paths, prune_inactive_datasets
|
||||
from app.db import engine, init_db, reset_db, session_scope
|
||||
from app.db_lock import database_write_lock
|
||||
from app.feed_discovery import build_gtfs_discovery_manifests, default_generated_dir
|
||||
from app.models import (
|
||||
Dataset,
|
||||
GtfsRoute,
|
||||
GtfsShape,
|
||||
GtfsStop,
|
||||
RouteMatch,
|
||||
RoutePattern,
|
||||
Source,
|
||||
SourceCatalogEntry,
|
||||
)
|
||||
from app.pipeline.matcher import run_route_matching
|
||||
from app.pipeline.osm_labeling import relabel_osm_features
|
||||
from app.pipeline.osm_pbf import run_osm_pbf_source_staged
|
||||
from app.pipeline.run import run_source
|
||||
from app.pipeline.gtfs import backfill_gtfs_shapes
|
||||
from app.pipeline.route_layer import rebuild_route_layer
|
||||
from app.pipeline.sample_data import load_sample_project
|
||||
from app.osm_storage import osm_feature_count
|
||||
from app.jobs import run_worker_loop
|
||||
from app.jobs import create_route_layer_rebuild_job, create_route_matching_job, create_source_import_job
|
||||
from app.source_catalog import (
|
||||
default_ingestable_sources_path,
|
||||
default_source_catalog_path,
|
||||
import_ingestable_sources,
|
||||
import_source_catalog,
|
||||
source_catalog_summary,
|
||||
)
|
||||
|
||||
cli = typer.Typer(help="Mobility Workbench pipeline CLI")
|
||||
|
||||
|
||||
@cli.command("init-db")
|
||||
def init_db_command() -> None:
|
||||
with _write_lock("init-db"):
|
||||
init_db()
|
||||
typer.echo("Database initialized")
|
||||
|
||||
|
||||
@cli.command("reset-db")
|
||||
def reset_db_command() -> None:
|
||||
with _write_lock("reset-db"):
|
||||
reset_db()
|
||||
typer.echo("Database reset")
|
||||
|
||||
|
||||
@cli.command("load-sample")
|
||||
def load_sample_command() -> None:
|
||||
with _write_lock("load-sample"):
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
result = load_sample_project(session)
|
||||
typer.echo(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
@cli.command("add-source")
|
||||
def add_source_command(
|
||||
name: str = typer.Option(..., help="Source name"),
|
||||
kind: str = typer.Option(..., help="gtfs, osm_geojson, osm_pbf, or osm_diff"),
|
||||
url: str = typer.Option(..., help="HTTP URL or local path"),
|
||||
country: Optional[str] = typer.Option(None),
|
||||
license: Optional[str] = typer.Option(None),
|
||||
priority: Optional[str] = typer.Option(None),
|
||||
mode_scope: Optional[str] = typer.Option(None),
|
||||
source_basis: Optional[str] = typer.Option(None),
|
||||
notes: Optional[str] = typer.Option(None),
|
||||
) -> None:
|
||||
with _write_lock("add-source"):
|
||||
init_db()
|
||||
if kind not in {"gtfs", "osm_geojson", "osm_pbf", "osm_diff"}:
|
||||
raise typer.BadParameter("kind must be gtfs, osm_geojson, osm_pbf, or osm_diff")
|
||||
with session_scope() as session:
|
||||
source = Source(
|
||||
name=name,
|
||||
kind=kind,
|
||||
url=url,
|
||||
country=country,
|
||||
license=license,
|
||||
priority=priority,
|
||||
mode_scope=mode_scope,
|
||||
source_basis=source_basis,
|
||||
notes=notes,
|
||||
)
|
||||
session.add(source)
|
||||
session.flush()
|
||||
typer.echo(json.dumps({"id": source.id, "name": source.name}, indent=2))
|
||||
|
||||
|
||||
@cli.command("run-source")
|
||||
def run_source_command(source_id: int) -> None:
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
source = session.get(Source, source_id)
|
||||
if source is None:
|
||||
raise typer.BadParameter(f"source not found: {source_id}")
|
||||
source_kind = source.kind
|
||||
if source_kind == "osm_pbf":
|
||||
dataset = run_osm_pbf_source_staged(source_id)
|
||||
typer.echo(json.dumps({"source_id": source_id, "dataset_id": dataset.id, "status": dataset.status, "import_mode": "staged_short_lock"}, indent=2))
|
||||
return
|
||||
with _write_lock("run-source"):
|
||||
with session_scope() as session:
|
||||
source = session.get(Source, source_id)
|
||||
if source is None:
|
||||
raise typer.BadParameter(f"source not found: {source_id}")
|
||||
dataset = run_source(session, source)
|
||||
typer.echo(json.dumps({"source_id": source.id, "dataset_id": dataset.id, "status": dataset.status}, indent=2))
|
||||
|
||||
|
||||
@cli.command("run-match")
|
||||
def run_match_command() -> None:
|
||||
with _write_lock("run-match"):
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
result = run_route_matching(session)
|
||||
typer.echo(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
@cli.command("build-route-layer")
|
||||
def build_route_layer_command() -> None:
|
||||
with _write_lock("build-route-layer"):
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
result = rebuild_route_layer(session)
|
||||
typer.echo(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
@cli.command("relabel-osm-features")
|
||||
def relabel_osm_features_command(
|
||||
dataset_id: Optional[int] = typer.Option(None, help="Only relabel one OSM dataset"),
|
||||
force: bool = typer.Option(False, help="Run even when the recorded dependency signature is current"),
|
||||
chunk_size: int = typer.Option(5000, help="Rows per relabel batch"),
|
||||
rebuild_indexes: bool = typer.Option(True, help="Drop/rebuild affected route-scope indexes around large relabel writes"),
|
||||
build_route_layer: bool = typer.Option(True, help="Rebuild the route layer after relabeling"),
|
||||
) -> None:
|
||||
with _write_lock("relabel-osm-features"):
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
result = relabel_osm_features(
|
||||
session,
|
||||
dataset_id=dataset_id,
|
||||
force=force,
|
||||
chunk_size=chunk_size,
|
||||
rebuild_indexes=rebuild_indexes,
|
||||
)
|
||||
if build_route_layer and (result["changed"] or force):
|
||||
result["route_layer_result"] = rebuild_route_layer(session)
|
||||
typer.echo(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
@cli.command("backfill-gtfs-shapes")
|
||||
def backfill_gtfs_shapes_command(dataset_id: Optional[int] = typer.Option(None, help="Only backfill one GTFS dataset")) -> None:
|
||||
with _write_lock("backfill-gtfs-shapes"):
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
result = backfill_gtfs_shapes(session, dataset_id=dataset_id)
|
||||
typer.echo(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
@cli.command("stats")
|
||||
def stats_command() -> None:
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
active_dataset_ids = [row[0] for row in session.execute(select(Dataset.id).where(Dataset.is_active.is_(True))).all()]
|
||||
stats = {
|
||||
"sources": session.scalar(select(func.count()).select_from(Source)),
|
||||
"source_catalog_entries": session.scalar(select(func.count()).select_from(SourceCatalogEntry)) or 0,
|
||||
"active_datasets": len(active_dataset_ids),
|
||||
"gtfs_routes": session.scalar(select(func.count()).select_from(GtfsRoute).where(GtfsRoute.dataset_id.in_(active_dataset_ids))) if active_dataset_ids else 0,
|
||||
"gtfs_stops": session.scalar(select(func.count()).select_from(GtfsStop).where(GtfsStop.dataset_id.in_(active_dataset_ids))) if active_dataset_ids else 0,
|
||||
"gtfs_shapes": session.scalar(select(func.count()).select_from(GtfsShape).where(GtfsShape.dataset_id.in_(active_dataset_ids))) if active_dataset_ids else 0,
|
||||
"route_patterns": session.scalar(select(func.count()).select_from(RoutePattern)) or 0,
|
||||
"osm_routes": sum(osm_feature_count(session, dataset_id, kind="route") for dataset_id in active_dataset_ids),
|
||||
"matches": {status: count for status, count in session.execute(select(RouteMatch.status, func.count()).group_by(RouteMatch.status)).all()},
|
||||
}
|
||||
typer.echo(json.dumps(stats, indent=2))
|
||||
|
||||
|
||||
@cli.command("import-source-catalog")
|
||||
def import_source_catalog_command(
|
||||
csv_path: Path = typer.Option(default_source_catalog_path(), "--csv", help="Source catalog CSV path"),
|
||||
no_update: bool = typer.Option(False, help="Skip rows that already exist"),
|
||||
) -> None:
|
||||
with _write_lock("import-source-catalog"):
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
result = import_source_catalog(session, csv_path, update_existing=not no_update)
|
||||
result["summary"] = source_catalog_summary(session)
|
||||
typer.echo(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
@cli.command("import-ingestable-sources")
|
||||
def import_ingestable_sources_command(
|
||||
csv_path: Path = typer.Option(default_ingestable_sources_path(), "--csv", help="Ingestable source seed CSV path"),
|
||||
no_update: bool = typer.Option(False, help="Skip sources that already exist"),
|
||||
) -> None:
|
||||
with _write_lock("import-ingestable-sources"):
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
result = import_ingestable_sources(session, csv_path, update_existing=not no_update)
|
||||
result["summary"] = source_catalog_summary(session)
|
||||
typer.echo(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
@cli.command("discover-gtfs-sources")
|
||||
def discover_gtfs_sources_command(
|
||||
output_dir: Path = typer.Option(default_generated_dir(), "--output-dir", help="Directory for generated discovery CSVs"),
|
||||
countries: str = typer.Option(
|
||||
",".join(["DE", "AT", "CH", "NL", "DK", "FR", "BE", "LU", "NO", "SE", "FI", "IE", "GB"]),
|
||||
"--countries",
|
||||
help="Comma-separated country codes, or ALL for every country exposed by the upstream catalogs",
|
||||
),
|
||||
no_mobility_database: bool = typer.Option(False, help="Skip Mobility Database feeds_v2.csv"),
|
||||
no_acceptance_test_list: bool = typer.Option(False, help="Skip MobilityData validator acceptance-test feed list"),
|
||||
no_ptna: bool = typer.Option(False, help="Skip PTNA GTFS analysis pages"),
|
||||
max_ptna_details: int = typer.Option(80, help="Maximum PTNA detail pages to fetch for license/crosswalk metadata"),
|
||||
test_limit: int = typer.Option(24, help="Rows to write to the focused test-run ingestable CSV"),
|
||||
check_urls: bool = typer.Option(False, help="Run HEAD/range checks for ingestable feed URLs"),
|
||||
) -> None:
|
||||
result = build_gtfs_discovery_manifests(
|
||||
output_dir=output_dir,
|
||||
countries=[part.strip() for part in countries.split(",") if part.strip()],
|
||||
include_mobility_database=not no_mobility_database,
|
||||
include_acceptance_test_list=not no_acceptance_test_list,
|
||||
include_ptna=not no_ptna,
|
||||
max_ptna_details=max_ptna_details,
|
||||
test_limit=test_limit,
|
||||
check_urls=check_urls,
|
||||
)
|
||||
typer.echo(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
@cli.command("queue-source-imports-from-csv")
|
||||
def queue_source_imports_from_csv_command(
|
||||
csv_path: Path = typer.Option(default_ingestable_sources_path(), "--csv", help="Ingestable source CSV path"),
|
||||
no_update: bool = typer.Option(False, help="Skip sources that already exist instead of updating them"),
|
||||
run_match_at_end: bool = typer.Option(True, help="Queue one route-matching job after all source imports"),
|
||||
build_route_layer_at_end: bool = typer.Option(True, help="Queue one route-layer rebuild after route matching"),
|
||||
priority: int = typer.Option(0, help="Priority for queued source import jobs"),
|
||||
) -> None:
|
||||
with _write_lock("queue-source-imports-from-csv"):
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
csv_path = csv_path if csv_path.is_absolute() else Path.cwd() / csv_path
|
||||
imported = import_ingestable_sources(session, csv_path, update_existing=not no_update)
|
||||
source_urls = _source_urls_from_ingestable_csv(csv_path)
|
||||
sources = session.scalars(
|
||||
select(Source)
|
||||
.where(Source.kind == "gtfs", Source.url.in_(source_urls))
|
||||
.order_by(Source.id)
|
||||
).all()
|
||||
jobs = [
|
||||
create_source_import_job(
|
||||
session,
|
||||
source,
|
||||
run_match=False,
|
||||
build_route_layer=False,
|
||||
priority=priority,
|
||||
)
|
||||
for source in sources
|
||||
]
|
||||
route_match_job = create_route_matching_job(session, priority=priority) if run_match_at_end else None
|
||||
route_layer_job = create_route_layer_rebuild_job(session, priority=priority) if build_route_layer_at_end else None
|
||||
typer.echo(
|
||||
json.dumps(
|
||||
{
|
||||
"csv": str(csv_path),
|
||||
"imported": imported,
|
||||
"sources": [{"id": source.id, "name": source.name} for source in sources],
|
||||
"source_import_jobs": [job.id for job in jobs],
|
||||
"route_match_job": None if route_match_job is None else route_match_job.id,
|
||||
"route_layer_job": None if route_layer_job is None else route_layer_job.id,
|
||||
},
|
||||
indent=2,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@cli.command("prune-cache")
|
||||
def prune_cache_command(dry_run: bool = typer.Option(False, help="Report files without deleting them")) -> None:
|
||||
with _write_lock("prune-cache"):
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
referenced = {
|
||||
Path(path).resolve()
|
||||
for path in session.scalars(select(Dataset.local_path)).all()
|
||||
if path
|
||||
}
|
||||
for dataset in session.scalars(select(Dataset)).all():
|
||||
referenced.update(path.resolve() for path in dataset_sidecar_paths(dataset))
|
||||
|
||||
roots = [settings.data_dir / "sources", settings.data_dir / "derived", settings.data_dir / "sidecars", settings.data_dir / "staging"]
|
||||
candidates = [
|
||||
path
|
||||
for root in roots
|
||||
if root.exists()
|
||||
for path in root.rglob("*")
|
||||
if path.is_file() and path.resolve() not in referenced
|
||||
]
|
||||
total_bytes = sum(path.stat().st_size for path in candidates)
|
||||
if not dry_run:
|
||||
for path in candidates:
|
||||
path.unlink()
|
||||
for root in roots:
|
||||
_remove_empty_dirs(root)
|
||||
|
||||
typer.echo(
|
||||
json.dumps(
|
||||
{
|
||||
"dry_run": dry_run,
|
||||
"files": len(candidates),
|
||||
"bytes": total_bytes,
|
||||
"deleted": 0 if dry_run else len(candidates),
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@cli.command("prune-inactive-datasets")
|
||||
def prune_inactive_datasets_command(
|
||||
dry_run: bool = typer.Option(False, help="Report inactive normalized datasets without deleting them"),
|
||||
) -> None:
|
||||
with _write_lock("prune-inactive-datasets"):
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
result = prune_inactive_datasets(session, dry_run=dry_run)
|
||||
typer.echo(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
@cli.command("vacuum-db")
|
||||
def vacuum_db_command() -> None:
|
||||
with _write_lock("vacuum-db"):
|
||||
init_db()
|
||||
with engine.connect().execution_options(isolation_level="AUTOCOMMIT") as connection:
|
||||
connection.execute(text("VACUUM"))
|
||||
connection.execute(text("PRAGMA wal_checkpoint(TRUNCATE)"))
|
||||
typer.echo("Database vacuumed")
|
||||
|
||||
|
||||
@cli.command("worker")
|
||||
def worker_command(
|
||||
once: bool = typer.Option(False, help="Process at most one queued job and exit"),
|
||||
max_jobs: Optional[int] = typer.Option(None, help="Process at most this many jobs and exit"),
|
||||
poll_interval: float = typer.Option(2.0, help="Seconds to wait between queue polls"),
|
||||
worker_id: Optional[str] = typer.Option(None, help="Stable worker identifier"),
|
||||
) -> None:
|
||||
result = run_worker_loop(worker_id=worker_id, poll_interval=poll_interval, max_jobs=max_jobs, once=once)
|
||||
typer.echo(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
def _remove_empty_dirs(root: Path) -> None:
|
||||
if not root.exists():
|
||||
return
|
||||
for path in sorted((p for p in root.rglob("*") if p.is_dir()), key=lambda p: len(p.parts), reverse=True):
|
||||
try:
|
||||
path.rmdir()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _write_lock(operation: str):
|
||||
return database_write_lock(f"cli:{operation}", timeout=settings.database_write_lock_cli_timeout_seconds)
|
||||
|
||||
|
||||
def _source_urls_from_ingestable_csv(path: Path) -> list[str]:
|
||||
urls: list[str] = []
|
||||
with path.open("r", encoding="utf-8-sig", newline="") as handle:
|
||||
for row in csv.DictReader(handle):
|
||||
if (row.get("kind") or "").strip().lower() != "gtfs":
|
||||
continue
|
||||
url = (row.get("url") or "").strip()
|
||||
if url and url not in urls:
|
||||
urls.append(url)
|
||||
return urls
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
74
app/config.py
Normal file
74
app/config.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Runtime settings.
|
||||
|
||||
SQLite is the default because this prototype should run immediately.
|
||||
The schema is deliberately plain enough to migrate to PostGIS later.
|
||||
"""
|
||||
|
||||
database_url: str = "sqlite:///./data/workbench.sqlite"
|
||||
data_dir: Path = Path("./data")
|
||||
# 0 means import all stop_times. Use a positive value only for constrained
|
||||
# demos where full timetable routing is not needed.
|
||||
gtfs_stop_times_import_limit: int = 0
|
||||
# "sidecar_stop_times" keeps the large timetable call table in a per-dataset
|
||||
# SQLite file and stores compact GTFS tables in the main app database.
|
||||
# Set to "main" for the old all-in-one SQLite layout.
|
||||
gtfs_timetable_storage: str = "sidecar_stop_times"
|
||||
gtfs_keep_activation_stage: bool = False
|
||||
# "sidecar_features" keeps extracted OSM transport features in a per-dataset
|
||||
# SQLite file. The main DB materializes only OSM rows that need stable
|
||||
# foreign keys for matches or route-layer output.
|
||||
osm_feature_storage: str = "sidecar_features"
|
||||
osm_sidecar_create_visual_only_stops: bool = False
|
||||
# Large OSM PBF extracts should be reduced to transport objects before the
|
||||
# Python extractor scans them. XML fixtures stay unfiltered by default.
|
||||
osm_pbf_prefilter_enabled: bool = True
|
||||
osm_pbf_prefilter_formats: str = "osm_pbf"
|
||||
osm_pbf_prefilter_script: Path = Path("scripts/osmium_transport_filter.sh")
|
||||
osm_diff_max_sequence_gap: int = 14
|
||||
osm_diff_apply_batch_size: int = 7
|
||||
osm_diff_state_timeout_seconds: float = 30.0
|
||||
sqlite_timeout_seconds: float = 120.0
|
||||
sqlite_busy_timeout_ms: int = 120000
|
||||
database_write_lock_timeout_seconds: float = 1.0
|
||||
database_write_lock_cli_timeout_seconds: float = 3600.0
|
||||
queue_worker_autostart: bool = True
|
||||
queue_worker_count: int = 1
|
||||
queue_worker_poll_interval_seconds: float = 2.0
|
||||
queue_job_lease_seconds: int = 7200
|
||||
route_matching_batch_size: int = 100
|
||||
route_layer_osm_route_batch_size: int = 1000
|
||||
route_layer_osm_stop_batch_size: int = 5000
|
||||
# SQLite defaults to sidecar storage. PostgreSQL/PostGIS defaults to main
|
||||
# table storage so indexes, joins, and spatial operators can work over the
|
||||
# full imported datasets.
|
||||
postgres_use_sidecars: bool = False
|
||||
# Keep supervised workers alive across API server restarts. Stale workers are
|
||||
# detected by PID files at the next startup; stale job leases are requeued.
|
||||
queue_worker_stop_on_shutdown: bool = False
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||
|
||||
@property
|
||||
def normalized_database_url(self) -> str:
|
||||
if self.database_url.startswith("postgresql://"):
|
||||
return "postgresql+psycopg://" + self.database_url.removeprefix("postgresql://")
|
||||
return self.database_url
|
||||
|
||||
@property
|
||||
def is_sqlite_database(self) -> bool:
|
||||
return self.normalized_database_url.startswith("sqlite")
|
||||
|
||||
@property
|
||||
def is_postgresql_database(self) -> bool:
|
||||
return self.normalized_database_url.startswith("postgresql")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
settings.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
327
app/data_management.py
Normal file
327
app/data_management.py
Normal file
@@ -0,0 +1,327 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import delete, func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.gtfs_storage import dataset_sidecar_paths as gtfs_dataset_sidecar_paths, missing_sidecar_paths as gtfs_missing_sidecar_paths, stop_time_count
|
||||
from app.models import (
|
||||
CanonicalStopLink,
|
||||
Dataset,
|
||||
GtfsAgency,
|
||||
GtfsCalendar,
|
||||
GtfsCalendarDate,
|
||||
GtfsRoute,
|
||||
GtfsRoutePatternLink,
|
||||
GtfsShape,
|
||||
GtfsStop,
|
||||
GtfsStopTime,
|
||||
GtfsTripRoutePatternLink,
|
||||
GtfsTrip,
|
||||
OsmDiffState,
|
||||
OsmFeature,
|
||||
RouteMatch,
|
||||
RoutePattern,
|
||||
RoutePatternStop,
|
||||
Source,
|
||||
SourceUpdateCheck,
|
||||
)
|
||||
from app.osm_storage import (
|
||||
dataset_sidecar_paths as osm_dataset_sidecar_paths,
|
||||
missing_sidecar_paths as osm_missing_sidecar_paths,
|
||||
osm_feature_count,
|
||||
)
|
||||
|
||||
|
||||
def dataset_row_counts(session: Session, dataset_id: int, kind: str) -> dict[str, int]:
|
||||
if kind == "gtfs":
|
||||
route_ids = select(GtfsRoute.id).where(GtfsRoute.dataset_id == dataset_id)
|
||||
match_counts = {
|
||||
status: count
|
||||
for status, count in session.execute(
|
||||
select(RouteMatch.status, func.count())
|
||||
.where(RouteMatch.gtfs_route_id.in_(route_ids))
|
||||
.group_by(RouteMatch.status)
|
||||
).all()
|
||||
}
|
||||
return {
|
||||
"agencies": _count(session, GtfsAgency, dataset_id),
|
||||
"stops": _count(session, GtfsStop, dataset_id),
|
||||
"routes": _count(session, GtfsRoute, dataset_id),
|
||||
"trips": _count(session, GtfsTrip, dataset_id),
|
||||
"calendars": _count(session, GtfsCalendar, dataset_id),
|
||||
"calendar_dates": _count(session, GtfsCalendarDate, dataset_id),
|
||||
"shapes": _count(session, GtfsShape, dataset_id),
|
||||
"stop_times": stop_time_count(session, dataset_id),
|
||||
"missing_sidecar": _gtfs_sidecar_missing(session, dataset_id),
|
||||
"matches": sum(match_counts.values()),
|
||||
"match_counts": match_counts,
|
||||
}
|
||||
if kind == "osm_geojson":
|
||||
return {
|
||||
"features": _safe_osm_feature_count(session, dataset_id),
|
||||
"routes": _safe_osm_feature_count(session, dataset_id, kind="route"),
|
||||
"stops": _safe_osm_feature_count(session, dataset_id, kind=["stop", "station", "terminal"]),
|
||||
"infra": _safe_osm_feature_count(session, dataset_id, kind="infra"),
|
||||
"missing_sidecar": _osm_sidecar_missing(session, dataset_id),
|
||||
}
|
||||
return {}
|
||||
|
||||
|
||||
def source_row_counts(session: Session, source: Source) -> dict[str, object]:
|
||||
counts = {
|
||||
"datasets": len(source.datasets),
|
||||
"active_datasets": sum(1 for dataset in source.datasets if dataset.is_active),
|
||||
"routes": 0,
|
||||
"stops": 0,
|
||||
"features": 0,
|
||||
"trips": 0,
|
||||
"shapes": 0,
|
||||
"stop_times": 0,
|
||||
"missing_sidecars": 0,
|
||||
"match_counts": {},
|
||||
"missing_gtfs_sidecars": 0,
|
||||
"missing_osm_sidecars": 0,
|
||||
}
|
||||
match_counts: dict[str, int] = {}
|
||||
for dataset in source.datasets:
|
||||
stats = dataset_row_counts(session, dataset.id, dataset.kind)
|
||||
counts["routes"] += int(stats.get("routes", 0))
|
||||
counts["stops"] += int(stats.get("stops", 0))
|
||||
counts["features"] += int(stats.get("features", 0))
|
||||
counts["trips"] += int(stats.get("trips", 0))
|
||||
counts["shapes"] += int(stats.get("shapes", 0))
|
||||
counts["stop_times"] += int(stats.get("stop_times", 0))
|
||||
if stats.get("missing_sidecar"):
|
||||
counts["missing_sidecars"] += 1
|
||||
if dataset.kind == "gtfs":
|
||||
counts["missing_gtfs_sidecars"] += 1
|
||||
elif dataset.kind == "osm_geojson":
|
||||
counts["missing_osm_sidecars"] += 1
|
||||
for status, count in stats.get("match_counts", {}).items():
|
||||
match_counts[status] = match_counts.get(status, 0) + int(count)
|
||||
counts["match_counts"] = match_counts
|
||||
return counts
|
||||
|
||||
|
||||
def delete_dataset(session: Session, dataset_id: int) -> dict[str, object]:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if dataset is None:
|
||||
return {"deleted": False, "reason": "dataset not found", "dataset_id": dataset_id}
|
||||
|
||||
counts = dataset_row_counts(session, dataset.id, dataset.kind)
|
||||
_detach_update_checks_for_dataset(session, dataset.id)
|
||||
session.execute(delete(OsmDiffState).where(OsmDiffState.raw_dataset_id == dataset.id))
|
||||
_delete_dataset_rows(session, dataset)
|
||||
_delete_dataset_files(dataset)
|
||||
session.delete(dataset)
|
||||
session.flush()
|
||||
return {"deleted": True, "dataset_id": dataset_id, "counts": counts}
|
||||
|
||||
|
||||
def delete_source(session: Session, source_id: int) -> dict[str, object]:
|
||||
source = session.get(Source, source_id)
|
||||
if source is None:
|
||||
return {"deleted": False, "reason": "source not found", "source_id": source_id}
|
||||
|
||||
datasets = list(source.datasets)
|
||||
dataset_results = []
|
||||
for dataset in datasets:
|
||||
dataset_results.append({"dataset_id": dataset.id, "kind": dataset.kind, "counts": dataset_row_counts(session, dataset.id, dataset.kind)})
|
||||
_detach_update_checks_for_dataset(session, dataset.id)
|
||||
session.execute(delete(OsmDiffState).where(OsmDiffState.raw_dataset_id == dataset.id))
|
||||
_delete_dataset_rows(session, dataset)
|
||||
_delete_dataset_files(dataset)
|
||||
session.delete(dataset)
|
||||
session.execute(delete(OsmDiffState).where(OsmDiffState.source_id == source.id))
|
||||
session.delete(source)
|
||||
session.flush()
|
||||
return {"deleted": True, "source_id": source_id, "datasets": dataset_results}
|
||||
|
||||
|
||||
def unreferenced_cache_file_summary(session: Session) -> dict[str, int]:
|
||||
candidates = _unreferenced_cache_files(session)
|
||||
return {"files": len(candidates), "bytes": sum(path.stat().st_size for path in candidates)}
|
||||
|
||||
|
||||
def prune_unreferenced_cache_files(session: Session) -> dict[str, int]:
|
||||
candidates = _unreferenced_cache_files(session)
|
||||
total_bytes = sum(path.stat().st_size for path in candidates)
|
||||
for path in candidates:
|
||||
path.unlink()
|
||||
for root in _cache_roots():
|
||||
_remove_empty_dirs(root)
|
||||
return {"files": len(candidates), "bytes": total_bytes}
|
||||
|
||||
|
||||
def _unreferenced_cache_files(session: Session) -> list[Path]:
|
||||
referenced = {
|
||||
Path(path).resolve()
|
||||
for path in session.scalars(select(Dataset.local_path)).all()
|
||||
if path
|
||||
}
|
||||
for dataset in session.scalars(select(Dataset)).all():
|
||||
referenced.update(path.resolve() for path in dataset_sidecar_paths(dataset))
|
||||
return [
|
||||
path
|
||||
for root in _cache_roots()
|
||||
if root.exists()
|
||||
for path in root.rglob("*")
|
||||
if path.is_file() and path.resolve() not in referenced
|
||||
]
|
||||
|
||||
|
||||
def _cache_roots() -> list[Path]:
|
||||
# Staging files are not referenced by datasets until activation. Automatic
|
||||
# pruning must not remove a staging DB from a running import.
|
||||
return [settings.data_dir / "sources", settings.data_dir / "derived", settings.data_dir / "sidecars"]
|
||||
|
||||
|
||||
def prune_inactive_datasets(session: Session, dry_run: bool = True) -> dict[str, object]:
|
||||
dataset_rows = session.execute(
|
||||
select(Dataset.id, Dataset.kind).where(Dataset.is_active.is_(False), Dataset.kind.in_(["gtfs", "osm_geojson"]))
|
||||
).all()
|
||||
dataset_ids = [int(row[0]) for row in dataset_rows]
|
||||
gtfs_ids = [int(dataset_id) for dataset_id, kind in dataset_rows if kind == "gtfs"]
|
||||
osm_ids = [int(dataset_id) for dataset_id, kind in dataset_rows if kind == "osm_geojson"]
|
||||
|
||||
route_ids = select(GtfsRoute.id).where(GtfsRoute.dataset_id.in_(gtfs_ids)) if gtfs_ids else None
|
||||
osm_feature_ids = select(OsmFeature.id).where(OsmFeature.dataset_id.in_(osm_ids)) if osm_ids else None
|
||||
match_filters = []
|
||||
if route_ids is not None:
|
||||
match_filters.append(RouteMatch.gtfs_route_id.in_(route_ids))
|
||||
if osm_feature_ids is not None:
|
||||
match_filters.append(RouteMatch.osm_feature_id.in_(osm_feature_ids))
|
||||
|
||||
counts = {
|
||||
"datasets": len(dataset_ids),
|
||||
"gtfs_stop_times": sum(stop_time_count(session, dataset_id) for dataset_id in gtfs_ids),
|
||||
"gtfs_shapes": _count_dataset_rows(session, GtfsShape, gtfs_ids),
|
||||
"gtfs_trips": _count_dataset_rows(session, GtfsTrip, gtfs_ids),
|
||||
"gtfs_calendar_dates": _count_dataset_rows(session, GtfsCalendarDate, gtfs_ids),
|
||||
"gtfs_calendars": _count_dataset_rows(session, GtfsCalendar, gtfs_ids),
|
||||
"gtfs_routes": _count_dataset_rows(session, GtfsRoute, gtfs_ids),
|
||||
"gtfs_stops": _count_dataset_rows(session, GtfsStop, gtfs_ids),
|
||||
"gtfs_agencies": _count_dataset_rows(session, GtfsAgency, gtfs_ids),
|
||||
"osm_features": sum(_safe_osm_feature_count(session, dataset_id) for dataset_id in osm_ids),
|
||||
"missing_osm_sidecars": sum(1 for dataset_id in osm_ids if _osm_sidecar_missing(session, dataset_id)),
|
||||
"gtfs_route_pattern_links": session.scalar(select(func.count()).select_from(GtfsRoutePatternLink).where(GtfsRoutePatternLink.dataset_id.in_(gtfs_ids))) if gtfs_ids else 0,
|
||||
"gtfs_trip_route_pattern_links": session.scalar(select(func.count()).select_from(GtfsTripRoutePatternLink).where(GtfsTripRoutePatternLink.dataset_id.in_(gtfs_ids))) if gtfs_ids else 0,
|
||||
"canonical_stop_links": session.scalar(select(func.count()).select_from(CanonicalStopLink).where(CanonicalStopLink.dataset_id.in_(dataset_ids))) if dataset_ids else 0,
|
||||
"route_matches": session.scalar(select(func.count()).select_from(RouteMatch).where(or_(*match_filters))) if match_filters else 0,
|
||||
}
|
||||
if dry_run or not dataset_ids:
|
||||
return {"dry_run": dry_run, "dataset_ids": dataset_ids, "deleted": counts if not dry_run else {}, "would_delete": counts}
|
||||
|
||||
for dataset_id in dataset_ids:
|
||||
_detach_update_checks_for_dataset(session, dataset_id)
|
||||
if match_filters:
|
||||
session.execute(delete(RouteMatch).where(or_(*match_filters)))
|
||||
if gtfs_ids:
|
||||
route_ids = select(GtfsRoute.id).where(GtfsRoute.dataset_id.in_(gtfs_ids))
|
||||
pattern_ids = select(RoutePattern.id).where(RoutePattern.gtfs_route_id.in_(route_ids))
|
||||
session.execute(delete(RoutePatternStop).where(RoutePatternStop.route_pattern_id.in_(pattern_ids)))
|
||||
session.execute(delete(GtfsTripRoutePatternLink).where(GtfsTripRoutePatternLink.dataset_id.in_(gtfs_ids)))
|
||||
session.execute(delete(GtfsRoutePatternLink).where(GtfsRoutePatternLink.dataset_id.in_(gtfs_ids)))
|
||||
session.execute(delete(RoutePattern).where(RoutePattern.gtfs_route_id.in_(route_ids)))
|
||||
session.execute(delete(CanonicalStopLink).where(CanonicalStopLink.dataset_id.in_(gtfs_ids), CanonicalStopLink.object_type == "gtfs_stop"))
|
||||
for model in [GtfsStopTime, GtfsShape, GtfsTrip, GtfsCalendarDate, GtfsCalendar, GtfsRoute, GtfsStop, GtfsAgency]:
|
||||
session.execute(delete(model).where(model.dataset_id.in_(gtfs_ids)))
|
||||
if osm_ids:
|
||||
osm_feature_ids = select(OsmFeature.id).where(OsmFeature.dataset_id.in_(osm_ids))
|
||||
pattern_ids = select(RoutePattern.id).where(RoutePattern.osm_feature_id.in_(osm_feature_ids))
|
||||
session.execute(delete(RoutePatternStop).where(RoutePatternStop.route_pattern_id.in_(pattern_ids)))
|
||||
session.execute(delete(GtfsTripRoutePatternLink).where(GtfsTripRoutePatternLink.route_pattern_id.in_(pattern_ids)))
|
||||
session.execute(delete(GtfsRoutePatternLink).where(GtfsRoutePatternLink.route_pattern_id.in_(pattern_ids)))
|
||||
session.execute(delete(RoutePattern).where(RoutePattern.osm_feature_id.in_(osm_feature_ids)))
|
||||
session.execute(delete(CanonicalStopLink).where(CanonicalStopLink.dataset_id.in_(osm_ids), CanonicalStopLink.object_type == "osm_feature"))
|
||||
session.execute(delete(OsmFeature).where(OsmFeature.dataset_id.in_(osm_ids)))
|
||||
for dataset in session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all():
|
||||
_delete_dataset_files(dataset)
|
||||
session.execute(delete(Dataset).where(Dataset.id.in_(dataset_ids)))
|
||||
session.flush()
|
||||
return {"dry_run": dry_run, "dataset_ids": dataset_ids, "deleted": counts, "would_delete": {}}
|
||||
|
||||
|
||||
def _delete_dataset_rows(session: Session, dataset: Dataset) -> None:
|
||||
if dataset.kind == "gtfs":
|
||||
route_ids = select(GtfsRoute.id).where(GtfsRoute.dataset_id == dataset.id)
|
||||
pattern_ids = select(RoutePattern.id).where(RoutePattern.gtfs_route_id.in_(route_ids))
|
||||
session.execute(delete(RouteMatch).where(RouteMatch.gtfs_route_id.in_(route_ids)))
|
||||
session.execute(delete(RoutePatternStop).where(RoutePatternStop.route_pattern_id.in_(pattern_ids)))
|
||||
session.execute(delete(GtfsTripRoutePatternLink).where(GtfsTripRoutePatternLink.dataset_id == dataset.id))
|
||||
session.execute(delete(GtfsRoutePatternLink).where(GtfsRoutePatternLink.dataset_id == dataset.id))
|
||||
session.execute(delete(RoutePattern).where(RoutePattern.gtfs_route_id.in_(route_ids)))
|
||||
session.execute(delete(CanonicalStopLink).where(CanonicalStopLink.dataset_id == dataset.id, CanonicalStopLink.object_type == "gtfs_stop"))
|
||||
for model in [GtfsStopTime, GtfsShape, GtfsTrip, GtfsCalendarDate, GtfsCalendar, GtfsRoute, GtfsStop, GtfsAgency]:
|
||||
session.execute(delete(model).where(model.dataset_id == dataset.id))
|
||||
elif dataset.kind == "osm_geojson":
|
||||
osm_feature_ids = select(OsmFeature.id).where(OsmFeature.dataset_id == dataset.id)
|
||||
pattern_ids = select(RoutePattern.id).where(RoutePattern.osm_feature_id.in_(osm_feature_ids))
|
||||
session.execute(delete(RouteMatch).where(RouteMatch.osm_feature_id.in_(osm_feature_ids)))
|
||||
session.execute(delete(RoutePatternStop).where(RoutePatternStop.route_pattern_id.in_(pattern_ids)))
|
||||
session.execute(delete(GtfsTripRoutePatternLink).where(GtfsTripRoutePatternLink.route_pattern_id.in_(pattern_ids)))
|
||||
session.execute(delete(GtfsRoutePatternLink).where(GtfsRoutePatternLink.route_pattern_id.in_(pattern_ids)))
|
||||
session.execute(delete(RoutePattern).where(RoutePattern.osm_feature_id.in_(osm_feature_ids)))
|
||||
session.execute(delete(CanonicalStopLink).where(CanonicalStopLink.dataset_id == dataset.id, CanonicalStopLink.object_type == "osm_feature"))
|
||||
session.execute(delete(OsmFeature).where(OsmFeature.dataset_id == dataset.id))
|
||||
|
||||
|
||||
def _delete_dataset_files(dataset: Dataset) -> None:
|
||||
for path in dataset_sidecar_paths(dataset):
|
||||
try:
|
||||
path.unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
def dataset_sidecar_paths(dataset: Dataset) -> list[Path]:
|
||||
return [*gtfs_dataset_sidecar_paths(dataset), *osm_dataset_sidecar_paths(dataset)]
|
||||
|
||||
|
||||
def _gtfs_sidecar_missing(session: Session, dataset_id: int) -> bool:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
return bool(gtfs_missing_sidecar_paths(dataset))
|
||||
|
||||
|
||||
def _safe_osm_feature_count(session: Session, dataset_id: int, *, kind=None) -> int:
|
||||
try:
|
||||
return osm_feature_count(session, dataset_id, kind=kind)
|
||||
except FileNotFoundError:
|
||||
return 0
|
||||
|
||||
|
||||
def _osm_sidecar_missing(session: Session, dataset_id: int) -> bool:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
return bool(osm_missing_sidecar_paths(dataset))
|
||||
|
||||
|
||||
def _detach_update_checks_for_dataset(session: Session, dataset_id: int) -> None:
|
||||
for check in session.scalars(select(SourceUpdateCheck).where(SourceUpdateCheck.active_dataset_id == dataset_id)).all():
|
||||
check.active_dataset_id = None
|
||||
|
||||
|
||||
def _count(session: Session, model, dataset_id: int) -> int:
|
||||
return session.scalar(select(func.count()).select_from(model).where(model.dataset_id == dataset_id)) or 0
|
||||
|
||||
|
||||
def _count_where(session: Session, model, dataset_id: int, *where) -> int:
|
||||
return session.scalar(select(func.count()).select_from(model).where(model.dataset_id == dataset_id, *where)) or 0
|
||||
|
||||
|
||||
def _count_dataset_rows(session: Session, model, dataset_ids: list[int]) -> int:
|
||||
if not dataset_ids:
|
||||
return 0
|
||||
return session.scalar(select(func.count()).select_from(model).where(model.dataset_id.in_(dataset_ids))) or 0
|
||||
|
||||
|
||||
def _remove_empty_dirs(root: Path) -> None:
|
||||
if not root.exists():
|
||||
return
|
||||
for path in sorted((p for p in root.rglob("*") if p.is_dir()), key=lambda p: len(p.parts), reverse=True):
|
||||
try:
|
||||
path.rmdir()
|
||||
except OSError:
|
||||
pass
|
||||
252
app/dataset_search.py
Normal file
252
app/dataset_search.py
Normal file
@@ -0,0 +1,252 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.gtfs_storage import execute_sidecar_query, uses_sidecar_stop_times
|
||||
from app.models import Dataset, GtfsRoute, GtfsShape, GtfsStopTime, GtfsTrip, OsmFeature, RoutePattern, Source
|
||||
from app.osm_storage import osm_feature_public_id, query_osm_features
|
||||
from app.pipeline.utils import norm_ref
|
||||
|
||||
|
||||
def search_datasets(session: Session, query: str, *, active_only: bool = False, limit: int = 80) -> dict:
|
||||
q = (query or "").strip()
|
||||
if len(q) < 1:
|
||||
return {"query": q, "gtfs_routes": [], "osm_routes": [], "route_patterns": [], "totals": {}}
|
||||
max_rows = max(1, min(limit, 250))
|
||||
gtfs_routes = _gtfs_route_hits(session, q, active_only=active_only, limit=max_rows)
|
||||
osm_routes = _osm_route_hits(session, q, active_only=active_only, limit=max_rows)
|
||||
route_patterns = _route_pattern_hits(session, q, limit=max_rows)
|
||||
return {
|
||||
"query": q,
|
||||
"gtfs_routes": gtfs_routes,
|
||||
"osm_routes": osm_routes,
|
||||
"route_patterns": route_patterns,
|
||||
"totals": {
|
||||
"gtfs_routes": len(gtfs_routes),
|
||||
"osm_routes": len(osm_routes),
|
||||
"route_patterns": len(route_patterns),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _gtfs_route_hits(session: Session, query: str, *, active_only: bool, limit: int) -> list[dict]:
|
||||
pattern = f"%{query}%"
|
||||
ref = norm_ref(query)
|
||||
stmt = (
|
||||
select(GtfsRoute, Dataset, Source)
|
||||
.join(Dataset, Dataset.id == GtfsRoute.dataset_id)
|
||||
.join(Source, Source.id == Dataset.source_id)
|
||||
.where(
|
||||
or_(
|
||||
GtfsRoute.short_name.ilike(pattern),
|
||||
GtfsRoute.route_id.ilike(pattern),
|
||||
GtfsRoute.long_name.ilike(pattern),
|
||||
GtfsRoute.route_key == ref,
|
||||
)
|
||||
)
|
||||
.order_by(Dataset.is_active.desc(), Source.name, GtfsRoute.short_name, GtfsRoute.route_id)
|
||||
.limit(limit)
|
||||
)
|
||||
if active_only:
|
||||
stmt = stmt.where(Dataset.is_active.is_(True))
|
||||
rows = session.execute(stmt).all()
|
||||
route_ids = [route.id for route, _, _ in rows]
|
||||
trip_counts = _trip_counts(session, route_ids)
|
||||
stop_time_counts = _stop_time_counts(session, route_ids)
|
||||
shape_counts = _shape_counts(session, route_ids)
|
||||
return [
|
||||
{
|
||||
"type": "gtfs_route",
|
||||
"source": _source_payload(source),
|
||||
"dataset": _dataset_payload(dataset),
|
||||
"route": {
|
||||
"id": route.id,
|
||||
"route_id": route.route_id,
|
||||
"ref": route.short_name,
|
||||
"name": route.long_name,
|
||||
"mode": route.mode,
|
||||
"operator": route.operator_name,
|
||||
},
|
||||
"geometry": _geometry_payload(route),
|
||||
"timetable": {
|
||||
"trips": trip_counts.get(route.id, 0),
|
||||
"stop_times": stop_time_counts.get(route.id, 0),
|
||||
"shapes": shape_counts.get(route.id, 0),
|
||||
},
|
||||
}
|
||||
for route, dataset, source in rows
|
||||
]
|
||||
|
||||
|
||||
def _osm_route_hits(session: Session, query: str, *, active_only: bool, limit: int) -> list[dict]:
|
||||
ref = norm_ref(query)
|
||||
dataset_stmt = select(Dataset).where(Dataset.kind == "osm_geojson")
|
||||
if active_only:
|
||||
dataset_stmt = dataset_stmt.where(Dataset.is_active.is_(True))
|
||||
datasets = session.scalars(dataset_stmt.order_by(Dataset.is_active.desc(), Dataset.id)).all()
|
||||
if not datasets:
|
||||
return []
|
||||
dataset_ids = [dataset.id for dataset in datasets]
|
||||
sources = {source.id: source for source in session.scalars(select(Source).where(Source.id.in_([dataset.source_id for dataset in datasets]))).all()}
|
||||
dataset_by_id = {dataset.id: dataset for dataset in datasets}
|
||||
features_by_identity: dict[tuple[int, str, str], OsmFeature] = {}
|
||||
for feature in query_osm_features(session, dataset_ids, kinds=["route"], search=query, limit=limit):
|
||||
features_by_identity[(feature.dataset_id, feature.osm_type, feature.osm_id)] = feature
|
||||
if ref:
|
||||
for feature in query_osm_features(session, dataset_ids, kinds=["route"], route_key=ref, limit=limit):
|
||||
features_by_identity[(feature.dataset_id, feature.osm_type, feature.osm_id)] = feature
|
||||
features = sorted(
|
||||
features_by_identity.values(),
|
||||
key=lambda feature: (
|
||||
0 if dataset_by_id.get(feature.dataset_id) and dataset_by_id[feature.dataset_id].is_active else 1,
|
||||
sources.get(dataset_by_id[feature.dataset_id].source_id).name if dataset_by_id.get(feature.dataset_id) and sources.get(dataset_by_id[feature.dataset_id].source_id) else "",
|
||||
feature.ref or "",
|
||||
feature.name or "",
|
||||
feature.id or 0,
|
||||
),
|
||||
)[:limit]
|
||||
return [
|
||||
{
|
||||
"type": "osm_route",
|
||||
"source": _source_payload(source),
|
||||
"dataset": _dataset_payload(dataset),
|
||||
"osm": {
|
||||
"id": osm_feature_public_id(feature),
|
||||
"osm_type": feature.osm_type,
|
||||
"osm_id": feature.osm_id,
|
||||
"ref": feature.ref,
|
||||
"name": feature.name,
|
||||
"mode": feature.mode,
|
||||
"route_scope": feature.route_scope,
|
||||
"operator": feature.operator,
|
||||
"network": feature.network,
|
||||
},
|
||||
"geometry": _geometry_payload(feature),
|
||||
}
|
||||
for feature in features
|
||||
if (dataset := dataset_by_id.get(feature.dataset_id)) is not None
|
||||
if (source := sources.get(dataset.source_id)) is not None
|
||||
]
|
||||
|
||||
|
||||
def _route_pattern_hits(session: Session, query: str, *, limit: int) -> list[dict]:
|
||||
pattern = f"%{query}%"
|
||||
ref = norm_ref(query)
|
||||
stmt = (
|
||||
select(RoutePattern)
|
||||
.where(
|
||||
or_(
|
||||
RoutePattern.route_ref.ilike(pattern),
|
||||
RoutePattern.route_name.ilike(pattern),
|
||||
RoutePattern.pattern_key.ilike(pattern),
|
||||
)
|
||||
)
|
||||
.order_by(RoutePattern.source_kind, RoutePattern.route_ref, RoutePattern.id)
|
||||
.limit(limit)
|
||||
)
|
||||
rows = session.scalars(stmt).all()
|
||||
return [
|
||||
{
|
||||
"type": "route_pattern",
|
||||
"id": pattern_row.id,
|
||||
"ref": pattern_row.route_ref,
|
||||
"name": pattern_row.route_name,
|
||||
"mode": pattern_row.mode,
|
||||
"route_scope": pattern_row.route_scope,
|
||||
"source_kind": pattern_row.source_kind,
|
||||
"status": pattern_row.status,
|
||||
"confidence": pattern_row.confidence,
|
||||
"gtfs_route_id": pattern_row.gtfs_route_id,
|
||||
"osm_feature_id": pattern_row.osm_feature_id,
|
||||
"geometry": _geometry_payload(pattern_row),
|
||||
}
|
||||
for pattern_row in rows
|
||||
if not ref or norm_ref(pattern_row.route_ref or pattern_row.route_name or "") == ref or query.lower() in (pattern_row.route_name or "").lower()
|
||||
]
|
||||
|
||||
|
||||
def _trip_counts(session: Session, route_row_ids: list[int]) -> dict[int, int]:
|
||||
if not route_row_ids:
|
||||
return {}
|
||||
rows = session.execute(
|
||||
select(GtfsRoute.id, func.count(GtfsTrip.id))
|
||||
.join(GtfsTrip, (GtfsTrip.dataset_id == GtfsRoute.dataset_id) & (GtfsTrip.route_id == GtfsRoute.route_id))
|
||||
.where(GtfsRoute.id.in_(route_row_ids))
|
||||
.group_by(GtfsRoute.id)
|
||||
).all()
|
||||
return {int(route_id): int(count) for route_id, count in rows}
|
||||
|
||||
|
||||
def _stop_time_counts(session: Session, route_row_ids: list[int]) -> dict[int, int]:
|
||||
if not route_row_ids:
|
||||
return {}
|
||||
routes = session.scalars(select(GtfsRoute).where(GtfsRoute.id.in_(route_row_ids))).all()
|
||||
sidecar_routes = [route for route in routes if uses_sidecar_stop_times(session, route.dataset_id)]
|
||||
sidecar_route_ids = {route.id for route in sidecar_routes}
|
||||
main_route_ids = [route.id for route in routes if route.id not in sidecar_route_ids]
|
||||
counts: dict[int, int] = {}
|
||||
if main_route_ids:
|
||||
rows = session.execute(
|
||||
select(GtfsRoute.id, func.count(GtfsStopTime.id))
|
||||
.join(GtfsTrip, (GtfsTrip.dataset_id == GtfsRoute.dataset_id) & (GtfsTrip.route_id == GtfsRoute.route_id))
|
||||
.join(GtfsStopTime, (GtfsStopTime.dataset_id == GtfsTrip.dataset_id) & (GtfsStopTime.trip_id == GtfsTrip.trip_id))
|
||||
.where(GtfsRoute.id.in_(main_route_ids))
|
||||
.group_by(GtfsRoute.id)
|
||||
).all()
|
||||
counts.update({int(route_id): int(count) for route_id, count in rows})
|
||||
for route in sidecar_routes:
|
||||
rows = execute_sidecar_query(
|
||||
session,
|
||||
route.dataset_id,
|
||||
"""
|
||||
SELECT COUNT(*) AS count
|
||||
FROM gtfs_stop_times AS stop_times
|
||||
JOIN gtfs_trips AS trips
|
||||
ON trips.trip_id = stop_times.trip_id
|
||||
WHERE trips.route_id = ?
|
||||
""",
|
||||
[route.route_id],
|
||||
)
|
||||
counts[int(route.id)] = int(rows[0]["count"] or 0) if rows else 0
|
||||
return counts
|
||||
|
||||
|
||||
def _shape_counts(session: Session, route_row_ids: list[int]) -> dict[int, int]:
|
||||
if not route_row_ids:
|
||||
return {}
|
||||
rows = session.execute(
|
||||
select(GtfsRoute.id, func.count(func.distinct(GtfsShape.shape_id)))
|
||||
.join(GtfsTrip, (GtfsTrip.dataset_id == GtfsRoute.dataset_id) & (GtfsTrip.route_id == GtfsRoute.route_id))
|
||||
.join(GtfsShape, (GtfsShape.dataset_id == GtfsTrip.dataset_id) & (GtfsShape.shape_id == GtfsTrip.shape_id))
|
||||
.where(GtfsRoute.id.in_(route_row_ids))
|
||||
.group_by(GtfsRoute.id)
|
||||
).all()
|
||||
return {int(route_id): int(count) for route_id, count in rows}
|
||||
|
||||
|
||||
def _source_payload(source: Source) -> dict:
|
||||
return {"id": source.id, "name": source.name, "kind": source.kind, "country": source.country}
|
||||
|
||||
|
||||
def _dataset_payload(dataset: Dataset) -> dict:
|
||||
return {
|
||||
"id": dataset.id,
|
||||
"kind": dataset.kind,
|
||||
"is_active": dataset.is_active,
|
||||
"status": dataset.status,
|
||||
"created_at": dataset.created_at.isoformat() if dataset.created_at else None,
|
||||
"sha256": dataset.sha256,
|
||||
}
|
||||
|
||||
|
||||
def _geometry_payload(row) -> dict:
|
||||
bbox = None
|
||||
if all(getattr(row, attr, None) is not None for attr in ("min_lon", "min_lat", "max_lon", "max_lat")):
|
||||
bbox = {
|
||||
"min_lon": row.min_lon,
|
||||
"min_lat": row.min_lat,
|
||||
"max_lon": row.max_lon,
|
||||
"max_lat": row.max_lat,
|
||||
}
|
||||
return {"present": bool(getattr(row, "geometry_geojson", None)), "bbox": bbox}
|
||||
339
app/db.py
Normal file
339
app/db.py
Normal file
@@ -0,0 +1,339 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Iterator
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
def _connect_args() -> dict:
|
||||
if settings.is_sqlite_database:
|
||||
return {"check_same_thread": False, "timeout": settings.sqlite_timeout_seconds}
|
||||
return {}
|
||||
|
||||
|
||||
def _ensure_sqlite_parent() -> None:
|
||||
if not settings.is_sqlite_database:
|
||||
return
|
||||
# sqlite:///./data/workbench.sqlite -> ./data/workbench.sqlite
|
||||
path = settings.normalized_database_url.replace("sqlite:///", "", 1)
|
||||
if path and path != ":memory:":
|
||||
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
_ensure_sqlite_parent()
|
||||
engine = create_engine(settings.normalized_database_url, connect_args=_connect_args(), pool_pre_ping=True, future=True)
|
||||
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False, expire_on_commit=False, future=True)
|
||||
|
||||
_CREATE_INDEX_NAME_RE = re.compile(
|
||||
r"CREATE\s+(?:UNIQUE\s+)?INDEX\s+(?:CONCURRENTLY\s+)?(?:IF\s+NOT\s+EXISTS\s+)?([A-Za-z_][A-Za-z0-9_]*)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
if settings.is_sqlite_database:
|
||||
@event.listens_for(engine, "connect")
|
||||
def _set_sqlite_pragmas(dbapi_connection, _connection_record) -> None:
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute(f"PRAGMA busy_timeout={int(settings.sqlite_busy_timeout_ms)}")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
# Import models so metadata is populated.
|
||||
from app import models # noqa: F401
|
||||
|
||||
_ensure_database_extensions()
|
||||
Base.metadata.create_all(bind=engine)
|
||||
_ensure_runtime_columns()
|
||||
_ensure_runtime_indexes()
|
||||
|
||||
|
||||
def reset_db() -> None:
|
||||
from app import models # noqa: F401
|
||||
|
||||
_ensure_database_extensions()
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
_ensure_runtime_columns()
|
||||
_ensure_runtime_indexes()
|
||||
|
||||
|
||||
def _ensure_database_extensions() -> None:
|
||||
if not settings.is_postgresql_database:
|
||||
return
|
||||
with engine.begin() as conn:
|
||||
conn.execute(text("CREATE EXTENSION IF NOT EXISTS postgis"))
|
||||
conn.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm"))
|
||||
has_pgrouting = conn.execute(text("SELECT EXISTS (SELECT 1 FROM pg_available_extensions WHERE name = 'pgrouting')")).scalar()
|
||||
if has_pgrouting:
|
||||
conn.execute(text("CREATE EXTENSION IF NOT EXISTS pgrouting"))
|
||||
|
||||
|
||||
def _ensure_runtime_columns() -> None:
|
||||
if settings.is_postgresql_database:
|
||||
_ensure_postgresql_runtime_columns()
|
||||
return
|
||||
if not settings.is_sqlite_database:
|
||||
return
|
||||
with engine.begin() as conn:
|
||||
columns = {row[1] for row in conn.execute(text("PRAGMA table_info(gtfs_stop_times)")).all()}
|
||||
if "arrival_seconds" not in columns:
|
||||
conn.execute(text("ALTER TABLE gtfs_stop_times ADD COLUMN arrival_seconds INTEGER"))
|
||||
if "departure_seconds" not in columns:
|
||||
conn.execute(text("ALTER TABLE gtfs_stop_times ADD COLUMN departure_seconds INTEGER"))
|
||||
|
||||
source_columns = {row[1] for row in conn.execute(text("PRAGMA table_info(sources)")).all()}
|
||||
source_runtime_columns = {
|
||||
"catalog_entry_id": "INTEGER",
|
||||
"priority": "VARCHAR(16)",
|
||||
"mode_scope": "TEXT",
|
||||
"source_basis": "TEXT",
|
||||
"notes": "TEXT",
|
||||
}
|
||||
for column_name, column_type in source_runtime_columns.items():
|
||||
if column_name not in source_columns:
|
||||
conn.execute(text(f"ALTER TABLE sources ADD COLUMN {column_name} {column_type}"))
|
||||
|
||||
job_columns = {row[1] for row in conn.execute(text("PRAGMA table_info(jobs)")).all()}
|
||||
job_runtime_columns = {
|
||||
"priority": "INTEGER NOT NULL DEFAULT 0",
|
||||
"requested_action": "VARCHAR(32)",
|
||||
"lease_owner": "VARCHAR(255)",
|
||||
"lease_expires_at": "DATETIME",
|
||||
"paused_at": "DATETIME",
|
||||
"dismissed_at": "DATETIME",
|
||||
}
|
||||
for column_name, column_type in job_runtime_columns.items():
|
||||
if column_name not in job_columns:
|
||||
conn.execute(text(f"ALTER TABLE jobs ADD COLUMN {column_name} {column_type}"))
|
||||
|
||||
route_runtime_tables = {
|
||||
"gtfs_routes": "VARCHAR(64)",
|
||||
"route_patterns": "VARCHAR(64)",
|
||||
"osm_features": "VARCHAR(64)",
|
||||
}
|
||||
for table_name, column_type in route_runtime_tables.items():
|
||||
table_columns = {row[1] for row in conn.execute(text(f"PRAGMA table_info({table_name})")).all()}
|
||||
if "route_scope" not in table_columns:
|
||||
conn.execute(text(f"ALTER TABLE {table_name} ADD COLUMN route_scope {column_type}"))
|
||||
address_columns = {row[1] for row in conn.execute(text("PRAGMA table_info(osm_addresses)")).all()}
|
||||
if "geometry_geojson" not in address_columns:
|
||||
conn.execute(text("ALTER TABLE osm_addresses ADD COLUMN geometry_geojson TEXT"))
|
||||
|
||||
|
||||
def _ensure_postgresql_runtime_columns() -> None:
|
||||
column_statements = [
|
||||
("osm_features", "geom", "ALTER TABLE osm_features ADD COLUMN geom geometry(Geometry, 4326)"),
|
||||
("gtfs_routes", "geom", "ALTER TABLE gtfs_routes ADD COLUMN geom geometry(Geometry, 4326)"),
|
||||
("gtfs_shapes", "geom", "ALTER TABLE gtfs_shapes ADD COLUMN geom geometry(Geometry, 4326)"),
|
||||
("route_patterns", "geom", "ALTER TABLE route_patterns ADD COLUMN geom geometry(Geometry, 4326)"),
|
||||
("osm_addresses", "geometry_geojson", "ALTER TABLE osm_addresses ADD COLUMN geometry_geojson TEXT"),
|
||||
("osm_addresses", "geom", "ALTER TABLE osm_addresses ADD COLUMN geom geometry(Point, 4326)"),
|
||||
("osm_addresses", "area_geom", "ALTER TABLE osm_addresses ADD COLUMN area_geom geometry(Geometry, 4326)"),
|
||||
("gtfs_stops", "geom", "ALTER TABLE gtfs_stops ADD COLUMN geom geometry(Point, 4326)"),
|
||||
("canonical_stops", "geom", "ALTER TABLE canonical_stops ADD COLUMN geom geometry(Point, 4326)"),
|
||||
("routing_nodes", "geom", "ALTER TABLE routing_nodes ADD COLUMN geom geometry(Point, 4326)"),
|
||||
("routing_edges", "geom", "ALTER TABLE routing_edges ADD COLUMN geom geometry(LineString, 4326)"),
|
||||
]
|
||||
with engine.begin() as conn:
|
||||
columns = _postgresql_columns(conn)
|
||||
for table_name, column_name, statement in column_statements:
|
||||
if (table_name, column_name) not in columns:
|
||||
conn.execute(text(statement))
|
||||
country_column = columns.get(("osm_addresses", "country"))
|
||||
if country_column is not None and country_column["data_type"] != "text":
|
||||
conn.execute(text("ALTER TABLE osm_addresses ALTER COLUMN country TYPE TEXT"))
|
||||
|
||||
|
||||
def _ensure_runtime_indexes() -> None:
|
||||
statements = [
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_features_map_bbox ON osm_features (dataset_id, kind, mode, min_lon, max_lon, min_lat, max_lat)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_features_scope_bbox ON osm_features (dataset_id, kind, mode, route_scope, min_lon, max_lon, min_lat, max_lat)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_routes_map_bbox ON gtfs_routes (dataset_id, mode, min_lon, max_lon, min_lat, max_lat)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_routes_scope_bbox ON gtfs_routes (dataset_id, mode, route_scope, min_lon, max_lon, min_lat, max_lat)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_stops_map_point ON gtfs_stops (dataset_id, lon, lat)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop ON gtfs_stop_times (dataset_id, stop_id, departure_seconds, trip_id, stop_sequence)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop_depart_trip ON gtfs_stop_times (dataset_id, stop_id, departure_seconds, trip_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop_arrival ON gtfs_stop_times (dataset_id, stop_id, arrival_seconds, trip_id, stop_sequence)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop_arrive_trip ON gtfs_stop_times (dataset_id, stop_id, arrival_seconds, trip_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_trip_seq ON gtfs_stop_times (dataset_id, trip_id, stop_sequence)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_trip_stop_seq ON gtfs_stop_times (dataset_id, trip_id, stop_id, stop_sequence)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_trip ON gtfs_trips (dataset_id, trip_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_route ON gtfs_trips (dataset_id, route_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_service ON gtfs_trips (dataset_id, service_id, trip_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_route_service ON gtfs_trips (dataset_id, route_id, service_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_routes_dataset_route ON gtfs_routes (dataset_id, route_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_shapes_dataset_shape ON gtfs_shapes (dataset_id, shape_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_calendars_dataset_service_dates ON gtfs_calendars (dataset_id, service_id, start_date, end_date)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_calendar_dates_dataset_date ON gtfs_calendar_dates (dataset_id, date, service_id, exception_type)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_canonical_stop_links_object ON canonical_stop_links (object_type, dataset_id, object_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_canonical_stop_links_external ON canonical_stop_links (object_type, dataset_id, external_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_route_patterns_ref_mode ON route_patterns (route_ref, mode, source_kind)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_route_patterns_bbox ON route_patterns (mode, min_lon, max_lon, min_lat, max_lat)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_route_patterns_scope_bbox ON route_patterns (mode, route_scope, source_kind, min_lon, max_lon, min_lat, max_lat)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_route_pattern_links_trip_shape ON gtfs_route_pattern_links (dataset_id, route_id, shape_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_trip_route_pattern_links_trip ON gtfs_trip_route_pattern_links (dataset_id, trip_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_trip_route_pattern_links_pattern ON gtfs_trip_route_pattern_links (route_pattern_id, dataset_id, trip_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_sources_catalog_entry ON sources (catalog_entry_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_sources_priority_country_kind ON sources (priority, country, kind)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_source_catalog_country_priority ON source_catalog_entries (country_code, priority, status)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_source_catalog_name ON source_catalog_entries (source_name)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_source_update_checks_source_checked ON source_update_checks (source_id, checked_at)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_source_update_checks_available ON source_update_checks (source_id, update_available, checked_at)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_diff_states_source_sequence ON osm_diff_states (source_id, sequence_number)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_diff_states_source_status ON osm_diff_states (source_id, status, updated_at)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_jobs_status_created ON jobs (status, created_at)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_jobs_kind_status ON jobs (kind, status)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_jobs_queue_claim ON jobs (status, priority, created_at, id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_jobs_lease ON jobs (status, lease_expires_at)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_jobs_dismissed_status ON jobs (dismissed_at, status, created_at)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_job_events_job_created ON job_events (job_id, created_at, id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_pipeline_runs_stage_dataset_hash ON pipeline_runs (stage, dataset_id, dependency_hash, status, started_at)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_pipeline_runs_stage_source_hash ON pipeline_runs (stage, source_id, dependency_hash, status, started_at)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_pipeline_runs_job ON pipeline_runs (job_id, stage, status)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_match_rules_type_active ON match_rules (rule_type, active)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_journey_search_cache_type_expires ON journey_search_cache (cache_type, expires_at)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_travel_requests_created ON travel_requests (created_at)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_itineraries_request_saved ON itineraries (request_id, saved, created_at)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_itinerary_legs_itinerary_sequence ON itinerary_legs (itinerary_id, sequence)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_routing_nodes_dataset_osm ON routing_nodes (dataset_id, osm_node_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_source ON routing_edges (dataset_id, source_osm_node_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_target ON routing_edges (dataset_id, target_osm_node_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_drive ON routing_edges (dataset_id, source_osm_node_id) WHERE drive_cost_s IS NOT NULL",
|
||||
"CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_walk ON routing_edges (dataset_id, source_osm_node_id) WHERE walk_cost_s IS NOT NULL",
|
||||
"CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_reverse_drive ON routing_edges (dataset_id, target_osm_node_id) WHERE reverse_drive_cost_s IS NOT NULL",
|
||||
"CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_reverse_walk ON routing_edges (dataset_id, target_osm_node_id) WHERE reverse_walk_cost_s IS NOT NULL",
|
||||
"CREATE INDEX IF NOT EXISTS ix_routing_edges_bbox ON routing_edges (dataset_id, min_lon, max_lon, min_lat, max_lat)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_dataset_city_street ON osm_addresses (dataset_id, city, street, housenumber)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_dataset_postcode ON osm_addresses (dataset_id, postcode)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_bbox ON osm_addresses (dataset_id, min_lon, max_lon, min_lat, max_lat)",
|
||||
]
|
||||
with engine.begin() as conn:
|
||||
if settings.is_sqlite_database:
|
||||
conn.execute(text("PRAGMA journal_mode=WAL"))
|
||||
conn.execute(text(f"PRAGMA busy_timeout={int(settings.sqlite_busy_timeout_ms)}"))
|
||||
if settings.is_postgresql_database:
|
||||
_execute_missing_postgresql_indexes(conn, statements + _postgresql_index_statements())
|
||||
else:
|
||||
for statement in statements:
|
||||
conn.execute(text(statement))
|
||||
|
||||
|
||||
def _postgresql_columns(conn: Connection) -> dict[tuple[str, str], dict[str, str]]:
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT table_name, column_name, data_type, udt_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = ANY (current_schemas(false))
|
||||
"""
|
||||
)
|
||||
).mappings()
|
||||
return {
|
||||
(str(row["table_name"]), str(row["column_name"])): {
|
||||
"data_type": str(row["data_type"]),
|
||||
"udt_name": str(row["udt_name"]),
|
||||
}
|
||||
for row in rows
|
||||
}
|
||||
|
||||
|
||||
def _execute_missing_postgresql_indexes(conn: Connection, statements: list[str]) -> None:
|
||||
existing = _postgresql_index_names(conn)
|
||||
for statement in statements:
|
||||
index_name = _index_name_from_create_statement(statement)
|
||||
if index_name and index_name in existing:
|
||||
continue
|
||||
conn.execute(text(statement))
|
||||
if index_name:
|
||||
existing.add(index_name)
|
||||
|
||||
|
||||
def _postgresql_index_names(conn: Connection) -> set[str]:
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT indexname
|
||||
FROM pg_indexes
|
||||
WHERE schemaname = ANY (current_schemas(false))
|
||||
"""
|
||||
)
|
||||
)
|
||||
return {str(row[0]) for row in rows}
|
||||
|
||||
|
||||
def _index_name_from_create_statement(statement: str) -> str | None:
|
||||
match = _CREATE_INDEX_NAME_RE.search(statement)
|
||||
return match.group(1) if match else None
|
||||
|
||||
|
||||
def _postgresql_index_statements() -> list[str]:
|
||||
return [
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_features_geom_gist ON osm_features USING GIST (geom)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_features_stop_geom_gist ON osm_features USING GIST (geom) WHERE kind IN ('stop', 'station', 'terminal')",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_features_route_geom_gist ON osm_features USING GIST (geom) WHERE kind = 'route'",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_stops_geom_gist ON gtfs_stops USING GIST (geom)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_canonical_stops_geom_gist ON canonical_stops USING GIST (geom)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_routes_geom_gist ON gtfs_routes USING GIST (geom)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_shapes_geom_gist ON gtfs_shapes USING GIST (geom)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_route_patterns_geom_gist ON route_patterns USING GIST (geom)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_geom_gist ON osm_addresses USING GIST (geom)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_area_geom_gist ON osm_addresses USING GIST (area_geom)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_routing_nodes_geom_gist ON routing_nodes USING GIST (geom)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_routing_edges_bbox_box_gist ON routing_edges USING GIST (box(point(max_lon, max_lat), point(min_lon, min_lat)))",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_route_shape_expr ON gtfs_trips (dataset_id, route_id, (COALESCE(shape_id, '__route__')))",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_dataset_stop ON gtfs_stop_times (dataset_id, stop_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_canonical_stop_links_gtfs_external ON canonical_stop_links (dataset_id, external_id, canonical_stop_id) WHERE object_type = 'gtfs_stop'",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_stops_dataset_parent ON gtfs_stops (dataset_id, parent_station)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_stops_dataset_stop_prefix ON gtfs_stops (dataset_id, (split_part(stop_id, '::', 1)))",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_features_name_trgm ON osm_features USING GIN (LOWER(COALESCE(name, '')) gin_trgm_ops)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_features_ref_trgm ON osm_features USING GIN (LOWER(COALESCE(ref, '')) gin_trgm_ops)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_features_tags_trgm ON osm_features USING GIN (LOWER(COALESCE(tags_json, '')) gin_trgm_ops)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_search_trgm ON osm_addresses USING GIN (LOWER(COALESCE(search_text, '')) gin_trgm_ops)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_display_trgm ON osm_addresses USING GIN (LOWER(COALESCE(display_name, '')) gin_trgm_ops)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_street_key_house ON osm_addresses (dataset_id, REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss'), housenumber)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_street_key_trgm ON osm_addresses USING GIN (REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss') gin_trgm_ops)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_stops_name_trgm ON gtfs_stops USING GIN (name gin_trgm_ops)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_gtfs_stops_stop_id_trgm ON gtfs_stops USING GIN (stop_id gin_trgm_ops)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_route_patterns_ref_trgm ON route_patterns USING GIN (LOWER(COALESCE(route_ref, '')) gin_trgm_ops)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_route_patterns_name_trgm ON route_patterns USING GIN (LOWER(COALESCE(route_name, '')) gin_trgm_ops)",
|
||||
]
|
||||
|
||||
|
||||
def get_db() -> Iterator[Session]:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_scope() -> Iterator[Session]:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
db.commit()
|
||||
except Exception:
|
||||
db.rollback()
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
211
app/db_lock.py
Normal file
211
app/db_lock.py
Normal file
@@ -0,0 +1,211 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import threading
|
||||
import time
|
||||
from typing import Iterator
|
||||
|
||||
from app.config import settings
|
||||
|
||||
try:
|
||||
import fcntl
|
||||
except ImportError: # pragma: no cover - this app currently targets Linux/macOS dev hosts
|
||||
fcntl = None # type: ignore[assignment]
|
||||
|
||||
|
||||
class DatabaseWriteBusy(RuntimeError):
|
||||
def __init__(self, operation: str, active: dict[str, object] | None = None) -> None:
|
||||
self.operation = operation
|
||||
self.active = active or {}
|
||||
active_operation = self.active.get("operation")
|
||||
detail = f"Database is busy with another write operation"
|
||||
if active_operation:
|
||||
detail += f": {active_operation}"
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DatabaseWriteState:
|
||||
locked: bool
|
||||
operation: str | None = None
|
||||
pid: int | None = None
|
||||
started_at: float | None = None
|
||||
|
||||
@property
|
||||
def elapsed_seconds(self) -> float | None:
|
||||
if self.started_at is None:
|
||||
return None
|
||||
return max(0.0, time.time() - self.started_at)
|
||||
|
||||
|
||||
_process_write_lock = threading.Lock()
|
||||
_state_lock = threading.Lock()
|
||||
_state = DatabaseWriteState(locked=False)
|
||||
|
||||
|
||||
def is_sqlite_database() -> bool:
|
||||
return settings.is_sqlite_database
|
||||
|
||||
|
||||
@contextmanager
|
||||
def database_write_lock(operation: str, timeout: float | None = None) -> Iterator[None]:
|
||||
"""Serialize SQLite writes inside and across app processes.
|
||||
|
||||
SQLite allows only one writer. This lock prevents mutating endpoints from
|
||||
competing until SQLite times out with a low-level "database is locked" error.
|
||||
"""
|
||||
if not is_sqlite_database():
|
||||
yield
|
||||
return
|
||||
|
||||
effective_timeout = settings.database_write_lock_timeout_seconds if timeout is None else timeout
|
||||
deadline = None if effective_timeout is None else time.monotonic() + max(0.0, effective_timeout)
|
||||
if not _acquire_process_lock(deadline):
|
||||
raise DatabaseWriteBusy(operation, database_write_status().__dict__)
|
||||
|
||||
handle = None
|
||||
file_locked = False
|
||||
try:
|
||||
lock_path = _lock_path()
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
handle = _open_locked_handle(lock_path, deadline)
|
||||
if handle is None:
|
||||
raise DatabaseWriteBusy(operation, _read_lock_metadata(lock_path))
|
||||
file_locked = True
|
||||
_write_lock_metadata(handle, operation)
|
||||
_set_state(DatabaseWriteState(locked=True, operation=operation, pid=os.getpid(), started_at=time.time()))
|
||||
yield
|
||||
finally:
|
||||
_set_state(DatabaseWriteState(locked=False))
|
||||
if handle is not None:
|
||||
if file_locked and fcntl is not None:
|
||||
try:
|
||||
fcntl.flock(handle.fileno(), fcntl.LOCK_UN)
|
||||
except OSError:
|
||||
pass
|
||||
handle.close()
|
||||
if file_locked:
|
||||
try:
|
||||
_lock_path().unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError:
|
||||
pass
|
||||
_process_write_lock.release()
|
||||
|
||||
|
||||
def database_write_status() -> DatabaseWriteState:
|
||||
with _state_lock:
|
||||
return _state
|
||||
|
||||
|
||||
def _acquire_process_lock(deadline: float | None) -> bool:
|
||||
while True:
|
||||
if _process_write_lock.acquire(blocking=False):
|
||||
return True
|
||||
if deadline is not None and time.monotonic() >= deadline:
|
||||
return False
|
||||
time.sleep(0.05)
|
||||
|
||||
|
||||
def _acquire_file_lock(handle, deadline: float | None) -> bool:
|
||||
if fcntl is None:
|
||||
return True
|
||||
while True:
|
||||
try:
|
||||
fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
return True
|
||||
except BlockingIOError:
|
||||
if deadline is not None and time.monotonic() >= deadline:
|
||||
return False
|
||||
time.sleep(0.05)
|
||||
|
||||
|
||||
def _open_locked_handle(lock_path: Path, deadline: float | None):
|
||||
while True:
|
||||
try:
|
||||
lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
handle = lock_path.open("a+", encoding="utf-8")
|
||||
except FileNotFoundError:
|
||||
if deadline is not None and time.monotonic() >= deadline:
|
||||
return None
|
||||
time.sleep(0.05)
|
||||
continue
|
||||
if _try_file_lock(handle):
|
||||
return handle
|
||||
metadata = _read_lock_metadata(lock_path)
|
||||
handle.close()
|
||||
if not _lock_metadata_is_stale(metadata):
|
||||
if deadline is not None and time.monotonic() >= deadline:
|
||||
return None
|
||||
time.sleep(0.05)
|
||||
continue
|
||||
try:
|
||||
lock_path.unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except OSError:
|
||||
return None
|
||||
if deadline is not None and time.monotonic() >= deadline:
|
||||
return None
|
||||
|
||||
|
||||
def _try_file_lock(handle) -> bool:
|
||||
if fcntl is None:
|
||||
return True
|
||||
try:
|
||||
fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
except BlockingIOError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _lock_metadata_is_stale(metadata: dict[str, object]) -> bool:
|
||||
pid = metadata.get("pid")
|
||||
try:
|
||||
pid_int = int(pid) # type: ignore[arg-type]
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
if pid_int <= 0 or pid_int == os.getpid():
|
||||
return False
|
||||
return not _pid_exists(pid_int)
|
||||
|
||||
|
||||
def _pid_exists(pid: int) -> bool:
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
except ProcessLookupError:
|
||||
return False
|
||||
except PermissionError:
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
def _set_state(state: DatabaseWriteState) -> None:
|
||||
global _state
|
||||
with _state_lock:
|
||||
_state = state
|
||||
|
||||
|
||||
def _lock_path() -> Path:
|
||||
return settings.data_dir / "workbench.write.lock"
|
||||
|
||||
|
||||
def _write_lock_metadata(handle, operation: str) -> None:
|
||||
handle.seek(0)
|
||||
handle.truncate()
|
||||
json.dump({"operation": operation, "pid": os.getpid(), "started_at": time.time()}, handle, separators=(",", ":"))
|
||||
handle.flush()
|
||||
os.fsync(handle.fileno())
|
||||
|
||||
|
||||
def _read_lock_metadata(path: Path) -> dict[str, object]:
|
||||
try:
|
||||
text = path.read_text(encoding="utf-8").strip()
|
||||
return json.loads(text) if text else {}
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return {}
|
||||
923
app/feed_discovery.py
Normal file
923
app/feed_discovery.py
Normal file
@@ -0,0 +1,923 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from html import unescape
|
||||
from html.parser import HTMLParser
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
from urllib.parse import parse_qs, urljoin, urlparse
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
MOBILITY_DATABASE_FEEDS_URL = "https://files.mobilitydatabase.org/feeds_v2.csv"
|
||||
MOBILITY_DATABASE_ACCEPTANCE_TEST_URL = (
|
||||
"https://raw.githubusercontent.com/MobilityData/gtfs-validator/master/"
|
||||
"scripts/mobility-database-harvester/acceptance_test_feed_list.csv"
|
||||
)
|
||||
PTNA_GTFS_INDEX_URL = "https://ptna.openstreetmap.de/gtfs/index.html"
|
||||
PTNA_COUNTRY_URL_TEMPLATE = "https://ptna.openstreetmap.de/gtfs/{country}/index.php"
|
||||
|
||||
DEFAULT_DISCOVERY_COUNTRIES = ["DE", "AT", "CH", "NL", "DK", "FR", "BE", "LU", "NO", "SE", "FI", "IE", "GB"]
|
||||
CURATED_TEST_COUNTRIES = ["DE", "CH", "AT", "NL", "DK", "FI", "NO", "SE", "IE", "GB", "FR", "BE", "LU"]
|
||||
DIRECT_INGEST_HEADERS = ["name", "kind", "url", "country", "license", "mode_scope", "source_basis", "priority", "notes"]
|
||||
CANONICAL_HEADERS = [
|
||||
"candidate_id",
|
||||
"discovery_source",
|
||||
"country",
|
||||
"subdivision",
|
||||
"provider",
|
||||
"feed_name",
|
||||
"stable_id",
|
||||
"ptna_feed_id",
|
||||
"data_type",
|
||||
"status",
|
||||
"is_official",
|
||||
"selected_url",
|
||||
"direct_download_url",
|
||||
"latest_url",
|
||||
"original_release_url",
|
||||
"license_url",
|
||||
"license_text",
|
||||
"osm_license_text",
|
||||
"details_url",
|
||||
"routes_url",
|
||||
"valid_from",
|
||||
"valid_to",
|
||||
"release_date",
|
||||
"feed_version",
|
||||
"bbox",
|
||||
"features",
|
||||
"priority",
|
||||
"availability_status",
|
||||
"http_status",
|
||||
"content_type",
|
||||
"content_length",
|
||||
"final_url",
|
||||
"source_basis",
|
||||
"notes",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeedCandidate:
|
||||
discovery_source: str
|
||||
country: str = ""
|
||||
subdivision: str = ""
|
||||
provider: str = ""
|
||||
feed_name: str = ""
|
||||
stable_id: str = ""
|
||||
ptna_feed_id: str = ""
|
||||
data_type: str = "gtfs"
|
||||
status: str = ""
|
||||
is_official: str = ""
|
||||
selected_url: str = ""
|
||||
direct_download_url: str = ""
|
||||
latest_url: str = ""
|
||||
original_release_url: str = ""
|
||||
license_url: str = ""
|
||||
license_text: str = ""
|
||||
osm_license_text: str = ""
|
||||
details_url: str = ""
|
||||
routes_url: str = ""
|
||||
valid_from: str = ""
|
||||
valid_to: str = ""
|
||||
release_date: str = ""
|
||||
feed_version: str = ""
|
||||
bbox: str = ""
|
||||
features: str = ""
|
||||
priority: str = ""
|
||||
availability_status: str = "unchecked"
|
||||
http_status: str = ""
|
||||
content_type: str = ""
|
||||
content_length: str = ""
|
||||
final_url: str = ""
|
||||
source_basis: str = ""
|
||||
notes: str = ""
|
||||
evidence_sources: list[str] = field(default_factory=list)
|
||||
|
||||
def key(self) -> str:
|
||||
if self.stable_id:
|
||||
return f"stable:{self.stable_id}"
|
||||
if self.selected_url:
|
||||
return f"url:{_normalize_url_key(self.selected_url)}"
|
||||
if self.ptna_feed_id:
|
||||
return f"ptna:{self.ptna_feed_id}"
|
||||
return "hash:" + hashlib.sha256(json.dumps(self.row(), sort_keys=True).encode("utf-8")).hexdigest()
|
||||
|
||||
def candidate_id(self) -> str:
|
||||
seed = "|".join(
|
||||
[
|
||||
self.discovery_source,
|
||||
self.country,
|
||||
self.stable_id,
|
||||
self.ptna_feed_id,
|
||||
self.selected_url,
|
||||
self.provider,
|
||||
self.feed_name,
|
||||
]
|
||||
)
|
||||
return hashlib.sha256(seed.encode("utf-8")).hexdigest()[:16]
|
||||
|
||||
def row(self) -> dict[str, str]:
|
||||
payload = {header: _string(getattr(self, header, "")) for header in CANONICAL_HEADERS if header != "candidate_id"}
|
||||
payload["candidate_id"] = self.candidate_id()
|
||||
return payload
|
||||
|
||||
def ingestable_row(self) -> dict[str, str]:
|
||||
name = _feed_source_name(self.country, self.provider or self.feed_name)
|
||||
license_value = self.license_text or (f"see {self.license_url}" if self.license_url else "")
|
||||
basis_parts = [self.source_basis or self.discovery_source]
|
||||
if self.details_url:
|
||||
basis_parts.append(f"details: {self.details_url}")
|
||||
if self.original_release_url and self.original_release_url != self.selected_url:
|
||||
basis_parts.append(f"release: {self.original_release_url}")
|
||||
notes = self.notes or ""
|
||||
if self.latest_url and self.latest_url != self.selected_url:
|
||||
notes = _join_notes(notes, f"Mobility Database mirror: {self.latest_url}")
|
||||
if self.osm_license_text:
|
||||
notes = _join_notes(notes, f"OSM permission note: {_truncate(self.osm_license_text, 240)}")
|
||||
return {
|
||||
"name": _truncate(name, 240),
|
||||
"kind": "gtfs",
|
||||
"url": self.selected_url,
|
||||
"country": self.country,
|
||||
"license": _truncate(license_value, 240),
|
||||
"mode_scope": _mode_scope_from_features(self.features),
|
||||
"source_basis": _truncate("; ".join(part for part in basis_parts if part), 500),
|
||||
"priority": self.priority or _candidate_priority(self),
|
||||
"notes": _truncate(notes, 1200),
|
||||
}
|
||||
|
||||
|
||||
def default_generated_dir() -> Path:
|
||||
return Path(__file__).resolve().parents[1] / "docs" / "generated"
|
||||
|
||||
|
||||
def build_gtfs_discovery_manifests(
|
||||
*,
|
||||
output_dir: Path | str | None = None,
|
||||
countries: Iterable[str] | None = None,
|
||||
include_mobility_database: bool = True,
|
||||
include_acceptance_test_list: bool = True,
|
||||
include_ptna: bool = True,
|
||||
max_ptna_details: int = 80,
|
||||
test_limit: int = 24,
|
||||
check_urls: bool = False,
|
||||
timeout: float = 30.0,
|
||||
) -> dict[str, object]:
|
||||
selected_countries = _normalize_countries(countries)
|
||||
out_dir = Path(output_dir) if output_dir is not None else default_generated_dir()
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
candidates: list[FeedCandidate] = []
|
||||
candidates.extend(load_curated_ingestable_seed(countries=selected_countries))
|
||||
if include_mobility_database:
|
||||
candidates.extend(fetch_mobility_database_candidates(countries=selected_countries, timeout=timeout))
|
||||
if include_acceptance_test_list:
|
||||
candidates.extend(fetch_mobility_acceptance_candidates(countries=selected_countries, timeout=timeout))
|
||||
if include_ptna:
|
||||
candidates.extend(fetch_ptna_candidates(countries=selected_countries, max_details=max_ptna_details, timeout=timeout))
|
||||
|
||||
merged = merge_candidates(candidates)
|
||||
ingestable = [candidate for candidate in merged if candidate.selected_url and candidate.data_type == "gtfs"]
|
||||
if check_urls:
|
||||
for candidate in ingestable:
|
||||
annotate_url_availability(candidate, timeout=min(timeout, 12.0))
|
||||
test_run = select_test_run_candidates(ingestable, limit=test_limit)
|
||||
|
||||
candidates_path = out_dir / "gtfs_feed_candidates.csv"
|
||||
ingestable_path = out_dir / "gtfs_ingestable_sources.csv"
|
||||
test_path = out_dir / "gtfs_test_run_sources.csv"
|
||||
report_path = out_dir / "gtfs_discovery_report.json"
|
||||
|
||||
_write_csv(candidates_path, CANONICAL_HEADERS, [candidate.row() for candidate in merged])
|
||||
_write_csv(ingestable_path, DIRECT_INGEST_HEADERS, [candidate.ingestable_row() for candidate in ingestable])
|
||||
_write_csv(test_path, DIRECT_INGEST_HEADERS, [candidate.ingestable_row() for candidate in test_run])
|
||||
|
||||
by_source = _count_by(merged, lambda item: item.discovery_source)
|
||||
by_country = _count_by(ingestable, lambda item: item.country or "unknown")
|
||||
report = {
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"countries": selected_countries or "all",
|
||||
"sources": {
|
||||
"mobility_database": MOBILITY_DATABASE_FEEDS_URL if include_mobility_database else None,
|
||||
"mobility_acceptance_test_list": MOBILITY_DATABASE_ACCEPTANCE_TEST_URL if include_acceptance_test_list else None,
|
||||
"ptna": PTNA_GTFS_INDEX_URL if include_ptna else None,
|
||||
},
|
||||
"counts": {
|
||||
"candidates": len(merged),
|
||||
"ingestable": len(ingestable),
|
||||
"test_run": len(test_run),
|
||||
"by_source": by_source,
|
||||
"ingestable_by_country": by_country,
|
||||
},
|
||||
"files": {
|
||||
"candidates": str(candidates_path),
|
||||
"ingestable": str(ingestable_path),
|
||||
"test_run": str(test_path),
|
||||
},
|
||||
}
|
||||
report_path.write_text(json.dumps(report, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
|
||||
return report
|
||||
|
||||
|
||||
def fetch_mobility_database_candidates(
|
||||
*,
|
||||
countries: list[str] | None = None,
|
||||
timeout: float = 30.0,
|
||||
url: str = MOBILITY_DATABASE_FEEDS_URL,
|
||||
) -> list[FeedCandidate]:
|
||||
text = _fetch_text(url, timeout=timeout)
|
||||
rows = csv.DictReader(text.splitlines())
|
||||
candidates: list[FeedCandidate] = []
|
||||
for row in rows:
|
||||
if _value(row, "data_type").lower() != "gtfs":
|
||||
continue
|
||||
country = _value(row, "location.country_code").upper()
|
||||
if countries and country not in countries:
|
||||
continue
|
||||
direct_url = _normalize_feed_url(_value(row, "urls.direct_download"))
|
||||
latest_url = _normalize_feed_url(_value(row, "urls.latest"))
|
||||
selected_url = _choose_feed_url(direct_url, latest_url)
|
||||
candidate = FeedCandidate(
|
||||
discovery_source="mobility_database",
|
||||
country=country,
|
||||
subdivision=_value(row, "location.subdivision_name"),
|
||||
provider=_value(row, "provider"),
|
||||
feed_name=_value(row, "name"),
|
||||
stable_id=_value(row, "id"),
|
||||
data_type="gtfs",
|
||||
status=_value(row, "status"),
|
||||
is_official=_value(row, "is_official"),
|
||||
selected_url=selected_url,
|
||||
direct_download_url=direct_url,
|
||||
latest_url=latest_url,
|
||||
license_url=_value(row, "urls.license"),
|
||||
bbox=_bbox_from_mobility_row(row),
|
||||
features=_value(row, "features"),
|
||||
source_basis="Mobility Database feed catalog",
|
||||
notes=_value(row, "note"),
|
||||
)
|
||||
normalize_candidate_geography(candidate)
|
||||
apply_known_download_overrides(candidate)
|
||||
candidate.priority = _candidate_priority(candidate)
|
||||
candidates.append(candidate)
|
||||
return candidates
|
||||
|
||||
|
||||
def fetch_mobility_acceptance_candidates(
|
||||
*,
|
||||
countries: list[str] | None = None,
|
||||
timeout: float = 30.0,
|
||||
url: str = MOBILITY_DATABASE_ACCEPTANCE_TEST_URL,
|
||||
) -> list[FeedCandidate]:
|
||||
text = _fetch_text(url, timeout=timeout)
|
||||
rows = csv.DictReader(text.splitlines())
|
||||
candidates: list[FeedCandidate] = []
|
||||
for row in rows:
|
||||
country = _value(row, "country_code").upper()
|
||||
if countries and country not in countries:
|
||||
continue
|
||||
latest_url = _normalize_feed_url(_value(row, "urls.latest"))
|
||||
if not latest_url:
|
||||
continue
|
||||
candidate = FeedCandidate(
|
||||
discovery_source="mobility_validator_acceptance",
|
||||
country=country,
|
||||
subdivision=_value(row, "subdivision_name"),
|
||||
provider=_value(row, "provider"),
|
||||
feed_name=_value(row, "provider"),
|
||||
stable_id=_value(row, "stable_id"),
|
||||
status="acceptance_test",
|
||||
selected_url=latest_url,
|
||||
latest_url=latest_url,
|
||||
source_basis="MobilityData validator acceptance-test feed list",
|
||||
notes="Useful smoke-test feed list; prefer Mobility Database feeds_v2 metadata for production source review.",
|
||||
priority="P3",
|
||||
)
|
||||
normalize_candidate_geography(candidate)
|
||||
apply_known_download_overrides(candidate)
|
||||
candidates.append(candidate)
|
||||
return candidates
|
||||
|
||||
|
||||
def fetch_ptna_candidates(
|
||||
*,
|
||||
countries: list[str] | None = None,
|
||||
max_details: int = 80,
|
||||
timeout: float = 30.0,
|
||||
) -> list[FeedCandidate]:
|
||||
country_codes = countries or DEFAULT_DISCOVERY_COUNTRIES
|
||||
if not country_codes:
|
||||
country_codes = discover_ptna_country_codes(timeout=timeout)
|
||||
candidates: list[FeedCandidate] = []
|
||||
detail_fetches = 0
|
||||
for country in country_codes:
|
||||
country_url = PTNA_COUNTRY_URL_TEMPLATE.format(country=country)
|
||||
try:
|
||||
html = _fetch_text(country_url, timeout=timeout)
|
||||
except requests.RequestException:
|
||||
continue
|
||||
for candidate in parse_ptna_country_page(html, country=country, page_url=country_url):
|
||||
if candidate.details_url and detail_fetches < max_details:
|
||||
try:
|
||||
detail_html = _fetch_text(candidate.details_url, timeout=timeout)
|
||||
enrich_ptna_candidate_from_details(candidate, detail_html, candidate.details_url)
|
||||
detail_fetches += 1
|
||||
except requests.RequestException:
|
||||
candidate.notes = _join_notes(candidate.notes, "PTNA detail page could not be fetched during discovery.")
|
||||
candidate.priority = _candidate_priority(candidate)
|
||||
candidates.append(candidate)
|
||||
return candidates
|
||||
|
||||
|
||||
def discover_ptna_country_codes(*, timeout: float = 30.0) -> list[str]:
|
||||
html = _fetch_text(PTNA_GTFS_INDEX_URL, timeout=timeout)
|
||||
links = _all_links(html, PTNA_GTFS_INDEX_URL)
|
||||
codes: list[str] = []
|
||||
for link in links:
|
||||
match = re.search(r"/gtfs/([A-Z]{2})/index\.php$", urlparse(link).path)
|
||||
if match and match.group(1) not in codes:
|
||||
codes.append(match.group(1))
|
||||
return codes
|
||||
|
||||
|
||||
def parse_ptna_country_page(html: str, *, country: str, page_url: str) -> list[FeedCandidate]:
|
||||
rows = _parse_table_rows(html, page_url)
|
||||
candidates: list[FeedCandidate] = []
|
||||
for row in rows:
|
||||
links = [link for cell in row.cells for link in cell.links]
|
||||
routes_url = _first_link_matching(links, "routes.php?feed=")
|
||||
details_url = _first_link_matching(links, "gtfs-details.php?feed=")
|
||||
if not routes_url and not details_url:
|
||||
continue
|
||||
feed_id = _feed_id_from_url(routes_url or details_url)
|
||||
if not feed_id:
|
||||
continue
|
||||
texts = [cell.text for cell in row.cells]
|
||||
release_link = _normalize_feed_url(row.cells[6].first_external_link if len(row.cells) > 6 else "")
|
||||
direct_url = release_link if _looks_like_download_url(release_link) else ""
|
||||
candidate = FeedCandidate(
|
||||
discovery_source="ptna",
|
||||
country=country,
|
||||
provider=texts[2] if len(texts) > 2 else "",
|
||||
feed_name=texts[1] if len(texts) > 1 else feed_id,
|
||||
ptna_feed_id=feed_id,
|
||||
selected_url=direct_url,
|
||||
direct_download_url=direct_url,
|
||||
original_release_url=release_link,
|
||||
details_url=details_url,
|
||||
routes_url=routes_url,
|
||||
valid_from=texts[3] if len(texts) > 3 else "",
|
||||
valid_to=texts[4] if len(texts) > 4 else "",
|
||||
feed_version=texts[5] if len(texts) > 5 else "",
|
||||
release_date=texts[6] if len(texts) > 6 else "",
|
||||
source_basis="PTNA GTFS analysis",
|
||||
notes="PTNA candidate; use original publisher URL where available.",
|
||||
)
|
||||
normalize_candidate_geography(candidate)
|
||||
apply_known_download_overrides(candidate)
|
||||
candidates.append(candidate)
|
||||
return candidates
|
||||
|
||||
|
||||
def enrich_ptna_candidate_from_details(candidate: FeedCandidate, html: str, page_url: str) -> None:
|
||||
fields = parse_ptna_detail_fields(html, page_url)
|
||||
candidate.original_release_url = _normalize_feed_url(fields.get("release url href") or fields.get("release url") or candidate.original_release_url)
|
||||
candidate.license_url = fields.get("publisher's license href") or candidate.license_url
|
||||
candidate.license_text = fields.get("publisher's license") or candidate.license_text
|
||||
candidate.osm_license_text = fields.get("license given for use in osm") or candidate.osm_license_text
|
||||
candidate.valid_from = fields.get("feed start date") or candidate.valid_from
|
||||
candidate.valid_to = fields.get("feed end date") or candidate.valid_to
|
||||
candidate.feed_version = fields.get("feed version") or candidate.feed_version
|
||||
candidate.release_date = fields.get("release date") or candidate.release_date
|
||||
network_guid = fields.get('"network:guid"')
|
||||
if network_guid:
|
||||
candidate.notes = _join_notes(candidate.notes, f"PTNA network:guid={network_guid}")
|
||||
if not candidate.selected_url and _looks_like_download_url(candidate.original_release_url):
|
||||
candidate.selected_url = _normalize_feed_url(candidate.original_release_url)
|
||||
candidate.direct_download_url = candidate.selected_url
|
||||
normalize_candidate_geography(candidate)
|
||||
|
||||
|
||||
def parse_ptna_detail_fields(html: str, page_url: str) -> dict[str, str]:
|
||||
parsed: dict[str, str] = {}
|
||||
for row in _parse_table_rows(html, page_url):
|
||||
if len(row.cells) < 2:
|
||||
continue
|
||||
label = _clean_text(row.cells[0].text).lower()
|
||||
if not label:
|
||||
continue
|
||||
detail = _clean_text(row.cells[1].text)
|
||||
parsed[label] = detail
|
||||
if row.cells[1].first_external_link:
|
||||
parsed[f"{label} href"] = row.cells[1].first_external_link
|
||||
return parsed
|
||||
|
||||
|
||||
def load_curated_ingestable_seed(
|
||||
*,
|
||||
countries: list[str] | None = None,
|
||||
path: Path | str | None = None,
|
||||
) -> list[FeedCandidate]:
|
||||
seed_path = Path(path) if path is not None else Path(__file__).resolve().parents[1] / "docs" / "ingestable_sources_seed.csv"
|
||||
if not seed_path.exists():
|
||||
return []
|
||||
candidates: list[FeedCandidate] = []
|
||||
with seed_path.open("r", encoding="utf-8-sig", newline="") as handle:
|
||||
for row in csv.DictReader(handle):
|
||||
if _value(row, "kind").lower() != "gtfs":
|
||||
continue
|
||||
country = _value(row, "country").upper()
|
||||
if countries and country not in countries and country != "EU":
|
||||
continue
|
||||
candidate = FeedCandidate(
|
||||
discovery_source="curated_seed",
|
||||
country=country,
|
||||
provider=_value(row, "name").removesuffix(" GTFS"),
|
||||
feed_name=_value(row, "name"),
|
||||
selected_url=_normalize_feed_url(_value(row, "url")),
|
||||
direct_download_url=_normalize_feed_url(_value(row, "url")),
|
||||
license_text=_value(row, "license"),
|
||||
features=_value(row, "mode_scope"),
|
||||
priority=_value(row, "priority"),
|
||||
source_basis=_value(row, "source_basis") or "curated seed",
|
||||
notes=_value(row, "notes"),
|
||||
)
|
||||
normalize_candidate_geography(candidate)
|
||||
apply_known_download_overrides(candidate)
|
||||
candidates.append(candidate)
|
||||
return candidates
|
||||
|
||||
|
||||
def merge_candidates(candidates: Iterable[FeedCandidate]) -> list[FeedCandidate]:
|
||||
by_key: dict[str, FeedCandidate] = {}
|
||||
alias_to_key: dict[str, str] = {}
|
||||
for candidate in candidates:
|
||||
keys = _candidate_alias_keys(candidate)
|
||||
primary_key = keys[0]
|
||||
existing_key = next((alias_to_key[key] for key in keys if key in alias_to_key), None)
|
||||
existing = by_key.get(existing_key) if existing_key is not None else None
|
||||
if existing is None:
|
||||
by_key[primary_key] = candidate
|
||||
for key in keys:
|
||||
alias_to_key[key] = primary_key
|
||||
continue
|
||||
_merge_candidate(existing, candidate)
|
||||
for key in keys:
|
||||
alias_to_key[key] = existing_key or primary_key
|
||||
return sorted(by_key.values(), key=lambda item: (_priority_sort_key(item.priority), item.country, item.provider.lower(), item.feed_name.lower()))
|
||||
|
||||
|
||||
def select_test_run_candidates(candidates: Iterable[FeedCandidate], *, limit: int = 24) -> list[FeedCandidate]:
|
||||
sorted_candidates = sorted(
|
||||
[
|
||||
candidate
|
||||
for candidate in candidates
|
||||
if candidate.discovery_source != "mobility_validator_acceptance" and _test_candidate_eligible(candidate)
|
||||
],
|
||||
key=_test_candidate_sort_key,
|
||||
)
|
||||
selected: list[FeedCandidate] = []
|
||||
seen_urls: set[str] = set()
|
||||
per_country: dict[str, int] = {}
|
||||
|
||||
def add(candidate: FeedCandidate, *, force: bool = False) -> None:
|
||||
if len(selected) >= limit:
|
||||
return
|
||||
url_key = _normalize_url_key(candidate.selected_url)
|
||||
if not candidate.selected_url or url_key in seen_urls:
|
||||
return
|
||||
country = candidate.country or "unknown"
|
||||
country_limit = 7 if force and country == "DE" else 3
|
||||
if per_country.get(country, 0) >= country_limit:
|
||||
return
|
||||
selected.append(candidate)
|
||||
seen_urls.add(url_key)
|
||||
per_country[country] = per_country.get(country, 0) + 1
|
||||
|
||||
preferred_tokens = [
|
||||
"opendata-oepnv.de",
|
||||
"download.gtfs.de/germany/",
|
||||
"vbb.de/vbbgtfs",
|
||||
"rnv-online.de",
|
||||
"vrn.de",
|
||||
"gtfs.geops.ch",
|
||||
"wienerlinien.at",
|
||||
"gtfs.openov.nl",
|
||||
"gtfs.ovapi.nl",
|
||||
"rejseplanen.info",
|
||||
"dev.hsl.fi/gtfs",
|
||||
"hsldev.com/gtfs",
|
||||
"rb_norway-aggregated-gtfs",
|
||||
"data.bus-data.dft.gov.uk",
|
||||
"transportforireland",
|
||||
"gtfs.irail.be/de-lijn",
|
||||
]
|
||||
for candidate in sorted_candidates:
|
||||
text = " ".join([candidate.provider, candidate.feed_name, candidate.source_basis, candidate.selected_url]).lower()
|
||||
if any(token in text for token in preferred_tokens):
|
||||
add(candidate, force=True)
|
||||
for country in CURATED_TEST_COUNTRIES:
|
||||
for candidate in sorted_candidates:
|
||||
if candidate.country == country:
|
||||
add(candidate)
|
||||
if len(selected) >= limit:
|
||||
break
|
||||
if len(selected) >= limit:
|
||||
break
|
||||
for candidate in sorted_candidates:
|
||||
add(candidate)
|
||||
if len(selected) >= limit:
|
||||
break
|
||||
return selected
|
||||
|
||||
|
||||
def _test_candidate_eligible(candidate: FeedCandidate) -> bool:
|
||||
if not candidate.selected_url:
|
||||
return False
|
||||
if _priority_sort_key(candidate.priority) > 2:
|
||||
return False
|
||||
text = " ".join([candidate.status, candidate.selected_url, candidate.provider, candidate.feed_name, candidate.notes]).lower()
|
||||
if "deprecated" in text or "inactive" in text or "{apikey}" in text:
|
||||
return False
|
||||
if "registration required" in text or "authentication" in text:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def annotate_url_availability(candidate: FeedCandidate, *, timeout: float = 10.0) -> FeedCandidate:
|
||||
if not candidate.selected_url:
|
||||
candidate.availability_status = "missing_url"
|
||||
return candidate
|
||||
headers = {"User-Agent": "meubility-workbench-feed-discovery/0.1"}
|
||||
try:
|
||||
response = requests.head(candidate.selected_url, allow_redirects=True, timeout=timeout, headers=headers)
|
||||
if response.status_code in {405, 403} or response.status_code >= 500:
|
||||
response = requests.get(
|
||||
candidate.selected_url,
|
||||
allow_redirects=True,
|
||||
timeout=timeout,
|
||||
headers={**headers, "Range": "bytes=0-0"},
|
||||
stream=True,
|
||||
)
|
||||
candidate.http_status = str(response.status_code)
|
||||
candidate.content_type = response.headers.get("content-type", "")
|
||||
candidate.content_length = response.headers.get("content-length", "")
|
||||
candidate.final_url = response.url
|
||||
candidate.availability_status = "ok" if response.status_code < 400 else "error"
|
||||
response.close()
|
||||
except requests.RequestException as exc:
|
||||
candidate.availability_status = "error"
|
||||
candidate.notes = _join_notes(candidate.notes, f"Availability check failed: {exc}")
|
||||
return candidate
|
||||
|
||||
|
||||
def normalize_candidate_geography(candidate: FeedCandidate) -> None:
|
||||
text = " ".join(
|
||||
[
|
||||
candidate.selected_url,
|
||||
candidate.direct_download_url,
|
||||
candidate.latest_url,
|
||||
candidate.original_release_url,
|
||||
candidate.provider,
|
||||
candidate.feed_name,
|
||||
candidate.source_basis,
|
||||
]
|
||||
).lower()
|
||||
if "download.gtfs.de/germany/" in text or "gtfs for germany" in text:
|
||||
candidate.country = "DE"
|
||||
elif "storage.googleapis.com/marduk-production/outbound/gtfs/rb_norway" in text:
|
||||
candidate.country = "NO"
|
||||
elif "gtfs.ovapi.nl" in text or "openov.nl" in text:
|
||||
candidate.country = "NL"
|
||||
elif "www.nvbw.de/fileadmin/user_upload/service/open_data/" in text:
|
||||
candidate.country = "DE"
|
||||
|
||||
|
||||
def apply_known_download_overrides(candidate: FeedCandidate) -> None:
|
||||
stale_direct_ids = {"mdb-684", "mdb-777"}
|
||||
if candidate.stable_id in stale_direct_ids and candidate.latest_url:
|
||||
candidate.selected_url = candidate.latest_url
|
||||
candidate.notes = _join_notes(
|
||||
candidate.notes,
|
||||
"Selected Mobility Database latest.zip mirror because the catalog direct URL is known to be stale.",
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _HtmlCell:
|
||||
text: str = ""
|
||||
links: list[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def first_external_link(self) -> str:
|
||||
for link in self.links:
|
||||
parsed = urlparse(link)
|
||||
if parsed.scheme in {"http", "https"} and "ptna.openstreetmap.de" not in parsed.netloc:
|
||||
return link
|
||||
return ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _HtmlRow:
|
||||
cells: list[_HtmlCell] = field(default_factory=list)
|
||||
|
||||
|
||||
class _TableParser(HTMLParser):
|
||||
def __init__(self, base_url: str):
|
||||
super().__init__(convert_charrefs=True)
|
||||
self.base_url = base_url
|
||||
self.rows: list[_HtmlRow] = []
|
||||
self._row: _HtmlRow | None = None
|
||||
self._cell: _HtmlCell | None = None
|
||||
self._active_link: str = ""
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
||||
attrs_dict = {key: value or "" for key, value in attrs}
|
||||
if tag == "tr":
|
||||
self._row = _HtmlRow()
|
||||
elif tag in {"td", "th"} and self._row is not None:
|
||||
self._cell = _HtmlCell()
|
||||
elif tag == "a" and self._cell is not None:
|
||||
href = attrs_dict.get("href", "")
|
||||
if href:
|
||||
self._active_link = urljoin(self.base_url, href)
|
||||
self._cell.links.append(self._active_link)
|
||||
|
||||
def handle_endtag(self, tag: str) -> None:
|
||||
if tag in {"td", "th"} and self._row is not None and self._cell is not None:
|
||||
self._cell.text = _clean_text(self._cell.text)
|
||||
self._row.cells.append(self._cell)
|
||||
self._cell = None
|
||||
self._active_link = ""
|
||||
elif tag == "a":
|
||||
self._active_link = ""
|
||||
elif tag == "tr":
|
||||
if self._row is not None and self._row.cells:
|
||||
self.rows.append(self._row)
|
||||
self._row = None
|
||||
self._cell = None
|
||||
self._active_link = ""
|
||||
|
||||
def handle_data(self, data: str) -> None:
|
||||
if self._cell is not None:
|
||||
self._cell.text += data
|
||||
|
||||
|
||||
class _LinkParser(HTMLParser):
|
||||
def __init__(self, base_url: str):
|
||||
super().__init__(convert_charrefs=True)
|
||||
self.base_url = base_url
|
||||
self.links: list[str] = []
|
||||
|
||||
def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
|
||||
if tag != "a":
|
||||
return
|
||||
for key, value in attrs:
|
||||
if key == "href" and value:
|
||||
self.links.append(urljoin(self.base_url, value))
|
||||
|
||||
|
||||
def _parse_table_rows(html: str, base_url: str) -> list[_HtmlRow]:
|
||||
parser = _TableParser(base_url)
|
||||
parser.feed(html)
|
||||
return parser.rows
|
||||
|
||||
|
||||
def _all_links(html: str, base_url: str) -> list[str]:
|
||||
parser = _LinkParser(base_url)
|
||||
parser.feed(html)
|
||||
return parser.links
|
||||
|
||||
|
||||
def _fetch_text(url: str, *, timeout: float) -> str:
|
||||
response = requests.get(url, timeout=timeout, headers={"User-Agent": "meubility-workbench-feed-discovery/0.1"})
|
||||
response.raise_for_status()
|
||||
return response.text
|
||||
|
||||
|
||||
def _first_link_matching(links: Iterable[str], needle: str) -> str:
|
||||
for link in links:
|
||||
if needle in link:
|
||||
return link
|
||||
return ""
|
||||
|
||||
|
||||
def _feed_id_from_url(url: str) -> str:
|
||||
query = parse_qs(urlparse(url).query)
|
||||
return (query.get("feed") or [""])[0]
|
||||
|
||||
|
||||
def _looks_like_download_url(url: str) -> bool:
|
||||
if not url:
|
||||
return False
|
||||
parsed = urlparse(url)
|
||||
lower_path = parsed.path.lower()
|
||||
lower_url = url.lower()
|
||||
if lower_path.endswith(".zip"):
|
||||
return True
|
||||
if "exportformat=gtfs" in lower_url or "google_transit" in lower_url:
|
||||
return True
|
||||
if lower_path.rstrip("/").endswith(("current_gtfs", "gtfs")):
|
||||
return True
|
||||
if "gtfs.ovapi.nl" in parsed.netloc.lower() and "gtfs" in lower_path:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_feed_url(url: str) -> str:
|
||||
cleaned = _clean_text(url)
|
||||
if not cleaned:
|
||||
return ""
|
||||
parsed = urlparse(cleaned)
|
||||
if parsed.scheme:
|
||||
return cleaned
|
||||
first = cleaned.split("/", 1)[0]
|
||||
if "." in first:
|
||||
return f"https://{cleaned}"
|
||||
return cleaned
|
||||
|
||||
|
||||
def _choose_feed_url(direct_url: str, latest_url: str) -> str:
|
||||
if direct_url:
|
||||
return direct_url
|
||||
return latest_url
|
||||
|
||||
|
||||
def _candidate_priority(candidate: FeedCandidate) -> str:
|
||||
status = candidate.status.lower()
|
||||
official = candidate.is_official.lower() == "true"
|
||||
if candidate.discovery_source == "curated_seed":
|
||||
return candidate.priority or "P1"
|
||||
if status == "active" and official and candidate.direct_download_url:
|
||||
return "P0"
|
||||
if status == "active" and candidate.direct_download_url:
|
||||
return "P1"
|
||||
if status == "active" and candidate.latest_url:
|
||||
return "P2"
|
||||
if candidate.discovery_source == "ptna":
|
||||
return "P2" if candidate.selected_url else "P4"
|
||||
return "P3"
|
||||
|
||||
|
||||
def _test_candidate_sort_key(candidate: FeedCandidate) -> tuple[int, int, str, str]:
|
||||
source_bonus = 0 if candidate.discovery_source == "curated_seed" else 1
|
||||
country_bonus = CURATED_TEST_COUNTRIES.index(candidate.country) if candidate.country in CURATED_TEST_COUNTRIES else 99
|
||||
return (_priority_sort_key(candidate.priority), source_bonus + country_bonus, candidate.country, candidate.provider.lower())
|
||||
|
||||
|
||||
def _priority_sort_key(priority: str) -> int:
|
||||
match = re.match(r"P(\d+)", priority or "")
|
||||
return int(match.group(1)) if match else 9
|
||||
|
||||
|
||||
def _candidate_alias_keys(candidate: FeedCandidate) -> list[str]:
|
||||
keys = [candidate.key()]
|
||||
if candidate.stable_id:
|
||||
keys.append(f"stable:{candidate.stable_id}")
|
||||
for url in [candidate.selected_url, candidate.direct_download_url, candidate.latest_url]:
|
||||
if url:
|
||||
keys.append(f"url:{_normalize_url_key(url)}")
|
||||
if candidate.ptna_feed_id:
|
||||
keys.append(f"ptna:{candidate.ptna_feed_id}")
|
||||
deduped: list[str] = []
|
||||
for key in keys:
|
||||
if key not in deduped:
|
||||
deduped.append(key)
|
||||
return deduped
|
||||
|
||||
|
||||
def _merge_candidate(existing: FeedCandidate, incoming: FeedCandidate) -> None:
|
||||
if incoming.discovery_source == "curated_seed":
|
||||
for field_name in ["country", "provider", "feed_name", "license_text", "features", "source_basis", "notes"]:
|
||||
new_value = getattr(incoming, field_name, "")
|
||||
if new_value:
|
||||
setattr(existing, field_name, new_value)
|
||||
existing.discovery_source = _join_unique(existing.discovery_source, incoming.discovery_source)
|
||||
for field_name in CANONICAL_HEADERS:
|
||||
if field_name == "candidate_id":
|
||||
continue
|
||||
current = getattr(existing, field_name, "")
|
||||
new_value = getattr(incoming, field_name, "")
|
||||
if not current and new_value:
|
||||
setattr(existing, field_name, new_value)
|
||||
existing.priority = _better_priority(existing.priority, incoming.priority)
|
||||
existing.source_basis = _join_unique(existing.source_basis, incoming.source_basis)
|
||||
existing.notes = _join_notes(existing.notes, incoming.notes)
|
||||
|
||||
|
||||
def _better_priority(left: str, right: str) -> str:
|
||||
return left if _priority_sort_key(left) <= _priority_sort_key(right) else right
|
||||
|
||||
|
||||
def _join_unique(left: str, right: str) -> str:
|
||||
parts: list[str] = []
|
||||
for value in [left, right]:
|
||||
for part in value.split(";"):
|
||||
cleaned = part.strip()
|
||||
if cleaned and cleaned not in parts:
|
||||
parts.append(cleaned)
|
||||
return "; ".join(parts)
|
||||
|
||||
|
||||
def _join_notes(left: str, right: str) -> str:
|
||||
return _join_unique(left, right)
|
||||
|
||||
|
||||
def _compact_name(value: str) -> str:
|
||||
return re.sub(r"\s+", " ", _clean_text(value)).strip()
|
||||
|
||||
|
||||
def _feed_source_name(country: str, value: str) -> str:
|
||||
base = _compact_name(value) or "GTFS feed"
|
||||
prefix = country.upper()
|
||||
display = base
|
||||
if prefix and not base.upper().startswith(f"{prefix} "):
|
||||
display = f"{prefix} {base}"
|
||||
if "gtfs" not in display.lower():
|
||||
display = f"{display} GTFS"
|
||||
return display
|
||||
|
||||
|
||||
def _clean_text(value: str) -> str:
|
||||
cleaned = unescape(value or "").replace("\xa0", " ")
|
||||
cleaned = re.sub(r"\s+", " ", cleaned)
|
||||
return cleaned.strip()
|
||||
|
||||
|
||||
def _mode_scope_from_features(features: str) -> str:
|
||||
lower = features.lower()
|
||||
modes = []
|
||||
if "rail" in lower or "train" in lower:
|
||||
modes.append("rail")
|
||||
if "tram" in lower or "light_rail" in lower:
|
||||
modes.append("tram")
|
||||
if "subway" in lower or "metro" in lower:
|
||||
modes.append("metro")
|
||||
if "bus" in lower or not modes:
|
||||
modes.append("bus")
|
||||
if "ferry" in lower:
|
||||
modes.append("ferry")
|
||||
return ",".join(dict.fromkeys(modes))
|
||||
|
||||
|
||||
def _bbox_from_mobility_row(row: dict[str, str]) -> str:
|
||||
min_lat = _value(row, "location.bounding_box.minimum_latitude")
|
||||
max_lat = _value(row, "location.bounding_box.maximum_latitude")
|
||||
min_lon = _value(row, "location.bounding_box.minimum_longitude")
|
||||
max_lon = _value(row, "location.bounding_box.maximum_longitude")
|
||||
if not all([min_lat, max_lat, min_lon, max_lon]):
|
||||
return ""
|
||||
return f"{min_lon},{min_lat},{max_lon},{max_lat}"
|
||||
|
||||
|
||||
def _normalize_countries(countries: Iterable[str] | None) -> list[str] | None:
|
||||
if countries is None:
|
||||
return DEFAULT_DISCOVERY_COUNTRIES
|
||||
normalized = [country.strip().upper() for country in countries if country and country.strip()]
|
||||
if any(country == "ALL" for country in normalized):
|
||||
return None
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_url_key(url: str) -> str:
|
||||
parsed = urlparse(url.strip())
|
||||
scheme = parsed.scheme.lower()
|
||||
netloc = parsed.netloc.lower()
|
||||
path = parsed.path.rstrip("/")
|
||||
query = parsed.query
|
||||
return f"{scheme}://{netloc}{path}" + (f"?{query}" if query else "")
|
||||
|
||||
|
||||
def _write_csv(path: Path, headers: list[str], rows: list[dict[str, str]]) -> None:
|
||||
with path.open("w", encoding="utf-8", newline="") as handle:
|
||||
writer = csv.DictWriter(handle, fieldnames=headers, extrasaction="ignore")
|
||||
writer.writeheader()
|
||||
writer.writerows(rows)
|
||||
|
||||
|
||||
def _count_by(items: Iterable[FeedCandidate], key_fn) -> dict[str, int]:
|
||||
counts: dict[str, int] = {}
|
||||
for item in items:
|
||||
key = key_fn(item)
|
||||
counts[key] = counts.get(key, 0) + 1
|
||||
return dict(sorted(counts.items()))
|
||||
|
||||
|
||||
def _value(row: dict[str, str], key: str) -> str:
|
||||
return _clean_text(row.get(key, ""))
|
||||
|
||||
|
||||
def _string(value: object) -> str:
|
||||
return "" if value is None else str(value)
|
||||
|
||||
|
||||
def _truncate(value: str, length: int) -> str:
|
||||
return value[:length] if value else ""
|
||||
120
app/geofabrik.py
Normal file
120
app/geofabrik.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import Source
|
||||
|
||||
|
||||
GEOFABRIK_INDEX_URL = "https://download.geofabrik.de/index-v1-nogeom.json"
|
||||
_CACHE: dict[str, Any] = {"expires_at": None, "rows": None}
|
||||
|
||||
|
||||
def geofabrik_catalog(q: str | None = None, limit: int = 80) -> list[dict[str, Any]]:
|
||||
rows = _geofabrik_rows()
|
||||
query = (q or "").strip().casefold()
|
||||
if query:
|
||||
rows = [
|
||||
row
|
||||
for row in rows
|
||||
if query in row["id"].casefold()
|
||||
or query in row["name"].casefold()
|
||||
or query in (row.get("parent") or "").casefold()
|
||||
or query in " ".join(row.get("country_codes") or []).casefold()
|
||||
]
|
||||
rows.sort(key=lambda row: (row.get("parent") or "", row["name"]))
|
||||
return rows[: max(1, min(limit, 500))]
|
||||
|
||||
|
||||
def geofabrik_entry(geofabrik_id: str) -> dict[str, Any] | None:
|
||||
target = geofabrik_id.strip().casefold()
|
||||
for row in _geofabrik_rows():
|
||||
if row["id"].casefold() == target:
|
||||
return row
|
||||
return None
|
||||
|
||||
|
||||
def create_geofabrik_source(session: Session, geofabrik_id: str, *, import_updates: bool = False) -> Source:
|
||||
entry = geofabrik_entry(geofabrik_id)
|
||||
if entry is None:
|
||||
raise ValueError(f"Geofabrik extract not found: {geofabrik_id}")
|
||||
if not entry.get("pbf_url"):
|
||||
raise ValueError(f"Geofabrik extract has no PBF URL: {geofabrik_id}")
|
||||
existing = session.scalar(select(Source).where(Source.kind == "osm_pbf", Source.url == entry["pbf_url"]))
|
||||
if existing is not None:
|
||||
return existing
|
||||
source = Source(
|
||||
name=f"Geofabrik {entry['name']}",
|
||||
kind="osm_pbf",
|
||||
url=entry["pbf_url"],
|
||||
country=",".join(entry.get("country_codes") or [])[:8] or None,
|
||||
license="ODbL / Geofabrik extract terms",
|
||||
priority="P0 fallback",
|
||||
mode_scope="public transport OSM routes, stops, and infrastructure",
|
||||
source_basis="OpenStreetMap / Geofabrik extracts",
|
||||
notes=_geofabrik_notes(entry, import_updates=import_updates),
|
||||
)
|
||||
session.add(source)
|
||||
session.flush()
|
||||
if import_updates and entry.get("updates_url"):
|
||||
update_source = Source(
|
||||
name=f"Geofabrik {entry['name']} updates",
|
||||
kind="osm_diff",
|
||||
url=entry["updates_url"],
|
||||
country=source.country,
|
||||
license=source.license,
|
||||
priority=source.priority,
|
||||
mode_scope=source.mode_scope,
|
||||
source_basis="OpenStreetMap / Geofabrik replication diffs",
|
||||
notes=f"Diff base for Geofabrik extract {entry['id']}; applying diffs to a local base extract is not implemented yet.",
|
||||
)
|
||||
session.add(update_source)
|
||||
return source
|
||||
|
||||
|
||||
def _geofabrik_rows() -> list[dict[str, Any]]:
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = _CACHE.get("expires_at")
|
||||
if _CACHE.get("rows") is not None and isinstance(expires_at, datetime) and expires_at > now:
|
||||
return list(_CACHE["rows"])
|
||||
response = requests.get(GEOFABRIK_INDEX_URL, timeout=45)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
rows = [_normalize_feature(feature) for feature in payload.get("features", [])]
|
||||
rows = [row for row in rows if row.get("id") and row.get("pbf_url")]
|
||||
_CACHE["rows"] = rows
|
||||
_CACHE["expires_at"] = now + timedelta(hours=12)
|
||||
return list(rows)
|
||||
|
||||
|
||||
def _normalize_feature(feature: dict[str, Any]) -> dict[str, Any]:
|
||||
props = feature.get("properties") or {}
|
||||
urls = props.get("urls") or {}
|
||||
country_codes = props.get("iso3166-1:alpha2") or []
|
||||
if isinstance(country_codes, str):
|
||||
country_codes = [country_codes]
|
||||
return {
|
||||
"id": str(props.get("id") or ""),
|
||||
"name": str(props.get("name") or props.get("id") or ""),
|
||||
"parent": props.get("parent"),
|
||||
"country_codes": country_codes,
|
||||
"pbf_url": urls.get("pbf"),
|
||||
"updates_url": urls.get("updates"),
|
||||
"taginfo_url": urls.get("taginfo"),
|
||||
"urls": urls,
|
||||
}
|
||||
|
||||
|
||||
def _geofabrik_notes(entry: dict[str, Any], *, import_updates: bool) -> str:
|
||||
parts = [
|
||||
f"geofabrik_id={entry['id']}",
|
||||
f"parent={entry.get('parent') or 'root'}",
|
||||
f"updates_url={entry.get('updates_url') or ''}",
|
||||
"diff_source_requested=true" if import_updates else "diff_source_requested=false",
|
||||
"Overlap dedupe is handled by OSM object identity in the route layer; source-specific map layers may still show both extracts.",
|
||||
]
|
||||
return "; ".join(parts)
|
||||
308
app/gtfs_storage.py
Normal file
308
app/gtfs_storage.py
Normal file
@@ -0,0 +1,308 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Sequence
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models import Dataset, GtfsStopTime
|
||||
|
||||
|
||||
GTFS_STORAGE_METADATA_KEY = "gtfs_storage"
|
||||
GTFS_STORAGE_MAIN = "main"
|
||||
GTFS_STORAGE_SIDECAR_STOP_TIMES = "sidecar_stop_times"
|
||||
GTFS_STOP_TIME_COLUMNS = [
|
||||
"trip_id",
|
||||
"stop_id",
|
||||
"stop_sequence",
|
||||
"arrival_time",
|
||||
"departure_time",
|
||||
"arrival_seconds",
|
||||
"departure_seconds",
|
||||
]
|
||||
SQLITE_IN_CHUNK_SIZE = 800
|
||||
|
||||
|
||||
def effective_gtfs_timetable_storage(value: str | None = None) -> str:
|
||||
configured = str(value or settings.gtfs_timetable_storage or GTFS_STORAGE_SIDECAR_STOP_TIMES).strip().lower()
|
||||
if configured in {GTFS_STORAGE_MAIN, "main_db", "main_sqlite", "postgres", "postgresql"}:
|
||||
return GTFS_STORAGE_MAIN
|
||||
if settings.is_postgresql_database and not settings.postgres_use_sidecars:
|
||||
return GTFS_STORAGE_MAIN
|
||||
return GTFS_STORAGE_SIDECAR_STOP_TIMES
|
||||
|
||||
|
||||
class MissingGtfsSidecar(FileNotFoundError):
|
||||
def __init__(self, dataset_id: int | None, path: Path | None) -> None:
|
||||
self.dataset_id = dataset_id
|
||||
self.path = path
|
||||
if path is None:
|
||||
message = f"dataset #{dataset_id} does not reference a GTFS sidecar"
|
||||
else:
|
||||
message = f"GTFS sidecar does not exist: {path}"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def dataset_metadata(dataset: Dataset) -> dict:
|
||||
try:
|
||||
metadata = json.loads(dataset.metadata_json or "{}")
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return metadata if isinstance(metadata, dict) else {}
|
||||
|
||||
|
||||
def stop_times_are_sidecar(dataset: Dataset | None) -> bool:
|
||||
if dataset is None:
|
||||
return False
|
||||
storage = dataset_metadata(dataset).get(GTFS_STORAGE_METADATA_KEY)
|
||||
if not isinstance(storage, dict):
|
||||
return False
|
||||
tables = storage.get("tables")
|
||||
if isinstance(tables, dict):
|
||||
return tables.get("gtfs_stop_times") == "sidecar"
|
||||
return storage.get("mode") == GTFS_STORAGE_SIDECAR_STOP_TIMES
|
||||
|
||||
|
||||
def sidecar_path(dataset: Dataset | None) -> Path | None:
|
||||
if dataset is None:
|
||||
return None
|
||||
storage = dataset_metadata(dataset).get(GTFS_STORAGE_METADATA_KEY)
|
||||
if not isinstance(storage, dict):
|
||||
return None
|
||||
value = storage.get("sidecar_path")
|
||||
if not value:
|
||||
return None
|
||||
return Path(str(value))
|
||||
|
||||
|
||||
def dataset_sidecar_paths(dataset: Dataset) -> list[Path]:
|
||||
path = sidecar_path(dataset)
|
||||
return [] if path is None else [path]
|
||||
|
||||
|
||||
def missing_sidecar_paths(dataset: Dataset | None) -> list[str]:
|
||||
if not stop_times_are_sidecar(dataset):
|
||||
return []
|
||||
path = sidecar_path(dataset)
|
||||
if path is None:
|
||||
dataset_id = "unknown" if dataset is None else str(dataset.id)
|
||||
return [f"dataset #{dataset_id} has no configured GTFS sidecar path"]
|
||||
return [] if path.exists() else [str(path)]
|
||||
|
||||
|
||||
def uses_sidecar_stop_times(session: Session, dataset_id: int) -> bool:
|
||||
return stop_times_are_sidecar(session.get(Dataset, dataset_id))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def sidecar_connection(dataset: Dataset) -> Iterator[sqlite3.Connection]:
|
||||
path = sidecar_path(dataset)
|
||||
if path is None:
|
||||
raise MissingGtfsSidecar(dataset.id, None)
|
||||
if not path.exists():
|
||||
raise MissingGtfsSidecar(dataset.id, path)
|
||||
connection = sqlite3.connect(f"file:{path}?mode=ro", uri=True)
|
||||
connection.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield connection
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
|
||||
def stop_time_count(session: Session, dataset_id: int) -> int:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if stop_times_are_sidecar(dataset):
|
||||
try:
|
||||
with sidecar_connection(dataset) as connection:
|
||||
return int(connection.execute("SELECT COUNT(*) FROM gtfs_stop_times").fetchone()[0] or 0)
|
||||
except MissingGtfsSidecar:
|
||||
return 0
|
||||
return session.scalar(select(func.count()).select_from(GtfsStopTime).where(GtfsStopTime.dataset_id == dataset_id)) or 0
|
||||
|
||||
|
||||
def stop_time_counts_by_dataset(session: Session, dataset_ids: Sequence[int]) -> dict[int, int]:
|
||||
counts: dict[int, int] = {}
|
||||
for dataset_id in dataset_ids:
|
||||
counts[int(dataset_id)] = stop_time_count(session, int(dataset_id))
|
||||
return counts
|
||||
|
||||
|
||||
def scheduled_stop_ids(session: Session, dataset_id: int, stop_ids: Sequence[str]) -> tuple[str, ...]:
|
||||
if not stop_ids:
|
||||
return ()
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
requested = [str(stop_id) for stop_id in stop_ids]
|
||||
found: set[str] = set()
|
||||
if stop_times_are_sidecar(dataset):
|
||||
try:
|
||||
with sidecar_connection(dataset) as connection:
|
||||
for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
|
||||
placeholders = ", ".join(["?"] * len(chunk))
|
||||
rows = connection.execute(
|
||||
f"""
|
||||
SELECT stop_id
|
||||
FROM gtfs_stop_times
|
||||
WHERE stop_id IN ({placeholders})
|
||||
GROUP BY stop_id
|
||||
""",
|
||||
list(chunk),
|
||||
).fetchall()
|
||||
found.update(str(row["stop_id"]) for row in rows)
|
||||
except MissingGtfsSidecar:
|
||||
return ()
|
||||
else:
|
||||
for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
|
||||
rows = session.scalars(
|
||||
select(GtfsStopTime.stop_id)
|
||||
.where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.stop_id.in_(chunk))
|
||||
.group_by(GtfsStopTime.stop_id)
|
||||
).all()
|
||||
found.update(str(row) for row in rows)
|
||||
return tuple(sorted(found))
|
||||
|
||||
|
||||
def all_scheduled_stop_ids(session: Session, dataset_id: int) -> set[str]:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if stop_times_are_sidecar(dataset):
|
||||
try:
|
||||
with sidecar_connection(dataset) as connection:
|
||||
return {
|
||||
str(row["stop_id"])
|
||||
for row in connection.execute("SELECT stop_id FROM gtfs_stop_times GROUP BY stop_id").fetchall()
|
||||
}
|
||||
except MissingGtfsSidecar:
|
||||
return set()
|
||||
return {
|
||||
str(row)
|
||||
for row in session.scalars(
|
||||
select(GtfsStopTime.stop_id)
|
||||
.where(GtfsStopTime.dataset_id == dataset_id)
|
||||
.group_by(GtfsStopTime.stop_id)
|
||||
).all()
|
||||
}
|
||||
|
||||
|
||||
def scheduled_stop_ids_by_dataset(session: Session, dataset_ids: Sequence[int]) -> dict[int, set[str]]:
|
||||
return {int(dataset_id): all_scheduled_stop_ids(session, int(dataset_id)) for dataset_id in dataset_ids}
|
||||
|
||||
|
||||
def has_scheduled_stop(session: Session, dataset_id: int, stop_id: str) -> bool:
|
||||
return bool(scheduled_stop_ids(session, dataset_id, [stop_id]))
|
||||
|
||||
|
||||
def stop_times_by_trip(
|
||||
session: Session,
|
||||
dataset_id: int,
|
||||
trip_ids: Sequence[str],
|
||||
) -> dict[str, list[GtfsStopTime]]:
|
||||
if not trip_ids:
|
||||
return {}
|
||||
grouped: dict[str, list[GtfsStopTime]] = {}
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
requested = [str(trip_id) for trip_id in trip_ids]
|
||||
if stop_times_are_sidecar(dataset):
|
||||
column_sql = ", ".join(GTFS_STOP_TIME_COLUMNS)
|
||||
try:
|
||||
with sidecar_connection(dataset) as connection:
|
||||
for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
|
||||
placeholders = ", ".join(["?"] * len(chunk))
|
||||
rows = connection.execute(
|
||||
f"""
|
||||
SELECT {column_sql}
|
||||
FROM gtfs_stop_times
|
||||
WHERE trip_id IN ({placeholders})
|
||||
ORDER BY trip_id, stop_sequence
|
||||
""",
|
||||
list(chunk),
|
||||
).fetchall()
|
||||
for row in rows:
|
||||
stop_time = stop_time_from_row(dataset_id, row)
|
||||
grouped.setdefault(stop_time.trip_id, []).append(stop_time)
|
||||
except MissingGtfsSidecar:
|
||||
return {}
|
||||
return grouped
|
||||
|
||||
for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
|
||||
rows = session.scalars(
|
||||
select(GtfsStopTime)
|
||||
.where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.trip_id.in_(chunk))
|
||||
.order_by(GtfsStopTime.trip_id, GtfsStopTime.stop_sequence)
|
||||
).all()
|
||||
for row in rows:
|
||||
grouped.setdefault(row.trip_id, []).append(row)
|
||||
return grouped
|
||||
|
||||
|
||||
def stop_times_for_trip_range(
|
||||
session: Session,
|
||||
dataset_id: int,
|
||||
trip_id: str,
|
||||
start_sequence: int,
|
||||
end_sequence: int,
|
||||
) -> list[GtfsStopTime]:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if stop_times_are_sidecar(dataset):
|
||||
column_sql = ", ".join(GTFS_STOP_TIME_COLUMNS)
|
||||
try:
|
||||
with sidecar_connection(dataset) as connection:
|
||||
rows = connection.execute(
|
||||
f"""
|
||||
SELECT {column_sql}
|
||||
FROM gtfs_stop_times
|
||||
WHERE trip_id = ?
|
||||
AND stop_sequence >= ?
|
||||
AND stop_sequence <= ?
|
||||
ORDER BY stop_sequence
|
||||
""",
|
||||
(trip_id, int(start_sequence), int(end_sequence)),
|
||||
).fetchall()
|
||||
return [stop_time_from_row(dataset_id, row) for row in rows]
|
||||
except MissingGtfsSidecar:
|
||||
return []
|
||||
|
||||
return list(
|
||||
session.scalars(
|
||||
select(GtfsStopTime)
|
||||
.where(
|
||||
GtfsStopTime.dataset_id == dataset_id,
|
||||
GtfsStopTime.trip_id == trip_id,
|
||||
GtfsStopTime.stop_sequence >= start_sequence,
|
||||
GtfsStopTime.stop_sequence <= end_sequence,
|
||||
)
|
||||
.order_by(GtfsStopTime.stop_sequence)
|
||||
).all()
|
||||
)
|
||||
|
||||
|
||||
def stop_time_from_row(dataset_id: int, row) -> GtfsStopTime:
|
||||
return GtfsStopTime(
|
||||
dataset_id=dataset_id,
|
||||
trip_id=str(row["trip_id"]),
|
||||
stop_id=str(row["stop_id"]),
|
||||
stop_sequence=int(row["stop_sequence"]),
|
||||
arrival_time=row["arrival_time"],
|
||||
departure_time=row["departure_time"],
|
||||
arrival_seconds=row["arrival_seconds"],
|
||||
departure_seconds=row["departure_seconds"],
|
||||
)
|
||||
|
||||
|
||||
def execute_sidecar_query(session: Session, dataset_id: int, sql: str, params: Sequence[object]) -> list[sqlite3.Row]:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if not stop_times_are_sidecar(dataset):
|
||||
raise ValueError(f"dataset #{dataset_id} does not use sidecar stop_times")
|
||||
try:
|
||||
with sidecar_connection(dataset) as connection:
|
||||
return list(connection.execute(sql, list(params)).fetchall())
|
||||
except MissingGtfsSidecar:
|
||||
return []
|
||||
|
||||
|
||||
def _chunks[T](items: Sequence[T], size: int) -> Iterator[Sequence[T]]:
|
||||
for index in range(0, len(items), size):
|
||||
yield items[index : index + size]
|
||||
394
app/harmonization.py
Normal file
394
app/harmonization.py
Normal file
@@ -0,0 +1,394 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import and_, func, select
|
||||
from sqlalchemy.orm import Session, aliased
|
||||
|
||||
from app.data_management import dataset_row_counts
|
||||
from app.models import (
|
||||
CanonicalStopLink,
|
||||
Dataset,
|
||||
GtfsCalendar,
|
||||
GtfsCalendarDate,
|
||||
GtfsRoute,
|
||||
GtfsStop,
|
||||
GtfsStopTime,
|
||||
GtfsTrip,
|
||||
RouteMatch,
|
||||
Source,
|
||||
)
|
||||
|
||||
|
||||
GTFS_QA_NOTE_PREFIX = "[GTFS QA]"
|
||||
|
||||
|
||||
def gtfs_harmonization_inventory(session: Session) -> dict[str, Any]:
|
||||
feeds = [_feed_inventory_item(session, source) for source in _gtfs_sources(session)]
|
||||
summary = {
|
||||
"sources": len(feeds),
|
||||
"active_sources": sum(1 for feed in feeds if feed["active_dataset"] is not None),
|
||||
"datasets": sum(len(feed["datasets"]) for feed in feeds),
|
||||
"ready": sum(1 for feed in feeds if feed["qa_status"] == "ready"),
|
||||
"needs_review": sum(1 for feed in feeds if feed["qa_status"] == "needs_review"),
|
||||
"blocked": sum(1 for feed in feeds if feed["qa_status"] == "blocked"),
|
||||
}
|
||||
return {
|
||||
"summary": summary,
|
||||
"feeds": feeds,
|
||||
}
|
||||
|
||||
|
||||
def gtfs_harmonization_feed_detail(session: Session, source_id: int) -> dict[str, Any] | None:
|
||||
source = session.get(Source, source_id)
|
||||
if source is None or source.kind != "gtfs":
|
||||
return None
|
||||
feed = _feed_inventory_item(session, source)
|
||||
return {
|
||||
**feed,
|
||||
"sections": _feed_sections(feed),
|
||||
}
|
||||
|
||||
|
||||
def _gtfs_sources(session: Session) -> list[Source]:
|
||||
return session.scalars(select(Source).where(Source.kind == "gtfs").order_by(Source.country, Source.priority, Source.name, Source.id)).all()
|
||||
|
||||
|
||||
def _feed_inventory_item(session: Session, source: Source) -> dict[str, Any]:
|
||||
datasets = sorted([dataset for dataset in source.datasets if dataset.kind == "gtfs"], key=lambda item: (not item.is_active, item.created_at, item.id))
|
||||
active_dataset = next((dataset for dataset in datasets if dataset.is_active), None)
|
||||
counts = dataset_row_counts(session, active_dataset.id, active_dataset.kind) if active_dataset is not None else {}
|
||||
validation = _validate_gtfs_dataset(session, source, active_dataset, counts)
|
||||
overlap = _overlap_summary(session, active_dataset)
|
||||
service = _service_horizon(session, active_dataset)
|
||||
issues = [*validation["issues"], *service["issues"], *overlap["issues"], *_license_issues(source)]
|
||||
qa_status = _qa_status(issues, active_dataset)
|
||||
return {
|
||||
"source": _source_payload(source),
|
||||
"active_dataset": None if active_dataset is None else _dataset_payload(active_dataset, counts),
|
||||
"datasets": [_dataset_payload(dataset, dataset_row_counts(session, dataset.id, dataset.kind)) for dataset in datasets],
|
||||
"counts": counts,
|
||||
"validation": validation,
|
||||
"service": service,
|
||||
"overlap": overlap,
|
||||
"license": _license_payload(source),
|
||||
"issues": issues,
|
||||
"qa_status": qa_status,
|
||||
}
|
||||
|
||||
|
||||
def _source_payload(source: Source) -> dict[str, Any]:
|
||||
return {
|
||||
"id": source.id,
|
||||
"name": source.name,
|
||||
"country": source.country,
|
||||
"license": source.license,
|
||||
"priority": source.priority,
|
||||
"mode_scope": source.mode_scope,
|
||||
"source_basis": source.source_basis,
|
||||
"status": source.status,
|
||||
"enabled": source.enabled,
|
||||
"last_error": source.last_error,
|
||||
"last_run_at": _iso(source.last_run_at),
|
||||
"url": source.url,
|
||||
"catalog_entry_id": source.catalog_entry_id,
|
||||
"notes": source.notes,
|
||||
"qa_review": _qa_review_payload(source.notes),
|
||||
}
|
||||
|
||||
|
||||
def _dataset_payload(dataset: Dataset, counts: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
"id": dataset.id,
|
||||
"kind": dataset.kind,
|
||||
"is_active": dataset.is_active,
|
||||
"status": dataset.status,
|
||||
"sha256": dataset.sha256,
|
||||
"local_path": dataset.local_path,
|
||||
"created_at": _iso(dataset.created_at),
|
||||
"counts": counts,
|
||||
}
|
||||
|
||||
|
||||
def _validate_gtfs_dataset(session: Session, source: Source, dataset: Dataset | None, counts: dict[str, Any]) -> dict[str, Any]:
|
||||
if dataset is None:
|
||||
return {
|
||||
"status": "blocked",
|
||||
"items": [],
|
||||
"issues": [_issue("missing_active_dataset", "bad", "No active GTFS dataset", "Import this source before harmonization.")],
|
||||
}
|
||||
items = [
|
||||
_metric("Agencies", counts.get("agencies", 0), "bad" if not counts.get("agencies", 0) else "good"),
|
||||
_metric("Stops", counts.get("stops", 0), "bad" if not counts.get("stops", 0) else "good"),
|
||||
_metric("Routes", counts.get("routes", 0), "bad" if not counts.get("routes", 0) else "good"),
|
||||
_metric("Trips", counts.get("trips", 0), "bad" if not counts.get("trips", 0) else "good"),
|
||||
_metric("Stop times", counts.get("stop_times", 0), "bad" if not counts.get("stop_times", 0) else "good"),
|
||||
_metric("Shapes", counts.get("shapes", 0), "warn" if not counts.get("shapes", 0) else "good"),
|
||||
]
|
||||
missing_coords = _count(session, GtfsStop, dataset.id, (GtfsStop.lat.is_(None) | GtfsStop.lon.is_(None)))
|
||||
invalid_coords = _count(
|
||||
session,
|
||||
GtfsStop,
|
||||
dataset.id,
|
||||
(GtfsStop.lat < -90) | (GtfsStop.lat > 90) | (GtfsStop.lon < -180) | (GtfsStop.lon > 180),
|
||||
)
|
||||
routes_without_trips = _routes_without_trips(session, dataset.id)
|
||||
trips_without_stop_times = _trips_without_stop_times(session, dataset.id)
|
||||
stop_times_without_seconds = _stop_times_without_seconds(session, dataset.id)
|
||||
route_geometry_missing = _count(session, GtfsRoute, dataset.id, GtfsRoute.geometry_geojson.is_(None))
|
||||
canonical_links = _count(session, CanonicalStopLink, dataset.id, CanonicalStopLink.object_type == "gtfs_stop")
|
||||
match_counts = counts.get("match_counts", {}) if isinstance(counts.get("match_counts"), dict) else {}
|
||||
|
||||
items.extend(
|
||||
[
|
||||
_metric("Stops missing coordinates", missing_coords, "bad" if missing_coords else "good"),
|
||||
_metric("Stops with invalid coordinates", invalid_coords, "bad" if invalid_coords else "good"),
|
||||
_metric("Routes without trips", routes_without_trips, "bad" if routes_without_trips else "good"),
|
||||
_metric("Trips without stop_times", trips_without_stop_times, "bad" if trips_without_stop_times else "good"),
|
||||
_metric("Stop times without parsed seconds", stop_times_without_seconds, "warn" if stop_times_without_seconds else "good"),
|
||||
_metric("Routes without geometry", route_geometry_missing, "warn" if route_geometry_missing else "good"),
|
||||
_metric("Canonical stop links", canonical_links, "warn" if counts.get("stops", 0) and canonical_links == 0 else "good"),
|
||||
_metric("Route matches", counts.get("matches", 0), "warn" if counts.get("routes", 0) and not counts.get("matches", 0) else "good"),
|
||||
]
|
||||
)
|
||||
issues: list[dict[str, str]] = []
|
||||
if counts.get("missing_sidecar"):
|
||||
issues.append(_issue("missing_sidecar", "bad", "GTFS sidecar is missing", "Queue a recovery import for this dataset."))
|
||||
for key, label in [
|
||||
("agencies", "No agencies imported"),
|
||||
("stops", "No stops imported"),
|
||||
("routes", "No routes imported"),
|
||||
("trips", "No trips imported"),
|
||||
("stop_times", "No stop_times imported"),
|
||||
]:
|
||||
if not counts.get(key, 0):
|
||||
issues.append(_issue(f"missing_{key}", "bad", label, "Required GTFS content is absent or failed to import."))
|
||||
if missing_coords:
|
||||
issues.append(_issue("missing_stop_coordinates", "bad", f"{missing_coords:,} stops have no coordinates", "Stop coordinates are required for deduplication and routing access."))
|
||||
if invalid_coords:
|
||||
issues.append(_issue("invalid_stop_coordinates", "bad", f"{invalid_coords:,} stops have invalid coordinates", "Fix or exclude invalid stop coordinates before publication."))
|
||||
if routes_without_trips:
|
||||
issues.append(_issue("routes_without_trips", "warn", f"{routes_without_trips:,} routes have no trips", "These routes cannot contribute timetable service."))
|
||||
if trips_without_stop_times:
|
||||
issues.append(_issue("trips_without_stop_times", "bad", f"{trips_without_stop_times:,} trips have no stop_times", "These trips cannot be routed."))
|
||||
if route_geometry_missing:
|
||||
issues.append(_issue("route_geometry_missing", "warn", f"{route_geometry_missing:,} routes have no geometry", "Use GTFS shapes, route-layer matching, or stop-by-stop fallback."))
|
||||
if counts.get("routes", 0) and not counts.get("shapes", 0):
|
||||
issues.append(_issue("missing_shapes", "warn", "No GTFS shapes imported", "OSM route matching or generated geometry will be needed."))
|
||||
if counts.get("routes", 0) and not match_counts:
|
||||
issues.append(_issue("no_route_matching", "warn", "No route-match rows", "Run route matching before route-layer publication QA."))
|
||||
return {
|
||||
"status": _qa_status(issues, dataset),
|
||||
"items": items,
|
||||
"issues": issues,
|
||||
}
|
||||
|
||||
|
||||
def _service_horizon(session: Session, dataset: Dataset | None) -> dict[str, Any]:
|
||||
if dataset is None:
|
||||
return {"start_date": None, "end_date": None, "days_until_end": None, "items": [], "issues": []}
|
||||
cal_min, cal_max = session.execute(
|
||||
select(func.min(GtfsCalendar.start_date), func.max(GtfsCalendar.end_date)).where(GtfsCalendar.dataset_id == dataset.id)
|
||||
).one()
|
||||
date_min, date_max = session.execute(
|
||||
select(func.min(GtfsCalendarDate.date), func.max(GtfsCalendarDate.date)).where(GtfsCalendarDate.dataset_id == dataset.id)
|
||||
).one()
|
||||
start_int = _min_int(cal_min, date_min)
|
||||
end_int = _max_int(cal_max, date_max)
|
||||
start_date = _gtfs_date(start_int)
|
||||
end_date = _gtfs_date(end_int)
|
||||
today = datetime.now(timezone.utc).date()
|
||||
days_until_end = None if end_date is None else (end_date - today).days
|
||||
issues: list[dict[str, str]] = []
|
||||
if end_date is None:
|
||||
issues.append(_issue("service_horizon_missing", "bad", "No service calendar horizon", "calendar.txt or calendar_dates.txt is required for reliable routing."))
|
||||
elif days_until_end is not None and days_until_end < 0:
|
||||
issues.append(_issue("service_horizon_expired", "bad", f"Service expired {abs(days_until_end):,} days ago", "Update or exclude this feed."))
|
||||
elif days_until_end is not None and days_until_end < 30:
|
||||
issues.append(_issue("service_horizon_short", "warn", f"Service ends in {days_until_end:,} days", "Update cadence is too close for publication confidence."))
|
||||
return {
|
||||
"start_date": None if start_date is None else start_date.isoformat(),
|
||||
"end_date": None if end_date is None else end_date.isoformat(),
|
||||
"days_until_end": days_until_end,
|
||||
"items": [
|
||||
_metric("Service starts", start_date.isoformat() if start_date else "n/a", "info"),
|
||||
_metric("Service ends", end_date.isoformat() if end_date else "n/a", "bad" if end_date is None or (days_until_end is not None and days_until_end < 0) else "warn" if days_until_end is not None and days_until_end < 30 else "good"),
|
||||
],
|
||||
"issues": issues,
|
||||
}
|
||||
|
||||
|
||||
def _overlap_summary(session: Session, dataset: Dataset | None) -> dict[str, Any]:
|
||||
if dataset is None:
|
||||
return {"items": [], "issues": []}
|
||||
route_key_overlaps = _shared_route_keys(session, dataset.id)
|
||||
canonical_stop_overlaps = _shared_canonical_stops(session, dataset.id)
|
||||
issues: list[dict[str, str]] = []
|
||||
if route_key_overlaps:
|
||||
issues.append(_issue("shared_route_keys", "warn", f"{route_key_overlaps:,} route keys also exist in another active feed", "Deduplicate or rank source authority for overlapping routes."))
|
||||
if canonical_stop_overlaps:
|
||||
issues.append(_issue("shared_canonical_stops", "warn", f"{canonical_stop_overlaps:,} canonical stops are shared with another active feed", "This is useful linking evidence, but conflicts need review."))
|
||||
return {
|
||||
"items": [
|
||||
_metric("Shared route keys", route_key_overlaps, "warn" if route_key_overlaps else "good"),
|
||||
_metric("Shared canonical stops", canonical_stop_overlaps, "warn" if canonical_stop_overlaps else "good"),
|
||||
],
|
||||
"issues": issues,
|
||||
}
|
||||
|
||||
|
||||
def _license_payload(source: Source) -> dict[str, Any]:
|
||||
text = (source.license or "").strip()
|
||||
unknown = not text or "unknown" in text.lower()
|
||||
return {
|
||||
"label": text or "unknown",
|
||||
"redistribution_status": "unknown" if unknown else "review_required",
|
||||
"tone": "warn" if unknown else "info",
|
||||
}
|
||||
|
||||
|
||||
def _license_issues(source: Source) -> list[dict[str, str]]:
|
||||
if _license_payload(source)["redistribution_status"] == "unknown":
|
||||
return [_issue("license_unknown", "warn", "License/redistribution status is unknown", "Publication needs explicit import, derivation, redistribution, and attribution flags.")]
|
||||
return []
|
||||
|
||||
|
||||
def _qa_review_payload(notes: str | None) -> dict[str, Any]:
|
||||
if not notes:
|
||||
return {"status": "unreviewed", "note": "", "updated_at": None}
|
||||
for line in str(notes).splitlines():
|
||||
if not line.startswith(GTFS_QA_NOTE_PREFIX):
|
||||
continue
|
||||
payload: dict[str, str] = {}
|
||||
for part in line[len(GTFS_QA_NOTE_PREFIX) :].strip().split(";"):
|
||||
if "=" not in part:
|
||||
continue
|
||||
key, value = part.split("=", 1)
|
||||
payload[key.strip()] = value.strip()
|
||||
return {
|
||||
"status": payload.get("status") or "unreviewed",
|
||||
"note": payload.get("note") or "",
|
||||
"updated_at": payload.get("updated_at"),
|
||||
}
|
||||
return {"status": "unreviewed", "note": "", "updated_at": None}
|
||||
|
||||
|
||||
def _routes_without_trips(session: Session, dataset_id: int) -> int:
|
||||
trip_exists = select(GtfsTrip.id).where(GtfsTrip.dataset_id == dataset_id, GtfsTrip.route_id == GtfsRoute.route_id).exists()
|
||||
return int(session.scalar(select(func.count()).select_from(GtfsRoute).where(GtfsRoute.dataset_id == dataset_id, ~trip_exists)) or 0)
|
||||
|
||||
|
||||
def _trips_without_stop_times(session: Session, dataset_id: int) -> int:
|
||||
stop_time_exists = select(GtfsStopTime.id).where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.trip_id == GtfsTrip.trip_id).exists()
|
||||
return int(session.scalar(select(func.count()).select_from(GtfsTrip).where(GtfsTrip.dataset_id == dataset_id, ~stop_time_exists)) or 0)
|
||||
|
||||
|
||||
def _stop_times_without_seconds(session: Session, dataset_id: int) -> int:
|
||||
return int(
|
||||
session.scalar(
|
||||
select(func.count())
|
||||
.select_from(GtfsStopTime)
|
||||
.where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.arrival_seconds.is_(None), GtfsStopTime.departure_seconds.is_(None))
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
|
||||
def _shared_route_keys(session: Session, dataset_id: int) -> int:
|
||||
current = aliased(GtfsRoute)
|
||||
other = aliased(GtfsRoute)
|
||||
other_dataset = aliased(Dataset)
|
||||
return int(
|
||||
session.scalar(
|
||||
select(func.count(func.distinct(current.route_key)))
|
||||
.select_from(current)
|
||||
.join(other, and_(other.route_key == current.route_key, other.dataset_id != current.dataset_id))
|
||||
.join(other_dataset, other_dataset.id == other.dataset_id)
|
||||
.where(
|
||||
current.dataset_id == dataset_id,
|
||||
current.route_key.is_not(None),
|
||||
current.route_key != "",
|
||||
other_dataset.kind == "gtfs",
|
||||
other_dataset.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
|
||||
def _shared_canonical_stops(session: Session, dataset_id: int) -> int:
|
||||
current = aliased(CanonicalStopLink)
|
||||
other = aliased(CanonicalStopLink)
|
||||
other_dataset = aliased(Dataset)
|
||||
return int(
|
||||
session.scalar(
|
||||
select(func.count(func.distinct(current.canonical_stop_id)))
|
||||
.select_from(current)
|
||||
.join(other, and_(other.canonical_stop_id == current.canonical_stop_id, other.dataset_id != current.dataset_id))
|
||||
.join(other_dataset, other_dataset.id == other.dataset_id)
|
||||
.where(
|
||||
current.dataset_id == dataset_id,
|
||||
current.object_type == "gtfs_stop",
|
||||
other.object_type == "gtfs_stop",
|
||||
other_dataset.kind == "gtfs",
|
||||
other_dataset.is_active.is_(True),
|
||||
)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
|
||||
def _count(session: Session, model: Any, dataset_id: int, *criteria: Any) -> int:
|
||||
stmt = select(func.count()).select_from(model).where(model.dataset_id == dataset_id)
|
||||
if criteria:
|
||||
stmt = stmt.where(*criteria)
|
||||
return int(session.scalar(stmt) or 0)
|
||||
|
||||
|
||||
def _metric(label: str, value: Any, tone: str = "info", description: str = "") -> dict[str, Any]:
|
||||
return {"label": label, "value": value, "tone": tone, "description": description}
|
||||
|
||||
|
||||
def _issue(issue_id: str, severity: str, title: str, detail: str) -> dict[str, str]:
|
||||
return {"id": issue_id, "severity": severity, "title": title, "detail": detail}
|
||||
|
||||
|
||||
def _qa_status(issues: list[dict[str, str]], dataset: Dataset | None) -> str:
|
||||
if dataset is None or any(issue.get("severity") == "bad" for issue in issues):
|
||||
return "blocked"
|
||||
if any(issue.get("severity") == "warn" for issue in issues):
|
||||
return "needs_review"
|
||||
return "ready"
|
||||
|
||||
|
||||
def _feed_sections(feed: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
return [
|
||||
{"id": "validation", "title": "GTFS Validation", "items": feed["validation"]["items"]},
|
||||
{"id": "service", "title": "Service Horizon", "items": feed["service"]["items"]},
|
||||
{"id": "overlap", "title": "Overlap and Deduplication", "items": feed["overlap"]["items"]},
|
||||
{"id": "license", "title": "License", "items": [_metric("Redistribution", feed["license"]["redistribution_status"], feed["license"]["tone"]), _metric("License", feed["license"]["label"], feed["license"]["tone"])]},
|
||||
]
|
||||
|
||||
|
||||
def _gtfs_date(value: int | None) -> date | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return datetime.strptime(str(int(value)), "%Y%m%d").date()
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _min_int(*values: int | None) -> int | None:
|
||||
clean = [int(value) for value in values if value is not None]
|
||||
return min(clean) if clean else None
|
||||
|
||||
|
||||
def _max_int(*values: int | None) -> int | None:
|
||||
clean = [int(value) for value in values if value is not None]
|
||||
return max(clean) if clean else None
|
||||
|
||||
|
||||
def _iso(value: datetime | None) -> str | None:
|
||||
return None if value is None else value.isoformat()
|
||||
360
app/itineraries.py
Normal file
360
app/itineraries.py
Normal file
@@ -0,0 +1,360 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.journey import duration_minutes_ceil, find_journeys, format_duration_label
|
||||
from app.models import Itinerary, ItineraryLeg, TravelRequest
|
||||
from app.routing import route_between_points
|
||||
|
||||
|
||||
def generate_itineraries(
|
||||
db: Session,
|
||||
*,
|
||||
from_stop_id: str,
|
||||
to_stop_id: str,
|
||||
via_stop_id: str | None,
|
||||
departure: str,
|
||||
service_date: str | None,
|
||||
max_transfers: int,
|
||||
transfer_seconds: int,
|
||||
limit: int,
|
||||
source_ids: list[int] | None,
|
||||
preferences: dict[str, Any] | None = None,
|
||||
) -> dict:
|
||||
request = TravelRequest(
|
||||
origin_stop_id=from_stop_id,
|
||||
destination_stop_id=to_stop_id,
|
||||
via_stop_id=via_stop_id or None,
|
||||
departure_time=departure,
|
||||
service_date=service_date or None,
|
||||
max_transfers=max(0, max_transfers),
|
||||
transfer_seconds=max(0, transfer_seconds),
|
||||
source_filter=",".join(str(source_id) for source_id in source_ids or []) or None,
|
||||
preferences_json=json.dumps(preferences or {}, separators=(",", ":")),
|
||||
)
|
||||
db.add(request)
|
||||
db.flush()
|
||||
|
||||
journey_result = find_journeys(
|
||||
db=db,
|
||||
from_stop_id=from_stop_id,
|
||||
to_stop_id=to_stop_id,
|
||||
via_stop_id=via_stop_id,
|
||||
departure=departure,
|
||||
service_date=service_date,
|
||||
max_transfers=max(0, max_transfers),
|
||||
transfer_seconds=max(0, transfer_seconds),
|
||||
limit=limit,
|
||||
source_ids=source_ids,
|
||||
)
|
||||
itineraries: list[Itinerary] = []
|
||||
for index, journey in enumerate(journey_result.get("journeys", []), start=1):
|
||||
itinerary = _journey_itinerary(request.id, journey, index)
|
||||
db.add(itinerary)
|
||||
db.flush()
|
||||
_add_journey_legs(db, itinerary.id, journey)
|
||||
itineraries.append(itinerary)
|
||||
|
||||
car_itinerary = _car_itinerary(db, request.id, journey_result.get("from"), journey_result.get("to"))
|
||||
if car_itinerary is not None:
|
||||
db.add(car_itinerary)
|
||||
db.flush()
|
||||
_add_routing_leg(db, car_itinerary.id, car_itinerary)
|
||||
itineraries.append(car_itinerary)
|
||||
|
||||
placeholders = _placeholder_itineraries(
|
||||
request.id,
|
||||
journey_result.get("from"),
|
||||
journey_result.get("to"),
|
||||
service_date=service_date,
|
||||
include_car=car_itinerary is None,
|
||||
)
|
||||
for itinerary in placeholders:
|
||||
db.add(itinerary)
|
||||
db.flush()
|
||||
itineraries.append(itinerary)
|
||||
|
||||
db.flush()
|
||||
return {
|
||||
"request": travel_request_payload(request),
|
||||
"journey_context": {
|
||||
"from": journey_result.get("from"),
|
||||
"to": journey_result.get("to"),
|
||||
"via": journey_result.get("via"),
|
||||
"sources": journey_result.get("sources", []),
|
||||
},
|
||||
"itineraries": [itinerary_payload(db, itinerary) for itinerary in itineraries],
|
||||
}
|
||||
|
||||
|
||||
def travel_request_payload(request: TravelRequest) -> dict[str, Any]:
|
||||
return {
|
||||
"id": request.id,
|
||||
"origin_stop_id": request.origin_stop_id,
|
||||
"destination_stop_id": request.destination_stop_id,
|
||||
"via_stop_id": request.via_stop_id,
|
||||
"departure_time": request.departure_time,
|
||||
"service_date": request.service_date,
|
||||
"max_transfers": request.max_transfers,
|
||||
"transfer_seconds": request.transfer_seconds,
|
||||
"source_filter": request.source_filter,
|
||||
"preferences": _json_dict(request.preferences_json),
|
||||
"created_at": request.created_at.isoformat() if request.created_at else None,
|
||||
}
|
||||
|
||||
|
||||
def itinerary_payload(db: Session, itinerary: Itinerary) -> dict[str, Any]:
|
||||
legs = db.scalars(
|
||||
select(ItineraryLeg)
|
||||
.where(ItineraryLeg.itinerary_id == itinerary.id)
|
||||
.order_by(ItineraryLeg.sequence)
|
||||
).all()
|
||||
return {
|
||||
"id": itinerary.id,
|
||||
"request_id": itinerary.request_id,
|
||||
"title": itinerary.title,
|
||||
"family": itinerary.family,
|
||||
"status": itinerary.status,
|
||||
"saved": itinerary.saved,
|
||||
"summary": _json_dict(itinerary.summary_json),
|
||||
"score": _json_dict(itinerary.score_json),
|
||||
"payload": _json_dict(itinerary.payload_json),
|
||||
"legs": [itinerary_leg_payload(leg) for leg in legs],
|
||||
"created_at": itinerary.created_at.isoformat() if itinerary.created_at else None,
|
||||
"updated_at": itinerary.updated_at.isoformat() if itinerary.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
def itinerary_leg_payload(leg: ItineraryLeg) -> dict[str, Any]:
|
||||
return {
|
||||
"id": leg.id,
|
||||
"itinerary_id": leg.itinerary_id,
|
||||
"sequence": leg.sequence,
|
||||
"mode": leg.mode,
|
||||
"route_ref": leg.route_ref,
|
||||
"route_name": leg.route_name,
|
||||
"from_name": leg.from_name,
|
||||
"to_name": leg.to_name,
|
||||
"departure_time": leg.departure_time,
|
||||
"arrival_time": leg.arrival_time,
|
||||
"locked": leg.locked,
|
||||
"payload": _json_dict(leg.payload_json),
|
||||
}
|
||||
|
||||
|
||||
def set_itinerary_saved(db: Session, itinerary: Itinerary, saved: bool) -> dict[str, Any]:
|
||||
itinerary.saved = saved
|
||||
itinerary.status = "saved" if saved else "candidate"
|
||||
itinerary.updated_at = datetime.now(timezone.utc)
|
||||
db.flush()
|
||||
return itinerary_payload(db, itinerary)
|
||||
|
||||
|
||||
def set_leg_locked(db: Session, leg: ItineraryLeg, locked: bool) -> dict[str, Any]:
|
||||
leg.locked = locked
|
||||
itinerary = db.get(Itinerary, leg.itinerary_id)
|
||||
if itinerary is not None:
|
||||
itinerary.updated_at = datetime.now(timezone.utc)
|
||||
db.flush()
|
||||
return itinerary_leg_payload(leg)
|
||||
|
||||
|
||||
def recent_itineraries(db: Session, *, saved_only: bool = False, limit: int = 30) -> list[dict[str, Any]]:
|
||||
stmt = select(Itinerary).order_by(Itinerary.updated_at.desc(), Itinerary.id.desc())
|
||||
if saved_only:
|
||||
stmt = stmt.where(Itinerary.saved.is_(True))
|
||||
rows = db.scalars(stmt.limit(max(1, min(limit, 100)))).all()
|
||||
return [itinerary_payload(db, itinerary) for itinerary in rows]
|
||||
|
||||
|
||||
def _journey_itinerary(request_id: int, journey: dict, index: int) -> Itinerary:
|
||||
score = _journey_score(journey)
|
||||
summary = {
|
||||
"departure_time": journey.get("departure_time"),
|
||||
"arrival_time": journey.get("arrival_time"),
|
||||
"duration_minutes": journey.get("duration_minutes"),
|
||||
"duration_label": journey.get("duration_label"),
|
||||
"transfers": journey.get("transfers"),
|
||||
"leg_count": len(journey.get("legs", [])),
|
||||
"route_refs": [leg.get("route_ref") or leg.get("route_id") for leg in journey.get("legs", [])],
|
||||
}
|
||||
return Itinerary(
|
||||
request_id=request_id,
|
||||
title=f"Public transport option {index}",
|
||||
family="public_transport",
|
||||
status="candidate",
|
||||
saved=False,
|
||||
summary_json=json.dumps(summary, separators=(",", ":")),
|
||||
score_json=json.dumps(score, separators=(",", ":")),
|
||||
payload_json=json.dumps({"journey": journey}, separators=(",", ":")),
|
||||
)
|
||||
|
||||
|
||||
def _add_journey_legs(db: Session, itinerary_id: int, journey: dict) -> None:
|
||||
for index, leg in enumerate(journey.get("legs", []), start=1):
|
||||
db.add(
|
||||
ItineraryLeg(
|
||||
itinerary_id=itinerary_id,
|
||||
sequence=index,
|
||||
mode=leg.get("mode"),
|
||||
route_ref=leg.get("route_ref"),
|
||||
route_name=leg.get("route_name"),
|
||||
from_name=(leg.get("from") or {}).get("name") or (leg.get("from") or {}).get("stop_id"),
|
||||
to_name=(leg.get("to") or {}).get("name") or (leg.get("to") or {}).get("stop_id"),
|
||||
departure_time=leg.get("departure_time"),
|
||||
arrival_time=leg.get("arrival_time"),
|
||||
locked=False,
|
||||
payload_json=json.dumps({"journey_leg": leg}, separators=(",", ":")),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _car_itinerary(db: Session, request_id: int, from_stop: dict | None, to_stop: dict | None) -> Itinerary | None:
|
||||
from_lon = _float_or_none((from_stop or {}).get("lon"))
|
||||
from_lat = _float_or_none((from_stop or {}).get("lat"))
|
||||
to_lon = _float_or_none((to_stop or {}).get("lon"))
|
||||
to_lat = _float_or_none((to_stop or {}).get("lat"))
|
||||
if None in {from_lon, from_lat, to_lon, to_lat}:
|
||||
return None
|
||||
try:
|
||||
route = route_between_points(
|
||||
db,
|
||||
from_lon=from_lon,
|
||||
from_lat=from_lat,
|
||||
to_lon=to_lon,
|
||||
to_lat=to_lat,
|
||||
mode="drive",
|
||||
max_visited=300_000,
|
||||
)
|
||||
except Exception: # noqa: BLE001 - car comparison is optional
|
||||
return None
|
||||
duration_seconds = _float_or_none(route.get("duration_seconds"))
|
||||
duration_minutes = duration_minutes_ceil(duration_seconds)
|
||||
distance_m = _float_or_none(route.get("distance_m"))
|
||||
summary = {
|
||||
"from": (from_stop or {}).get("name") or (from_stop or {}).get("stop_id") or "origin",
|
||||
"to": (to_stop or {}).get("name") or (to_stop or {}).get("stop_id") or "destination",
|
||||
"duration_minutes": duration_minutes,
|
||||
"duration_label": format_duration_label(duration_seconds),
|
||||
"distance_km": None if distance_m is None else round(distance_m / 1000, 1),
|
||||
"transfers": 0,
|
||||
"engine": route.get("engine"),
|
||||
}
|
||||
score = {
|
||||
"duration_minutes": duration_minutes,
|
||||
"transfers": 0,
|
||||
"complexity": 1,
|
||||
"emissions": "high",
|
||||
"estimated_cost": None,
|
||||
}
|
||||
return Itinerary(
|
||||
request_id=request_id,
|
||||
title="Car only",
|
||||
family="car",
|
||||
status="candidate",
|
||||
saved=False,
|
||||
summary_json=json.dumps(summary, separators=(",", ":")),
|
||||
score_json=json.dumps(score, separators=(",", ":")),
|
||||
payload_json=json.dumps({"routing": route}, separators=(",", ":")),
|
||||
)
|
||||
|
||||
|
||||
def _add_routing_leg(db: Session, itinerary_id: int, itinerary: Itinerary) -> None:
|
||||
payload = _json_dict(itinerary.payload_json)
|
||||
route = payload.get("routing") if isinstance(payload, dict) else None
|
||||
if not isinstance(route, dict):
|
||||
return
|
||||
db.add(
|
||||
ItineraryLeg(
|
||||
itinerary_id=itinerary_id,
|
||||
sequence=1,
|
||||
mode=str(route.get("mode") or "drive"),
|
||||
route_ref=None,
|
||||
route_name="Road route",
|
||||
from_name=str((route.get("start_node") or {}).get("osm_node_id") or "origin"),
|
||||
to_name=str((route.get("target_node") or {}).get("osm_node_id") or "destination"),
|
||||
departure_time=None,
|
||||
arrival_time=None,
|
||||
locked=False,
|
||||
payload_json=json.dumps({"routing_leg": route}, separators=(",", ":")),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _placeholder_itineraries(
|
||||
request_id: int,
|
||||
from_stop: dict | None,
|
||||
to_stop: dict | None,
|
||||
*,
|
||||
service_date: str | None,
|
||||
include_car: bool = True,
|
||||
) -> list[Itinerary]:
|
||||
from_name = (from_stop or {}).get("name") or (from_stop or {}).get("stop_id") or "origin"
|
||||
to_name = (to_stop or {}).get("name") or (to_stop or {}).get("stop_id") or "destination"
|
||||
placeholders = [
|
||||
("car_ferry", "Car + ferry", "Needs ferry-port candidate graph", {"complexity": 3, "emissions": "medium_high"}),
|
||||
("flight_access", "Flight + airport access", "Needs airport/flight schedule connector", {"complexity": 4, "emissions": "high"}),
|
||||
("rail_long_stay", "Rail with adjustable city stop", "Use via stop and leg locking to refine", {"complexity": 3, "emissions": "low"}),
|
||||
]
|
||||
if include_car:
|
||||
placeholders.insert(0, ("car", "Car only", "Needs road-routing connector", {"complexity": 1, "emissions": "high"}))
|
||||
rows = []
|
||||
for family, title, note, score in placeholders:
|
||||
summary = {
|
||||
"from": from_name,
|
||||
"to": to_name,
|
||||
"service_date": service_date,
|
||||
"note": note,
|
||||
"duration_minutes": None,
|
||||
"transfers": None,
|
||||
}
|
||||
rows.append(
|
||||
Itinerary(
|
||||
request_id=request_id,
|
||||
title=title,
|
||||
family=family,
|
||||
status="placeholder",
|
||||
saved=False,
|
||||
summary_json=json.dumps(summary, separators=(",", ":")),
|
||||
score_json=json.dumps(score, separators=(",", ":")),
|
||||
payload_json=json.dumps({"placeholder": True, "note": note}, separators=(",", ":")),
|
||||
)
|
||||
)
|
||||
return rows
|
||||
|
||||
|
||||
def _float_or_none(value: object) -> float | None:
|
||||
try:
|
||||
return None if value is None else float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _journey_score(journey: dict) -> dict[str, Any]:
|
||||
modes = [leg.get("mode") for leg in journey.get("legs", [])]
|
||||
duration = journey.get("duration_minutes")
|
||||
transfers = int(journey.get("transfers") or 0)
|
||||
railish = sum(1 for mode in modes if mode in {"train", "subway", "tram", "light_rail"})
|
||||
busish = sum(1 for mode in modes if mode in {"bus", "coach", "trolleybus"})
|
||||
emissions_hint = "low" if railish >= busish else "medium"
|
||||
return {
|
||||
"duration_minutes": duration,
|
||||
"transfers": transfers,
|
||||
"complexity": transfers + len(modes),
|
||||
"emissions": emissions_hint,
|
||||
"overnight": False,
|
||||
"estimated_cost": None,
|
||||
}
|
||||
|
||||
|
||||
def _json_dict(value: str | None) -> dict[str, Any]:
|
||||
try:
|
||||
data = json.loads(value or "{}")
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return data if isinstance(data, dict) else {}
|
||||
1932
app/jobs.py
Normal file
1932
app/jobs.py
Normal file
File diff suppressed because it is too large
Load Diff
5385
app/journey.py
Normal file
5385
app/journey.py
Normal file
File diff suppressed because it is too large
Load Diff
717
app/journey_search.py
Normal file
717
app/journey_search.py
Normal file
@@ -0,0 +1,717 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import hashlib
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.address_search import is_location_token
|
||||
from app.db import SessionLocal
|
||||
from app.journey import find_journeys, parse_service_date, resolve_location_summary
|
||||
from app.models import JourneySearchCache
|
||||
from app.routing import direct_route_between_points, route_between_points
|
||||
|
||||
|
||||
MAX_PROGRESSIVE_TRANSFERS = 5
|
||||
TRANSIT_STAGE_CACHE_TTL_SECONDS = 5 * 60
|
||||
TRANSIT_STAGE_CACHE_MAX_ENTRIES = 256
|
||||
PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS = 10 * 60
|
||||
PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES = 128
|
||||
JOURNEY_SEARCH_CACHE_VERSION = "journey-search-v7"
|
||||
_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="journey-search")
|
||||
_lock = threading.RLock()
|
||||
_searches: dict[str, "_SearchState"] = {}
|
||||
_progressive_search_inflight: dict[tuple[object, ...], str] = {}
|
||||
_transit_stage_cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]] = {}
|
||||
_progressive_search_cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SearchState:
|
||||
id: str
|
||||
request: dict[str, Any]
|
||||
cache_key: tuple[object, ...] | None = None
|
||||
status: str = "queued"
|
||||
message: str = "Queued."
|
||||
stage: str = "queued"
|
||||
journeys: list[dict] = field(default_factory=list)
|
||||
routing: dict[str, Any] | None = None
|
||||
context: dict[str, Any] = field(default_factory=dict)
|
||||
error: str | None = None
|
||||
created_at: float = field(default_factory=time.time)
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
complete: bool = False
|
||||
cancelled: bool = False
|
||||
|
||||
|
||||
def start_journey_search(request: dict[str, Any]) -> dict[str, Any]:
|
||||
key = _progressive_cache_key(request)
|
||||
cached = _progressive_cache_get(key)
|
||||
search_id = uuid.uuid4().hex
|
||||
state = _SearchState(id=search_id, request=dict(request), cache_key=key)
|
||||
if cached is not None:
|
||||
_apply_cached_payload(state, cached)
|
||||
with _lock:
|
||||
_prune_old_searches()
|
||||
if cached is None:
|
||||
existing_search_id = _progressive_search_inflight.get(key)
|
||||
existing_state = None if existing_search_id is None else _searches.get(existing_search_id)
|
||||
if existing_state is not None and not existing_state.complete and not existing_state.cancelled:
|
||||
return _payload(existing_state)
|
||||
_progressive_search_inflight[key] = search_id
|
||||
_searches[search_id] = state
|
||||
if cached is None:
|
||||
_executor.submit(_run_search, search_id)
|
||||
return journey_search_payload(search_id)
|
||||
|
||||
|
||||
def journey_search_payload(search_id: str) -> dict[str, Any]:
|
||||
with _lock:
|
||||
state = _searches.get(search_id)
|
||||
if state is None:
|
||||
raise KeyError(search_id)
|
||||
return _payload(state)
|
||||
|
||||
|
||||
def cancel_journey_search(search_id: str) -> dict[str, Any]:
|
||||
with _lock:
|
||||
state = _searches.get(search_id)
|
||||
if state is None:
|
||||
raise KeyError(search_id)
|
||||
state.cancelled = True
|
||||
if not state.complete:
|
||||
state.status = "cancelled"
|
||||
state.message = "Search cancelled."
|
||||
state.complete = True
|
||||
state.updated_at = time.time()
|
||||
_clear_inflight_search_locked(state)
|
||||
return _payload(state)
|
||||
|
||||
|
||||
def _run_search(search_id: str) -> None:
|
||||
with _lock:
|
||||
state = _searches.get(search_id)
|
||||
if state is None or state.cancelled:
|
||||
return
|
||||
state.status = "running"
|
||||
state.stage = "starting"
|
||||
state.message = "Starting search..."
|
||||
state.updated_at = time.time()
|
||||
request = dict(state.request)
|
||||
try:
|
||||
mode = str(request.get("mode") or "transit")
|
||||
if mode in {"walk", "drive", "car"}:
|
||||
_run_point_route_search(search_id, "drive" if mode == "car" else mode, request)
|
||||
else:
|
||||
_run_transit_search(search_id, request)
|
||||
except Exception as exc: # noqa: BLE001 - report progressive-search failure to client
|
||||
_publish_error(search_id, str(exc))
|
||||
|
||||
|
||||
def _run_transit_search(search_id: str, request: dict[str, Any]) -> None:
|
||||
direct_only = bool(request.get("direct_only"))
|
||||
limit = max(3, min(int(request.get("limit") or 5), 10))
|
||||
transfer_seconds = max(0, min(int(request.get("transfer_seconds") or 120), 3600))
|
||||
source_ids = _csv_ints(request.get("source_id"))
|
||||
service_date = request.get("service_date") or None
|
||||
stages = [0] if direct_only else list(range(0, MAX_PROGRESSIVE_TRANSFERS + 1))
|
||||
address_search = is_location_token(request.get("from_stop_id")) or is_location_token(request.get("to_stop_id"))
|
||||
stage_limit = limit if address_search else max(limit, 10)
|
||||
merged: dict[str, dict] = {}
|
||||
context: dict[str, Any] = {}
|
||||
diagnostics: dict[str, Any] = {"stages": []}
|
||||
best_count = 0
|
||||
stale_stages = 0
|
||||
for transfers in stages:
|
||||
if _is_cancelled(search_id):
|
||||
return
|
||||
label = "direct" if transfers == 0 else f"up to {transfers} transfer{'s' if transfers != 1 else ''}"
|
||||
_publish_status(search_id, "running", f"Searching {label}...", f"transfers_{transfers}")
|
||||
stage_started_at = time.monotonic()
|
||||
with SessionLocal() as db:
|
||||
result = _cached_find_journeys(
|
||||
db,
|
||||
from_stop_id=str(request.get("from_stop_id") or ""),
|
||||
to_stop_id=str(request.get("to_stop_id") or ""),
|
||||
via_stop_id=request.get("via_stop_id") or None,
|
||||
source_ids=source_ids,
|
||||
departure=str(request.get("departure") or "08:00"),
|
||||
service_date=service_date,
|
||||
max_transfers=transfers,
|
||||
transfer_seconds=transfer_seconds,
|
||||
limit=stage_limit,
|
||||
)
|
||||
cache_status = str(result.pop("_cache_status", "miss"))
|
||||
elapsed_ms = int((time.monotonic() - stage_started_at) * 1000)
|
||||
stage_diagnostics = {
|
||||
"transfers": transfers,
|
||||
"cache": cache_status,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"journeys": len(result.get("journeys") or []),
|
||||
}
|
||||
result_diagnostics = result.get("diagnostics")
|
||||
if isinstance(result_diagnostics, dict):
|
||||
stage_diagnostics["details"] = result_diagnostics
|
||||
diagnostics["stages"].append(stage_diagnostics)
|
||||
context = _context_from_result(result)
|
||||
context["diagnostics"] = diagnostics
|
||||
before = len(merged)
|
||||
for journey in result.get("journeys") or []:
|
||||
merged.setdefault(_journey_key(journey), journey)
|
||||
ranked = _select_diverse_journeys(_rank_journeys(merged.values(), str(request.get("ranking") or "recommended")), limit=limit)
|
||||
_publish_results(
|
||||
search_id,
|
||||
journeys=ranked,
|
||||
context=context,
|
||||
status="running",
|
||||
stage=f"transfers_{transfers}",
|
||||
message=f"Found {len(ranked)} option{'s' if len(ranked) != 1 else ''}; still searching..." if not direct_only else "Direct search complete.",
|
||||
)
|
||||
if len(merged) <= before and ranked:
|
||||
stale_stages += 1
|
||||
else:
|
||||
stale_stages = 0
|
||||
best_count = max(best_count, len(ranked))
|
||||
if ranked and stale_stages >= 2 and transfers >= 2:
|
||||
break
|
||||
if _major_hub_address_stage_is_complete(result_diagnostics, ranked, transfers=transfers, limit=limit):
|
||||
break
|
||||
complete_message = (
|
||||
f"Search complete. Found {best_count} option{'s' if best_count != 1 else ''}."
|
||||
if best_count
|
||||
else "Search complete. No route found in the imported timetable."
|
||||
)
|
||||
_publish_complete(search_id, message=complete_message)
|
||||
payload = journey_search_payload(search_id)
|
||||
if payload.get("status") == "complete" and not payload.get("error"):
|
||||
_progressive_cache_put(_progressive_cache_key(request), payload)
|
||||
|
||||
|
||||
def _major_hub_address_stage_is_complete(
|
||||
diagnostics: dict[str, Any] | None,
|
||||
ranked: list[dict],
|
||||
*,
|
||||
transfers: int,
|
||||
limit: int,
|
||||
) -> bool:
|
||||
if transfers < 1 or not ranked or not isinstance(diagnostics, dict):
|
||||
return False
|
||||
address_access = diagnostics.get("address_access")
|
||||
if not isinstance(address_access, dict) or not address_access.get("major_hubs"):
|
||||
return False
|
||||
return len(ranked) >= min(3, limit)
|
||||
|
||||
|
||||
def _run_point_route_search(search_id: str, mode: str, request: dict[str, Any]) -> None:
|
||||
_publish_status(search_id, "running", f"Searching {mode} route...", mode)
|
||||
with SessionLocal() as db:
|
||||
from_location = resolve_location_summary(db, str(request.get("from_stop_id") or ""), source_ids=_csv_ints(request.get("source_id")))
|
||||
to_location = resolve_location_summary(db, str(request.get("to_stop_id") or ""), source_ids=_csv_ints(request.get("source_id")))
|
||||
if from_location.lon is None or from_location.lat is None:
|
||||
raise ValueError("Selected start has no coordinates.")
|
||||
if to_location.lon is None or to_location.lat is None:
|
||||
raise ValueError("Selected destination has no coordinates.")
|
||||
try:
|
||||
route = route_between_points(
|
||||
db,
|
||||
from_lon=float(from_location.lon),
|
||||
from_lat=float(from_location.lat),
|
||||
to_lon=float(to_location.lon),
|
||||
to_lat=float(to_location.lat),
|
||||
mode=mode,
|
||||
max_visited=300_000,
|
||||
)
|
||||
message = f"{mode.title()} route found."
|
||||
except Exception as exc: # noqa: BLE001 - point routing should still return an approximate connector
|
||||
route = direct_route_between_points(
|
||||
db,
|
||||
from_lon=float(from_location.lon),
|
||||
from_lat=float(from_location.lat),
|
||||
to_lon=float(to_location.lon),
|
||||
to_lat=float(to_location.lat),
|
||||
mode=mode,
|
||||
reason=str(exc),
|
||||
)
|
||||
message = f"{mode.title()} route approximated."
|
||||
context = {
|
||||
"from": _stop_payload(from_location),
|
||||
"to": _stop_payload(to_location),
|
||||
"mode": mode,
|
||||
}
|
||||
_publish_routing(search_id, route, context=context, message=message)
|
||||
_publish_complete(search_id, message=f"{mode.title()} route complete.")
|
||||
payload = journey_search_payload(search_id)
|
||||
if payload.get("status") == "complete" and not payload.get("error"):
|
||||
_progressive_cache_put(_progressive_cache_key(request), payload)
|
||||
|
||||
|
||||
def _cached_find_journeys(
|
||||
db,
|
||||
*,
|
||||
from_stop_id: str,
|
||||
to_stop_id: str,
|
||||
via_stop_id: object,
|
||||
source_ids: list[int] | None,
|
||||
departure: str,
|
||||
service_date: object,
|
||||
max_transfers: int,
|
||||
transfer_seconds: int,
|
||||
limit: int,
|
||||
) -> dict[str, Any]:
|
||||
key = (
|
||||
from_stop_id,
|
||||
to_stop_id,
|
||||
str(via_stop_id or ""),
|
||||
tuple(sorted(int(source_id) for source_id in source_ids or [])),
|
||||
departure,
|
||||
str(service_date or ""),
|
||||
int(max_transfers),
|
||||
int(transfer_seconds),
|
||||
int(limit),
|
||||
)
|
||||
now = time.monotonic()
|
||||
with _lock:
|
||||
cached = _transit_stage_cache.get(key)
|
||||
if cached is not None:
|
||||
expires_at, payload = cached
|
||||
if expires_at > now:
|
||||
return _with_cache_status(payload, "memory")
|
||||
_transit_stage_cache.pop(key, None)
|
||||
durable = _durable_cache_get("transit_stage", key)
|
||||
if durable is not None:
|
||||
with _lock:
|
||||
_transit_stage_cache[key] = (now + TRANSIT_STAGE_CACHE_TTL_SECONDS, json.loads(json.dumps(durable)))
|
||||
_prune_timed_cache(_transit_stage_cache, TRANSIT_STAGE_CACHE_MAX_ENTRIES)
|
||||
return _with_cache_status(durable, "persistent")
|
||||
result = find_journeys(
|
||||
db=db,
|
||||
from_stop_id=from_stop_id,
|
||||
to_stop_id=to_stop_id,
|
||||
via_stop_id=via_stop_id,
|
||||
source_ids=source_ids,
|
||||
departure=departure,
|
||||
service_date=service_date,
|
||||
max_transfers=max_transfers,
|
||||
transfer_seconds=transfer_seconds,
|
||||
limit=limit,
|
||||
)
|
||||
stored_result = json.loads(json.dumps(result))
|
||||
with _lock:
|
||||
_transit_stage_cache[key] = (now + TRANSIT_STAGE_CACHE_TTL_SECONDS, stored_result)
|
||||
_prune_timed_cache(_transit_stage_cache, TRANSIT_STAGE_CACHE_MAX_ENTRIES)
|
||||
_durable_cache_put("transit_stage", key, stored_result, ttl_seconds=TRANSIT_STAGE_CACHE_TTL_SECONDS)
|
||||
return _with_cache_status(result, "miss")
|
||||
|
||||
|
||||
def _with_cache_status(payload: dict[str, Any], cache_status: str) -> dict[str, Any]:
|
||||
copied = json.loads(json.dumps(payload))
|
||||
copied["_cache_status"] = cache_status
|
||||
return copied
|
||||
|
||||
|
||||
def _prune_timed_cache(cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]], max_entries: int) -> None:
|
||||
if len(cache) <= max_entries:
|
||||
return
|
||||
oldest = sorted(cache.items(), key=lambda item: item[1][0])[: len(cache) - max_entries]
|
||||
for old_key, _ in oldest:
|
||||
cache.pop(old_key, None)
|
||||
|
||||
|
||||
def _durable_cache_get(cache_type: str, key: tuple[object, ...]) -> dict[str, Any] | None:
|
||||
storage_key = _durable_cache_key(cache_type, key)
|
||||
now = datetime.now(timezone.utc)
|
||||
try:
|
||||
with SessionLocal() as session:
|
||||
row = session.scalar(select(JourneySearchCache).where(JourneySearchCache.cache_key == storage_key))
|
||||
if row is None:
|
||||
return None
|
||||
expires_at = _as_utc(row.expires_at)
|
||||
if expires_at is None or expires_at <= now:
|
||||
session.delete(row)
|
||||
session.commit()
|
||||
return None
|
||||
return json.loads(row.payload_json)
|
||||
except Exception: # noqa: BLE001 - cache misses must not break journey search
|
||||
return None
|
||||
|
||||
|
||||
def _durable_cache_put(cache_type: str, key: tuple[object, ...], payload: dict[str, Any], *, ttl_seconds: int) -> None:
|
||||
storage_key = _durable_cache_key(cache_type, key)
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=max(1, int(ttl_seconds)))
|
||||
try:
|
||||
with SessionLocal() as session:
|
||||
row = session.scalar(select(JourneySearchCache).where(JourneySearchCache.cache_key == storage_key))
|
||||
if row is None:
|
||||
row = JourneySearchCache(
|
||||
cache_key=storage_key,
|
||||
cache_type=cache_type,
|
||||
payload_json=json.dumps(payload, separators=(",", ":")),
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
session.add(row)
|
||||
else:
|
||||
row.cache_type = cache_type
|
||||
row.payload_json = json.dumps(payload, separators=(",", ":"))
|
||||
row.updated_at = now
|
||||
row.expires_at = expires_at
|
||||
session.commit()
|
||||
except Exception: # noqa: BLE001 - cache writes are best-effort
|
||||
return
|
||||
|
||||
|
||||
def _durable_cache_key(cache_type: str, key: tuple[object, ...]) -> str:
|
||||
raw = json.dumps(
|
||||
{
|
||||
"version": JOURNEY_SEARCH_CACHE_VERSION,
|
||||
"cache_type": cache_type,
|
||||
"key": _json_safe(key),
|
||||
},
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _json_safe(value: object) -> object:
|
||||
if isinstance(value, tuple):
|
||||
return [_json_safe(item) for item in value]
|
||||
if isinstance(value, list):
|
||||
return [_json_safe(item) for item in value]
|
||||
if isinstance(value, dict):
|
||||
return {str(key): _json_safe(item) for key, item in sorted(value.items(), key=lambda item: str(item[0]))}
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
return value
|
||||
if isinstance(value, (date, datetime)):
|
||||
return value.isoformat()
|
||||
return str(value)
|
||||
|
||||
|
||||
def _as_utc(value: datetime | None) -> datetime | None:
|
||||
if value is None:
|
||||
return None
|
||||
if value.tzinfo is None:
|
||||
return value.replace(tzinfo=timezone.utc)
|
||||
return value.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def _progressive_cache_key(request: dict[str, Any]) -> tuple[object, ...]:
|
||||
source_ids = _csv_ints(request.get("source_id"))
|
||||
return (
|
||||
str(request.get("mode") or "transit"),
|
||||
str(request.get("from_stop_id") or ""),
|
||||
str(request.get("to_stop_id") or ""),
|
||||
str(request.get("via_stop_id") or ""),
|
||||
tuple(sorted(int(source_id) for source_id in source_ids or [])),
|
||||
str(request.get("departure") or "08:00"),
|
||||
str(request.get("service_date") or ""),
|
||||
bool(request.get("direct_only")),
|
||||
str(request.get("ranking") or "recommended"),
|
||||
int(request.get("transfer_seconds") or 120),
|
||||
max(3, min(int(request.get("limit") or 5), 10)),
|
||||
)
|
||||
|
||||
|
||||
def _progressive_cache_get(key: tuple[object, ...]) -> dict[str, Any] | None:
|
||||
now = time.monotonic()
|
||||
with _lock:
|
||||
cached = _progressive_search_cache.get(key)
|
||||
if cached is not None:
|
||||
expires_at, payload = cached
|
||||
if expires_at > now:
|
||||
copied = json.loads(json.dumps(payload))
|
||||
copied["cache_status"] = "memory"
|
||||
return copied
|
||||
_progressive_search_cache.pop(key, None)
|
||||
durable = _durable_cache_get("progressive", key)
|
||||
if durable is None:
|
||||
return None
|
||||
with _lock:
|
||||
_progressive_search_cache[key] = (now + PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS, json.loads(json.dumps(durable)))
|
||||
_prune_timed_cache(_progressive_search_cache, PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES)
|
||||
copied = json.loads(json.dumps(durable))
|
||||
copied["cache_status"] = "persistent"
|
||||
return copied
|
||||
|
||||
|
||||
def _progressive_cache_put(key: tuple[object, ...], payload: dict[str, Any]) -> None:
|
||||
stored_payload = json.loads(json.dumps(payload))
|
||||
stored_payload.pop("cache_status", None)
|
||||
with _lock:
|
||||
_progressive_search_cache[key] = (time.monotonic() + PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS, stored_payload)
|
||||
_prune_timed_cache(_progressive_search_cache, PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES)
|
||||
_durable_cache_put("progressive", key, stored_payload, ttl_seconds=PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS)
|
||||
|
||||
|
||||
def _apply_cached_payload(state: _SearchState, payload: dict[str, Any]) -> None:
|
||||
state.status = str(payload.get("status") or "complete")
|
||||
state.message = "Cached result."
|
||||
state.stage = str(payload.get("stage") or "cached")
|
||||
state.journeys = json.loads(json.dumps(payload.get("journeys") or []))
|
||||
state.routing = json.loads(json.dumps(payload.get("routing"))) if payload.get("routing") is not None else None
|
||||
state.context = {
|
||||
key: value
|
||||
for key, value in payload.items()
|
||||
if key not in {"search_id", "status", "stage", "message", "complete", "error", "journeys", "routing", "created_at", "updated_at"}
|
||||
}
|
||||
state.error = None
|
||||
state.complete = True
|
||||
state.updated_at = time.time()
|
||||
|
||||
|
||||
def _publish_status(search_id: str, status: str, message: str, stage: str) -> None:
|
||||
with _lock:
|
||||
state = _searches.get(search_id)
|
||||
if state is None or state.cancelled:
|
||||
return
|
||||
state.status = status
|
||||
state.message = message
|
||||
state.stage = stage
|
||||
state.updated_at = time.time()
|
||||
|
||||
|
||||
def _publish_results(search_id: str, *, journeys: list[dict], context: dict[str, Any], status: str, stage: str, message: str) -> None:
|
||||
with _lock:
|
||||
state = _searches.get(search_id)
|
||||
if state is None or state.cancelled:
|
||||
return
|
||||
state.status = status
|
||||
state.stage = stage
|
||||
state.message = message
|
||||
state.journeys = list(journeys)
|
||||
state.context = dict(context)
|
||||
state.updated_at = time.time()
|
||||
|
||||
|
||||
def _publish_routing(search_id: str, routing: dict[str, Any], *, context: dict[str, Any], message: str) -> None:
|
||||
with _lock:
|
||||
state = _searches.get(search_id)
|
||||
if state is None or state.cancelled:
|
||||
return
|
||||
state.status = "running"
|
||||
state.stage = str(routing.get("mode") or "route")
|
||||
state.message = message
|
||||
state.routing = routing
|
||||
state.context = dict(context)
|
||||
state.updated_at = time.time()
|
||||
|
||||
|
||||
def _publish_complete(search_id: str, *, message: str) -> None:
|
||||
with _lock:
|
||||
state = _searches.get(search_id)
|
||||
if state is None or state.cancelled:
|
||||
return
|
||||
state.status = "complete"
|
||||
state.message = message
|
||||
state.complete = True
|
||||
state.updated_at = time.time()
|
||||
_clear_inflight_search_locked(state)
|
||||
|
||||
|
||||
def _publish_error(search_id: str, message: str) -> None:
|
||||
with _lock:
|
||||
state = _searches.get(search_id)
|
||||
if state is None:
|
||||
return
|
||||
state.status = "error"
|
||||
state.stage = "error"
|
||||
state.message = message
|
||||
state.error = message
|
||||
state.complete = True
|
||||
state.updated_at = time.time()
|
||||
_clear_inflight_search_locked(state)
|
||||
|
||||
|
||||
def _clear_inflight_search_locked(state: _SearchState) -> None:
|
||||
if state.cache_key is not None and _progressive_search_inflight.get(state.cache_key) == state.id:
|
||||
_progressive_search_inflight.pop(state.cache_key, None)
|
||||
|
||||
|
||||
def _is_cancelled(search_id: str) -> bool:
|
||||
with _lock:
|
||||
state = _searches.get(search_id)
|
||||
return state is None or state.cancelled
|
||||
|
||||
|
||||
def _payload(state: _SearchState) -> dict[str, Any]:
|
||||
return {
|
||||
"search_id": state.id,
|
||||
"status": state.status,
|
||||
"stage": state.stage,
|
||||
"message": state.message,
|
||||
"complete": state.complete,
|
||||
"error": state.error,
|
||||
"journeys": json.loads(json.dumps(state.journeys)),
|
||||
"routing": json.loads(json.dumps(state.routing)) if state.routing is not None else None,
|
||||
"created_at": state.created_at,
|
||||
"updated_at": state.updated_at,
|
||||
**json.loads(json.dumps(state.context)),
|
||||
}
|
||||
|
||||
|
||||
def _context_from_result(result: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
key: value
|
||||
for key, value in result.items()
|
||||
if key not in {"journeys"} and key not in {"error"}
|
||||
}
|
||||
|
||||
|
||||
def _journey_key(journey: dict[str, Any]) -> str:
|
||||
parts = []
|
||||
for leg in journey.get("legs") or []:
|
||||
parts.append(
|
||||
"|".join(
|
||||
str(part or "")
|
||||
for part in [
|
||||
leg.get("dataset_id"),
|
||||
leg.get("mode"),
|
||||
leg.get("route_id"),
|
||||
leg.get("trip_id"),
|
||||
(leg.get("from") or {}).get("stop_id") or (leg.get("from") or {}).get("name"),
|
||||
(leg.get("to") or {}).get("stop_id") or (leg.get("to") or {}).get("name"),
|
||||
leg.get("departure_time"),
|
||||
leg.get("arrival_time"),
|
||||
]
|
||||
)
|
||||
)
|
||||
return "||".join(parts)
|
||||
|
||||
|
||||
def _rank_journeys(journeys, ranking: str) -> list[dict]:
|
||||
def key(journey: dict[str, Any]) -> tuple[float, float, int, float]:
|
||||
departure = journey.get("departure_seconds")
|
||||
arrival = journey.get("arrival_seconds")
|
||||
duration = journey.get("duration_minutes")
|
||||
transfers = int(journey.get("transfers") or 0)
|
||||
walking = sum(float(leg.get("distance_m") or 0) for leg in journey.get("legs") or [] if leg.get("mode") == "walk")
|
||||
walking_seconds = walking / 1.35
|
||||
if ranking == "duration":
|
||||
return (
|
||||
float("inf") if duration is None else float(duration),
|
||||
float("inf") if arrival is None else float(arrival),
|
||||
transfers,
|
||||
walking,
|
||||
)
|
||||
if ranking == "fewest_transfers":
|
||||
return (
|
||||
transfers,
|
||||
float("inf") if arrival is None else float(arrival),
|
||||
float("inf") if duration is None else float(duration),
|
||||
walking,
|
||||
)
|
||||
if ranking == "earliest_arrival":
|
||||
return (
|
||||
float("inf") if arrival is None else float(arrival),
|
||||
float("inf") if duration is None else float(duration),
|
||||
transfers,
|
||||
walking,
|
||||
)
|
||||
return (
|
||||
float("inf") if arrival is None else float(arrival) + transfers * 600 + walking_seconds,
|
||||
float("inf") if arrival is None else float(arrival),
|
||||
transfers,
|
||||
walking,
|
||||
)
|
||||
|
||||
return sorted((dict(journey) for journey in journeys), key=key)
|
||||
|
||||
|
||||
def _select_diverse_journeys(journeys: list[dict], *, limit: int) -> list[dict]:
|
||||
selected: list[dict] = []
|
||||
selected_exact: set[str] = set()
|
||||
selected_diversity: set[tuple[object, ...]] = set()
|
||||
for journey in journeys:
|
||||
exact_key = _journey_key(journey)
|
||||
if exact_key in selected_exact:
|
||||
continue
|
||||
diversity_key = _journey_diversity_key(journey)
|
||||
if diversity_key in selected_diversity and len(selected) >= 3:
|
||||
continue
|
||||
selected.append(journey)
|
||||
selected_exact.add(exact_key)
|
||||
selected_diversity.add(diversity_key)
|
||||
if len(selected) >= limit:
|
||||
return selected
|
||||
if len(selected) >= min(3, limit):
|
||||
return selected
|
||||
for journey in journeys:
|
||||
exact_key = _journey_key(journey)
|
||||
if exact_key in selected_exact:
|
||||
continue
|
||||
selected.append(journey)
|
||||
selected_exact.add(exact_key)
|
||||
if len(selected) >= min(3, limit):
|
||||
break
|
||||
return _ensure_walk_only_option(selected, journeys, limit=limit)
|
||||
|
||||
|
||||
def _ensure_walk_only_option(selected: list[dict], ranked: list[dict], *, limit: int) -> list[dict]:
|
||||
if any(_journey_is_walk_only(journey) for journey in selected):
|
||||
return selected
|
||||
walk = next((journey for journey in ranked if _journey_is_walk_only(journey)), None)
|
||||
if walk is None:
|
||||
return selected
|
||||
if len(selected) < limit:
|
||||
return [*selected, walk]
|
||||
if selected:
|
||||
selected[-1] = walk
|
||||
return selected
|
||||
|
||||
|
||||
def _journey_is_walk_only(journey: dict) -> bool:
|
||||
legs = journey.get("legs") or []
|
||||
return bool(legs) and all(leg.get("mode") == "walk" for leg in legs)
|
||||
|
||||
|
||||
def _journey_diversity_key(journey: dict[str, Any]) -> tuple[object, ...]:
|
||||
route_signature = tuple(
|
||||
str(leg.get("route_ref") or leg.get("route_id") or leg.get("mode") or "")
|
||||
for leg in journey.get("legs") or []
|
||||
if leg.get("mode") != "walk"
|
||||
)
|
||||
departure = journey.get("departure_seconds")
|
||||
time_band = None if departure is None else int(departure) // (30 * 60)
|
||||
return (int(journey.get("transfers") or 0), route_signature, time_band)
|
||||
|
||||
|
||||
def _csv_ints(value: object) -> list[int] | None:
|
||||
if value is None:
|
||||
return None
|
||||
items = [item.strip() for item in str(value).split(",") if item.strip()]
|
||||
if not items:
|
||||
return None
|
||||
return [int(item) for item in items]
|
||||
|
||||
|
||||
def _stop_payload(stop) -> dict[str, Any]:
|
||||
return {
|
||||
"id": stop.id,
|
||||
"dataset_id": stop.dataset_id,
|
||||
"stop_id": stop.stop_id,
|
||||
"name": stop.name,
|
||||
"lat": stop.lat,
|
||||
"lon": stop.lon,
|
||||
}
|
||||
|
||||
|
||||
def _prune_old_searches() -> None:
|
||||
now = time.time()
|
||||
stale = [
|
||||
search_id
|
||||
for search_id, state in _searches.items()
|
||||
if now - state.updated_at > 15 * 60 or (state.complete and now - state.updated_at > 3 * 60)
|
||||
]
|
||||
for search_id in stale:
|
||||
state = _searches.pop(search_id, None)
|
||||
if state is not None:
|
||||
_clear_inflight_search_locked(state)
|
||||
2653
app/main.py
Normal file
2653
app/main.py
Normal file
File diff suppressed because it is too large
Load Diff
612
app/models.py
Normal file
612
app/models.py
Normal file
@@ -0,0 +1,612 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import BigInteger, Boolean, DateTime, Float, ForeignKey, Integer, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db import Base
|
||||
|
||||
|
||||
def now_utc() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class Source(Base):
|
||||
__tablename__ = "sources"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
catalog_entry_id: Mapped[Optional[int]] = mapped_column(ForeignKey("source_catalog_entries.id"), nullable=True, index=True)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
kind: Mapped[str] = mapped_column(String(64), nullable=False) # gtfs, osm_geojson, osm_pbf, osm_diff
|
||||
url: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
country: Mapped[Optional[str]] = mapped_column(String(8), nullable=True)
|
||||
license: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
priority: Mapped[Optional[str]] = mapped_column(String(16), nullable=True, index=True)
|
||||
mode_scope: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
source_basis: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(64), default="new", nullable=False)
|
||||
last_error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
last_run_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
|
||||
|
||||
catalog_entry: Mapped[Optional["SourceCatalogEntry"]] = relationship()
|
||||
datasets: Mapped[list["Dataset"]] = relationship(back_populates="source", cascade="all, delete-orphan")
|
||||
update_checks: Mapped[list["SourceUpdateCheck"]] = relationship(back_populates="source", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class SourceCatalogEntry(Base):
|
||||
__tablename__ = "source_catalog_entries"
|
||||
__table_args__ = (UniqueConstraint("catalog_key", name="uq_source_catalog_entry_key"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
catalog_key: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
geography: Mapped[Optional[str]] = mapped_column(String(128), nullable=True, index=True)
|
||||
country_code: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
|
||||
mode_scope: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
source_name: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
source_category: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
formats_apis: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
availability: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
coverage_notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
geometry_notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
disruptions_closures: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
operator_list_use: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
access_license_notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
priority: Mapped[Optional[str]] = mapped_column(String(32), nullable=True, index=True)
|
||||
source_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
evidence_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
next_pipeline_action: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(64), default="backlog", nullable=False, index=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
|
||||
|
||||
|
||||
class Dataset(Base):
|
||||
__tablename__ = "datasets"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
source_id: Mapped[int] = mapped_column(ForeignKey("sources.id"), nullable=False, index=True)
|
||||
kind: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
local_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
sha256: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(64), default="imported", nullable=False)
|
||||
metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
|
||||
|
||||
source: Mapped[Source] = relationship(back_populates="datasets")
|
||||
|
||||
|
||||
class SourceUpdateCheck(Base):
|
||||
__tablename__ = "source_update_checks"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
source_id: Mapped[int] = mapped_column(ForeignKey("sources.id"), nullable=False, index=True)
|
||||
checked_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
|
||||
status: Mapped[str] = mapped_column(String(64), nullable=False, default="checked", index=True)
|
||||
update_available: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||
reason: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
remote_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
etag: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
last_modified: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
content_length: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
content_type: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
local_mtime: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
local_size: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
local_sha256: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
active_dataset_id: Mapped[Optional[int]] = mapped_column(ForeignKey("datasets.id"), nullable=True, index=True)
|
||||
active_dataset_sha256: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
source: Mapped[Source] = relationship(back_populates="update_checks")
|
||||
active_dataset: Mapped[Optional[Dataset]] = relationship()
|
||||
|
||||
|
||||
class OsmDiffState(Base):
|
||||
__tablename__ = "osm_diff_states"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
source_id: Mapped[int] = mapped_column(ForeignKey("sources.id"), nullable=False, index=True)
|
||||
raw_dataset_id: Mapped[Optional[int]] = mapped_column(ForeignKey("datasets.id"), nullable=True, index=True)
|
||||
updates_url: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
sequence_number: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
|
||||
timestamp: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
|
||||
status: Mapped[str] = mapped_column(String(64), nullable=False, default="active", index=True)
|
||||
metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
|
||||
|
||||
source: Mapped[Source] = relationship()
|
||||
raw_dataset: Mapped[Optional[Dataset]] = relationship()
|
||||
|
||||
|
||||
class Job(Base):
|
||||
__tablename__ = "jobs"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
kind: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
status: Mapped[str] = mapped_column(String(64), nullable=False, default="queued", index=True)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
progress_current: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
progress_total: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
priority: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
|
||||
requested_action: Mapped[Optional[str]] = mapped_column(String(32), nullable=True, index=True)
|
||||
lease_owner: Mapped[Optional[str]] = mapped_column(String(255), nullable=True, index=True)
|
||||
lease_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True, index=True)
|
||||
paused_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
result_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
dismissed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True, index=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
|
||||
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
|
||||
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
events: Mapped[list["JobEvent"]] = relationship(back_populates="job", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class JobEvent(Base):
|
||||
__tablename__ = "job_events"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
job_id: Mapped[int] = mapped_column(ForeignKey("jobs.id"), nullable=False, index=True)
|
||||
level: Mapped[str] = mapped_column(String(32), nullable=False, default="info", index=True)
|
||||
event_type: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
message: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
progress_current: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
progress_total: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
|
||||
|
||||
job: Mapped[Job] = relationship(back_populates="events")
|
||||
|
||||
|
||||
class PipelineRun(Base):
|
||||
__tablename__ = "pipeline_runs"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
stage: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
version: Mapped[str] = mapped_column(String(128), nullable=False, index=True)
|
||||
dependency_hash: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
status: Mapped[str] = mapped_column(String(64), nullable=False, default="running", index=True)
|
||||
source_id: Mapped[Optional[int]] = mapped_column(ForeignKey("sources.id"), nullable=True, index=True)
|
||||
dataset_id: Mapped[Optional[int]] = mapped_column(ForeignKey("datasets.id"), nullable=True, index=True)
|
||||
job_id: Mapped[Optional[int]] = mapped_column(ForeignKey("jobs.id"), nullable=True, index=True)
|
||||
input_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
output_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
started_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
|
||||
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
source: Mapped[Optional[Source]] = relationship()
|
||||
dataset: Mapped[Optional[Dataset]] = relationship()
|
||||
job: Mapped[Optional[Job]] = relationship()
|
||||
|
||||
|
||||
class GtfsAgency(Base):
|
||||
__tablename__ = "gtfs_agencies"
|
||||
__table_args__ = (UniqueConstraint("dataset_id", "agency_id", name="uq_gtfs_agency_dataset_id"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
|
||||
agency_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
name: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
url: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
timezone: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
|
||||
|
||||
|
||||
class GtfsStop(Base):
|
||||
__tablename__ = "gtfs_stops"
|
||||
__table_args__ = (UniqueConstraint("dataset_id", "stop_id", name="uq_gtfs_stop_dataset_id"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
|
||||
stop_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
parent_station: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
|
||||
|
||||
class GtfsRoute(Base):
|
||||
__tablename__ = "gtfs_routes"
|
||||
__table_args__ = (UniqueConstraint("dataset_id", "route_id", name="uq_gtfs_route_dataset_id"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
|
||||
route_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
agency_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
short_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
long_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
route_type: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
mode: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
|
||||
route_scope: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
|
||||
operator_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
geometry_geojson: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
min_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
min_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
max_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
max_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
route_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True, index=True)
|
||||
operator_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True, index=True)
|
||||
|
||||
|
||||
class GtfsTrip(Base):
|
||||
__tablename__ = "gtfs_trips"
|
||||
__table_args__ = (UniqueConstraint("dataset_id", "trip_id", name="uq_gtfs_trip_dataset_id"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
|
||||
route_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
trip_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
service_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
shape_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
|
||||
|
||||
class GtfsCalendar(Base):
|
||||
__tablename__ = "gtfs_calendars"
|
||||
__table_args__ = (UniqueConstraint("dataset_id", "service_id", name="uq_gtfs_calendar_dataset_service"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
|
||||
service_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
monday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
tuesday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
wednesday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
thursday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
friday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
saturday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
sunday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
start_date: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
|
||||
end_date: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
|
||||
|
||||
|
||||
class GtfsCalendarDate(Base):
|
||||
__tablename__ = "gtfs_calendar_dates"
|
||||
__table_args__ = (UniqueConstraint("dataset_id", "service_id", "date", name="uq_gtfs_calendar_date_dataset_service_date"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
|
||||
service_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
date: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
|
||||
exception_type: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
|
||||
|
||||
class GtfsShape(Base):
|
||||
__tablename__ = "gtfs_shapes"
|
||||
__table_args__ = (UniqueConstraint("dataset_id", "shape_id", name="uq_gtfs_shape_dataset_id"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
|
||||
shape_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
geometry_geojson: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
min_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
min_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
max_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
max_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
|
||||
|
||||
class GtfsStopTime(Base):
|
||||
__tablename__ = "gtfs_stop_times"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
|
||||
trip_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
stop_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
stop_sequence: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
arrival_time: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
departure_time: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
arrival_seconds: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, index=True)
|
||||
departure_seconds: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, index=True)
|
||||
|
||||
|
||||
class CanonicalStop(Base):
|
||||
__tablename__ = "canonical_stops"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
stop_key: Mapped[str] = mapped_column(String(255), nullable=False, unique=True, index=True)
|
||||
name: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
normalized_name: Mapped[str] = mapped_column(Text, nullable=False, index=True)
|
||||
lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
mode: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
|
||||
metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
|
||||
|
||||
|
||||
class CanonicalStopLink(Base):
|
||||
__tablename__ = "canonical_stop_links"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("object_type", "dataset_id", "object_id", name="uq_canonical_stop_link_object"),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
canonical_stop_id: Mapped[int] = mapped_column(ForeignKey("canonical_stops.id"), nullable=False, index=True)
|
||||
layer: Mapped[str] = mapped_column(String(64), nullable=False, index=True) # timetable, visual
|
||||
object_type: Mapped[str] = mapped_column(String(64), nullable=False, index=True) # gtfs_stop, osm_feature
|
||||
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
|
||||
object_id: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
|
||||
external_id: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
role: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=1.0)
|
||||
distance_m: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
canonical_stop: Mapped[CanonicalStop] = relationship()
|
||||
|
||||
|
||||
class RoutePattern(Base):
|
||||
__tablename__ = "route_patterns"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
pattern_key: Mapped[str] = mapped_column(String(255), nullable=False, unique=True, index=True)
|
||||
route_ref: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
route_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
mode: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
|
||||
route_scope: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
|
||||
operator_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
source_kind: Mapped[str] = mapped_column(String(64), nullable=False, index=True) # osm, gtfs_proposed
|
||||
status: Mapped[str] = mapped_column(String(64), nullable=False, default="active", index=True)
|
||||
osm_feature_id: Mapped[Optional[int]] = mapped_column(ForeignKey("osm_features.id"), nullable=True, index=True)
|
||||
gtfs_route_id: Mapped[Optional[int]] = mapped_column(ForeignKey("gtfs_routes.id"), nullable=True, index=True)
|
||||
gtfs_shape_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True, index=True)
|
||||
geometry_geojson: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
min_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
min_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
max_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
max_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=1.0)
|
||||
metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
|
||||
|
||||
osm_feature: Mapped[Optional["OsmFeature"]] = relationship()
|
||||
gtfs_route: Mapped[Optional[GtfsRoute]] = relationship()
|
||||
|
||||
|
||||
class RoutePatternStop(Base):
|
||||
__tablename__ = "route_pattern_stops"
|
||||
__table_args__ = (UniqueConstraint("route_pattern_id", "sequence", name="uq_route_pattern_stop_sequence"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
route_pattern_id: Mapped[int] = mapped_column(ForeignKey("route_patterns.id"), nullable=False, index=True)
|
||||
canonical_stop_id: Mapped[int] = mapped_column(ForeignKey("canonical_stops.id"), nullable=False, index=True)
|
||||
sequence: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
distance_along: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
source_kind: Mapped[str] = mapped_column(String(64), nullable=False, default="timetable")
|
||||
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=1.0)
|
||||
|
||||
route_pattern: Mapped[RoutePattern] = relationship()
|
||||
canonical_stop: Mapped[CanonicalStop] = relationship()
|
||||
|
||||
|
||||
class GtfsRoutePatternLink(Base):
|
||||
__tablename__ = "gtfs_route_pattern_links"
|
||||
__table_args__ = (UniqueConstraint("dataset_id", "route_id", "shape_id", name="uq_gtfs_route_pattern_shape"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
|
||||
gtfs_route_id: Mapped[int] = mapped_column(ForeignKey("gtfs_routes.id"), nullable=False, index=True)
|
||||
route_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
shape_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
route_pattern_id: Mapped[int] = mapped_column(ForeignKey("route_patterns.id"), nullable=False, index=True)
|
||||
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0)
|
||||
status: Mapped[str] = mapped_column(String(64), nullable=False, default="linked", index=True)
|
||||
source_kind: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
reasons_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
gtfs_route: Mapped[GtfsRoute] = relationship()
|
||||
route_pattern: Mapped[RoutePattern] = relationship()
|
||||
|
||||
|
||||
class GtfsTripRoutePatternLink(Base):
|
||||
__tablename__ = "gtfs_trip_route_pattern_links"
|
||||
__table_args__ = (UniqueConstraint("dataset_id", "trip_id", name="uq_gtfs_trip_route_pattern"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
|
||||
trip_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
route_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
shape_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
route_pattern_id: Mapped[int] = mapped_column(ForeignKey("route_patterns.id"), nullable=False, index=True)
|
||||
source_kind: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0)
|
||||
status: Mapped[str] = mapped_column(String(64), nullable=False, default="linked", index=True)
|
||||
|
||||
route_pattern: Mapped[RoutePattern] = relationship()
|
||||
|
||||
|
||||
class OsmFeature(Base):
|
||||
__tablename__ = "osm_features"
|
||||
__table_args__ = (UniqueConstraint("dataset_id", "osm_type", "osm_id", name="uq_osm_feature_dataset_type_id"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
|
||||
osm_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
osm_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
kind: Mapped[str] = mapped_column(String(64), nullable=False, index=True) # route, stop, terminal, station, infra
|
||||
mode: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
|
||||
route_scope: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
|
||||
name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
ref: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
operator: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
network: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
geometry_geojson: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
min_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
min_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
max_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
max_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
tags_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
route_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True, index=True)
|
||||
operator_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True, index=True)
|
||||
|
||||
|
||||
class OsmAddress(Base):
|
||||
__tablename__ = "osm_addresses"
|
||||
__table_args__ = (UniqueConstraint("dataset_id", "osm_type", "osm_id", name="uq_osm_address_dataset_type_id"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
|
||||
osm_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
osm_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
housenumber: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
street: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
place: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
postcode: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
city: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
country: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
unit: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
display_name: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
search_text: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
lat: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
lon: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
min_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
min_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
max_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
max_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
geometry_geojson: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
tags_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
|
||||
|
||||
|
||||
class RoutingNode(Base):
|
||||
__tablename__ = "routing_nodes"
|
||||
__table_args__ = (UniqueConstraint("dataset_id", "osm_node_id", name="uq_routing_node_dataset_osm"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
|
||||
osm_node_id: Mapped[int] = mapped_column(BigInteger, nullable=False, index=True)
|
||||
lat: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
lon: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
|
||||
|
||||
|
||||
class RoutingEdge(Base):
|
||||
__tablename__ = "routing_edges"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
|
||||
osm_way_id: Mapped[int] = mapped_column(BigInteger, nullable=False, index=True)
|
||||
source_osm_node_id: Mapped[int] = mapped_column(BigInteger, nullable=False, index=True)
|
||||
target_osm_node_id: Mapped[int] = mapped_column(BigInteger, nullable=False, index=True)
|
||||
source_lat: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
source_lon: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
target_lat: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
target_lon: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
highway: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
|
||||
name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
length_m: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
walk_cost_s: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
reverse_walk_cost_s: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
drive_cost_s: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
reverse_drive_cost_s: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
geometry_geojson: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
min_lon: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
min_lat: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
max_lon: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
max_lat: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
tags_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
|
||||
|
||||
|
||||
class RouteMatch(Base):
|
||||
__tablename__ = "route_matches"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
gtfs_route_id: Mapped[int] = mapped_column(ForeignKey("gtfs_routes.id"), nullable=False, index=True)
|
||||
osm_feature_id: Mapped[Optional[int]] = mapped_column(ForeignKey("osm_features.id"), nullable=True, index=True)
|
||||
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0)
|
||||
status: Mapped[str] = mapped_column(String(64), nullable=False) # matched, probable, weak, missing, accepted, rejected
|
||||
rule_source: Mapped[str] = mapped_column(String(64), default="auto", nullable=False) # auto, manual
|
||||
reasons_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
|
||||
|
||||
gtfs_route: Mapped[GtfsRoute] = relationship()
|
||||
osm_feature: Mapped[Optional[OsmFeature]] = relationship()
|
||||
|
||||
|
||||
class MatchRule(Base):
|
||||
__tablename__ = "match_rules"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
rule_type: Mapped[str] = mapped_column(String(64), nullable=False) # accept_match, reject_match, alias, force_operator
|
||||
selector_json: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
action_json: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
note: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
|
||||
|
||||
|
||||
class JourneySearchCache(Base):
|
||||
__tablename__ = "journey_search_cache"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
cache_key: Mapped[str] = mapped_column(String(128), nullable=False, unique=True, index=True)
|
||||
cache_type: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
payload_json: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
|
||||
|
||||
|
||||
class TravelRequest(Base):
|
||||
__tablename__ = "travel_requests"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
origin_stop_id: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
destination_stop_id: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
via_stop_id: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
departure_time: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||
service_date: Mapped[Optional[str]] = mapped_column(String(10), nullable=True, index=True)
|
||||
max_transfers: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||
transfer_seconds: Mapped[int] = mapped_column(Integer, nullable=False, default=120)
|
||||
source_filter: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
preferences_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
|
||||
|
||||
itineraries: Mapped[list["Itinerary"]] = relationship(back_populates="request", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class Itinerary(Base):
|
||||
__tablename__ = "itineraries"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
request_id: Mapped[int] = mapped_column(ForeignKey("travel_requests.id"), nullable=False, index=True)
|
||||
title: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
family: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
status: Mapped[str] = mapped_column(String(64), nullable=False, default="candidate", index=True)
|
||||
saved: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, index=True)
|
||||
summary_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
score_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
payload_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
|
||||
|
||||
request: Mapped[TravelRequest] = relationship(back_populates="itineraries")
|
||||
legs: Mapped[list["ItineraryLeg"]] = relationship(back_populates="itinerary", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class ItineraryLeg(Base):
|
||||
__tablename__ = "itinerary_legs"
|
||||
__table_args__ = (UniqueConstraint("itinerary_id", "sequence", name="uq_itinerary_leg_sequence"),)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
itinerary_id: Mapped[int] = mapped_column(ForeignKey("itineraries.id"), nullable=False, index=True)
|
||||
sequence: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
mode: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
|
||||
route_ref: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
route_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
from_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
to_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
departure_time: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
arrival_time: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
|
||||
locked: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, index=True)
|
||||
payload_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
|
||||
itinerary: Mapped[Itinerary] = relationship(back_populates="legs")
|
||||
111
app/osm_classification.py
Normal file
111
app/osm_classification.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Mapping
|
||||
|
||||
|
||||
LOCAL_SCOPE = "local"
|
||||
REGIONAL_SCOPE = "regional"
|
||||
LONG_DISTANCE_SCOPE = "long_distance"
|
||||
UNKNOWN_SCOPE = "unknown"
|
||||
OSM_ROUTE_SCOPE_CLASSIFIER_VERSION = "route_scope_v2"
|
||||
|
||||
BUS_MODES = {"bus", "trolleybus"}
|
||||
LOCAL_MODES = {"tram", "light_rail", "subway", "ferry", "funicular", "aerialway", "monorail"}
|
||||
LONG_DISTANCE_MODES = {"coach"}
|
||||
|
||||
LONG_DISTANCE_SERVICE_VALUES = {
|
||||
"high_speed",
|
||||
"long_distance",
|
||||
"intercity",
|
||||
"international",
|
||||
"night",
|
||||
"sleeper",
|
||||
}
|
||||
REGIONAL_SERVICE_VALUES = {"regional", "interurban", "commuter", "branch", "suburban"}
|
||||
LOCAL_SERVICE_VALUES = {"local", "urban", "city", "subway", "tram", "light_rail", "s-bahn", "sbahn"}
|
||||
|
||||
LONG_DISTANCE_PREFIX_RE = re.compile(r"^(ICE|IC|EC|ECE|EN|NJ|RJ|RJX|TGV|THA|EST|FLX|WB)\b|^(ICE|IC|EC|ECE|EN|NJ|RJ|RJX|TGV|THA|EST|FLX|WB)\d")
|
||||
REGIONAL_PREFIX_RE = re.compile(r"^(IRE|RE|RB|RER|TER|REX|MEX|ALX|WFB|R)\b|^(IRE|RE|RB|RER|TER|REX|MEX|ALX|WFB|R)\d")
|
||||
LOCAL_TRAIN_PREFIX_RE = re.compile(r"^(S|S-BAHN)\b|^S\d")
|
||||
|
||||
|
||||
def infer_osm_route_scope(
|
||||
*,
|
||||
mode: str | None,
|
||||
ref: str | None = None,
|
||||
name: str | None = None,
|
||||
network: str | None = None,
|
||||
tags: Mapping[str, object] | str | None = None,
|
||||
) -> str | None:
|
||||
"""Classify a public-transport route into a display scope.
|
||||
|
||||
OSM tagging varies by country and operator, so this intentionally combines
|
||||
explicit service tags with conservative reference-prefix heuristics.
|
||||
"""
|
||||
normalized_mode = (mode or "").strip().lower()
|
||||
tags_dict = _tags_dict(tags)
|
||||
values = {
|
||||
str(tags_dict.get(key) or "").strip().lower()
|
||||
for key in ("service", "train", "bus", "passenger", "network:type", "route_scope")
|
||||
if tags_dict.get(key)
|
||||
}
|
||||
if values & LONG_DISTANCE_SERVICE_VALUES:
|
||||
return LONG_DISTANCE_SCOPE
|
||||
if values & LOCAL_SERVICE_VALUES:
|
||||
return LOCAL_SCOPE
|
||||
if values & REGIONAL_SERVICE_VALUES:
|
||||
return REGIONAL_SCOPE
|
||||
if normalized_mode in LOCAL_MODES:
|
||||
return LOCAL_SCOPE
|
||||
if normalized_mode in LONG_DISTANCE_MODES:
|
||||
return LONG_DISTANCE_SCOPE
|
||||
|
||||
text = _classification_text(ref, name, network, tags_dict)
|
||||
if normalized_mode in BUS_MODES:
|
||||
if any(marker in text for marker in ("FLIXBUS", "EUROLINES", "INTERCITYBUS", "IC BUS", "LONG DISTANCE", "FERNBUS")):
|
||||
return LONG_DISTANCE_SCOPE
|
||||
if any(marker in text for marker in ("REGIONALBUS", "REGIOBUS", "REGIONAL BUS", "REGIONALVERKEHR", "REGIONAL VERKEHR")):
|
||||
return REGIONAL_SCOPE
|
||||
return LOCAL_SCOPE
|
||||
|
||||
if normalized_mode == "train":
|
||||
if LONG_DISTANCE_PREFIX_RE.search(text) or any(marker in text for marker in ("INTERCITY", "EUROCITY", "NIGHTJET", "FLIXTRAIN")):
|
||||
return LONG_DISTANCE_SCOPE
|
||||
if LOCAL_TRAIN_PREFIX_RE.search(text) or "S-BAHN" in text or "SBahn".upper() in text:
|
||||
return LOCAL_SCOPE
|
||||
if REGIONAL_PREFIX_RE.search(text) or any(marker in text for marker in ("REGIONAL", "REGIO", "REGIONALBAHN", "REGIONALEXPRESS")):
|
||||
return REGIONAL_SCOPE
|
||||
return UNKNOWN_SCOPE
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def infer_osm_route_scope_from_tags(mode: str | None, ref: str | None, name: str | None, network: str | None, tags_json: str | None) -> str | None:
|
||||
return infer_osm_route_scope(mode=mode, ref=ref, name=name, network=network, tags=tags_json)
|
||||
|
||||
|
||||
def _tags_dict(tags: Mapping[str, object] | str | None) -> dict[str, object]:
|
||||
if isinstance(tags, str):
|
||||
try:
|
||||
data = json.loads(tags or "{}")
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return data if isinstance(data, dict) else {}
|
||||
if isinstance(tags, Mapping):
|
||||
return dict(tags)
|
||||
return {}
|
||||
|
||||
|
||||
def _classification_text(ref: str | None, name: str | None, network: str | None, tags: Mapping[str, object]) -> str:
|
||||
parts = [
|
||||
ref or "",
|
||||
name or "",
|
||||
network or "",
|
||||
str(tags.get("ref") or ""),
|
||||
str(tags.get("name") or ""),
|
||||
str(tags.get("network") or ""),
|
||||
str(tags.get("network:short") or ""),
|
||||
]
|
||||
return " ".join(parts).strip().upper().replace("_", " ")
|
||||
981
app/osm_storage.py
Normal file
981
app/osm_storage.py
Normal file
@@ -0,0 +1,981 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Sequence
|
||||
|
||||
from sqlalchemy import and_, func, insert, not_, or_, select, text
|
||||
from sqlalchemy.dialects.postgresql import insert as postgresql_insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models import Dataset, OsmFeature
|
||||
from app.spatial import refresh_postgis_geometries
|
||||
|
||||
|
||||
OSM_STORAGE_METADATA_KEY = "osm_storage"
|
||||
OSM_STORAGE_MAIN = "main"
|
||||
OSM_STORAGE_SIDECAR_FEATURES = "sidecar_features"
|
||||
SQLITE_IN_CHUNK_SIZE = 800
|
||||
OSM_SIDECAR_ROUTE_SCOPE_INDEXES = ["ix_osm_sidecar_scope_bbox"]
|
||||
OSM_FEATURE_COLUMNS = [
|
||||
"dataset_id",
|
||||
"osm_type",
|
||||
"osm_id",
|
||||
"kind",
|
||||
"mode",
|
||||
"route_scope",
|
||||
"name",
|
||||
"ref",
|
||||
"operator",
|
||||
"network",
|
||||
"geometry_geojson",
|
||||
"min_lon",
|
||||
"min_lat",
|
||||
"max_lon",
|
||||
"max_lat",
|
||||
"tags_json",
|
||||
"route_key",
|
||||
"operator_key",
|
||||
]
|
||||
|
||||
|
||||
def effective_osm_feature_storage(value: str | None = None) -> str:
|
||||
configured = str(value or settings.osm_feature_storage or OSM_STORAGE_SIDECAR_FEATURES).strip().lower()
|
||||
if configured in {OSM_STORAGE_MAIN, "main", "main_db", "postgres", "postgresql"}:
|
||||
return OSM_STORAGE_MAIN
|
||||
if settings.is_postgresql_database and not settings.postgres_use_sidecars:
|
||||
return OSM_STORAGE_MAIN
|
||||
return OSM_STORAGE_SIDECAR_FEATURES
|
||||
|
||||
|
||||
class MissingOsmSidecar(FileNotFoundError):
|
||||
pass
|
||||
|
||||
|
||||
def dataset_metadata(dataset: Dataset) -> dict:
|
||||
try:
|
||||
metadata = json.loads(dataset.metadata_json or "{}")
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return metadata if isinstance(metadata, dict) else {}
|
||||
|
||||
|
||||
def features_are_sidecar(dataset: Dataset | None) -> bool:
|
||||
if dataset is None:
|
||||
return False
|
||||
storage = dataset_metadata(dataset).get(OSM_STORAGE_METADATA_KEY)
|
||||
if not isinstance(storage, dict):
|
||||
return False
|
||||
tables = storage.get("tables")
|
||||
if isinstance(tables, dict):
|
||||
return tables.get("osm_features") == "sidecar"
|
||||
return storage.get("mode") == OSM_STORAGE_SIDECAR_FEATURES
|
||||
|
||||
|
||||
def sidecar_path(dataset: Dataset | None) -> Path | None:
|
||||
if dataset is None:
|
||||
return None
|
||||
storage = dataset_metadata(dataset).get(OSM_STORAGE_METADATA_KEY)
|
||||
if not isinstance(storage, dict):
|
||||
return None
|
||||
value = storage.get("sidecar_path")
|
||||
if not value:
|
||||
return None
|
||||
return Path(str(value))
|
||||
|
||||
|
||||
def dataset_sidecar_paths(dataset: Dataset) -> list[Path]:
|
||||
path = sidecar_path(dataset)
|
||||
return [] if path is None else [path]
|
||||
|
||||
|
||||
def missing_sidecar_paths(dataset: Dataset | None) -> list[str]:
|
||||
if not features_are_sidecar(dataset):
|
||||
return []
|
||||
path = sidecar_path(dataset)
|
||||
if path is None or path.exists():
|
||||
return []
|
||||
return [str(path)]
|
||||
|
||||
|
||||
@contextmanager
|
||||
def sidecar_connection(dataset: Dataset) -> Iterator[sqlite3.Connection]:
|
||||
path = sidecar_path(dataset)
|
||||
if path is None:
|
||||
raise MissingOsmSidecar(f"dataset #{dataset.id} does not reference an OSM sidecar")
|
||||
if not path.exists():
|
||||
raise MissingOsmSidecar(f"OSM sidecar does not exist: {path}")
|
||||
connection = sqlite3.connect(f"file:{path}?mode=ro", uri=True)
|
||||
connection.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield connection
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def writable_sidecar_connection(dataset: Dataset) -> Iterator[sqlite3.Connection]:
|
||||
path = sidecar_path(dataset)
|
||||
if path is None:
|
||||
raise MissingOsmSidecar(f"dataset #{dataset.id} does not reference an OSM sidecar")
|
||||
if not path.exists():
|
||||
raise MissingOsmSidecar(f"OSM sidecar does not exist: {path}")
|
||||
connection = sqlite3.connect(path)
|
||||
connection.row_factory = sqlite3.Row
|
||||
try:
|
||||
connection.execute(f"PRAGMA busy_timeout={int(settings.sqlite_busy_timeout_ms)}")
|
||||
connection.execute("PRAGMA synchronous=NORMAL")
|
||||
yield connection
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
|
||||
def create_osm_sidecar(dataset: Dataset, rows: Sequence[dict[str, object]], *, source_hash: str | None = None) -> dict:
|
||||
path = _new_sidecar_path(dataset, source_hash or dataset.sha256)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
connection = sqlite3.connect(path)
|
||||
try:
|
||||
connection.execute("PRAGMA journal_mode=OFF")
|
||||
connection.execute("PRAGMA synchronous=OFF")
|
||||
_create_schema(connection)
|
||||
deduped_rows, duplicate_count = dedupe_osm_feature_rows(rows)
|
||||
inserted = 0
|
||||
counts = {"route": 0, "stop": 0, "station": 0, "terminal": 0, "infra": 0, "feature": 0}
|
||||
insert_sql = f"""
|
||||
INSERT INTO osm_features
|
||||
({", ".join(["id", *OSM_FEATURE_COLUMNS])})
|
||||
VALUES
|
||||
({", ".join(["?"] * (len(OSM_FEATURE_COLUMNS) + 1))})
|
||||
"""
|
||||
batch = []
|
||||
for index, row in enumerate(deduped_rows, start=1):
|
||||
kind = str(row.get("kind") or "feature")
|
||||
counts[kind] = counts.get(kind, 0) + 1
|
||||
batch.append((index, *[row.get(column) for column in OSM_FEATURE_COLUMNS]))
|
||||
if len(batch) >= 5000:
|
||||
connection.executemany(insert_sql, batch)
|
||||
inserted += len(batch)
|
||||
batch.clear()
|
||||
if batch:
|
||||
connection.executemany(insert_sql, batch)
|
||||
inserted += len(batch)
|
||||
connection.commit()
|
||||
_create_indexes(connection)
|
||||
connection.commit()
|
||||
finally:
|
||||
connection.close()
|
||||
return {
|
||||
"mode": OSM_STORAGE_SIDECAR_FEATURES,
|
||||
"tables": {"osm_features": "sidecar"},
|
||||
"sidecar_path": str(path),
|
||||
"features": inserted,
|
||||
"duplicate_features_skipped": duplicate_count,
|
||||
"counts": counts,
|
||||
}
|
||||
|
||||
|
||||
def ensure_osm_sidecar_schema(connection: sqlite3.Connection) -> None:
|
||||
columns = _sidecar_columns(connection)
|
||||
if "route_scope" not in columns:
|
||||
connection.execute("ALTER TABLE osm_features ADD COLUMN route_scope TEXT")
|
||||
connection.commit()
|
||||
|
||||
|
||||
def drop_osm_sidecar_route_scope_indexes(connection: sqlite3.Connection) -> None:
|
||||
for index_name in OSM_SIDECAR_ROUTE_SCOPE_INDEXES:
|
||||
connection.execute(f"DROP INDEX IF EXISTS {index_name}")
|
||||
|
||||
|
||||
def rebuild_osm_sidecar_indexes(connection: sqlite3.Connection) -> None:
|
||||
_create_indexes(connection)
|
||||
|
||||
|
||||
def osm_feature_count(session: Session, dataset_id: int, *, kind: str | Sequence[str] | None = None) -> int:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if features_are_sidecar(dataset):
|
||||
kinds = _as_list(kind)
|
||||
sql = "SELECT COUNT(*) FROM osm_features"
|
||||
params: list[object] = []
|
||||
if kinds:
|
||||
placeholders = ", ".join(["?"] * len(kinds))
|
||||
sql += f" WHERE kind IN ({placeholders})"
|
||||
params.extend(kinds)
|
||||
try:
|
||||
with sidecar_connection(dataset) as connection:
|
||||
return int(connection.execute(sql, params).fetchone()[0] or 0)
|
||||
except MissingOsmSidecar:
|
||||
return 0
|
||||
stmt = select(func.count()).select_from(OsmFeature).where(OsmFeature.dataset_id == dataset_id)
|
||||
kinds = _as_list(kind)
|
||||
if kinds:
|
||||
stmt = stmt.where(OsmFeature.kind.in_(kinds))
|
||||
return int(session.scalar(stmt) or 0)
|
||||
|
||||
|
||||
def osm_feature_bbox(
|
||||
session: Session,
|
||||
dataset_ids: Sequence[int],
|
||||
*,
|
||||
kinds: Sequence[str] | None = None,
|
||||
) -> tuple[float | None, float | None, float | None, float | None]:
|
||||
if not dataset_ids:
|
||||
return (None, None, None, None)
|
||||
datasets = {
|
||||
dataset.id: dataset
|
||||
for dataset in session.scalars(select(Dataset).where(Dataset.id.in_([int(value) for value in dataset_ids]))).all()
|
||||
}
|
||||
boxes: list[tuple[float, float, float, float]] = []
|
||||
main_dataset_ids = [dataset_id for dataset_id, dataset in datasets.items() if not features_are_sidecar(dataset)]
|
||||
if main_dataset_ids:
|
||||
stmt = select(func.min(OsmFeature.min_lon), func.min(OsmFeature.min_lat), func.max(OsmFeature.max_lon), func.max(OsmFeature.max_lat)).where(
|
||||
OsmFeature.dataset_id.in_(main_dataset_ids)
|
||||
)
|
||||
if kinds:
|
||||
stmt = stmt.where(OsmFeature.kind.in_(list(kinds)))
|
||||
row = session.execute(stmt).one()
|
||||
if None not in row:
|
||||
boxes.append((float(row[0]), float(row[1]), float(row[2]), float(row[3])))
|
||||
for dataset in datasets.values():
|
||||
if not features_are_sidecar(dataset):
|
||||
continue
|
||||
where = []
|
||||
params: list[object] = []
|
||||
if kinds:
|
||||
placeholders = ", ".join(["?"] * len(kinds))
|
||||
where.append(f"kind IN ({placeholders})")
|
||||
params.extend(list(kinds))
|
||||
sql = "SELECT MIN(min_lon), MIN(min_lat), MAX(max_lon), MAX(max_lat) FROM osm_features"
|
||||
if where:
|
||||
sql += " WHERE " + " AND ".join(where)
|
||||
try:
|
||||
with sidecar_connection(dataset) as connection:
|
||||
row = connection.execute(sql, params).fetchone()
|
||||
if row is not None and None not in row:
|
||||
boxes.append((float(row[0]), float(row[1]), float(row[2]), float(row[3])))
|
||||
except MissingOsmSidecar:
|
||||
continue
|
||||
if not boxes:
|
||||
return (None, None, None, None)
|
||||
return (
|
||||
min(box[0] for box in boxes),
|
||||
min(box[1] for box in boxes),
|
||||
max(box[2] for box in boxes),
|
||||
max(box[3] for box in boxes),
|
||||
)
|
||||
|
||||
|
||||
def query_osm_features(
|
||||
session: Session,
|
||||
dataset_ids: Sequence[int],
|
||||
*,
|
||||
kinds: Sequence[str] | None = None,
|
||||
modes: Sequence[str] | None = None,
|
||||
bbox: tuple[float, float, float, float] | None = None,
|
||||
geometry_required: bool | None = None,
|
||||
search: str | None = None,
|
||||
route_key: str | None = None,
|
||||
route_scopes: Sequence[str] | None = None,
|
||||
ref: str | None = None,
|
||||
osm_type: str | None = None,
|
||||
osm_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
prefer_materialized_ids: bool = True,
|
||||
) -> list[OsmFeature]:
|
||||
if not dataset_ids:
|
||||
return []
|
||||
datasets = {
|
||||
dataset.id: dataset
|
||||
for dataset in session.scalars(select(Dataset).where(Dataset.id.in_([int(value) for value in dataset_ids]))).all()
|
||||
}
|
||||
materialized_ids = _materialized_ids_by_identity(session, list(datasets)) if prefer_materialized_ids else {}
|
||||
rows: list[OsmFeature] = []
|
||||
main_dataset_ids = [dataset_id for dataset_id, dataset in datasets.items() if not features_are_sidecar(dataset)]
|
||||
if main_dataset_ids:
|
||||
stmt = select(OsmFeature).where(OsmFeature.dataset_id.in_(main_dataset_ids))
|
||||
stmt = _apply_main_filters(
|
||||
stmt,
|
||||
kinds=kinds,
|
||||
modes=modes,
|
||||
bbox=bbox,
|
||||
geometry_required=geometry_required,
|
||||
search=search,
|
||||
route_key=route_key,
|
||||
route_scopes=route_scopes,
|
||||
ref=ref,
|
||||
osm_type=osm_type,
|
||||
osm_id=osm_id,
|
||||
)
|
||||
if offset:
|
||||
stmt = stmt.offset(max(0, int(offset)))
|
||||
rows.extend(
|
||||
session.scalars(
|
||||
stmt.order_by(OsmFeature.kind, OsmFeature.mode, OsmFeature.ref, OsmFeature.name, OsmFeature.id).limit(limit)
|
||||
).all()
|
||||
)
|
||||
for dataset_id, dataset in datasets.items():
|
||||
if not features_are_sidecar(dataset):
|
||||
continue
|
||||
rows.extend(
|
||||
_query_sidecar_features(
|
||||
dataset,
|
||||
kinds=kinds,
|
||||
modes=modes,
|
||||
bbox=bbox,
|
||||
geometry_required=geometry_required,
|
||||
search=search,
|
||||
route_key=route_key,
|
||||
route_scopes=route_scopes,
|
||||
ref=ref,
|
||||
osm_type=osm_type,
|
||||
osm_id=osm_id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
materialized_ids=materialized_ids,
|
||||
)
|
||||
)
|
||||
rows.sort(key=lambda row: (row.kind or "", row.mode or "", row.ref or "", row.name or "", int(row.id or 0)))
|
||||
if limit is not None:
|
||||
return rows[: max(1, int(limit))]
|
||||
return rows
|
||||
|
||||
|
||||
def get_osm_feature(session: Session, feature_id: int) -> OsmFeature | None:
|
||||
return session.get(OsmFeature, feature_id)
|
||||
|
||||
|
||||
def osm_feature_identity_key(feature: OsmFeature) -> str:
|
||||
return f"{feature.dataset_id}|{feature.osm_type}|{feature.osm_id}"
|
||||
|
||||
|
||||
def osm_feature_public_id(feature: OsmFeature) -> int | str | None:
|
||||
if getattr(feature, "_osm_sidecar_source", False):
|
||||
return osm_feature_identity_key(feature)
|
||||
return feature.id
|
||||
|
||||
|
||||
def resolve_osm_feature(session: Session, value: int | str) -> OsmFeature | None:
|
||||
int_value = _safe_int(value)
|
||||
if int_value is not None:
|
||||
feature = session.get(OsmFeature, int_value)
|
||||
if feature is not None:
|
||||
return feature
|
||||
parsed = parse_osm_feature_identity_key(str(value))
|
||||
if parsed is None:
|
||||
return None
|
||||
dataset_id, osm_type, osm_id = parsed
|
||||
existing = session.scalar(
|
||||
select(OsmFeature).where(
|
||||
OsmFeature.dataset_id == dataset_id,
|
||||
OsmFeature.osm_type == osm_type,
|
||||
OsmFeature.osm_id == osm_id,
|
||||
)
|
||||
)
|
||||
if existing is not None:
|
||||
return existing
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if not features_are_sidecar(dataset):
|
||||
return None
|
||||
try:
|
||||
with sidecar_connection(dataset) as connection:
|
||||
select_columns = ", ".join(_sidecar_select_columns(_sidecar_columns(connection)))
|
||||
row = connection.execute(
|
||||
f"""
|
||||
SELECT id, {select_columns}
|
||||
FROM osm_features
|
||||
WHERE dataset_id = ?
|
||||
AND osm_type = ?
|
||||
AND osm_id = ?
|
||||
""",
|
||||
(dataset_id, osm_type, osm_id),
|
||||
).fetchone()
|
||||
except MissingOsmSidecar:
|
||||
return None
|
||||
if row is None:
|
||||
return None
|
||||
return _feature_from_row(row, {})
|
||||
|
||||
|
||||
def parse_osm_feature_identity_key(value: str) -> tuple[int, str, str] | None:
|
||||
parts = value.split("|", 2)
|
||||
if len(parts) != 3:
|
||||
return None
|
||||
dataset_id = _safe_int(parts[0])
|
||||
if dataset_id is None:
|
||||
return None
|
||||
osm_type = parts[1].strip()
|
||||
osm_id = parts[2].strip()
|
||||
if not osm_type or not osm_id:
|
||||
return None
|
||||
return dataset_id, osm_type, osm_id
|
||||
|
||||
|
||||
def ensure_main_osm_feature(session: Session, feature: OsmFeature) -> OsmFeature:
|
||||
existing = session.scalar(
|
||||
select(OsmFeature).where(
|
||||
OsmFeature.dataset_id == feature.dataset_id,
|
||||
OsmFeature.osm_type == feature.osm_type,
|
||||
OsmFeature.osm_id == feature.osm_id,
|
||||
)
|
||||
)
|
||||
if existing is not None:
|
||||
return existing
|
||||
values = dict(
|
||||
dataset_id=feature.dataset_id,
|
||||
osm_type=feature.osm_type,
|
||||
osm_id=feature.osm_id,
|
||||
kind=feature.kind,
|
||||
mode=feature.mode,
|
||||
route_scope=feature.route_scope,
|
||||
name=feature.name,
|
||||
ref=feature.ref,
|
||||
operator=feature.operator,
|
||||
network=feature.network,
|
||||
geometry_geojson=feature.geometry_geojson,
|
||||
min_lon=feature.min_lon,
|
||||
min_lat=feature.min_lat,
|
||||
max_lon=feature.max_lon,
|
||||
max_lat=feature.max_lat,
|
||||
tags_json=feature.tags_json,
|
||||
route_key=feature.route_key,
|
||||
operator_key=feature.operator_key,
|
||||
)
|
||||
if settings.is_postgresql_database:
|
||||
session.execute(
|
||||
postgresql_insert(OsmFeature)
|
||||
.values(**values)
|
||||
.on_conflict_do_nothing(index_elements=["dataset_id", "osm_type", "osm_id"])
|
||||
)
|
||||
else:
|
||||
session.execute(insert(OsmFeature).values(**values).prefix_with("OR IGNORE"))
|
||||
session.flush()
|
||||
refresh_postgis_geometries(session, dataset_id=feature.dataset_id, tables=["osm_features"])
|
||||
existing = session.scalar(
|
||||
select(OsmFeature).where(
|
||||
OsmFeature.dataset_id == feature.dataset_id,
|
||||
OsmFeature.osm_type == feature.osm_type,
|
||||
OsmFeature.osm_id == feature.osm_id,
|
||||
)
|
||||
)
|
||||
if existing is None:
|
||||
raise RuntimeError(f"Could not materialize OSM feature {feature.dataset_id}:{feature.osm_type}:{feature.osm_id}")
|
||||
return existing
|
||||
|
||||
|
||||
def materialize_osm_features(session: Session, features: Sequence[OsmFeature]) -> list[OsmFeature]:
|
||||
return [ensure_main_osm_feature(session, feature) for feature in features]
|
||||
|
||||
|
||||
def _new_sidecar_path(dataset: Dataset, source_hash: str | None) -> Path:
|
||||
suffix = (source_hash or dataset.sha256 or str(dataset.id))[:12]
|
||||
return settings.data_dir / "sidecars" / f"source_{dataset.source_id}" / f"osm_dataset_{dataset.id}_{suffix}.sqlite"
|
||||
|
||||
|
||||
def dedupe_osm_feature_rows(rows: Sequence[dict[str, object]]) -> tuple[list[dict[str, object]], int]:
|
||||
selected: dict[tuple[int, str, str], dict[str, object]] = {}
|
||||
for row in rows:
|
||||
key = (int(row["dataset_id"]), str(row["osm_type"]), str(row["osm_id"]))
|
||||
current = selected.get(key)
|
||||
if current is None or _feature_row_preference(row) < _feature_row_preference(current):
|
||||
selected[key] = dict(row)
|
||||
return list(selected.values()), max(0, len(rows) - len(selected))
|
||||
|
||||
|
||||
def _feature_row_preference(row: dict[str, object]) -> tuple[int, int, int]:
|
||||
kind_rank = {
|
||||
"route": 0,
|
||||
"station": 1,
|
||||
"terminal": 2,
|
||||
"stop": 3,
|
||||
"infra": 4,
|
||||
"feature": 5,
|
||||
}.get(str(row.get("kind") or "feature"), 6)
|
||||
has_geometry = 0 if row.get("geometry_geojson") else 1
|
||||
geometry_size = -len(str(row.get("geometry_geojson") or ""))
|
||||
return (kind_rank, has_geometry, geometry_size)
|
||||
|
||||
|
||||
def _create_schema(connection: sqlite3.Connection) -> None:
|
||||
connection.execute(
|
||||
"""
|
||||
CREATE TABLE osm_features (
|
||||
id INTEGER PRIMARY KEY,
|
||||
dataset_id INTEGER NOT NULL,
|
||||
osm_type TEXT NOT NULL,
|
||||
osm_id TEXT NOT NULL,
|
||||
kind TEXT NOT NULL,
|
||||
mode TEXT,
|
||||
route_scope TEXT,
|
||||
name TEXT,
|
||||
ref TEXT,
|
||||
operator TEXT,
|
||||
network TEXT,
|
||||
geometry_geojson TEXT,
|
||||
min_lon REAL,
|
||||
min_lat REAL,
|
||||
max_lon REAL,
|
||||
max_lat REAL,
|
||||
tags_json TEXT,
|
||||
route_key TEXT,
|
||||
operator_key TEXT,
|
||||
UNIQUE(dataset_id, osm_type, osm_id)
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _create_indexes(connection: sqlite3.Connection) -> None:
|
||||
statements = [
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_sidecar_kind_mode_bbox ON osm_features (kind, mode, min_lon, max_lon, min_lat, max_lat)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_sidecar_scope_bbox ON osm_features (kind, mode, route_scope, min_lon, max_lon, min_lat, max_lat)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_sidecar_route_key ON osm_features (route_key)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_sidecar_ref ON osm_features (ref)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_sidecar_identity ON osm_features (dataset_id, osm_type, osm_id)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_sidecar_kind_ref_mode ON osm_features (kind, ref, mode)",
|
||||
]
|
||||
for statement in statements:
|
||||
connection.execute(statement)
|
||||
|
||||
|
||||
def _apply_main_filters(stmt, *, kinds, modes, bbox, geometry_required, search, route_key, route_scopes, ref, osm_type, osm_id):
|
||||
if kinds:
|
||||
stmt = stmt.where(OsmFeature.kind.in_(list(kinds)))
|
||||
if modes:
|
||||
stmt = stmt.where(OsmFeature.mode.in_(list(modes)))
|
||||
if route_scopes:
|
||||
stmt = stmt.where(_main_route_scope_condition([str(scope) for scope in route_scopes]))
|
||||
if bbox:
|
||||
min_lon, min_lat, max_lon, max_lat = bbox
|
||||
if settings.is_postgresql_database:
|
||||
stmt = stmt.where(
|
||||
text(
|
||||
"""
|
||||
(
|
||||
osm_features.geom && ST_MakeEnvelope(:bbox_min_lon, :bbox_min_lat, :bbox_max_lon, :bbox_max_lat, 4326)
|
||||
OR (
|
||||
osm_features.geom IS NULL
|
||||
AND osm_features.min_lon <= :bbox_max_lon
|
||||
AND osm_features.max_lon >= :bbox_min_lon
|
||||
AND osm_features.min_lat <= :bbox_max_lat
|
||||
AND osm_features.max_lat >= :bbox_min_lat
|
||||
)
|
||||
)
|
||||
"""
|
||||
)
|
||||
).params(
|
||||
bbox_min_lon=min_lon,
|
||||
bbox_min_lat=min_lat,
|
||||
bbox_max_lon=max_lon,
|
||||
bbox_max_lat=max_lat,
|
||||
)
|
||||
else:
|
||||
stmt = stmt.where(OsmFeature.min_lon <= max_lon, OsmFeature.max_lon >= min_lon, OsmFeature.min_lat <= max_lat, OsmFeature.max_lat >= min_lat)
|
||||
if geometry_required is True:
|
||||
stmt = stmt.where(OsmFeature.geometry_geojson.is_not(None))
|
||||
elif geometry_required is False:
|
||||
stmt = stmt.where(OsmFeature.geometry_geojson.is_(None))
|
||||
if search:
|
||||
if settings.is_postgresql_database:
|
||||
stmt = stmt.where(
|
||||
text(
|
||||
"""
|
||||
(
|
||||
LOWER(COALESCE(osm_features.ref, '')) LIKE :search_pattern
|
||||
OR LOWER(COALESCE(osm_features.name, '')) LIKE :search_pattern
|
||||
OR LOWER(COALESCE(osm_features.tags_json, '')) LIKE :search_pattern
|
||||
)
|
||||
"""
|
||||
)
|
||||
).params(search_pattern=f"%{search.lower()}%")
|
||||
else:
|
||||
pattern = f"%{search}%"
|
||||
stmt = stmt.where(
|
||||
(OsmFeature.ref.ilike(pattern))
|
||||
| (OsmFeature.name.ilike(pattern))
|
||||
| (OsmFeature.tags_json.ilike(pattern))
|
||||
)
|
||||
if route_key:
|
||||
stmt = stmt.where(OsmFeature.route_key == route_key)
|
||||
if ref:
|
||||
stmt = stmt.where(OsmFeature.ref == ref)
|
||||
if osm_type:
|
||||
stmt = stmt.where(OsmFeature.osm_type == osm_type)
|
||||
if osm_id:
|
||||
stmt = stmt.where(OsmFeature.osm_id == osm_id)
|
||||
return stmt
|
||||
|
||||
|
||||
def _main_route_scope_condition(route_scopes: list[str]):
|
||||
fallback = _main_route_scope_fallback_condition(route_scopes)
|
||||
stored = OsmFeature.route_scope.in_(route_scopes)
|
||||
if "local" in route_scopes:
|
||||
non_local_bus_fallback = _main_route_scope_fallback_condition(["long_distance", "regional"])
|
||||
stored = and_(stored, not_(and_(OsmFeature.mode.in_(["bus", "trolleybus"]), non_local_bus_fallback)))
|
||||
return or_(stored, fallback)
|
||||
|
||||
|
||||
def _main_route_scope_fallback_condition(route_scopes: list[str]):
|
||||
ref = func.upper(func.coalesce(OsmFeature.ref, ""))
|
||||
name = func.upper(func.coalesce(OsmFeature.name, ""))
|
||||
network = func.upper(func.coalesce(OsmFeature.network, ""))
|
||||
tags = func.lower(func.coalesce(OsmFeature.tags_json, ""))
|
||||
train_long_distance = and_(
|
||||
OsmFeature.mode == "train",
|
||||
or_(
|
||||
ref.like("ICE%"),
|
||||
ref.like("IC%"),
|
||||
ref.like("EC%"),
|
||||
ref.like("ECE%"),
|
||||
ref.like("EN%"),
|
||||
ref.like("NJ%"),
|
||||
ref.like("RJ%"),
|
||||
ref.like("RJX%"),
|
||||
ref.like("TGV%"),
|
||||
ref.like("THA%"),
|
||||
ref.like("FLX%"),
|
||||
name.like("%INTERCITY%"),
|
||||
name.like("%EUROCITY%"),
|
||||
name.like("%NIGHTJET%"),
|
||||
name.like("%FLIXTRAIN%"),
|
||||
tags.like('%"service":"long_distance"%'),
|
||||
tags.like('%"train":"long_distance"%'),
|
||||
tags.like('%"train":"high_speed"%'),
|
||||
tags.like('%"train":"intercity"%'),
|
||||
),
|
||||
)
|
||||
bus_long_distance = and_(
|
||||
OsmFeature.mode.in_(["bus", "trolleybus"]),
|
||||
or_(
|
||||
name.like("%FLIXBUS%"),
|
||||
network.like("%FLIXBUS%"),
|
||||
name.like("%EUROLINES%"),
|
||||
network.like("%EUROLINES%"),
|
||||
name.like("%INTERCITYBUS%"),
|
||||
name.like("%IC BUS%"),
|
||||
name.like("%FERNBUS%"),
|
||||
tags.like('%"service":"long_distance"%'),
|
||||
tags.like('%"bus":"long_distance"%'),
|
||||
tags.like('%"bus":"intercity"%'),
|
||||
tags.like('%"network:type":"long_distance"%'),
|
||||
),
|
||||
)
|
||||
long_distance = or_(OsmFeature.mode == "coach", train_long_distance, bus_long_distance)
|
||||
bus_regional = and_(
|
||||
OsmFeature.mode.in_(["bus", "trolleybus"]),
|
||||
not_(bus_long_distance),
|
||||
or_(
|
||||
name.like("%REGIONALBUS%"),
|
||||
name.like("%REGIOBUS%"),
|
||||
name.like("%REGIONAL BUS%"),
|
||||
name.like("%REGIONALVERKEHR%"),
|
||||
network.like("%REGIONALBUS%"),
|
||||
network.like("%REGIOBUS%"),
|
||||
network.like("%REGIONALVERKEHR%"),
|
||||
tags.like('%"service":"regional"%'),
|
||||
tags.like('%"bus":"regional"%'),
|
||||
tags.like('%"bus":"interurban"%'),
|
||||
tags.like('%"network:type":"regional"%'),
|
||||
),
|
||||
)
|
||||
local = or_(
|
||||
OsmFeature.mode.in_(["tram", "light_rail", "subway", "ferry", "funicular", "aerialway", "monorail"]),
|
||||
and_(OsmFeature.mode.in_(["bus", "trolleybus"]), not_(or_(bus_long_distance, bus_regional))),
|
||||
and_(
|
||||
OsmFeature.mode == "train",
|
||||
or_(ref.like("S%"), name.like("%S-BAHN%"), network.like("%S-BAHN%"), tags.like('%"train":"commuter"%')),
|
||||
),
|
||||
)
|
||||
train_regional = and_(
|
||||
OsmFeature.mode == "train",
|
||||
not_(train_long_distance),
|
||||
or_(
|
||||
ref.like("IRE%"),
|
||||
ref.like("RE%"),
|
||||
ref.like("RB%"),
|
||||
ref.like("RER%"),
|
||||
ref.like("TER%"),
|
||||
ref.like("REX%"),
|
||||
ref.like("MEX%"),
|
||||
ref.like("ALX%"),
|
||||
ref.like("WFB%"),
|
||||
ref.like("R%"),
|
||||
name.like("%REGIONAL%"),
|
||||
name.like("%REGIO%"),
|
||||
tags.like('%"service":"regional"%'),
|
||||
tags.like('%"train":"regional"%'),
|
||||
),
|
||||
)
|
||||
regional = or_(train_regional, bus_regional)
|
||||
conditions = []
|
||||
if "long_distance" in route_scopes:
|
||||
conditions.append(long_distance)
|
||||
if "regional" in route_scopes:
|
||||
conditions.append(regional)
|
||||
if "local" in route_scopes:
|
||||
conditions.append(local)
|
||||
if "unknown" in route_scopes:
|
||||
conditions.append(and_(OsmFeature.mode == "train", not_(or_(long_distance, regional, local))))
|
||||
return or_(*conditions) if conditions else OsmFeature.route_scope.is_(None)
|
||||
|
||||
|
||||
def _query_sidecar_features(
|
||||
dataset: Dataset,
|
||||
*,
|
||||
kinds: Sequence[str] | None,
|
||||
modes: Sequence[str] | None,
|
||||
bbox: tuple[float, float, float, float] | None,
|
||||
geometry_required: bool | None,
|
||||
search: str | None,
|
||||
route_key: str | None,
|
||||
route_scopes: Sequence[str] | None,
|
||||
ref: str | None,
|
||||
osm_type: str | None,
|
||||
osm_id: str | None,
|
||||
limit: int | None,
|
||||
offset: int | None,
|
||||
materialized_ids: dict[tuple[int, str, str], int],
|
||||
) -> list[OsmFeature]:
|
||||
where = []
|
||||
params: list[object] = []
|
||||
try:
|
||||
with sidecar_connection(dataset) as connection:
|
||||
available_columns = _sidecar_columns(connection)
|
||||
if kinds:
|
||||
placeholders = ", ".join(["?"] * len(kinds))
|
||||
where.append(f"kind IN ({placeholders})")
|
||||
params.extend(list(kinds))
|
||||
if modes:
|
||||
placeholders = ", ".join(["?"] * len(modes))
|
||||
where.append(f"mode IN ({placeholders})")
|
||||
params.extend(list(modes))
|
||||
if bbox:
|
||||
min_lon, min_lat, max_lon, max_lat = bbox
|
||||
where.extend(["min_lon <= ?", "max_lon >= ?", "min_lat <= ?", "max_lat >= ?"])
|
||||
params.extend([max_lon, min_lon, max_lat, min_lat])
|
||||
if geometry_required is True:
|
||||
where.append("geometry_geojson IS NOT NULL")
|
||||
elif geometry_required is False:
|
||||
where.append("geometry_geojson IS NULL")
|
||||
if search:
|
||||
where.append("(LOWER(COALESCE(ref, '')) LIKE ? OR LOWER(COALESCE(name, '')) LIKE ? OR LOWER(COALESCE(tags_json, '')) LIKE ?)")
|
||||
pattern = f"%{search.lower()}%"
|
||||
params.extend([pattern, pattern, pattern])
|
||||
if route_key:
|
||||
where.append("route_key = ?")
|
||||
params.append(route_key)
|
||||
if route_scopes:
|
||||
condition, condition_params = _sidecar_route_scope_condition([str(scope) for scope in route_scopes], has_route_scope="route_scope" in available_columns)
|
||||
where.append(condition)
|
||||
params.extend(condition_params)
|
||||
if ref:
|
||||
where.append("ref = ?")
|
||||
params.append(ref)
|
||||
if osm_type:
|
||||
where.append("osm_type = ?")
|
||||
params.append(osm_type)
|
||||
if osm_id:
|
||||
where.append("osm_id = ?")
|
||||
params.append(osm_id)
|
||||
select_columns = ", ".join(_sidecar_select_columns(available_columns))
|
||||
sql = f"SELECT id, {select_columns} FROM osm_features"
|
||||
if where:
|
||||
sql += " WHERE " + " AND ".join(where)
|
||||
sql += " ORDER BY kind, mode, ref, name, id"
|
||||
if limit is not None:
|
||||
sql += " LIMIT ?"
|
||||
params.append(max(1, int(limit)))
|
||||
if offset:
|
||||
if limit is None:
|
||||
sql += " LIMIT -1"
|
||||
sql += " OFFSET ?"
|
||||
params.append(max(0, int(offset)))
|
||||
return [_feature_from_row(row, materialized_ids) for row in connection.execute(sql, params).fetchall()]
|
||||
except MissingOsmSidecar:
|
||||
return []
|
||||
|
||||
|
||||
def _sidecar_columns(connection: sqlite3.Connection) -> set[str]:
|
||||
return {str(row["name"]) for row in connection.execute("PRAGMA table_info(osm_features)").fetchall()}
|
||||
|
||||
|
||||
def _sidecar_select_columns(available_columns: set[str]) -> list[str]:
|
||||
return [column if column in available_columns else f"NULL AS {column}" for column in OSM_FEATURE_COLUMNS]
|
||||
|
||||
|
||||
def _sidecar_route_scope_condition(route_scopes: list[str], *, has_route_scope: bool) -> tuple[str, list[object]]:
|
||||
fallback_sql, fallback_params = _sidecar_route_scope_fallback_condition(route_scopes)
|
||||
if has_route_scope:
|
||||
placeholders = ", ".join(["?"] * len(route_scopes))
|
||||
stored_sql = f"route_scope IN ({placeholders})"
|
||||
params: list[object] = [*route_scopes]
|
||||
if "local" in route_scopes:
|
||||
non_local_sql, non_local_params = _sidecar_route_scope_fallback_condition(["long_distance", "regional"])
|
||||
stored_sql = f"({stored_sql} AND NOT (mode IN ('bus', 'trolleybus') AND {non_local_sql}))"
|
||||
params.extend(non_local_params)
|
||||
return f"({stored_sql} OR {fallback_sql})", [*params, *fallback_params]
|
||||
return fallback_sql, fallback_params
|
||||
|
||||
|
||||
def _sidecar_route_scope_fallback_condition(route_scopes: list[str]) -> tuple[str, list[object]]:
|
||||
train_long_distance = """(
|
||||
mode = 'train'
|
||||
AND (
|
||||
UPPER(COALESCE(ref, '')) LIKE 'ICE%'
|
||||
OR UPPER(COALESCE(ref, '')) LIKE 'IC%'
|
||||
OR UPPER(COALESCE(ref, '')) LIKE 'EC%'
|
||||
OR UPPER(COALESCE(ref, '')) LIKE 'ECE%'
|
||||
OR UPPER(COALESCE(ref, '')) LIKE 'EN%'
|
||||
OR UPPER(COALESCE(ref, '')) LIKE 'NJ%'
|
||||
OR UPPER(COALESCE(ref, '')) LIKE 'RJ%'
|
||||
OR UPPER(COALESCE(ref, '')) LIKE 'RJX%'
|
||||
OR UPPER(COALESCE(ref, '')) LIKE 'TGV%'
|
||||
OR UPPER(COALESCE(ref, '')) LIKE 'THA%'
|
||||
OR UPPER(COALESCE(ref, '')) LIKE 'FLX%'
|
||||
OR UPPER(COALESCE(name, '')) LIKE '%INTERCITY%'
|
||||
OR UPPER(COALESCE(name, '')) LIKE '%EUROCITY%'
|
||||
OR UPPER(COALESCE(name, '')) LIKE '%NIGHTJET%'
|
||||
OR UPPER(COALESCE(name, '')) LIKE '%FLIXTRAIN%'
|
||||
OR LOWER(COALESCE(tags_json, '')) LIKE '%"service":"long_distance"%'
|
||||
OR LOWER(COALESCE(tags_json, '')) LIKE '%"train":"long_distance"%'
|
||||
OR LOWER(COALESCE(tags_json, '')) LIKE '%"train":"high_speed"%'
|
||||
OR LOWER(COALESCE(tags_json, '')) LIKE '%"train":"intercity"%'
|
||||
)
|
||||
)"""
|
||||
bus_long_distance = """(
|
||||
mode IN ('bus', 'trolleybus')
|
||||
AND (
|
||||
UPPER(COALESCE(name, '')) LIKE '%FLIXBUS%'
|
||||
OR UPPER(COALESCE(network, '')) LIKE '%FLIXBUS%'
|
||||
OR UPPER(COALESCE(name, '')) LIKE '%EUROLINES%'
|
||||
OR UPPER(COALESCE(network, '')) LIKE '%EUROLINES%'
|
||||
OR UPPER(COALESCE(name, '')) LIKE '%INTERCITYBUS%'
|
||||
OR UPPER(COALESCE(name, '')) LIKE '%IC BUS%'
|
||||
OR UPPER(COALESCE(name, '')) LIKE '%FERNBUS%'
|
||||
OR LOWER(COALESCE(tags_json, '')) LIKE '%"service":"long_distance"%'
|
||||
OR LOWER(COALESCE(tags_json, '')) LIKE '%"bus":"long_distance"%'
|
||||
OR LOWER(COALESCE(tags_json, '')) LIKE '%"bus":"intercity"%'
|
||||
OR LOWER(COALESCE(tags_json, '')) LIKE '%"network:type":"long_distance"%'
|
||||
)
|
||||
)"""
|
||||
long_distance = f"(mode = 'coach' OR {train_long_distance} OR {bus_long_distance})"
|
||||
bus_regional = f"""(
|
||||
mode IN ('bus', 'trolleybus')
|
||||
AND NOT {bus_long_distance}
|
||||
AND (
|
||||
UPPER(COALESCE(name, '')) LIKE '%REGIONALBUS%'
|
||||
OR UPPER(COALESCE(name, '')) LIKE '%REGIOBUS%'
|
||||
OR UPPER(COALESCE(name, '')) LIKE '%REGIONAL BUS%'
|
||||
OR UPPER(COALESCE(name, '')) LIKE '%REGIONALVERKEHR%'
|
||||
OR UPPER(COALESCE(network, '')) LIKE '%REGIONALBUS%'
|
||||
OR UPPER(COALESCE(network, '')) LIKE '%REGIOBUS%'
|
||||
OR UPPER(COALESCE(network, '')) LIKE '%REGIONALVERKEHR%'
|
||||
OR LOWER(COALESCE(tags_json, '')) LIKE '%"service":"regional"%'
|
||||
OR LOWER(COALESCE(tags_json, '')) LIKE '%"bus":"regional"%'
|
||||
OR LOWER(COALESCE(tags_json, '')) LIKE '%"bus":"interurban"%'
|
||||
OR LOWER(COALESCE(tags_json, '')) LIKE '%"network:type":"regional"%'
|
||||
)
|
||||
)"""
|
||||
train_regional = f"""(
|
||||
mode = 'train'
|
||||
AND NOT {train_long_distance}
|
||||
AND (
|
||||
UPPER(COALESCE(ref, '')) LIKE 'IRE%'
|
||||
OR UPPER(COALESCE(ref, '')) LIKE 'RE%'
|
||||
OR UPPER(COALESCE(ref, '')) LIKE 'RB%'
|
||||
OR UPPER(COALESCE(ref, '')) LIKE 'RER%'
|
||||
OR UPPER(COALESCE(ref, '')) LIKE 'TER%'
|
||||
OR UPPER(COALESCE(ref, '')) LIKE 'REX%'
|
||||
OR UPPER(COALESCE(ref, '')) LIKE 'MEX%'
|
||||
OR UPPER(COALESCE(ref, '')) LIKE 'ALX%'
|
||||
OR UPPER(COALESCE(ref, '')) LIKE 'WFB%'
|
||||
OR UPPER(COALESCE(ref, '')) LIKE 'R%'
|
||||
OR UPPER(COALESCE(name, '')) LIKE '%REGIONAL%'
|
||||
OR UPPER(COALESCE(name, '')) LIKE '%REGIO%'
|
||||
OR LOWER(COALESCE(tags_json, '')) LIKE '%"service":"regional"%'
|
||||
OR LOWER(COALESCE(tags_json, '')) LIKE '%"train":"regional"%'
|
||||
)
|
||||
)"""
|
||||
regional = f"({train_regional} OR {bus_regional})"
|
||||
local = f"""(
|
||||
mode IN ('tram', 'light_rail', 'subway', 'ferry', 'funicular', 'aerialway', 'monorail')
|
||||
OR (mode IN ('bus', 'trolleybus') AND NOT ({bus_long_distance} OR {bus_regional}))
|
||||
OR (
|
||||
mode = 'train'
|
||||
AND (
|
||||
UPPER(COALESCE(ref, '')) LIKE 'S%'
|
||||
OR UPPER(COALESCE(name, '')) LIKE '%S-BAHN%'
|
||||
OR UPPER(COALESCE(network, '')) LIKE '%S-BAHN%'
|
||||
OR LOWER(COALESCE(tags_json, '')) LIKE '%"train":"commuter"%'
|
||||
)
|
||||
)
|
||||
)"""
|
||||
parts = []
|
||||
if "long_distance" in route_scopes:
|
||||
parts.append(long_distance)
|
||||
if "regional" in route_scopes:
|
||||
parts.append(regional)
|
||||
if "local" in route_scopes:
|
||||
parts.append(local)
|
||||
if "unknown" in route_scopes:
|
||||
parts.append(f"(mode = 'train' AND NOT ({long_distance} OR {regional} OR {local}))")
|
||||
return "(" + " OR ".join(parts or ["0"]) + ")", []
|
||||
|
||||
|
||||
def _feature_from_row(row: sqlite3.Row, materialized_ids: dict[tuple[int, str, str], int]) -> OsmFeature:
|
||||
dataset_id = int(row["dataset_id"])
|
||||
osm_type = str(row["osm_type"])
|
||||
osm_id = str(row["osm_id"])
|
||||
feature_id = materialized_ids.get((dataset_id, osm_type, osm_id), int(row["id"]))
|
||||
feature = OsmFeature(
|
||||
id=feature_id,
|
||||
dataset_id=dataset_id,
|
||||
osm_type=osm_type,
|
||||
osm_id=osm_id,
|
||||
kind=str(row["kind"]),
|
||||
mode=row["mode"],
|
||||
route_scope=row["route_scope"],
|
||||
name=row["name"],
|
||||
ref=row["ref"],
|
||||
operator=row["operator"],
|
||||
network=row["network"],
|
||||
geometry_geojson=row["geometry_geojson"],
|
||||
min_lon=row["min_lon"],
|
||||
min_lat=row["min_lat"],
|
||||
max_lon=row["max_lon"],
|
||||
max_lat=row["max_lat"],
|
||||
tags_json=row["tags_json"],
|
||||
route_key=row["route_key"],
|
||||
operator_key=row["operator_key"],
|
||||
)
|
||||
setattr(feature, "_osm_sidecar_source", True)
|
||||
setattr(feature, "_osm_sidecar_row_id", int(row["id"]))
|
||||
return feature
|
||||
|
||||
|
||||
def _materialized_ids_by_identity(session: Session, dataset_ids: Sequence[int]) -> dict[tuple[int, str, str], int]:
|
||||
if not dataset_ids:
|
||||
return {}
|
||||
rows = session.execute(
|
||||
select(OsmFeature.dataset_id, OsmFeature.osm_type, OsmFeature.osm_id, OsmFeature.id).where(OsmFeature.dataset_id.in_(dataset_ids))
|
||||
).all()
|
||||
return {(int(dataset_id), str(osm_type), str(osm_id)): int(feature_id) for dataset_id, osm_type, osm_id, feature_id in rows}
|
||||
|
||||
|
||||
def _as_list(value: str | Sequence[str] | None) -> list[str]:
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, str):
|
||||
return [value]
|
||||
return [str(item) for item in value]
|
||||
|
||||
|
||||
def _safe_int(value: object) -> int | None:
|
||||
try:
|
||||
return int(value) # type: ignore[arg-type]
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
61
app/performance.py
Normal file
61
app/performance.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Iterator
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
@contextmanager
|
||||
def measure_pipeline_phase(
|
||||
phase: str,
|
||||
*,
|
||||
source_id: int | None = None,
|
||||
dataset_id: int | None = None,
|
||||
metadata: dict[str, object] | None = None,
|
||||
) -> Iterator[dict[str, object]]:
|
||||
start = time.perf_counter()
|
||||
payload: dict[str, object] = dict(metadata or {})
|
||||
try:
|
||||
yield payload
|
||||
finally:
|
||||
duration = round(time.perf_counter() - start, 3)
|
||||
payload["duration_seconds"] = duration
|
||||
record_pipeline_metric(
|
||||
phase,
|
||||
source_id=source_id,
|
||||
dataset_id=dataset_id,
|
||||
duration_seconds=duration,
|
||||
metadata=payload,
|
||||
)
|
||||
|
||||
|
||||
def record_pipeline_metric(
|
||||
phase: str,
|
||||
*,
|
||||
source_id: int | None = None,
|
||||
dataset_id: int | None = None,
|
||||
duration_seconds: float | None = None,
|
||||
metadata: dict[str, object] | None = None,
|
||||
) -> None:
|
||||
path = _metric_path()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
row = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"phase": phase,
|
||||
"source_id": source_id,
|
||||
"dataset_id": dataset_id,
|
||||
"duration_seconds": duration_seconds,
|
||||
"metadata": metadata or {},
|
||||
}
|
||||
with path.open("a", encoding="utf-8") as handle:
|
||||
handle.write(json.dumps(row, separators=(",", ":"), default=str))
|
||||
handle.write("\n")
|
||||
|
||||
|
||||
def _metric_path() -> Path:
|
||||
return settings.data_dir / "metrics" / "pipeline_metrics.jsonl"
|
||||
111
app/pipeline/download.py
Normal file
111
app/pipeline/download.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
import time
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from app.config import settings
|
||||
from app.models import Source
|
||||
from app.pipeline.utils import sha256_file
|
||||
|
||||
|
||||
def materialize_source(source: Source) -> Path:
|
||||
"""Download/copy a source into the local cache and return the file path.
|
||||
|
||||
Files are stored by content hash per source. Re-running an unchanged source
|
||||
reuses the existing cached file instead of creating another timestamped copy.
|
||||
"""
|
||||
source_dir = settings.data_dir / "sources" / f"source_{source.id}"
|
||||
source_dir.mkdir(parents=True, exist_ok=True)
|
||||
suffix = _guess_suffix(source.url, source.kind)
|
||||
|
||||
parsed = urlparse(source.url)
|
||||
if parsed.scheme in {"http", "https"}:
|
||||
temp_path = _download_temp_path(source_dir, suffix)
|
||||
existing_size = temp_path.stat().st_size if temp_path.exists() else 0
|
||||
headers = {"Range": f"bytes={existing_size}-"} if existing_size > 0 else None
|
||||
with requests.get(source.url, stream=True, timeout=120, headers=headers) as r:
|
||||
r.raise_for_status()
|
||||
mode = "ab" if existing_size > 0 and r.status_code == 206 else "wb"
|
||||
with temp_path.open(mode) as f:
|
||||
for chunk in r.iter_content(chunk_size=1024 * 1024):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
return _store_or_reuse_cached_file(source_dir=source_dir, source_path=temp_path, suffix=suffix, move=True)
|
||||
|
||||
if parsed.scheme == "file":
|
||||
source_path = Path(parsed.path)
|
||||
else:
|
||||
source_path = Path(source.url)
|
||||
|
||||
if not source_path.exists():
|
||||
raise FileNotFoundError(f"Source file does not exist: {source.url}")
|
||||
if _is_relative_to(source_path.resolve(), source_dir.resolve()):
|
||||
return source_path
|
||||
return _store_or_reuse_cached_file(source_dir=source_dir, source_path=source_path, suffix=suffix, move=False)
|
||||
|
||||
|
||||
def _download_temp_path(source_dir: Path, suffix: str) -> Path:
|
||||
candidates = sorted(
|
||||
source_dir.glob(f"*.download{suffix}"),
|
||||
key=lambda path: path.stat().st_mtime if path.exists() else 0,
|
||||
reverse=True,
|
||||
)
|
||||
if candidates:
|
||||
return candidates[0]
|
||||
return source_dir / f"{int(time.time())}.download{suffix}"
|
||||
|
||||
|
||||
def _guess_suffix(url: str, kind: str) -> str:
|
||||
path = urlparse(url).path or url
|
||||
lower = path.lower()
|
||||
for suffix in (".zip", ".geojson", ".json", ".osm.pbf", ".pbf", ".osm", ".osm.xml", ".osc.gz", ".osc", ".csv"):
|
||||
if lower.endswith(suffix):
|
||||
return suffix
|
||||
if kind == "gtfs":
|
||||
return ".zip"
|
||||
if kind == "osm_geojson":
|
||||
return ".geojson"
|
||||
return ".dat"
|
||||
|
||||
|
||||
def _store_or_reuse_cached_file(source_dir: Path, source_path: Path, suffix: str, move: bool) -> Path:
|
||||
source_hash = sha256_file(source_path)
|
||||
target = source_dir / f"{source_hash[:16]}{suffix}"
|
||||
|
||||
if target.exists() and sha256_file(target) == source_hash:
|
||||
if move and source_path != target:
|
||||
source_path.unlink(missing_ok=True)
|
||||
return target
|
||||
|
||||
existing = _find_existing_cached_file(source_dir, source_hash, suffix, exclude=source_path)
|
||||
if existing is not None:
|
||||
if move and source_path != existing:
|
||||
source_path.unlink(missing_ok=True)
|
||||
return existing
|
||||
|
||||
if move:
|
||||
source_path.replace(target)
|
||||
else:
|
||||
shutil.copyfile(source_path, target)
|
||||
return target
|
||||
|
||||
|
||||
def _find_existing_cached_file(source_dir: Path, source_hash: str, suffix: str, exclude: Path | None = None) -> Path | None:
|
||||
for candidate in sorted(source_dir.glob(f"*{suffix}")):
|
||||
if exclude is not None and candidate.resolve() == exclude.resolve():
|
||||
continue
|
||||
if candidate.is_file() and sha256_file(candidate) == source_hash:
|
||||
return candidate
|
||||
return None
|
||||
|
||||
|
||||
def _is_relative_to(path: Path, parent: Path) -> bool:
|
||||
try:
|
||||
path.relative_to(parent)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
1327
app/pipeline/gtfs.py
Normal file
1327
app/pipeline/gtfs.py
Normal file
File diff suppressed because it is too large
Load Diff
995
app/pipeline/matcher.py
Normal file
995
app/pipeline/matcher.py
Normal file
@@ -0,0 +1,995 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
from typing import Callable, Optional
|
||||
|
||||
from shapely.geometry import LineString, MultiLineString, Point, shape
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models import Dataset, GtfsRoute, MatchRule, OsmFeature, RouteMatch
|
||||
from app.osm_storage import ensure_main_osm_feature, osm_feature_bbox, query_osm_features
|
||||
from app.pipeline.state import STAGE_MATCH_ROUTES, dependency_hash, finish_pipeline_run, start_pipeline_run
|
||||
from app.pipeline.utils import approx_bbox_center_distance_deg, bbox_overlap, norm_ref, norm_text
|
||||
|
||||
MODE_GROUPS = {
|
||||
"train": {"train", "rail", "railway"},
|
||||
"subway": {"subway", "metro"},
|
||||
"tram": {"tram", "light_rail"},
|
||||
"light_rail": {"light_rail", "tram"},
|
||||
"bus": {"bus", "coach", "trolleybus"},
|
||||
"coach": {"coach", "bus"},
|
||||
"trolleybus": {"trolleybus", "bus"},
|
||||
"ferry": {"ferry"},
|
||||
"funicular": {"funicular"},
|
||||
"aerialway": {"aerialway", "cable_car"},
|
||||
"monorail": {"monorail"},
|
||||
}
|
||||
MAX_FALLBACK_CANDIDATES_WITH_REF = 40
|
||||
MAX_FALLBACK_CANDIDATES_WITHOUT_REF = 80
|
||||
MAX_EXACT_REF_CANDIDATES = 120
|
||||
OSM_SCOPE_NEAR_DISTANCE_DEG = 0.15
|
||||
GEOMETRY_PROXIMITY_DEG = 0.0035
|
||||
GEOMETRY_SAMPLE_POINTS = 24
|
||||
MATCHER_VERSION = "matcher_v4_scope_spatial_manual_rules"
|
||||
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _ManualMatchRule:
|
||||
id: int
|
||||
rule_type: str
|
||||
route_selector: dict[str, object]
|
||||
osm_selector: dict[str, object] | None
|
||||
status: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _OsmRouteIndex:
|
||||
all_routes: list[OsmFeature]
|
||||
by_ref: dict[str, list[OsmFeature]]
|
||||
by_route_key: dict[str, list[OsmFeature]]
|
||||
by_mode: dict[str, list[OsmFeature]]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _GeometryProfile:
|
||||
geom: object
|
||||
lines: list[LineString]
|
||||
length: float
|
||||
sample_points: list[Point]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _RouteMatchPayload:
|
||||
gtfs_route_id: int
|
||||
osm_feature_id: int | None
|
||||
confidence: float
|
||||
status: str
|
||||
rule_source: str
|
||||
reasons_json: str | None
|
||||
|
||||
|
||||
def run_route_matching(
|
||||
session: Session,
|
||||
*,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
batch_size: int | None = None,
|
||||
) -> dict[str, object]:
|
||||
"""Match active GTFS routes against active OSM route features."""
|
||||
active_datasets = session.execute(
|
||||
select(Dataset.id, Dataset.kind, Dataset.source_id).where(Dataset.is_active.is_(True))
|
||||
).all()
|
||||
if not active_datasets:
|
||||
return {"routes": 0, "matches": 0, "missing": 0}
|
||||
dataset_source_ids = {int(dataset_id): int(source_id) for dataset_id, _, source_id in active_datasets}
|
||||
gtfs_dataset_ids = [int(dataset_id) for dataset_id, kind, _ in active_datasets if kind == "gtfs"]
|
||||
osm_dataset_ids = [int(dataset_id) for dataset_id, kind, _ in active_datasets if kind == "osm_geojson"]
|
||||
if not gtfs_dataset_ids:
|
||||
return {"routes": 0, "matches": 0, "missing": 0}
|
||||
|
||||
route_row_ids = session.scalars(
|
||||
select(GtfsRoute.id)
|
||||
.where(GtfsRoute.dataset_id.in_(gtfs_dataset_ids))
|
||||
.order_by(GtfsRoute.dataset_id, GtfsRoute.route_id, GtfsRoute.id)
|
||||
).all()
|
||||
# Reconcile current match rows from auto scoring plus durable manual rules.
|
||||
total_routes = len(route_row_ids)
|
||||
if total_routes == 0:
|
||||
return {"routes": 0, "matches": 0, "missing": 0}
|
||||
|
||||
dependency = _route_matching_dependency(session, active_datasets)
|
||||
run = start_pipeline_run(
|
||||
session,
|
||||
stage=STAGE_MATCH_ROUTES,
|
||||
version=MATCHER_VERSION,
|
||||
dependency_hash_value=dependency_hash(dependency),
|
||||
inputs=dependency,
|
||||
)
|
||||
session.commit()
|
||||
effective_batch_size = max(1, int(batch_size or settings.route_matching_batch_size))
|
||||
_emit_progress(
|
||||
progress_callback,
|
||||
"route_matching_started",
|
||||
f"Matching {total_routes} GTFS routes in batches of {effective_batch_size}.",
|
||||
0,
|
||||
total_routes,
|
||||
{"gtfs_datasets": gtfs_dataset_ids, "osm_datasets": osm_dataset_ids, "batch_size": effective_batch_size},
|
||||
)
|
||||
manual_rules = _manual_match_rules(session)
|
||||
osm_scope_bbox = osm_feature_bbox(session, osm_dataset_ids, kinds=["route"])
|
||||
counts = {"routes": total_routes, "matches": 0, "missing": 0, "manual": 0, "created": 0, "updated": 0, "unchanged": 0}
|
||||
scoped_counts = {"in_osm_scope": 0, "near_osm_scope": 0, "outside_osm_scope": 0, "unknown_scope": 0}
|
||||
processed = 0
|
||||
for chunk in _chunks_int(route_row_ids, effective_batch_size):
|
||||
routes = session.scalars(
|
||||
select(GtfsRoute)
|
||||
.where(GtfsRoute.id.in_(chunk))
|
||||
.order_by(GtfsRoute.dataset_id, GtfsRoute.route_id, GtfsRoute.id)
|
||||
).all()
|
||||
batch_counts = _match_route_batch(
|
||||
session=session,
|
||||
routes=routes,
|
||||
osm_dataset_ids=osm_dataset_ids,
|
||||
dataset_source_ids=dataset_source_ids,
|
||||
manual_rules=manual_rules,
|
||||
osm_scope_bbox=osm_scope_bbox,
|
||||
scoped_counts=scoped_counts,
|
||||
)
|
||||
counts["matches"] += batch_counts["matches"]
|
||||
counts["missing"] += batch_counts["missing"]
|
||||
counts["manual"] += batch_counts["manual"]
|
||||
counts["created"] += batch_counts["created"]
|
||||
counts["updated"] += batch_counts["updated"]
|
||||
counts["unchanged"] += batch_counts["unchanged"]
|
||||
processed += len(routes)
|
||||
session.commit()
|
||||
_emit_progress(
|
||||
progress_callback,
|
||||
"route_matching_batch",
|
||||
f"Matched {processed}/{total_routes} GTFS routes.",
|
||||
processed,
|
||||
total_routes,
|
||||
{
|
||||
"processed": processed,
|
||||
"matches": counts["matches"],
|
||||
"missing": counts["missing"],
|
||||
"manual": counts["manual"],
|
||||
"created": counts["created"],
|
||||
"updated": counts["updated"],
|
||||
"unchanged": counts["unchanged"],
|
||||
"scope": dict(scoped_counts),
|
||||
},
|
||||
)
|
||||
result = {**counts, "scope": scoped_counts}
|
||||
finish_pipeline_run(session, run, outputs=result)
|
||||
session.commit()
|
||||
_emit_progress(
|
||||
progress_callback,
|
||||
"route_matching_completed",
|
||||
"Route matching completed.",
|
||||
total_routes,
|
||||
total_routes,
|
||||
result,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _route_matching_dependency(session: Session, active_datasets) -> dict[str, object]:
|
||||
datasets = [
|
||||
{"id": int(dataset_id), "kind": str(kind), "source_id": int(source_id), "sha256": _dataset_sha(session, int(dataset_id))}
|
||||
for dataset_id, kind, source_id in active_datasets
|
||||
]
|
||||
rules = [
|
||||
{
|
||||
"id": int(rule.id),
|
||||
"type": rule.rule_type,
|
||||
"active": bool(rule.active),
|
||||
"selector": rule.selector_json,
|
||||
"action": rule.action_json,
|
||||
}
|
||||
for rule in session.scalars(select(MatchRule).order_by(MatchRule.id)).all()
|
||||
]
|
||||
return {"version": MATCHER_VERSION, "active_datasets": datasets, "manual_rules": rules}
|
||||
|
||||
|
||||
def _dataset_sha(session: Session, dataset_id: int) -> str | None:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
return None if dataset is None else dataset.sha256
|
||||
|
||||
|
||||
def _match_route_batch(
|
||||
*,
|
||||
session: Session,
|
||||
routes: list[GtfsRoute],
|
||||
osm_dataset_ids: list[int],
|
||||
dataset_source_ids: dict[int, int],
|
||||
manual_rules: list[_ManualMatchRule],
|
||||
osm_scope_bbox: tuple[float | None, float | None, float | None, float | None],
|
||||
scoped_counts: dict[str, int],
|
||||
) -> dict[str, int]:
|
||||
matches = 0
|
||||
missing = 0
|
||||
manual = 0
|
||||
payloads: list[_RouteMatchPayload] = []
|
||||
for route in routes:
|
||||
scope = route_match_scope(route, osm_scope_bbox)
|
||||
scoped_counts[scope] = scoped_counts.get(scope, 0) + 1
|
||||
route_source_id = dataset_source_ids.get(route.dataset_id)
|
||||
accepted_rule = _accepted_rule_for_route(manual_rules, route, route_source_id)
|
||||
if accepted_rule is not None:
|
||||
accepted_feature = _feature_for_rule_from_storage(session, osm_dataset_ids, dataset_source_ids, accepted_rule)
|
||||
if accepted_feature is not None:
|
||||
accepted_feature = ensure_main_osm_feature(session, accepted_feature)
|
||||
payloads.append(
|
||||
_RouteMatchPayload(
|
||||
gtfs_route_id=route.id,
|
||||
osm_feature_id=accepted_feature.id,
|
||||
confidence=100.0,
|
||||
status="accepted",
|
||||
rule_source="manual",
|
||||
reasons_json=json.dumps(
|
||||
{"manual_rule_id": accepted_rule.id, "manual": "accepted_match", "scope": scope},
|
||||
separators=(",", ":"),
|
||||
),
|
||||
)
|
||||
)
|
||||
matches += 1
|
||||
manual += 1
|
||||
continue
|
||||
|
||||
if scope == "outside_osm_scope":
|
||||
missing += 1
|
||||
payloads.append(
|
||||
_RouteMatchPayload(
|
||||
gtfs_route_id=route.id,
|
||||
osm_feature_id=None,
|
||||
confidence=0.0,
|
||||
status="missing",
|
||||
rule_source="auto",
|
||||
reasons_json=json.dumps(
|
||||
{
|
||||
"reason": "outside loaded OSM route scope",
|
||||
"scope": scope,
|
||||
},
|
||||
separators=(",", ":"),
|
||||
),
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
best_feature: Optional[OsmFeature] = None
|
||||
best_score = 0.0
|
||||
best_reasons: dict[str, object] = {}
|
||||
route_geometry_profile = _geometry_profile(route.geometry_geojson)
|
||||
for feature in candidate_osm_routes_for_route(session, route, osm_dataset_ids):
|
||||
if _is_rejected_pair(manual_rules, route, route_source_id, feature, dataset_source_ids.get(feature.dataset_id)):
|
||||
continue
|
||||
feature_geometry_profile = _geometry_profile(feature.geometry_geojson)
|
||||
score, reasons = score_route_pair(
|
||||
route,
|
||||
feature,
|
||||
route_geometry_profile=route_geometry_profile,
|
||||
feature_geometry_profile=feature_geometry_profile,
|
||||
)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_feature = feature
|
||||
best_reasons = reasons
|
||||
status = _status_from_score(best_score)
|
||||
if best_feature is None or status == "missing":
|
||||
missing += 1
|
||||
best_feature_id = None
|
||||
best_reasons = {
|
||||
"reason": "no OSM candidate above threshold",
|
||||
"scope": scope,
|
||||
"best_score_below_threshold": round(float(best_score), 2) if best_score else 0,
|
||||
"best_reasons": best_reasons,
|
||||
}
|
||||
best_score = 0
|
||||
else:
|
||||
matches += 1
|
||||
best_feature = ensure_main_osm_feature(session, best_feature)
|
||||
best_feature_id = best_feature.id
|
||||
best_reasons["scope"] = scope
|
||||
payloads.append(
|
||||
_RouteMatchPayload(
|
||||
gtfs_route_id=route.id,
|
||||
osm_feature_id=best_feature_id,
|
||||
confidence=round(float(best_score), 2),
|
||||
status=status,
|
||||
rule_source="auto",
|
||||
reasons_json=json.dumps(best_reasons, separators=(",", ":")),
|
||||
)
|
||||
)
|
||||
changes = _apply_route_match_payloads(session, payloads)
|
||||
session.flush()
|
||||
return {"matches": matches, "missing": missing, "manual": manual, **changes}
|
||||
|
||||
|
||||
def _apply_route_match_payloads(session: Session, payloads: list[_RouteMatchPayload]) -> dict[str, int]:
|
||||
if not payloads:
|
||||
return {"created": 0, "updated": 0, "unchanged": 0}
|
||||
route_ids = [payload.gtfs_route_id for payload in payloads]
|
||||
existing_rows = session.scalars(
|
||||
select(RouteMatch).where(RouteMatch.gtfs_route_id.in_(route_ids)).order_by(RouteMatch.gtfs_route_id, RouteMatch.id)
|
||||
).all()
|
||||
existing_by_route: dict[int, list[RouteMatch]] = {}
|
||||
for row in existing_rows:
|
||||
existing_by_route.setdefault(row.gtfs_route_id, []).append(row)
|
||||
|
||||
created = 0
|
||||
updated = 0
|
||||
unchanged = 0
|
||||
duplicate_ids: list[int] = []
|
||||
now = datetime.now(timezone.utc)
|
||||
for payload in payloads:
|
||||
existing = existing_by_route.get(payload.gtfs_route_id, [])
|
||||
current = _preferred_existing_match(existing)
|
||||
if current is None:
|
||||
session.add(
|
||||
RouteMatch(
|
||||
gtfs_route_id=payload.gtfs_route_id,
|
||||
osm_feature_id=payload.osm_feature_id,
|
||||
confidence=payload.confidence,
|
||||
status=payload.status,
|
||||
rule_source=payload.rule_source,
|
||||
reasons_json=payload.reasons_json,
|
||||
)
|
||||
)
|
||||
created += 1
|
||||
continue
|
||||
|
||||
duplicate_ids.extend(row.id for row in existing if row.id != current.id)
|
||||
if _route_match_payload_equal(current, payload):
|
||||
unchanged += 1
|
||||
continue
|
||||
current.osm_feature_id = payload.osm_feature_id
|
||||
current.confidence = payload.confidence
|
||||
current.status = payload.status
|
||||
current.rule_source = payload.rule_source
|
||||
current.reasons_json = payload.reasons_json
|
||||
current.updated_at = now
|
||||
updated += 1
|
||||
|
||||
for chunk in _chunks_int(duplicate_ids, 1000):
|
||||
session.execute(delete(RouteMatch).where(RouteMatch.id.in_(chunk)))
|
||||
return {"created": created, "updated": updated, "unchanged": unchanged}
|
||||
|
||||
|
||||
def _preferred_existing_match(rows: list[RouteMatch]) -> RouteMatch | None:
|
||||
if not rows:
|
||||
return None
|
||||
return next((row for row in rows if row.rule_source == "manual"), rows[0])
|
||||
|
||||
|
||||
def _route_match_payload_equal(row: RouteMatch, payload: _RouteMatchPayload) -> bool:
|
||||
return (
|
||||
row.osm_feature_id == payload.osm_feature_id
|
||||
and round(float(row.confidence or 0), 2) == round(float(payload.confidence or 0), 2)
|
||||
and row.status == payload.status
|
||||
and row.rule_source == payload.rule_source
|
||||
and (row.reasons_json or None) == (payload.reasons_json or None)
|
||||
)
|
||||
|
||||
|
||||
def _build_osm_route_index(osm_routes: list[OsmFeature]) -> _OsmRouteIndex:
|
||||
by_ref: dict[str, list[OsmFeature]] = {}
|
||||
by_route_key: dict[str, list[OsmFeature]] = {}
|
||||
by_mode: dict[str, list[OsmFeature]] = {}
|
||||
for feature in osm_routes:
|
||||
ref = norm_ref(feature.ref or "")
|
||||
if ref:
|
||||
by_ref.setdefault(ref, []).append(feature)
|
||||
if feature.route_key:
|
||||
by_route_key.setdefault(feature.route_key, []).append(feature)
|
||||
if feature.mode:
|
||||
by_mode.setdefault(feature.mode, []).append(feature)
|
||||
return _OsmRouteIndex(all_routes=osm_routes, by_ref=by_ref, by_route_key=by_route_key, by_mode=by_mode)
|
||||
|
||||
|
||||
def _candidate_osm_routes(route: GtfsRoute, index: _OsmRouteIndex) -> list[OsmFeature]:
|
||||
selected: list[OsmFeature] = []
|
||||
seen: set[int] = set()
|
||||
|
||||
def add(features: list[OsmFeature], *, require_compatible_mode: bool = True) -> None:
|
||||
for feature in features:
|
||||
if feature.id in seen:
|
||||
continue
|
||||
if require_compatible_mode and not _mode_compatible(route.mode or "", feature.mode or ""):
|
||||
continue
|
||||
seen.add(feature.id)
|
||||
selected.append(feature)
|
||||
|
||||
route_ref = norm_ref(route.short_name or route.route_id)
|
||||
if route_ref:
|
||||
add(index.by_ref.get(route_ref, []))
|
||||
if route.route_key:
|
||||
add(index.by_route_key.get(route.route_key, []))
|
||||
if selected:
|
||||
return _spatially_ranked_candidates(route, selected, MAX_EXACT_REF_CANDIDATES)
|
||||
|
||||
compatible_modes = MODE_GROUPS.get(route.mode or "", {route.mode or ""})
|
||||
mode_candidates: list[OsmFeature] = []
|
||||
for mode in compatible_modes:
|
||||
if mode:
|
||||
mode_candidates.extend(index.by_mode.get(mode, []))
|
||||
if not mode_candidates:
|
||||
mode_candidates = index.all_routes
|
||||
|
||||
gtfs_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
|
||||
near_candidates: list[tuple[float, OsmFeature]] = []
|
||||
for feature in mode_candidates:
|
||||
osm_bbox = (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
|
||||
distance = approx_bbox_center_distance_deg(gtfs_bbox, osm_bbox)
|
||||
if bbox_overlap(gtfs_bbox, osm_bbox):
|
||||
near_candidates.append((0.0, feature))
|
||||
elif distance is not None and distance < 0.12:
|
||||
near_candidates.append((distance, feature))
|
||||
fallback_limit = MAX_FALLBACK_CANDIDATES_WITH_REF if route_ref else MAX_FALLBACK_CANDIDATES_WITHOUT_REF
|
||||
fallback = [feature for _, feature in sorted(near_candidates, key=lambda item: item[0])[:fallback_limit]]
|
||||
if not fallback:
|
||||
fallback = mode_candidates[:fallback_limit]
|
||||
add(fallback)
|
||||
return _spatially_ranked_candidates(route, selected, fallback_limit)
|
||||
|
||||
|
||||
def candidate_osm_routes_for_route(session: Session, route: GtfsRoute, osm_dataset_ids: list[int]) -> list[OsmFeature]:
|
||||
if not osm_dataset_ids:
|
||||
return []
|
||||
selected: list[OsmFeature] = []
|
||||
seen: set[tuple[int, str, str]] = set()
|
||||
|
||||
def add(features: list[OsmFeature], *, require_compatible_mode: bool = True) -> None:
|
||||
for feature in features:
|
||||
key = (feature.dataset_id, feature.osm_type, feature.osm_id)
|
||||
if key in seen:
|
||||
continue
|
||||
if require_compatible_mode and not _mode_compatible(route.mode or "", feature.mode or ""):
|
||||
continue
|
||||
seen.add(key)
|
||||
selected.append(feature)
|
||||
|
||||
route_ref = norm_ref(route.short_name or route.route_id)
|
||||
route_keys = [key for key in [route.route_key, route_ref] if key]
|
||||
for route_key in dict.fromkeys(route_keys):
|
||||
add(
|
||||
query_osm_features(
|
||||
session,
|
||||
osm_dataset_ids,
|
||||
kinds=["route"],
|
||||
route_key=route_key,
|
||||
)
|
||||
)
|
||||
if selected:
|
||||
return _spatially_ranked_candidates(route, selected, MAX_EXACT_REF_CANDIDATES)
|
||||
|
||||
gtfs_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
|
||||
compatible_modes = sorted(MODE_GROUPS.get(route.mode or "", {route.mode or ""}) - {""})
|
||||
if not any(value is None for value in gtfs_bbox):
|
||||
bbox = _expanded_bbox(gtfs_bbox, 0.10)
|
||||
add(
|
||||
query_osm_features(
|
||||
session,
|
||||
osm_dataset_ids,
|
||||
kinds=["route"],
|
||||
modes=compatible_modes or None,
|
||||
bbox=bbox,
|
||||
limit=MAX_FALLBACK_CANDIDATES_WITHOUT_REF * 4,
|
||||
),
|
||||
require_compatible_mode=False,
|
||||
)
|
||||
if not selected:
|
||||
add(
|
||||
query_osm_features(
|
||||
session,
|
||||
osm_dataset_ids,
|
||||
kinds=["route"],
|
||||
modes=compatible_modes or None,
|
||||
limit=MAX_FALLBACK_CANDIDATES_WITHOUT_REF,
|
||||
),
|
||||
require_compatible_mode=False,
|
||||
)
|
||||
fallback_limit = MAX_FALLBACK_CANDIDATES_WITH_REF if route_ref else MAX_FALLBACK_CANDIDATES_WITHOUT_REF
|
||||
return _spatially_ranked_candidates(route, selected, fallback_limit)
|
||||
|
||||
|
||||
def score_route_pair(
|
||||
route: GtfsRoute,
|
||||
feature: OsmFeature,
|
||||
route_geometry_profile: _GeometryProfile | None = None,
|
||||
feature_geometry_profile: _GeometryProfile | None = None,
|
||||
) -> tuple[float, dict[str, object]]:
|
||||
score = 0.0
|
||||
reasons: dict[str, object] = {}
|
||||
|
||||
gtfs_mode = route.mode or ""
|
||||
osm_mode = feature.mode or ""
|
||||
if _mode_compatible(gtfs_mode, osm_mode):
|
||||
score += 25
|
||||
reasons["mode"] = "compatible"
|
||||
elif gtfs_mode and osm_mode:
|
||||
reasons["mode"] = f"mismatch: {gtfs_mode} != {osm_mode}"
|
||||
return 0.0, reasons
|
||||
|
||||
gtfs_ref = norm_ref(route.short_name or route.route_id)
|
||||
osm_ref = norm_ref(feature.ref or "")
|
||||
if gtfs_ref and osm_ref:
|
||||
if gtfs_ref == osm_ref:
|
||||
score += 25
|
||||
reasons["ref"] = "exact"
|
||||
elif gtfs_ref in osm_ref or osm_ref in gtfs_ref:
|
||||
score += 15
|
||||
reasons["ref"] = "partial"
|
||||
|
||||
gtfs_name = norm_text(" ".join(v for v in [route.long_name, route.short_name, route.route_id] if v))
|
||||
osm_name = norm_text(" ".join(v for v in [feature.name, feature.ref] if v))
|
||||
name_similarity = _ratio(gtfs_name, osm_name)
|
||||
score += 20 * name_similarity
|
||||
reasons["name_similarity"] = round(name_similarity, 3)
|
||||
|
||||
gtfs_operator = norm_text(route.operator_name or "")
|
||||
osm_operator = norm_text(" ".join(v for v in [feature.operator, feature.network] if v))
|
||||
operator_similarity = _ratio(gtfs_operator, osm_operator) if gtfs_operator and osm_operator else 0
|
||||
score += 15 * operator_similarity
|
||||
reasons["operator_similarity"] = round(operator_similarity, 3)
|
||||
|
||||
gtfs_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
|
||||
osm_bbox = (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
|
||||
center_distance = None
|
||||
if bbox_overlap(gtfs_bbox, osm_bbox):
|
||||
score += 14
|
||||
reasons["bbox"] = "overlap"
|
||||
if gtfs_ref and osm_ref and gtfs_ref == osm_ref and _mode_compatible(gtfs_mode, osm_mode):
|
||||
score += 8
|
||||
reasons["line_identity"] = "exact_ref_mode_bbox_overlap"
|
||||
else:
|
||||
center_distance = approx_bbox_center_distance_deg(gtfs_bbox, osm_bbox)
|
||||
if center_distance is not None:
|
||||
if center_distance < 0.01:
|
||||
score += 12
|
||||
elif center_distance < 0.03:
|
||||
score += 8
|
||||
elif center_distance < 0.08:
|
||||
score += 4
|
||||
elif gtfs_ref and osm_ref and gtfs_ref == osm_ref and center_distance > OSM_SCOPE_NEAR_DISTANCE_DEG:
|
||||
score -= 8
|
||||
reasons["spatial_penalty"] = "exact_ref_far_bbox_center"
|
||||
reasons["bbox_center_distance_deg"] = round(center_distance, 5)
|
||||
|
||||
geometry_metrics = (
|
||||
_geometry_match_metrics_from_profiles(route_geometry_profile, feature_geometry_profile)
|
||||
if route_geometry_profile is not None and feature_geometry_profile is not None
|
||||
else _geometry_match_metrics(route.geometry_geojson, feature.geometry_geojson)
|
||||
)
|
||||
if geometry_metrics is not None:
|
||||
reasons["geometry"] = geometry_metrics
|
||||
geometry_score = 34 * float(geometry_metrics["gtfs_on_osm_ratio"]) + 8 * float(geometry_metrics["osm_on_gtfs_ratio"])
|
||||
if float(geometry_metrics["endpoint_distance_deg"]) < GEOMETRY_PROXIMITY_DEG * 2:
|
||||
geometry_score += 6
|
||||
if float(geometry_metrics["length_ratio"]) < 0.35 or float(geometry_metrics["length_ratio"]) > 2.8:
|
||||
geometry_score -= 8
|
||||
reasons["geometry_length"] = "implausible_ratio"
|
||||
score += max(0.0, min(42.0, geometry_score))
|
||||
|
||||
# Extra small boost for same normalized route key.
|
||||
if route.route_key and feature.route_key and route.route_key == feature.route_key:
|
||||
score += 5
|
||||
reasons["route_key"] = "same"
|
||||
|
||||
if gtfs_ref and osm_ref and gtfs_ref == osm_ref and _mode_compatible(gtfs_mode, osm_mode):
|
||||
if bbox_overlap(gtfs_bbox, osm_bbox):
|
||||
score = max(score, 88.0)
|
||||
reasons["strong_identity"] = "exact_ref_mode_bbox_overlap"
|
||||
elif center_distance is not None and center_distance < 0.02:
|
||||
score = max(score, 82.0)
|
||||
reasons["strong_identity"] = "exact_ref_mode_near_bbox_center"
|
||||
|
||||
if route.route_key and feature.route_key and route.route_key == feature.route_key and _mode_compatible(gtfs_mode, osm_mode):
|
||||
if bbox_overlap(gtfs_bbox, osm_bbox):
|
||||
score = max(score, 86.0)
|
||||
reasons.setdefault("strong_identity", "same_route_key_mode_bbox_overlap")
|
||||
|
||||
if geometry_metrics is not None:
|
||||
gtfs_on_osm = float(geometry_metrics["gtfs_on_osm_ratio"])
|
||||
endpoint_distance = float(geometry_metrics["endpoint_distance_deg"])
|
||||
if gtfs_on_osm >= 0.82 and endpoint_distance < GEOMETRY_PROXIMITY_DEG * 3 and _mode_compatible(gtfs_mode, osm_mode):
|
||||
if gtfs_ref and osm_ref and gtfs_ref == osm_ref:
|
||||
score = max(score, 90.0)
|
||||
reasons["strong_identity"] = "exact_ref_mode_geometry_overlap"
|
||||
elif gtfs_ref and osm_ref and (gtfs_ref in osm_ref or osm_ref in gtfs_ref):
|
||||
score = max(score, 82.0)
|
||||
reasons["strong_identity"] = "partial_ref_mode_geometry_overlap"
|
||||
|
||||
if (
|
||||
gtfs_ref
|
||||
and osm_ref
|
||||
and gtfs_ref == osm_ref
|
||||
and center_distance is not None
|
||||
and center_distance > OSM_SCOPE_NEAR_DISTANCE_DEG
|
||||
and not bbox_overlap(gtfs_bbox, osm_bbox)
|
||||
and (
|
||||
geometry_metrics is None
|
||||
or float(geometry_metrics.get("gtfs_on_osm_ratio", 0.0)) < 0.25
|
||||
)
|
||||
):
|
||||
score = min(score, 58.0)
|
||||
reasons["spatial_cap"] = "exact_ref_far_without_geometry_overlap"
|
||||
|
||||
return min(score, 100.0), reasons
|
||||
|
||||
|
||||
def route_match_scope(route: GtfsRoute, osm_scope_bbox: tuple[float | None, float | None, float | None, float | None]) -> str:
|
||||
route_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
|
||||
if any(value is None for value in route_bbox) or any(value is None for value in osm_scope_bbox):
|
||||
return "unknown_scope"
|
||||
if bbox_overlap(route_bbox, osm_scope_bbox):
|
||||
return "in_osm_scope"
|
||||
distance = approx_bbox_center_distance_deg(route_bbox, osm_scope_bbox)
|
||||
if distance is not None and distance < OSM_SCOPE_NEAR_DISTANCE_DEG:
|
||||
return "near_osm_scope"
|
||||
return "outside_osm_scope"
|
||||
|
||||
|
||||
def _combined_bbox(features: list[OsmFeature]) -> tuple[float | None, float | None, float | None, float | None]:
|
||||
boxes = [
|
||||
(feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
|
||||
for feature in features
|
||||
if None not in (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
|
||||
]
|
||||
if not boxes:
|
||||
return (None, None, None, None)
|
||||
return (
|
||||
min(float(box[0]) for box in boxes if box[0] is not None),
|
||||
min(float(box[1]) for box in boxes if box[1] is not None),
|
||||
max(float(box[2]) for box in boxes if box[2] is not None),
|
||||
max(float(box[3]) for box in boxes if box[3] is not None),
|
||||
)
|
||||
|
||||
|
||||
def _spatially_ranked_candidates(route: GtfsRoute, candidates: list[OsmFeature], limit: int) -> list[OsmFeature]:
|
||||
return [
|
||||
feature
|
||||
for _, feature in sorted(
|
||||
((_spatial_rank(route, feature), feature) for feature in candidates),
|
||||
key=lambda item: item[0],
|
||||
)[: max(1, limit)]
|
||||
]
|
||||
|
||||
|
||||
def _spatial_rank(route: GtfsRoute, feature: OsmFeature) -> tuple[int, float, str]:
|
||||
route_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
|
||||
feature_bbox = (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
|
||||
distance = approx_bbox_center_distance_deg(route_bbox, feature_bbox)
|
||||
if bbox_overlap(route_bbox, feature_bbox):
|
||||
bucket = 0
|
||||
elif distance is not None and distance < OSM_SCOPE_NEAR_DISTANCE_DEG:
|
||||
bucket = 1
|
||||
elif distance is not None:
|
||||
bucket = 2
|
||||
else:
|
||||
bucket = 3
|
||||
return (bucket, distance if distance is not None else 999.0, feature.osm_id)
|
||||
|
||||
|
||||
def _expanded_bbox(
|
||||
bbox: tuple[float | None, float | None, float | None, float | None],
|
||||
padding: float,
|
||||
) -> tuple[float, float, float, float] | None:
|
||||
min_lon, min_lat, max_lon, max_lat = bbox
|
||||
if None in (min_lon, min_lat, max_lon, max_lat):
|
||||
return None
|
||||
return (float(min_lon) - padding, float(min_lat) - padding, float(max_lon) + padding, float(max_lat) + padding)
|
||||
|
||||
|
||||
def _chunks_int(values: list[int], size: int) -> list[list[int]]:
|
||||
return [values[start : start + size] for start in range(0, len(values), max(1, size))]
|
||||
|
||||
|
||||
def _emit_progress(
|
||||
progress_callback: ProgressCallback | None,
|
||||
event_type: str,
|
||||
message: str,
|
||||
progress_current: int | None,
|
||||
progress_total: int | None,
|
||||
metadata: dict[str, object] | None = None,
|
||||
) -> None:
|
||||
if progress_callback is not None:
|
||||
progress_callback(event_type, message, progress_current, progress_total, metadata)
|
||||
|
||||
|
||||
def _geometry_match_metrics(route_geometry: str | None, feature_geometry: str | None) -> dict[str, float] | None:
|
||||
route_profile = _geometry_profile(route_geometry)
|
||||
feature_profile = _geometry_profile(feature_geometry)
|
||||
return _geometry_match_metrics_from_profiles(route_profile, feature_profile)
|
||||
|
||||
|
||||
def _geometry_profile(geometry_text: str | None) -> _GeometryProfile | None:
|
||||
if not geometry_text:
|
||||
return None
|
||||
try:
|
||||
geom = shape(json.loads(geometry_text))
|
||||
except Exception: # noqa: BLE001 - malformed geometry should not break matching
|
||||
return None
|
||||
lines = _iter_lines(geom)
|
||||
if not lines:
|
||||
return None
|
||||
length = sum(line.length for line in lines)
|
||||
if length == 0:
|
||||
return None
|
||||
sample_points = _sample_line_points(lines, GEOMETRY_SAMPLE_POINTS)
|
||||
if not sample_points:
|
||||
return None
|
||||
return _GeometryProfile(geom=geom, lines=lines, length=length, sample_points=sample_points)
|
||||
|
||||
|
||||
def _geometry_match_metrics_from_profiles(
|
||||
route_profile: _GeometryProfile | None, feature_profile: _GeometryProfile | None
|
||||
) -> dict[str, float] | None:
|
||||
if route_profile is None or feature_profile is None:
|
||||
return None
|
||||
gtfs_on_osm = _near_point_ratio(route_profile.sample_points, feature_profile.geom, GEOMETRY_PROXIMITY_DEG)
|
||||
osm_on_gtfs = _near_point_ratio(feature_profile.sample_points, route_profile.geom, GEOMETRY_PROXIMITY_DEG)
|
||||
endpoint_distance = _endpoint_distance(route_profile.lines, feature_profile.geom)
|
||||
length_ratio = route_profile.length / feature_profile.length if feature_profile.length else 0.0
|
||||
return {
|
||||
"gtfs_on_osm_ratio": round(gtfs_on_osm, 3),
|
||||
"osm_on_gtfs_ratio": round(osm_on_gtfs, 3),
|
||||
"endpoint_distance_deg": round(endpoint_distance, 6),
|
||||
"length_ratio": round(length_ratio, 3),
|
||||
}
|
||||
|
||||
|
||||
def _iter_lines(geom) -> list[LineString]:
|
||||
if isinstance(geom, LineString):
|
||||
return [geom]
|
||||
if isinstance(geom, MultiLineString):
|
||||
return [line for line in geom.geoms if isinstance(line, LineString) and line.length > 0]
|
||||
return []
|
||||
|
||||
|
||||
def _sample_line_points(lines: list[LineString], count: int) -> list[Point]:
|
||||
total_length = sum(line.length for line in lines)
|
||||
if total_length == 0:
|
||||
return []
|
||||
points = []
|
||||
for index in range(count):
|
||||
target = total_length * (index / max(1, count - 1))
|
||||
traversed = 0.0
|
||||
for line in lines:
|
||||
next_traversed = traversed + line.length
|
||||
if target <= next_traversed or line is lines[-1]:
|
||||
points.append(line.interpolate(max(0.0, min(line.length, target - traversed))))
|
||||
break
|
||||
traversed = next_traversed
|
||||
return points
|
||||
|
||||
|
||||
def _near_point_ratio(points: list[Point], geom, max_distance: float) -> float:
|
||||
if not points:
|
||||
return 0.0
|
||||
near = sum(1 for point in points if geom.distance(point) <= max_distance)
|
||||
return near / len(points)
|
||||
|
||||
|
||||
def _endpoint_distance(gtfs_lines: list[LineString], osm_geom) -> float:
|
||||
longest = max(gtfs_lines, key=lambda line: line.length)
|
||||
coords = list(longest.coords)
|
||||
if len(coords) < 2:
|
||||
return 999.0
|
||||
return osm_geom.distance(Point(coords[0])) + osm_geom.distance(Point(coords[-1]))
|
||||
|
||||
|
||||
def _manual_match_rules(session: Session) -> list[_ManualMatchRule]:
|
||||
rules = session.scalars(
|
||||
select(MatchRule)
|
||||
.where(MatchRule.active.is_(True), MatchRule.rule_type.in_(["accept_match", "reject_match"]))
|
||||
.order_by(MatchRule.id.desc())
|
||||
).all()
|
||||
parsed: list[_ManualMatchRule] = []
|
||||
for rule in rules:
|
||||
try:
|
||||
selector = json.loads(rule.selector_json or "{}")
|
||||
action = json.loads(rule.action_json or "{}")
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
route_selector = selector.get("gtfs") if isinstance(selector.get("gtfs"), dict) else selector
|
||||
osm_selector = action.get("osm") if isinstance(action.get("osm"), dict) else selector.get("osm")
|
||||
if not isinstance(osm_selector, dict) and selector.get("osm_feature_id") is not None:
|
||||
osm_selector = {"osm_feature_id": selector.get("osm_feature_id")}
|
||||
status = str(action.get("status") or ("accepted" if rule.rule_type == "accept_match" else "rejected"))
|
||||
parsed.append(
|
||||
_ManualMatchRule(
|
||||
id=rule.id,
|
||||
rule_type=rule.rule_type,
|
||||
route_selector=route_selector,
|
||||
osm_selector=osm_selector if isinstance(osm_selector, dict) else None,
|
||||
status=status,
|
||||
)
|
||||
)
|
||||
return parsed
|
||||
|
||||
|
||||
def _accepted_rule_for_route(
|
||||
rules: list[_ManualMatchRule], route: GtfsRoute, route_source_id: int | None
|
||||
) -> _ManualMatchRule | None:
|
||||
for rule in rules:
|
||||
if rule.rule_type != "accept_match":
|
||||
continue
|
||||
if rule.status != "accepted":
|
||||
continue
|
||||
if _route_matches_selector(route, route_source_id, rule.route_selector):
|
||||
return rule
|
||||
return None
|
||||
|
||||
|
||||
def _feature_for_rule(
|
||||
features: list[OsmFeature], dataset_source_ids: dict[int, int], rule: _ManualMatchRule
|
||||
) -> OsmFeature | None:
|
||||
if not rule.osm_selector:
|
||||
return None
|
||||
for feature in features:
|
||||
if _feature_matches_selector(feature, dataset_source_ids.get(feature.dataset_id), rule.osm_selector):
|
||||
return feature
|
||||
return None
|
||||
|
||||
|
||||
def _feature_for_rule_from_storage(
|
||||
session: Session,
|
||||
osm_dataset_ids: list[int],
|
||||
dataset_source_ids: dict[int, int],
|
||||
rule: _ManualMatchRule,
|
||||
) -> OsmFeature | None:
|
||||
if not rule.osm_selector:
|
||||
return None
|
||||
selector = rule.osm_selector
|
||||
legacy_id = _safe_int(selector.get("osm_feature_id"))
|
||||
if legacy_id is not None:
|
||||
feature = session.get(OsmFeature, legacy_id)
|
||||
if feature is not None and _feature_matches_selector(feature, dataset_source_ids.get(feature.dataset_id), selector):
|
||||
return feature
|
||||
scoped_dataset_ids = list(osm_dataset_ids)
|
||||
expected_source = selector.get("source_id")
|
||||
if expected_source is not None:
|
||||
expected_source_id = _safe_int(expected_source)
|
||||
if expected_source_id is not None:
|
||||
scoped_dataset_ids = [
|
||||
dataset_id
|
||||
for dataset_id in scoped_dataset_ids
|
||||
if dataset_source_ids.get(dataset_id) == expected_source_id
|
||||
]
|
||||
dataset_id = _safe_int(selector.get("dataset_id"))
|
||||
if dataset_id is not None:
|
||||
scoped_dataset_ids = [value for value in scoped_dataset_ids if value == dataset_id]
|
||||
if not scoped_dataset_ids:
|
||||
return None
|
||||
|
||||
features: list[OsmFeature] = []
|
||||
osm_type = selector.get("osm_type")
|
||||
osm_id = selector.get("osm_id")
|
||||
if osm_type and osm_id:
|
||||
features = query_osm_features(
|
||||
session,
|
||||
scoped_dataset_ids,
|
||||
kinds=["route"],
|
||||
osm_type=str(osm_type),
|
||||
osm_id=str(osm_id),
|
||||
limit=10,
|
||||
)
|
||||
if not features:
|
||||
route_key = selector.get("route_key")
|
||||
if route_key:
|
||||
features = query_osm_features(session, scoped_dataset_ids, kinds=["route"], route_key=str(route_key))
|
||||
if not features:
|
||||
ref = norm_ref(selector.get("ref"))
|
||||
if ref:
|
||||
features = query_osm_features(session, scoped_dataset_ids, kinds=["route"], route_key=ref)
|
||||
for feature in features:
|
||||
if _feature_matches_selector(feature, dataset_source_ids.get(feature.dataset_id), selector):
|
||||
return feature
|
||||
return None
|
||||
|
||||
|
||||
def _is_rejected_pair(
|
||||
rules: list[_ManualMatchRule],
|
||||
route: GtfsRoute,
|
||||
route_source_id: int | None,
|
||||
feature: OsmFeature,
|
||||
feature_source_id: int | None,
|
||||
) -> bool:
|
||||
for rule in rules:
|
||||
if rule.rule_type != "reject_match":
|
||||
continue
|
||||
if not _route_matches_selector(route, route_source_id, rule.route_selector):
|
||||
continue
|
||||
if rule.osm_selector and _feature_matches_selector(feature, feature_source_id, rule.osm_selector):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _route_matches_selector(route: GtfsRoute, source_id: int | None, selector: dict[str, object]) -> bool:
|
||||
legacy_id = selector.get("gtfs_route_id")
|
||||
if legacy_id is not None and _safe_int(legacy_id) == route.id:
|
||||
return True
|
||||
expected_source = selector.get("source_id")
|
||||
if expected_source is not None and source_id is not None and _safe_int(expected_source) != source_id:
|
||||
return False
|
||||
route_id = selector.get("route_id")
|
||||
if route_id and str(route_id) == route.route_id:
|
||||
return True
|
||||
route_key = selector.get("route_key")
|
||||
if route_key and route.route_key and str(route_key) == route.route_key:
|
||||
return True
|
||||
ref = norm_ref(selector.get("ref"))
|
||||
mode = selector.get("mode")
|
||||
if ref and ref == norm_ref(route.short_name or route.route_id):
|
||||
return not mode or _mode_compatible(str(mode), route.mode or "")
|
||||
return False
|
||||
|
||||
|
||||
def _feature_matches_selector(feature: OsmFeature, source_id: int | None, selector: dict[str, object]) -> bool:
|
||||
legacy_id = selector.get("osm_feature_id")
|
||||
if legacy_id is not None and _safe_int(legacy_id) == feature.id:
|
||||
return True
|
||||
expected_source = selector.get("source_id")
|
||||
if expected_source is not None and source_id is not None and _safe_int(expected_source) != source_id:
|
||||
return False
|
||||
osm_type = selector.get("osm_type")
|
||||
osm_id = selector.get("osm_id")
|
||||
if osm_type and osm_id and str(osm_type) == feature.osm_type and str(osm_id) == feature.osm_id:
|
||||
return True
|
||||
route_key = selector.get("route_key")
|
||||
if route_key and feature.route_key and str(route_key) == feature.route_key:
|
||||
return True
|
||||
ref = norm_ref(selector.get("ref"))
|
||||
mode = selector.get("mode")
|
||||
if ref and ref == norm_ref(feature.ref or ""):
|
||||
return not mode or _mode_compatible(str(mode), feature.mode or "")
|
||||
return False
|
||||
|
||||
|
||||
def _safe_int(value: object) -> int | None:
|
||||
try:
|
||||
return int(value) # type: ignore[arg-type]
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _mode_compatible(gtfs_mode: str, osm_mode: str) -> bool:
|
||||
if not gtfs_mode or not osm_mode:
|
||||
return True
|
||||
if gtfs_mode == osm_mode:
|
||||
return True
|
||||
return osm_mode in MODE_GROUPS.get(gtfs_mode, {gtfs_mode}) or gtfs_mode in MODE_GROUPS.get(osm_mode, {osm_mode})
|
||||
|
||||
|
||||
def _ratio(a: str, b: str) -> float:
|
||||
if not a or not b:
|
||||
return 0.0
|
||||
if a == b:
|
||||
return 1.0
|
||||
token_ratio = _token_similarity(a, b)
|
||||
if a in b or b in a:
|
||||
token_ratio = max(token_ratio, 0.82)
|
||||
return token_ratio
|
||||
|
||||
|
||||
def _token_similarity(a: str, b: str) -> float:
|
||||
left = set(a.split())
|
||||
right = set(b.split())
|
||||
if not left or not right:
|
||||
return 0.0
|
||||
return len(left & right) / len(left | right)
|
||||
|
||||
|
||||
def _status_from_score(score: float) -> str:
|
||||
if score >= 85:
|
||||
return "matched"
|
||||
if score >= 65:
|
||||
return "probable"
|
||||
if score >= 40:
|
||||
return "weak"
|
||||
return "missing"
|
||||
508
app/pipeline/osm_addresses.py
Normal file
508
app/pipeline/osm_addresses.py
Normal file
@@ -0,0 +1,508 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import osmium
|
||||
from sqlalchemy import delete, func, select, text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models import Dataset, OsmAddress
|
||||
from app.pipeline.routing_layer import active_routing_dataset
|
||||
from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries
|
||||
|
||||
|
||||
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
|
||||
ADDRESS_INDEX_VERSION = "osm_addresses_v2_nodes_ways_area_geometry"
|
||||
ADDRESS_TAGS = {
|
||||
"addr:housenumber",
|
||||
"addr:housename",
|
||||
"addr:street",
|
||||
"addr:place",
|
||||
"addr:postcode",
|
||||
"addr:city",
|
||||
"addr:country",
|
||||
"addr:unit",
|
||||
"addr:suburb",
|
||||
"addr:district",
|
||||
"addr:municipality",
|
||||
"entrance",
|
||||
"name",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AddressIndexResult:
|
||||
dataset_id: int
|
||||
input_path: str
|
||||
addresses: int
|
||||
node_addresses: int
|
||||
way_addresses: int
|
||||
skipped: int
|
||||
version: str = ADDRESS_INDEX_VERSION
|
||||
|
||||
def as_dict(self) -> dict[str, object]:
|
||||
return {
|
||||
"version": self.version,
|
||||
"dataset_id": self.dataset_id,
|
||||
"input_path": self.input_path,
|
||||
"addresses": self.addresses,
|
||||
"node_addresses": self.node_addresses,
|
||||
"way_addresses": self.way_addresses,
|
||||
"skipped": self.skipped,
|
||||
}
|
||||
|
||||
|
||||
def rebuild_address_index(
|
||||
session: Session,
|
||||
*,
|
||||
dataset_id: int | None = None,
|
||||
input_path: str | Path | None = None,
|
||||
reset: bool = True,
|
||||
batch_size: int = 20_000,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
) -> dict[str, object]:
|
||||
dataset = session.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(session)
|
||||
if dataset is None:
|
||||
raise ValueError("No OSM PBF dataset is available for address indexing.")
|
||||
path = Path(input_path or dataset.local_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Address index PBF does not exist: {path}")
|
||||
|
||||
if reset:
|
||||
_emit(progress_callback, "address_index_clear_started", "Clearing existing OSM address index.", None, None, {"dataset_id": dataset.id})
|
||||
_clear_address_rows(session, dataset_id=int(dataset.id))
|
||||
session.commit()
|
||||
|
||||
if settings.is_postgresql_database:
|
||||
_emit(progress_callback, "address_index_indexes_dropped", "Dropping address lookup indexes before bulk import.", None, None, {"dataset_id": dataset.id})
|
||||
_drop_address_indexes(session)
|
||||
session.commit()
|
||||
|
||||
_emit(progress_callback, "address_index_import_started", "Importing OSM address nodes and ways.", None, None, {"dataset_id": dataset.id, "path": str(path)})
|
||||
handler = _AddressHandler(
|
||||
session=session,
|
||||
dataset_id=dataset.id,
|
||||
batch_size=batch_size,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
if hasattr(osmium, "FileProcessor"):
|
||||
_apply_address_file_processor(handler, path)
|
||||
else:
|
||||
handler.apply_file(str(path), locations=True)
|
||||
handler.flush()
|
||||
|
||||
return finalize_address_index(
|
||||
session,
|
||||
dataset_id=dataset.id,
|
||||
input_path=path,
|
||||
node_addresses=handler.node_address_count,
|
||||
way_addresses=handler.way_address_count,
|
||||
skipped=handler.skipped_count,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
|
||||
def finalize_address_index(
|
||||
session: Session,
|
||||
*,
|
||||
dataset_id: int,
|
||||
input_path: str | Path,
|
||||
node_addresses: int = 0,
|
||||
way_addresses: int = 0,
|
||||
skipped: int = 0,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
) -> dict[str, object]:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if dataset is None:
|
||||
raise ValueError("Address index dataset does not exist.")
|
||||
if settings.is_postgresql_database:
|
||||
_emit(progress_callback, "address_index_geometry_started", "Refreshing address point geometries.", None, None, {"dataset_id": dataset.id})
|
||||
refresh_postgis_geometries(session, dataset_id=dataset.id, tables=["osm_addresses"], only_missing=False)
|
||||
session.commit()
|
||||
_emit(progress_callback, "address_index_indexes_started", "Rebuilding address lookup indexes.", None, None, {"dataset_id": dataset.id})
|
||||
_create_address_indexes(session)
|
||||
session.commit()
|
||||
analyze_postgresql_tables(session, ["osm_addresses"])
|
||||
address_count = int(session.scalar(select(func.count()).select_from(OsmAddress).where(OsmAddress.dataset_id == dataset.id)) or 0)
|
||||
metadata = _metadata(dataset)
|
||||
metadata["address_index"] = {
|
||||
"version": ADDRESS_INDEX_VERSION,
|
||||
"addresses": address_count,
|
||||
"node_addresses": int(node_addresses),
|
||||
"way_addresses": int(way_addresses),
|
||||
"skipped": int(skipped),
|
||||
"input_path": str(input_path),
|
||||
}
|
||||
dataset.metadata_json = json.dumps(metadata, indent=2)
|
||||
session.commit()
|
||||
result = AddressIndexResult(
|
||||
dataset_id=dataset.id,
|
||||
input_path=str(input_path),
|
||||
addresses=address_count,
|
||||
node_addresses=node_addresses,
|
||||
way_addresses=way_addresses,
|
||||
skipped=skipped,
|
||||
).as_dict()
|
||||
_emit(progress_callback, "address_index_import_completed", "OSM address index import completed.", address_count, address_count, result)
|
||||
return result
|
||||
|
||||
|
||||
def _clear_address_rows(session: Session, *, dataset_id: int) -> None:
|
||||
if settings.is_postgresql_database:
|
||||
other_dataset_count = int(
|
||||
session.scalar(
|
||||
select(func.count(func.distinct(OsmAddress.dataset_id))).where(OsmAddress.dataset_id != int(dataset_id))
|
||||
)
|
||||
or 0
|
||||
)
|
||||
if other_dataset_count == 0:
|
||||
session.execute(text("TRUNCATE TABLE osm_addresses RESTART IDENTITY"))
|
||||
return
|
||||
session.execute(delete(OsmAddress).where(OsmAddress.dataset_id == int(dataset_id)))
|
||||
|
||||
|
||||
def address_index_status(session: Session) -> dict[str, object]:
|
||||
dataset = active_routing_dataset(session)
|
||||
dataset_id = None if dataset is None else int(dataset.id)
|
||||
address_count = 0
|
||||
metadata: dict[str, object] = {}
|
||||
if dataset is not None:
|
||||
metadata = _metadata(dataset).get("address_index") or {}
|
||||
if isinstance(metadata, dict):
|
||||
try:
|
||||
address_count = int(metadata.get("addresses") or 0)
|
||||
except (TypeError, ValueError):
|
||||
address_count = 0
|
||||
if not address_count:
|
||||
address_count = int(session.scalar(select(func.count()).select_from(OsmAddress).where(OsmAddress.dataset_id == dataset.id)) or 0)
|
||||
installed_version = metadata.get("version") if isinstance(metadata, dict) else None
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"addresses": address_count,
|
||||
"available": address_count > 0,
|
||||
"version": installed_version,
|
||||
"current_version": ADDRESS_INDEX_VERSION,
|
||||
"stale": bool(address_count and installed_version != ADDRESS_INDEX_VERSION),
|
||||
"input_path": metadata.get("input_path") if isinstance(metadata, dict) else None,
|
||||
}
|
||||
|
||||
|
||||
class _AddressHandler(osmium.SimpleHandler):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
dataset_id: int,
|
||||
batch_size: int,
|
||||
progress_callback: ProgressCallback | None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.session = session
|
||||
self.dataset_id = int(dataset_id)
|
||||
self.batch_size = max(1_000, int(batch_size))
|
||||
self.progress_callback = progress_callback
|
||||
self.rows: list[dict[str, object]] = []
|
||||
self.address_count = 0
|
||||
self.node_address_count = 0
|
||||
self.way_address_count = 0
|
||||
self.skipped_count = 0
|
||||
self.processed_count = 0
|
||||
|
||||
def node(self, node) -> None:
|
||||
self.process_node(node)
|
||||
|
||||
def way(self, way) -> None:
|
||||
self.process_way(way)
|
||||
|
||||
def process_object(self, obj) -> None:
|
||||
if hasattr(obj, "nodes"):
|
||||
self.process_way(obj)
|
||||
elif hasattr(obj, "location"):
|
||||
self.process_node(obj)
|
||||
|
||||
def process_node(self, node) -> None:
|
||||
self.processed_count += 1
|
||||
tags = {tag.k: tag.v for tag in node.tags}
|
||||
if not _has_address(tags):
|
||||
return
|
||||
if not node.location.valid():
|
||||
self.skipped_count += 1
|
||||
return
|
||||
row = _address_row(
|
||||
dataset_id=self.dataset_id,
|
||||
osm_type="node",
|
||||
osm_id=str(node.id),
|
||||
tags=tags,
|
||||
lon=float(node.location.lon),
|
||||
lat=float(node.location.lat),
|
||||
bounds=(float(node.location.lon), float(node.location.lat), float(node.location.lon), float(node.location.lat)),
|
||||
geometry_geojson=None,
|
||||
)
|
||||
if row is None:
|
||||
self.skipped_count += 1
|
||||
return
|
||||
self.rows.append(row)
|
||||
self.node_address_count += 1
|
||||
self._after_address()
|
||||
|
||||
def process_way(self, way) -> None:
|
||||
self.processed_count += 1
|
||||
tags = {tag.k: tag.v for tag in way.tags}
|
||||
if not _has_address(tags):
|
||||
return
|
||||
coords = [
|
||||
(float(node.location.lon), float(node.location.lat))
|
||||
for node in way.nodes
|
||||
if node.location.valid()
|
||||
]
|
||||
if not coords:
|
||||
self.skipped_count += 1
|
||||
return
|
||||
lon, lat = _centroid(coords)
|
||||
min_lon = min(coord[0] for coord in coords)
|
||||
max_lon = max(coord[0] for coord in coords)
|
||||
min_lat = min(coord[1] for coord in coords)
|
||||
max_lat = max(coord[1] for coord in coords)
|
||||
row = _address_row(
|
||||
dataset_id=self.dataset_id,
|
||||
osm_type="way",
|
||||
osm_id=str(way.id),
|
||||
tags=tags,
|
||||
lon=lon,
|
||||
lat=lat,
|
||||
bounds=(min_lon, min_lat, max_lon, max_lat),
|
||||
geometry_geojson=_address_area_geometry_geojson(coords, closed=_way_is_closed(way)),
|
||||
)
|
||||
if row is None:
|
||||
self.skipped_count += 1
|
||||
return
|
||||
self.rows.append(row)
|
||||
self.way_address_count += 1
|
||||
self._after_address()
|
||||
|
||||
def _after_address(self) -> None:
|
||||
self.address_count += 1
|
||||
if len(self.rows) >= self.batch_size:
|
||||
self.flush()
|
||||
if self.address_count % 50_000 == 0:
|
||||
_emit(
|
||||
self.progress_callback,
|
||||
"address_index_import_batch",
|
||||
f"Imported {self.address_count:,} OSM addresses.",
|
||||
self.address_count,
|
||||
None,
|
||||
{"processed": self.processed_count, "skipped": self.skipped_count},
|
||||
)
|
||||
|
||||
def flush(self) -> None:
|
||||
if not self.rows:
|
||||
return
|
||||
self.session.bulk_insert_mappings(OsmAddress, self.rows)
|
||||
self.session.commit()
|
||||
self.rows = []
|
||||
|
||||
|
||||
def _apply_address_file_processor(handler: _AddressHandler, path: Path) -> None:
|
||||
processor = (
|
||||
osmium.FileProcessor(str(path), osmium.osm.NODE | osmium.osm.WAY)
|
||||
.with_locations()
|
||||
.with_filter(osmium.filter.KeyFilter("addr:housenumber", "addr:housename"))
|
||||
)
|
||||
for obj in processor:
|
||||
handler.process_object(obj)
|
||||
|
||||
|
||||
def _has_address(tags: dict[str, str]) -> bool:
|
||||
housenumber = _clean(tags.get("addr:housenumber") or tags.get("addr:housename"))
|
||||
if not housenumber:
|
||||
return False
|
||||
return any(_clean(tags.get(key)) for key in ("addr:street", "addr:place", "addr:city", "addr:postcode"))
|
||||
|
||||
|
||||
def _address_row(
|
||||
*,
|
||||
dataset_id: int,
|
||||
osm_type: str,
|
||||
osm_id: str,
|
||||
tags: dict[str, str],
|
||||
lon: float,
|
||||
lat: float,
|
||||
bounds: tuple[float, float, float, float],
|
||||
geometry_geojson: str | None = None,
|
||||
) -> dict[str, object] | None:
|
||||
housenumber = _clean(tags.get("addr:housenumber") or tags.get("addr:housename"))
|
||||
street = _clean(tags.get("addr:street"))
|
||||
place = _clean(tags.get("addr:place"))
|
||||
postcode = _clean(tags.get("addr:postcode"))
|
||||
city = _clean(tags.get("addr:city") or tags.get("addr:municipality"))
|
||||
country = _clean(tags.get("addr:country"))
|
||||
unit = _clean(tags.get("addr:unit"))
|
||||
name = _clean(tags.get("name"))
|
||||
display_name = _display_name(housenumber=housenumber, street=street, place=place, postcode=postcode, city=city, name=name)
|
||||
if not display_name:
|
||||
return None
|
||||
search_text = _search_text(display_name, housenumber, street, place, postcode, city, country, unit, name)
|
||||
selected_tags = {key: tags[key] for key in sorted(ADDRESS_TAGS) if key in tags}
|
||||
min_lon, min_lat, max_lon, max_lat = bounds
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"osm_type": osm_type,
|
||||
"osm_id": osm_id,
|
||||
"housenumber": housenumber,
|
||||
"street": street,
|
||||
"place": place,
|
||||
"postcode": postcode,
|
||||
"city": city,
|
||||
"country": country,
|
||||
"unit": unit,
|
||||
"name": name,
|
||||
"display_name": display_name,
|
||||
"search_text": search_text,
|
||||
"lon": lon,
|
||||
"lat": lat,
|
||||
"min_lon": min_lon,
|
||||
"min_lat": min_lat,
|
||||
"max_lon": max_lon,
|
||||
"max_lat": max_lat,
|
||||
"geometry_geojson": geometry_geojson,
|
||||
"tags_json": json.dumps(selected_tags, separators=(",", ":")) if selected_tags else None,
|
||||
}
|
||||
|
||||
|
||||
def _address_area_geometry_geojson(coords: list[tuple[float, float]], *, closed: bool | None = None) -> str | None:
|
||||
if closed is False:
|
||||
return None
|
||||
if len(coords) < 3:
|
||||
return None
|
||||
ring_coords = list(coords)
|
||||
first = ring_coords[0]
|
||||
last = ring_coords[-1]
|
||||
already_closed = abs(first[0] - last[0]) <= 1e-12 and abs(first[1] - last[1]) <= 1e-12
|
||||
if not already_closed:
|
||||
if closed is not True:
|
||||
return None
|
||||
ring_coords.append(first)
|
||||
if len(ring_coords) < 4:
|
||||
return None
|
||||
ring = [[float(lon), float(lat)] for lon, lat in ring_coords]
|
||||
if len({(round(lon, 12), round(lat, 12)) for lon, lat in ring_coords[:-1]}) < 3:
|
||||
return None
|
||||
return json.dumps({"type": "Polygon", "coordinates": [ring]}, separators=(",", ":"))
|
||||
|
||||
|
||||
def _way_is_closed(way) -> bool:
|
||||
try:
|
||||
nodes = way.nodes
|
||||
return len(nodes) >= 3 and nodes[0].ref == nodes[-1].ref
|
||||
except (AttributeError, IndexError, TypeError):
|
||||
return False
|
||||
|
||||
|
||||
def _display_name(
|
||||
*,
|
||||
housenumber: str | None,
|
||||
street: str | None,
|
||||
place: str | None,
|
||||
postcode: str | None,
|
||||
city: str | None,
|
||||
name: str | None,
|
||||
) -> str | None:
|
||||
road = street or place or name
|
||||
if road and housenumber:
|
||||
first = f"{road} {housenumber}"
|
||||
else:
|
||||
first = road or housenumber
|
||||
locality = " ".join(part for part in [postcode, city] if part)
|
||||
if first and locality:
|
||||
return f"{first}, {locality}"
|
||||
return first or locality
|
||||
|
||||
|
||||
def _search_text(*parts: str | None) -> str:
|
||||
return re.sub(r"\s+", " ", " ".join(part.casefold() for part in parts if part)).strip()
|
||||
|
||||
|
||||
def _clean(value: object) -> str | None:
|
||||
cleaned = re.sub(r"\s+", " ", str(value or "")).strip()
|
||||
return cleaned or None
|
||||
|
||||
|
||||
def _centroid(coords: list[tuple[float, float]]) -> tuple[float, float]:
|
||||
if len(coords) >= 4 and coords[0] == coords[-1]:
|
||||
area = 0.0
|
||||
cx = 0.0
|
||||
cy = 0.0
|
||||
for (x1, y1), (x2, y2) in zip(coords, coords[1:]):
|
||||
cross = x1 * y2 - x2 * y1
|
||||
area += cross
|
||||
cx += (x1 + x2) * cross
|
||||
cy += (y1 + y2) * cross
|
||||
if abs(area) > 1e-18:
|
||||
factor = 1 / (3 * area)
|
||||
return cx * factor, cy * factor
|
||||
return (
|
||||
math.fsum(coord[0] for coord in coords) / len(coords),
|
||||
math.fsum(coord[1] for coord in coords) / len(coords),
|
||||
)
|
||||
|
||||
|
||||
def _drop_address_indexes(session: Session) -> None:
|
||||
for name in [
|
||||
"ix_osm_addresses_dataset_city_street",
|
||||
"ix_osm_addresses_dataset_postcode",
|
||||
"ix_osm_addresses_bbox",
|
||||
"ix_osm_addresses_geom_gist",
|
||||
"ix_osm_addresses_area_geom_gist",
|
||||
"ix_osm_addresses_search_trgm",
|
||||
"ix_osm_addresses_display_trgm",
|
||||
"ix_osm_addresses_street_key_house",
|
||||
"ix_osm_addresses_street_key_trgm",
|
||||
]:
|
||||
session.execute(text(f"DROP INDEX IF EXISTS {name}"))
|
||||
|
||||
|
||||
def _create_address_indexes(session: Session) -> None:
|
||||
statements = [
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_dataset_city_street ON osm_addresses (dataset_id, city, street, housenumber)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_dataset_postcode ON osm_addresses (dataset_id, postcode)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_bbox ON osm_addresses (dataset_id, min_lon, max_lon, min_lat, max_lat)",
|
||||
]
|
||||
if settings.is_postgresql_database:
|
||||
statements.extend(
|
||||
[
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_geom_gist ON osm_addresses USING GIST (geom)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_area_geom_gist ON osm_addresses USING GIST (area_geom)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_search_trgm ON osm_addresses USING GIN (LOWER(COALESCE(search_text, '')) gin_trgm_ops)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_display_trgm ON osm_addresses USING GIN (LOWER(COALESCE(display_name, '')) gin_trgm_ops)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_street_key_house ON osm_addresses (dataset_id, REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss'), housenumber)",
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_street_key_trgm ON osm_addresses USING GIN (REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss') gin_trgm_ops)",
|
||||
]
|
||||
)
|
||||
for statement in statements:
|
||||
session.execute(text(statement))
|
||||
|
||||
|
||||
def _metadata(dataset: Dataset) -> dict[str, object]:
|
||||
try:
|
||||
value = json.loads(dataset.metadata_json or "{}")
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return value if isinstance(value, dict) else {}
|
||||
|
||||
|
||||
def _emit(
|
||||
progress_callback: ProgressCallback | None,
|
||||
event_type: str,
|
||||
message: str,
|
||||
progress_current: int | None,
|
||||
progress_total: int | None,
|
||||
metadata: dict[str, object] | None = None,
|
||||
) -> None:
|
||||
if progress_callback is not None:
|
||||
progress_callback(event_type, message, progress_current, progress_total, metadata)
|
||||
100
app/pipeline/osm_diff.py
Normal file
100
app/pipeline/osm_diff.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models import Dataset, Source
|
||||
from app.pipeline.download import materialize_source
|
||||
from app.pipeline.osm_pbf import _raw_format
|
||||
from app.pipeline.osm_replication import fetch_replication_state
|
||||
from app.pipeline.utils import sha256_file
|
||||
|
||||
|
||||
def run_osm_diff_source(session: Session, source: Source) -> Dataset:
|
||||
"""Commit an OSM change file as a raw update artifact.
|
||||
|
||||
Applying the diff to an authoritative OSM base extract is a separate step;
|
||||
this importer deliberately records the file without treating it as a
|
||||
complete visual route layer.
|
||||
"""
|
||||
if _looks_like_update_directory(source.url):
|
||||
return _commit_update_directory_state(session, source)
|
||||
|
||||
raw_path = materialize_source(source)
|
||||
raw_hash = sha256_file(raw_path)
|
||||
existing = session.scalar(
|
||||
select(Dataset)
|
||||
.where(Dataset.source_id == source.id, Dataset.kind == "osm_diff_raw", Dataset.sha256 == raw_hash)
|
||||
.order_by(Dataset.id.desc())
|
||||
)
|
||||
if existing is not None:
|
||||
return existing
|
||||
|
||||
dataset = Dataset(
|
||||
source_id=source.id,
|
||||
kind="osm_diff_raw",
|
||||
local_path=str(raw_path),
|
||||
sha256=raw_hash,
|
||||
is_active=False,
|
||||
status="committed",
|
||||
metadata_json=json.dumps(
|
||||
{
|
||||
"stage": "raw_osm_diff",
|
||||
"raw_format": _raw_format(raw_path),
|
||||
"source_url": source.url,
|
||||
},
|
||||
indent=2,
|
||||
),
|
||||
)
|
||||
session.add(dataset)
|
||||
session.flush()
|
||||
return dataset
|
||||
|
||||
|
||||
def _commit_update_directory_state(session: Session, source: Source) -> Dataset:
|
||||
state = fetch_replication_state(source.url, timeout=settings.osm_diff_state_timeout_seconds)
|
||||
source_dir = settings.data_dir / "sources" / f"source_{source.id}"
|
||||
source_dir.mkdir(parents=True, exist_ok=True)
|
||||
state_path = source_dir / f"state_{state.sequence_number}.txt"
|
||||
state_path.write_text(
|
||||
"\n".join(f"{key}={value}" for key, value in sorted(state.raw.items())) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
state_hash = sha256_file(state_path)
|
||||
existing = session.scalar(
|
||||
select(Dataset)
|
||||
.where(Dataset.source_id == source.id, Dataset.kind == "osm_diff_state", Dataset.sha256 == state_hash)
|
||||
.order_by(Dataset.id.desc())
|
||||
)
|
||||
if existing is not None:
|
||||
return existing
|
||||
dataset = Dataset(
|
||||
source_id=source.id,
|
||||
kind="osm_diff_state",
|
||||
local_path=str(state_path),
|
||||
sha256=state_hash,
|
||||
is_active=False,
|
||||
status="committed",
|
||||
metadata_json=json.dumps(
|
||||
{
|
||||
"stage": "osm_diff_state",
|
||||
"updates_url": source.url,
|
||||
"sequence_number": state.sequence_number,
|
||||
"timestamp": state.timestamp,
|
||||
"state": state.raw,
|
||||
},
|
||||
indent=2,
|
||||
),
|
||||
)
|
||||
session.add(dataset)
|
||||
session.flush()
|
||||
return dataset
|
||||
|
||||
|
||||
def _looks_like_update_directory(url: str) -> bool:
|
||||
lower_path = urlparse(url).path.lower()
|
||||
return lower_path.endswith("-updates") or lower_path.endswith("-updates/")
|
||||
248
app/pipeline/osm_geojson.py
Normal file
248
app/pipeline/osm_geojson.py
Normal file
@@ -0,0 +1,248 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models import Dataset, OsmFeature, Source
|
||||
from app.osm_classification import infer_osm_route_scope
|
||||
from app.osm_storage import (
|
||||
OSM_STORAGE_METADATA_KEY,
|
||||
OSM_STORAGE_MAIN,
|
||||
OSM_STORAGE_SIDECAR_FEATURES,
|
||||
create_osm_sidecar,
|
||||
dedupe_osm_feature_rows,
|
||||
effective_osm_feature_storage,
|
||||
)
|
||||
from app.pipeline.download import materialize_source
|
||||
from app.pipeline.utils import first_nonempty, geometry_json_and_bbox, norm_ref, norm_text, sha256_file
|
||||
from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries
|
||||
|
||||
ROUTE_MODES = {
|
||||
"train",
|
||||
"railway",
|
||||
"light_rail",
|
||||
"subway",
|
||||
"tram",
|
||||
"bus",
|
||||
"trolleybus",
|
||||
"coach",
|
||||
"ferry",
|
||||
"monorail",
|
||||
"funicular",
|
||||
"aerialway",
|
||||
}
|
||||
|
||||
|
||||
def run_osm_geojson_source(session: Session, source: Source) -> Dataset:
|
||||
local_path = materialize_source(source)
|
||||
source_hash = sha256_file(local_path)
|
||||
existing = session.scalar(
|
||||
select(Dataset)
|
||||
.where(
|
||||
Dataset.source_id == source.id,
|
||||
Dataset.kind == "osm_geojson",
|
||||
Dataset.sha256 == source_hash,
|
||||
Dataset.is_active.is_(True),
|
||||
Dataset.status == "imported",
|
||||
)
|
||||
.order_by(Dataset.id.desc())
|
||||
)
|
||||
if existing is not None:
|
||||
return existing
|
||||
return import_osm_geojson(session=session, source=source, path=local_path, source_hash=source_hash)
|
||||
|
||||
|
||||
def import_osm_geojson(
|
||||
session: Session,
|
||||
source: Source,
|
||||
path: Path,
|
||||
source_hash: str | None = None,
|
||||
*,
|
||||
storage_mode: str | None = None,
|
||||
) -> Dataset:
|
||||
for dataset in source.datasets:
|
||||
dataset.is_active = False
|
||||
|
||||
dataset = Dataset(
|
||||
source_id=source.id,
|
||||
kind="osm_geojson",
|
||||
local_path=str(path),
|
||||
sha256=source_hash or sha256_file(path),
|
||||
is_active=True,
|
||||
status="importing",
|
||||
)
|
||||
session.add(dataset)
|
||||
session.flush()
|
||||
|
||||
source_hash = source_hash or sha256_file(path)
|
||||
dataset.metadata_json = json.dumps(
|
||||
prepare_osm_geojson_storage(
|
||||
session=session,
|
||||
dataset=dataset,
|
||||
path=path,
|
||||
source_hash=source_hash,
|
||||
storage_mode=storage_mode,
|
||||
),
|
||||
indent=2,
|
||||
)
|
||||
|
||||
dataset.status = "imported"
|
||||
source.status = "ok"
|
||||
source.last_error = None
|
||||
session.flush()
|
||||
return dataset
|
||||
|
||||
|
||||
def prepare_osm_geojson_storage(
|
||||
*,
|
||||
session: Session,
|
||||
dataset: Dataset,
|
||||
path: Path,
|
||||
source_hash: str | None = None,
|
||||
storage_mode: str | None = None,
|
||||
) -> dict[str, object]:
|
||||
data = json.loads(path.read_text(encoding="utf-8"))
|
||||
features = _as_features(data)
|
||||
feature_rows = [_feature_row(dataset.id, idx, feature) for idx, feature in enumerate(features)]
|
||||
storage = effective_osm_feature_storage(storage_mode)
|
||||
if storage not in {OSM_STORAGE_MAIN, OSM_STORAGE_SIDECAR_FEATURES}:
|
||||
raise ValueError(f"Unsupported OSM feature storage mode: {storage}")
|
||||
if storage == OSM_STORAGE_SIDECAR_FEATURES:
|
||||
return {
|
||||
"features": len(feature_rows),
|
||||
OSM_STORAGE_METADATA_KEY: create_osm_sidecar(dataset, feature_rows, source_hash=source_hash or dataset.sha256),
|
||||
}
|
||||
_insert_main_features(session, feature_rows)
|
||||
session.flush()
|
||||
refresh_postgis_geometries(session, dataset_id=dataset.id, tables=["osm_features"])
|
||||
analyze_postgresql_tables(session, ["osm_features"])
|
||||
return {"features": len(feature_rows), OSM_STORAGE_METADATA_KEY: {"mode": OSM_STORAGE_MAIN}}
|
||||
|
||||
|
||||
def _insert_main_features(session: Session, feature_rows: list[dict[str, object]]) -> None:
|
||||
objects: list[OsmFeature] = []
|
||||
deduped_rows, _duplicate_count = dedupe_osm_feature_rows(feature_rows)
|
||||
for row in deduped_rows:
|
||||
objects.append(
|
||||
OsmFeature(
|
||||
dataset_id=row["dataset_id"],
|
||||
osm_type=row["osm_type"],
|
||||
osm_id=row["osm_id"],
|
||||
kind=row["kind"],
|
||||
mode=row["mode"],
|
||||
route_scope=row["route_scope"],
|
||||
name=row["name"],
|
||||
ref=row["ref"],
|
||||
operator=row["operator"],
|
||||
network=row["network"],
|
||||
geometry_geojson=row["geometry_geojson"],
|
||||
min_lon=row["min_lon"],
|
||||
min_lat=row["min_lat"],
|
||||
max_lon=row["max_lon"],
|
||||
max_lat=row["max_lat"],
|
||||
tags_json=row["tags_json"],
|
||||
route_key=row["route_key"],
|
||||
operator_key=row["operator_key"],
|
||||
)
|
||||
)
|
||||
if len(objects) >= 5000:
|
||||
session.bulk_save_objects(objects)
|
||||
objects.clear()
|
||||
if objects:
|
||||
session.bulk_save_objects(objects)
|
||||
|
||||
|
||||
def _feature_row(dataset_id: int, idx: int, feature: dict[str, Any]) -> dict[str, object]:
|
||||
props = feature.get("properties") or {}
|
||||
geometry = feature.get("geometry")
|
||||
geometry_text, bbox = geometry_json_and_bbox(geometry)
|
||||
osm_type = str(first_nonempty(props.get("osm_type"), props.get("@type"), props.get("type"), "feature"))
|
||||
osm_id = str(first_nonempty(props.get("osm_id"), props.get("@id"), props.get("id"), f"feature_{idx}"))
|
||||
mode = _infer_mode(props)
|
||||
kind = _infer_kind(props, mode)
|
||||
name = first_nonempty(props.get("name"), props.get("official_name")) or None
|
||||
ref = first_nonempty(props.get("ref"), props.get("route_ref"), props.get("line")) or None
|
||||
operator = first_nonempty(props.get("operator"), props.get("agency"), props.get("brand")) or None
|
||||
network = first_nonempty(props.get("network"), props.get("network:short")) or None
|
||||
route_scope = infer_osm_route_scope(mode=mode, ref=ref, name=name, network=network, tags=props)
|
||||
route_key = norm_ref(ref) or norm_text(name) or norm_ref(osm_id)
|
||||
operator_key = norm_text(operator or network or "")
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"osm_type": osm_type,
|
||||
"osm_id": osm_id,
|
||||
"kind": kind,
|
||||
"mode": mode,
|
||||
"route_scope": route_scope,
|
||||
"name": name,
|
||||
"ref": ref,
|
||||
"operator": operator,
|
||||
"network": network,
|
||||
"geometry_geojson": geometry_text,
|
||||
"min_lon": bbox[0],
|
||||
"min_lat": bbox[1],
|
||||
"max_lon": bbox[2],
|
||||
"max_lat": bbox[3],
|
||||
"tags_json": json.dumps(props, separators=(",", ":")),
|
||||
"route_key": route_key,
|
||||
"operator_key": operator_key,
|
||||
}
|
||||
|
||||
|
||||
def _as_features(data: Any) -> list[dict[str, Any]]:
|
||||
if isinstance(data, dict) and data.get("type") == "FeatureCollection":
|
||||
return [f for f in data.get("features", []) if isinstance(f, dict)]
|
||||
if isinstance(data, dict) and data.get("type") == "Feature":
|
||||
return [data]
|
||||
if isinstance(data, list):
|
||||
return [f for f in data if isinstance(f, dict)]
|
||||
raise ValueError("OSM source must be GeoJSON FeatureCollection, Feature, or list of Features")
|
||||
|
||||
|
||||
def _infer_mode(props: dict[str, Any]) -> str | None:
|
||||
for key in ("mode", "route", "route_master"):
|
||||
value = str(props.get(key) or "").strip()
|
||||
if value in ROUTE_MODES:
|
||||
return "train" if value == "railway" else value
|
||||
railway = str(props.get("railway") or "").strip()
|
||||
if railway in {"station", "halt"}:
|
||||
return "train"
|
||||
if railway == "tram_stop":
|
||||
return "tram"
|
||||
if railway == "subway_entrance":
|
||||
return "subway"
|
||||
if str(props.get("highway") or "") == "bus_stop" or str(props.get("amenity") or "") == "bus_station":
|
||||
return "bus"
|
||||
if str(props.get("amenity") or "") == "ferry_terminal":
|
||||
return "ferry"
|
||||
if str(props.get("aerialway") or "") == "station":
|
||||
return "aerialway"
|
||||
return None
|
||||
|
||||
|
||||
def _infer_kind(props: dict[str, Any], mode: str | None) -> str:
|
||||
explicit_kind = str(props.get("kind") or "").strip()
|
||||
if explicit_kind in {"route", "stop", "station", "terminal", "infra", "feature"}:
|
||||
return explicit_kind
|
||||
if str(props.get("type") or "") in {"route", "route_master"} or str(props.get("route") or "") in ROUTE_MODES:
|
||||
return "route"
|
||||
if str(props.get("amenity") or "") == "ferry_terminal":
|
||||
return "terminal"
|
||||
if str(props.get("amenity") or "") == "bus_station":
|
||||
return "terminal"
|
||||
if str(props.get("railway") or "") in {"station", "halt"}:
|
||||
return "station"
|
||||
if str(props.get("aerialway") or "") == "station":
|
||||
return "station"
|
||||
if str(props.get("public_transport") or "") in {"platform", "stop_position", "station"}:
|
||||
return "stop"
|
||||
if str(props.get("highway") or "") == "bus_stop":
|
||||
return "stop"
|
||||
if mode:
|
||||
return "infra"
|
||||
return "feature"
|
||||
456
app/pipeline/osm_labeling.py
Normal file
456
app/pipeline/osm_labeling.py
Normal file
@@ -0,0 +1,456 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import json
|
||||
from pathlib import Path
|
||||
import sqlite3
|
||||
from typing import Callable
|
||||
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import Dataset, OsmFeature
|
||||
from app.osm_classification import OSM_ROUTE_SCOPE_CLASSIFIER_VERSION, infer_osm_route_scope_from_tags
|
||||
from app.osm_storage import (
|
||||
dataset_metadata,
|
||||
drop_osm_sidecar_route_scope_indexes,
|
||||
ensure_osm_sidecar_schema,
|
||||
features_are_sidecar,
|
||||
rebuild_osm_sidecar_indexes,
|
||||
sidecar_path,
|
||||
writable_sidecar_connection,
|
||||
)
|
||||
from app.pipeline.state import (
|
||||
STAGE_BUILD_INDEXES,
|
||||
STAGE_LABEL_FEATURES,
|
||||
dependency_hash,
|
||||
finish_pipeline_run,
|
||||
latest_completed_run,
|
||||
start_pipeline_run,
|
||||
)
|
||||
|
||||
|
||||
OSM_LABEL_FEATURES_VERSION = OSM_ROUTE_SCOPE_CLASSIFIER_VERSION
|
||||
MAIN_ROUTE_SCOPE_INDEX = "ix_osm_features_scope_bbox"
|
||||
MAIN_INDEX_REBUILD_THRESHOLD = 10_000
|
||||
SIDECAR_INDEX_REBUILD_THRESHOLD = 10_000
|
||||
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
|
||||
|
||||
|
||||
def relabel_osm_features(
|
||||
session: Session,
|
||||
*,
|
||||
dataset_id: int | None = None,
|
||||
chunk_size: int = 5000,
|
||||
force: bool = False,
|
||||
rebuild_indexes: bool = True,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
job_id: int | None = None,
|
||||
) -> dict[str, object]:
|
||||
datasets = _target_datasets(session, dataset_id)
|
||||
result: dict[str, object] = {
|
||||
"version": OSM_LABEL_FEATURES_VERSION,
|
||||
"datasets": len(datasets),
|
||||
"processed": 0,
|
||||
"changed": 0,
|
||||
"skipped": 0,
|
||||
"missing": 0,
|
||||
"index_rebuilds": 0,
|
||||
"dataset_results": [],
|
||||
}
|
||||
_emit_progress(
|
||||
progress_callback,
|
||||
"osm_labeling_started",
|
||||
f"Relabeling {len(datasets)} OSM dataset(s).",
|
||||
0,
|
||||
len(datasets),
|
||||
{"dataset_id": dataset_id, "force": force, "version": OSM_LABEL_FEATURES_VERSION},
|
||||
)
|
||||
for index, dataset in enumerate(datasets, start=1):
|
||||
dataset_result = relabel_osm_dataset(
|
||||
session,
|
||||
dataset,
|
||||
chunk_size=chunk_size,
|
||||
force=force,
|
||||
rebuild_indexes=rebuild_indexes,
|
||||
progress_callback=progress_callback,
|
||||
job_id=job_id,
|
||||
)
|
||||
result["processed"] = int(result["processed"]) + int(dataset_result.get("processed", 0) or 0)
|
||||
result["changed"] = int(result["changed"]) + int(dataset_result.get("changed", 0) or 0)
|
||||
result["skipped"] = int(result["skipped"]) + (1 if dataset_result.get("status") == "skipped" else 0)
|
||||
result["missing"] = int(result["missing"]) + (1 if dataset_result.get("status") == "missing_sidecar" else 0)
|
||||
result["index_rebuilds"] = int(result["index_rebuilds"]) + int(dataset_result.get("index_rebuilds", 0) or 0)
|
||||
result["dataset_results"].append(dataset_result) # type: ignore[union-attr]
|
||||
_emit_progress(
|
||||
progress_callback,
|
||||
"osm_labeling_dataset_completed",
|
||||
f"Relabeled {index}/{len(datasets)} OSM dataset(s).",
|
||||
index,
|
||||
len(datasets),
|
||||
dataset_result,
|
||||
)
|
||||
_emit_progress(progress_callback, "osm_labeling_completed", "OSM feature relabeling completed.", len(datasets), len(datasets), result)
|
||||
return result
|
||||
|
||||
|
||||
def relabel_osm_dataset(
|
||||
session: Session,
|
||||
dataset: Dataset,
|
||||
*,
|
||||
chunk_size: int = 5000,
|
||||
force: bool = False,
|
||||
rebuild_indexes: bool = True,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
job_id: int | None = None,
|
||||
) -> dict[str, object]:
|
||||
dependency = _label_dependency(dataset)
|
||||
dependency_hash_value = dependency_hash(dependency)
|
||||
if not force and _dataset_label_is_current(session, dataset, dependency_hash_value):
|
||||
return {
|
||||
"dataset_id": dataset.id,
|
||||
"source_id": dataset.source_id,
|
||||
"status": "skipped",
|
||||
"reason": "label_features dependency is current",
|
||||
"dependency_hash": dependency_hash_value,
|
||||
"version": OSM_LABEL_FEATURES_VERSION,
|
||||
"processed": 0,
|
||||
"changed": 0,
|
||||
"index_rebuilds": 0,
|
||||
}
|
||||
|
||||
run = start_pipeline_run(
|
||||
session,
|
||||
stage=STAGE_LABEL_FEATURES,
|
||||
version=OSM_LABEL_FEATURES_VERSION,
|
||||
dependency_hash_value=dependency_hash_value,
|
||||
source_id=dataset.source_id,
|
||||
dataset_id=dataset.id,
|
||||
job_id=job_id,
|
||||
inputs=dependency,
|
||||
)
|
||||
session.commit()
|
||||
try:
|
||||
if features_are_sidecar(dataset):
|
||||
counts = _relabel_sidecar_dataset(dataset, chunk_size=chunk_size, rebuild_indexes=rebuild_indexes, progress_callback=progress_callback)
|
||||
else:
|
||||
counts = _relabel_main_dataset(session, dataset, chunk_size=chunk_size, rebuild_indexes=rebuild_indexes, progress_callback=progress_callback)
|
||||
output = {
|
||||
"dataset_id": dataset.id,
|
||||
"source_id": dataset.source_id,
|
||||
"status": "completed",
|
||||
"dependency_hash": dependency_hash_value,
|
||||
"version": OSM_LABEL_FEATURES_VERSION,
|
||||
**counts,
|
||||
}
|
||||
_stamp_dataset_metadata(session, dataset, dependency_hash_value, output)
|
||||
finish_pipeline_run(session, run, outputs=output)
|
||||
session.commit()
|
||||
return output
|
||||
except FileNotFoundError as exc:
|
||||
output = {
|
||||
"dataset_id": dataset.id,
|
||||
"source_id": dataset.source_id,
|
||||
"status": "missing_sidecar",
|
||||
"dependency_hash": dependency_hash_value,
|
||||
"version": OSM_LABEL_FEATURES_VERSION,
|
||||
"processed": 0,
|
||||
"changed": 0,
|
||||
"index_rebuilds": 0,
|
||||
"error": str(exc),
|
||||
}
|
||||
finish_pipeline_run(session, run, status="failed", outputs=output, error=str(exc))
|
||||
session.commit()
|
||||
return output
|
||||
except Exception as exc:
|
||||
finish_pipeline_run(session, run, status="failed", error=str(exc))
|
||||
session.commit()
|
||||
raise
|
||||
|
||||
|
||||
def _target_datasets(session: Session, dataset_id: int | None) -> list[Dataset]:
|
||||
stmt = select(Dataset).where(Dataset.kind == "osm_geojson", Dataset.status == "imported")
|
||||
if dataset_id is None:
|
||||
stmt = stmt.where(Dataset.is_active.is_(True))
|
||||
else:
|
||||
stmt = stmt.where(Dataset.id == dataset_id)
|
||||
return session.scalars(stmt.order_by(Dataset.source_id, Dataset.id)).all()
|
||||
|
||||
|
||||
def _dataset_label_is_current(session: Session, dataset: Dataset, dependency_hash_value: str) -> bool:
|
||||
metadata = dataset_metadata(dataset)
|
||||
label_info = metadata.get("label_features")
|
||||
metadata_current = (
|
||||
isinstance(label_info, dict)
|
||||
and label_info.get("version") == OSM_LABEL_FEATURES_VERSION
|
||||
and label_info.get("dependency_hash") == dependency_hash_value
|
||||
)
|
||||
if not metadata_current:
|
||||
return False
|
||||
return (
|
||||
latest_completed_run(
|
||||
session,
|
||||
stage=STAGE_LABEL_FEATURES,
|
||||
version=OSM_LABEL_FEATURES_VERSION,
|
||||
dependency_hash_value=dependency_hash_value,
|
||||
source_id=dataset.source_id,
|
||||
dataset_id=dataset.id,
|
||||
)
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def _relabel_sidecar_dataset(
|
||||
dataset: Dataset,
|
||||
*,
|
||||
chunk_size: int,
|
||||
rebuild_indexes: bool,
|
||||
progress_callback: ProgressCallback | None,
|
||||
) -> dict[str, int | str]:
|
||||
path = sidecar_path(dataset)
|
||||
if path is None or not path.exists():
|
||||
raise FileNotFoundError(f"OSM sidecar does not exist: {path}")
|
||||
with writable_sidecar_connection(dataset) as connection:
|
||||
ensure_osm_sidecar_schema(connection)
|
||||
total = int(connection.execute("SELECT COUNT(*) FROM osm_features").fetchone()[0] or 0)
|
||||
should_rebuild_index = rebuild_indexes and total >= SIDECAR_INDEX_REBUILD_THRESHOLD
|
||||
if should_rebuild_index:
|
||||
drop_osm_sidecar_route_scope_indexes(connection)
|
||||
connection.commit()
|
||||
processed = 0
|
||||
changed = 0
|
||||
last_id = 0
|
||||
try:
|
||||
while True:
|
||||
rows = connection.execute(
|
||||
"""
|
||||
SELECT id, mode, ref, name, network, tags_json, route_scope
|
||||
FROM osm_features
|
||||
WHERE id > ?
|
||||
ORDER BY id
|
||||
LIMIT ?
|
||||
""",
|
||||
(last_id, max(1, int(chunk_size))),
|
||||
).fetchall()
|
||||
if not rows:
|
||||
break
|
||||
updates: list[tuple[str | None, int]] = []
|
||||
for row in rows:
|
||||
last_id = int(row["id"])
|
||||
new_scope = _classified_scope(row["mode"], row["ref"], row["name"], row["network"], row["tags_json"])
|
||||
if _normalize_scope(row["route_scope"]) != new_scope:
|
||||
updates.append((new_scope, last_id))
|
||||
if updates:
|
||||
connection.executemany("UPDATE osm_features SET route_scope = ? WHERE id = ?", updates)
|
||||
processed += len(rows)
|
||||
changed += len(updates)
|
||||
connection.commit()
|
||||
_emit_progress(
|
||||
progress_callback,
|
||||
"osm_labeling_batch",
|
||||
f"Relabeled {processed}/{total} OSM sidecar features.",
|
||||
processed,
|
||||
total,
|
||||
{"dataset_id": dataset.id, "changed": changed, "storage": "sidecar"},
|
||||
)
|
||||
finally:
|
||||
index_rebuilds = 0
|
||||
if should_rebuild_index:
|
||||
rebuild_osm_sidecar_indexes(connection)
|
||||
connection.commit()
|
||||
index_rebuilds = 1
|
||||
_record_sidecar_index_build(connection, dataset, path)
|
||||
_record_sidecar_label(connection, dataset, processed=processed, changed=changed)
|
||||
connection.commit()
|
||||
return {"storage": "sidecar", "processed": processed, "changed": changed, "index_rebuilds": index_rebuilds}
|
||||
|
||||
|
||||
def _relabel_main_dataset(
|
||||
session: Session,
|
||||
dataset: Dataset,
|
||||
*,
|
||||
chunk_size: int,
|
||||
rebuild_indexes: bool,
|
||||
progress_callback: ProgressCallback | None,
|
||||
) -> dict[str, int | str]:
|
||||
total = int(session.scalar(select(func.count()).select_from(OsmFeature).where(OsmFeature.dataset_id == dataset.id)) or 0)
|
||||
should_rebuild_index = rebuild_indexes and total >= MAIN_INDEX_REBUILD_THRESHOLD
|
||||
index_rebuilds = 0
|
||||
if should_rebuild_index:
|
||||
session.execute(text(f"DROP INDEX IF EXISTS {MAIN_ROUTE_SCOPE_INDEX}"))
|
||||
session.commit()
|
||||
processed = 0
|
||||
changed = 0
|
||||
last_id = 0
|
||||
try:
|
||||
while True:
|
||||
rows = session.scalars(
|
||||
select(OsmFeature)
|
||||
.where(OsmFeature.dataset_id == dataset.id, OsmFeature.id > last_id)
|
||||
.order_by(OsmFeature.id)
|
||||
.limit(max(1, int(chunk_size)))
|
||||
).all()
|
||||
if not rows:
|
||||
break
|
||||
updates: list[dict[str, object]] = []
|
||||
for feature in rows:
|
||||
last_id = int(feature.id)
|
||||
new_scope = _classified_scope(feature.mode, feature.ref, feature.name, feature.network, feature.tags_json)
|
||||
if _normalize_scope(feature.route_scope) != new_scope:
|
||||
updates.append({"id": feature.id, "route_scope": new_scope})
|
||||
if updates:
|
||||
session.bulk_update_mappings(OsmFeature, updates)
|
||||
processed += len(rows)
|
||||
changed += len(updates)
|
||||
session.commit()
|
||||
_emit_progress(
|
||||
progress_callback,
|
||||
"osm_labeling_batch",
|
||||
f"Relabeled {processed}/{total} main-table OSM features.",
|
||||
processed,
|
||||
total,
|
||||
{"dataset_id": dataset.id, "changed": changed, "storage": "main"},
|
||||
)
|
||||
finally:
|
||||
if should_rebuild_index:
|
||||
session.execute(
|
||||
text(
|
||||
"CREATE INDEX IF NOT EXISTS ix_osm_features_scope_bbox "
|
||||
"ON osm_features (dataset_id, kind, mode, route_scope, min_lon, max_lon, min_lat, max_lat)"
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
index_rebuilds = 1
|
||||
_record_main_index_build(session, dataset)
|
||||
return {"storage": "main", "processed": processed, "changed": changed, "index_rebuilds": index_rebuilds}
|
||||
|
||||
|
||||
def _classified_scope(mode: object, ref: object, name: object, network: object, tags_json: object) -> str | None:
|
||||
return _normalize_scope(
|
||||
infer_osm_route_scope_from_tags(
|
||||
None if mode is None else str(mode),
|
||||
None if ref is None else str(ref),
|
||||
None if name is None else str(name),
|
||||
None if network is None else str(network),
|
||||
None if tags_json is None else str(tags_json),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _normalize_scope(value: object) -> str | None:
|
||||
text_value = str(value or "").strip()
|
||||
return text_value or None
|
||||
|
||||
|
||||
def _label_dependency(dataset: Dataset) -> dict[str, object]:
|
||||
metadata = dataset_metadata(dataset)
|
||||
storage = metadata.get("osm_storage") if isinstance(metadata, dict) else None
|
||||
path = sidecar_path(dataset)
|
||||
path_fingerprint: dict[str, object] | None = None
|
||||
if path is not None:
|
||||
resolved = Path(path)
|
||||
if resolved.exists():
|
||||
path_fingerprint = {"path": str(resolved), "exists": True}
|
||||
else:
|
||||
path_fingerprint = {"path": str(resolved), "missing": True}
|
||||
return {
|
||||
"dataset_id": dataset.id,
|
||||
"source_id": dataset.source_id,
|
||||
"kind": dataset.kind,
|
||||
"dataset_sha256": dataset.sha256,
|
||||
"storage": storage,
|
||||
"sidecar": path_fingerprint,
|
||||
"classifier_version": OSM_LABEL_FEATURES_VERSION,
|
||||
}
|
||||
|
||||
|
||||
def _stamp_dataset_metadata(session: Session, dataset: Dataset, dependency_hash_value: str, output: dict[str, object]) -> None:
|
||||
refreshed = session.get(Dataset, dataset.id)
|
||||
if refreshed is None:
|
||||
return
|
||||
metadata = dataset_metadata(refreshed)
|
||||
metadata["label_features"] = {
|
||||
"stage": STAGE_LABEL_FEATURES,
|
||||
"version": OSM_LABEL_FEATURES_VERSION,
|
||||
"dependency_hash": dependency_hash_value,
|
||||
"labeled_at": datetime.now(timezone.utc).isoformat(),
|
||||
"processed": output.get("processed", 0),
|
||||
"changed": output.get("changed", 0),
|
||||
"storage": output.get("storage"),
|
||||
}
|
||||
refreshed.metadata_json = json.dumps(metadata, indent=2)
|
||||
session.flush()
|
||||
|
||||
|
||||
def _record_sidecar_label(connection: sqlite3.Connection, dataset: Dataset, *, processed: int, changed: int) -> None:
|
||||
connection.execute("CREATE TABLE IF NOT EXISTS pipeline_metadata (key TEXT PRIMARY KEY, value TEXT NOT NULL)")
|
||||
connection.execute(
|
||||
"INSERT OR REPLACE INTO pipeline_metadata (key, value) VALUES (?, ?)",
|
||||
(
|
||||
"label_features",
|
||||
json.dumps(
|
||||
{
|
||||
"stage": STAGE_LABEL_FEATURES,
|
||||
"version": OSM_LABEL_FEATURES_VERSION,
|
||||
"dataset_id": dataset.id,
|
||||
"processed": processed,
|
||||
"changed": changed,
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _record_sidecar_index_build(connection: sqlite3.Connection, dataset: Dataset, path: Path) -> None:
|
||||
connection.execute("CREATE TABLE IF NOT EXISTS pipeline_metadata (key TEXT PRIMARY KEY, value TEXT NOT NULL)")
|
||||
connection.execute(
|
||||
"INSERT OR REPLACE INTO pipeline_metadata (key, value) VALUES (?, ?)",
|
||||
(
|
||||
"build_indexes:route_scope",
|
||||
json.dumps(
|
||||
{
|
||||
"stage": STAGE_BUILD_INDEXES,
|
||||
"version": "osm_sidecar_indexes_v1",
|
||||
"dataset_id": dataset.id,
|
||||
"path": str(path),
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
},
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _record_main_index_build(session: Session, dataset: Dataset) -> None:
|
||||
dependency = {
|
||||
"dataset_id": dataset.id,
|
||||
"index": MAIN_ROUTE_SCOPE_INDEX,
|
||||
"version": "osm_main_indexes_v1",
|
||||
}
|
||||
run = start_pipeline_run(
|
||||
session,
|
||||
stage=STAGE_BUILD_INDEXES,
|
||||
version="osm_main_indexes_v1",
|
||||
dependency_hash_value=dependency_hash(dependency),
|
||||
source_id=dataset.source_id,
|
||||
dataset_id=dataset.id,
|
||||
inputs=dependency,
|
||||
)
|
||||
finish_pipeline_run(session, run, outputs={"index": MAIN_ROUTE_SCOPE_INDEX})
|
||||
session.commit()
|
||||
|
||||
|
||||
def _emit_progress(
|
||||
callback: ProgressCallback | None,
|
||||
event_type: str,
|
||||
message: str,
|
||||
current: int | None,
|
||||
total: int | None,
|
||||
metadata: dict[str, object] | None,
|
||||
) -> None:
|
||||
if callback is not None:
|
||||
callback(event_type, message, current, total, metadata)
|
||||
1581
app/pipeline/osm_pbf.py
Normal file
1581
app/pipeline/osm_pbf.py
Normal file
File diff suppressed because it is too large
Load Diff
105
app/pipeline/osm_replication.py
Normal file
105
app/pipeline/osm_replication.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReplicationState:
|
||||
sequence_number: int
|
||||
timestamp: str | None
|
||||
raw: dict[str, str]
|
||||
|
||||
|
||||
def fetch_replication_state(updates_url: str, *, timeout: float = 30) -> ReplicationState:
|
||||
state_url = _state_url(updates_url)
|
||||
response = requests.get(state_url, timeout=timeout)
|
||||
response.raise_for_status()
|
||||
return parse_replication_state_text(response.text)
|
||||
|
||||
|
||||
def parse_replication_state_text(text: str) -> ReplicationState:
|
||||
values: dict[str, str] = {}
|
||||
for line in text.splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
key, value = line.split("=", 1)
|
||||
values[key.strip()] = _unescape_state_value(value.strip())
|
||||
sequence = values.get("sequenceNumber")
|
||||
if sequence is None:
|
||||
raise ValueError("replication state is missing sequenceNumber")
|
||||
try:
|
||||
sequence_number = int(sequence)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"invalid replication sequenceNumber: {sequence}") from exc
|
||||
return ReplicationState(
|
||||
sequence_number=sequence_number,
|
||||
timestamp=values.get("timestamp"),
|
||||
raw=values,
|
||||
)
|
||||
|
||||
|
||||
def diff_url_for_sequence(updates_url: str, sequence_number: int) -> str:
|
||||
padded = str(sequence_number).zfill(max(9, ((len(str(sequence_number)) + 2) // 3) * 3))
|
||||
parts = [padded[index : index + 3] for index in range(0, len(padded), 3)]
|
||||
return urljoin(_directory_url(updates_url), "/".join(parts) + ".osc.gz")
|
||||
|
||||
|
||||
def download_diff(updates_url: str, sequence_number: int, output_dir: Path, *, timeout: float = 120) -> Path:
|
||||
url = diff_url_for_sequence(updates_url, sequence_number)
|
||||
parsed_path = Path(urlparse(url).path)
|
||||
output_path = output_dir / parsed_path.name
|
||||
nested = output_dir / parsed_path.parent.name / output_path.name
|
||||
if output_path.exists():
|
||||
return output_path
|
||||
if nested.exists():
|
||||
return nested
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_path = output_dir / f"{sequence_number}.download"
|
||||
with requests.get(url, stream=True, timeout=timeout) as response:
|
||||
response.raise_for_status()
|
||||
with temp_path.open("wb") as handle:
|
||||
for chunk in response.iter_content(chunk_size=1024 * 1024):
|
||||
if chunk:
|
||||
handle.write(chunk)
|
||||
temp_path.replace(output_path)
|
||||
return output_path
|
||||
|
||||
|
||||
def apply_osm_changes(base_path: Path, diff_paths: list[Path], output_path: Path, host_tool_path: Path) -> subprocess.CompletedProcess[str]:
|
||||
if not diff_paths:
|
||||
raise ValueError("no OSM change files supplied")
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
command = [
|
||||
str(host_tool_path),
|
||||
"osmium",
|
||||
"apply-changes",
|
||||
"--output",
|
||||
str(output_path),
|
||||
"--overwrite",
|
||||
str(base_path),
|
||||
*[str(path) for path in diff_paths],
|
||||
]
|
||||
return subprocess.run(command, check=True, capture_output=True, text=True)
|
||||
|
||||
|
||||
def _state_url(updates_url: str) -> str:
|
||||
return urljoin(_directory_url(updates_url), "state.txt")
|
||||
|
||||
|
||||
def _directory_url(url: str) -> str:
|
||||
return url if url.endswith("/") else f"{url}/"
|
||||
|
||||
|
||||
def _unescape_state_value(value: str) -> str:
|
||||
return (
|
||||
value.replace("\\:", ":")
|
||||
.replace("\\=", "=")
|
||||
.replace("\\ ", " ")
|
||||
.replace("\\\\", "\\")
|
||||
)
|
||||
1903
app/pipeline/route_layer.py
Normal file
1903
app/pipeline/route_layer.py
Normal file
File diff suppressed because it is too large
Load Diff
473
app/pipeline/routing_layer.py
Normal file
473
app/pipeline/routing_layer.py
Normal file
@@ -0,0 +1,473 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import osmium
|
||||
from sqlalchemy import delete, func, select, text
|
||||
from sqlalchemy.dialects.postgresql import insert as postgresql_insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models import Dataset, RoutingEdge, RoutingNode
|
||||
from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries
|
||||
|
||||
|
||||
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
|
||||
ROUTING_LAYER_VERSION = "routing_layer_v2_osm_highway_segments_service_tags"
|
||||
|
||||
DRIVE_HIGHWAYS = {
|
||||
"motorway",
|
||||
"motorway_link",
|
||||
"trunk",
|
||||
"trunk_link",
|
||||
"primary",
|
||||
"primary_link",
|
||||
"secondary",
|
||||
"secondary_link",
|
||||
"tertiary",
|
||||
"tertiary_link",
|
||||
"unclassified",
|
||||
"residential",
|
||||
"living_street",
|
||||
"service",
|
||||
"road",
|
||||
"track",
|
||||
}
|
||||
WALK_HIGHWAYS = {
|
||||
"pedestrian",
|
||||
"footway",
|
||||
"path",
|
||||
"steps",
|
||||
"cycleway",
|
||||
"bridleway",
|
||||
"living_street",
|
||||
"residential",
|
||||
"service",
|
||||
"track",
|
||||
"unclassified",
|
||||
"tertiary",
|
||||
"tertiary_link",
|
||||
"secondary",
|
||||
"secondary_link",
|
||||
"primary",
|
||||
"primary_link",
|
||||
"road",
|
||||
}
|
||||
EXCLUDED_HIGHWAYS = {"construction", "proposed", "abandoned", "platform", "raceway"}
|
||||
NO_VALUES = {"no", "private", "agricultural", "forestry", "delivery", "customers"}
|
||||
YES_VALUES = {"yes", "designated", "permissive", "destination"}
|
||||
ONEWAY_FORWARD = {"yes", "true", "1"}
|
||||
ONEWAY_REVERSE = {"-1", "reverse"}
|
||||
DEFAULT_DRIVE_SPEED_KMH = {
|
||||
"motorway": 110,
|
||||
"motorway_link": 50,
|
||||
"trunk": 90,
|
||||
"trunk_link": 45,
|
||||
"primary": 70,
|
||||
"primary_link": 40,
|
||||
"secondary": 60,
|
||||
"secondary_link": 35,
|
||||
"tertiary": 50,
|
||||
"tertiary_link": 30,
|
||||
"unclassified": 40,
|
||||
"residential": 30,
|
||||
"living_street": 10,
|
||||
"service": 15,
|
||||
"road": 30,
|
||||
"track": 15,
|
||||
}
|
||||
DEFAULT_WALK_SPEED_MPS = 1.35
|
||||
STEP_WALK_SPEED_MPS = 0.65
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoutingImportResult:
|
||||
dataset_id: int
|
||||
input_path: str
|
||||
nodes: int
|
||||
edges: int
|
||||
walk_edges: int
|
||||
drive_edges: int
|
||||
skipped_ways: int
|
||||
version: str = ROUTING_LAYER_VERSION
|
||||
|
||||
def as_dict(self) -> dict[str, object]:
|
||||
return {
|
||||
"version": self.version,
|
||||
"dataset_id": self.dataset_id,
|
||||
"input_path": self.input_path,
|
||||
"nodes": self.nodes,
|
||||
"edges": self.edges,
|
||||
"walk_edges": self.walk_edges,
|
||||
"drive_edges": self.drive_edges,
|
||||
"skipped_ways": self.skipped_ways,
|
||||
}
|
||||
|
||||
|
||||
def active_routing_dataset(session: Session) -> Dataset | None:
|
||||
active_osm = session.scalar(
|
||||
select(Dataset).where(Dataset.kind == "osm_geojson", Dataset.is_active.is_(True)).order_by(Dataset.id.desc())
|
||||
)
|
||||
if active_osm is not None:
|
||||
metadata = _metadata(active_osm)
|
||||
raw_dataset_id = metadata.get("raw_dataset_id")
|
||||
if raw_dataset_id is not None:
|
||||
raw = session.get(Dataset, int(raw_dataset_id))
|
||||
if raw is not None and Path(raw.local_path).exists():
|
||||
return raw
|
||||
return session.scalar(
|
||||
select(Dataset)
|
||||
.where(Dataset.kind == "osm_pbf_raw")
|
||||
.order_by(Dataset.is_active.desc(), Dataset.id.desc())
|
||||
)
|
||||
|
||||
|
||||
def rebuild_routing_layer(
|
||||
session: Session,
|
||||
*,
|
||||
dataset_id: int | None = None,
|
||||
input_path: str | Path | None = None,
|
||||
reset: bool = True,
|
||||
batch_size: int = 5000,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
) -> dict[str, object]:
|
||||
if not settings.is_postgresql_database:
|
||||
raise RuntimeError("The routing layer importer requires PostgreSQL/PostGIS.")
|
||||
dataset = session.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(session)
|
||||
if dataset is None:
|
||||
raise ValueError("No OSM PBF dataset is available for routing import.")
|
||||
path = Path(input_path or dataset.local_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Routing import PBF does not exist: {path}")
|
||||
|
||||
if reset:
|
||||
_emit(progress_callback, "routing_layer_clear_started", "Clearing existing routing graph.", None, None, {"dataset_id": dataset.id})
|
||||
session.execute(delete(RoutingEdge).where(RoutingEdge.dataset_id == dataset.id))
|
||||
session.execute(delete(RoutingNode).where(RoutingNode.dataset_id == dataset.id))
|
||||
session.commit()
|
||||
|
||||
_emit(progress_callback, "routing_layer_import_started", "Importing routable OSM highway graph.", None, None, {"dataset_id": dataset.id, "path": str(path)})
|
||||
handler = _RoutingGraphHandler(session=session, dataset_id=dataset.id, batch_size=batch_size, progress_callback=progress_callback)
|
||||
handler.apply_file(str(path), locations=True)
|
||||
handler.flush()
|
||||
|
||||
return finalize_routing_layer(
|
||||
session,
|
||||
dataset_id=dataset.id,
|
||||
input_path=str(path),
|
||||
skipped_way_count=handler.skipped_way_count,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
|
||||
def finalize_routing_layer(
|
||||
session: Session,
|
||||
*,
|
||||
dataset_id: int | None = None,
|
||||
input_path: str | Path | None = None,
|
||||
skipped_way_count: int = 0,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
) -> dict[str, object]:
|
||||
if not settings.is_postgresql_database:
|
||||
raise RuntimeError("The routing layer finalizer requires PostgreSQL/PostGIS.")
|
||||
dataset = session.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(session)
|
||||
if dataset is None:
|
||||
raise ValueError("No routing dataset is available to finalize.")
|
||||
path = Path(input_path or dataset.local_path)
|
||||
_emit(progress_callback, "routing_layer_geometry_indexes_dropped", "Dropping routing geometry indexes before bulk refresh.", None, None, {"dataset_id": dataset.id})
|
||||
_drop_routing_geometry_indexes(session)
|
||||
session.commit()
|
||||
_emit(progress_callback, "routing_layer_geometry_started", "Refreshing routing node PostGIS geometries.", None, None, {"dataset_id": dataset.id})
|
||||
refresh_postgis_geometries(session, dataset_id=dataset.id, tables=["routing_nodes"], only_missing=False)
|
||||
session.commit()
|
||||
_emit(progress_callback, "routing_layer_geometry_indexes_started", "Rebuilding routing geometry indexes.", None, None, {"dataset_id": dataset.id})
|
||||
_create_routing_geometry_indexes(session)
|
||||
session.commit()
|
||||
analyze_postgresql_tables(session, ["routing_nodes", "routing_edges"])
|
||||
node_count = int(session.scalar(select(func.count()).select_from(RoutingNode).where(RoutingNode.dataset_id == dataset.id)) or 0)
|
||||
edge_count = int(session.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset.id)) or 0)
|
||||
walk_edge_count = int(session.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset.id, RoutingEdge.walk_cost_s.is_not(None))) or 0)
|
||||
drive_edge_count = int(session.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset.id, RoutingEdge.drive_cost_s.is_not(None))) or 0)
|
||||
dataset_metadata = _metadata(dataset)
|
||||
dataset_metadata["routing_layer"] = {
|
||||
"version": ROUTING_LAYER_VERSION,
|
||||
"nodes": node_count,
|
||||
"edges": edge_count,
|
||||
"walk_edges": walk_edge_count,
|
||||
"drive_edges": drive_edge_count,
|
||||
"input_path": str(path),
|
||||
}
|
||||
dataset.metadata_json = json.dumps(dataset_metadata, indent=2)
|
||||
session.commit()
|
||||
result = RoutingImportResult(
|
||||
dataset_id=dataset.id,
|
||||
input_path=str(path),
|
||||
nodes=node_count,
|
||||
edges=edge_count,
|
||||
walk_edges=walk_edge_count,
|
||||
drive_edges=drive_edge_count,
|
||||
skipped_ways=skipped_way_count,
|
||||
).as_dict()
|
||||
_emit(progress_callback, "routing_layer_import_completed", "Routing graph import completed.", edge_count, edge_count, result)
|
||||
return result
|
||||
|
||||
|
||||
def _drop_routing_geometry_indexes(session: Session) -> None:
|
||||
session.execute(text("DROP INDEX IF EXISTS ix_routing_nodes_geom_gist"))
|
||||
session.execute(text("DROP INDEX IF EXISTS ix_routing_edges_geom_gist"))
|
||||
session.execute(text("DROP INDEX IF EXISTS ix_routing_edges_bbox_box_gist"))
|
||||
|
||||
|
||||
def _create_routing_geometry_indexes(session: Session) -> None:
|
||||
session.execute(text("CREATE INDEX IF NOT EXISTS ix_routing_nodes_geom_gist ON routing_nodes USING GIST (geom)"))
|
||||
session.execute(text("CREATE INDEX IF NOT EXISTS ix_routing_edges_bbox_box_gist ON routing_edges USING GIST (box(point(max_lon, max_lat), point(min_lon, min_lat)))"))
|
||||
|
||||
|
||||
class _RoutingGraphHandler(osmium.SimpleHandler):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
dataset_id: int,
|
||||
batch_size: int,
|
||||
progress_callback: ProgressCallback | None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.session = session
|
||||
self.dataset_id = dataset_id
|
||||
self.batch_size = max(500, int(batch_size))
|
||||
self.progress_callback = progress_callback
|
||||
self.nodes: dict[int, dict[str, object]] = {}
|
||||
self.edges: list[dict[str, object]] = []
|
||||
self.node_count = int(
|
||||
session.scalar(select(func.count()).select_from(RoutingNode).where(RoutingNode.dataset_id == dataset_id)) or 0
|
||||
)
|
||||
self.edge_count = int(
|
||||
session.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset_id)) or 0
|
||||
)
|
||||
self.walk_edge_count = 0
|
||||
self.drive_edge_count = 0
|
||||
self.skipped_way_count = 0
|
||||
self.processed_way_count = 0
|
||||
|
||||
def way(self, way) -> None:
|
||||
tags = {tag.k: tag.v for tag in way.tags}
|
||||
highway = tags.get("highway")
|
||||
if not highway or highway in EXCLUDED_HIGHWAYS:
|
||||
self.skipped_way_count += 1
|
||||
return
|
||||
walkable = _walkable(tags, highway)
|
||||
drivable = _drivable(tags, highway)
|
||||
if not walkable and not drivable:
|
||||
self.skipped_way_count += 1
|
||||
return
|
||||
|
||||
nodes = []
|
||||
for node in way.nodes:
|
||||
if not node.location.valid():
|
||||
continue
|
||||
nodes.append((int(node.ref), float(node.location.lon), float(node.location.lat)))
|
||||
if len(nodes) < 2:
|
||||
self.skipped_way_count += 1
|
||||
return
|
||||
|
||||
oneway = _oneway_direction(tags, highway)
|
||||
drive_speed_mps = _drive_speed_mps(tags, highway)
|
||||
walk_speed_mps = STEP_WALK_SPEED_MPS if highway == "steps" else DEFAULT_WALK_SPEED_MPS
|
||||
for left, right in zip(nodes, nodes[1:]):
|
||||
source_id, source_lon, source_lat = left
|
||||
target_id, target_lon, target_lat = right
|
||||
if source_id == target_id:
|
||||
continue
|
||||
length_m = _distance_m(source_lat, source_lon, target_lat, target_lon)
|
||||
if length_m <= 0:
|
||||
continue
|
||||
if oneway == "reverse":
|
||||
source_id, target_id = target_id, source_id
|
||||
source_lon, target_lon = target_lon, source_lon
|
||||
source_lat, target_lat = target_lat, source_lat
|
||||
|
||||
walk_cost = length_m / walk_speed_mps if walkable else None
|
||||
drive_cost = length_m / drive_speed_mps if drivable and drive_speed_mps > 0 else None
|
||||
reverse_walk_cost = walk_cost
|
||||
reverse_drive_cost = None if oneway in {"forward", "reverse"} else drive_cost
|
||||
self.nodes[source_id] = {"dataset_id": self.dataset_id, "osm_node_id": source_id, "lon": source_lon, "lat": source_lat}
|
||||
self.nodes[target_id] = {"dataset_id": self.dataset_id, "osm_node_id": target_id, "lon": target_lon, "lat": target_lat}
|
||||
self.edges.append(
|
||||
{
|
||||
"dataset_id": self.dataset_id,
|
||||
"osm_way_id": int(way.id),
|
||||
"source_osm_node_id": source_id,
|
||||
"target_osm_node_id": target_id,
|
||||
"source_lon": source_lon,
|
||||
"source_lat": source_lat,
|
||||
"target_lon": target_lon,
|
||||
"target_lat": target_lat,
|
||||
"highway": highway,
|
||||
"name": tags.get("name"),
|
||||
"length_m": length_m,
|
||||
"walk_cost_s": walk_cost,
|
||||
"reverse_walk_cost_s": reverse_walk_cost,
|
||||
"drive_cost_s": drive_cost,
|
||||
"reverse_drive_cost_s": reverse_drive_cost,
|
||||
"geometry_geojson": json.dumps({"type": "LineString", "coordinates": [[source_lon, source_lat], [target_lon, target_lat]]}, separators=(",", ":")),
|
||||
"min_lon": min(source_lon, target_lon),
|
||||
"min_lat": min(source_lat, target_lat),
|
||||
"max_lon": max(source_lon, target_lon),
|
||||
"max_lat": max(source_lat, target_lat),
|
||||
"tags_json": _routing_tags_json(tags),
|
||||
}
|
||||
)
|
||||
self.edge_count += 1
|
||||
if walk_cost is not None:
|
||||
self.walk_edge_count += 1
|
||||
if drive_cost is not None:
|
||||
self.drive_edge_count += 1
|
||||
|
||||
self.processed_way_count += 1
|
||||
if len(self.edges) >= self.batch_size:
|
||||
self.flush()
|
||||
if self.processed_way_count % 100_000 == 0:
|
||||
_emit(
|
||||
self.progress_callback,
|
||||
"routing_layer_import_batch",
|
||||
f"Imported {self.edge_count:,} routing edges.",
|
||||
self.edge_count,
|
||||
None,
|
||||
{"processed_ways": self.processed_way_count, "nodes_pending": len(self.nodes), "edges": self.edge_count},
|
||||
)
|
||||
|
||||
def flush(self) -> None:
|
||||
if not self.nodes and not self.edges:
|
||||
return
|
||||
node_rows = list(self.nodes.values())
|
||||
edge_rows = self.edges
|
||||
if node_rows:
|
||||
stmt = postgresql_insert(RoutingNode).values(node_rows)
|
||||
stmt = stmt.on_conflict_do_nothing(index_elements=["dataset_id", "osm_node_id"])
|
||||
self.session.execute(stmt)
|
||||
self.node_count += len(node_rows)
|
||||
self.nodes.clear()
|
||||
if edge_rows:
|
||||
self.session.bulk_insert_mappings(RoutingEdge, edge_rows)
|
||||
self.edges = []
|
||||
self.session.commit()
|
||||
|
||||
|
||||
def _walkable(tags: dict[str, str], highway: str) -> bool:
|
||||
if highway not in WALK_HIGHWAYS:
|
||||
return False
|
||||
access = _tag_value(tags, "access")
|
||||
foot = _tag_value(tags, "foot")
|
||||
if foot in NO_VALUES:
|
||||
return False
|
||||
if access in NO_VALUES and foot not in YES_VALUES:
|
||||
return False
|
||||
if highway in {"motorway", "motorway_link", "trunk", "trunk_link"} and foot not in YES_VALUES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _drivable(tags: dict[str, str], highway: str) -> bool:
|
||||
if highway not in DRIVE_HIGHWAYS:
|
||||
return False
|
||||
access = _tag_value(tags, "access")
|
||||
motor_vehicle = _tag_value(tags, "motor_vehicle")
|
||||
motorcar = _tag_value(tags, "motorcar")
|
||||
vehicle = _tag_value(tags, "vehicle")
|
||||
if motorcar in NO_VALUES or motor_vehicle in NO_VALUES or vehicle in NO_VALUES:
|
||||
return False
|
||||
if access in NO_VALUES and motorcar not in YES_VALUES and motor_vehicle not in YES_VALUES:
|
||||
return False
|
||||
if highway in {"footway", "path", "pedestrian", "steps", "cycleway", "bridleway"}:
|
||||
return motorcar in YES_VALUES or motor_vehicle in YES_VALUES
|
||||
return True
|
||||
|
||||
|
||||
def _oneway_direction(tags: dict[str, str], highway: str) -> str:
|
||||
oneway = _tag_value(tags, "oneway")
|
||||
if oneway in ONEWAY_REVERSE:
|
||||
return "reverse"
|
||||
if oneway in ONEWAY_FORWARD or tags.get("junction") == "roundabout" or highway == "motorway":
|
||||
return "forward"
|
||||
return "both"
|
||||
|
||||
|
||||
def _drive_speed_mps(tags: dict[str, str], highway: str) -> float:
|
||||
maxspeed = _parse_maxspeed(tags.get("maxspeed"))
|
||||
kmh = maxspeed or DEFAULT_DRIVE_SPEED_KMH.get(highway, 30)
|
||||
return max(5.0, float(kmh) / 3.6)
|
||||
|
||||
|
||||
def _parse_maxspeed(value: str | None) -> float | None:
|
||||
if not value:
|
||||
return None
|
||||
text = value.strip().lower()
|
||||
if text in {"signals", "none", "walk", "variable"}:
|
||||
return None
|
||||
if text.endswith("mph"):
|
||||
number = _leading_float(text[:-3])
|
||||
return None if number is None else number * 1.60934
|
||||
return _leading_float(text)
|
||||
|
||||
|
||||
def _leading_float(value: str) -> float | None:
|
||||
digits = []
|
||||
for char in value.strip():
|
||||
if char.isdigit() or char == ".":
|
||||
digits.append(char)
|
||||
elif digits:
|
||||
break
|
||||
if not digits:
|
||||
return None
|
||||
try:
|
||||
return float("".join(digits))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _routing_tags_json(tags: dict[str, str]) -> str:
|
||||
selected = {
|
||||
key: value
|
||||
for key, value in tags.items()
|
||||
if key in {"access", "bicycle", "bridge", "foot", "highway", "junction", "maxspeed", "motor_vehicle", "motorcar", "name", "oneway", "service", "surface", "tunnel", "vehicle"}
|
||||
}
|
||||
return json.dumps(selected, separators=(",", ":"))
|
||||
|
||||
|
||||
def _tag_value(tags: dict[str, str], key: str) -> str:
|
||||
return str(tags.get(key) or "").strip().lower()
|
||||
|
||||
|
||||
def _distance_m(lat_a: float, lon_a: float, lat_b: float, lon_b: float) -> float:
|
||||
radius = 6_371_000.0
|
||||
phi_a = math.radians(lat_a)
|
||||
phi_b = math.radians(lat_b)
|
||||
delta_phi = math.radians(lat_b - lat_a)
|
||||
delta_lambda = math.radians(lon_b - lon_a)
|
||||
hav = math.sin(delta_phi / 2) ** 2 + math.cos(phi_a) * math.cos(phi_b) * math.sin(delta_lambda / 2) ** 2
|
||||
return radius * 2 * math.atan2(math.sqrt(hav), math.sqrt(1 - hav))
|
||||
|
||||
|
||||
def _metadata(dataset: Dataset) -> dict[str, object]:
|
||||
try:
|
||||
value = json.loads(dataset.metadata_json or "{}")
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return value if isinstance(value, dict) else {}
|
||||
|
||||
|
||||
def _emit(
|
||||
progress_callback: ProgressCallback | None,
|
||||
event_type: str,
|
||||
message: str,
|
||||
progress_current: int | None,
|
||||
progress_total: int | None,
|
||||
metadata: dict[str, object] | None = None,
|
||||
) -> None:
|
||||
if progress_callback is not None:
|
||||
progress_callback(event_type, message, progress_current, progress_total, metadata)
|
||||
40
app/pipeline/run.py
Normal file
40
app/pipeline/run.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Callable, Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import Source
|
||||
from app.pipeline.gtfs import run_gtfs_source
|
||||
from app.pipeline.osm_diff import run_osm_diff_source
|
||||
from app.pipeline.osm_geojson import run_osm_geojson_source
|
||||
from app.pipeline.osm_pbf import run_osm_pbf_source
|
||||
|
||||
|
||||
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, Any] | None], None]
|
||||
|
||||
|
||||
def run_source(session: Session, source: Source, progress_callback: ProgressCallback | None = None):
|
||||
source.status = "running"
|
||||
source.last_run_at = datetime.now(timezone.utc)
|
||||
source.last_error = None
|
||||
session.flush()
|
||||
try:
|
||||
if source.kind == "gtfs":
|
||||
dataset = run_gtfs_source(session, source, progress_callback=progress_callback)
|
||||
elif source.kind == "osm_geojson":
|
||||
dataset = run_osm_geojson_source(session, source)
|
||||
elif source.kind == "osm_pbf":
|
||||
dataset = run_osm_pbf_source(session, source, progress_callback=progress_callback)
|
||||
elif source.kind == "osm_diff":
|
||||
dataset = run_osm_diff_source(session, source)
|
||||
else:
|
||||
raise ValueError(f"Unsupported source kind: {source.kind}")
|
||||
source.status = "ok"
|
||||
source.last_error = None
|
||||
return dataset
|
||||
except Exception as exc: # noqa: BLE001 - persist pipeline error for UI
|
||||
source.status = "error"
|
||||
source.last_error = str(exc)
|
||||
raise
|
||||
294
app/pipeline/sample_data.py
Normal file
294
app/pipeline/sample_data.py
Normal file
@@ -0,0 +1,294 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.db import init_db
|
||||
from app.models import (
|
||||
Dataset,
|
||||
CanonicalStop,
|
||||
CanonicalStopLink,
|
||||
GtfsAgency,
|
||||
GtfsCalendar,
|
||||
GtfsCalendarDate,
|
||||
GtfsRoute,
|
||||
GtfsRoutePatternLink,
|
||||
GtfsShape,
|
||||
GtfsStop,
|
||||
GtfsStopTime,
|
||||
GtfsTripRoutePatternLink,
|
||||
GtfsTrip,
|
||||
Itinerary,
|
||||
ItineraryLeg,
|
||||
Job,
|
||||
JobEvent,
|
||||
MatchRule,
|
||||
OsmDiffState,
|
||||
OsmFeature,
|
||||
PipelineRun,
|
||||
RouteMatch,
|
||||
RoutePattern,
|
||||
RoutePatternStop,
|
||||
RoutingEdge,
|
||||
RoutingNode,
|
||||
Source,
|
||||
SourceCatalogEntry,
|
||||
SourceUpdateCheck,
|
||||
TravelRequest,
|
||||
)
|
||||
from app.pipeline.matcher import run_route_matching
|
||||
from app.pipeline.route_layer import rebuild_route_layer
|
||||
from app.pipeline.run import run_source
|
||||
|
||||
|
||||
def load_sample_project(session: Session, *, preserve_job_id: int | None = None) -> dict:
|
||||
"""Clear the DB, create a small Berlin-like GTFS + OSM sample, import, and match."""
|
||||
init_db()
|
||||
clear_project_data(session, preserve_job_id=preserve_job_id, preserve_catalog=True)
|
||||
sample_dir = settings.data_dir / "sample"
|
||||
sample_dir.mkdir(parents=True, exist_ok=True)
|
||||
gtfs_path = sample_dir / "sample_berlin.gtfs.zip"
|
||||
osm_path = sample_dir / "sample_berlin_osm.geojson"
|
||||
create_sample_gtfs(gtfs_path)
|
||||
create_sample_osm_geojson(osm_path)
|
||||
|
||||
gtfs_source = Source(name="Sample Berlin GTFS", kind="gtfs", url=str(gtfs_path), country="DE", license="sample")
|
||||
osm_source = Source(name="Sample Berlin OSM transport", kind="osm_geojson", url=str(osm_path), country="DE", license="sample")
|
||||
session.add_all([gtfs_source, osm_source])
|
||||
session.flush()
|
||||
|
||||
gtfs_dataset = run_source(session, gtfs_source)
|
||||
osm_dataset = run_source(session, osm_source)
|
||||
match_result = run_route_matching(session)
|
||||
route_layer_result = rebuild_route_layer(session)
|
||||
return {
|
||||
"status": "ok",
|
||||
"gtfs_dataset_id": gtfs_dataset.id,
|
||||
"osm_dataset_id": osm_dataset.id,
|
||||
"match_result": match_result,
|
||||
"route_layer_result": route_layer_result,
|
||||
}
|
||||
|
||||
|
||||
def clear_project_data(
|
||||
session: Session,
|
||||
*,
|
||||
preserve_job_id: int | None = None,
|
||||
preserve_catalog: bool = True,
|
||||
) -> None:
|
||||
"""Clear user/project data while optionally preserving the current queue job."""
|
||||
session.execute(delete(PipelineRun))
|
||||
if preserve_job_id is None:
|
||||
session.execute(delete(JobEvent))
|
||||
session.execute(delete(Job))
|
||||
else:
|
||||
_cancel_other_jobs_for_reset(session, preserve_job_id)
|
||||
|
||||
for model in [
|
||||
ItineraryLeg,
|
||||
Itinerary,
|
||||
TravelRequest,
|
||||
SourceUpdateCheck,
|
||||
OsmDiffState,
|
||||
MatchRule,
|
||||
RouteMatch,
|
||||
GtfsTripRoutePatternLink,
|
||||
GtfsRoutePatternLink,
|
||||
RoutePatternStop,
|
||||
RoutePattern,
|
||||
CanonicalStopLink,
|
||||
CanonicalStop,
|
||||
RoutingEdge,
|
||||
RoutingNode,
|
||||
GtfsStopTime,
|
||||
GtfsCalendarDate,
|
||||
GtfsCalendar,
|
||||
GtfsShape,
|
||||
GtfsTrip,
|
||||
GtfsRoute,
|
||||
GtfsStop,
|
||||
GtfsAgency,
|
||||
OsmFeature,
|
||||
Dataset,
|
||||
Source,
|
||||
]:
|
||||
session.execute(delete(model))
|
||||
if not preserve_catalog:
|
||||
session.execute(delete(SourceCatalogEntry))
|
||||
session.flush()
|
||||
|
||||
|
||||
def _cancel_other_jobs_for_reset(session: Session, preserve_job_id: int) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
active_statuses = {"queued", "running", "paused"}
|
||||
jobs = session.scalars(
|
||||
select(Job).where(Job.id != preserve_job_id, Job.status.in_(active_statuses))
|
||||
).all()
|
||||
for job in jobs:
|
||||
job.status = "cancelled"
|
||||
job.requested_action = None
|
||||
job.lease_owner = None
|
||||
job.lease_expires_at = None
|
||||
job.paused_at = None
|
||||
job.error = None
|
||||
job.updated_at = now
|
||||
job.finished_at = now
|
||||
session.add(
|
||||
JobEvent(
|
||||
job_id=job.id,
|
||||
event_type="cancelled_by_reset",
|
||||
message=f"Job cancelled by reset job #{preserve_job_id}.",
|
||||
progress_current=job.progress_current,
|
||||
progress_total=job.progress_total,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def create_sample_gtfs(path: Path) -> None:
|
||||
agencies = [
|
||||
{"agency_id": "BVG", "agency_name": "BVG", "agency_url": "https://example.invalid/bvg", "agency_timezone": "Europe/Berlin"},
|
||||
{"agency_id": "DB", "agency_name": "DB Regio", "agency_url": "https://example.invalid/db", "agency_timezone": "Europe/Berlin"},
|
||||
{"agency_id": "XAIR", "agency_name": "Example Airport Shuttle", "agency_url": "https://example.invalid/xair", "agency_timezone": "Europe/Berlin"},
|
||||
]
|
||||
stops = [
|
||||
{"stop_id": "hbf", "stop_name": "Berlin Hauptbahnhof", "stop_lat": "52.5251", "stop_lon": "13.3696"},
|
||||
{"stop_id": "friedrich", "stop_name": "Friedrichstraße", "stop_lat": "52.5201", "stop_lon": "13.3862"},
|
||||
{"stop_id": "alex", "stop_name": "Alexanderplatz", "stop_lat": "52.5219", "stop_lon": "13.4132"},
|
||||
{"stop_id": "ost", "stop_name": "Ostbahnhof", "stop_lat": "52.5100", "stop_lon": "13.4344"},
|
||||
{"stop_id": "zoo", "stop_name": "Zoologischer Garten", "stop_lat": "52.5069", "stop_lon": "13.3320"},
|
||||
{"stop_id": "wittenberg", "stop_name": "Wittenbergplatz", "stop_lat": "52.5020", "stop_lon": "13.3430"},
|
||||
{"stop_id": "potsdamer", "stop_name": "Potsdamer Platz", "stop_lat": "52.5096", "stop_lon": "13.3760"},
|
||||
{"stop_id": "stadtmitte", "stop_name": "Stadtmitte", "stop_lat": "52.5113", "stop_lon": "13.3907"},
|
||||
{"stop_id": "reichstag", "stop_name": "Reichstag", "stop_lat": "52.5186", "stop_lon": "13.3763"},
|
||||
{"stop_id": "hackescher", "stop_name": "Hackescher Markt", "stop_lat": "52.5220", "stop_lon": "13.4023"},
|
||||
{"stop_id": "naturkunde", "stop_name": "Naturkundemuseum", "stop_lat": "52.5300", "stop_lon": "13.3790"},
|
||||
{"stop_id": "wannsee", "stop_name": "Wannsee", "stop_lat": "52.4210", "stop_lon": "13.1797"},
|
||||
{"stop_id": "kladow", "stop_name": "Kladow", "stop_lat": "52.4547", "stop_lon": "13.1439"},
|
||||
{"stop_id": "airport", "stop_name": "Example Airport Terminal", "stop_lat": "52.3650", "stop_lon": "13.5100"},
|
||||
]
|
||||
routes = [
|
||||
{"route_id": "u2", "agency_id": "BVG", "route_short_name": "U2", "route_long_name": "Pankow - Ruhleben", "route_type": "1"},
|
||||
{"route_id": "re1", "agency_id": "DB", "route_short_name": "RE1", "route_long_name": "Magdeburg - Frankfurt Oder", "route_type": "2"},
|
||||
{"route_id": "m5", "agency_id": "BVG", "route_short_name": "M5", "route_long_name": "Hauptbahnhof - Hohenschönhausen", "route_type": "0"},
|
||||
{"route_id": "bus100", "agency_id": "BVG", "route_short_name": "100", "route_long_name": "Zoologischer Garten - Alexanderplatz", "route_type": "3"},
|
||||
{"route_id": "f10", "agency_id": "BVG", "route_short_name": "F10", "route_long_name": "Wannsee - Kladow", "route_type": "4"},
|
||||
{"route_id": "x99", "agency_id": "XAIR", "route_short_name": "X99", "route_long_name": "Airport Express Sample", "route_type": "3"},
|
||||
]
|
||||
trips = [
|
||||
{"route_id": r["route_id"], "service_id": "daily", "trip_id": f"{r['route_id']}_trip", "shape_id": f"{r['route_id']}_shape"}
|
||||
for r in routes
|
||||
]
|
||||
stop_sequences = {
|
||||
"u2_trip": ["zoo", "wittenberg", "potsdamer", "stadtmitte", "alex"],
|
||||
"re1_trip": ["hbf", "friedrich", "alex", "ost"],
|
||||
"m5_trip": ["hbf", "naturkunde", "hackescher", "alex"],
|
||||
"bus100_trip": ["zoo", "reichstag", "alex"],
|
||||
"f10_trip": ["wannsee", "kladow"],
|
||||
"x99_trip": ["alex", "airport"],
|
||||
}
|
||||
coords = {row["stop_id"]: (row["stop_lon"], row["stop_lat"]) for row in stops}
|
||||
stop_times = []
|
||||
shapes = []
|
||||
for trip in trips:
|
||||
trip_id = trip["trip_id"]
|
||||
for idx, stop_id in enumerate(stop_sequences[trip_id], start=1):
|
||||
stop_times.append(
|
||||
{
|
||||
"trip_id": trip_id,
|
||||
"arrival_time": f"08:{idx * 5:02d}:00",
|
||||
"departure_time": f"08:{idx * 5 + 1:02d}:00",
|
||||
"stop_id": stop_id,
|
||||
"stop_sequence": str(idx),
|
||||
}
|
||||
)
|
||||
lon, lat = coords[stop_id]
|
||||
shapes.append(
|
||||
{
|
||||
"shape_id": trip["shape_id"],
|
||||
"shape_pt_lat": lat,
|
||||
"shape_pt_lon": lon,
|
||||
"shape_pt_sequence": str(idx),
|
||||
}
|
||||
)
|
||||
calendar = [
|
||||
{
|
||||
"service_id": "daily",
|
||||
"monday": "1",
|
||||
"tuesday": "1",
|
||||
"wednesday": "1",
|
||||
"thursday": "1",
|
||||
"friday": "1",
|
||||
"saturday": "1",
|
||||
"sunday": "1",
|
||||
"start_date": "20260101",
|
||||
"end_date": "20261231",
|
||||
}
|
||||
]
|
||||
|
||||
with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
|
||||
_write_csv(zf, "agency.txt", ["agency_id", "agency_name", "agency_url", "agency_timezone"], agencies)
|
||||
_write_csv(zf, "stops.txt", ["stop_id", "stop_name", "stop_lat", "stop_lon"], stops)
|
||||
_write_csv(zf, "routes.txt", ["route_id", "agency_id", "route_short_name", "route_long_name", "route_type"], routes)
|
||||
_write_csv(zf, "trips.txt", ["route_id", "service_id", "trip_id", "shape_id"], trips)
|
||||
_write_csv(zf, "stop_times.txt", ["trip_id", "arrival_time", "departure_time", "stop_id", "stop_sequence"], stop_times)
|
||||
_write_csv(
|
||||
zf,
|
||||
"calendar.txt",
|
||||
["service_id", "monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday", "start_date", "end_date"],
|
||||
calendar,
|
||||
)
|
||||
_write_csv(zf, "shapes.txt", ["shape_id", "shape_pt_lat", "shape_pt_lon", "shape_pt_sequence"], shapes)
|
||||
|
||||
|
||||
def _write_csv(zf: zipfile.ZipFile, name: str, fields: list[str], rows: list[dict[str, str]]) -> None:
|
||||
buffer = io.StringIO(newline="")
|
||||
writer = csv.DictWriter(buffer, fieldnames=fields)
|
||||
writer.writeheader()
|
||||
writer.writerows(rows)
|
||||
zf.writestr(name, buffer.getvalue())
|
||||
|
||||
|
||||
def create_sample_osm_geojson(path: Path) -> None:
|
||||
def line(fid, mode, ref, name, operator, coords):
|
||||
return {
|
||||
"type": "Feature",
|
||||
"geometry": {"type": "LineString", "coordinates": coords},
|
||||
"properties": {
|
||||
"osm_type": "relation",
|
||||
"osm_id": str(fid),
|
||||
"type": "route",
|
||||
"route": mode,
|
||||
"ref": ref,
|
||||
"name": name,
|
||||
"operator": operator,
|
||||
"network": "VBB" if operator == "BVG" else "DB",
|
||||
},
|
||||
}
|
||||
|
||||
def point(fid, kind, name, coords, props=None):
|
||||
props = props or {}
|
||||
props.update({"osm_type": "node", "osm_id": str(fid), "name": name})
|
||||
return {"type": "Feature", "geometry": {"type": "Point", "coordinates": coords}, "properties": props}
|
||||
|
||||
features = [
|
||||
line(1002, "subway", "U2", "U2 Ruhleben - Pankow", "BVG", [[13.3320, 52.5069], [13.3430, 52.5020], [13.3760, 52.5096], [13.3907, 52.5113], [13.4132, 52.5219]]),
|
||||
line(2001, "train", "RE1", "RE1 Magdeburg - Frankfurt Oder", "DB Regio", [[13.3696, 52.5251], [13.3862, 52.5201], [13.4132, 52.5219], [13.4344, 52.5100]]),
|
||||
line(5005, "tram", "M5", "M5 Hauptbahnhof - Hohenschönhausen", "BVG", [[13.3696, 52.5251], [13.3790, 52.5300], [13.4023, 52.5220], [13.4132, 52.5219]]),
|
||||
line(6100, "bus", "100", "Bus 100 Zoologischer Garten - Alexanderplatz", "BVG", [[13.3320, 52.5069], [13.3763, 52.5186], [13.4132, 52.5219]]),
|
||||
line(7010, "ferry", "F10", "F10 Wannsee - Kladow", "BVG", [[13.1797, 52.4210], [13.1439, 52.4547]]),
|
||||
line(5010, "tram", "M10", "M10 Warschauer Straße - Hauptbahnhof", "BVG", [[13.4500, 52.5050], [13.4020, 52.5300], [13.3696, 52.5251]]),
|
||||
point(9001, "station", "Berlin Hauptbahnhof", [13.3696, 52.5251], {"railway": "station"}),
|
||||
point(9002, "station", "Alexanderplatz", [13.4132, 52.5219], {"railway": "station"}),
|
||||
point(9003, "stop", "Zoologischer Garten", [13.3320, 52.5069], {"public_transport": "station", "railway": "station"}),
|
||||
point(9004, "terminal", "Wannsee Ferry Terminal", [13.1797, 52.4210], {"amenity": "ferry_terminal"}),
|
||||
point(9005, "terminal", "Kladow Ferry Terminal", [13.1439, 52.4547], {"amenity": "ferry_terminal"}),
|
||||
]
|
||||
path.write_text(json.dumps({"type": "FeatureCollection", "features": features}, indent=2), encoding="utf-8")
|
||||
135
app/pipeline/state.py
Normal file
135
app/pipeline/state.py
Normal file
@@ -0,0 +1,135 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import hashlib
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import PipelineRun
|
||||
|
||||
|
||||
STAGE_ACQUIRE_RAW = "acquire_raw"
|
||||
STAGE_FILTER_TRANSPORT = "filter_transport"
|
||||
STAGE_EXTRACT_GEOMETRY = "extract_geometry"
|
||||
STAGE_LABEL_FEATURES = "label_features"
|
||||
STAGE_BUILD_INDEXES = "build_indexes"
|
||||
STAGE_MATCH_ROUTES = "match_routes"
|
||||
STAGE_BUILD_ROUTE_LAYER = "build_route_layer"
|
||||
|
||||
|
||||
def stable_json(value: Any) -> str:
|
||||
return json.dumps(value, sort_keys=True, separators=(",", ":"), default=str)
|
||||
|
||||
|
||||
def dependency_hash(value: Any) -> str:
|
||||
return hashlib.sha256(stable_json(value).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def latest_completed_run(
|
||||
session: Session,
|
||||
*,
|
||||
stage: str,
|
||||
version: str,
|
||||
dependency_hash_value: str,
|
||||
source_id: int | None = None,
|
||||
dataset_id: int | None = None,
|
||||
) -> PipelineRun | None:
|
||||
stmt = (
|
||||
select(PipelineRun)
|
||||
.where(
|
||||
PipelineRun.stage == stage,
|
||||
PipelineRun.version == version,
|
||||
PipelineRun.dependency_hash == dependency_hash_value,
|
||||
PipelineRun.status == "completed",
|
||||
)
|
||||
.order_by(PipelineRun.finished_at.desc(), PipelineRun.id.desc())
|
||||
.limit(1)
|
||||
)
|
||||
if source_id is None:
|
||||
stmt = stmt.where(PipelineRun.source_id.is_(None))
|
||||
else:
|
||||
stmt = stmt.where(PipelineRun.source_id == source_id)
|
||||
if dataset_id is None:
|
||||
stmt = stmt.where(PipelineRun.dataset_id.is_(None))
|
||||
else:
|
||||
stmt = stmt.where(PipelineRun.dataset_id == dataset_id)
|
||||
return session.scalar(stmt)
|
||||
|
||||
|
||||
def start_pipeline_run(
|
||||
session: Session,
|
||||
*,
|
||||
stage: str,
|
||||
version: str,
|
||||
dependency_hash_value: str,
|
||||
source_id: int | None = None,
|
||||
dataset_id: int | None = None,
|
||||
job_id: int | None = None,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> PipelineRun:
|
||||
now = datetime.now(timezone.utc)
|
||||
run = PipelineRun(
|
||||
stage=stage,
|
||||
version=version,
|
||||
dependency_hash=dependency_hash_value,
|
||||
status="running",
|
||||
source_id=source_id,
|
||||
dataset_id=dataset_id,
|
||||
job_id=job_id,
|
||||
input_json=None if inputs is None else stable_json(inputs),
|
||||
started_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
session.add(run)
|
||||
session.flush()
|
||||
return run
|
||||
|
||||
|
||||
def finish_pipeline_run(
|
||||
session: Session,
|
||||
run: PipelineRun,
|
||||
*,
|
||||
status: str = "completed",
|
||||
outputs: dict[str, Any] | None = None,
|
||||
error: str | None = None,
|
||||
) -> PipelineRun:
|
||||
now = datetime.now(timezone.utc)
|
||||
run.status = status
|
||||
run.output_json = None if outputs is None else stable_json(outputs)
|
||||
run.error = error
|
||||
run.updated_at = now
|
||||
run.finished_at = now
|
||||
session.flush()
|
||||
return run
|
||||
|
||||
|
||||
def pipeline_run_payload(run: PipelineRun) -> dict[str, Any]:
|
||||
return {
|
||||
"id": run.id,
|
||||
"stage": run.stage,
|
||||
"version": run.version,
|
||||
"dependency_hash": run.dependency_hash,
|
||||
"status": run.status,
|
||||
"source_id": run.source_id,
|
||||
"dataset_id": run.dataset_id,
|
||||
"job_id": run.job_id,
|
||||
"input": _json_object(run.input_json),
|
||||
"output": _json_object(run.output_json),
|
||||
"error": run.error,
|
||||
"started_at": run.started_at.isoformat() if run.started_at else None,
|
||||
"updated_at": run.updated_at.isoformat() if run.updated_at else None,
|
||||
"finished_at": run.finished_at.isoformat() if run.finished_at else None,
|
||||
}
|
||||
|
||||
|
||||
def _json_object(text: str | None) -> dict[str, Any]:
|
||||
if not text:
|
||||
return {}
|
||||
try:
|
||||
value = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return value if isinstance(value, dict) else {}
|
||||
89
app/pipeline/utils.py
Normal file
89
app/pipeline/utils.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from shapely.geometry import shape
|
||||
|
||||
|
||||
def sha256_file(path: Path) -> str:
|
||||
h = hashlib.sha256()
|
||||
with path.open("rb") as f:
|
||||
for chunk in iter(lambda: f.read(1024 * 1024), b""):
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
|
||||
|
||||
def norm_text(value: object) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
value = str(value).lower().strip()
|
||||
value = value.replace("ß", "ss")
|
||||
value = re.sub(r"[^a-z0-9]+", " ", value)
|
||||
return re.sub(r"\s+", " ", value).strip()
|
||||
|
||||
|
||||
def norm_ref(value: object) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
return re.sub(r"[^a-z0-9]+", "", str(value).lower())
|
||||
|
||||
|
||||
def first_nonempty(*values: object) -> str:
|
||||
for value in values:
|
||||
if value is None:
|
||||
continue
|
||||
text = str(value).strip()
|
||||
if text:
|
||||
return text
|
||||
return ""
|
||||
|
||||
|
||||
def geometry_json_and_bbox(geometry: object) -> tuple[Optional[str], tuple[Optional[float], Optional[float], Optional[float], Optional[float]]]:
|
||||
if geometry is None:
|
||||
return None, (None, None, None, None)
|
||||
try:
|
||||
geom = shape(geometry) if isinstance(geometry, dict) else geometry
|
||||
if geom.is_empty:
|
||||
return None, (None, None, None, None)
|
||||
min_lon, min_lat, max_lon, max_lat = geom.bounds
|
||||
return json.dumps(geom.__geo_interface__, separators=(",", ":")), (min_lon, min_lat, max_lon, max_lat)
|
||||
except Exception:
|
||||
return None, (None, None, None, None)
|
||||
|
||||
|
||||
def bbox_overlap(a: tuple[float | None, float | None, float | None, float | None], b: tuple[float | None, float | None, float | None, float | None]) -> bool:
|
||||
if any(v is None for v in (*a, *b)):
|
||||
return False
|
||||
aminx, aminy, amaxx, amaxy = a # type: ignore[misc]
|
||||
bminx, bminy, bmaxx, bmaxy = b # type: ignore[misc]
|
||||
return not (amaxx < bminx or bmaxx < aminx or amaxy < bminy or bmaxy < aminy)
|
||||
|
||||
|
||||
def bbox_center(b: tuple[float | None, float | None, float | None, float | None]) -> Optional[tuple[float, float]]:
|
||||
if any(v is None for v in b):
|
||||
return None
|
||||
minx, miny, maxx, maxy = b # type: ignore[misc]
|
||||
return ((minx + maxx) / 2, (miny + maxy) / 2)
|
||||
|
||||
|
||||
def approx_bbox_center_distance_deg(a: tuple[float | None, float | None, float | None, float | None], b: tuple[float | None, float | None, float | None, float | None]) -> Optional[float]:
|
||||
ca = bbox_center(a)
|
||||
cb = bbox_center(b)
|
||||
if ca is None or cb is None:
|
||||
return None
|
||||
return ((ca[0] - cb[0]) ** 2 + (ca[1] - cb[1]) ** 2) ** 0.5
|
||||
|
||||
|
||||
def batched(iterable: Iterable[dict], batch_size: int = 1000) -> Iterable[list[dict]]:
|
||||
batch: list[dict] = []
|
||||
for item in iterable:
|
||||
batch.append(item)
|
||||
if len(batch) >= batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
if batch:
|
||||
yield batch
|
||||
393
app/qa.py
Normal file
393
app/qa.py
Normal file
@@ -0,0 +1,393 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.gtfs_storage import missing_sidecar_paths as missing_gtfs_sidecar_paths
|
||||
from app.models import (
|
||||
CanonicalStop,
|
||||
CanonicalStopLink,
|
||||
Dataset,
|
||||
GtfsAgency,
|
||||
GtfsCalendar,
|
||||
GtfsCalendarDate,
|
||||
GtfsRoute,
|
||||
GtfsShape,
|
||||
GtfsStop,
|
||||
GtfsTrip,
|
||||
Job,
|
||||
OsmFeature,
|
||||
RouteMatch,
|
||||
RoutePattern,
|
||||
RoutePatternStop,
|
||||
Source,
|
||||
SourceCatalogEntry,
|
||||
)
|
||||
from app.osm_storage import missing_sidecar_paths as missing_osm_sidecar_paths
|
||||
from app.pipeline.osm_addresses import ADDRESS_INDEX_VERSION
|
||||
from app.pipeline.routing_layer import active_routing_dataset
|
||||
|
||||
|
||||
def qa_summary(session: Session) -> dict[str, Any]:
|
||||
active_gtfs_datasets = session.scalars(
|
||||
select(Dataset).where(Dataset.kind == "gtfs", Dataset.is_active.is_(True)).order_by(Dataset.id)
|
||||
).all()
|
||||
active_osm_datasets = session.scalars(
|
||||
select(Dataset).where(Dataset.kind == "osm_geojson", Dataset.is_active.is_(True)).order_by(Dataset.id)
|
||||
).all()
|
||||
active_gtfs_ids = [int(dataset.id) for dataset in active_gtfs_datasets]
|
||||
active_osm_ids = [int(dataset.id) for dataset in active_osm_datasets]
|
||||
|
||||
source_catalog_total = _count(session, SourceCatalogEntry)
|
||||
registered_sources = _count(session, Source)
|
||||
linked_catalog_entries = int(
|
||||
session.scalar(
|
||||
select(func.count(func.distinct(Source.catalog_entry_id))).where(Source.catalog_entry_id.is_not(None))
|
||||
)
|
||||
or 0
|
||||
)
|
||||
priority_backlog = _priority_catalog_backlog(session)
|
||||
failed_sources = int(
|
||||
session.scalar(
|
||||
select(func.count())
|
||||
.select_from(Source)
|
||||
.where((Source.last_error.is_not(None)) | Source.status.in_(["failed", "error"]))
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
active_jobs = _job_status_counts(session)
|
||||
missing_gtfs_sidecars = sum(1 for dataset in active_gtfs_datasets if missing_gtfs_sidecar_paths(dataset))
|
||||
missing_osm_sidecars = sum(1 for dataset in active_osm_datasets if missing_osm_sidecar_paths(dataset))
|
||||
|
||||
gtfs_counts = _gtfs_validation_counts(session, active_gtfs_ids)
|
||||
link_counts = _link_quality_counts(session, active_gtfs_ids, active_osm_ids)
|
||||
route_counts = _route_quality_counts(session, active_gtfs_ids)
|
||||
address_status = _lightweight_address_index_status(session)
|
||||
license_unknown = int(
|
||||
session.scalar(
|
||||
select(func.count())
|
||||
.select_from(Source)
|
||||
.where(Source.kind == "gtfs", (Source.license.is_(None)) | (func.lower(Source.license).in_(["", "unknown"])))
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
return {
|
||||
"generated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"decision": {
|
||||
"deployment": "same_workbench_for_now",
|
||||
"database": "same_postgresql_database_for_now",
|
||||
"split_trigger": "Split when third-party API, accounts/billing, heavy export jobs, or independent scaling are needed.",
|
||||
"api_contract": "/api/qa/summary is intentionally display-ready but stable enough to become a harmonization-service summary endpoint.",
|
||||
},
|
||||
"sections": [
|
||||
{
|
||||
"id": "source_discovery",
|
||||
"title": "Source Discovery",
|
||||
"items": [
|
||||
_item("Identified sources", source_catalog_total, "info", "Rows in the source catalog."),
|
||||
_item("Registered sources", registered_sources, "info", "Sources known to the importer."),
|
||||
_item("Catalog entries linked", linked_catalog_entries, "good" if linked_catalog_entries else "warn", "Catalog rows connected to importer sources."),
|
||||
_item("Priority catalog backlog", priority_backlog, "warn" if priority_backlog else "good", "P0/P1 catalog rows without a registered source."),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "import_health",
|
||||
"title": "Import Health",
|
||||
"items": [
|
||||
_item("Active GTFS datasets", len(active_gtfs_ids), "good" if active_gtfs_ids else "warn", "Feeds currently participating in harmonization."),
|
||||
_item("Active OSM datasets", len(active_osm_ids), "good" if active_osm_ids else "warn", "Visual/spatial datasets currently active."),
|
||||
_item("Running jobs", active_jobs.get("running", 0), "warn" if active_jobs.get("running", 0) else "info", "Currently running queued work."),
|
||||
_item("Queued jobs", active_jobs.get("queued", 0), "info", "Outstanding queued work."),
|
||||
_item("Failed sources", failed_sources, "bad" if failed_sources else "good", "Sources with failed status or last_error."),
|
||||
_item("Missing GTFS sidecars", missing_gtfs_sidecars, "bad" if missing_gtfs_sidecars else "good", "Active GTFS datasets whose sidecar is unavailable."),
|
||||
_item("Missing OSM sidecars", missing_osm_sidecars, "bad" if missing_osm_sidecars else "good", "Active OSM datasets whose sidecar is unavailable."),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "gtfs_validation",
|
||||
"title": "GTFS Validation",
|
||||
"items": [
|
||||
_item("Agencies", gtfs_counts["agencies"], "info", "Imported agency.txt rows."),
|
||||
_item("Stops", gtfs_counts["stops"], "info", "Imported stops."),
|
||||
_item("Routes", gtfs_counts["routes"], "info", "Imported routes."),
|
||||
_item("Trips", gtfs_counts["trips"], "info", "Imported trips."),
|
||||
_item("Shapes", gtfs_counts["shapes"], "info", "Imported shape records."),
|
||||
_item("Stops without coordinates", gtfs_counts["stops_without_coordinates"], "bad" if gtfs_counts["stops_without_coordinates"] else "good", "Stops that cannot be spatially linked or routed."),
|
||||
_item("Routes without geometry", gtfs_counts["routes_without_geometry"], "warn" if gtfs_counts["routes_without_geometry"] else "good", "Routes with no stored GTFS shape geometry."),
|
||||
_item("Routes without agency", gtfs_counts["routes_without_agency"], "warn" if gtfs_counts["routes_without_agency"] else "good", "Routes missing agency/operator references."),
|
||||
_item("Calendar range", gtfs_counts["calendar_range"], "info", "Min/max imported service dates from calendars and exceptions."),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "deduplication",
|
||||
"title": "Deduplication and Stop Links",
|
||||
"items": [
|
||||
_item("Canonical stops", link_counts["canonical_stops"], "info", "Current normalized stop/station records."),
|
||||
_item("GTFS stop links", link_counts["gtfs_stop_links"], "good" if link_counts["gtfs_stop_links"] else "warn", "Timetable stops linked into canonical stops."),
|
||||
_item("GTFS stops without canonical link", link_counts["gtfs_stops_without_canonical"], "bad" if link_counts["gtfs_stops_without_canonical"] else "good", "Imported active stops that still need deduplication/linking."),
|
||||
_item("OSM visual stop links", link_counts["osm_stop_links"], "good" if link_counts["osm_stop_links"] else "warn", "OSM stop/station features linked to canonical stops."),
|
||||
_item("OSM stops without canonical link", link_counts["osm_stops_without_canonical"], "warn" if link_counts["osm_stops_without_canonical"] else "good", "Visual stops that are not yet linked to GTFS/canonical stops."),
|
||||
_item("Multi-source stop groups", link_counts["multi_source_stop_groups"], "info", "Canonical stops that merge GTFS stops from multiple datasets."),
|
||||
_item("Long-distance OSM links", link_counts["long_distance_osm_links"], "warn" if link_counts["long_distance_osm_links"] else "good", "OSM stop links over 150m from the canonical stop."),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "route_quality",
|
||||
"title": "Route Matching and Geometry",
|
||||
"items": [
|
||||
_item("Matched/accepted routes", route_counts["matched_or_accepted"], "good" if route_counts["matched_or_accepted"] else "warn", "GTFS routes with accepted or automatic OSM matches."),
|
||||
_item("Probable matches", route_counts["probable"], "warn" if route_counts["probable"] else "info", "Potential conflicts needing review."),
|
||||
_item("Weak matches", route_counts["weak"], "warn" if route_counts["weak"] else "good", "Low-confidence route links."),
|
||||
_item("Missing route matches", route_counts["missing"], "bad" if route_counts["missing"] else "good", "Routes with no visual match."),
|
||||
_item("Unreviewed GTFS routes", route_counts["routes_without_match"], "warn" if route_counts["routes_without_match"] else "good", "Active GTFS routes without a RouteMatch row."),
|
||||
_item("Route patterns", route_counts["route_patterns"], "info", "Published visual route-layer patterns."),
|
||||
_item("Route patterns without stops", route_counts["route_patterns_without_stops"], "warn" if route_counts["route_patterns_without_stops"] else "good", "Visual patterns missing canonical stop sequence evidence."),
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "publication_readiness",
|
||||
"title": "Publication Readiness",
|
||||
"items": [
|
||||
_item("Address index stale", "yes" if address_status.get("stale") else "no", "warn" if address_status.get("stale") else "good", "Address polygons/search index version status."),
|
||||
_item("GTFS licenses unknown", license_unknown, "warn" if license_unknown else "good", "GTFS sources without explicit redistribution/license status."),
|
||||
_item("Canonical export", "draft", "warn", "Canonical Europe dataset export tables/API are not versioned yet."),
|
||||
_item("Third-party API", "later", "info", "Accounts, billing, quotas, and API backend are intentionally out of scope for this step."),
|
||||
],
|
||||
},
|
||||
],
|
||||
"next_actions": [
|
||||
"Add review queues for each non-zero bad/warn metric.",
|
||||
"Persist source authority and redistribution policy before publishing third-party exports.",
|
||||
"Create versioned canonical snapshots and export manifests.",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _item(label: str, value: object, tone: str, description: str) -> dict[str, object]:
|
||||
return {"label": label, "value": value, "tone": tone, "description": description}
|
||||
|
||||
|
||||
def _lightweight_address_index_status(session: Session) -> dict[str, object]:
|
||||
dataset = active_routing_dataset(session)
|
||||
if dataset is None or not dataset.metadata_json:
|
||||
return {"stale": False, "version": None, "current_version": ADDRESS_INDEX_VERSION}
|
||||
try:
|
||||
metadata = json.loads(dataset.metadata_json or "{}")
|
||||
except json.JSONDecodeError:
|
||||
metadata = {}
|
||||
address_index = metadata.get("address_index") if isinstance(metadata, dict) else {}
|
||||
if not isinstance(address_index, dict):
|
||||
address_index = {}
|
||||
version = address_index.get("version")
|
||||
return {
|
||||
"stale": bool(address_index and version != ADDRESS_INDEX_VERSION),
|
||||
"version": version,
|
||||
"current_version": ADDRESS_INDEX_VERSION,
|
||||
}
|
||||
|
||||
|
||||
def _count(session: Session, model, *where) -> int:
|
||||
stmt = select(func.count()).select_from(model)
|
||||
if where:
|
||||
stmt = stmt.where(*where)
|
||||
return int(session.scalar(stmt) or 0)
|
||||
|
||||
|
||||
def _priority_catalog_backlog(session: Session) -> int:
|
||||
linked = select(Source.id).where(Source.catalog_entry_id == SourceCatalogEntry.id).exists()
|
||||
return int(
|
||||
session.scalar(
|
||||
select(func.count())
|
||||
.select_from(SourceCatalogEntry)
|
||||
.where(SourceCatalogEntry.priority.in_(["P0", "P0 fallback", "P1"]), ~linked)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
|
||||
def _job_status_counts(session: Session) -> dict[str, int]:
|
||||
return {
|
||||
str(status): int(count)
|
||||
for status, count in session.execute(
|
||||
select(Job.status, func.count())
|
||||
.where(Job.dismissed_at.is_(None), Job.status.in_(["queued", "running", "paused", "failed"]))
|
||||
.group_by(Job.status)
|
||||
).all()
|
||||
}
|
||||
|
||||
|
||||
def _gtfs_validation_counts(session: Session, dataset_ids: list[int]) -> dict[str, object]:
|
||||
if not dataset_ids:
|
||||
return {
|
||||
"agencies": 0,
|
||||
"stops": 0,
|
||||
"routes": 0,
|
||||
"trips": 0,
|
||||
"shapes": 0,
|
||||
"stops_without_coordinates": 0,
|
||||
"routes_without_geometry": 0,
|
||||
"routes_without_agency": 0,
|
||||
"calendar_range": "none",
|
||||
}
|
||||
calendar_min, calendar_max = session.execute(
|
||||
select(func.min(GtfsCalendar.start_date), func.max(GtfsCalendar.end_date)).where(GtfsCalendar.dataset_id.in_(dataset_ids))
|
||||
).one()
|
||||
exception_min, exception_max = session.execute(
|
||||
select(func.min(GtfsCalendarDate.date), func.max(GtfsCalendarDate.date)).where(GtfsCalendarDate.dataset_id.in_(dataset_ids))
|
||||
).one()
|
||||
min_date = min(value for value in [calendar_min, exception_min] if value is not None) if (calendar_min or exception_min) else None
|
||||
max_date = max(value for value in [calendar_max, exception_max] if value is not None) if (calendar_max or exception_max) else None
|
||||
return {
|
||||
"agencies": _count(session, GtfsAgency, GtfsAgency.dataset_id.in_(dataset_ids)),
|
||||
"stops": _count(session, GtfsStop, GtfsStop.dataset_id.in_(dataset_ids)),
|
||||
"routes": _count(session, GtfsRoute, GtfsRoute.dataset_id.in_(dataset_ids)),
|
||||
"trips": _count(session, GtfsTrip, GtfsTrip.dataset_id.in_(dataset_ids)),
|
||||
"shapes": _count(session, GtfsShape, GtfsShape.dataset_id.in_(dataset_ids)),
|
||||
"stops_without_coordinates": _count(
|
||||
session,
|
||||
GtfsStop,
|
||||
GtfsStop.dataset_id.in_(dataset_ids),
|
||||
(GtfsStop.lat.is_(None)) | (GtfsStop.lon.is_(None)),
|
||||
),
|
||||
"routes_without_geometry": _count(
|
||||
session,
|
||||
GtfsRoute,
|
||||
GtfsRoute.dataset_id.in_(dataset_ids),
|
||||
(GtfsRoute.geometry_geojson.is_(None)) | (GtfsRoute.geometry_geojson == ""),
|
||||
),
|
||||
"routes_without_agency": _count(
|
||||
session,
|
||||
GtfsRoute,
|
||||
GtfsRoute.dataset_id.in_(dataset_ids),
|
||||
(GtfsRoute.agency_id.is_(None)) | (GtfsRoute.agency_id == ""),
|
||||
),
|
||||
"calendar_range": f"{min_date or 'unknown'} -> {max_date or 'unknown'}",
|
||||
}
|
||||
|
||||
|
||||
def _link_quality_counts(session: Session, gtfs_dataset_ids: list[int], osm_dataset_ids: list[int]) -> dict[str, int]:
|
||||
if gtfs_dataset_ids:
|
||||
gtfs_link_exists = (
|
||||
select(CanonicalStopLink.id)
|
||||
.where(
|
||||
CanonicalStopLink.object_type == "gtfs_stop",
|
||||
CanonicalStopLink.dataset_id == GtfsStop.dataset_id,
|
||||
CanonicalStopLink.object_id == GtfsStop.id,
|
||||
)
|
||||
.exists()
|
||||
)
|
||||
gtfs_stops_without_canonical = _count(
|
||||
session,
|
||||
GtfsStop,
|
||||
GtfsStop.dataset_id.in_(gtfs_dataset_ids),
|
||||
~gtfs_link_exists,
|
||||
)
|
||||
gtfs_stop_links = _count(
|
||||
session,
|
||||
CanonicalStopLink,
|
||||
CanonicalStopLink.object_type == "gtfs_stop",
|
||||
CanonicalStopLink.dataset_id.in_(gtfs_dataset_ids),
|
||||
)
|
||||
multi_source_subquery = (
|
||||
select(CanonicalStopLink.canonical_stop_id)
|
||||
.where(CanonicalStopLink.object_type == "gtfs_stop", CanonicalStopLink.dataset_id.in_(gtfs_dataset_ids))
|
||||
.group_by(CanonicalStopLink.canonical_stop_id)
|
||||
.having(func.count(func.distinct(CanonicalStopLink.dataset_id)) > 1)
|
||||
.subquery()
|
||||
)
|
||||
multi_source_stop_groups = int(session.scalar(select(func.count()).select_from(multi_source_subquery)) or 0)
|
||||
else:
|
||||
gtfs_stops_without_canonical = 0
|
||||
gtfs_stop_links = 0
|
||||
multi_source_stop_groups = 0
|
||||
|
||||
if osm_dataset_ids:
|
||||
osm_link_exists = (
|
||||
select(CanonicalStopLink.id)
|
||||
.where(
|
||||
CanonicalStopLink.object_type == "osm_feature",
|
||||
CanonicalStopLink.dataset_id == OsmFeature.dataset_id,
|
||||
CanonicalStopLink.object_id == OsmFeature.id,
|
||||
)
|
||||
.exists()
|
||||
)
|
||||
osm_stops_without_canonical = _count(
|
||||
session,
|
||||
OsmFeature,
|
||||
OsmFeature.dataset_id.in_(osm_dataset_ids),
|
||||
OsmFeature.kind.in_(["stop", "station", "terminal"]),
|
||||
~osm_link_exists,
|
||||
)
|
||||
osm_stop_links = _count(
|
||||
session,
|
||||
CanonicalStopLink,
|
||||
CanonicalStopLink.object_type == "osm_feature",
|
||||
CanonicalStopLink.dataset_id.in_(osm_dataset_ids),
|
||||
)
|
||||
long_distance_osm_links = _count(
|
||||
session,
|
||||
CanonicalStopLink,
|
||||
CanonicalStopLink.object_type == "osm_feature",
|
||||
CanonicalStopLink.dataset_id.in_(osm_dataset_ids),
|
||||
CanonicalStopLink.distance_m > 150,
|
||||
)
|
||||
else:
|
||||
osm_stops_without_canonical = 0
|
||||
osm_stop_links = 0
|
||||
long_distance_osm_links = 0
|
||||
|
||||
return {
|
||||
"canonical_stops": _count(session, CanonicalStop),
|
||||
"gtfs_stop_links": gtfs_stop_links,
|
||||
"gtfs_stops_without_canonical": gtfs_stops_without_canonical,
|
||||
"osm_stop_links": osm_stop_links,
|
||||
"osm_stops_without_canonical": osm_stops_without_canonical,
|
||||
"multi_source_stop_groups": multi_source_stop_groups,
|
||||
"long_distance_osm_links": long_distance_osm_links,
|
||||
}
|
||||
|
||||
|
||||
def _route_quality_counts(session: Session, gtfs_dataset_ids: list[int]) -> dict[str, int]:
|
||||
route_patterns = _count(session, RoutePattern)
|
||||
route_pattern_stop_exists = (
|
||||
select(RoutePatternStop.id)
|
||||
.where(RoutePatternStop.route_pattern_id == RoutePattern.id)
|
||||
.exists()
|
||||
)
|
||||
route_patterns_without_stops = _count(session, RoutePattern, ~route_pattern_stop_exists)
|
||||
if not gtfs_dataset_ids:
|
||||
return {
|
||||
"matched_or_accepted": 0,
|
||||
"probable": 0,
|
||||
"weak": 0,
|
||||
"missing": 0,
|
||||
"routes_without_match": 0,
|
||||
"route_patterns": route_patterns,
|
||||
"route_patterns_without_stops": route_patterns_without_stops,
|
||||
}
|
||||
match_rows = {
|
||||
str(status): int(count)
|
||||
for status, count in session.execute(
|
||||
select(RouteMatch.status, func.count())
|
||||
.join(GtfsRoute, GtfsRoute.id == RouteMatch.gtfs_route_id)
|
||||
.where(GtfsRoute.dataset_id.in_(gtfs_dataset_ids))
|
||||
.group_by(RouteMatch.status)
|
||||
).all()
|
||||
}
|
||||
match_exists = select(RouteMatch.id).where(RouteMatch.gtfs_route_id == GtfsRoute.id).exists()
|
||||
routes_without_match = _count(session, GtfsRoute, GtfsRoute.dataset_id.in_(gtfs_dataset_ids), ~match_exists)
|
||||
return {
|
||||
"matched_or_accepted": match_rows.get("matched", 0) + match_rows.get("accepted", 0),
|
||||
"probable": match_rows.get("probable", 0),
|
||||
"weak": match_rows.get("weak", 0),
|
||||
"missing": match_rows.get("missing", 0),
|
||||
"routes_without_match": routes_without_match,
|
||||
"route_patterns": route_patterns,
|
||||
"route_patterns_without_stops": route_patterns_without_stops,
|
||||
}
|
||||
911
app/routing.py
Normal file
911
app/routing.py
Normal file
@@ -0,0 +1,911 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import heapq
|
||||
import json
|
||||
import math
|
||||
import threading
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models import Dataset, RoutingEdge, RoutingNode
|
||||
from app.pipeline.routing_layer import active_routing_dataset
|
||||
from app.serializers import feature_collection
|
||||
|
||||
|
||||
WALK_HEURISTIC_MPS = 1.6
|
||||
DRIVE_HEURISTIC_MPS = 36.0
|
||||
DEFAULT_MAX_VISITED = 160_000
|
||||
PGR_WALK_BBOX_PADDING_KM = [0.5, 1.5, 4, 10, 25]
|
||||
PGR_DRIVE_BBOX_PADDING_KM = [2, 8, 25, 75, 200]
|
||||
PGR_WALK_STATEMENT_TIMEOUT_MS = 2_500
|
||||
PGR_DRIVE_STATEMENT_TIMEOUT_MS = 7_500
|
||||
ROUTE_CACHE_TTL_SECONDS = 15 * 60
|
||||
ROUTE_CACHE_MAX_ENTRIES = 512
|
||||
_route_cache_lock = threading.RLock()
|
||||
_route_cache: OrderedDict[tuple[object, ...], tuple[float, dict[str, object]]] = OrderedDict()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _GraphNode:
|
||||
osm_node_id: int
|
||||
lon: float
|
||||
lat: float
|
||||
distance_m: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _Traversal:
|
||||
edge_id: int
|
||||
from_node: int
|
||||
to_node: int
|
||||
from_lon: float
|
||||
from_lat: float
|
||||
to_lon: float
|
||||
to_lat: float
|
||||
cost_s: float
|
||||
length_m: float
|
||||
highway: str | None
|
||||
name: str | None
|
||||
geometry_geojson: str
|
||||
reversed: bool
|
||||
|
||||
|
||||
def routing_status(db: Session) -> dict[str, object]:
|
||||
dataset = active_routing_dataset(db)
|
||||
dataset_id = None if dataset is None else int(dataset.id)
|
||||
node_count = 0
|
||||
edge_count = 0
|
||||
if dataset_id is not None:
|
||||
node_count, edge_count = _routing_status_counts(db, dataset, dataset_id)
|
||||
pgrouting_available = False
|
||||
pgrouting_installed = False
|
||||
if settings.is_postgresql_database:
|
||||
pgrouting_available = bool(
|
||||
db.execute(text("SELECT EXISTS (SELECT 1 FROM pg_available_extensions WHERE name = 'pgrouting')")).scalar()
|
||||
)
|
||||
pgrouting_installed = bool(
|
||||
db.execute(text("SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgrouting')")).scalar()
|
||||
)
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"nodes": node_count,
|
||||
"edges": edge_count,
|
||||
"available": edge_count > 0,
|
||||
"engine": "pgrouting" if pgrouting_installed else "python_astar",
|
||||
"pgrouting_available": pgrouting_available,
|
||||
"pgrouting_installed": pgrouting_installed,
|
||||
}
|
||||
|
||||
|
||||
def _routing_status_counts(db: Session, dataset: Dataset, dataset_id: int) -> tuple[int, int]:
|
||||
metadata = _metadata(dataset)
|
||||
routing_layer = metadata.get("routing_layer")
|
||||
if isinstance(routing_layer, dict):
|
||||
try:
|
||||
nodes = int(routing_layer.get("nodes") or 0)
|
||||
edges = int(routing_layer.get("edges") or 0)
|
||||
except (TypeError, ValueError):
|
||||
nodes = 0
|
||||
edges = 0
|
||||
if nodes or edges:
|
||||
return nodes, edges
|
||||
if settings.is_postgresql_database:
|
||||
rows = db.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT relname, COALESCE(reltuples, 0)::bigint AS estimate
|
||||
FROM pg_class
|
||||
WHERE oid IN ('routing_nodes'::regclass, 'routing_edges'::regclass)
|
||||
"""
|
||||
)
|
||||
).mappings()
|
||||
estimates = {str(row["relname"]): int(row["estimate"] or 0) for row in rows}
|
||||
return estimates.get("routing_nodes", 0), estimates.get("routing_edges", 0)
|
||||
node_count = int(db.scalar(select(func.count()).select_from(RoutingNode).where(RoutingNode.dataset_id == dataset_id)) or 0)
|
||||
edge_count = int(db.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset_id)) or 0)
|
||||
return node_count, edge_count
|
||||
|
||||
|
||||
def _metadata(dataset: Dataset) -> dict[str, object]:
|
||||
if not dataset.metadata_json:
|
||||
return {}
|
||||
try:
|
||||
value = json.loads(dataset.metadata_json)
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return value if isinstance(value, dict) else {}
|
||||
|
||||
|
||||
def route_between_points(
|
||||
db: Session,
|
||||
*,
|
||||
from_lon: float,
|
||||
from_lat: float,
|
||||
to_lon: float,
|
||||
to_lat: float,
|
||||
mode: str = "walk",
|
||||
dataset_id: int | None = None,
|
||||
max_visited: int = DEFAULT_MAX_VISITED,
|
||||
) -> dict[str, object]:
|
||||
if mode not in {"walk", "drive"}:
|
||||
raise ValueError("mode must be walk or drive")
|
||||
dataset = db.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(db)
|
||||
if dataset is None:
|
||||
raise ValueError("No routing dataset is available.")
|
||||
dataset_id = int(dataset.id)
|
||||
cache_key = _route_cache_key(dataset_id, mode, from_lon, from_lat, to_lon, to_lat)
|
||||
cached = _route_cache_get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
start = _nearest_node(db, dataset_id, from_lon, from_lat, mode)
|
||||
target = _nearest_node(db, dataset_id, to_lon, to_lat, mode)
|
||||
if start is None or target is None:
|
||||
raise ValueError("Routing graph has no nearby nodes for the requested mode.")
|
||||
if start.osm_node_id == target.osm_node_id:
|
||||
payload = _single_point_route(start, from_lon, from_lat, to_lon, to_lat, mode, dataset_id)
|
||||
_route_cache_put(cache_key, payload)
|
||||
return payload
|
||||
if settings.is_postgresql_database and _pgrouting_installed(db):
|
||||
try:
|
||||
payload = _route_with_pgrouting(
|
||||
db,
|
||||
dataset_id=dataset_id,
|
||||
mode=mode,
|
||||
start=start,
|
||||
target=target,
|
||||
from_lon=from_lon,
|
||||
from_lat=from_lat,
|
||||
to_lon=to_lon,
|
||||
to_lat=to_lat,
|
||||
)
|
||||
_route_cache_put(cache_key, payload)
|
||||
return payload
|
||||
except ValueError:
|
||||
pass
|
||||
except SQLAlchemyError:
|
||||
db.rollback()
|
||||
|
||||
heuristic_mps = WALK_HEURISTIC_MPS if mode == "walk" else DRIVE_HEURISTIC_MPS
|
||||
queue: list[tuple[float, float, int]] = []
|
||||
heapq.heappush(queue, (0.0, 0.0, start.osm_node_id))
|
||||
costs: dict[int, float] = {start.osm_node_id: 0.0}
|
||||
coords: dict[int, tuple[float, float]] = {start.osm_node_id: (start.lon, start.lat), target.osm_node_id: (target.lon, target.lat)}
|
||||
previous: dict[int, tuple[int, _Traversal]] = {}
|
||||
adjacency_cache: dict[int, list[_Traversal]] = {}
|
||||
visited: set[int] = set()
|
||||
|
||||
while queue and len(visited) < max(1, max_visited):
|
||||
_, cost, node_id = heapq.heappop(queue)
|
||||
if node_id in visited:
|
||||
continue
|
||||
visited.add(node_id)
|
||||
if node_id == target.osm_node_id:
|
||||
payload = _route_payload(
|
||||
dataset_id=dataset_id,
|
||||
mode=mode,
|
||||
start=start,
|
||||
target=target,
|
||||
from_lon=from_lon,
|
||||
from_lat=from_lat,
|
||||
to_lon=to_lon,
|
||||
to_lat=to_lat,
|
||||
previous=previous,
|
||||
total_cost_s=cost,
|
||||
visited=len(visited),
|
||||
)
|
||||
_route_cache_put(cache_key, payload)
|
||||
return payload
|
||||
for edge in adjacency_cache.setdefault(node_id, _outgoing_edges(db, dataset_id, node_id, mode)):
|
||||
coords[edge.to_node] = (edge.to_lon, edge.to_lat)
|
||||
next_cost = cost + edge.cost_s
|
||||
if next_cost >= costs.get(edge.to_node, float("inf")):
|
||||
continue
|
||||
costs[edge.to_node] = next_cost
|
||||
previous[edge.to_node] = (node_id, edge)
|
||||
heuristic = _distance_m(edge.to_lat, edge.to_lon, target.lat, target.lon) / heuristic_mps
|
||||
heapq.heappush(queue, (next_cost + heuristic, next_cost, edge.to_node))
|
||||
|
||||
raise ValueError(f"No {mode} route found within {max_visited:,} visited graph nodes.")
|
||||
|
||||
|
||||
def direct_route_between_points(
|
||||
db: Session,
|
||||
*,
|
||||
from_lon: float,
|
||||
from_lat: float,
|
||||
to_lon: float,
|
||||
to_lat: float,
|
||||
mode: str = "walk",
|
||||
dataset_id: int | None = None,
|
||||
reason: str | None = None,
|
||||
) -> dict[str, object]:
|
||||
if mode not in {"walk", "drive"}:
|
||||
raise ValueError("mode must be walk or drive")
|
||||
dataset = db.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(db)
|
||||
payload = _direct_route_payload(
|
||||
dataset_id=0 if dataset is None else int(dataset.id),
|
||||
mode=mode,
|
||||
from_lon=float(from_lon),
|
||||
from_lat=float(from_lat),
|
||||
to_lon=float(to_lon),
|
||||
to_lat=float(to_lat),
|
||||
)
|
||||
if reason:
|
||||
payload["warning"] = reason
|
||||
return payload
|
||||
|
||||
|
||||
def snap_point_to_routing_graph(
|
||||
db: Session,
|
||||
*,
|
||||
lon: float,
|
||||
lat: float,
|
||||
mode: str = "walk",
|
||||
dataset_id: int | None = None,
|
||||
max_distance_m: float = 250,
|
||||
) -> dict[str, object] | None:
|
||||
if mode not in {"walk", "drive"}:
|
||||
raise ValueError("mode must be walk or drive")
|
||||
dataset = db.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(db)
|
||||
if dataset is None:
|
||||
return None
|
||||
dataset_id = int(dataset.id)
|
||||
if settings.is_postgresql_database:
|
||||
return _snap_point_to_routing_edge_postgresql(
|
||||
db,
|
||||
dataset_id=dataset_id,
|
||||
lon=float(lon),
|
||||
lat=float(lat),
|
||||
mode=mode,
|
||||
max_distance_m=float(max_distance_m),
|
||||
)
|
||||
node = _nearest_node(db, dataset_id, float(lon), float(lat), mode)
|
||||
if node is None or node.distance_m > max_distance_m:
|
||||
return None
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"lon": node.lon,
|
||||
"lat": node.lat,
|
||||
"distance_m": round(node.distance_m, 1),
|
||||
"source": "routing_node",
|
||||
"osm_node_id": node.osm_node_id,
|
||||
}
|
||||
|
||||
|
||||
def _snap_point_to_routing_edge_postgresql(
|
||||
db: Session,
|
||||
*,
|
||||
dataset_id: int,
|
||||
lon: float,
|
||||
lat: float,
|
||||
mode: str,
|
||||
max_distance_m: float,
|
||||
) -> dict[str, object] | None:
|
||||
cost_column = "walk_cost_s" if mode == "walk" else "drive_cost_s"
|
||||
reverse_cost_column = "reverse_walk_cost_s" if mode == "walk" else "reverse_drive_cost_s"
|
||||
radius_deg = max_distance_m / 111_320
|
||||
row = db.execute(
|
||||
text(
|
||||
f"""
|
||||
WITH point AS (
|
||||
SELECT ST_SetSRID(ST_MakePoint(:lon, :lat), 4326) AS geom
|
||||
),
|
||||
edges AS MATERIALIZED (
|
||||
SELECT
|
||||
edge.id,
|
||||
edge.highway,
|
||||
edge.name,
|
||||
CASE
|
||||
WHEN edge.tags_json IS NULL OR edge.tags_json = '' THEN NULL
|
||||
ELSE edge.tags_json::jsonb ->> 'service'
|
||||
END AS service,
|
||||
edge.source_osm_node_id,
|
||||
edge.target_osm_node_id,
|
||||
ST_SetSRID(
|
||||
ST_MakeLine(
|
||||
ST_MakePoint(edge.source_lon, edge.source_lat),
|
||||
ST_MakePoint(edge.target_lon, edge.target_lat)
|
||||
),
|
||||
4326
|
||||
) AS edge_geom
|
||||
FROM routing_edges AS edge
|
||||
CROSS JOIN point
|
||||
WHERE edge.dataset_id = :dataset_id
|
||||
AND (edge.{cost_column} IS NOT NULL OR edge.{reverse_cost_column} IS NOT NULL)
|
||||
AND box(point(edge.max_lon, edge.max_lat), point(edge.min_lon, edge.min_lat))
|
||||
&& box(
|
||||
point(:lon + :radius_deg, :lat + :radius_deg),
|
||||
point(:lon - :radius_deg, :lat - :radius_deg)
|
||||
)
|
||||
),
|
||||
candidate AS (
|
||||
SELECT
|
||||
edges.id,
|
||||
edges.highway,
|
||||
edges.name,
|
||||
edges.service,
|
||||
edges.source_osm_node_id,
|
||||
edges.target_osm_node_id,
|
||||
ST_ClosestPoint(edges.edge_geom, point.geom) AS snapped_geom,
|
||||
ST_DistanceSphere(edges.edge_geom, point.geom) AS distance_m,
|
||||
CASE
|
||||
WHEN edges.highway IN ('footway', 'pedestrian', 'steps') THEN 0
|
||||
WHEN edges.highway IN ('path', 'cycleway', 'bridleway') THEN 1
|
||||
WHEN edges.highway IN ('living_street', 'residential') THEN 2
|
||||
WHEN edges.highway = 'service' THEN 3
|
||||
ELSE 4
|
||||
END AS highway_rank,
|
||||
CASE
|
||||
WHEN :mode != 'walk' THEN 0
|
||||
WHEN edges.highway = 'service' THEN 20
|
||||
WHEN edges.highway IN ('primary', 'primary_link', 'secondary', 'secondary_link') THEN 10
|
||||
WHEN edges.highway IN ('tertiary', 'tertiary_link', 'unclassified', 'road') THEN 5
|
||||
ELSE 0
|
||||
END AS snap_penalty_m
|
||||
FROM edges
|
||||
CROSS JOIN point
|
||||
WHERE ST_DWithin(edges.edge_geom::geography, point.geom::geography, :max_distance_m)
|
||||
AND NOT (
|
||||
:mode = 'walk'
|
||||
AND edges.highway = 'service'
|
||||
AND COALESCE(edges.service, '') IN ('driveway', 'parking_aisle', 'drive-through')
|
||||
)
|
||||
ORDER BY
|
||||
ST_DistanceSphere(edges.edge_geom, point.geom) + CASE
|
||||
WHEN :mode != 'walk' THEN 0
|
||||
WHEN edges.highway = 'service' THEN 20
|
||||
WHEN edges.highway IN ('primary', 'primary_link', 'secondary', 'secondary_link') THEN 10
|
||||
WHEN edges.highway IN ('tertiary', 'tertiary_link', 'unclassified', 'road') THEN 5
|
||||
ELSE 0
|
||||
END,
|
||||
ST_DistanceSphere(edges.edge_geom, point.geom),
|
||||
highway_rank,
|
||||
edges.id
|
||||
LIMIT 1
|
||||
)
|
||||
SELECT
|
||||
id,
|
||||
highway,
|
||||
name,
|
||||
source_osm_node_id,
|
||||
target_osm_node_id,
|
||||
ST_X(snapped_geom) AS lon,
|
||||
ST_Y(snapped_geom) AS lat,
|
||||
distance_m
|
||||
FROM candidate
|
||||
"""
|
||||
),
|
||||
{
|
||||
"dataset_id": dataset_id,
|
||||
"lon": lon,
|
||||
"lat": lat,
|
||||
"radius_deg": radius_deg,
|
||||
"max_distance_m": max_distance_m,
|
||||
"mode": mode,
|
||||
},
|
||||
).mappings().first()
|
||||
if row is None:
|
||||
return None
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"lon": float(row["lon"]),
|
||||
"lat": float(row["lat"]),
|
||||
"distance_m": round(float(row["distance_m"] or 0), 1),
|
||||
"source": "routing_edge",
|
||||
"edge_id": int(row["id"]),
|
||||
"highway": row["highway"],
|
||||
"name": row["name"],
|
||||
"source_osm_node_id": int(row["source_osm_node_id"]),
|
||||
"target_osm_node_id": int(row["target_osm_node_id"]),
|
||||
}
|
||||
|
||||
|
||||
def _route_cache_key(dataset_id: int, mode: str, from_lon: float, from_lat: float, to_lon: float, to_lat: float) -> tuple[object, ...]:
|
||||
return (
|
||||
int(dataset_id),
|
||||
mode,
|
||||
round(float(from_lon), 6),
|
||||
round(float(from_lat), 6),
|
||||
round(float(to_lon), 6),
|
||||
round(float(to_lat), 6),
|
||||
)
|
||||
|
||||
|
||||
def _route_cache_get(key: tuple[object, ...]) -> dict[str, object] | None:
|
||||
now = time.monotonic()
|
||||
with _route_cache_lock:
|
||||
cached = _route_cache.get(key)
|
||||
if cached is None:
|
||||
return None
|
||||
expires_at, payload = cached
|
||||
if expires_at <= now:
|
||||
_route_cache.pop(key, None)
|
||||
return None
|
||||
_route_cache.move_to_end(key)
|
||||
return copy.deepcopy(payload)
|
||||
|
||||
|
||||
def _route_cache_put(key: tuple[object, ...], payload: dict[str, object]) -> None:
|
||||
with _route_cache_lock:
|
||||
_route_cache[key] = (time.monotonic() + ROUTE_CACHE_TTL_SECONDS, copy.deepcopy(payload))
|
||||
_route_cache.move_to_end(key)
|
||||
while len(_route_cache) > ROUTE_CACHE_MAX_ENTRIES:
|
||||
_route_cache.popitem(last=False)
|
||||
|
||||
|
||||
def _pgrouting_installed(db: Session) -> bool:
|
||||
return bool(db.execute(text("SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgrouting')")).scalar())
|
||||
|
||||
|
||||
def _route_with_pgrouting(
|
||||
db: Session,
|
||||
*,
|
||||
dataset_id: int,
|
||||
mode: str,
|
||||
start: _GraphNode,
|
||||
target: _GraphNode,
|
||||
from_lon: float,
|
||||
from_lat: float,
|
||||
to_lon: float,
|
||||
to_lat: float,
|
||||
) -> dict[str, object]:
|
||||
cost_column = "walk_cost_s" if mode == "walk" else "drive_cost_s"
|
||||
reverse_cost_column = "reverse_walk_cost_s" if mode == "walk" else "reverse_drive_cost_s"
|
||||
routing_cost = _routing_cost_expression(cost_column, mode)
|
||||
reverse_routing_cost = _routing_cost_expression(reverse_cost_column, mode)
|
||||
for padding_km in PGR_WALK_BBOX_PADDING_KM if mode == "walk" else PGR_DRIVE_BBOX_PADDING_KM:
|
||||
_set_local_statement_timeout(
|
||||
db,
|
||||
PGR_WALK_STATEMENT_TIMEOUT_MS if mode == "walk" else PGR_DRIVE_STATEMENT_TIMEOUT_MS,
|
||||
)
|
||||
bbox = _expanded_bbox(
|
||||
min(from_lon, to_lon, start.lon, target.lon),
|
||||
min(from_lat, to_lat, start.lat, target.lat),
|
||||
max(from_lon, to_lon, start.lon, target.lon),
|
||||
max(from_lat, to_lat, start.lat, target.lat),
|
||||
padding_km,
|
||||
)
|
||||
edge_sql = f"""
|
||||
SELECT
|
||||
id,
|
||||
source_osm_node_id AS source,
|
||||
target_osm_node_id AS target,
|
||||
COALESCE({routing_cost}, -1)::float8 AS cost,
|
||||
COALESCE({reverse_routing_cost}, -1)::float8 AS reverse_cost
|
||||
FROM routing_edges
|
||||
WHERE dataset_id = {int(dataset_id)}
|
||||
AND ({cost_column} IS NOT NULL OR {reverse_cost_column} IS NOT NULL)
|
||||
AND box(point(max_lon, max_lat), point(min_lon, min_lat))
|
||||
&& box(point({bbox[2]:.8f}, {bbox[3]:.8f}), point({bbox[0]:.8f}, {bbox[1]:.8f}))
|
||||
"""
|
||||
rows = db.execute(
|
||||
text(
|
||||
f"""
|
||||
WITH route AS (
|
||||
SELECT *
|
||||
FROM pgr_dijkstra(:edge_sql, :start_node, :target_node, directed := true)
|
||||
),
|
||||
steps AS (
|
||||
SELECT
|
||||
route.path_seq,
|
||||
route.node AS from_node,
|
||||
LEAD(route.node) OVER (ORDER BY route.path_seq) AS to_node,
|
||||
route.edge,
|
||||
route.cost
|
||||
FROM route
|
||||
)
|
||||
SELECT
|
||||
steps.path_seq,
|
||||
steps.from_node,
|
||||
steps.to_node,
|
||||
steps.cost,
|
||||
edge.id,
|
||||
edge.source_osm_node_id,
|
||||
edge.target_osm_node_id,
|
||||
edge.source_lon,
|
||||
edge.source_lat,
|
||||
edge.target_lon,
|
||||
edge.target_lat,
|
||||
edge.length_m,
|
||||
edge.highway,
|
||||
edge.name,
|
||||
edge.geometry_geojson,
|
||||
CASE
|
||||
WHEN steps.from_node = edge.source_osm_node_id THEN edge.{cost_column}
|
||||
ELSE edge.{reverse_cost_column}
|
||||
END AS actual_cost_s
|
||||
FROM steps
|
||||
JOIN routing_edges AS edge ON edge.id = steps.edge
|
||||
WHERE steps.edge <> -1
|
||||
ORDER BY steps.path_seq
|
||||
"""
|
||||
),
|
||||
{"edge_sql": edge_sql, "start_node": start.osm_node_id, "target_node": target.osm_node_id},
|
||||
).all()
|
||||
if rows:
|
||||
return _pgrouting_payload(
|
||||
dataset_id=dataset_id,
|
||||
mode=mode,
|
||||
start=start,
|
||||
target=target,
|
||||
from_lon=from_lon,
|
||||
from_lat=from_lat,
|
||||
to_lon=to_lon,
|
||||
to_lat=to_lat,
|
||||
rows=rows,
|
||||
padding_km=padding_km,
|
||||
)
|
||||
raise ValueError("pgRouting did not find a route in the bounded search area.")
|
||||
|
||||
|
||||
def _set_local_statement_timeout(db: Session, timeout_ms: int) -> None:
|
||||
db.execute(text("SELECT set_config('statement_timeout', :timeout, true)"), {"timeout": f"{int(timeout_ms)}ms"})
|
||||
|
||||
|
||||
def _pgrouting_payload(
|
||||
*,
|
||||
dataset_id: int,
|
||||
mode: str,
|
||||
start: _GraphNode,
|
||||
target: _GraphNode,
|
||||
from_lon: float,
|
||||
from_lat: float,
|
||||
to_lon: float,
|
||||
to_lat: float,
|
||||
rows,
|
||||
padding_km: float,
|
||||
) -> dict[str, object]:
|
||||
previous: dict[int, tuple[int, _Traversal]] = {}
|
||||
total_cost = 0.0
|
||||
for row in rows:
|
||||
if row.to_node is None:
|
||||
continue
|
||||
from_node = int(row.from_node)
|
||||
to_node = int(row.to_node)
|
||||
source_node = int(row.source_osm_node_id)
|
||||
target_node = int(row.target_osm_node_id)
|
||||
actual_cost = float(row.actual_cost_s if row.actual_cost_s is not None else row.cost or 0)
|
||||
reversed_edge = from_node == target_node and to_node == source_node
|
||||
if reversed_edge:
|
||||
from_lon_edge, from_lat_edge = float(row.target_lon), float(row.target_lat)
|
||||
to_lon_edge, to_lat_edge = float(row.source_lon), float(row.source_lat)
|
||||
else:
|
||||
from_lon_edge, from_lat_edge = float(row.source_lon), float(row.source_lat)
|
||||
to_lon_edge, to_lat_edge = float(row.target_lon), float(row.target_lat)
|
||||
total_cost += actual_cost
|
||||
previous[to_node] = (
|
||||
from_node,
|
||||
_Traversal(
|
||||
edge_id=int(row.id),
|
||||
from_node=from_node,
|
||||
to_node=to_node,
|
||||
from_lon=from_lon_edge,
|
||||
from_lat=from_lat_edge,
|
||||
to_lon=to_lon_edge,
|
||||
to_lat=to_lat_edge,
|
||||
cost_s=actual_cost,
|
||||
length_m=float(row.length_m),
|
||||
highway=row.highway,
|
||||
name=row.name,
|
||||
geometry_geojson=str(row.geometry_geojson),
|
||||
reversed=reversed_edge,
|
||||
),
|
||||
)
|
||||
payload = _route_payload(
|
||||
dataset_id=dataset_id,
|
||||
mode=mode,
|
||||
start=start,
|
||||
target=target,
|
||||
from_lon=from_lon,
|
||||
from_lat=from_lat,
|
||||
to_lon=to_lon,
|
||||
to_lat=to_lat,
|
||||
previous=previous,
|
||||
total_cost_s=total_cost,
|
||||
visited=len(rows),
|
||||
)
|
||||
payload["engine"] = "pgrouting"
|
||||
payload["bbox_padding_km"] = padding_km
|
||||
return payload
|
||||
|
||||
|
||||
def _routing_cost_expression(column: str, mode: str) -> str:
|
||||
if mode != "walk":
|
||||
return column
|
||||
return f"""
|
||||
CASE
|
||||
WHEN {column} IS NULL THEN NULL
|
||||
ELSE {column} * CASE
|
||||
WHEN highway IN ('footway', 'pedestrian') THEN 0.70
|
||||
WHEN highway = 'path' THEN 0.78
|
||||
WHEN highway = 'steps' THEN 0.95
|
||||
WHEN highway = 'cycleway' THEN 1.05
|
||||
WHEN highway = 'bridleway' THEN 1.10
|
||||
WHEN highway IN ('living_street', 'track') THEN 1.15
|
||||
WHEN highway IN ('residential', 'service') THEN 1.35
|
||||
WHEN highway IN ('unclassified', 'road') THEN 1.55
|
||||
WHEN highway IN ('tertiary', 'tertiary_link') THEN 1.80
|
||||
WHEN highway IN ('secondary', 'secondary_link') THEN 2.15
|
||||
WHEN highway IN ('primary', 'primary_link') THEN 2.50
|
||||
ELSE 1.30
|
||||
END
|
||||
END
|
||||
"""
|
||||
|
||||
|
||||
def _nearest_node(db: Session, dataset_id: int, lon: float, lat: float, mode: str) -> _GraphNode | None:
|
||||
cost_column = "walk_cost_s" if mode == "walk" else "drive_cost_s"
|
||||
reverse_cost_column = "reverse_walk_cost_s" if mode == "walk" else "reverse_drive_cost_s"
|
||||
row = None
|
||||
for candidate_limit in (64, 512, 4096):
|
||||
row = db.execute(
|
||||
text(
|
||||
f"""
|
||||
WITH nearest AS MATERIALIZED (
|
||||
SELECT node.osm_node_id, node.lon, node.lat, node.geom
|
||||
FROM routing_nodes AS node
|
||||
WHERE node.dataset_id = :dataset_id
|
||||
AND node.geom IS NOT NULL
|
||||
ORDER BY node.geom <-> ST_SetSRID(ST_MakePoint(:lon, :lat), 4326)
|
||||
LIMIT :candidate_limit
|
||||
),
|
||||
candidate AS (
|
||||
SELECT nearest.osm_node_id, nearest.lon, nearest.lat, nearest.geom
|
||||
FROM nearest
|
||||
WHERE EXISTS (
|
||||
SELECT 1
|
||||
FROM routing_edges AS edge
|
||||
WHERE edge.dataset_id = :dataset_id
|
||||
AND (
|
||||
(edge.source_osm_node_id = nearest.osm_node_id AND edge.{cost_column} IS NOT NULL)
|
||||
OR (edge.target_osm_node_id = nearest.osm_node_id AND edge.{reverse_cost_column} IS NOT NULL)
|
||||
)
|
||||
LIMIT 1
|
||||
)
|
||||
ORDER BY nearest.geom <-> ST_SetSRID(ST_MakePoint(:lon, :lat), 4326)
|
||||
LIMIT 1
|
||||
)
|
||||
SELECT osm_node_id, lon, lat, ST_DistanceSphere(geom, ST_SetSRID(ST_MakePoint(:lon, :lat), 4326)) AS distance_m
|
||||
FROM candidate
|
||||
"""
|
||||
),
|
||||
{"dataset_id": dataset_id, "lon": lon, "lat": lat, "candidate_limit": candidate_limit},
|
||||
).first()
|
||||
if row is not None:
|
||||
break
|
||||
if row is None:
|
||||
return None
|
||||
return _GraphNode(osm_node_id=int(row.osm_node_id), lon=float(row.lon), lat=float(row.lat), distance_m=float(row.distance_m or 0))
|
||||
|
||||
|
||||
def _outgoing_edges(db: Session, dataset_id: int, node_id: int, mode: str) -> list[_Traversal]:
|
||||
cost_column = "walk_cost_s" if mode == "walk" else "drive_cost_s"
|
||||
reverse_cost_column = "reverse_walk_cost_s" if mode == "walk" else "reverse_drive_cost_s"
|
||||
rows = db.execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT
|
||||
id, source_osm_node_id, target_osm_node_id,
|
||||
source_lon, source_lat, target_lon, target_lat,
|
||||
length_m, highway, name, geometry_geojson,
|
||||
CASE
|
||||
WHEN source_osm_node_id = :node_id THEN {cost_column}
|
||||
ELSE {reverse_cost_column}
|
||||
END AS cost_s,
|
||||
target_osm_node_id != :node_id AS forward
|
||||
FROM routing_edges
|
||||
WHERE dataset_id = :dataset_id
|
||||
AND (
|
||||
(source_osm_node_id = :node_id AND {cost_column} IS NOT NULL)
|
||||
OR (target_osm_node_id = :node_id AND {reverse_cost_column} IS NOT NULL)
|
||||
)
|
||||
"""
|
||||
),
|
||||
{"dataset_id": dataset_id, "node_id": node_id},
|
||||
).all()
|
||||
edges = []
|
||||
for row in rows:
|
||||
forward = bool(row.forward)
|
||||
if forward:
|
||||
to_node = int(row.target_osm_node_id)
|
||||
from_lon, from_lat = float(row.source_lon), float(row.source_lat)
|
||||
to_lon, to_lat = float(row.target_lon), float(row.target_lat)
|
||||
else:
|
||||
to_node = int(row.source_osm_node_id)
|
||||
from_lon, from_lat = float(row.target_lon), float(row.target_lat)
|
||||
to_lon, to_lat = float(row.source_lon), float(row.source_lat)
|
||||
edges.append(
|
||||
_Traversal(
|
||||
edge_id=int(row.id),
|
||||
from_node=node_id,
|
||||
to_node=to_node,
|
||||
from_lon=from_lon,
|
||||
from_lat=from_lat,
|
||||
to_lon=to_lon,
|
||||
to_lat=to_lat,
|
||||
cost_s=float(row.cost_s),
|
||||
length_m=float(row.length_m),
|
||||
highway=row.highway,
|
||||
name=row.name,
|
||||
geometry_geojson=str(row.geometry_geojson),
|
||||
reversed=not forward,
|
||||
)
|
||||
)
|
||||
return edges
|
||||
|
||||
|
||||
def _route_payload(
|
||||
*,
|
||||
dataset_id: int,
|
||||
mode: str,
|
||||
start: _GraphNode,
|
||||
target: _GraphNode,
|
||||
from_lon: float,
|
||||
from_lat: float,
|
||||
to_lon: float,
|
||||
to_lat: float,
|
||||
previous: dict[int, tuple[int, _Traversal]],
|
||||
total_cost_s: float,
|
||||
visited: int,
|
||||
) -> dict[str, object]:
|
||||
edges: list[_Traversal] = []
|
||||
current = target.osm_node_id
|
||||
while current != start.osm_node_id:
|
||||
prior, edge = previous[current]
|
||||
edges.append(edge)
|
||||
current = prior
|
||||
edges.reverse()
|
||||
network_distance = sum(edge.length_m for edge in edges)
|
||||
access_distance = start.distance_m + target.distance_m
|
||||
features = []
|
||||
if start.distance_m:
|
||||
features.append(_connector_feature("access", mode, [[from_lon, from_lat], [start.lon, start.lat]], start.distance_m))
|
||||
for index, edge in enumerate(edges, start=1):
|
||||
geometry = json.loads(edge.geometry_geojson)
|
||||
if edge.reversed:
|
||||
geometry["coordinates"] = list(reversed(geometry.get("coordinates", [])))
|
||||
features.append(
|
||||
{
|
||||
"type": "Feature",
|
||||
"geometry": geometry,
|
||||
"properties": {
|
||||
"feature_type": "routing_edge",
|
||||
"sequence": index,
|
||||
"mode": mode,
|
||||
"edge_id": edge.edge_id,
|
||||
"highway": edge.highway,
|
||||
"name": edge.name,
|
||||
"length_m": edge.length_m,
|
||||
"cost_s": edge.cost_s,
|
||||
},
|
||||
}
|
||||
)
|
||||
if target.distance_m:
|
||||
features.append(_connector_feature("egress", mode, [[target.lon, target.lat], [to_lon, to_lat]], target.distance_m))
|
||||
duration_seconds = total_cost_s + _connector_seconds(access_distance, mode)
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"mode": mode,
|
||||
"engine": "python_astar",
|
||||
"distance_m": round(network_distance + access_distance, 1),
|
||||
"network_distance_m": round(network_distance, 1),
|
||||
"access_distance_m": round(access_distance, 1),
|
||||
"duration_seconds": round(duration_seconds, 1),
|
||||
"duration_minutes": _duration_minutes_ceil(duration_seconds),
|
||||
"duration_label": _duration_label(duration_seconds),
|
||||
"visited_nodes": visited,
|
||||
"start_node": {"osm_node_id": start.osm_node_id, "distance_m": round(start.distance_m, 1)},
|
||||
"target_node": {"osm_node_id": target.osm_node_id, "distance_m": round(target.distance_m, 1)},
|
||||
"features": feature_collection(features),
|
||||
}
|
||||
|
||||
|
||||
def _single_point_route(start: _GraphNode, from_lon: float, from_lat: float, to_lon: float, to_lat: float, mode: str, dataset_id: int) -> dict[str, object]:
|
||||
return _direct_route_payload(
|
||||
dataset_id=dataset_id,
|
||||
mode=mode,
|
||||
from_lon=from_lon,
|
||||
from_lat=from_lat,
|
||||
to_lon=to_lon,
|
||||
to_lat=to_lat,
|
||||
engine="python_astar",
|
||||
start_node={"osm_node_id": start.osm_node_id, "distance_m": round(start.distance_m, 1)},
|
||||
target_node={"osm_node_id": start.osm_node_id, "distance_m": round(start.distance_m, 1)},
|
||||
visited_nodes=1,
|
||||
)
|
||||
|
||||
|
||||
def _direct_route_payload(
|
||||
*,
|
||||
dataset_id: int,
|
||||
mode: str,
|
||||
from_lon: float,
|
||||
from_lat: float,
|
||||
to_lon: float,
|
||||
to_lat: float,
|
||||
engine: str = "direct_fallback",
|
||||
start_node: dict[str, object] | None = None,
|
||||
target_node: dict[str, object] | None = None,
|
||||
visited_nodes: int = 0,
|
||||
) -> dict[str, object]:
|
||||
distance = _distance_m(from_lat, from_lon, to_lat, to_lon)
|
||||
duration_seconds = _connector_seconds(distance, mode)
|
||||
return {
|
||||
"dataset_id": dataset_id,
|
||||
"mode": mode,
|
||||
"engine": engine,
|
||||
"distance_m": round(distance, 1),
|
||||
"network_distance_m": 0,
|
||||
"access_distance_m": round(distance, 1),
|
||||
"duration_seconds": round(duration_seconds, 1),
|
||||
"duration_minutes": _duration_minutes_ceil(duration_seconds),
|
||||
"duration_label": _duration_label(duration_seconds),
|
||||
"visited_nodes": visited_nodes,
|
||||
"start_node": start_node,
|
||||
"target_node": target_node,
|
||||
"features": feature_collection([_connector_feature("direct", mode, [[from_lon, from_lat], [to_lon, to_lat]], distance)]),
|
||||
}
|
||||
|
||||
|
||||
def _connector_feature(kind: str, mode: str, coordinates: list[list[float]], distance_m: float) -> dict:
|
||||
return {
|
||||
"type": "Feature",
|
||||
"geometry": {"type": "LineString", "coordinates": coordinates},
|
||||
"properties": {
|
||||
"feature_type": "routing_connector",
|
||||
"connector": kind,
|
||||
"mode": mode,
|
||||
"length_m": distance_m,
|
||||
"cost_s": _connector_seconds(distance_m, mode),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _connector_seconds(distance_m: float, mode: str) -> float:
|
||||
speed = 1.35 if mode == "walk" else 8.0
|
||||
return float(distance_m) / speed
|
||||
|
||||
|
||||
def _duration_minutes_ceil(seconds: int | float | None) -> int | None:
|
||||
if seconds is None:
|
||||
return None
|
||||
return max(0, int(math.ceil(float(seconds) / 60)))
|
||||
|
||||
|
||||
def _duration_label(seconds: int | float | None) -> str | None:
|
||||
minutes_total = _duration_minutes_ceil(seconds)
|
||||
if minutes_total is None:
|
||||
return None
|
||||
days = minutes_total // (24 * 60)
|
||||
remaining = minutes_total % (24 * 60)
|
||||
hours = remaining // 60
|
||||
minutes = remaining % 60
|
||||
if days:
|
||||
return f"{days}d {hours:02d}:{minutes:02d}"
|
||||
if hours:
|
||||
return f"{hours}:{minutes:02d}"
|
||||
return f"{minutes} min"
|
||||
|
||||
|
||||
def _expanded_bbox(min_lon: float, min_lat: float, max_lon: float, max_lat: float, padding_km: float) -> tuple[float, float, float, float]:
|
||||
mid_lat = (min_lat + max_lat) / 2
|
||||
lat_delta = padding_km / 111.0
|
||||
lon_delta = padding_km / max(1.0, 111.0 * math.cos(math.radians(mid_lat)))
|
||||
return (min_lon - lon_delta, min_lat - lat_delta, max_lon + lon_delta, max_lat + lat_delta)
|
||||
|
||||
|
||||
def _distance_m(lat_a: float, lon_a: float, lat_b: float, lon_b: float) -> float:
|
||||
radius = 6_371_000.0
|
||||
phi_a = math.radians(lat_a)
|
||||
phi_b = math.radians(lat_b)
|
||||
delta_phi = math.radians(lat_b - lat_a)
|
||||
delta_lambda = math.radians(lon_b - lon_a)
|
||||
hav = math.sin(delta_phi / 2) ** 2 + math.cos(phi_a) * math.cos(phi_b) * math.sin(delta_lambda / 2) ** 2
|
||||
return radius * 2 * math.atan2(math.sqrt(hav), math.sqrt(1 - hav))
|
||||
130
app/serializers.py
Normal file
130
app/serializers.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Iterable
|
||||
|
||||
from app.models import GtfsRoute, GtfsStop, OsmFeature, RouteMatch, RoutePattern
|
||||
from app.osm_storage import osm_feature_public_id
|
||||
|
||||
|
||||
def feature_collection(features: Iterable[dict[str, Any]]) -> dict[str, Any]:
|
||||
return {"type": "FeatureCollection", "features": list(features)}
|
||||
|
||||
|
||||
def gtfs_route_feature(route: GtfsRoute, extra: dict[str, Any] | None = None) -> dict[str, Any] | None:
|
||||
if not route.geometry_geojson:
|
||||
return None
|
||||
props = {
|
||||
"id": route.id,
|
||||
"dataset_id": route.dataset_id,
|
||||
"route_id": route.route_id,
|
||||
"mode": route.mode,
|
||||
"route_scope": route.route_scope,
|
||||
"ref": route.short_name,
|
||||
"name": route.long_name,
|
||||
"operator": route.operator_name,
|
||||
"source": "gtfs",
|
||||
}
|
||||
if extra:
|
||||
props.update(extra)
|
||||
return {"type": "Feature", "geometry": json.loads(route.geometry_geojson), "properties": props}
|
||||
|
||||
|
||||
def osm_feature_feature(feature: OsmFeature, extra: dict[str, Any] | None = None) -> dict[str, Any] | None:
|
||||
if not feature.geometry_geojson:
|
||||
return None
|
||||
props = {
|
||||
"id": osm_feature_public_id(feature),
|
||||
"row_id": feature.id,
|
||||
"dataset_id": feature.dataset_id,
|
||||
"osm_type": feature.osm_type,
|
||||
"osm_id": feature.osm_id,
|
||||
"kind": feature.kind,
|
||||
"mode": feature.mode,
|
||||
"route_scope": feature.route_scope,
|
||||
"ref": feature.ref,
|
||||
"name": feature.name,
|
||||
"operator": feature.operator,
|
||||
"network": feature.network,
|
||||
"source": "osm",
|
||||
}
|
||||
if extra:
|
||||
props.update(extra)
|
||||
return {"type": "Feature", "geometry": json.loads(feature.geometry_geojson), "properties": props}
|
||||
|
||||
|
||||
def route_pattern_feature(pattern: RoutePattern, extra: dict[str, Any] | None = None) -> dict[str, Any] | None:
|
||||
if not pattern.geometry_geojson:
|
||||
return None
|
||||
props = {
|
||||
"id": pattern.id,
|
||||
"route_pattern_id": pattern.id,
|
||||
"route_ref": pattern.route_ref,
|
||||
"ref": pattern.route_ref,
|
||||
"name": pattern.route_name,
|
||||
"mode": pattern.mode,
|
||||
"route_scope": pattern.route_scope,
|
||||
"operator": pattern.operator_name,
|
||||
"source": "route_layer",
|
||||
"source_kind": pattern.source_kind,
|
||||
"status": pattern.status,
|
||||
"confidence": pattern.confidence,
|
||||
"osm_feature_id": pattern.osm_feature_id,
|
||||
"gtfs_route_id": pattern.gtfs_route_id,
|
||||
"gtfs_shape_id": pattern.gtfs_shape_id,
|
||||
}
|
||||
if extra:
|
||||
props.update(extra)
|
||||
return {"type": "Feature", "geometry": json.loads(pattern.geometry_geojson), "properties": props}
|
||||
|
||||
|
||||
def gtfs_stop_feature(stop: GtfsStop) -> dict[str, Any] | None:
|
||||
if stop.lon is None or stop.lat is None:
|
||||
return None
|
||||
return {
|
||||
"type": "Feature",
|
||||
"geometry": {"type": "Point", "coordinates": [stop.lon, stop.lat]},
|
||||
"properties": {
|
||||
"id": stop.id,
|
||||
"dataset_id": stop.dataset_id,
|
||||
"stop_id": stop.stop_id,
|
||||
"name": stop.name,
|
||||
"source": "gtfs",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def match_row(match: RouteMatch) -> dict[str, Any]:
|
||||
route = match.gtfs_route
|
||||
feature = match.osm_feature
|
||||
return {
|
||||
"id": match.id,
|
||||
"status": match.status,
|
||||
"confidence": match.confidence,
|
||||
"rule_source": match.rule_source,
|
||||
"gtfs": {
|
||||
"id": route.id,
|
||||
"dataset_id": route.dataset_id,
|
||||
"route_id": route.route_id,
|
||||
"mode": route.mode,
|
||||
"route_scope": route.route_scope,
|
||||
"ref": route.short_name,
|
||||
"name": route.long_name,
|
||||
"operator": route.operator_name,
|
||||
},
|
||||
"osm": None
|
||||
if feature is None
|
||||
else {
|
||||
"id": feature.id,
|
||||
"dataset_id": feature.dataset_id,
|
||||
"osm_type": feature.osm_type,
|
||||
"osm_id": feature.osm_id,
|
||||
"mode": feature.mode,
|
||||
"route_scope": feature.route_scope,
|
||||
"ref": feature.ref,
|
||||
"name": feature.name,
|
||||
"operator": feature.operator,
|
||||
"network": feature.network,
|
||||
},
|
||||
"reasons": json.loads(match.reasons_json or "{}"),
|
||||
}
|
||||
309
app/source_catalog.py
Normal file
309
app/source_catalog.py
Normal file
@@ -0,0 +1,309 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import hashlib
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import Source, SourceCatalogEntry
|
||||
|
||||
|
||||
DIRECT_INGEST_KINDS = {"gtfs", "osm_geojson", "osm_pbf"}
|
||||
|
||||
|
||||
def default_source_catalog_path() -> Path:
|
||||
return Path(__file__).resolve().parents[1] / "docs" / "source_catalog_seed.csv"
|
||||
|
||||
|
||||
def default_ingestable_sources_path() -> Path:
|
||||
return Path(__file__).resolve().parents[1] / "docs" / "ingestable_sources_seed.csv"
|
||||
|
||||
|
||||
def import_source_catalog(session: Session, path: Path | str | None = None, *, update_existing: bool = True) -> dict[str, int]:
|
||||
csv_path = _resolve_path(path, default_source_catalog_path())
|
||||
rows = _read_csv(csv_path)
|
||||
created = 0
|
||||
updated = 0
|
||||
skipped = 0
|
||||
for row in rows:
|
||||
source_name = _value(row, "Source name")
|
||||
if not source_name:
|
||||
skipped += 1
|
||||
continue
|
||||
payload = {
|
||||
"catalog_key": _catalog_key(row),
|
||||
"geography": _value(row, "Geography"),
|
||||
"country_code": _value(row, "Country code"),
|
||||
"mode_scope": _value(row, "Mode scope"),
|
||||
"source_name": source_name,
|
||||
"source_category": _value(row, "Source category"),
|
||||
"formats_apis": _value(row, "Formats / APIs"),
|
||||
"availability": _value(row, "Availability"),
|
||||
"coverage_notes": _value(row, "Coverage notes"),
|
||||
"geometry_notes": _value(row, "Supersedes OSM for"),
|
||||
"disruptions_closures": _value(row, "Disruptions / closures"),
|
||||
"operator_list_use": _value(row, "Operator-list use"),
|
||||
"access_license_notes": _value(row, "Access / licence notes"),
|
||||
"priority": _value(row, "Priority"),
|
||||
"source_url": _value(row, "Source URL"),
|
||||
"evidence_url": _value(row, "Evidence URL"),
|
||||
"next_pipeline_action": _value(row, "Next pipeline action"),
|
||||
}
|
||||
existing = session.scalar(select(SourceCatalogEntry).where(SourceCatalogEntry.catalog_key == payload["catalog_key"]))
|
||||
if existing is None:
|
||||
session.add(SourceCatalogEntry(**payload))
|
||||
created += 1
|
||||
continue
|
||||
if not update_existing:
|
||||
skipped += 1
|
||||
continue
|
||||
for key, value in payload.items():
|
||||
setattr(existing, key, value)
|
||||
existing.updated_at = datetime.now(timezone.utc)
|
||||
updated += 1
|
||||
session.flush()
|
||||
return {"created": created, "updated": updated, "skipped": skipped}
|
||||
|
||||
|
||||
def import_ingestable_sources(
|
||||
session: Session,
|
||||
path: Path | str | None = None,
|
||||
*,
|
||||
update_existing: bool = True,
|
||||
) -> dict[str, int]:
|
||||
csv_path = _resolve_path(path, default_ingestable_sources_path())
|
||||
rows = _read_csv(csv_path)
|
||||
created = 0
|
||||
updated = 0
|
||||
skipped = 0
|
||||
linked_catalog = 0
|
||||
for row in rows:
|
||||
name = _value(row, "name")
|
||||
kind = (_value(row, "kind") or "").lower()
|
||||
url = _value(row, "url")
|
||||
if not name or not url or kind not in DIRECT_INGEST_KINDS:
|
||||
skipped += 1
|
||||
continue
|
||||
catalog_entry = _catalog_entry_for_ingestable_row(session, row)
|
||||
payload = {
|
||||
"name": name,
|
||||
"kind": kind,
|
||||
"url": url,
|
||||
"country": _value(row, "country"),
|
||||
"license": _value(row, "license"),
|
||||
"priority": _value(row, "priority"),
|
||||
"mode_scope": _value(row, "mode_scope"),
|
||||
"source_basis": _value(row, "source_basis"),
|
||||
"notes": _value(row, "notes"),
|
||||
"catalog_entry_id": None if catalog_entry is None else catalog_entry.id,
|
||||
}
|
||||
existing = session.scalar(
|
||||
select(Source)
|
||||
.where(Source.kind == kind, Source.url == url)
|
||||
.order_by(Source.id)
|
||||
.limit(1)
|
||||
)
|
||||
if existing is None:
|
||||
existing = session.scalar(select(Source).where(Source.name == name, Source.url == url).order_by(Source.id).limit(1))
|
||||
if existing is None:
|
||||
session.add(Source(**payload))
|
||||
created += 1
|
||||
if catalog_entry is not None:
|
||||
linked_catalog += 1
|
||||
continue
|
||||
if not update_existing:
|
||||
skipped += 1
|
||||
continue
|
||||
for key, value in payload.items():
|
||||
setattr(existing, key, value)
|
||||
existing.enabled = True
|
||||
updated += 1
|
||||
if catalog_entry is not None:
|
||||
linked_catalog += 1
|
||||
session.flush()
|
||||
return {"created": created, "updated": updated, "skipped": skipped, "linked_catalog": linked_catalog}
|
||||
|
||||
|
||||
def source_catalog_summary(session: Session) -> dict[str, object]:
|
||||
priority_counts = {
|
||||
priority or "unknown": count
|
||||
for priority, count in session.execute(
|
||||
select(SourceCatalogEntry.priority, func.count()).group_by(SourceCatalogEntry.priority)
|
||||
).all()
|
||||
}
|
||||
status_counts = {
|
||||
status or "unknown": count
|
||||
for status, count in session.execute(select(SourceCatalogEntry.status, func.count()).group_by(SourceCatalogEntry.status)).all()
|
||||
}
|
||||
ingestable_sources = session.scalar(
|
||||
select(func.count()).select_from(Source).where(Source.source_basis.is_not(None) | Source.priority.is_not(None))
|
||||
) or 0
|
||||
return {
|
||||
"catalog_entries": session.scalar(select(func.count()).select_from(SourceCatalogEntry)) or 0,
|
||||
"catalog_by_priority": priority_counts,
|
||||
"catalog_by_status": status_counts,
|
||||
"seeded_ingestable_sources": ingestable_sources,
|
||||
}
|
||||
|
||||
|
||||
def source_catalog_rows(
|
||||
session: Session,
|
||||
*,
|
||||
q: str | None = None,
|
||||
country: str | None = None,
|
||||
priority: str | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[SourceCatalogEntry]:
|
||||
stmt = select(SourceCatalogEntry).order_by(
|
||||
SourceCatalogEntry.priority,
|
||||
SourceCatalogEntry.country_code,
|
||||
SourceCatalogEntry.source_name,
|
||||
SourceCatalogEntry.id,
|
||||
)
|
||||
if q:
|
||||
pattern = f"%{q.strip()}%"
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
SourceCatalogEntry.source_name.ilike(pattern),
|
||||
SourceCatalogEntry.source_category.ilike(pattern),
|
||||
SourceCatalogEntry.formats_apis.ilike(pattern),
|
||||
SourceCatalogEntry.coverage_notes.ilike(pattern),
|
||||
SourceCatalogEntry.next_pipeline_action.ilike(pattern),
|
||||
)
|
||||
)
|
||||
if country:
|
||||
stmt = stmt.where(SourceCatalogEntry.country_code.ilike(f"%{country.strip()}%"))
|
||||
if priority:
|
||||
stmt = stmt.where(SourceCatalogEntry.priority == priority.strip())
|
||||
if status:
|
||||
stmt = stmt.where(SourceCatalogEntry.status == status.strip())
|
||||
return session.scalars(stmt.limit(max(1, min(limit, 500)))).all()
|
||||
|
||||
|
||||
def catalog_entry_payload(entry: SourceCatalogEntry, *, linked_source_count: int = 0) -> dict[str, object]:
|
||||
return {
|
||||
"id": entry.id,
|
||||
"geography": entry.geography,
|
||||
"country_code": entry.country_code,
|
||||
"mode_scope": entry.mode_scope,
|
||||
"source_name": entry.source_name,
|
||||
"source_category": entry.source_category,
|
||||
"formats_apis": entry.formats_apis,
|
||||
"availability": entry.availability,
|
||||
"coverage_notes": entry.coverage_notes,
|
||||
"geometry_notes": entry.geometry_notes,
|
||||
"disruptions_closures": entry.disruptions_closures,
|
||||
"operator_list_use": entry.operator_list_use,
|
||||
"access_license_notes": entry.access_license_notes,
|
||||
"priority": entry.priority,
|
||||
"source_url": entry.source_url,
|
||||
"evidence_url": entry.evidence_url,
|
||||
"next_pipeline_action": entry.next_pipeline_action,
|
||||
"status": entry.status,
|
||||
"linked_source_count": linked_source_count,
|
||||
"created_at": entry.created_at.isoformat() if entry.created_at else None,
|
||||
"updated_at": entry.updated_at.isoformat() if entry.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
def linked_source_counts(session: Session, entries: Iterable[SourceCatalogEntry]) -> dict[int, int]:
|
||||
entry_ids = [entry.id for entry in entries]
|
||||
if not entry_ids:
|
||||
return {}
|
||||
return {
|
||||
entry_id: count
|
||||
for entry_id, count in session.execute(
|
||||
select(Source.catalog_entry_id, func.count())
|
||||
.where(Source.catalog_entry_id.in_(entry_ids))
|
||||
.group_by(Source.catalog_entry_id)
|
||||
).all()
|
||||
if entry_id is not None
|
||||
}
|
||||
|
||||
|
||||
def _catalog_entry_for_ingestable_row(session: Session, row: dict[str, str]) -> SourceCatalogEntry | None:
|
||||
country = _value(row, "country")
|
||||
source_basis = _value(row, "source_basis")
|
||||
name = _value(row, "name")
|
||||
if not country and not source_basis and not name:
|
||||
return None
|
||||
if name:
|
||||
exact = session.scalar(
|
||||
select(SourceCatalogEntry)
|
||||
.where(func.lower(SourceCatalogEntry.source_name) == name.lower())
|
||||
.order_by(SourceCatalogEntry.id)
|
||||
.limit(1)
|
||||
)
|
||||
if exact is not None:
|
||||
return exact
|
||||
clauses = []
|
||||
if country:
|
||||
clauses.append(SourceCatalogEntry.country_code.ilike(f"%{country}%"))
|
||||
if source_basis:
|
||||
for token in _basis_tokens(source_basis):
|
||||
clauses.append(SourceCatalogEntry.source_name.ilike(f"%{token}%"))
|
||||
clauses.append(SourceCatalogEntry.coverage_notes.ilike(f"%{token}%"))
|
||||
if name:
|
||||
first_word = name.split()[0]
|
||||
if len(first_word) > 2:
|
||||
clauses.append(SourceCatalogEntry.source_name.ilike(f"%{first_word}%"))
|
||||
if not clauses:
|
||||
return None
|
||||
return session.scalar(
|
||||
select(SourceCatalogEntry)
|
||||
.where(or_(*clauses))
|
||||
.order_by(SourceCatalogEntry.priority, SourceCatalogEntry.id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
|
||||
def _basis_tokens(value: str) -> list[str]:
|
||||
tokens = []
|
||||
for raw in value.replace("/", " ").replace("-", " ").split():
|
||||
token = raw.strip(" ,.;()")
|
||||
if len(token) >= 5 and token.lower() not in {"official", "mirror", "feeds", "transport"}:
|
||||
tokens.append(token)
|
||||
return tokens[:4]
|
||||
|
||||
|
||||
def _catalog_key(row: dict[str, str]) -> str:
|
||||
parts = [
|
||||
_value(row, "Country code"),
|
||||
_value(row, "Source name"),
|
||||
_value(row, "Source URL"),
|
||||
_value(row, "Formats / APIs"),
|
||||
]
|
||||
text = "|".join(part.lower() for part in parts if part)
|
||||
if not text:
|
||||
text = repr(sorted(row.items()))
|
||||
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _read_csv(path: Path) -> list[dict[str, str]]:
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(path)
|
||||
with path.open("r", encoding="utf-8-sig", newline="") as handle:
|
||||
reader = csv.DictReader(handle)
|
||||
return [dict(row) for row in reader]
|
||||
|
||||
|
||||
def _resolve_path(path: Path | str | None, default_path: Path) -> Path:
|
||||
if path is None:
|
||||
return default_path
|
||||
candidate = Path(path)
|
||||
if candidate.is_absolute():
|
||||
return candidate
|
||||
return Path.cwd() / candidate
|
||||
|
||||
|
||||
def _value(row: dict[str, str], key: str) -> str | None:
|
||||
value = row.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
stripped = value.strip()
|
||||
return stripped or None
|
||||
256
app/source_updates.py
Normal file
256
app/source_updates.py
Normal file
@@ -0,0 +1,256 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models import Dataset, Source, SourceUpdateCheck
|
||||
from app.pipeline.utils import norm_text, sha256_file
|
||||
|
||||
|
||||
def check_source_for_update(session: Session, source: Source) -> SourceUpdateCheck:
|
||||
active_dataset = session.scalar(
|
||||
select(Dataset)
|
||||
.where(Dataset.source_id == source.id, Dataset.is_active.is_(True))
|
||||
.order_by(Dataset.created_at.desc(), Dataset.id.desc())
|
||||
)
|
||||
recovery = _recover_missing_managed_cache_url(source)
|
||||
remote = _source_remote_metadata(source)
|
||||
if recovery is not None:
|
||||
remote["recovered_source_url"] = recovery["url"]
|
||||
remote["previous_source_url"] = recovery["previous_url"]
|
||||
update_available, reason = _update_decision(active_dataset, remote)
|
||||
check = SourceUpdateCheck(
|
||||
source_id=source.id,
|
||||
status=remote["status"],
|
||||
update_available=update_available,
|
||||
reason=reason,
|
||||
remote_url=source.url,
|
||||
etag=remote.get("etag"),
|
||||
last_modified=remote.get("last_modified"),
|
||||
content_length=remote.get("content_length"),
|
||||
content_type=remote.get("content_type"),
|
||||
local_mtime=remote.get("local_mtime"),
|
||||
local_size=remote.get("local_size"),
|
||||
local_sha256=remote.get("local_sha256"),
|
||||
active_dataset_id=None if active_dataset is None else active_dataset.id,
|
||||
active_dataset_sha256=None if active_dataset is None else active_dataset.sha256,
|
||||
metadata_json=json.dumps(remote, separators=(",", ":"), default=_json_default),
|
||||
)
|
||||
session.add(check)
|
||||
source.status = "update_check_error" if remote["status"] != "checked" else "update_available" if update_available else "up_to_date"
|
||||
source.last_error = None if remote["status"] == "checked" else reason
|
||||
session.flush()
|
||||
return check
|
||||
|
||||
|
||||
def latest_source_update_check(session: Session, source_id: int) -> SourceUpdateCheck | None:
|
||||
return session.scalar(
|
||||
select(SourceUpdateCheck)
|
||||
.where(SourceUpdateCheck.source_id == source_id)
|
||||
.order_by(SourceUpdateCheck.checked_at.desc(), SourceUpdateCheck.id.desc())
|
||||
)
|
||||
|
||||
|
||||
def update_check_payload(check: SourceUpdateCheck | None) -> dict | None:
|
||||
if check is None:
|
||||
return None
|
||||
try:
|
||||
metadata = json.loads(check.metadata_json or "{}")
|
||||
except json.JSONDecodeError:
|
||||
metadata = {}
|
||||
return {
|
||||
"id": check.id,
|
||||
"source_id": check.source_id,
|
||||
"checked_at": check.checked_at.isoformat() if check.checked_at else None,
|
||||
"status": check.status,
|
||||
"update_available": check.update_available,
|
||||
"reason": check.reason,
|
||||
"etag": check.etag,
|
||||
"last_modified": check.last_modified,
|
||||
"content_length": check.content_length,
|
||||
"content_type": check.content_type,
|
||||
"local_mtime": check.local_mtime.isoformat() if check.local_mtime else None,
|
||||
"local_size": check.local_size,
|
||||
"local_sha256": check.local_sha256,
|
||||
"active_dataset_id": check.active_dataset_id,
|
||||
"active_dataset_sha256": check.active_dataset_sha256,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
|
||||
def record_dataset_update_metadata(dataset: Dataset, check: SourceUpdateCheck | None) -> None:
|
||||
if check is None:
|
||||
return
|
||||
try:
|
||||
metadata = json.loads(dataset.metadata_json or "{}")
|
||||
except json.JSONDecodeError:
|
||||
metadata = {}
|
||||
metadata["source_update_check"] = {
|
||||
"id": check.id,
|
||||
"checked_at": check.checked_at.isoformat() if check.checked_at else None,
|
||||
"etag": check.etag,
|
||||
"last_modified": check.last_modified,
|
||||
"content_length": check.content_length,
|
||||
"content_type": check.content_type,
|
||||
"local_mtime": check.local_mtime.isoformat() if check.local_mtime else None,
|
||||
"local_size": check.local_size,
|
||||
"local_sha256": check.local_sha256,
|
||||
"metadata": update_check_payload(check).get("metadata", {}),
|
||||
}
|
||||
dataset.metadata_json = json.dumps(metadata, indent=2, default=_json_default)
|
||||
|
||||
|
||||
def _source_remote_metadata(source: Source) -> dict:
|
||||
parsed = urlparse(source.url)
|
||||
if parsed.scheme in {"http", "https"}:
|
||||
return _http_metadata(source.url)
|
||||
path = Path(parsed.path) if parsed.scheme == "file" else Path(source.url)
|
||||
return _local_metadata(path)
|
||||
|
||||
|
||||
def _recover_missing_managed_cache_url(source: Source) -> dict | None:
|
||||
parsed = urlparse(source.url)
|
||||
if parsed.scheme in {"http", "https"}:
|
||||
return None
|
||||
path = Path(parsed.path) if parsed.scheme == "file" else Path(source.url)
|
||||
if path.exists() or not _is_managed_source_cache_path(path, source.id):
|
||||
return None
|
||||
replacement = _seed_source_url_for(source)
|
||||
if replacement is None:
|
||||
return None
|
||||
previous_url = source.url
|
||||
source.url = replacement
|
||||
return {"previous_url": previous_url, "url": replacement}
|
||||
|
||||
|
||||
def _is_managed_source_cache_path(path: Path, source_id: int) -> bool:
|
||||
source_dir = f"source_{source_id}"
|
||||
try:
|
||||
resolved = path.resolve()
|
||||
managed_dir = (settings.data_dir / "sources" / source_dir).resolve()
|
||||
resolved.relative_to(managed_dir)
|
||||
return True
|
||||
except ValueError:
|
||||
pass
|
||||
parts = path.parts
|
||||
return any(part == "sources" and index + 1 < len(parts) and parts[index + 1] == source_dir for index, part in enumerate(parts))
|
||||
|
||||
|
||||
def _seed_source_url_for(source: Source) -> str | None:
|
||||
seed_path = Path(__file__).resolve().parents[1] / "scripts" / "example_sources.json"
|
||||
if not seed_path.exists():
|
||||
return None
|
||||
try:
|
||||
rows = json.loads(seed_path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
source_tokens = set(norm_text(source.name).split())
|
||||
for row in rows if isinstance(rows, list) else []:
|
||||
if not isinstance(row, dict):
|
||||
continue
|
||||
url = str(row.get("url") or "")
|
||||
if urlparse(url).scheme not in {"http", "https"}:
|
||||
continue
|
||||
if row.get("kind") != source.kind:
|
||||
continue
|
||||
if source.country and row.get("country") and str(row.get("country")) != source.country:
|
||||
continue
|
||||
row_tokens = set(norm_text(row.get("name")).split())
|
||||
if row_tokens and (row_tokens <= source_tokens or source_tokens <= row_tokens):
|
||||
return url
|
||||
return None
|
||||
|
||||
|
||||
def _http_metadata(url: str) -> dict:
|
||||
response = None
|
||||
try:
|
||||
response = requests.head(url, allow_redirects=True, timeout=30)
|
||||
if response.status_code in {405, 501}:
|
||||
response.close()
|
||||
response = requests.get(url, stream=True, timeout=30)
|
||||
response.raise_for_status()
|
||||
except Exception as exc: # noqa: BLE001 - persisted as update-check status
|
||||
return {"status": "error", "error": str(exc)}
|
||||
finally:
|
||||
if response is not None:
|
||||
response.close()
|
||||
headers = response.headers
|
||||
content_length = headers.get("Content-Length")
|
||||
return {
|
||||
"status": "checked",
|
||||
"etag": headers.get("ETag"),
|
||||
"last_modified": headers.get("Last-Modified"),
|
||||
"content_length": int(content_length) if content_length and content_length.isdigit() else None,
|
||||
"content_type": headers.get("Content-Type"),
|
||||
"final_url": response.url,
|
||||
"update_artifact": _update_artifact(url, headers.get("Content-Type")),
|
||||
}
|
||||
|
||||
|
||||
def _local_metadata(path: Path) -> dict:
|
||||
if not path.exists():
|
||||
return {"status": "error", "error": f"Source file does not exist: {path}"}
|
||||
stat = path.stat()
|
||||
return {
|
||||
"status": "checked",
|
||||
"local_mtime": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc),
|
||||
"local_size": stat.st_size,
|
||||
"local_sha256": sha256_file(path),
|
||||
"update_artifact": _update_artifact(str(path), None),
|
||||
}
|
||||
|
||||
|
||||
def _update_decision(active_dataset: Dataset | None, remote: dict) -> tuple[bool, str]:
|
||||
if remote["status"] != "checked":
|
||||
return False, remote.get("error") or "update check failed"
|
||||
if active_dataset is None:
|
||||
return True, "no active dataset imported"
|
||||
if remote.get("local_sha256"):
|
||||
if remote["local_sha256"] == active_dataset.sha256:
|
||||
return False, "local file hash matches active dataset"
|
||||
return True, "local file hash differs from active dataset"
|
||||
|
||||
previous = _dataset_update_metadata(active_dataset)
|
||||
comparable = []
|
||||
for key in ("etag", "last_modified", "content_length"):
|
||||
current = remote.get(key)
|
||||
old = previous.get(key)
|
||||
if current is not None and old is not None:
|
||||
comparable.append(key)
|
||||
if str(current) != str(old):
|
||||
return True, f"remote {key} changed"
|
||||
if comparable:
|
||||
return False, "remote metadata matches active dataset"
|
||||
return True, "no previous remote metadata recorded"
|
||||
|
||||
|
||||
def _dataset_update_metadata(dataset: Dataset) -> dict:
|
||||
try:
|
||||
metadata = json.loads(dataset.metadata_json or "{}")
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return metadata.get("source_update_check") or {}
|
||||
|
||||
|
||||
def _json_default(value):
|
||||
if isinstance(value, datetime):
|
||||
return value.isoformat()
|
||||
raise TypeError(f"Object of type {type(value).__name__} is not JSON serializable")
|
||||
|
||||
|
||||
def _update_artifact(url_or_path: str, content_type: str | None) -> dict:
|
||||
lower = url_or_path.lower()
|
||||
is_osm_diff = lower.endswith(".osc") or lower.endswith(".osc.gz")
|
||||
is_gtfs_zip = lower.endswith(".zip") or (content_type or "").lower() in {"application/zip", "application/x-zip-compressed"}
|
||||
return {
|
||||
"kind": "osm_diff" if is_osm_diff else "gtfs_or_archive" if is_gtfs_zip else "full_snapshot",
|
||||
"is_diff": is_osm_diff,
|
||||
"content_type": content_type,
|
||||
}
|
||||
158
app/spatial.py
Normal file
158
app/spatial.py
Normal file
@@ -0,0 +1,158 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
POSTGIS_GEOMETRY_TABLES = {
|
||||
"osm_features",
|
||||
"gtfs_routes",
|
||||
"gtfs_shapes",
|
||||
"gtfs_stops",
|
||||
"canonical_stops",
|
||||
"route_patterns",
|
||||
"osm_addresses",
|
||||
"routing_nodes",
|
||||
"routing_edges",
|
||||
}
|
||||
|
||||
|
||||
def using_postgresql() -> bool:
|
||||
return settings.is_postgresql_database
|
||||
|
||||
|
||||
def refresh_postgis_geometries(
|
||||
session: Session,
|
||||
*,
|
||||
dataset_id: int | None = None,
|
||||
tables: Iterable[str] | None = None,
|
||||
only_missing: bool = True,
|
||||
) -> None:
|
||||
if not using_postgresql():
|
||||
return
|
||||
selected = set(tables or POSTGIS_GEOMETRY_TABLES)
|
||||
unknown = selected - POSTGIS_GEOMETRY_TABLES
|
||||
if unknown:
|
||||
raise ValueError(f"Unsupported PostGIS geometry table(s): {', '.join(sorted(unknown))}")
|
||||
|
||||
if "osm_features" in selected:
|
||||
_refresh_geojson_geometry(session, "osm_features", dataset_id=dataset_id, only_missing=only_missing)
|
||||
if "gtfs_routes" in selected:
|
||||
_refresh_geojson_geometry(session, "gtfs_routes", dataset_id=dataset_id, only_missing=only_missing)
|
||||
if "gtfs_shapes" in selected:
|
||||
_refresh_geojson_geometry(session, "gtfs_shapes", dataset_id=dataset_id, only_missing=only_missing)
|
||||
if "route_patterns" in selected:
|
||||
_refresh_geojson_geometry(session, "route_patterns", dataset_id=None, only_missing=only_missing)
|
||||
if "osm_addresses" in selected:
|
||||
_refresh_address_geometry(session, dataset_id=dataset_id, only_missing=only_missing)
|
||||
if "gtfs_stops" in selected:
|
||||
_refresh_point_geometry(session, "gtfs_stops", dataset_id=dataset_id, only_missing=only_missing)
|
||||
if "canonical_stops" in selected:
|
||||
_refresh_point_geometry(session, "canonical_stops", dataset_id=None, only_missing=only_missing)
|
||||
if "routing_nodes" in selected:
|
||||
_refresh_point_geometry(session, "routing_nodes", dataset_id=dataset_id, only_missing=only_missing)
|
||||
if "routing_edges" in selected:
|
||||
_refresh_routing_edge_geometry(session, dataset_id=dataset_id, only_missing=only_missing)
|
||||
|
||||
|
||||
def analyze_postgresql_tables(session: Session, tables: Iterable[str]) -> None:
|
||||
if not using_postgresql():
|
||||
return
|
||||
for table in tables:
|
||||
session.execute(text(f"ANALYZE {table}"))
|
||||
|
||||
|
||||
def _refresh_geojson_geometry(session: Session, table: str, *, dataset_id: int | None, only_missing: bool) -> None:
|
||||
where = ["geometry_geojson IS NOT NULL", "geometry_geojson <> ''"]
|
||||
params: dict[str, object] = {}
|
||||
if dataset_id is not None:
|
||||
where.append("dataset_id = :dataset_id")
|
||||
params["dataset_id"] = int(dataset_id)
|
||||
if only_missing:
|
||||
where.append("geom IS NULL")
|
||||
session.execute(
|
||||
text(
|
||||
f"""
|
||||
UPDATE {table}
|
||||
SET geom = ST_SetSRID(ST_GeomFromGeoJSON(geometry_geojson), 4326)
|
||||
WHERE {" AND ".join(where)}
|
||||
"""
|
||||
),
|
||||
params,
|
||||
)
|
||||
|
||||
|
||||
def _refresh_point_geometry(session: Session, table: str, *, dataset_id: int | None, only_missing: bool) -> None:
|
||||
where = ["lon IS NOT NULL", "lat IS NOT NULL"]
|
||||
params: dict[str, object] = {}
|
||||
if dataset_id is not None:
|
||||
where.append("dataset_id = :dataset_id")
|
||||
params["dataset_id"] = int(dataset_id)
|
||||
if only_missing:
|
||||
where.append("geom IS NULL")
|
||||
session.execute(
|
||||
text(
|
||||
f"""
|
||||
UPDATE {table}
|
||||
SET geom = ST_SetSRID(ST_MakePoint(lon, lat), 4326)
|
||||
WHERE {" AND ".join(where)}
|
||||
"""
|
||||
),
|
||||
params,
|
||||
)
|
||||
|
||||
|
||||
def _refresh_address_geometry(session: Session, *, dataset_id: int | None, only_missing: bool) -> None:
|
||||
_refresh_point_geometry(session, "osm_addresses", dataset_id=dataset_id, only_missing=only_missing)
|
||||
where = ["geometry_geojson IS NOT NULL", "geometry_geojson <> ''"]
|
||||
params: dict[str, object] = {}
|
||||
if dataset_id is not None:
|
||||
where.append("dataset_id = :dataset_id")
|
||||
params["dataset_id"] = int(dataset_id)
|
||||
if only_missing:
|
||||
where.append("area_geom IS NULL")
|
||||
session.execute(
|
||||
text(
|
||||
f"""
|
||||
UPDATE osm_addresses
|
||||
SET area_geom = ST_SetSRID(ST_GeomFromGeoJSON(geometry_geojson), 4326)
|
||||
WHERE {" AND ".join(where)}
|
||||
"""
|
||||
),
|
||||
params,
|
||||
)
|
||||
|
||||
|
||||
def _refresh_routing_edge_geometry(session: Session, *, dataset_id: int | None, only_missing: bool) -> None:
|
||||
where = [
|
||||
"source_lon IS NOT NULL",
|
||||
"source_lat IS NOT NULL",
|
||||
"target_lon IS NOT NULL",
|
||||
"target_lat IS NOT NULL",
|
||||
]
|
||||
params: dict[str, object] = {}
|
||||
if dataset_id is not None:
|
||||
where.append("dataset_id = :dataset_id")
|
||||
params["dataset_id"] = int(dataset_id)
|
||||
if only_missing:
|
||||
where.append("geom IS NULL")
|
||||
session.execute(
|
||||
text(
|
||||
f"""
|
||||
UPDATE routing_edges
|
||||
SET geom = ST_SetSRID(
|
||||
ST_MakeLine(
|
||||
ST_MakePoint(source_lon, source_lat),
|
||||
ST_MakePoint(target_lon, target_lat)
|
||||
),
|
||||
4326
|
||||
)
|
||||
WHERE {" AND ".join(where)}
|
||||
"""
|
||||
),
|
||||
params,
|
||||
)
|
||||
4090
app/static/app.js
Normal file
4090
app/static/app.js
Normal file
File diff suppressed because it is too large
Load Diff
1498
app/static/style.css
Normal file
1498
app/static/style.css
Normal file
File diff suppressed because it is too large
Load Diff
329
app/templates/index.html
Normal file
329
app/templates/index.html
Normal file
@@ -0,0 +1,329 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>Mobility Workbench</title>
|
||||
<link rel="stylesheet" href="https://unpkg.com/leaflet@1.9.4/dist/leaflet.css" crossorigin="" />
|
||||
<link rel="stylesheet" href="/static/style.css?v=20260701-harmonizer-module" />
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<div>
|
||||
<h1>Mobility Workbench</h1>
|
||||
<p>Harmonized transit, mapping data, route layer, map review, and journey tests.</p>
|
||||
</div>
|
||||
<div class="actions">
|
||||
<button id="refreshBtn">Refresh</button>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<main>
|
||||
<aside>
|
||||
<div class="sidebar-content">
|
||||
<details class="card sidebar-section" data-sidebar-section="stats" open>
|
||||
<summary><h2>Stats</h2></summary>
|
||||
<div class="sidebar-section-body">
|
||||
<div id="stats" class="stats"></div>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<details class="card sidebar-section" data-sidebar-section="qa" open>
|
||||
<summary><h2>QA</h2></summary>
|
||||
<div class="sidebar-section-body">
|
||||
<div class="qa-toolbar">
|
||||
<button type="button" id="refreshQaBtn">Refresh QA</button>
|
||||
</div>
|
||||
<div id="qaDashboard" class="qa-dashboard muted">No QA loaded.</div>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<details class="card sidebar-section" data-sidebar-section="jobs" open>
|
||||
<summary><h2>Jobs</h2></summary>
|
||||
<div class="sidebar-section-body">
|
||||
<div id="jobs" class="jobs muted">No jobs loaded.</div>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<details class="card sidebar-section" data-sidebar-section="harmonization" open>
|
||||
<summary><h2>GTFS Harmonization</h2></summary>
|
||||
<div class="sidebar-section-body">
|
||||
<details class="nested-section" data-sidebar-section="add-gtfs-source">
|
||||
<summary><h3>Add GTFS source</h3></summary>
|
||||
<div class="nested-section-body">
|
||||
<form id="sourceForm">
|
||||
<input name="catalog_entry_id" type="hidden" />
|
||||
<input name="kind" type="hidden" value="gtfs" />
|
||||
<label>Name <input name="name" required placeholder="DELFI / national GTFS" /></label>
|
||||
<label>URL or path <input name="url" required placeholder="https://.../feed.zip or ./data/feed.zip" /></label>
|
||||
<label>Country <input name="country" placeholder="DE" maxlength="8" /></label>
|
||||
<label>License <input name="license" placeholder="ODbL / CC-BY / unknown" /></label>
|
||||
<button type="submit">Add GTFS source</button>
|
||||
</form>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<details class="nested-section source-catalog-card" data-sidebar-section="source-catalog">
|
||||
<summary><h3>Transit source catalog</h3></summary>
|
||||
<div class="nested-section-body">
|
||||
<div id="sourceCatalogSummary" class="muted"></div>
|
||||
<div class="filter-row source-catalog-filter">
|
||||
<input id="sourceCatalogSearch" placeholder="Search catalog" />
|
||||
<input id="sourceCatalogCountry" placeholder="Country" />
|
||||
<select id="sourceCatalogPriority">
|
||||
<option value="">all priorities</option>
|
||||
<option value="P0">P0</option>
|
||||
<option value="P0 fallback">P0 fallback</option>
|
||||
<option value="P1">P1</option>
|
||||
<option value="P2">P2</option>
|
||||
<option value="P3">P3</option>
|
||||
<option value="P4">P4</option>
|
||||
<option value="P5">P5</option>
|
||||
</select>
|
||||
</div>
|
||||
<div class="source-catalog-actions">
|
||||
<button type="button" id="importSourceCatalogBtn">Import catalog</button>
|
||||
<button type="button" id="importIngestableSourcesBtn">Import ingestable seeds</button>
|
||||
</div>
|
||||
<div id="sourceCatalog"></div>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<details class="nested-section" data-sidebar-section="gtfs-feed-qa" open>
|
||||
<summary><h3>Feed QA</h3></summary>
|
||||
<div class="nested-section-body">
|
||||
<div class="qa-toolbar">
|
||||
<button type="button" id="refreshGtfsHarmonizationBtn">Refresh feeds</button>
|
||||
</div>
|
||||
<div id="gtfsHarmonizationInventory" class="harmonization-inventory muted">No GTFS feed QA loaded.</div>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<details class="nested-section" data-sidebar-section="gtfs-source-management" open>
|
||||
<summary><h3>GTFS source library</h3></summary>
|
||||
<div class="nested-section-body">
|
||||
<div class="filter-row">
|
||||
<input id="sourceSearch" placeholder="Filter GTFS sources" />
|
||||
</div>
|
||||
<div id="sources"></div>
|
||||
</div>
|
||||
</details>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<details class="card sidebar-section" data-sidebar-section="mapping" open>
|
||||
<summary><h2>Mapping Data</h2></summary>
|
||||
<div class="sidebar-section-body">
|
||||
<details class="nested-section" data-sidebar-section="add-map-source">
|
||||
<summary><h3>Add map source</h3></summary>
|
||||
<div class="nested-section-body">
|
||||
<form id="mappingSourceForm">
|
||||
<input name="catalog_entry_id" type="hidden" />
|
||||
<label>Name <input name="name" required placeholder="Germany OSM PBF" /></label>
|
||||
<label>Kind
|
||||
<select name="kind">
|
||||
<option value="osm_pbf">OSM PBF extract</option>
|
||||
<option value="osm_geojson">OSM transport GeoJSON</option>
|
||||
<option value="osm_diff">OSM change diff</option>
|
||||
</select>
|
||||
</label>
|
||||
<label>URL or path <input name="url" required placeholder="https://.../latest.osm.pbf or ./data/routes.geojson" /></label>
|
||||
<label>Country <input name="country" placeholder="DE" maxlength="8" /></label>
|
||||
<label>License <input name="license" placeholder="ODbL / CC-BY / unknown" /></label>
|
||||
<button type="submit">Add map source</button>
|
||||
</form>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<details class="nested-section source-catalog-card" data-sidebar-section="geofabrik">
|
||||
<summary><h3>Geofabrik OSM</h3></summary>
|
||||
<div class="nested-section-body">
|
||||
<div class="filter-row geofabrik-filter">
|
||||
<input id="geofabrikSearch" placeholder="Berlin, Germany, Hamburg" />
|
||||
<button type="button" id="geofabrikSearchBtn">Search</button>
|
||||
</div>
|
||||
<label class="inline-check"><input id="geofabrikDiffSource" type="checkbox" checked /> add diff source metadata</label>
|
||||
<div id="geofabrikResults" class="dataset-search-results muted">Search Geofabrik extracts, then add or import one as an OSM PBF source.</div>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<details class="nested-section" data-sidebar-section="mapping-source-management" open>
|
||||
<summary><h3>Map source library</h3></summary>
|
||||
<div class="nested-section-body">
|
||||
<div class="filter-row">
|
||||
<input id="mappingSourceSearch" placeholder="Filter map sources" />
|
||||
<select id="mappingSourceKindFilter">
|
||||
<option value="">all map kinds</option>
|
||||
<option value="osm_geojson">OSM GeoJSON</option>
|
||||
<option value="osm_pbf">OSM PBF</option>
|
||||
<option value="osm_diff">OSM diff</option>
|
||||
</select>
|
||||
</div>
|
||||
<div id="mappingSources"></div>
|
||||
</div>
|
||||
</details>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<details class="card sidebar-section" data-sidebar-section="datasets" open>
|
||||
<summary><h2>Datasets</h2></summary>
|
||||
<div class="sidebar-section-body">
|
||||
<details class="nested-section" data-sidebar-section="dataset-pipeline" open>
|
||||
<summary><h3>Derivation pipeline</h3></summary>
|
||||
<div class="nested-section-body">
|
||||
<div class="workflow-actions">
|
||||
<button id="runMatchBtn" type="button">Run matcher</button>
|
||||
<button id="buildRouteLayerBtn" type="button">Build route layer</button>
|
||||
<button id="loadSampleBtn" type="button">Reset sample</button>
|
||||
</div>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<details class="nested-section" data-sidebar-section="dataset-search">
|
||||
<summary><h3>Dataset search</h3></summary>
|
||||
<div class="nested-section-body">
|
||||
<form id="datasetSearchForm" class="dataset-search-form">
|
||||
<input id="datasetSearchQuery" placeholder="Route, line, stop, shape ID" autocomplete="off" />
|
||||
<div class="filter-row">
|
||||
<label class="inline-check"><input id="datasetSearchActiveOnly" type="checkbox" checked /> active only</label>
|
||||
<button type="submit">Search</button>
|
||||
</div>
|
||||
</form>
|
||||
<div id="datasetSearchResults" class="dataset-search-results muted">Search all imported datasets by label, route ID, and route-layer reference.</div>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<details class="nested-section matches-card" data-sidebar-section="route-matches">
|
||||
<summary><h3>Route matches</h3></summary>
|
||||
<div class="nested-section-body">
|
||||
<div class="filter-row">
|
||||
<select id="matchStatusFilter">
|
||||
<option value="">all</option>
|
||||
<option value="matched">matched</option>
|
||||
<option value="probable">probable</option>
|
||||
<option value="weak">weak</option>
|
||||
<option value="missing">missing</option>
|
||||
<option value="accepted">accepted</option>
|
||||
<option value="rejected">rejected</option>
|
||||
</select>
|
||||
<button id="reloadMatchesBtn">Reload</button>
|
||||
</div>
|
||||
<div id="matches"></div>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<details class="nested-section" data-sidebar-section="maintenance">
|
||||
<summary><h3>Maintenance</h3></summary>
|
||||
<div class="nested-section-body">
|
||||
<div class="maintenance-grid">
|
||||
<button type="button" data-admin-action="init-db">Init DB</button>
|
||||
<button type="button" data-admin-action="backfill-gtfs-shapes">Backfill GTFS shapes</button>
|
||||
<button type="button" data-admin-action="prune-cache-dry">Check cache</button>
|
||||
<button type="button" data-admin-action="prune-cache">Prune cache</button>
|
||||
<button type="button" data-admin-action="prune-inactive-dry">Check inactive</button>
|
||||
<button type="button" data-admin-action="prune-inactive">Prune inactive</button>
|
||||
<button type="button" data-admin-action="vacuum-db">Vacuum DB</button>
|
||||
<button type="button" class="danger" data-admin-action="reset-db">Reset DB</button>
|
||||
</div>
|
||||
<div id="adminStatus" class="admin-status muted"></div>
|
||||
</div>
|
||||
</details>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<details class="card sidebar-section" data-sidebar-section="layers" open>
|
||||
<summary><h2>Layers</h2></summary>
|
||||
<div class="sidebar-section-body">
|
||||
<div class="preset-row">
|
||||
<button type="button" data-layer-preset="network">Network</button>
|
||||
<button type="button" data-layer-preset="review">Matched/unmatched</button>
|
||||
<button type="button" data-layer-preset="unmatched">Unmatched</button>
|
||||
<button type="button" data-layer-preset="all">All</button>
|
||||
</div>
|
||||
<div id="layerControls" class="layer-controls"></div>
|
||||
<div id="mapStatus" class="map-status muted"></div>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
</div>
|
||||
<button id="sidebarCollapseBtn" class="sidebar-collapse-handle" type="button" aria-label="Collapse left panel" title="Collapse left panel" aria-expanded="true">‹</button>
|
||||
</aside>
|
||||
|
||||
<section class="map-panel">
|
||||
<div id="map"></div>
|
||||
<div id="mapLoading" class="map-loading" hidden>
|
||||
<span class="spinner" aria-hidden="true"></span>
|
||||
<span id="mapLoadingText">Loading map layers...</span>
|
||||
</div>
|
||||
<section class="map-floating journey-card">
|
||||
<h2>Journey</h2>
|
||||
<form id="journeyForm">
|
||||
<div id="journeyTransitSnapshot" class="journey-snapshot muted">Transit snapshot loading...</div>
|
||||
<label>From <input id="journeyFromQuery" placeholder="Hauptbahnhof" autocomplete="off" /></label>
|
||||
<input id="journeyFromStop" type="hidden" />
|
||||
<div id="journeyFromSuggestions" class="stop-suggestions"></div>
|
||||
<button type="button" id="journeySwapBtn" class="journey-swap" title="Switch start and destination">Swap</button>
|
||||
<label>To <input id="journeyToQuery" placeholder="Alexanderplatz" autocomplete="off" /></label>
|
||||
<input id="journeyToStop" type="hidden" />
|
||||
<div id="journeyToSuggestions" class="stop-suggestions"></div>
|
||||
<label>Via <input id="journeyViaQuery" placeholder="optional stop" autocomplete="off" /></label>
|
||||
<input id="journeyViaStop" type="hidden" />
|
||||
<div id="journeyViaSuggestions" class="stop-suggestions"></div>
|
||||
<div class="journey-mode" role="radiogroup" aria-label="Route mode">
|
||||
<label><input type="radio" name="journeyMode" value="transit" checked /> Public transport</label>
|
||||
<label><input type="radio" name="journeyMode" value="walk" /> Walk</label>
|
||||
<label><input type="radio" name="journeyMode" value="drive" /> Car</label>
|
||||
</div>
|
||||
<div class="journey-options">
|
||||
<label>Date <input id="journeyServiceDate" type="date" /></label>
|
||||
<label>Departure <input id="journeyDeparture" type="time" value="08:00" /></label>
|
||||
<label>Transfer buffer <input id="journeyTransferMinutes" type="number" min="0" max="60" step="1" value="2" /></label>
|
||||
<label>Rank by
|
||||
<select id="journeyRanking">
|
||||
<option value="recommended">Recommended</option>
|
||||
<option value="earliest_arrival">Earliest arrival</option>
|
||||
<option value="duration">Shortest duration</option>
|
||||
<option value="fewest_transfers">Fewest transfers</option>
|
||||
</select>
|
||||
</label>
|
||||
</div>
|
||||
<label class="journey-direct"><input id="journeyDirectOnly" type="checkbox" /> Direct public transport only</label>
|
||||
<div class="journey-actions">
|
||||
<button type="button" id="journeyEarlierBtn">Earlier</button>
|
||||
<button type="submit" class="primary">Search</button>
|
||||
<button type="button" id="journeyLaterBtn">Later</button>
|
||||
</div>
|
||||
<button type="button" id="generateItinerariesBtn">Generate travel options</button>
|
||||
</form>
|
||||
<div id="journeyResults" class="journey-results"></div>
|
||||
<section class="itinerary-panel">
|
||||
<div class="journey-title">
|
||||
<span>Comparison</span>
|
||||
<button type="button" id="reloadItinerariesBtn">Reload</button>
|
||||
</div>
|
||||
<div id="itineraryResults" class="itinerary-results muted">Generate travel options to compare route families.</div>
|
||||
</section>
|
||||
</section>
|
||||
<div class="legend">
|
||||
<span><b class="line osm"></b>OSM existing routes</span>
|
||||
<span><b class="line gtfs"></b>GTFS covered routes</span>
|
||||
<span><b class="line missing"></b>GTFS missing OSM match</span>
|
||||
<span><b class="dot stops"></b>Stops / stations / terminals</span>
|
||||
</div>
|
||||
</section>
|
||||
</main>
|
||||
|
||||
<div id="overlay" class="overlay" hidden>
|
||||
<section class="overlay-panel">
|
||||
<div class="overlay-title">
|
||||
<h2 id="overlayTitle">Candidates</h2>
|
||||
<button id="overlayCloseBtn">Close</button>
|
||||
</div>
|
||||
<div id="overlayContent"></div>
|
||||
</section>
|
||||
</div>
|
||||
|
||||
<script src="https://unpkg.com/leaflet@1.9.4/dist/leaflet.js" crossorigin=""></script>
|
||||
<script src="/static/app.js?v=20260701-harmonizer-module"></script>
|
||||
</body>
|
||||
</html>
|
||||
155
app/worker_supervisor.py
Normal file
155
app/worker_supervisor.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from app.config import settings
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkerHandle:
|
||||
index: int
|
||||
worker_id: str
|
||||
pid: int | None
|
||||
status: str
|
||||
pid_file: Path
|
||||
log_file: Path
|
||||
started_by_server: bool = False
|
||||
|
||||
|
||||
_handles: list[WorkerHandle] = []
|
||||
|
||||
|
||||
def start_queue_workers() -> list[WorkerHandle]:
|
||||
if not settings.queue_worker_autostart:
|
||||
return []
|
||||
worker_count = max(0, int(settings.queue_worker_count))
|
||||
handles: list[WorkerHandle] = []
|
||||
worker_dir = settings.data_dir / "workers"
|
||||
worker_dir.mkdir(parents=True, exist_ok=True)
|
||||
for index in range(worker_count):
|
||||
worker_id = f"server-worker-{index + 1}"
|
||||
pid_file = worker_dir / f"{worker_id}.pid"
|
||||
log_file = worker_dir / f"{worker_id}.log"
|
||||
existing_pid = _read_pid(pid_file)
|
||||
if existing_pid is not None and _pid_running(existing_pid):
|
||||
handles.append(
|
||||
WorkerHandle(
|
||||
index=index,
|
||||
worker_id=worker_id,
|
||||
pid=existing_pid,
|
||||
status="already_running",
|
||||
pid_file=pid_file,
|
||||
log_file=log_file,
|
||||
)
|
||||
)
|
||||
continue
|
||||
pid_file.unlink(missing_ok=True)
|
||||
process = _spawn_worker(worker_id, log_file)
|
||||
pid_file.write_text(str(process.pid), encoding="utf-8")
|
||||
handles.append(
|
||||
WorkerHandle(
|
||||
index=index,
|
||||
worker_id=worker_id,
|
||||
pid=process.pid,
|
||||
status="started",
|
||||
pid_file=pid_file,
|
||||
log_file=log_file,
|
||||
started_by_server=True,
|
||||
)
|
||||
)
|
||||
_handles[:] = handles
|
||||
return list(_handles)
|
||||
|
||||
|
||||
def stop_queue_workers() -> None:
|
||||
if not settings.queue_worker_stop_on_shutdown:
|
||||
return
|
||||
for handle in list(_handles):
|
||||
if not handle.started_by_server or handle.pid is None:
|
||||
continue
|
||||
_terminate_pid(handle.pid)
|
||||
handle.pid_file.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def queue_worker_status() -> list[dict[str, object]]:
|
||||
if not settings.queue_worker_autostart:
|
||||
return []
|
||||
worker_dir = settings.data_dir / "workers"
|
||||
statuses: list[dict[str, object]] = []
|
||||
configured_count = max(0, int(settings.queue_worker_count))
|
||||
for index in range(configured_count):
|
||||
worker_id = f"server-worker-{index + 1}"
|
||||
pid_file = worker_dir / f"{worker_id}.pid"
|
||||
log_file = worker_dir / f"{worker_id}.log"
|
||||
pid = _read_pid(pid_file)
|
||||
running = pid is not None and _pid_running(pid)
|
||||
statuses.append(
|
||||
{
|
||||
"index": index,
|
||||
"worker_id": worker_id,
|
||||
"pid": pid,
|
||||
"running": running,
|
||||
"pid_file": str(pid_file),
|
||||
"log_file": str(log_file),
|
||||
}
|
||||
)
|
||||
return statuses
|
||||
|
||||
|
||||
def _spawn_worker(worker_id: str, log_file: Path) -> subprocess.Popen:
|
||||
root = Path(__file__).resolve().parents[1]
|
||||
command = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"app.cli",
|
||||
"worker",
|
||||
"--worker-id",
|
||||
worker_id,
|
||||
"--poll-interval",
|
||||
str(settings.queue_worker_poll_interval_seconds),
|
||||
]
|
||||
env = os.environ.copy()
|
||||
env["MOBILITY_SUPERVISED_WORKER"] = "1"
|
||||
log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
log_handle = log_file.open("ab", buffering=0)
|
||||
try:
|
||||
return subprocess.Popen(
|
||||
command,
|
||||
cwd=str(root),
|
||||
env=env,
|
||||
stdin=subprocess.DEVNULL,
|
||||
stdout=log_handle,
|
||||
stderr=subprocess.STDOUT,
|
||||
start_new_session=True,
|
||||
)
|
||||
finally:
|
||||
log_handle.close()
|
||||
|
||||
|
||||
def _read_pid(path: Path) -> int | None:
|
||||
try:
|
||||
return int(path.read_text(encoding="utf-8").strip())
|
||||
except (FileNotFoundError, ValueError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def _pid_running(pid: int) -> bool:
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
except ProcessLookupError:
|
||||
return False
|
||||
except PermissionError:
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
def _terminate_pid(pid: int) -> None:
|
||||
try:
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
except ProcessLookupError:
|
||||
return
|
||||
Reference in New Issue
Block a user