Files
meubility-workbench/app/gtfs_storage.py
2026-07-01 23:29:51 +02:00

309 lines
11 KiB
Python

from __future__ import annotations
import json
import sqlite3
from contextlib import contextmanager
from pathlib import Path
from typing import Iterator, Sequence
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from app.config import settings
from app.models import Dataset, GtfsStopTime
GTFS_STORAGE_METADATA_KEY = "gtfs_storage"
GTFS_STORAGE_MAIN = "main"
GTFS_STORAGE_SIDECAR_STOP_TIMES = "sidecar_stop_times"
GTFS_STOP_TIME_COLUMNS = [
"trip_id",
"stop_id",
"stop_sequence",
"arrival_time",
"departure_time",
"arrival_seconds",
"departure_seconds",
]
SQLITE_IN_CHUNK_SIZE = 800
def effective_gtfs_timetable_storage(value: str | None = None) -> str:
configured = str(value or settings.gtfs_timetable_storage or GTFS_STORAGE_SIDECAR_STOP_TIMES).strip().lower()
if configured in {GTFS_STORAGE_MAIN, "main_db", "main_sqlite", "postgres", "postgresql"}:
return GTFS_STORAGE_MAIN
if settings.is_postgresql_database and not settings.postgres_use_sidecars:
return GTFS_STORAGE_MAIN
return GTFS_STORAGE_SIDECAR_STOP_TIMES
class MissingGtfsSidecar(FileNotFoundError):
def __init__(self, dataset_id: int | None, path: Path | None) -> None:
self.dataset_id = dataset_id
self.path = path
if path is None:
message = f"dataset #{dataset_id} does not reference a GTFS sidecar"
else:
message = f"GTFS sidecar does not exist: {path}"
super().__init__(message)
def dataset_metadata(dataset: Dataset) -> dict:
try:
metadata = json.loads(dataset.metadata_json or "{}")
except json.JSONDecodeError:
return {}
return metadata if isinstance(metadata, dict) else {}
def stop_times_are_sidecar(dataset: Dataset | None) -> bool:
if dataset is None:
return False
storage = dataset_metadata(dataset).get(GTFS_STORAGE_METADATA_KEY)
if not isinstance(storage, dict):
return False
tables = storage.get("tables")
if isinstance(tables, dict):
return tables.get("gtfs_stop_times") == "sidecar"
return storage.get("mode") == GTFS_STORAGE_SIDECAR_STOP_TIMES
def sidecar_path(dataset: Dataset | None) -> Path | None:
if dataset is None:
return None
storage = dataset_metadata(dataset).get(GTFS_STORAGE_METADATA_KEY)
if not isinstance(storage, dict):
return None
value = storage.get("sidecar_path")
if not value:
return None
return Path(str(value))
def dataset_sidecar_paths(dataset: Dataset) -> list[Path]:
path = sidecar_path(dataset)
return [] if path is None else [path]
def missing_sidecar_paths(dataset: Dataset | None) -> list[str]:
if not stop_times_are_sidecar(dataset):
return []
path = sidecar_path(dataset)
if path is None:
dataset_id = "unknown" if dataset is None else str(dataset.id)
return [f"dataset #{dataset_id} has no configured GTFS sidecar path"]
return [] if path.exists() else [str(path)]
def uses_sidecar_stop_times(session: Session, dataset_id: int) -> bool:
return stop_times_are_sidecar(session.get(Dataset, dataset_id))
@contextmanager
def sidecar_connection(dataset: Dataset) -> Iterator[sqlite3.Connection]:
path = sidecar_path(dataset)
if path is None:
raise MissingGtfsSidecar(dataset.id, None)
if not path.exists():
raise MissingGtfsSidecar(dataset.id, path)
connection = sqlite3.connect(f"file:{path}?mode=ro", uri=True)
connection.row_factory = sqlite3.Row
try:
yield connection
finally:
connection.close()
def stop_time_count(session: Session, dataset_id: int) -> int:
dataset = session.get(Dataset, dataset_id)
if stop_times_are_sidecar(dataset):
try:
with sidecar_connection(dataset) as connection:
return int(connection.execute("SELECT COUNT(*) FROM gtfs_stop_times").fetchone()[0] or 0)
except MissingGtfsSidecar:
return 0
return session.scalar(select(func.count()).select_from(GtfsStopTime).where(GtfsStopTime.dataset_id == dataset_id)) or 0
def stop_time_counts_by_dataset(session: Session, dataset_ids: Sequence[int]) -> dict[int, int]:
counts: dict[int, int] = {}
for dataset_id in dataset_ids:
counts[int(dataset_id)] = stop_time_count(session, int(dataset_id))
return counts
def scheduled_stop_ids(session: Session, dataset_id: int, stop_ids: Sequence[str]) -> tuple[str, ...]:
if not stop_ids:
return ()
dataset = session.get(Dataset, dataset_id)
requested = [str(stop_id) for stop_id in stop_ids]
found: set[str] = set()
if stop_times_are_sidecar(dataset):
try:
with sidecar_connection(dataset) as connection:
for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
placeholders = ", ".join(["?"] * len(chunk))
rows = connection.execute(
f"""
SELECT stop_id
FROM gtfs_stop_times
WHERE stop_id IN ({placeholders})
GROUP BY stop_id
""",
list(chunk),
).fetchall()
found.update(str(row["stop_id"]) for row in rows)
except MissingGtfsSidecar:
return ()
else:
for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
rows = session.scalars(
select(GtfsStopTime.stop_id)
.where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.stop_id.in_(chunk))
.group_by(GtfsStopTime.stop_id)
).all()
found.update(str(row) for row in rows)
return tuple(sorted(found))
def all_scheduled_stop_ids(session: Session, dataset_id: int) -> set[str]:
dataset = session.get(Dataset, dataset_id)
if stop_times_are_sidecar(dataset):
try:
with sidecar_connection(dataset) as connection:
return {
str(row["stop_id"])
for row in connection.execute("SELECT stop_id FROM gtfs_stop_times GROUP BY stop_id").fetchall()
}
except MissingGtfsSidecar:
return set()
return {
str(row)
for row in session.scalars(
select(GtfsStopTime.stop_id)
.where(GtfsStopTime.dataset_id == dataset_id)
.group_by(GtfsStopTime.stop_id)
).all()
}
def scheduled_stop_ids_by_dataset(session: Session, dataset_ids: Sequence[int]) -> dict[int, set[str]]:
return {int(dataset_id): all_scheduled_stop_ids(session, int(dataset_id)) for dataset_id in dataset_ids}
def has_scheduled_stop(session: Session, dataset_id: int, stop_id: str) -> bool:
return bool(scheduled_stop_ids(session, dataset_id, [stop_id]))
def stop_times_by_trip(
session: Session,
dataset_id: int,
trip_ids: Sequence[str],
) -> dict[str, list[GtfsStopTime]]:
if not trip_ids:
return {}
grouped: dict[str, list[GtfsStopTime]] = {}
dataset = session.get(Dataset, dataset_id)
requested = [str(trip_id) for trip_id in trip_ids]
if stop_times_are_sidecar(dataset):
column_sql = ", ".join(GTFS_STOP_TIME_COLUMNS)
try:
with sidecar_connection(dataset) as connection:
for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
placeholders = ", ".join(["?"] * len(chunk))
rows = connection.execute(
f"""
SELECT {column_sql}
FROM gtfs_stop_times
WHERE trip_id IN ({placeholders})
ORDER BY trip_id, stop_sequence
""",
list(chunk),
).fetchall()
for row in rows:
stop_time = stop_time_from_row(dataset_id, row)
grouped.setdefault(stop_time.trip_id, []).append(stop_time)
except MissingGtfsSidecar:
return {}
return grouped
for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
rows = session.scalars(
select(GtfsStopTime)
.where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.trip_id.in_(chunk))
.order_by(GtfsStopTime.trip_id, GtfsStopTime.stop_sequence)
).all()
for row in rows:
grouped.setdefault(row.trip_id, []).append(row)
return grouped
def stop_times_for_trip_range(
session: Session,
dataset_id: int,
trip_id: str,
start_sequence: int,
end_sequence: int,
) -> list[GtfsStopTime]:
dataset = session.get(Dataset, dataset_id)
if stop_times_are_sidecar(dataset):
column_sql = ", ".join(GTFS_STOP_TIME_COLUMNS)
try:
with sidecar_connection(dataset) as connection:
rows = connection.execute(
f"""
SELECT {column_sql}
FROM gtfs_stop_times
WHERE trip_id = ?
AND stop_sequence >= ?
AND stop_sequence <= ?
ORDER BY stop_sequence
""",
(trip_id, int(start_sequence), int(end_sequence)),
).fetchall()
return [stop_time_from_row(dataset_id, row) for row in rows]
except MissingGtfsSidecar:
return []
return list(
session.scalars(
select(GtfsStopTime)
.where(
GtfsStopTime.dataset_id == dataset_id,
GtfsStopTime.trip_id == trip_id,
GtfsStopTime.stop_sequence >= start_sequence,
GtfsStopTime.stop_sequence <= end_sequence,
)
.order_by(GtfsStopTime.stop_sequence)
).all()
)
def stop_time_from_row(dataset_id: int, row) -> GtfsStopTime:
return GtfsStopTime(
dataset_id=dataset_id,
trip_id=str(row["trip_id"]),
stop_id=str(row["stop_id"]),
stop_sequence=int(row["stop_sequence"]),
arrival_time=row["arrival_time"],
departure_time=row["departure_time"],
arrival_seconds=row["arrival_seconds"],
departure_seconds=row["departure_seconds"],
)
def execute_sidecar_query(session: Session, dataset_id: int, sql: str, params: Sequence[object]) -> list[sqlite3.Row]:
dataset = session.get(Dataset, dataset_id)
if not stop_times_are_sidecar(dataset):
raise ValueError(f"dataset #{dataset_id} does not use sidecar stop_times")
try:
with sidecar_connection(dataset) as connection:
return list(connection.execute(sql, list(params)).fetchall())
except MissingGtfsSidecar:
return []
def _chunks[T](items: Sequence[T], size: int) -> Iterator[Sequence[T]]:
for index in range(0, len(items), size):
yield items[index : index + size]