from __future__ import annotations import json import sqlite3 from contextlib import contextmanager from pathlib import Path from typing import Iterator, Sequence from sqlalchemy import func, select from sqlalchemy.orm import Session from app.config import settings from app.models import Dataset, GtfsStopTime GTFS_STORAGE_METADATA_KEY = "gtfs_storage" GTFS_STORAGE_MAIN = "main" GTFS_STORAGE_SIDECAR_STOP_TIMES = "sidecar_stop_times" GTFS_STOP_TIME_COLUMNS = [ "trip_id", "stop_id", "stop_sequence", "arrival_time", "departure_time", "arrival_seconds", "departure_seconds", ] SQLITE_IN_CHUNK_SIZE = 800 def effective_gtfs_timetable_storage(value: str | None = None) -> str: configured = str(value or settings.gtfs_timetable_storage or GTFS_STORAGE_SIDECAR_STOP_TIMES).strip().lower() if configured in {GTFS_STORAGE_MAIN, "main_db", "main_sqlite", "postgres", "postgresql"}: return GTFS_STORAGE_MAIN if settings.is_postgresql_database and not settings.postgres_use_sidecars: return GTFS_STORAGE_MAIN return GTFS_STORAGE_SIDECAR_STOP_TIMES class MissingGtfsSidecar(FileNotFoundError): def __init__(self, dataset_id: int | None, path: Path | None) -> None: self.dataset_id = dataset_id self.path = path if path is None: message = f"dataset #{dataset_id} does not reference a GTFS sidecar" else: message = f"GTFS sidecar does not exist: {path}" super().__init__(message) def dataset_metadata(dataset: Dataset) -> dict: try: metadata = json.loads(dataset.metadata_json or "{}") except json.JSONDecodeError: return {} return metadata if isinstance(metadata, dict) else {} def stop_times_are_sidecar(dataset: Dataset | None) -> bool: if dataset is None: return False storage = dataset_metadata(dataset).get(GTFS_STORAGE_METADATA_KEY) if not isinstance(storage, dict): return False tables = storage.get("tables") if isinstance(tables, dict): return tables.get("gtfs_stop_times") == "sidecar" return storage.get("mode") == GTFS_STORAGE_SIDECAR_STOP_TIMES def sidecar_path(dataset: Dataset | None) -> Path | None: if dataset is None: return None storage = dataset_metadata(dataset).get(GTFS_STORAGE_METADATA_KEY) if not isinstance(storage, dict): return None value = storage.get("sidecar_path") if not value: return None return Path(str(value)) def dataset_sidecar_paths(dataset: Dataset) -> list[Path]: path = sidecar_path(dataset) return [] if path is None else [path] def missing_sidecar_paths(dataset: Dataset | None) -> list[str]: if not stop_times_are_sidecar(dataset): return [] path = sidecar_path(dataset) if path is None: dataset_id = "unknown" if dataset is None else str(dataset.id) return [f"dataset #{dataset_id} has no configured GTFS sidecar path"] return [] if path.exists() else [str(path)] def uses_sidecar_stop_times(session: Session, dataset_id: int) -> bool: return stop_times_are_sidecar(session.get(Dataset, dataset_id)) @contextmanager def sidecar_connection(dataset: Dataset) -> Iterator[sqlite3.Connection]: path = sidecar_path(dataset) if path is None: raise MissingGtfsSidecar(dataset.id, None) if not path.exists(): raise MissingGtfsSidecar(dataset.id, path) connection = sqlite3.connect(f"file:{path}?mode=ro", uri=True) connection.row_factory = sqlite3.Row try: yield connection finally: connection.close() def stop_time_count(session: Session, dataset_id: int) -> int: dataset = session.get(Dataset, dataset_id) if stop_times_are_sidecar(dataset): try: with sidecar_connection(dataset) as connection: return int(connection.execute("SELECT COUNT(*) FROM gtfs_stop_times").fetchone()[0] or 0) except MissingGtfsSidecar: return 0 return session.scalar(select(func.count()).select_from(GtfsStopTime).where(GtfsStopTime.dataset_id == dataset_id)) or 0 def stop_time_counts_by_dataset(session: Session, dataset_ids: Sequence[int]) -> dict[int, int]: counts: dict[int, int] = {} for dataset_id in dataset_ids: counts[int(dataset_id)] = stop_time_count(session, int(dataset_id)) return counts def scheduled_stop_ids(session: Session, dataset_id: int, stop_ids: Sequence[str]) -> tuple[str, ...]: if not stop_ids: return () dataset = session.get(Dataset, dataset_id) requested = [str(stop_id) for stop_id in stop_ids] found: set[str] = set() if stop_times_are_sidecar(dataset): try: with sidecar_connection(dataset) as connection: for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE): placeholders = ", ".join(["?"] * len(chunk)) rows = connection.execute( f""" SELECT stop_id FROM gtfs_stop_times WHERE stop_id IN ({placeholders}) GROUP BY stop_id """, list(chunk), ).fetchall() found.update(str(row["stop_id"]) for row in rows) except MissingGtfsSidecar: return () else: for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE): rows = session.scalars( select(GtfsStopTime.stop_id) .where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.stop_id.in_(chunk)) .group_by(GtfsStopTime.stop_id) ).all() found.update(str(row) for row in rows) return tuple(sorted(found)) def all_scheduled_stop_ids(session: Session, dataset_id: int) -> set[str]: dataset = session.get(Dataset, dataset_id) if stop_times_are_sidecar(dataset): try: with sidecar_connection(dataset) as connection: return { str(row["stop_id"]) for row in connection.execute("SELECT stop_id FROM gtfs_stop_times GROUP BY stop_id").fetchall() } except MissingGtfsSidecar: return set() return { str(row) for row in session.scalars( select(GtfsStopTime.stop_id) .where(GtfsStopTime.dataset_id == dataset_id) .group_by(GtfsStopTime.stop_id) ).all() } def scheduled_stop_ids_by_dataset(session: Session, dataset_ids: Sequence[int]) -> dict[int, set[str]]: return {int(dataset_id): all_scheduled_stop_ids(session, int(dataset_id)) for dataset_id in dataset_ids} def has_scheduled_stop(session: Session, dataset_id: int, stop_id: str) -> bool: return bool(scheduled_stop_ids(session, dataset_id, [stop_id])) def stop_times_by_trip( session: Session, dataset_id: int, trip_ids: Sequence[str], ) -> dict[str, list[GtfsStopTime]]: if not trip_ids: return {} grouped: dict[str, list[GtfsStopTime]] = {} dataset = session.get(Dataset, dataset_id) requested = [str(trip_id) for trip_id in trip_ids] if stop_times_are_sidecar(dataset): column_sql = ", ".join(GTFS_STOP_TIME_COLUMNS) try: with sidecar_connection(dataset) as connection: for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE): placeholders = ", ".join(["?"] * len(chunk)) rows = connection.execute( f""" SELECT {column_sql} FROM gtfs_stop_times WHERE trip_id IN ({placeholders}) ORDER BY trip_id, stop_sequence """, list(chunk), ).fetchall() for row in rows: stop_time = stop_time_from_row(dataset_id, row) grouped.setdefault(stop_time.trip_id, []).append(stop_time) except MissingGtfsSidecar: return {} return grouped for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE): rows = session.scalars( select(GtfsStopTime) .where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.trip_id.in_(chunk)) .order_by(GtfsStopTime.trip_id, GtfsStopTime.stop_sequence) ).all() for row in rows: grouped.setdefault(row.trip_id, []).append(row) return grouped def stop_times_for_trip_range( session: Session, dataset_id: int, trip_id: str, start_sequence: int, end_sequence: int, ) -> list[GtfsStopTime]: dataset = session.get(Dataset, dataset_id) if stop_times_are_sidecar(dataset): column_sql = ", ".join(GTFS_STOP_TIME_COLUMNS) try: with sidecar_connection(dataset) as connection: rows = connection.execute( f""" SELECT {column_sql} FROM gtfs_stop_times WHERE trip_id = ? AND stop_sequence >= ? AND stop_sequence <= ? ORDER BY stop_sequence """, (trip_id, int(start_sequence), int(end_sequence)), ).fetchall() return [stop_time_from_row(dataset_id, row) for row in rows] except MissingGtfsSidecar: return [] return list( session.scalars( select(GtfsStopTime) .where( GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.trip_id == trip_id, GtfsStopTime.stop_sequence >= start_sequence, GtfsStopTime.stop_sequence <= end_sequence, ) .order_by(GtfsStopTime.stop_sequence) ).all() ) def stop_time_from_row(dataset_id: int, row) -> GtfsStopTime: return GtfsStopTime( dataset_id=dataset_id, trip_id=str(row["trip_id"]), stop_id=str(row["stop_id"]), stop_sequence=int(row["stop_sequence"]), arrival_time=row["arrival_time"], departure_time=row["departure_time"], arrival_seconds=row["arrival_seconds"], departure_seconds=row["departure_seconds"], ) def execute_sidecar_query(session: Session, dataset_id: int, sql: str, params: Sequence[object]) -> list[sqlite3.Row]: dataset = session.get(Dataset, dataset_id) if not stop_times_are_sidecar(dataset): raise ValueError(f"dataset #{dataset_id} does not use sidecar stop_times") try: with sidecar_connection(dataset) as connection: return list(connection.execute(sql, list(params)).fetchall()) except MissingGtfsSidecar: return [] def _chunks[T](items: Sequence[T], size: int) -> Iterator[Sequence[T]]: for index in range(0, len(items), size): yield items[index : index + size]