309 lines
11 KiB
Python
309 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import sqlite3
|
|
from contextlib import contextmanager
|
|
from pathlib import Path
|
|
from typing import Iterator, Sequence
|
|
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.config import settings
|
|
from app.models import Dataset, GtfsStopTime
|
|
|
|
|
|
GTFS_STORAGE_METADATA_KEY = "gtfs_storage"
|
|
GTFS_STORAGE_MAIN = "main"
|
|
GTFS_STORAGE_SIDECAR_STOP_TIMES = "sidecar_stop_times"
|
|
GTFS_STOP_TIME_COLUMNS = [
|
|
"trip_id",
|
|
"stop_id",
|
|
"stop_sequence",
|
|
"arrival_time",
|
|
"departure_time",
|
|
"arrival_seconds",
|
|
"departure_seconds",
|
|
]
|
|
SQLITE_IN_CHUNK_SIZE = 800
|
|
|
|
|
|
def effective_gtfs_timetable_storage(value: str | None = None) -> str:
|
|
configured = str(value or settings.gtfs_timetable_storage or GTFS_STORAGE_SIDECAR_STOP_TIMES).strip().lower()
|
|
if configured in {GTFS_STORAGE_MAIN, "main_db", "main_sqlite", "postgres", "postgresql"}:
|
|
return GTFS_STORAGE_MAIN
|
|
if settings.is_postgresql_database and not settings.postgres_use_sidecars:
|
|
return GTFS_STORAGE_MAIN
|
|
return GTFS_STORAGE_SIDECAR_STOP_TIMES
|
|
|
|
|
|
class MissingGtfsSidecar(FileNotFoundError):
|
|
def __init__(self, dataset_id: int | None, path: Path | None) -> None:
|
|
self.dataset_id = dataset_id
|
|
self.path = path
|
|
if path is None:
|
|
message = f"dataset #{dataset_id} does not reference a GTFS sidecar"
|
|
else:
|
|
message = f"GTFS sidecar does not exist: {path}"
|
|
super().__init__(message)
|
|
|
|
|
|
def dataset_metadata(dataset: Dataset) -> dict:
|
|
try:
|
|
metadata = json.loads(dataset.metadata_json or "{}")
|
|
except json.JSONDecodeError:
|
|
return {}
|
|
return metadata if isinstance(metadata, dict) else {}
|
|
|
|
|
|
def stop_times_are_sidecar(dataset: Dataset | None) -> bool:
|
|
if dataset is None:
|
|
return False
|
|
storage = dataset_metadata(dataset).get(GTFS_STORAGE_METADATA_KEY)
|
|
if not isinstance(storage, dict):
|
|
return False
|
|
tables = storage.get("tables")
|
|
if isinstance(tables, dict):
|
|
return tables.get("gtfs_stop_times") == "sidecar"
|
|
return storage.get("mode") == GTFS_STORAGE_SIDECAR_STOP_TIMES
|
|
|
|
|
|
def sidecar_path(dataset: Dataset | None) -> Path | None:
|
|
if dataset is None:
|
|
return None
|
|
storage = dataset_metadata(dataset).get(GTFS_STORAGE_METADATA_KEY)
|
|
if not isinstance(storage, dict):
|
|
return None
|
|
value = storage.get("sidecar_path")
|
|
if not value:
|
|
return None
|
|
return Path(str(value))
|
|
|
|
|
|
def dataset_sidecar_paths(dataset: Dataset) -> list[Path]:
|
|
path = sidecar_path(dataset)
|
|
return [] if path is None else [path]
|
|
|
|
|
|
def missing_sidecar_paths(dataset: Dataset | None) -> list[str]:
|
|
if not stop_times_are_sidecar(dataset):
|
|
return []
|
|
path = sidecar_path(dataset)
|
|
if path is None:
|
|
dataset_id = "unknown" if dataset is None else str(dataset.id)
|
|
return [f"dataset #{dataset_id} has no configured GTFS sidecar path"]
|
|
return [] if path.exists() else [str(path)]
|
|
|
|
|
|
def uses_sidecar_stop_times(session: Session, dataset_id: int) -> bool:
|
|
return stop_times_are_sidecar(session.get(Dataset, dataset_id))
|
|
|
|
|
|
@contextmanager
|
|
def sidecar_connection(dataset: Dataset) -> Iterator[sqlite3.Connection]:
|
|
path = sidecar_path(dataset)
|
|
if path is None:
|
|
raise MissingGtfsSidecar(dataset.id, None)
|
|
if not path.exists():
|
|
raise MissingGtfsSidecar(dataset.id, path)
|
|
connection = sqlite3.connect(f"file:{path}?mode=ro", uri=True)
|
|
connection.row_factory = sqlite3.Row
|
|
try:
|
|
yield connection
|
|
finally:
|
|
connection.close()
|
|
|
|
|
|
def stop_time_count(session: Session, dataset_id: int) -> int:
|
|
dataset = session.get(Dataset, dataset_id)
|
|
if stop_times_are_sidecar(dataset):
|
|
try:
|
|
with sidecar_connection(dataset) as connection:
|
|
return int(connection.execute("SELECT COUNT(*) FROM gtfs_stop_times").fetchone()[0] or 0)
|
|
except MissingGtfsSidecar:
|
|
return 0
|
|
return session.scalar(select(func.count()).select_from(GtfsStopTime).where(GtfsStopTime.dataset_id == dataset_id)) or 0
|
|
|
|
|
|
def stop_time_counts_by_dataset(session: Session, dataset_ids: Sequence[int]) -> dict[int, int]:
|
|
counts: dict[int, int] = {}
|
|
for dataset_id in dataset_ids:
|
|
counts[int(dataset_id)] = stop_time_count(session, int(dataset_id))
|
|
return counts
|
|
|
|
|
|
def scheduled_stop_ids(session: Session, dataset_id: int, stop_ids: Sequence[str]) -> tuple[str, ...]:
|
|
if not stop_ids:
|
|
return ()
|
|
dataset = session.get(Dataset, dataset_id)
|
|
requested = [str(stop_id) for stop_id in stop_ids]
|
|
found: set[str] = set()
|
|
if stop_times_are_sidecar(dataset):
|
|
try:
|
|
with sidecar_connection(dataset) as connection:
|
|
for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
|
|
placeholders = ", ".join(["?"] * len(chunk))
|
|
rows = connection.execute(
|
|
f"""
|
|
SELECT stop_id
|
|
FROM gtfs_stop_times
|
|
WHERE stop_id IN ({placeholders})
|
|
GROUP BY stop_id
|
|
""",
|
|
list(chunk),
|
|
).fetchall()
|
|
found.update(str(row["stop_id"]) for row in rows)
|
|
except MissingGtfsSidecar:
|
|
return ()
|
|
else:
|
|
for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
|
|
rows = session.scalars(
|
|
select(GtfsStopTime.stop_id)
|
|
.where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.stop_id.in_(chunk))
|
|
.group_by(GtfsStopTime.stop_id)
|
|
).all()
|
|
found.update(str(row) for row in rows)
|
|
return tuple(sorted(found))
|
|
|
|
|
|
def all_scheduled_stop_ids(session: Session, dataset_id: int) -> set[str]:
|
|
dataset = session.get(Dataset, dataset_id)
|
|
if stop_times_are_sidecar(dataset):
|
|
try:
|
|
with sidecar_connection(dataset) as connection:
|
|
return {
|
|
str(row["stop_id"])
|
|
for row in connection.execute("SELECT stop_id FROM gtfs_stop_times GROUP BY stop_id").fetchall()
|
|
}
|
|
except MissingGtfsSidecar:
|
|
return set()
|
|
return {
|
|
str(row)
|
|
for row in session.scalars(
|
|
select(GtfsStopTime.stop_id)
|
|
.where(GtfsStopTime.dataset_id == dataset_id)
|
|
.group_by(GtfsStopTime.stop_id)
|
|
).all()
|
|
}
|
|
|
|
|
|
def scheduled_stop_ids_by_dataset(session: Session, dataset_ids: Sequence[int]) -> dict[int, set[str]]:
|
|
return {int(dataset_id): all_scheduled_stop_ids(session, int(dataset_id)) for dataset_id in dataset_ids}
|
|
|
|
|
|
def has_scheduled_stop(session: Session, dataset_id: int, stop_id: str) -> bool:
|
|
return bool(scheduled_stop_ids(session, dataset_id, [stop_id]))
|
|
|
|
|
|
def stop_times_by_trip(
|
|
session: Session,
|
|
dataset_id: int,
|
|
trip_ids: Sequence[str],
|
|
) -> dict[str, list[GtfsStopTime]]:
|
|
if not trip_ids:
|
|
return {}
|
|
grouped: dict[str, list[GtfsStopTime]] = {}
|
|
dataset = session.get(Dataset, dataset_id)
|
|
requested = [str(trip_id) for trip_id in trip_ids]
|
|
if stop_times_are_sidecar(dataset):
|
|
column_sql = ", ".join(GTFS_STOP_TIME_COLUMNS)
|
|
try:
|
|
with sidecar_connection(dataset) as connection:
|
|
for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
|
|
placeholders = ", ".join(["?"] * len(chunk))
|
|
rows = connection.execute(
|
|
f"""
|
|
SELECT {column_sql}
|
|
FROM gtfs_stop_times
|
|
WHERE trip_id IN ({placeholders})
|
|
ORDER BY trip_id, stop_sequence
|
|
""",
|
|
list(chunk),
|
|
).fetchall()
|
|
for row in rows:
|
|
stop_time = stop_time_from_row(dataset_id, row)
|
|
grouped.setdefault(stop_time.trip_id, []).append(stop_time)
|
|
except MissingGtfsSidecar:
|
|
return {}
|
|
return grouped
|
|
|
|
for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
|
|
rows = session.scalars(
|
|
select(GtfsStopTime)
|
|
.where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.trip_id.in_(chunk))
|
|
.order_by(GtfsStopTime.trip_id, GtfsStopTime.stop_sequence)
|
|
).all()
|
|
for row in rows:
|
|
grouped.setdefault(row.trip_id, []).append(row)
|
|
return grouped
|
|
|
|
|
|
def stop_times_for_trip_range(
|
|
session: Session,
|
|
dataset_id: int,
|
|
trip_id: str,
|
|
start_sequence: int,
|
|
end_sequence: int,
|
|
) -> list[GtfsStopTime]:
|
|
dataset = session.get(Dataset, dataset_id)
|
|
if stop_times_are_sidecar(dataset):
|
|
column_sql = ", ".join(GTFS_STOP_TIME_COLUMNS)
|
|
try:
|
|
with sidecar_connection(dataset) as connection:
|
|
rows = connection.execute(
|
|
f"""
|
|
SELECT {column_sql}
|
|
FROM gtfs_stop_times
|
|
WHERE trip_id = ?
|
|
AND stop_sequence >= ?
|
|
AND stop_sequence <= ?
|
|
ORDER BY stop_sequence
|
|
""",
|
|
(trip_id, int(start_sequence), int(end_sequence)),
|
|
).fetchall()
|
|
return [stop_time_from_row(dataset_id, row) for row in rows]
|
|
except MissingGtfsSidecar:
|
|
return []
|
|
|
|
return list(
|
|
session.scalars(
|
|
select(GtfsStopTime)
|
|
.where(
|
|
GtfsStopTime.dataset_id == dataset_id,
|
|
GtfsStopTime.trip_id == trip_id,
|
|
GtfsStopTime.stop_sequence >= start_sequence,
|
|
GtfsStopTime.stop_sequence <= end_sequence,
|
|
)
|
|
.order_by(GtfsStopTime.stop_sequence)
|
|
).all()
|
|
)
|
|
|
|
|
|
def stop_time_from_row(dataset_id: int, row) -> GtfsStopTime:
|
|
return GtfsStopTime(
|
|
dataset_id=dataset_id,
|
|
trip_id=str(row["trip_id"]),
|
|
stop_id=str(row["stop_id"]),
|
|
stop_sequence=int(row["stop_sequence"]),
|
|
arrival_time=row["arrival_time"],
|
|
departure_time=row["departure_time"],
|
|
arrival_seconds=row["arrival_seconds"],
|
|
departure_seconds=row["departure_seconds"],
|
|
)
|
|
|
|
|
|
def execute_sidecar_query(session: Session, dataset_id: int, sql: str, params: Sequence[object]) -> list[sqlite3.Row]:
|
|
dataset = session.get(Dataset, dataset_id)
|
|
if not stop_times_are_sidecar(dataset):
|
|
raise ValueError(f"dataset #{dataset_id} does not use sidecar stop_times")
|
|
try:
|
|
with sidecar_connection(dataset) as connection:
|
|
return list(connection.execute(sql, list(params)).fetchall())
|
|
except MissingGtfsSidecar:
|
|
return []
|
|
|
|
|
|
def _chunks[T](items: Sequence[T], size: int) -> Iterator[Sequence[T]]:
|
|
for index in range(0, len(items), size):
|
|
yield items[index : index + size]
|