Alpha stage commit
This commit is contained in:
394
app/cli.py
Normal file
394
app/cli.py
Normal file
@@ -0,0 +1,394 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import csv
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
from sqlalchemy import func, select, text
|
||||
|
||||
from app.config import settings
|
||||
from app.data_management import dataset_sidecar_paths, prune_inactive_datasets
|
||||
from app.db import engine, init_db, reset_db, session_scope
|
||||
from app.db_lock import database_write_lock
|
||||
from app.feed_discovery import build_gtfs_discovery_manifests, default_generated_dir
|
||||
from app.models import (
|
||||
Dataset,
|
||||
GtfsRoute,
|
||||
GtfsShape,
|
||||
GtfsStop,
|
||||
RouteMatch,
|
||||
RoutePattern,
|
||||
Source,
|
||||
SourceCatalogEntry,
|
||||
)
|
||||
from app.pipeline.matcher import run_route_matching
|
||||
from app.pipeline.osm_labeling import relabel_osm_features
|
||||
from app.pipeline.osm_pbf import run_osm_pbf_source_staged
|
||||
from app.pipeline.run import run_source
|
||||
from app.pipeline.gtfs import backfill_gtfs_shapes
|
||||
from app.pipeline.route_layer import rebuild_route_layer
|
||||
from app.pipeline.sample_data import load_sample_project
|
||||
from app.osm_storage import osm_feature_count
|
||||
from app.jobs import run_worker_loop
|
||||
from app.jobs import create_route_layer_rebuild_job, create_route_matching_job, create_source_import_job
|
||||
from app.source_catalog import (
|
||||
default_ingestable_sources_path,
|
||||
default_source_catalog_path,
|
||||
import_ingestable_sources,
|
||||
import_source_catalog,
|
||||
source_catalog_summary,
|
||||
)
|
||||
|
||||
cli = typer.Typer(help="Mobility Workbench pipeline CLI")
|
||||
|
||||
|
||||
@cli.command("init-db")
|
||||
def init_db_command() -> None:
|
||||
with _write_lock("init-db"):
|
||||
init_db()
|
||||
typer.echo("Database initialized")
|
||||
|
||||
|
||||
@cli.command("reset-db")
|
||||
def reset_db_command() -> None:
|
||||
with _write_lock("reset-db"):
|
||||
reset_db()
|
||||
typer.echo("Database reset")
|
||||
|
||||
|
||||
@cli.command("load-sample")
|
||||
def load_sample_command() -> None:
|
||||
with _write_lock("load-sample"):
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
result = load_sample_project(session)
|
||||
typer.echo(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
@cli.command("add-source")
|
||||
def add_source_command(
|
||||
name: str = typer.Option(..., help="Source name"),
|
||||
kind: str = typer.Option(..., help="gtfs, osm_geojson, osm_pbf, or osm_diff"),
|
||||
url: str = typer.Option(..., help="HTTP URL or local path"),
|
||||
country: Optional[str] = typer.Option(None),
|
||||
license: Optional[str] = typer.Option(None),
|
||||
priority: Optional[str] = typer.Option(None),
|
||||
mode_scope: Optional[str] = typer.Option(None),
|
||||
source_basis: Optional[str] = typer.Option(None),
|
||||
notes: Optional[str] = typer.Option(None),
|
||||
) -> None:
|
||||
with _write_lock("add-source"):
|
||||
init_db()
|
||||
if kind not in {"gtfs", "osm_geojson", "osm_pbf", "osm_diff"}:
|
||||
raise typer.BadParameter("kind must be gtfs, osm_geojson, osm_pbf, or osm_diff")
|
||||
with session_scope() as session:
|
||||
source = Source(
|
||||
name=name,
|
||||
kind=kind,
|
||||
url=url,
|
||||
country=country,
|
||||
license=license,
|
||||
priority=priority,
|
||||
mode_scope=mode_scope,
|
||||
source_basis=source_basis,
|
||||
notes=notes,
|
||||
)
|
||||
session.add(source)
|
||||
session.flush()
|
||||
typer.echo(json.dumps({"id": source.id, "name": source.name}, indent=2))
|
||||
|
||||
|
||||
@cli.command("run-source")
|
||||
def run_source_command(source_id: int) -> None:
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
source = session.get(Source, source_id)
|
||||
if source is None:
|
||||
raise typer.BadParameter(f"source not found: {source_id}")
|
||||
source_kind = source.kind
|
||||
if source_kind == "osm_pbf":
|
||||
dataset = run_osm_pbf_source_staged(source_id)
|
||||
typer.echo(json.dumps({"source_id": source_id, "dataset_id": dataset.id, "status": dataset.status, "import_mode": "staged_short_lock"}, indent=2))
|
||||
return
|
||||
with _write_lock("run-source"):
|
||||
with session_scope() as session:
|
||||
source = session.get(Source, source_id)
|
||||
if source is None:
|
||||
raise typer.BadParameter(f"source not found: {source_id}")
|
||||
dataset = run_source(session, source)
|
||||
typer.echo(json.dumps({"source_id": source.id, "dataset_id": dataset.id, "status": dataset.status}, indent=2))
|
||||
|
||||
|
||||
@cli.command("run-match")
|
||||
def run_match_command() -> None:
|
||||
with _write_lock("run-match"):
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
result = run_route_matching(session)
|
||||
typer.echo(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
@cli.command("build-route-layer")
|
||||
def build_route_layer_command() -> None:
|
||||
with _write_lock("build-route-layer"):
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
result = rebuild_route_layer(session)
|
||||
typer.echo(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
@cli.command("relabel-osm-features")
|
||||
def relabel_osm_features_command(
|
||||
dataset_id: Optional[int] = typer.Option(None, help="Only relabel one OSM dataset"),
|
||||
force: bool = typer.Option(False, help="Run even when the recorded dependency signature is current"),
|
||||
chunk_size: int = typer.Option(5000, help="Rows per relabel batch"),
|
||||
rebuild_indexes: bool = typer.Option(True, help="Drop/rebuild affected route-scope indexes around large relabel writes"),
|
||||
build_route_layer: bool = typer.Option(True, help="Rebuild the route layer after relabeling"),
|
||||
) -> None:
|
||||
with _write_lock("relabel-osm-features"):
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
result = relabel_osm_features(
|
||||
session,
|
||||
dataset_id=dataset_id,
|
||||
force=force,
|
||||
chunk_size=chunk_size,
|
||||
rebuild_indexes=rebuild_indexes,
|
||||
)
|
||||
if build_route_layer and (result["changed"] or force):
|
||||
result["route_layer_result"] = rebuild_route_layer(session)
|
||||
typer.echo(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
@cli.command("backfill-gtfs-shapes")
|
||||
def backfill_gtfs_shapes_command(dataset_id: Optional[int] = typer.Option(None, help="Only backfill one GTFS dataset")) -> None:
|
||||
with _write_lock("backfill-gtfs-shapes"):
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
result = backfill_gtfs_shapes(session, dataset_id=dataset_id)
|
||||
typer.echo(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
@cli.command("stats")
|
||||
def stats_command() -> None:
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
active_dataset_ids = [row[0] for row in session.execute(select(Dataset.id).where(Dataset.is_active.is_(True))).all()]
|
||||
stats = {
|
||||
"sources": session.scalar(select(func.count()).select_from(Source)),
|
||||
"source_catalog_entries": session.scalar(select(func.count()).select_from(SourceCatalogEntry)) or 0,
|
||||
"active_datasets": len(active_dataset_ids),
|
||||
"gtfs_routes": session.scalar(select(func.count()).select_from(GtfsRoute).where(GtfsRoute.dataset_id.in_(active_dataset_ids))) if active_dataset_ids else 0,
|
||||
"gtfs_stops": session.scalar(select(func.count()).select_from(GtfsStop).where(GtfsStop.dataset_id.in_(active_dataset_ids))) if active_dataset_ids else 0,
|
||||
"gtfs_shapes": session.scalar(select(func.count()).select_from(GtfsShape).where(GtfsShape.dataset_id.in_(active_dataset_ids))) if active_dataset_ids else 0,
|
||||
"route_patterns": session.scalar(select(func.count()).select_from(RoutePattern)) or 0,
|
||||
"osm_routes": sum(osm_feature_count(session, dataset_id, kind="route") for dataset_id in active_dataset_ids),
|
||||
"matches": {status: count for status, count in session.execute(select(RouteMatch.status, func.count()).group_by(RouteMatch.status)).all()},
|
||||
}
|
||||
typer.echo(json.dumps(stats, indent=2))
|
||||
|
||||
|
||||
@cli.command("import-source-catalog")
|
||||
def import_source_catalog_command(
|
||||
csv_path: Path = typer.Option(default_source_catalog_path(), "--csv", help="Source catalog CSV path"),
|
||||
no_update: bool = typer.Option(False, help="Skip rows that already exist"),
|
||||
) -> None:
|
||||
with _write_lock("import-source-catalog"):
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
result = import_source_catalog(session, csv_path, update_existing=not no_update)
|
||||
result["summary"] = source_catalog_summary(session)
|
||||
typer.echo(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
@cli.command("import-ingestable-sources")
|
||||
def import_ingestable_sources_command(
|
||||
csv_path: Path = typer.Option(default_ingestable_sources_path(), "--csv", help="Ingestable source seed CSV path"),
|
||||
no_update: bool = typer.Option(False, help="Skip sources that already exist"),
|
||||
) -> None:
|
||||
with _write_lock("import-ingestable-sources"):
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
result = import_ingestable_sources(session, csv_path, update_existing=not no_update)
|
||||
result["summary"] = source_catalog_summary(session)
|
||||
typer.echo(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
@cli.command("discover-gtfs-sources")
|
||||
def discover_gtfs_sources_command(
|
||||
output_dir: Path = typer.Option(default_generated_dir(), "--output-dir", help="Directory for generated discovery CSVs"),
|
||||
countries: str = typer.Option(
|
||||
",".join(["DE", "AT", "CH", "NL", "DK", "FR", "BE", "LU", "NO", "SE", "FI", "IE", "GB"]),
|
||||
"--countries",
|
||||
help="Comma-separated country codes, or ALL for every country exposed by the upstream catalogs",
|
||||
),
|
||||
no_mobility_database: bool = typer.Option(False, help="Skip Mobility Database feeds_v2.csv"),
|
||||
no_acceptance_test_list: bool = typer.Option(False, help="Skip MobilityData validator acceptance-test feed list"),
|
||||
no_ptna: bool = typer.Option(False, help="Skip PTNA GTFS analysis pages"),
|
||||
max_ptna_details: int = typer.Option(80, help="Maximum PTNA detail pages to fetch for license/crosswalk metadata"),
|
||||
test_limit: int = typer.Option(24, help="Rows to write to the focused test-run ingestable CSV"),
|
||||
check_urls: bool = typer.Option(False, help="Run HEAD/range checks for ingestable feed URLs"),
|
||||
) -> None:
|
||||
result = build_gtfs_discovery_manifests(
|
||||
output_dir=output_dir,
|
||||
countries=[part.strip() for part in countries.split(",") if part.strip()],
|
||||
include_mobility_database=not no_mobility_database,
|
||||
include_acceptance_test_list=not no_acceptance_test_list,
|
||||
include_ptna=not no_ptna,
|
||||
max_ptna_details=max_ptna_details,
|
||||
test_limit=test_limit,
|
||||
check_urls=check_urls,
|
||||
)
|
||||
typer.echo(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
@cli.command("queue-source-imports-from-csv")
|
||||
def queue_source_imports_from_csv_command(
|
||||
csv_path: Path = typer.Option(default_ingestable_sources_path(), "--csv", help="Ingestable source CSV path"),
|
||||
no_update: bool = typer.Option(False, help="Skip sources that already exist instead of updating them"),
|
||||
run_match_at_end: bool = typer.Option(True, help="Queue one route-matching job after all source imports"),
|
||||
build_route_layer_at_end: bool = typer.Option(True, help="Queue one route-layer rebuild after route matching"),
|
||||
priority: int = typer.Option(0, help="Priority for queued source import jobs"),
|
||||
) -> None:
|
||||
with _write_lock("queue-source-imports-from-csv"):
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
csv_path = csv_path if csv_path.is_absolute() else Path.cwd() / csv_path
|
||||
imported = import_ingestable_sources(session, csv_path, update_existing=not no_update)
|
||||
source_urls = _source_urls_from_ingestable_csv(csv_path)
|
||||
sources = session.scalars(
|
||||
select(Source)
|
||||
.where(Source.kind == "gtfs", Source.url.in_(source_urls))
|
||||
.order_by(Source.id)
|
||||
).all()
|
||||
jobs = [
|
||||
create_source_import_job(
|
||||
session,
|
||||
source,
|
||||
run_match=False,
|
||||
build_route_layer=False,
|
||||
priority=priority,
|
||||
)
|
||||
for source in sources
|
||||
]
|
||||
route_match_job = create_route_matching_job(session, priority=priority) if run_match_at_end else None
|
||||
route_layer_job = create_route_layer_rebuild_job(session, priority=priority) if build_route_layer_at_end else None
|
||||
typer.echo(
|
||||
json.dumps(
|
||||
{
|
||||
"csv": str(csv_path),
|
||||
"imported": imported,
|
||||
"sources": [{"id": source.id, "name": source.name} for source in sources],
|
||||
"source_import_jobs": [job.id for job in jobs],
|
||||
"route_match_job": None if route_match_job is None else route_match_job.id,
|
||||
"route_layer_job": None if route_layer_job is None else route_layer_job.id,
|
||||
},
|
||||
indent=2,
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@cli.command("prune-cache")
|
||||
def prune_cache_command(dry_run: bool = typer.Option(False, help="Report files without deleting them")) -> None:
|
||||
with _write_lock("prune-cache"):
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
referenced = {
|
||||
Path(path).resolve()
|
||||
for path in session.scalars(select(Dataset.local_path)).all()
|
||||
if path
|
||||
}
|
||||
for dataset in session.scalars(select(Dataset)).all():
|
||||
referenced.update(path.resolve() for path in dataset_sidecar_paths(dataset))
|
||||
|
||||
roots = [settings.data_dir / "sources", settings.data_dir / "derived", settings.data_dir / "sidecars", settings.data_dir / "staging"]
|
||||
candidates = [
|
||||
path
|
||||
for root in roots
|
||||
if root.exists()
|
||||
for path in root.rglob("*")
|
||||
if path.is_file() and path.resolve() not in referenced
|
||||
]
|
||||
total_bytes = sum(path.stat().st_size for path in candidates)
|
||||
if not dry_run:
|
||||
for path in candidates:
|
||||
path.unlink()
|
||||
for root in roots:
|
||||
_remove_empty_dirs(root)
|
||||
|
||||
typer.echo(
|
||||
json.dumps(
|
||||
{
|
||||
"dry_run": dry_run,
|
||||
"files": len(candidates),
|
||||
"bytes": total_bytes,
|
||||
"deleted": 0 if dry_run else len(candidates),
|
||||
},
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@cli.command("prune-inactive-datasets")
|
||||
def prune_inactive_datasets_command(
|
||||
dry_run: bool = typer.Option(False, help="Report inactive normalized datasets without deleting them"),
|
||||
) -> None:
|
||||
with _write_lock("prune-inactive-datasets"):
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
result = prune_inactive_datasets(session, dry_run=dry_run)
|
||||
typer.echo(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
@cli.command("vacuum-db")
|
||||
def vacuum_db_command() -> None:
|
||||
with _write_lock("vacuum-db"):
|
||||
init_db()
|
||||
with engine.connect().execution_options(isolation_level="AUTOCOMMIT") as connection:
|
||||
connection.execute(text("VACUUM"))
|
||||
connection.execute(text("PRAGMA wal_checkpoint(TRUNCATE)"))
|
||||
typer.echo("Database vacuumed")
|
||||
|
||||
|
||||
@cli.command("worker")
|
||||
def worker_command(
|
||||
once: bool = typer.Option(False, help="Process at most one queued job and exit"),
|
||||
max_jobs: Optional[int] = typer.Option(None, help="Process at most this many jobs and exit"),
|
||||
poll_interval: float = typer.Option(2.0, help="Seconds to wait between queue polls"),
|
||||
worker_id: Optional[str] = typer.Option(None, help="Stable worker identifier"),
|
||||
) -> None:
|
||||
result = run_worker_loop(worker_id=worker_id, poll_interval=poll_interval, max_jobs=max_jobs, once=once)
|
||||
typer.echo(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
def _remove_empty_dirs(root: Path) -> None:
|
||||
if not root.exists():
|
||||
return
|
||||
for path in sorted((p for p in root.rglob("*") if p.is_dir()), key=lambda p: len(p.parts), reverse=True):
|
||||
try:
|
||||
path.rmdir()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _write_lock(operation: str):
|
||||
return database_write_lock(f"cli:{operation}", timeout=settings.database_write_lock_cli_timeout_seconds)
|
||||
|
||||
|
||||
def _source_urls_from_ingestable_csv(path: Path) -> list[str]:
|
||||
urls: list[str] = []
|
||||
with path.open("r", encoding="utf-8-sig", newline="") as handle:
|
||||
for row in csv.DictReader(handle):
|
||||
if (row.get("kind") or "").strip().lower() != "gtfs":
|
||||
continue
|
||||
url = (row.get("url") or "").strip()
|
||||
if url and url not in urls:
|
||||
urls.append(url)
|
||||
return urls
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
Reference in New Issue
Block a user