from __future__ import annotations import json import csv from pathlib import Path from typing import Optional import typer from sqlalchemy import func, select, text from app.config import settings from app.data_management import dataset_sidecar_paths, prune_inactive_datasets from app.db import engine, init_db, reset_db, session_scope from app.db_lock import database_write_lock from app.feed_discovery import build_gtfs_discovery_manifests, default_generated_dir from app.models import ( Dataset, GtfsRoute, GtfsShape, GtfsStop, RouteMatch, RoutePattern, Source, SourceCatalogEntry, ) from app.pipeline.matcher import run_route_matching from app.pipeline.osm_labeling import relabel_osm_features from app.pipeline.osm_pbf import run_osm_pbf_source_staged from app.pipeline.run import run_source from app.pipeline.gtfs import backfill_gtfs_shapes from app.pipeline.route_layer import rebuild_route_layer from app.pipeline.sample_data import load_sample_project from app.osm_storage import osm_feature_count from app.jobs import run_worker_loop from app.jobs import create_route_layer_rebuild_job, create_route_matching_job, create_source_import_job from app.source_catalog import ( default_ingestable_sources_path, default_source_catalog_path, import_ingestable_sources, import_source_catalog, source_catalog_summary, ) cli = typer.Typer(help="Mobility Workbench pipeline CLI") @cli.command("init-db") def init_db_command() -> None: with _write_lock("init-db"): init_db() typer.echo("Database initialized") @cli.command("reset-db") def reset_db_command() -> None: with _write_lock("reset-db"): reset_db() typer.echo("Database reset") @cli.command("load-sample") def load_sample_command() -> None: with _write_lock("load-sample"): init_db() with session_scope() as session: result = load_sample_project(session) typer.echo(json.dumps(result, indent=2)) @cli.command("add-source") def add_source_command( name: str = typer.Option(..., help="Source name"), kind: str = typer.Option(..., help="gtfs, osm_geojson, osm_pbf, or osm_diff"), url: str = typer.Option(..., help="HTTP URL or local path"), country: Optional[str] = typer.Option(None), license: Optional[str] = typer.Option(None), priority: Optional[str] = typer.Option(None), mode_scope: Optional[str] = typer.Option(None), source_basis: Optional[str] = typer.Option(None), notes: Optional[str] = typer.Option(None), ) -> None: with _write_lock("add-source"): init_db() if kind not in {"gtfs", "osm_geojson", "osm_pbf", "osm_diff"}: raise typer.BadParameter("kind must be gtfs, osm_geojson, osm_pbf, or osm_diff") with session_scope() as session: source = Source( name=name, kind=kind, url=url, country=country, license=license, priority=priority, mode_scope=mode_scope, source_basis=source_basis, notes=notes, ) session.add(source) session.flush() typer.echo(json.dumps({"id": source.id, "name": source.name}, indent=2)) @cli.command("run-source") def run_source_command(source_id: int) -> None: init_db() with session_scope() as session: source = session.get(Source, source_id) if source is None: raise typer.BadParameter(f"source not found: {source_id}") source_kind = source.kind if source_kind == "osm_pbf": dataset = run_osm_pbf_source_staged(source_id) typer.echo(json.dumps({"source_id": source_id, "dataset_id": dataset.id, "status": dataset.status, "import_mode": "staged_short_lock"}, indent=2)) return with _write_lock("run-source"): with session_scope() as session: source = session.get(Source, source_id) if source is None: raise typer.BadParameter(f"source not found: {source_id}") dataset = run_source(session, source) typer.echo(json.dumps({"source_id": source.id, "dataset_id": dataset.id, "status": dataset.status}, indent=2)) @cli.command("run-match") def run_match_command() -> None: with _write_lock("run-match"): init_db() with session_scope() as session: result = run_route_matching(session) typer.echo(json.dumps(result, indent=2)) @cli.command("build-route-layer") def build_route_layer_command() -> None: with _write_lock("build-route-layer"): init_db() with session_scope() as session: result = rebuild_route_layer(session) typer.echo(json.dumps(result, indent=2)) @cli.command("relabel-osm-features") def relabel_osm_features_command( dataset_id: Optional[int] = typer.Option(None, help="Only relabel one OSM dataset"), force: bool = typer.Option(False, help="Run even when the recorded dependency signature is current"), chunk_size: int = typer.Option(5000, help="Rows per relabel batch"), rebuild_indexes: bool = typer.Option(True, help="Drop/rebuild affected route-scope indexes around large relabel writes"), build_route_layer: bool = typer.Option(True, help="Rebuild the route layer after relabeling"), ) -> None: with _write_lock("relabel-osm-features"): init_db() with session_scope() as session: result = relabel_osm_features( session, dataset_id=dataset_id, force=force, chunk_size=chunk_size, rebuild_indexes=rebuild_indexes, ) if build_route_layer and (result["changed"] or force): result["route_layer_result"] = rebuild_route_layer(session) typer.echo(json.dumps(result, indent=2)) @cli.command("backfill-gtfs-shapes") def backfill_gtfs_shapes_command(dataset_id: Optional[int] = typer.Option(None, help="Only backfill one GTFS dataset")) -> None: with _write_lock("backfill-gtfs-shapes"): init_db() with session_scope() as session: result = backfill_gtfs_shapes(session, dataset_id=dataset_id) typer.echo(json.dumps(result, indent=2)) @cli.command("stats") def stats_command() -> None: init_db() with session_scope() as session: active_dataset_ids = [row[0] for row in session.execute(select(Dataset.id).where(Dataset.is_active.is_(True))).all()] stats = { "sources": session.scalar(select(func.count()).select_from(Source)), "source_catalog_entries": session.scalar(select(func.count()).select_from(SourceCatalogEntry)) or 0, "active_datasets": len(active_dataset_ids), "gtfs_routes": session.scalar(select(func.count()).select_from(GtfsRoute).where(GtfsRoute.dataset_id.in_(active_dataset_ids))) if active_dataset_ids else 0, "gtfs_stops": session.scalar(select(func.count()).select_from(GtfsStop).where(GtfsStop.dataset_id.in_(active_dataset_ids))) if active_dataset_ids else 0, "gtfs_shapes": session.scalar(select(func.count()).select_from(GtfsShape).where(GtfsShape.dataset_id.in_(active_dataset_ids))) if active_dataset_ids else 0, "route_patterns": session.scalar(select(func.count()).select_from(RoutePattern)) or 0, "osm_routes": sum(osm_feature_count(session, dataset_id, kind="route") for dataset_id in active_dataset_ids), "matches": {status: count for status, count in session.execute(select(RouteMatch.status, func.count()).group_by(RouteMatch.status)).all()}, } typer.echo(json.dumps(stats, indent=2)) @cli.command("import-source-catalog") def import_source_catalog_command( csv_path: Path = typer.Option(default_source_catalog_path(), "--csv", help="Source catalog CSV path"), no_update: bool = typer.Option(False, help="Skip rows that already exist"), ) -> None: with _write_lock("import-source-catalog"): init_db() with session_scope() as session: result = import_source_catalog(session, csv_path, update_existing=not no_update) result["summary"] = source_catalog_summary(session) typer.echo(json.dumps(result, indent=2)) @cli.command("import-ingestable-sources") def import_ingestable_sources_command( csv_path: Path = typer.Option(default_ingestable_sources_path(), "--csv", help="Ingestable source seed CSV path"), no_update: bool = typer.Option(False, help="Skip sources that already exist"), ) -> None: with _write_lock("import-ingestable-sources"): init_db() with session_scope() as session: result = import_ingestable_sources(session, csv_path, update_existing=not no_update) result["summary"] = source_catalog_summary(session) typer.echo(json.dumps(result, indent=2)) @cli.command("discover-gtfs-sources") def discover_gtfs_sources_command( output_dir: Path = typer.Option(default_generated_dir(), "--output-dir", help="Directory for generated discovery CSVs"), countries: str = typer.Option( ",".join(["DE", "AT", "CH", "NL", "DK", "FR", "BE", "LU", "NO", "SE", "FI", "IE", "GB"]), "--countries", help="Comma-separated country codes, or ALL for every country exposed by the upstream catalogs", ), no_mobility_database: bool = typer.Option(False, help="Skip Mobility Database feeds_v2.csv"), no_acceptance_test_list: bool = typer.Option(False, help="Skip MobilityData validator acceptance-test feed list"), no_ptna: bool = typer.Option(False, help="Skip PTNA GTFS analysis pages"), max_ptna_details: int = typer.Option(80, help="Maximum PTNA detail pages to fetch for license/crosswalk metadata"), test_limit: int = typer.Option(24, help="Rows to write to the focused test-run ingestable CSV"), check_urls: bool = typer.Option(False, help="Run HEAD/range checks for ingestable feed URLs"), ) -> None: result = build_gtfs_discovery_manifests( output_dir=output_dir, countries=[part.strip() for part in countries.split(",") if part.strip()], include_mobility_database=not no_mobility_database, include_acceptance_test_list=not no_acceptance_test_list, include_ptna=not no_ptna, max_ptna_details=max_ptna_details, test_limit=test_limit, check_urls=check_urls, ) typer.echo(json.dumps(result, indent=2, ensure_ascii=False)) @cli.command("queue-source-imports-from-csv") def queue_source_imports_from_csv_command( csv_path: Path = typer.Option(default_ingestable_sources_path(), "--csv", help="Ingestable source CSV path"), no_update: bool = typer.Option(False, help="Skip sources that already exist instead of updating them"), run_match_at_end: bool = typer.Option(True, help="Queue one route-matching job after all source imports"), build_route_layer_at_end: bool = typer.Option(True, help="Queue one route-layer rebuild after route matching"), priority: int = typer.Option(0, help="Priority for queued source import jobs"), ) -> None: with _write_lock("queue-source-imports-from-csv"): init_db() with session_scope() as session: csv_path = csv_path if csv_path.is_absolute() else Path.cwd() / csv_path imported = import_ingestable_sources(session, csv_path, update_existing=not no_update) source_urls = _source_urls_from_ingestable_csv(csv_path) sources = session.scalars( select(Source) .where(Source.kind == "gtfs", Source.url.in_(source_urls)) .order_by(Source.id) ).all() jobs = [ create_source_import_job( session, source, run_match=False, build_route_layer=False, priority=priority, ) for source in sources ] route_match_job = create_route_matching_job(session, priority=priority) if run_match_at_end else None route_layer_job = create_route_layer_rebuild_job(session, priority=priority) if build_route_layer_at_end else None typer.echo( json.dumps( { "csv": str(csv_path), "imported": imported, "sources": [{"id": source.id, "name": source.name} for source in sources], "source_import_jobs": [job.id for job in jobs], "route_match_job": None if route_match_job is None else route_match_job.id, "route_layer_job": None if route_layer_job is None else route_layer_job.id, }, indent=2, ensure_ascii=False, ) ) @cli.command("prune-cache") def prune_cache_command(dry_run: bool = typer.Option(False, help="Report files without deleting them")) -> None: with _write_lock("prune-cache"): init_db() with session_scope() as session: referenced = { Path(path).resolve() for path in session.scalars(select(Dataset.local_path)).all() if path } for dataset in session.scalars(select(Dataset)).all(): referenced.update(path.resolve() for path in dataset_sidecar_paths(dataset)) roots = [settings.data_dir / "sources", settings.data_dir / "derived", settings.data_dir / "sidecars", settings.data_dir / "staging"] candidates = [ path for root in roots if root.exists() for path in root.rglob("*") if path.is_file() and path.resolve() not in referenced ] total_bytes = sum(path.stat().st_size for path in candidates) if not dry_run: for path in candidates: path.unlink() for root in roots: _remove_empty_dirs(root) typer.echo( json.dumps( { "dry_run": dry_run, "files": len(candidates), "bytes": total_bytes, "deleted": 0 if dry_run else len(candidates), }, indent=2, ) ) @cli.command("prune-inactive-datasets") def prune_inactive_datasets_command( dry_run: bool = typer.Option(False, help="Report inactive normalized datasets without deleting them"), ) -> None: with _write_lock("prune-inactive-datasets"): init_db() with session_scope() as session: result = prune_inactive_datasets(session, dry_run=dry_run) typer.echo(json.dumps(result, indent=2)) @cli.command("vacuum-db") def vacuum_db_command() -> None: with _write_lock("vacuum-db"): init_db() with engine.connect().execution_options(isolation_level="AUTOCOMMIT") as connection: connection.execute(text("VACUUM")) connection.execute(text("PRAGMA wal_checkpoint(TRUNCATE)")) typer.echo("Database vacuumed") @cli.command("worker") def worker_command( once: bool = typer.Option(False, help="Process at most one queued job and exit"), max_jobs: Optional[int] = typer.Option(None, help="Process at most this many jobs and exit"), poll_interval: float = typer.Option(2.0, help="Seconds to wait between queue polls"), worker_id: Optional[str] = typer.Option(None, help="Stable worker identifier"), ) -> None: result = run_worker_loop(worker_id=worker_id, poll_interval=poll_interval, max_jobs=max_jobs, once=once) typer.echo(json.dumps(result, indent=2)) def _remove_empty_dirs(root: Path) -> None: if not root.exists(): return for path in sorted((p for p in root.rglob("*") if p.is_dir()), key=lambda p: len(p.parts), reverse=True): try: path.rmdir() except OSError: pass def _write_lock(operation: str): return database_write_lock(f"cli:{operation}", timeout=settings.database_write_lock_cli_timeout_seconds) def _source_urls_from_ingestable_csv(path: Path) -> list[str]: urls: list[str] = [] with path.open("r", encoding="utf-8-sig", newline="") as handle: for row in csv.DictReader(handle): if (row.get("kind") or "").strip().lower() != "gtfs": continue url = (row.get("url") or "").strip() if url and url not in urls: urls.append(url) return urls if __name__ == "__main__": cli()