Alpha stage commit
This commit is contained in:
308
app/gtfs_storage.py
Normal file
308
app/gtfs_storage.py
Normal file
@@ -0,0 +1,308 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Iterator, Sequence
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.models import Dataset, GtfsStopTime
|
||||
|
||||
|
||||
GTFS_STORAGE_METADATA_KEY = "gtfs_storage"
|
||||
GTFS_STORAGE_MAIN = "main"
|
||||
GTFS_STORAGE_SIDECAR_STOP_TIMES = "sidecar_stop_times"
|
||||
GTFS_STOP_TIME_COLUMNS = [
|
||||
"trip_id",
|
||||
"stop_id",
|
||||
"stop_sequence",
|
||||
"arrival_time",
|
||||
"departure_time",
|
||||
"arrival_seconds",
|
||||
"departure_seconds",
|
||||
]
|
||||
SQLITE_IN_CHUNK_SIZE = 800
|
||||
|
||||
|
||||
def effective_gtfs_timetable_storage(value: str | None = None) -> str:
|
||||
configured = str(value or settings.gtfs_timetable_storage or GTFS_STORAGE_SIDECAR_STOP_TIMES).strip().lower()
|
||||
if configured in {GTFS_STORAGE_MAIN, "main_db", "main_sqlite", "postgres", "postgresql"}:
|
||||
return GTFS_STORAGE_MAIN
|
||||
if settings.is_postgresql_database and not settings.postgres_use_sidecars:
|
||||
return GTFS_STORAGE_MAIN
|
||||
return GTFS_STORAGE_SIDECAR_STOP_TIMES
|
||||
|
||||
|
||||
class MissingGtfsSidecar(FileNotFoundError):
|
||||
def __init__(self, dataset_id: int | None, path: Path | None) -> None:
|
||||
self.dataset_id = dataset_id
|
||||
self.path = path
|
||||
if path is None:
|
||||
message = f"dataset #{dataset_id} does not reference a GTFS sidecar"
|
||||
else:
|
||||
message = f"GTFS sidecar does not exist: {path}"
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def dataset_metadata(dataset: Dataset) -> dict:
|
||||
try:
|
||||
metadata = json.loads(dataset.metadata_json or "{}")
|
||||
except json.JSONDecodeError:
|
||||
return {}
|
||||
return metadata if isinstance(metadata, dict) else {}
|
||||
|
||||
|
||||
def stop_times_are_sidecar(dataset: Dataset | None) -> bool:
|
||||
if dataset is None:
|
||||
return False
|
||||
storage = dataset_metadata(dataset).get(GTFS_STORAGE_METADATA_KEY)
|
||||
if not isinstance(storage, dict):
|
||||
return False
|
||||
tables = storage.get("tables")
|
||||
if isinstance(tables, dict):
|
||||
return tables.get("gtfs_stop_times") == "sidecar"
|
||||
return storage.get("mode") == GTFS_STORAGE_SIDECAR_STOP_TIMES
|
||||
|
||||
|
||||
def sidecar_path(dataset: Dataset | None) -> Path | None:
|
||||
if dataset is None:
|
||||
return None
|
||||
storage = dataset_metadata(dataset).get(GTFS_STORAGE_METADATA_KEY)
|
||||
if not isinstance(storage, dict):
|
||||
return None
|
||||
value = storage.get("sidecar_path")
|
||||
if not value:
|
||||
return None
|
||||
return Path(str(value))
|
||||
|
||||
|
||||
def dataset_sidecar_paths(dataset: Dataset) -> list[Path]:
|
||||
path = sidecar_path(dataset)
|
||||
return [] if path is None else [path]
|
||||
|
||||
|
||||
def missing_sidecar_paths(dataset: Dataset | None) -> list[str]:
|
||||
if not stop_times_are_sidecar(dataset):
|
||||
return []
|
||||
path = sidecar_path(dataset)
|
||||
if path is None:
|
||||
dataset_id = "unknown" if dataset is None else str(dataset.id)
|
||||
return [f"dataset #{dataset_id} has no configured GTFS sidecar path"]
|
||||
return [] if path.exists() else [str(path)]
|
||||
|
||||
|
||||
def uses_sidecar_stop_times(session: Session, dataset_id: int) -> bool:
|
||||
return stop_times_are_sidecar(session.get(Dataset, dataset_id))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def sidecar_connection(dataset: Dataset) -> Iterator[sqlite3.Connection]:
|
||||
path = sidecar_path(dataset)
|
||||
if path is None:
|
||||
raise MissingGtfsSidecar(dataset.id, None)
|
||||
if not path.exists():
|
||||
raise MissingGtfsSidecar(dataset.id, path)
|
||||
connection = sqlite3.connect(f"file:{path}?mode=ro", uri=True)
|
||||
connection.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield connection
|
||||
finally:
|
||||
connection.close()
|
||||
|
||||
|
||||
def stop_time_count(session: Session, dataset_id: int) -> int:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if stop_times_are_sidecar(dataset):
|
||||
try:
|
||||
with sidecar_connection(dataset) as connection:
|
||||
return int(connection.execute("SELECT COUNT(*) FROM gtfs_stop_times").fetchone()[0] or 0)
|
||||
except MissingGtfsSidecar:
|
||||
return 0
|
||||
return session.scalar(select(func.count()).select_from(GtfsStopTime).where(GtfsStopTime.dataset_id == dataset_id)) or 0
|
||||
|
||||
|
||||
def stop_time_counts_by_dataset(session: Session, dataset_ids: Sequence[int]) -> dict[int, int]:
|
||||
counts: dict[int, int] = {}
|
||||
for dataset_id in dataset_ids:
|
||||
counts[int(dataset_id)] = stop_time_count(session, int(dataset_id))
|
||||
return counts
|
||||
|
||||
|
||||
def scheduled_stop_ids(session: Session, dataset_id: int, stop_ids: Sequence[str]) -> tuple[str, ...]:
|
||||
if not stop_ids:
|
||||
return ()
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
requested = [str(stop_id) for stop_id in stop_ids]
|
||||
found: set[str] = set()
|
||||
if stop_times_are_sidecar(dataset):
|
||||
try:
|
||||
with sidecar_connection(dataset) as connection:
|
||||
for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
|
||||
placeholders = ", ".join(["?"] * len(chunk))
|
||||
rows = connection.execute(
|
||||
f"""
|
||||
SELECT stop_id
|
||||
FROM gtfs_stop_times
|
||||
WHERE stop_id IN ({placeholders})
|
||||
GROUP BY stop_id
|
||||
""",
|
||||
list(chunk),
|
||||
).fetchall()
|
||||
found.update(str(row["stop_id"]) for row in rows)
|
||||
except MissingGtfsSidecar:
|
||||
return ()
|
||||
else:
|
||||
for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
|
||||
rows = session.scalars(
|
||||
select(GtfsStopTime.stop_id)
|
||||
.where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.stop_id.in_(chunk))
|
||||
.group_by(GtfsStopTime.stop_id)
|
||||
).all()
|
||||
found.update(str(row) for row in rows)
|
||||
return tuple(sorted(found))
|
||||
|
||||
|
||||
def all_scheduled_stop_ids(session: Session, dataset_id: int) -> set[str]:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if stop_times_are_sidecar(dataset):
|
||||
try:
|
||||
with sidecar_connection(dataset) as connection:
|
||||
return {
|
||||
str(row["stop_id"])
|
||||
for row in connection.execute("SELECT stop_id FROM gtfs_stop_times GROUP BY stop_id").fetchall()
|
||||
}
|
||||
except MissingGtfsSidecar:
|
||||
return set()
|
||||
return {
|
||||
str(row)
|
||||
for row in session.scalars(
|
||||
select(GtfsStopTime.stop_id)
|
||||
.where(GtfsStopTime.dataset_id == dataset_id)
|
||||
.group_by(GtfsStopTime.stop_id)
|
||||
).all()
|
||||
}
|
||||
|
||||
|
||||
def scheduled_stop_ids_by_dataset(session: Session, dataset_ids: Sequence[int]) -> dict[int, set[str]]:
|
||||
return {int(dataset_id): all_scheduled_stop_ids(session, int(dataset_id)) for dataset_id in dataset_ids}
|
||||
|
||||
|
||||
def has_scheduled_stop(session: Session, dataset_id: int, stop_id: str) -> bool:
|
||||
return bool(scheduled_stop_ids(session, dataset_id, [stop_id]))
|
||||
|
||||
|
||||
def stop_times_by_trip(
|
||||
session: Session,
|
||||
dataset_id: int,
|
||||
trip_ids: Sequence[str],
|
||||
) -> dict[str, list[GtfsStopTime]]:
|
||||
if not trip_ids:
|
||||
return {}
|
||||
grouped: dict[str, list[GtfsStopTime]] = {}
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
requested = [str(trip_id) for trip_id in trip_ids]
|
||||
if stop_times_are_sidecar(dataset):
|
||||
column_sql = ", ".join(GTFS_STOP_TIME_COLUMNS)
|
||||
try:
|
||||
with sidecar_connection(dataset) as connection:
|
||||
for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
|
||||
placeholders = ", ".join(["?"] * len(chunk))
|
||||
rows = connection.execute(
|
||||
f"""
|
||||
SELECT {column_sql}
|
||||
FROM gtfs_stop_times
|
||||
WHERE trip_id IN ({placeholders})
|
||||
ORDER BY trip_id, stop_sequence
|
||||
""",
|
||||
list(chunk),
|
||||
).fetchall()
|
||||
for row in rows:
|
||||
stop_time = stop_time_from_row(dataset_id, row)
|
||||
grouped.setdefault(stop_time.trip_id, []).append(stop_time)
|
||||
except MissingGtfsSidecar:
|
||||
return {}
|
||||
return grouped
|
||||
|
||||
for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
|
||||
rows = session.scalars(
|
||||
select(GtfsStopTime)
|
||||
.where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.trip_id.in_(chunk))
|
||||
.order_by(GtfsStopTime.trip_id, GtfsStopTime.stop_sequence)
|
||||
).all()
|
||||
for row in rows:
|
||||
grouped.setdefault(row.trip_id, []).append(row)
|
||||
return grouped
|
||||
|
||||
|
||||
def stop_times_for_trip_range(
|
||||
session: Session,
|
||||
dataset_id: int,
|
||||
trip_id: str,
|
||||
start_sequence: int,
|
||||
end_sequence: int,
|
||||
) -> list[GtfsStopTime]:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if stop_times_are_sidecar(dataset):
|
||||
column_sql = ", ".join(GTFS_STOP_TIME_COLUMNS)
|
||||
try:
|
||||
with sidecar_connection(dataset) as connection:
|
||||
rows = connection.execute(
|
||||
f"""
|
||||
SELECT {column_sql}
|
||||
FROM gtfs_stop_times
|
||||
WHERE trip_id = ?
|
||||
AND stop_sequence >= ?
|
||||
AND stop_sequence <= ?
|
||||
ORDER BY stop_sequence
|
||||
""",
|
||||
(trip_id, int(start_sequence), int(end_sequence)),
|
||||
).fetchall()
|
||||
return [stop_time_from_row(dataset_id, row) for row in rows]
|
||||
except MissingGtfsSidecar:
|
||||
return []
|
||||
|
||||
return list(
|
||||
session.scalars(
|
||||
select(GtfsStopTime)
|
||||
.where(
|
||||
GtfsStopTime.dataset_id == dataset_id,
|
||||
GtfsStopTime.trip_id == trip_id,
|
||||
GtfsStopTime.stop_sequence >= start_sequence,
|
||||
GtfsStopTime.stop_sequence <= end_sequence,
|
||||
)
|
||||
.order_by(GtfsStopTime.stop_sequence)
|
||||
).all()
|
||||
)
|
||||
|
||||
|
||||
def stop_time_from_row(dataset_id: int, row) -> GtfsStopTime:
|
||||
return GtfsStopTime(
|
||||
dataset_id=dataset_id,
|
||||
trip_id=str(row["trip_id"]),
|
||||
stop_id=str(row["stop_id"]),
|
||||
stop_sequence=int(row["stop_sequence"]),
|
||||
arrival_time=row["arrival_time"],
|
||||
departure_time=row["departure_time"],
|
||||
arrival_seconds=row["arrival_seconds"],
|
||||
departure_seconds=row["departure_seconds"],
|
||||
)
|
||||
|
||||
|
||||
def execute_sidecar_query(session: Session, dataset_id: int, sql: str, params: Sequence[object]) -> list[sqlite3.Row]:
|
||||
dataset = session.get(Dataset, dataset_id)
|
||||
if not stop_times_are_sidecar(dataset):
|
||||
raise ValueError(f"dataset #{dataset_id} does not use sidecar stop_times")
|
||||
try:
|
||||
with sidecar_connection(dataset) as connection:
|
||||
return list(connection.execute(sql, list(params)).fetchall())
|
||||
except MissingGtfsSidecar:
|
||||
return []
|
||||
|
||||
|
||||
def _chunks[T](items: Sequence[T], size: int) -> Iterator[Sequence[T]]:
|
||||
for index in range(0, len(items), size):
|
||||
yield items[index : index + size]
|
||||
Reference in New Issue
Block a user