Files
2026-07-01 23:29:51 +02:00

2654 lines
96 KiB
Python

from __future__ import annotations
import json
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from functools import wraps
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
from fastapi import Depends, FastAPI, HTTPException, Request, Response
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel, Field
from sqlalchemy import and_, exists, func, not_, or_, select, text
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session, joinedload
from app.address_search import address_at_point, search_addresses
from app.config import settings
from app.data_management import dataset_row_counts, source_row_counts
from app.dataset_search import search_datasets
from app.db import engine, get_db, init_db
from app.db_lock import DatabaseWriteBusy, database_write_lock, database_write_status
from app.geofabrik import create_geofabrik_source, geofabrik_catalog
from app.gtfs_storage import scheduled_stop_ids
from app.harmonization import GTFS_QA_NOTE_PREFIX, gtfs_harmonization_feed_detail, gtfs_harmonization_inventory
from app.itineraries import generate_itineraries, itinerary_payload, recent_itineraries, set_itinerary_saved, set_leg_locked
from app.journey import find_journeys, nearest_scheduled_stops, search_scheduled_stops
from app.journey_search import cancel_journey_search, journey_search_payload, start_journey_search
from app.jobs import (
active_address_index_rebuild_job,
active_source_import_job,
active_source_workflow_jobs,
active_dataset_delete_jobs,
cancel_job,
create_dataset_delete_job,
create_address_index_rebuild_job,
create_maintenance_job,
create_osm_relabel_job,
create_route_matching_job,
create_source_delete_job,
create_source_import_job,
create_route_layer_rebuild_job,
dismiss_job,
dismiss_terminal_jobs,
job_event_payload,
job_events,
job_payload,
latest_jobs,
job_queue_revision,
pause_job,
queue_missing_gtfs_sidecar_recovery_jobs,
reconcile_interrupted_jobs,
reconcile_source_workflow_state,
request_job_control,
resume_job,
retry_job,
set_job_priority,
)
from app.models import (
CanonicalStop,
CanonicalStopLink,
Dataset,
GtfsRoute,
GtfsStop,
Itinerary,
ItineraryLeg,
Job,
MatchRule,
OsmFeature,
PipelineRun,
RouteMatch,
RoutePattern,
Source,
SourceCatalogEntry,
SourceUpdateCheck,
)
from app.osm_storage import (
ensure_main_osm_feature,
osm_feature_count,
osm_feature_public_id,
query_osm_features,
resolve_osm_feature,
)
from app.pipeline.matcher import candidate_osm_routes_for_route, run_route_matching, score_route_pair
from app.pipeline.osm_addresses import address_index_status
from app.qa import qa_summary
from app.pipeline.route_layer import rebuild_route_layer
from app.pipeline.sample_data import load_sample_project
from app.pipeline.state import pipeline_run_payload
from app.routing import route_between_points, routing_status
from app.serializers import feature_collection, gtfs_route_feature, gtfs_stop_feature, match_row, osm_feature_feature, route_pattern_feature
from app.spatial import using_postgresql
from app.source_updates import (
check_source_for_update,
latest_source_update_check,
update_check_payload,
)
from app.source_catalog import (
catalog_entry_payload,
import_ingestable_sources,
import_source_catalog,
linked_source_counts,
source_catalog_rows,
source_catalog_summary,
)
from app.worker_supervisor import queue_worker_status, start_queue_workers, stop_queue_workers
@asynccontextmanager
async def lifespan(_app: FastAPI):
init_db()
with Session(engine) as session:
reconcile_interrupted_jobs(session)
queue_missing_gtfs_sidecar_recovery_jobs(session)
session.commit()
start_queue_workers()
try:
yield
finally:
stop_queue_workers()
app = FastAPI(title="Mobility Workbench", version="0.1.0", lifespan=lifespan)
app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
MAP_FEATURE_LIMIT = 5000
MAP_FEATURE_LIMIT_MAX = 20000
class SourceCreate(BaseModel):
name: str
kind: str
url: str
country: Optional[str] = None
license: Optional[str] = None
priority: Optional[str] = None
mode_scope: Optional[str] = None
source_basis: Optional[str] = None
notes: Optional[str] = None
catalog_entry_id: Optional[int] = None
class GtfsFeedReviewUpdate(BaseModel):
license: Optional[str] = None
review_status: Optional[str] = None
review_note: Optional[str] = None
enabled: Optional[bool] = None
class CatalogImportRequest(BaseModel):
csv_path: Optional[str] = None
update_existing: bool = True
class RuleCreate(BaseModel):
rule_type: str
selector: dict
action: dict
note: Optional[str] = None
class CanonicalStopGtfsLinkRequest(BaseModel):
gtfs_stop_id: Optional[int] = None
dataset_id: Optional[int] = None
stop_id: Optional[str] = None
note: Optional[str] = None
class ItineraryGenerateRequest(BaseModel):
from_stop_id: str
to_stop_id: str
via_stop_id: Optional[str] = None
source_id: Optional[str] = None
departure: str = "08:00"
service_date: Optional[str] = None
max_transfers: int = 1
transfer_seconds: int = 120
limit: int = 5
preferences: dict = Field(default_factory=dict)
class JourneySearchRequest(BaseModel):
from_stop_id: str
to_stop_id: str
via_stop_id: Optional[str] = None
source_id: Optional[str] = None
departure: str = "08:00"
service_date: Optional[str] = None
mode: str = "transit"
direct_only: bool = False
ranking: str = "recommended"
transfer_seconds: int = 120
limit: int = 5
class ItinerarySaveRequest(BaseModel):
saved: bool = True
class ItineraryLegLockRequest(BaseModel):
locked: bool = True
class GeofabrikSourceRequest(BaseModel):
geofabrik_id: str
import_updates: bool = False
run_import: bool = False
run_match: bool = True
build_route_layer: bool = True
class AdminActionRequest(BaseModel):
dry_run: bool = True
confirm: Optional[str] = None
dataset_id: Optional[int] = None
class JobPriorityRequest(BaseModel):
priority: int
def write_endpoint(operation: str):
def decorator(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
try:
with database_write_lock(operation):
return fn(*args, **kwargs)
except DatabaseWriteBusy as exc:
raise HTTPException(status_code=409, detail=str(exc), headers={"Retry-After": "5"}) from exc
return wrapper
return decorator
@app.exception_handler(OperationalError)
async def database_operational_error_handler(_request: Request, exc: OperationalError) -> JSONResponse:
if _is_database_locked_error(exc):
return JSONResponse(
status_code=409,
content={"detail": "Database is busy with another operation. Try again when the current write finishes."},
headers={"Retry-After": "5"},
)
return JSONResponse(status_code=500, content={"detail": "Database operation failed."})
@app.get("/", response_class=HTMLResponse)
def index(request: Request) -> HTMLResponse:
return templates.TemplateResponse(request=request, name="index.html")
@app.get("/api/sources")
def list_sources(db: Session = Depends(get_db)) -> list[dict]:
reconcile_interrupted_jobs(db)
queue_missing_gtfs_sidecar_recovery_jobs(db)
db.commit()
sources = db.scalars(select(Source).order_by(Source.id)).all()
latest_checks = {s.id: latest_source_update_check(db, s.id) for s in sources}
active_jobs = active_source_workflow_jobs(db)
active_dataset_jobs = active_dataset_delete_jobs(db)
return [_source_response(db, s, latest_checks.get(s.id), active_jobs.get(s.id), active_dataset_jobs) for s in sources]
@app.post("/api/sources")
@write_endpoint("create source")
def create_source(payload: SourceCreate, db: Session = Depends(get_db)) -> dict:
if payload.kind not in {"gtfs", "osm_geojson", "osm_pbf", "osm_diff"}:
raise HTTPException(status_code=400, detail="kind must be 'gtfs', 'osm_geojson', 'osm_pbf', or 'osm_diff'")
catalog_entry = db.get(SourceCatalogEntry, payload.catalog_entry_id) if payload.catalog_entry_id else None
if payload.catalog_entry_id and catalog_entry is None:
raise HTTPException(status_code=404, detail="catalog entry not found")
source = Source(
catalog_entry_id=payload.catalog_entry_id,
name=_truncate(payload.name, 255) or payload.name,
kind=payload.kind,
url=payload.url,
country=payload.country or _catalog_country(catalog_entry),
license=_truncate(payload.license, 255) or _truncate(catalog_entry.access_license_notes if catalog_entry else None, 255),
priority=payload.priority or (catalog_entry.priority if catalog_entry else None),
mode_scope=payload.mode_scope or (catalog_entry.mode_scope if catalog_entry else None),
source_basis=payload.source_basis or (catalog_entry.source_category if catalog_entry else None),
notes=payload.notes or _catalog_notes(catalog_entry),
)
db.add(source)
db.commit()
db.refresh(source)
return {"id": source.id, "status": source.status}
@app.post("/api/sources/{source_id}/run")
@write_endpoint("run source import")
def run_source(source_id: int, db: Session = Depends(get_db)) -> dict:
source = db.get(Source, source_id)
if source is None:
raise HTTPException(status_code=404, detail="source not found")
job, started = _queue_source_import_job(db, source, run_match=False, build_route_layer=False)
return {"source_id": source.id, "job_id": job.id, "status": "queued" if started else "already_running", "job": job_payload(job)}
@app.post("/api/sources/{source_id}/update")
@write_endpoint("update source")
def update_source(source_id: int, force: bool = False, db: Session = Depends(get_db)) -> dict:
source = db.get(Source, source_id)
if source is None:
raise HTTPException(status_code=404, detail="source not found")
active_job = active_source_import_job(db, source.id)
if active_job is not None:
return {"source_id": source.id, "status": "already_running", "job": job_payload(active_job)}
check = check_source_for_update(db, source)
if check.status != "checked":
db.commit()
raise HTTPException(status_code=502, detail=check.reason or "update check failed")
if not force and not check.update_available:
db.commit()
return {
"source_id": source.id,
"status": "skipped",
"reason": check.reason,
"check": update_check_payload(check),
}
job, started = _queue_source_import_job(db, source, run_match=True, build_route_layer=True)
return {"source_id": source.id, "status": "queued" if started else "already_running", "job": job_payload(job), "check": update_check_payload(check)}
@app.post("/api/sources/{source_id}/check-update")
@write_endpoint("check source update")
def check_source_update(source_id: int, db: Session = Depends(get_db)) -> dict:
source = db.get(Source, source_id)
if source is None:
raise HTTPException(status_code=404, detail="source not found")
check = check_source_for_update(db, source)
db.commit()
db.refresh(check)
return update_check_payload(check) or {}
@app.get("/api/sources/{source_id}/update-checks")
def list_source_update_checks(source_id: int, limit: int = 20, db: Session = Depends(get_db)) -> dict:
source = db.get(Source, source_id)
if source is None:
raise HTTPException(status_code=404, detail="source not found")
checks = db.scalars(
select(SourceUpdateCheck)
.where(SourceUpdateCheck.source_id == source.id)
.order_by(SourceUpdateCheck.checked_at.desc(), SourceUpdateCheck.id.desc())
.limit(max(1, min(limit, 100)))
).all()
return {"source_id": source.id, "checks": [update_check_payload(check) for check in checks]}
@app.get("/api/database/write-status")
def database_write_status_endpoint() -> dict:
state = database_write_status()
return {
"locked": state.locked,
"operation": state.operation,
"pid": state.pid,
"elapsed_seconds": None if state.elapsed_seconds is None else round(state.elapsed_seconds, 1),
}
ADMIN_JOB_ACTIONS = {
"init-db",
"reset-db",
"backfill-gtfs-shapes",
"prune-cache",
"prune-inactive-datasets",
"vacuum-db",
}
@app.post("/api/admin/init-db")
@write_endpoint("admin init database")
def admin_init_database(priority: int = 0, db: Session = Depends(get_db)) -> dict:
return _queue_admin_maintenance_job(db, "init-db", None, priority=priority)
@app.post("/api/admin/reset-db")
@write_endpoint("admin reset database")
def admin_reset_database(payload: AdminActionRequest, priority: int = 0, db: Session = Depends(get_db)) -> dict:
return _queue_admin_maintenance_job(db, "reset-db", payload, priority=priority)
@app.post("/api/admin/backfill-gtfs-shapes")
@write_endpoint("admin backfill GTFS shapes")
def admin_backfill_gtfs_shapes(payload: AdminActionRequest | None = None, priority: int = 0, db: Session = Depends(get_db)) -> dict:
return _queue_admin_maintenance_job(db, "backfill-gtfs-shapes", payload, priority=priority)
@app.post("/api/admin/prune-cache")
@write_endpoint("admin prune cache")
def admin_prune_cache(payload: AdminActionRequest, priority: int = 0, db: Session = Depends(get_db)) -> dict:
return _queue_admin_maintenance_job(db, "prune-cache", payload, priority=priority)
@app.post("/api/admin/prune-inactive-datasets")
@write_endpoint("admin prune inactive datasets")
def admin_prune_inactive_datasets(payload: AdminActionRequest, priority: int = 0, db: Session = Depends(get_db)) -> dict:
return _queue_admin_maintenance_job(db, "prune-inactive-datasets", payload, priority=priority)
@app.post("/api/admin/vacuum-db")
@write_endpoint("admin vacuum database")
def admin_vacuum_database(payload: AdminActionRequest, priority: int = 0, db: Session = Depends(get_db)) -> dict:
return _queue_admin_maintenance_job(db, "vacuum-db", payload, priority=priority)
@app.post("/api/jobs/admin/{action}")
@write_endpoint("queue admin maintenance")
def queue_admin_maintenance(action: str, payload: AdminActionRequest | None = None, priority: int = 0, db: Session = Depends(get_db)) -> dict:
return _queue_admin_maintenance_job(db, action, payload, priority=priority)
@app.delete("/api/sources/{source_id}")
@write_endpoint("queue source delete")
def remove_source(source_id: int, db: Session = Depends(get_db)) -> dict:
source = db.get(Source, source_id)
if source is None:
raise HTTPException(status_code=404, detail="source not found")
job = create_source_delete_job(db, source)
db.commit()
db.refresh(job)
return job_payload(job)
@app.get("/api/source-catalog")
def list_source_catalog(
q: Optional[str] = None,
country: Optional[str] = None,
priority: Optional[str] = None,
status: Optional[str] = None,
limit: int = 100,
db: Session = Depends(get_db),
) -> dict:
entries = source_catalog_rows(db, q=q, country=country, priority=priority, status=status, limit=limit)
source_counts = linked_source_counts(db, entries)
return {
"summary": source_catalog_summary(db),
"entries": [
catalog_entry_payload(entry, linked_source_count=source_counts.get(entry.id, 0))
for entry in entries
],
}
@app.post("/api/source-catalog/import")
@write_endpoint("import source catalog")
def import_source_catalog_endpoint(payload: CatalogImportRequest | None = None, db: Session = Depends(get_db)) -> dict:
result = import_source_catalog(db, None if payload is None else payload.csv_path, update_existing=payload.update_existing if payload else True)
db.commit()
return {**result, "summary": source_catalog_summary(db)}
@app.post("/api/source-catalog/import-ingestable")
@write_endpoint("import ingestable sources")
def import_ingestable_sources_endpoint(payload: CatalogImportRequest | None = None, db: Session = Depends(get_db)) -> dict:
result = import_ingestable_sources(db, None if payload is None else payload.csv_path, update_existing=payload.update_existing if payload else True)
db.commit()
return {**result, "summary": source_catalog_summary(db)}
@app.post("/api/jobs/source-catalog/import")
@write_endpoint("queue source catalog import")
def queue_source_catalog_import(payload: CatalogImportRequest | None = None, priority: int = 0, db: Session = Depends(get_db)) -> dict:
request_payload = payload or CatalogImportRequest()
job = create_maintenance_job(db, "source-catalog-import", _request_model_payload(request_payload), priority=priority)
db.commit()
db.refresh(job)
return job_payload(job)
@app.post("/api/jobs/source-catalog/import-ingestable")
@write_endpoint("queue ingestable source import")
def queue_ingestable_source_import(payload: CatalogImportRequest | None = None, priority: int = 0, db: Session = Depends(get_db)) -> dict:
request_payload = payload or CatalogImportRequest()
job = create_maintenance_job(db, "source-catalog-import-ingestable", _request_model_payload(request_payload), priority=priority)
db.commit()
db.refresh(job)
return job_payload(job)
@app.get("/api/geofabrik/catalog")
def list_geofabrik_catalog(q: Optional[str] = None, limit: int = 80) -> dict:
try:
return {"entries": geofabrik_catalog(q=q, limit=limit)}
except Exception as exc: # noqa: BLE001 - remote catalog errors should be visible in UI
raise HTTPException(status_code=502, detail=f"Geofabrik catalog fetch failed: {exc}") from exc
@app.post("/api/geofabrik/sources")
@write_endpoint("create Geofabrik source")
def create_geofabrik_source_endpoint(payload: GeofabrikSourceRequest, db: Session = Depends(get_db)) -> dict:
try:
source = create_geofabrik_source(db, payload.geofabrik_id, import_updates=payload.import_updates)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
response = {"source": _source_response(db, source), "job": None}
if payload.run_import:
job, _started = _queue_source_import_job(db, source, run_match=payload.run_match, build_route_layer=payload.build_route_layer)
response["job"] = job_payload(job)
return response
db.commit()
db.refresh(source)
response["source"] = _source_response(db, source)
return response
@app.post("/api/jobs/sources/{source_id}/import")
@write_endpoint("queue source import")
def queue_source_import(
source_id: int,
run_match: bool = True,
build_route_layer: bool = True,
priority: int = 0,
db: Session = Depends(get_db),
) -> dict:
source = db.get(Source, source_id)
if source is None:
raise HTTPException(status_code=404, detail="source not found")
job, _started = _queue_source_import_job(db, source, run_match=run_match, build_route_layer=build_route_layer, priority=priority)
return job_payload(job)
@app.delete("/api/datasets/{dataset_id}")
@write_endpoint("delete dataset")
def remove_dataset(dataset_id: int, db: Session = Depends(get_db)) -> dict:
dataset = db.get(Dataset, dataset_id)
if dataset is None:
raise HTTPException(status_code=404, detail="dataset not found")
job = create_dataset_delete_job(db, dataset)
db.commit()
db.refresh(job)
return job_payload(job)
@app.get("/api/datasets/search")
def dataset_search(q: str = "", active_only: bool = False, limit: int = 80, db: Session = Depends(get_db)) -> dict:
return search_datasets(db, q, active_only=active_only, limit=limit)
@app.get("/api/datasets/search/feature.geojson")
def dataset_search_feature(type: str, id: str, db: Session = Depends(get_db)) -> JSONResponse:
if type == "gtfs_route":
route = db.get(GtfsRoute, _path_int(id, "GTFS route id"))
if route is None:
raise HTTPException(status_code=404, detail="GTFS route not found")
feature = gtfs_route_feature(route, {"search_result_type": type})
elif type == "osm_route":
osm_feature = resolve_osm_feature(db, id)
if osm_feature is None or osm_feature.kind != "route":
raise HTTPException(status_code=404, detail="OSM route not found")
feature = osm_feature_feature(osm_feature, {"search_result_type": type})
elif type == "route_pattern":
pattern = db.get(RoutePattern, _path_int(id, "route pattern id"))
if pattern is None:
raise HTTPException(status_code=404, detail="route pattern not found")
feature = route_pattern_feature(pattern, {"search_result_type": type})
else:
raise HTTPException(status_code=400, detail="type must be gtfs_route, osm_route, or route_pattern")
return JSONResponse(feature_collection([] if feature is None else [feature]))
@app.post("/api/sample/reset")
@write_endpoint("reset sample data")
def reset_and_load_sample(db: Session = Depends(get_db)) -> dict:
result = load_sample_project(db)
db.commit()
return result
@app.post("/api/jobs/sample-reset")
@write_endpoint("queue sample reset")
def queue_sample_reset(priority: int = 0, db: Session = Depends(get_db)) -> dict:
job = create_maintenance_job(db, "sample-reset", {}, priority=priority)
db.commit()
db.refresh(job)
return job_payload(job)
@app.post("/api/match/run")
@write_endpoint("run route matching")
def run_matching(db: Session = Depends(get_db)) -> dict:
result = run_route_matching(db)
db.commit()
return result
@app.post("/api/route-layer/build")
@write_endpoint("build route layer")
def build_route_layer(db: Session = Depends(get_db)) -> dict:
result = rebuild_route_layer(db)
db.commit()
return result
@app.post("/api/jobs/route-layer-build")
@write_endpoint("queue route layer build")
def queue_route_layer_build(priority: int = 0, db: Session = Depends(get_db)) -> dict:
job = create_route_layer_rebuild_job(db, priority=priority)
db.commit()
db.refresh(job)
return job_payload(job)
@app.post("/api/jobs/address-index-build")
@write_endpoint("queue OSM address index build")
def queue_address_index_build(priority: int = 0, db: Session = Depends(get_db)) -> dict:
job = create_address_index_rebuild_job(db, priority=priority)
db.commit()
db.refresh(job)
return job_payload(job)
@app.post("/api/jobs/match-run")
@write_endpoint("queue route matching")
def queue_route_matching(priority: int = 0, db: Session = Depends(get_db)) -> dict:
job = create_route_matching_job(db, priority=priority)
db.commit()
db.refresh(job)
return job_payload(job)
@app.post("/api/jobs/osm-relabel")
@write_endpoint("queue OSM relabeling")
def queue_osm_relabel(
dataset_id: Optional[int] = None,
build_route_layer: bool = True,
force: bool = False,
priority: int = 0,
db: Session = Depends(get_db),
) -> dict:
job = create_osm_relabel_job(db, dataset_id=dataset_id, build_route_layer=build_route_layer, force=force, priority=priority)
db.commit()
db.refresh(job)
return job_payload(job)
@app.get("/api/jobs")
def list_jobs(
response: Response,
kind: Optional[str] = None,
limit: int = 20,
include_dismissed: bool = False,
db: Session = Depends(get_db),
) -> dict:
reconcile_interrupted_jobs(db)
db.commit()
workers = queue_worker_status()
revision = _job_queue_revision_payload(db, workers=workers, include_dismissed=include_dismissed)
_set_etag(response, revision["revision"])
return {
"jobs": [job_payload(job) for job in latest_jobs(db, limit=limit, kind=kind, include_dismissed=include_dismissed)],
"workers": workers,
"revision": revision["revision"],
}
@app.get("/api/jobs/revision")
def get_jobs_revision(
response: Response,
since: Optional[str] = None,
include_dismissed: bool = False,
db: Session = Depends(get_db),
) -> dict:
reconcile_interrupted_jobs(db)
db.commit()
revision = _job_queue_revision_payload(db, include_dismissed=include_dismissed)
_set_etag(response, revision["revision"])
return {**revision, "changed": since != revision["revision"]}
@app.get("/api/jobs/{job_id}")
def get_job(job_id: int, db: Session = Depends(get_db)) -> dict:
job = db.get(Job, job_id)
if job is None:
raise HTTPException(status_code=404, detail="job not found")
return job_payload(job)
@app.get("/api/jobs/{job_id}/events")
def get_job_events(job_id: int, limit: int = 100, db: Session = Depends(get_db)) -> dict:
job = db.get(Job, job_id)
if job is None:
raise HTTPException(status_code=404, detail="job not found")
return {"job": job_payload(job), "events": [job_event_payload(event) for event in job_events(db, job_id, limit=limit)]}
@app.get("/api/pipeline-runs")
def list_pipeline_runs(
stage: Optional[str] = None,
source_id: Optional[int] = None,
dataset_id: Optional[int] = None,
limit: int = 50,
db: Session = Depends(get_db),
) -> dict:
stmt = select(PipelineRun).order_by(PipelineRun.started_at.desc(), PipelineRun.id.desc())
if stage:
stmt = stmt.where(PipelineRun.stage == stage)
if source_id is not None:
stmt = stmt.where(PipelineRun.source_id == source_id)
if dataset_id is not None:
stmt = stmt.where(PipelineRun.dataset_id == dataset_id)
runs = db.scalars(stmt.limit(max(1, min(int(limit), 200)))).all()
return {"runs": [pipeline_run_payload(run) for run in runs]}
@app.post("/api/jobs/{job_id}/pause")
def pause_job_endpoint(job_id: int, db: Session = Depends(get_db)) -> dict:
return _job_control_response(db, job_id, "pause")
@app.post("/api/jobs/{job_id}/resume")
@write_endpoint("resume job")
def resume_job_endpoint(job_id: int, db: Session = Depends(get_db)) -> dict:
try:
job = resume_job(db, job_id)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
db.commit()
db.refresh(job)
return job_payload(job)
@app.post("/api/jobs/{job_id}/retry")
@write_endpoint("retry job")
def retry_job_endpoint(job_id: int, db: Session = Depends(get_db)) -> dict:
try:
job = retry_job(db, job_id)
except ValueError as exc:
detail = str(exc)
status_code = 404 if detail.startswith("job not found") else 400
raise HTTPException(status_code=status_code, detail=detail) from exc
db.commit()
db.refresh(job)
return job_payload(job)
@app.post("/api/jobs/{job_id}/cancel")
def cancel_job_endpoint(job_id: int, db: Session = Depends(get_db)) -> dict:
return _job_control_response(db, job_id, "cancel")
@app.post("/api/jobs/{job_id}/stop")
def stop_job_endpoint(job_id: int, db: Session = Depends(get_db)) -> dict:
return _job_control_response(db, job_id, "cancel")
def _job_control_response(db: Session, job_id: int, action: str) -> dict:
try:
job = pause_job(db, job_id) if action == "pause" else cancel_job(db, job_id)
db.commit()
db.refresh(job)
return job_payload(job)
except OperationalError as exc:
db.rollback()
if not _is_database_locked_error(exc):
raise
return request_job_control(job_id, action)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.post("/api/jobs/{job_id}/priority")
@write_endpoint("set job priority")
def set_job_priority_endpoint(job_id: int, payload: JobPriorityRequest, db: Session = Depends(get_db)) -> dict:
try:
job = set_job_priority(db, job_id, payload.priority)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
db.commit()
db.refresh(job)
return job_payload(job)
@app.post("/api/jobs/{job_id}/dismiss")
@write_endpoint("dismiss job")
def dismiss_job_endpoint(job_id: int, db: Session = Depends(get_db)) -> dict:
try:
job = dismiss_job(db, job_id)
except ValueError as exc:
detail = str(exc)
status_code = 404 if detail.startswith("job not found") else 400
raise HTTPException(status_code=status_code, detail=detail) from exc
db.commit()
db.refresh(job)
return job_payload(job)
@app.post("/api/jobs/dismiss-terminal")
@write_endpoint("dismiss terminal jobs")
def dismiss_terminal_jobs_endpoint(db: Session = Depends(get_db)) -> dict:
count = dismiss_terminal_jobs(db)
db.commit()
return {"dismissed": count}
@app.get("/api/stats")
def stats(db: Session = Depends(get_db)) -> dict:
queue_missing_gtfs_sidecar_recovery_jobs(db)
db.commit()
active_dataset_ids = [row[0] for row in db.execute(select(Dataset.id).where(Dataset.is_active.is_(True))).all()]
def count_active(model, *where):
if not active_dataset_ids:
return 0
return db.scalar(select(func.count()).select_from(model).where(model.dataset_id.in_(active_dataset_ids), *where)) or 0
active_gtfs_dataset_ids = [
row[0]
for row in db.execute(select(Dataset.id).where(Dataset.is_active.is_(True), Dataset.kind == "gtfs")).all()
]
active_osm_dataset_ids = [
row[0]
for row in db.execute(select(Dataset.id).where(Dataset.is_active.is_(True), Dataset.kind == "osm_geojson")).all()
]
match_counts = {}
if active_gtfs_dataset_ids:
match_counts = {
status: count
for status, count in db.execute(
select(RouteMatch.status, func.count())
.join(RouteMatch.gtfs_route)
.where(GtfsRoute.dataset_id.in_(active_gtfs_dataset_ids))
.group_by(RouteMatch.status)
).all()
}
matched_total = (match_counts.get("matched", 0) or 0) + (match_counts.get("accepted", 0) or 0)
review_total = sum(match_counts.values())
scope_summary = _match_scope_summary(db, active_gtfs_dataset_ids)
scoped_review_total = sum(scope_summary.get("in_osm_scope", {}).values())
scoped_matched_total = (
scope_summary.get("in_osm_scope", {}).get("matched", 0)
+ scope_summary.get("in_osm_scope", {}).get("accepted", 0)
)
return {
"sources": db.scalar(select(func.count()).select_from(Source)) or 0,
"source_catalog": source_catalog_summary(db),
"active_datasets": len(active_dataset_ids),
"gtfs_routes": count_active(GtfsRoute),
"gtfs_stops": count_active(GtfsStop),
"osm_routes": sum(osm_feature_count(db, dataset_id, kind="route") for dataset_id in active_osm_dataset_ids),
"osm_stops_terminals": sum(osm_feature_count(db, dataset_id, kind=["stop", "station", "terminal"]) for dataset_id in active_osm_dataset_ids),
"route_patterns": db.scalar(select(func.count()).select_from(RoutePattern)) or 0,
"matches": match_counts,
"match_summary": {
"reviewed_routes": review_total,
"matched_or_accepted": matched_total,
"probable": match_counts.get("probable", 0) or 0,
"weak": match_counts.get("weak", 0) or 0,
"missing": match_counts.get("missing", 0) or 0,
"coverage_percent": round((matched_total / review_total) * 100, 1) if review_total else 0,
"in_scope_reviewed_routes": scoped_review_total,
"in_scope_matched_or_accepted": scoped_matched_total,
"in_scope_coverage_percent": round((scoped_matched_total / scoped_review_total) * 100, 1) if scoped_review_total else 0,
},
"match_scope_summary": scope_summary,
"by_source": _source_stats(db),
}
@app.get("/api/qa/summary")
def qa_summary_endpoint(db: Session = Depends(get_db)) -> dict:
return qa_summary(db)
@app.get("/api/harmonization/gtfs/inventory")
def gtfs_harmonization_inventory_endpoint(db: Session = Depends(get_db)) -> dict:
return gtfs_harmonization_inventory(db)
@app.get("/api/harmonization/gtfs/sources/{source_id}")
def gtfs_harmonization_feed_detail_endpoint(source_id: int, db: Session = Depends(get_db)) -> dict:
detail = gtfs_harmonization_feed_detail(db, source_id)
if detail is None:
raise HTTPException(status_code=404, detail="GTFS source not found")
return detail
@app.patch("/api/harmonization/gtfs/sources/{source_id}/review")
@write_endpoint("update GTFS feed QA review")
def update_gtfs_harmonization_feed_review(source_id: int, payload: GtfsFeedReviewUpdate, db: Session = Depends(get_db)) -> dict:
source = db.get(Source, source_id)
if source is None or source.kind != "gtfs":
raise HTTPException(status_code=404, detail="GTFS source not found")
if payload.review_status is not None and payload.review_status not in {"unreviewed", "approved", "needs_review", "blocked", "rejected"}:
raise HTTPException(status_code=400, detail="review_status must be unreviewed, approved, needs_review, blocked, or rejected")
if payload.license is not None:
source.license = _truncate(payload.license, 255)
if payload.enabled is not None:
source.enabled = bool(payload.enabled)
if payload.review_status is not None or payload.review_note is not None:
source.notes = _upsert_gtfs_qa_note(
source.notes,
status=payload.review_status or "unreviewed",
note=payload.review_note or "",
)
db.commit()
detail = gtfs_harmonization_feed_detail(db, source_id)
if detail is None:
raise HTTPException(status_code=404, detail="GTFS source not found")
return detail
@app.get("/api/matches")
def list_matches(
status: Optional[str] = None,
source_id: Optional[str] = None,
limit: int = 250,
db: Session = Depends(get_db),
) -> list[dict]:
source_ids = _csv_ints(source_id, "source_id")
gtfs_dataset_ids = _active_dataset_ids(db, source_ids=source_ids, dataset_kinds=["gtfs"])
if not gtfs_dataset_ids:
return []
stmt = select(RouteMatch).options(joinedload(RouteMatch.gtfs_route), joinedload(RouteMatch.osm_feature)).order_by(
RouteMatch.confidence.desc(), RouteMatch.id
).join(RouteMatch.gtfs_route).where(GtfsRoute.dataset_id.in_(gtfs_dataset_ids))
if status:
stmt = stmt.where(RouteMatch.status == status)
matches = db.scalars(stmt.limit(min(limit, 1000))).all()
return [match_row(m) for m in matches]
@app.post("/api/matches/{match_id}/accept")
@write_endpoint("accept route match")
def accept_match(match_id: int, db: Session = Depends(get_db)) -> dict:
match = db.get(RouteMatch, match_id)
if match is None:
raise HTTPException(status_code=404, detail="match not found")
match.status = "accepted"
match.rule_source = "manual"
match.updated_at = datetime.now(timezone.utc)
_persist_match_rule(db, match, "accept_match")
db.commit()
return {"id": match.id, "status": match.status}
@app.post("/api/matches/{match_id}/reject")
@write_endpoint("reject route match")
def reject_match(match_id: int, db: Session = Depends(get_db)) -> dict:
match = db.get(RouteMatch, match_id)
if match is None:
raise HTTPException(status_code=404, detail="match not found")
match.status = "rejected"
match.rule_source = "manual"
match.updated_at = datetime.now(timezone.utc)
_persist_match_rule(db, match, "reject_match")
db.commit()
return {"id": match.id, "status": match.status}
@app.get("/api/matches/{match_id}/candidates")
def match_candidates(match_id: int, limit: int = 25, db: Session = Depends(get_db)) -> dict:
match = db.get(RouteMatch, match_id)
if match is None:
raise HTTPException(status_code=404, detail="match not found")
route = db.get(GtfsRoute, match.gtfs_route_id)
if route is None:
raise HTTPException(status_code=404, detail="GTFS route for match not found")
osm_dataset_ids = _active_dataset_ids(db, dataset_kinds=["osm_geojson"])
if not osm_dataset_ids:
return {
"match_id": match.id,
"route": _gtfs_route_summary(route),
"candidates": [],
"preview": _match_candidate_preview(route, match, []),
}
features = candidate_osm_routes_for_route(db, route, osm_dataset_ids)
candidate_rows = []
for feature in features:
score, reasons = score_route_pair(route, feature)
candidate_rows.append(
(
float(score),
feature.ref or "",
feature.name or "",
feature,
reasons,
)
)
candidate_rows.sort(key=lambda row: (row[0], row[1], row[2]), reverse=True)
selected_rows = candidate_rows[: max(1, min(limit, 100))]
candidates = [
{
"score": round(score, 2),
"status": _status_from_candidate_score(score),
"osm": _osm_route_summary(feature),
"reasons": reasons,
"current_match": feature.id == match.osm_feature_id,
}
for score, _, _, feature, reasons in selected_rows
]
return {
"match_id": match.id,
"current_status": match.status,
"current_confidence": match.confidence,
"route": _gtfs_route_summary(route),
"candidates": candidates,
"preview": _match_candidate_preview(
route,
match,
[(feature, score, reasons) for score, _, _, feature, reasons in selected_rows],
),
}
@app.post("/api/matches/{match_id}/candidates/{osm_feature_id}/accept")
@write_endpoint("accept route match candidate")
def accept_match_candidate(match_id: int, osm_feature_id: str, db: Session = Depends(get_db)) -> dict:
match = db.get(RouteMatch, match_id)
if match is None:
raise HTTPException(status_code=404, detail="match not found")
route = db.get(GtfsRoute, match.gtfs_route_id)
feature = resolve_osm_feature(db, osm_feature_id)
if route is None:
raise HTTPException(status_code=404, detail="GTFS route for match not found")
if feature is None or feature.kind != "route":
raise HTTPException(status_code=404, detail="OSM route candidate not found")
feature = ensure_main_osm_feature(db, feature)
score, reasons = score_route_pair(route, feature)
match.osm_feature = feature
match.osm_feature_id = feature.id
match.confidence = float(score)
match.status = "accepted"
match.rule_source = "manual"
match.reasons_json = json.dumps(reasons, separators=(",", ":"))
match.updated_at = datetime.now(timezone.utc)
_persist_match_rule(db, match, "accept_match")
db.commit()
db.refresh(match)
return {"id": match.id, "status": match.status, "confidence": match.confidence, "match": match_row(match)}
@app.post("/api/rules")
@write_endpoint("create match rule")
def create_rule(payload: RuleCreate, db: Session = Depends(get_db)) -> dict:
rule = MatchRule(
rule_type=payload.rule_type,
selector_json=json.dumps(payload.selector, separators=(",", ":")),
action_json=json.dumps(payload.action, separators=(",", ":")),
note=payload.note,
)
db.add(rule)
db.commit()
db.refresh(rule)
return {"id": rule.id, "rule_type": rule.rule_type}
@app.get("/api/rules")
def list_rules(db: Session = Depends(get_db)) -> list[dict]:
rules = db.scalars(select(MatchRule).order_by(MatchRule.id.desc())).all()
return [
{
"id": r.id,
"rule_type": r.rule_type,
"selector": json.loads(r.selector_json),
"action": json.loads(r.action_json),
"note": r.note,
"active": r.active,
"created_at": r.created_at.isoformat() if r.created_at else None,
}
for r in rules
]
@app.get("/api/canonical-stops/{canonical_stop_id}")
def canonical_stop_detail(canonical_stop_id: int, db: Session = Depends(get_db)) -> dict:
canonical = db.get(CanonicalStop, canonical_stop_id)
if canonical is None:
raise HTTPException(status_code=404, detail="canonical stop not found")
return _canonical_stop_detail(db, canonical)
@app.get("/api/canonical-stops/{canonical_stop_id}/gtfs-candidates")
def canonical_stop_gtfs_candidates(
canonical_stop_id: int,
q: Optional[str] = None,
source_id: Optional[str] = None,
limit: int = 30,
db: Session = Depends(get_db),
) -> dict:
canonical = db.get(CanonicalStop, canonical_stop_id)
if canonical is None:
raise HTTPException(status_code=404, detail="canonical stop not found")
candidates = _gtfs_stop_candidates_for_canonical_stop(
db,
canonical,
q=q,
source_ids=_csv_ints(source_id, "source_id"),
limit=limit,
)
return {"canonical_stop": _canonical_stop_summary(canonical), "candidates": candidates}
@app.post("/api/canonical-stops/{canonical_stop_id}/link-gtfs-stop")
@write_endpoint("link GTFS stop to canonical stop")
def link_gtfs_stop_to_canonical_stop(
canonical_stop_id: int,
payload: CanonicalStopGtfsLinkRequest,
db: Session = Depends(get_db),
) -> dict:
canonical = db.get(CanonicalStop, canonical_stop_id)
if canonical is None:
raise HTTPException(status_code=404, detail="canonical stop not found")
stop = _resolve_gtfs_stop_link_payload(db, payload)
link = db.scalar(
select(CanonicalStopLink).where(
CanonicalStopLink.object_type == "gtfs_stop",
CanonicalStopLink.dataset_id == stop.dataset_id,
CanonicalStopLink.object_id == stop.id,
)
)
role = "parent" if stop.parent_station is None else "platform"
if link is None:
link = CanonicalStopLink(
canonical_stop_id=canonical.id,
layer="timetable",
object_type="gtfs_stop",
dataset_id=stop.dataset_id,
object_id=stop.id,
external_id=stop.stop_id,
role=role,
confidence=1.0,
distance_m=None,
metadata_json=json.dumps({"manual_rule": "link_canonical_stop"}, separators=(",", ":")),
)
db.add(link)
else:
link.canonical_stop_id = canonical.id
link.role = role
link.confidence = 1.0
link.metadata_json = json.dumps({"manual_rule": "link_canonical_stop"}, separators=(",", ":"))
_persist_canonical_stop_rule(db, "link_canonical_stop", stop, canonical, payload.note)
db.commit()
db.refresh(canonical)
return _canonical_stop_detail(db, canonical)
@app.post("/api/canonical-stop-links/{link_id}/unlink")
@write_endpoint("unlink GTFS stop from canonical stop")
def unlink_gtfs_stop_from_canonical_stop(link_id: int, db: Session = Depends(get_db)) -> dict:
link = db.get(CanonicalStopLink, link_id)
if link is None:
raise HTTPException(status_code=404, detail="canonical stop link not found")
if link.object_type != "gtfs_stop":
raise HTTPException(status_code=400, detail="only GTFS stop links can be unlinked here")
stop = db.get(GtfsStop, link.object_id)
if stop is None:
raise HTTPException(status_code=404, detail="linked GTFS stop not found")
standalone = _manual_standalone_canonical_stop(db, stop)
link.canonical_stop_id = standalone.id
link.confidence = 1.0
link.metadata_json = json.dumps({"manual_rule": "unlink_canonical_stop"}, separators=(",", ":"))
_persist_canonical_stop_rule(db, "unlink_canonical_stop", stop, standalone, None)
db.commit()
db.refresh(standalone)
return {"canonical_stop": _canonical_stop_summary(standalone), "moved_link_id": link.id}
@app.get("/api/map/osm_routes.geojson")
def map_osm_routes(
mode: Optional[str] = None,
route_scope: Optional[str] = None,
source_id: Optional[str] = None,
dataset_id: Optional[str] = None,
bbox: Optional[str] = None,
zoom: Optional[int] = None,
limit: int = MAP_FEATURE_LIMIT,
db: Session = Depends(get_db),
) -> JSONResponse:
return _osm_features_response(db=db, kind="route", mode=mode, route_scope=route_scope, source_id=source_id, dataset_id=dataset_id, bbox=bbox, limit=limit)
@app.get("/api/map/osm_stops.geojson")
def map_osm_stops(
geometry: Optional[str] = None,
source_id: Optional[str] = None,
dataset_id: Optional[str] = None,
bbox: Optional[str] = None,
zoom: Optional[int] = None,
limit: int = MAP_FEATURE_LIMIT,
db: Session = Depends(get_db),
) -> JSONResponse:
return _osm_features_response(
db=db,
kind="stop,station,terminal",
geometry=geometry,
source_id=source_id,
dataset_id=dataset_id,
bbox=bbox,
limit=limit,
)
@app.get("/api/map/osm_features.geojson")
def map_osm_features(
kind: Optional[str] = None,
mode: Optional[str] = None,
route_scope: Optional[str] = None,
geometry: Optional[str] = None,
source_id: Optional[str] = None,
dataset_id: Optional[str] = None,
bbox: Optional[str] = None,
zoom: Optional[int] = None,
limit: int = MAP_FEATURE_LIMIT,
db: Session = Depends(get_db),
) -> JSONResponse:
return _osm_features_response(
db=db,
kind=kind,
mode=mode,
route_scope=route_scope,
geometry=geometry,
source_id=source_id,
dataset_id=dataset_id,
bbox=bbox,
limit=limit,
)
@app.get("/api/map/gtfs_routes.geojson")
def map_gtfs_routes(
mode: Optional[str] = None,
route_scope: Optional[str] = None,
source_id: Optional[str] = None,
dataset_id: Optional[str] = None,
bbox: Optional[str] = None,
zoom: Optional[int] = None,
limit: int = MAP_FEATURE_LIMIT,
db: Session = Depends(get_db),
) -> JSONResponse:
active_dataset_ids = _active_dataset_ids(
db,
source_ids=_csv_ints(source_id, "source_id"),
dataset_ids=_csv_ints(dataset_id, "dataset_id"),
dataset_kinds=["gtfs"],
)
if not active_dataset_ids:
return JSONResponse(feature_collection([]))
stmt = select(GtfsRoute).where(GtfsRoute.dataset_id.in_(active_dataset_ids))
modes = _csv_values(mode)
if modes:
stmt = stmt.where(GtfsRoute.mode.in_(modes))
route_scopes = _csv_values(route_scope)
if route_scopes:
stmt = stmt.where(_gtfs_route_scope_condition(route_scopes))
parsed_bbox = _parse_bbox(bbox)
if parsed_bbox:
stmt = _where_bbox_overlaps(stmt, GtfsRoute, parsed_bbox)
rows = db.scalars(stmt.order_by(GtfsRoute.mode, GtfsRoute.short_name, GtfsRoute.id).limit(_clamp_limit(limit))).all()
return JSONResponse(feature_collection(f for r in rows if (f := gtfs_route_feature(r)) is not None))
@app.get("/api/map/route_patterns.geojson")
def map_route_patterns(
mode: Optional[str] = None,
route_scope: Optional[str] = None,
source_kind: Optional[str] = None,
status: Optional[str] = None,
bbox: Optional[str] = None,
zoom: Optional[int] = None,
limit: int = MAP_FEATURE_LIMIT,
db: Session = Depends(get_db),
) -> JSONResponse:
stmt = select(RoutePattern)
modes = _csv_values(mode)
if modes:
stmt = stmt.where(RoutePattern.mode.in_(modes))
route_scopes = _csv_values(route_scope)
if route_scopes:
stmt = stmt.where(_route_pattern_scope_condition(route_scopes))
source_kinds = _csv_values(source_kind)
if source_kinds:
stmt = stmt.where(RoutePattern.source_kind.in_(source_kinds))
statuses = _csv_values(status)
if statuses:
stmt = stmt.where(RoutePattern.status.in_(statuses))
parsed_bbox = _parse_bbox(bbox)
if parsed_bbox:
stmt = _where_bbox_overlaps(stmt, RoutePattern, parsed_bbox)
rows = db.scalars(
stmt.order_by(RoutePattern.mode, RoutePattern.route_ref, RoutePattern.source_kind, RoutePattern.id).limit(_clamp_limit(limit))
).all()
return JSONResponse(feature_collection(f for r in rows if (f := route_pattern_feature(r)) is not None))
@app.get("/api/map/gtfs_stops.geojson")
def map_gtfs_stops(
source_id: Optional[str] = None,
dataset_id: Optional[str] = None,
bbox: Optional[str] = None,
zoom: Optional[int] = None,
limit: int = MAP_FEATURE_LIMIT,
db: Session = Depends(get_db),
) -> JSONResponse:
active_dataset_ids = _active_dataset_ids(
db,
source_ids=_csv_ints(source_id, "source_id"),
dataset_ids=_csv_ints(dataset_id, "dataset_id"),
dataset_kinds=["gtfs"],
)
if not active_dataset_ids:
return JSONResponse(feature_collection([]))
stmt = select(GtfsStop).where(GtfsStop.dataset_id.in_(active_dataset_ids))
parsed_bbox = _parse_bbox(bbox)
if parsed_bbox:
stmt = _where_point_bbox(stmt, GtfsStop, parsed_bbox)
rows = db.scalars(stmt.order_by(GtfsStop.name, GtfsStop.id).limit(_clamp_limit(limit))).all()
return JSONResponse(feature_collection(f for r in rows if (f := gtfs_stop_feature(r)) is not None))
@app.get("/api/map/matched_gtfs_routes.geojson")
def map_matched_gtfs_routes(
status: Optional[str] = None,
source_id: Optional[str] = None,
dataset_id: Optional[str] = None,
bbox: Optional[str] = None,
zoom: Optional[int] = None,
limit: int = MAP_FEATURE_LIMIT,
db: Session = Depends(get_db),
) -> JSONResponse:
stmt = select(RouteMatch).options(joinedload(RouteMatch.gtfs_route), joinedload(RouteMatch.osm_feature))
if status:
stmt = stmt.where(RouteMatch.status == status)
gtfs_dataset_ids = _active_dataset_ids(
db,
source_ids=_csv_ints(source_id, "source_id"),
dataset_ids=_csv_ints(dataset_id, "dataset_id"),
dataset_kinds=["gtfs"],
)
if not gtfs_dataset_ids:
return JSONResponse(feature_collection([]))
stmt = stmt.join(RouteMatch.gtfs_route).where(GtfsRoute.dataset_id.in_(gtfs_dataset_ids))
parsed_bbox = _parse_bbox(bbox)
if parsed_bbox:
stmt = _where_bbox_overlaps(stmt, GtfsRoute, parsed_bbox)
matches = db.scalars(stmt.order_by(RouteMatch.confidence.desc(), RouteMatch.id).limit(_clamp_limit(limit))).all()
features = []
for match in matches:
route = match.gtfs_route
extra = {
"match_id": match.id,
"match_status": match.status,
"confidence": match.confidence,
"gtfs_route_id": route.route_id,
"gtfs_ref": route.short_name,
"gtfs_mode": route.mode,
"visual_source": "osm" if match.osm_feature else "gtfs",
}
if match.osm_feature:
feature = osm_feature_feature(match.osm_feature, extra)
else:
feature = gtfs_route_feature(route, extra)
if feature:
features.append(feature)
return JSONResponse(feature_collection(features))
@app.get("/api/journey/stops")
def journey_stops(
q: Optional[str] = None,
source_id: Optional[str] = None,
bbox: Optional[str] = None,
limit: int = 25,
db: Session = Depends(get_db),
) -> dict:
if settings.is_postgresql_database:
db.execute(text("SET LOCAL statement_timeout = '4000ms'"))
parsed_bbox = _parse_bbox(bbox)
selected_limit = max(1, min(limit, 100))
stop_limit = selected_limit if source_id else max(1, int(selected_limit * 0.7))
address_limit = 0 if source_id else max(1, selected_limit - stop_limit)
stops = []
try:
stops = search_scheduled_stops(
db=db,
query=q,
source_ids=_csv_ints(source_id, "source_id"),
bbox=parsed_bbox,
limit=stop_limit,
)
except OperationalError as exc:
if not _is_statement_timeout_error(exc):
raise
db.rollback()
return {"stops": [], "timed_out": True, "message": "Search timed out; keep typing or try a more specific query."}
if address_limit:
try:
stops.extend(search_addresses(db=db, query=q, bbox=parsed_bbox, limit=address_limit))
except OperationalError as exc:
if not _is_statement_timeout_error(exc):
raise
db.rollback()
return {
"stops": stops[:selected_limit],
"timed_out": True,
"message": "Address search timed out; showing stop results.",
}
return {
"stops": stops[:selected_limit]
}
@app.get("/api/journey/nearest-location")
def journey_nearest_location(
lat: float,
lon: float,
source_id: Optional[str] = None,
limit: int = 4,
stop_radius_m: float = 35,
db: Session = Depends(get_db),
) -> dict:
if lat < -90 or lat > 90 or lon < -180 or lon > 180:
raise HTTPException(status_code=400, detail="lat/lon out of range")
stops = nearest_scheduled_stops(
db=db,
lat=lat,
lon=lon,
source_ids=_csv_ints(source_id, "source_id"),
limit=1,
radius_m=max(5, min(float(stop_radius_m), 120)),
)
location = stops[0] if stops else None
if location is not None:
return {
"lat": lat,
"lon": lon,
"location": location,
"locations": [location],
"selection_kind": "stop",
}
address_job = active_address_index_rebuild_job(db)
if address_job is not None:
return {
"lat": lat,
"lon": lon,
"location": None,
"locations": [],
"selection_kind": "coordinate",
"address_lookup_skipped": True,
"message": f"Address index rebuild job #{address_job.id} is {address_job.status}; using coordinates.",
}
address = address_at_point(db=db, lat=lat, lon=lon)
if address is not None:
return {
"lat": lat,
"lon": lon,
"location": address,
"locations": [address],
"selection_kind": "address",
}
return {
"lat": lat,
"lon": lon,
"location": None,
"locations": [],
"selection_kind": "coordinate",
}
@app.get("/api/addresses/status")
def addresses_status(db: Session = Depends(get_db)) -> dict:
return address_index_status(db)
@app.get("/api/addresses/search")
def addresses_search(
q: Optional[str] = None,
bbox: Optional[str] = None,
limit: int = 25,
db: Session = Depends(get_db),
) -> dict:
return {"addresses": search_addresses(db=db, query=q, bbox=_parse_bbox(bbox), limit=limit)}
@app.get("/api/journey/search")
def journey_search(
from_stop_id: str,
to_stop_id: str,
via_stop_id: Optional[str] = None,
source_id: Optional[str] = None,
departure: str = "08:00",
service_date: Optional[str] = None,
max_transfers: int = 0,
transfer_seconds: int = 120,
limit: int = 5,
db: Session = Depends(get_db),
) -> dict:
try:
return find_journeys(
db=db,
from_stop_id=from_stop_id,
to_stop_id=to_stop_id,
departure=departure,
max_transfers=max(0, min(max_transfers, 5)),
transfer_seconds=max(0, min(transfer_seconds, 3600)),
limit=limit,
source_ids=_csv_ints(source_id, "source_id"),
via_stop_id=via_stop_id,
service_date=service_date,
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.post("/api/journey/searches")
def start_progressive_journey_search(payload: JourneySearchRequest) -> dict:
mode = payload.mode if payload.mode in {"transit", "walk", "drive", "car"} else "transit"
ranking = payload.ranking if payload.ranking in {"recommended", "earliest_arrival", "duration", "fewest_transfers"} else "recommended"
try:
return start_journey_search(
{
"from_stop_id": payload.from_stop_id,
"to_stop_id": payload.to_stop_id,
"via_stop_id": payload.via_stop_id,
"source_id": payload.source_id,
"departure": payload.departure,
"service_date": payload.service_date,
"mode": mode,
"direct_only": bool(payload.direct_only),
"ranking": ranking,
"transfer_seconds": max(0, min(payload.transfer_seconds, 3600)),
"limit": max(1, min(payload.limit, 10)),
}
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.get("/api/journey/searches/{search_id}")
def progressive_journey_search_status(search_id: str) -> dict:
try:
return journey_search_payload(search_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail="journey search not found") from exc
@app.delete("/api/journey/searches/{search_id}")
def cancel_progressive_journey_search(search_id: str) -> dict:
try:
return cancel_journey_search(search_id)
except KeyError as exc:
raise HTTPException(status_code=404, detail="journey search not found") from exc
@app.post("/api/itineraries/generate")
@write_endpoint("generate itineraries")
def generate_itineraries_endpoint(payload: ItineraryGenerateRequest, db: Session = Depends(get_db)) -> dict:
try:
result = generate_itineraries(
db,
from_stop_id=payload.from_stop_id,
to_stop_id=payload.to_stop_id,
via_stop_id=payload.via_stop_id,
departure=payload.departure,
service_date=payload.service_date,
max_transfers=max(0, min(payload.max_transfers, 5)),
transfer_seconds=max(0, min(payload.transfer_seconds, 3600)),
limit=max(1, min(payload.limit, 10)),
source_ids=_csv_ints(payload.source_id, "source_id"),
preferences=payload.preferences,
)
db.commit()
return result
except ValueError as exc:
db.rollback()
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.get("/api/itineraries")
def list_itineraries(saved_only: bool = False, limit: int = 30, db: Session = Depends(get_db)) -> dict:
return {"itineraries": recent_itineraries(db, saved_only=saved_only, limit=limit)}
@app.post("/api/itineraries/{itinerary_id}/save")
@write_endpoint("save itinerary")
def save_itinerary(itinerary_id: int, payload: ItinerarySaveRequest, db: Session = Depends(get_db)) -> dict:
itinerary = db.get(Itinerary, itinerary_id)
if itinerary is None:
raise HTTPException(status_code=404, detail="itinerary not found")
result = set_itinerary_saved(db, itinerary, payload.saved)
db.commit()
return result
@app.get("/api/routing/status")
def routing_status_endpoint(db: Session = Depends(get_db)) -> dict:
return routing_status(db)
@app.get("/api/routing/route")
def routing_route_endpoint(
from_lon: float,
from_lat: float,
to_lon: float,
to_lat: float,
mode: str = "walk",
dataset_id: Optional[int] = None,
max_visited: int = 160_000,
db: Session = Depends(get_db),
) -> dict:
try:
return route_between_points(
db,
from_lon=from_lon,
from_lat=from_lat,
to_lon=to_lon,
to_lat=to_lat,
mode=mode,
dataset_id=dataset_id,
max_visited=max(1_000, min(max_visited, 1_000_000)),
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.post("/api/itinerary-legs/{leg_id}/lock")
@write_endpoint("lock itinerary leg")
def lock_itinerary_leg(leg_id: int, payload: ItineraryLegLockRequest, db: Session = Depends(get_db)) -> dict:
leg = db.get(ItineraryLeg, leg_id)
if leg is None:
raise HTTPException(status_code=404, detail="itinerary leg not found")
result = set_leg_locked(db, leg, payload.locked)
db.commit()
return result
def _canonical_stop_summary(canonical: CanonicalStop) -> dict:
return {
"id": canonical.id,
"stop_key": canonical.stop_key,
"name": canonical.name,
"normalized_name": canonical.normalized_name,
"lat": canonical.lat,
"lon": canonical.lon,
"mode": canonical.mode,
"metadata": _json_object(canonical.metadata_json),
"created_at": canonical.created_at.isoformat() if canonical.created_at else None,
}
def _canonical_stop_detail(db: Session, canonical: CanonicalStop) -> dict:
rows = db.execute(
select(CanonicalStopLink, Dataset, Source)
.join(Dataset, Dataset.id == CanonicalStopLink.dataset_id)
.join(Source, Source.id == Dataset.source_id)
.where(CanonicalStopLink.canonical_stop_id == canonical.id)
.order_by(CanonicalStopLink.layer, Source.name, Dataset.id, CanonicalStopLink.role, CanonicalStopLink.external_id)
).all()
gtfs_links = []
osm_links = []
for link, dataset, source in rows:
if link.object_type == "gtfs_stop":
stop = db.get(GtfsStop, link.object_id)
gtfs_links.append(_canonical_gtfs_stop_link_payload(link, stop, dataset, source))
elif link.object_type == "osm_feature":
feature = db.get(OsmFeature, link.object_id)
osm_links.append(_canonical_osm_feature_link_payload(link, feature, dataset, source))
return {
"canonical_stop": _canonical_stop_summary(canonical),
"gtfs_stops": gtfs_links,
"osm_features": osm_links,
"rules": _canonical_stop_rule_payloads(db, canonical, gtfs_links),
}
def _canonical_gtfs_stop_link_payload(link: CanonicalStopLink, stop: GtfsStop | None, dataset: Dataset, source: Source) -> dict:
return {
"link_id": link.id,
"layer": link.layer,
"object_type": link.object_type,
"dataset_id": link.dataset_id,
"source_id": source.id,
"source_name": source.name,
"source_kind": source.kind,
"dataset_active": dataset.is_active,
"external_id": link.external_id,
"role": link.role,
"confidence": link.confidence,
"distance_m": link.distance_m,
"metadata": _json_object(link.metadata_json),
"stop": None
if stop is None
else {
"id": stop.id,
"dataset_id": stop.dataset_id,
"stop_id": stop.stop_id,
"name": stop.name,
"lat": stop.lat,
"lon": stop.lon,
"parent_station": stop.parent_station,
},
}
def _canonical_osm_feature_link_payload(
link: CanonicalStopLink,
feature: OsmFeature | None,
dataset: Dataset,
source: Source,
) -> dict:
return {
"link_id": link.id,
"layer": link.layer,
"object_type": link.object_type,
"dataset_id": link.dataset_id,
"source_id": source.id,
"source_name": source.name,
"source_kind": source.kind,
"dataset_active": dataset.is_active,
"external_id": link.external_id,
"role": link.role,
"confidence": link.confidence,
"distance_m": link.distance_m,
"metadata": _json_object(link.metadata_json),
"feature": None
if feature is None
else {
"id": feature.id,
"dataset_id": feature.dataset_id,
"osm_type": feature.osm_type,
"osm_id": feature.osm_id,
"kind": feature.kind,
"mode": feature.mode,
"name": feature.name,
"ref": feature.ref,
"operator": feature.operator,
"network": feature.network,
"geometry": {
"present": bool(feature.geometry_geojson),
"bbox": [feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat],
},
},
}
def _canonical_stop_rule_payloads(db: Session, canonical: CanonicalStop, gtfs_links: list[dict]) -> list[dict]:
rule_rows = db.scalars(
select(MatchRule)
.where(MatchRule.rule_type.in_(["link_canonical_stop", "unlink_canonical_stop"]))
.order_by(MatchRule.id.desc())
.limit(300)
).all()
linked_keys = {
(int(link["source_id"]), str(link["external_id"]))
for link in gtfs_links
if link.get("source_id") is not None and link.get("external_id")
}
payloads = []
for rule in rule_rows:
selector = _json_object(rule.selector_json)
action = _json_object(rule.action_json)
selector_key = (
_optional_int(selector.get("source_id") or _json_object_value(selector.get("gtfs_stop"), "source_id")),
str(selector.get("external_id") or _json_object_value(selector.get("gtfs_stop"), "external_id") or ""),
)
if action.get("target_stop_key") != canonical.stop_key and selector_key not in linked_keys:
continue
payloads.append(
{
"id": rule.id,
"rule_type": rule.rule_type,
"selector": selector,
"action": action,
"note": rule.note,
"active": rule.active,
"created_at": rule.created_at.isoformat() if rule.created_at else None,
}
)
return payloads[:50]
def _gtfs_stop_candidates_for_canonical_stop(
db: Session,
canonical: CanonicalStop,
*,
q: str | None,
source_ids: list[int] | None,
limit: int,
) -> list[dict]:
active_dataset_ids = _active_dataset_ids(db, source_ids=source_ids, dataset_kinds=["gtfs"])
if not active_dataset_ids:
return []
stmt = (
select(GtfsStop, Dataset, Source)
.join(Dataset, Dataset.id == GtfsStop.dataset_id)
.join(Source, Source.id == Dataset.source_id)
.where(GtfsStop.dataset_id.in_(active_dataset_ids))
)
query = (q or "").strip()
if query:
pattern = f"%{query}%"
tokens = [token for token in query.replace(",", " ").replace("/", " ").split() if token]
token_filters = [
or_(GtfsStop.name.ilike(f"%{token}%"), GtfsStop.stop_id.ilike(f"%{token}%"))
for token in tokens
]
where_parts = [GtfsStop.name.ilike(pattern), GtfsStop.stop_id.ilike(pattern)]
if token_filters:
where_parts.append(and_(*token_filters))
stmt = stmt.where(or_(*where_parts))
elif canonical.lon is not None and canonical.lat is not None:
radius = 0.015
stmt = stmt.where(
GtfsStop.lon >= canonical.lon - radius,
GtfsStop.lon <= canonical.lon + radius,
GtfsStop.lat >= canonical.lat - radius,
GtfsStop.lat <= canonical.lat + radius,
)
else:
return []
rows = db.execute(stmt.order_by(Source.name, GtfsStop.name, GtfsStop.stop_id).limit(max(1, min(limit, 100)) * 4)).all()
if not rows:
return []
stop_ids = [stop.id for stop, _, _ in rows]
link_by_stop_id = {
link.object_id: link
for link in db.scalars(
select(CanonicalStopLink).where(
CanonicalStopLink.object_type == "gtfs_stop",
CanonicalStopLink.object_id.in_(stop_ids),
)
).all()
}
stop_ids_by_dataset: dict[int, set[str]] = {}
for stop, _, _ in rows:
stop_ids_by_dataset.setdefault(stop.dataset_id, set()).add(stop.stop_id)
scheduled_by_dataset = {
dataset_id: set(scheduled_stop_ids(db, dataset_id, sorted(stop_ids)))
for dataset_id, stop_ids in stop_ids_by_dataset.items()
}
scheduled_keys = {
(dataset_id, stop_id)
for dataset_id, stop_ids in scheduled_by_dataset.items()
for stop_id in stop_ids
if stop_id in stop_ids_by_dataset.get(dataset_id, set())
}
candidates = []
for stop, dataset, source in rows:
link = link_by_stop_id.get(stop.id)
distance_m = _distance_m(canonical.lon, canonical.lat, stop.lon, stop.lat)
candidates.append(
{
"id": stop.id,
"dataset_id": stop.dataset_id,
"source_id": source.id,
"source_name": source.name,
"dataset_active": dataset.is_active,
"stop_id": stop.stop_id,
"name": stop.name,
"lat": stop.lat,
"lon": stop.lon,
"parent_station": stop.parent_station,
"scheduled": (stop.dataset_id, stop.stop_id) in scheduled_keys,
"distance_m": distance_m,
"current_canonical_stop_id": None if link is None else link.canonical_stop_id,
"current_link_id": None if link is None else link.id,
}
)
candidates.sort(
key=lambda item: (
0 if item["current_canonical_stop_id"] == canonical.id else 1,
0 if item["scheduled"] else 1,
float("inf") if item["distance_m"] is None else item["distance_m"],
item["source_name"] or "",
item["name"] or "",
item["stop_id"],
)
)
return candidates[: max(1, min(limit, 100))]
def _resolve_gtfs_stop_link_payload(db: Session, payload: CanonicalStopGtfsLinkRequest) -> GtfsStop:
stop = db.get(GtfsStop, payload.gtfs_stop_id) if payload.gtfs_stop_id is not None else None
if stop is None and payload.dataset_id is not None and payload.stop_id:
stop = db.scalar(
select(GtfsStop).where(
GtfsStop.dataset_id == payload.dataset_id,
GtfsStop.stop_id == payload.stop_id,
)
)
if stop is None:
raise HTTPException(status_code=404, detail="GTFS stop not found")
return stop
def _manual_standalone_canonical_stop(db: Session, stop: GtfsStop) -> CanonicalStop:
stop_key = f"manual:gtfs_stop:{stop.dataset_id}:{stop.stop_id}"
canonical = db.scalar(select(CanonicalStop).where(CanonicalStop.stop_key == stop_key))
if canonical is not None:
return canonical
name = stop.name or stop.stop_id
canonical = CanonicalStop(
stop_key=stop_key,
name=name,
normalized_name=" ".join(name.casefold().split()),
lat=stop.lat,
lon=stop.lon,
metadata_json=json.dumps({"source": "manual_unlink", "dataset_id": stop.dataset_id, "stop_id": stop.stop_id}, separators=(",", ":")),
)
db.add(canonical)
db.flush()
return canonical
def _persist_canonical_stop_rule(
db: Session,
rule_type: str,
stop: GtfsStop,
canonical: CanonicalStop,
note: str | None,
) -> None:
dataset = db.get(Dataset, stop.dataset_id)
selector = {
"object_type": "gtfs_stop",
"source_id": None if dataset is None else dataset.source_id,
"dataset_id": stop.dataset_id,
"external_id": stop.stop_id,
}
action = {
"target_stop_key": canonical.stop_key,
"target_name": canonical.name,
"target_lat": canonical.lat,
"target_lon": canonical.lon,
"target_mode": canonical.mode,
"target_gtfs_stops": _target_gtfs_stop_rule_refs(db, canonical.id),
}
db.add(
MatchRule(
rule_type=rule_type,
selector_json=json.dumps(selector, separators=(",", ":")),
action_json=json.dumps(action, separators=(",", ":")),
note=note or f"Created from canonical stop {canonical.id}",
)
)
def _target_gtfs_stop_rule_refs(db: Session, canonical_stop_id: int) -> list[dict]:
rows = db.execute(
select(CanonicalStopLink, Dataset)
.join(Dataset, Dataset.id == CanonicalStopLink.dataset_id)
.where(
CanonicalStopLink.canonical_stop_id == canonical_stop_id,
CanonicalStopLink.object_type == "gtfs_stop",
)
.order_by(Dataset.is_active.desc(), Dataset.source_id, CanonicalStopLink.role, CanonicalStopLink.external_id)
.limit(50)
).all()
return [
{
"source_id": dataset.source_id,
"dataset_id": link.dataset_id,
"external_id": link.external_id,
}
for link, dataset in rows
]
def _distance_m(left_lon: float | None, left_lat: float | None, right_lon: float | None, right_lat: float | None) -> float | None:
if None in {left_lon, left_lat, right_lon, right_lat}:
return None
return round((((float(left_lon) - float(right_lon)) ** 2 + (float(left_lat) - float(right_lat)) ** 2) ** 0.5) * 111_320, 1)
def _optional_int(value: object) -> int | None:
try:
return None if value is None else int(value)
except (TypeError, ValueError):
return None
def _json_object_value(value: object, key: str) -> object:
return value.get(key) if isinstance(value, dict) else None
def _job_queue_revision_payload(
db: Session,
*,
workers: list[dict] | None = None,
include_dismissed: bool = False,
) -> dict:
queue_revision = job_queue_revision(db, include_dismissed=include_dismissed)
worker_status = workers if workers is not None else queue_worker_status()
worker_revision = _worker_revision(worker_status)
return {
**queue_revision,
"job_revision": queue_revision["revision"],
"worker_revision": worker_revision,
"revision": f"{queue_revision['revision']}|workers:{worker_revision}",
"workers": worker_status,
}
def _worker_revision(workers: list[dict]) -> str:
if not workers:
return "none"
parts = []
for worker in sorted(workers, key=lambda item: str(item.get("worker_id") or "")):
parts.append(
":".join(
[
str(worker.get("worker_id") or ""),
"1" if worker.get("running") else "0",
str(worker.get("pid") or 0),
str(worker.get("log_file") or ""),
]
)
)
return "|".join(parts)
def _set_etag(response: Response, revision: str) -> None:
response.headers["ETag"] = f'W/"{revision}"'
def _upsert_gtfs_qa_note(notes: str | None, *, status: str, note: str) -> str | None:
status_text = (status or "unreviewed").strip() or "unreviewed"
note_text = " ".join((note or "").strip().split())
updated_at = datetime.now(timezone.utc).isoformat()
marker = f"{GTFS_QA_NOTE_PREFIX} status={status_text}; updated_at={updated_at}"
if note_text:
marker = f"{marker}; note={note_text}"
preserved = [
line
for line in str(notes or "").splitlines()
if line.strip() and not line.startswith(GTFS_QA_NOTE_PREFIX)
]
if status_text != "unreviewed" or note_text:
preserved.insert(0, marker)
return "\n".join(preserved) or None
def _source_response(
db: Session,
source: Source,
update_check: SourceUpdateCheck | None = None,
active_job: Job | None = None,
active_dataset_jobs: dict[int, Job] | None = None,
) -> dict:
is_online = urlparse(source.url).scheme in {"http", "https"}
dataset_jobs = active_dataset_jobs or {}
return {
"id": source.id,
"name": source.name,
"kind": source.kind,
"url": source.url,
"is_online": is_online,
"country": source.country,
"license": source.license,
"priority": source.priority,
"mode_scope": source.mode_scope,
"source_basis": source.source_basis,
"notes": source.notes,
"catalog_entry_id": source.catalog_entry_id,
"enabled": source.enabled,
"status": source.status,
"last_error": source.last_error,
"last_run_at": source.last_run_at.isoformat() if source.last_run_at else None,
"active_job": None if active_job is None else job_payload(active_job),
"latest_update_check": update_check_payload(update_check),
"stats": source_row_counts(db, source),
"datasets": [
{
"id": dataset.id,
"kind": dataset.kind,
"is_active": dataset.is_active,
"status": dataset.status,
"local_path": dataset.local_path,
"sha256": dataset.sha256,
"created_at": dataset.created_at.isoformat() if dataset.created_at else None,
"metadata": _json_object(dataset.metadata_json),
"stats": dataset_row_counts(db, dataset.id, dataset.kind),
"active_job": None if dataset.id not in dataset_jobs else job_payload(dataset_jobs[dataset.id]),
}
for dataset in source.datasets
],
}
def _queue_source_import_job(db: Session, source: Source, *, run_match: bool, build_route_layer: bool, priority: int = 0) -> tuple[Job, bool]:
active_job = active_source_import_job(db, source.id)
if active_job is not None:
db.commit()
return active_job, False
job = create_source_import_job(db, source, run_match=run_match, build_route_layer=build_route_layer, priority=priority)
db.commit()
db.refresh(job)
return job, True
def _queue_admin_maintenance_job(
db: Session,
action: str,
payload: AdminActionRequest | None,
*,
priority: int = 0,
) -> dict:
if action not in ADMIN_JOB_ACTIONS:
raise HTTPException(status_code=404, detail="admin action not found")
request_payload = payload or AdminActionRequest()
if action == "reset-db" and request_payload.confirm != "RESET":
raise HTTPException(status_code=400, detail="confirmation text RESET is required")
if action == "vacuum-db" and request_payload.confirm != "VACUUM":
raise HTTPException(status_code=400, detail="confirmation text VACUUM is required")
if action in {"prune-cache", "prune-inactive-datasets"} and not request_payload.dry_run and request_payload.confirm != "PRUNE":
raise HTTPException(status_code=400, detail="confirmation text PRUNE is required")
job_payload_data = _request_model_payload(request_payload)
job_payload_data.pop("confirm", None)
if action in {"init-db", "vacuum-db", "reset-db"}:
job_payload_data.pop("dry_run", None)
job = create_maintenance_job(db, action, job_payload_data, priority=priority)
db.commit()
db.refresh(job)
return job_payload(job)
def _request_model_payload(model: BaseModel) -> dict:
if hasattr(model, "model_dump"):
return model.model_dump(exclude_none=True)
return model.dict(exclude_none=True)
def _active_dataset_ids(
db: Session,
source_ids: list[int] | None = None,
dataset_ids: list[int] | None = None,
dataset_kinds: list[str] | None = None,
) -> list[int]:
stmt = select(Dataset.id).where(Dataset.is_active.is_(True))
if source_ids:
stmt = stmt.where(Dataset.source_id.in_(source_ids))
if dataset_ids:
stmt = stmt.where(Dataset.id.in_(dataset_ids))
if dataset_kinds:
stmt = stmt.where(Dataset.kind.in_(dataset_kinds))
return [row[0] for row in db.execute(stmt).all()]
def _gtfs_route_scope_condition(route_scopes: list[str]):
fallback_builder = _route_scope_fallback_condition(
mode_column=GtfsRoute.mode,
ref_column=GtfsRoute.short_name,
name_column=GtfsRoute.long_name,
)
fallback = fallback_builder(route_scopes)
stored = GtfsRoute.route_scope.in_(route_scopes)
if "local" in route_scopes:
non_local_bus_fallback = fallback_builder(["long_distance", "regional"])
stored = and_(stored, not_(and_(GtfsRoute.mode.in_(["bus", "trolleybus"]), non_local_bus_fallback)))
return or_(stored, fallback)
def _route_pattern_scope_condition(route_scopes: list[str]):
fallback_builder = _route_scope_fallback_condition(
mode_column=RoutePattern.mode,
ref_column=RoutePattern.route_ref,
name_column=RoutePattern.route_name,
)
fallback = fallback_builder(route_scopes)
stored = RoutePattern.route_scope.in_(route_scopes)
if "local" in route_scopes:
non_local_bus_fallback = fallback_builder(["long_distance", "regional"])
stored = and_(stored, not_(and_(RoutePattern.mode.in_(["bus", "trolleybus"]), non_local_bus_fallback)))
return or_(stored, fallback)
def _route_scope_fallback_condition(*, mode_column, ref_column, name_column):
def condition(route_scopes: list[str]):
ref = func.upper(func.coalesce(ref_column, ""))
name = func.upper(func.coalesce(name_column, ""))
train_long_distance = and_(
mode_column == "train",
or_(
ref.like("ICE%"),
ref.like("IC%"),
ref.like("EC%"),
ref.like("ECE%"),
ref.like("EN%"),
ref.like("NJ%"),
ref.like("RJ%"),
ref.like("RJX%"),
ref.like("TGV%"),
ref.like("THA%"),
ref.like("FLX%"),
name.like("%INTERCITY%"),
name.like("%EUROCITY%"),
name.like("%NIGHTJET%"),
name.like("%FLIXTRAIN%"),
),
)
bus_long_distance = and_(
mode_column.in_(["bus", "trolleybus"]),
or_(
ref.like("FLX%"),
name.like("%FLIXBUS%"),
name.like("%EUROLINES%"),
name.like("%INTERCITYBUS%"),
name.like("%IC BUS%"),
name.like("%FERNBUS%"),
name.like("%LONG DISTANCE%"),
),
)
long_distance = or_(mode_column == "coach", train_long_distance, bus_long_distance)
bus_regional = and_(
mode_column.in_(["bus", "trolleybus"]),
not_(bus_long_distance),
or_(
name.like("%REGIONALBUS%"),
name.like("%REGIOBUS%"),
name.like("%REGIONAL BUS%"),
name.like("%REGIONALVERKEHR%"),
),
)
local = or_(
mode_column.in_(["tram", "light_rail", "subway", "ferry", "funicular", "aerialway", "monorail"]),
and_(mode_column.in_(["bus", "trolleybus"]), not_(or_(bus_long_distance, bus_regional))),
and_(mode_column == "train", or_(ref.like("S%"), name.like("%S-BAHN%"))),
)
train_regional = and_(
mode_column == "train",
not_(train_long_distance),
or_(
ref.like("IRE%"),
ref.like("RE%"),
ref.like("RB%"),
ref.like("RER%"),
ref.like("TER%"),
ref.like("REX%"),
ref.like("MEX%"),
ref.like("ALX%"),
ref.like("WFB%"),
ref.like("R%"),
name.like("%REGIONAL%"),
name.like("%REGIO%"),
),
)
regional = or_(train_regional, bus_regional)
conditions = []
if "long_distance" in route_scopes:
conditions.append(long_distance)
if "regional" in route_scopes:
conditions.append(regional)
if "local" in route_scopes:
conditions.append(local)
if "unknown" in route_scopes:
conditions.append(and_(mode_column == "train", not_(or_(long_distance, regional, local))))
return or_(*conditions) if conditions else mode_column.is_(None)
return condition
def _osm_features_response(
db: Session,
kind: Optional[str],
mode: Optional[str] = None,
route_scope: Optional[str] = None,
geometry: Optional[str] = None,
source_id: Optional[str] = None,
dataset_id: Optional[str] = None,
bbox: Optional[str] = None,
limit: int = MAP_FEATURE_LIMIT,
) -> JSONResponse:
active_dataset_ids = _active_dataset_ids(
db,
source_ids=_csv_ints(source_id, "source_id"),
dataset_ids=_csv_ints(dataset_id, "dataset_id"),
dataset_kinds=["osm_geojson"],
)
if not active_dataset_ids:
return JSONResponse(feature_collection([]))
kinds = _csv_values(kind)
modes = _csv_values(mode)
route_scopes = _csv_values(route_scope)
parsed_bbox = _parse_bbox(bbox)
rows = query_osm_features(
db,
active_dataset_ids,
kinds=kinds or None,
modes=modes or None,
route_scopes=route_scopes or None,
bbox=parsed_bbox,
limit=_clamp_limit(limit),
)
if not source_id and not dataset_id:
rows = _dedupe_osm_feature_rows(rows)
features = []
for row in rows:
feature = osm_feature_feature(row)
if feature is not None and _geometry_matches(feature["geometry"], geometry):
features.append(feature)
return JSONResponse(feature_collection(features))
def _dedupe_osm_feature_rows(rows: list[OsmFeature]) -> list[OsmFeature]:
selected: dict[tuple[str, str, str], OsmFeature] = {}
for row in rows:
key = (row.kind, row.osm_type, row.osm_id)
current = selected.get(key)
if current is None or _osm_feature_preference(row) < _osm_feature_preference(current):
selected[key] = row
return list(selected.values())
def _osm_feature_preference(row: OsmFeature) -> tuple[int, int]:
span = None
if None not in {row.min_lon, row.min_lat, row.max_lon, row.max_lat}:
span = abs(float(row.max_lon) - float(row.min_lon)) + abs(float(row.max_lat) - float(row.min_lat))
return (0 if span is not None else 1, -int((span or 0) * 1_000_000), row.dataset_id)
def _csv_values(value: Optional[str]) -> list[str]:
if not value:
return []
return [part.strip() for part in value.split(",") if part.strip()]
def _csv_ints(value: Optional[str], name: str) -> list[int]:
if not value:
return []
values = []
for part in value.split(","):
part = part.strip()
if not part:
continue
try:
values.append(int(part))
except ValueError as exc:
raise HTTPException(status_code=400, detail=f"{name} values must be integers") from exc
return values
def _path_int(value: str, name: str) -> int:
try:
return int(value)
except (TypeError, ValueError) as exc:
raise HTTPException(status_code=400, detail=f"{name} must be an integer") from exc
def _json_object(value: str | None) -> dict:
try:
data = json.loads(value or "{}")
except json.JSONDecodeError:
return {}
return data if isinstance(data, dict) else {}
def _source_stats(db: Session) -> list[dict]:
active_datasets = db.scalars(select(Dataset).where(Dataset.is_active.is_(True))).all()
rows = []
for dataset in active_datasets:
source = db.get(Source, dataset.source_id)
if source is None:
continue
item = {
"source_id": source.id,
"source_name": source.name,
"source_kind": source.kind,
"dataset_id": dataset.id,
"dataset_kind": dataset.kind,
"routes": 0,
"stops": 0,
"trips": 0,
"stop_times": 0,
"features": 0,
"match_counts": {},
}
dataset_stats = dataset_row_counts(db, dataset.id, dataset.kind)
if dataset.kind == "gtfs":
item["routes"] = dataset_stats.get("routes", 0)
item["stops"] = dataset_stats.get("stops", 0)
item["trips"] = dataset_stats.get("trips", 0)
item["stop_times"] = dataset_stats.get("stop_times", 0)
item["match_counts"] = dataset_stats.get("match_counts", {})
elif dataset.kind == "osm_geojson":
item["routes"] = dataset_stats.get("routes", 0)
item["stops"] = dataset_stats.get("stops", 0)
item["features"] = dataset_stats.get("features", 0)
rows.append(item)
return rows
def _match_scope_summary(db: Session, active_gtfs_dataset_ids: list[int]) -> dict[str, dict[str, int]]:
if not active_gtfs_dataset_ids:
return {}
rows = db.execute(
select(RouteMatch.status, RouteMatch.reasons_json)
.join(RouteMatch.gtfs_route)
.where(GtfsRoute.dataset_id.in_(active_gtfs_dataset_ids))
).all()
summary: dict[str, dict[str, int]] = {}
for status, reasons_json in rows:
try:
reasons = json.loads(reasons_json or "{}")
except json.JSONDecodeError:
reasons = {}
scope = str(reasons.get("scope") or "unknown_scope")
scope_counts = summary.setdefault(scope, {})
scope_counts[str(status)] = scope_counts.get(str(status), 0) + 1
return summary
def _parse_bbox(value: Optional[str]) -> tuple[float, float, float, float] | None:
if not value:
return None
parts = value.split(",")
if len(parts) != 4:
raise HTTPException(status_code=400, detail="bbox must be min_lon,min_lat,max_lon,max_lat")
try:
min_lon, min_lat, max_lon, max_lat = [float(part) for part in parts]
except ValueError as exc:
raise HTTPException(status_code=400, detail="bbox values must be numbers") from exc
if min_lon > max_lon or min_lat > max_lat:
raise HTTPException(status_code=400, detail="bbox minimums must be less than maximums")
return min_lon, min_lat, max_lon, max_lat
def _where_bbox_overlaps(stmt, model, bbox: tuple[float, float, float, float]):
min_lon, min_lat, max_lon, max_lat = bbox
table_name = getattr(model, "__tablename__", "")
if using_postgresql() and table_name in {"gtfs_routes", "route_patterns"}:
return stmt.where(
text(
f"""
(
{table_name}.geom && ST_MakeEnvelope(:bbox_min_lon, :bbox_min_lat, :bbox_max_lon, :bbox_max_lat, 4326)
OR (
{table_name}.geom IS NULL
AND {table_name}.min_lon <= :bbox_max_lon
AND {table_name}.max_lon >= :bbox_min_lon
AND {table_name}.min_lat <= :bbox_max_lat
AND {table_name}.max_lat >= :bbox_min_lat
)
)
"""
)
).params(
bbox_min_lon=min_lon,
bbox_min_lat=min_lat,
bbox_max_lon=max_lon,
bbox_max_lat=max_lat,
)
return stmt.where(model.min_lon <= max_lon, model.max_lon >= min_lon, model.min_lat <= max_lat, model.max_lat >= min_lat)
def _where_point_bbox(stmt, model, bbox: tuple[float, float, float, float]):
min_lon, min_lat, max_lon, max_lat = bbox
table_name = getattr(model, "__tablename__", "")
if using_postgresql() and table_name in {"gtfs_stops", "canonical_stops"}:
return stmt.where(
text(
f"""
(
{table_name}.geom && ST_MakeEnvelope(:bbox_min_lon, :bbox_min_lat, :bbox_max_lon, :bbox_max_lat, 4326)
OR (
{table_name}.geom IS NULL
AND {table_name}.lon >= :bbox_min_lon
AND {table_name}.lon <= :bbox_max_lon
AND {table_name}.lat >= :bbox_min_lat
AND {table_name}.lat <= :bbox_max_lat
)
)
"""
)
).params(
bbox_min_lon=min_lon,
bbox_min_lat=min_lat,
bbox_max_lon=max_lon,
bbox_max_lat=max_lat,
)
return stmt.where(model.lon >= min_lon, model.lon <= max_lon, model.lat >= min_lat, model.lat <= max_lat)
def _geometry_matches(geometry: dict, requested: Optional[str]) -> bool:
if not requested:
return True
geometry_type = geometry.get("type")
if requested == "point":
return geometry_type == "Point" or geometry_type == "MultiPoint"
if requested == "line":
return geometry_type == "LineString" or geometry_type == "MultiLineString"
if requested == "polygon":
return geometry_type == "Polygon" or geometry_type == "MultiPolygon"
if requested == "nonpoint":
return geometry_type not in {"Point", "MultiPoint"}
raise HTTPException(status_code=400, detail="geometry must be point, line, polygon, or nonpoint")
def _clamp_limit(value: int) -> int:
return max(1, min(value, MAP_FEATURE_LIMIT_MAX))
def _persist_match_rule(db: Session, match: RouteMatch, rule_type: str) -> None:
route = db.get(GtfsRoute, match.gtfs_route_id)
if route is None:
return
route_dataset = db.get(Dataset, route.dataset_id)
feature = db.get(OsmFeature, match.osm_feature_id) if match.osm_feature_id else None
feature_dataset = db.get(Dataset, feature.dataset_id) if feature is not None else None
selector = {
"gtfs": {
"source_id": None if route_dataset is None else route_dataset.source_id,
"dataset_id": route.dataset_id,
"route_id": route.route_id,
"route_key": route.route_key,
"ref": route.short_name,
"mode": route.mode,
},
"gtfs_route_id": match.gtfs_route_id,
}
action = {"status": match.status}
if feature is not None:
selector["osm_feature_id"] = match.osm_feature_id
action["osm"] = {
"source_id": None if feature_dataset is None else feature_dataset.source_id,
"dataset_id": feature.dataset_id,
"osm_type": feature.osm_type,
"osm_id": feature.osm_id,
"route_key": feature.route_key,
"ref": feature.ref,
"mode": feature.mode,
}
db.add(
MatchRule(
rule_type=rule_type,
selector_json=json.dumps(selector, separators=(",", ":")),
action_json=json.dumps(action, separators=(",", ":")),
note=f"Created from match {match.id}",
)
)
def _status_from_candidate_score(score: float) -> str:
if score >= 85:
return "matched"
if score >= 65:
return "probable"
if score >= 40:
return "weak"
return "below_threshold"
def _catalog_country(entry: SourceCatalogEntry | None) -> str | None:
if entry is None or not entry.country_code:
return None
country = entry.country_code.strip()
if len(country) == 2 and country.isalpha():
return country.upper()
return None
def _catalog_notes(entry: SourceCatalogEntry | None) -> str | None:
if entry is None:
return None
parts = [
entry.next_pipeline_action,
entry.coverage_notes,
entry.geometry_notes,
]
return _truncate(" ".join(part for part in parts if part), 2000)
def _truncate(value: str | None, length: int) -> str | None:
if not value:
return None
return value[:length]
def _is_database_locked_error(exc: OperationalError) -> bool:
text = " ".join(str(part).lower() for part in [exc, getattr(exc, "orig", "")])
return "database is locked" in text or "database table is locked" in text or "database is busy" in text
def _is_statement_timeout_error(exc: OperationalError) -> bool:
text = " ".join(str(part).lower() for part in [exc, getattr(exc, "orig", "")])
return "statement timeout" in text or "canceling statement due to statement timeout" in text
def _gtfs_route_summary(route: GtfsRoute) -> dict:
return {
"id": route.id,
"dataset_id": route.dataset_id,
"route_id": route.route_id,
"ref": route.short_name,
"name": route.long_name,
"mode": route.mode,
"operator": route.operator_name,
"geometry": {
"present": bool(route.geometry_geojson),
"bbox": [route.min_lon, route.min_lat, route.max_lon, route.max_lat],
},
}
def _osm_route_summary(feature: OsmFeature) -> dict:
return {
"id": osm_feature_public_id(feature),
"row_id": feature.id,
"dataset_id": feature.dataset_id,
"osm_type": feature.osm_type,
"osm_id": feature.osm_id,
"ref": feature.ref,
"name": feature.name,
"mode": feature.mode,
"operator": feature.operator,
"network": feature.network,
"geometry": {
"present": bool(feature.geometry_geojson),
"bbox": [feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat],
},
}
def _match_candidate_preview(
route: GtfsRoute,
match: RouteMatch,
candidate_rows: list[tuple[OsmFeature, float, dict[str, object]]],
) -> dict:
features: list[dict] = []
route_feature = gtfs_route_feature(
route,
{
"preview_role": "gtfs_route",
"match_id": match.id,
"match_status": match.status,
"label": route.short_name or route.route_id,
},
)
if route_feature is not None:
features.append(route_feature)
for rank, (feature, score, _reasons) in enumerate(candidate_rows, start=1):
candidate_feature = osm_feature_feature(
feature,
{
"preview_role": "candidate",
"match_id": match.id,
"candidate_rank": rank,
"candidate_score": round(float(score), 2),
"candidate_status": _status_from_candidate_score(score),
"current_match": feature.id == match.osm_feature_id,
"label": feature.ref or feature.name or feature.osm_id,
},
)
if candidate_feature is not None:
features.append(candidate_feature)
return feature_collection(features)