Files
2026-07-01 23:29:51 +02:00

340 lines
20 KiB
Python

from __future__ import annotations
from contextlib import contextmanager
from pathlib import Path
import re
from typing import Iterator
from sqlalchemy import create_engine
from sqlalchemy import event
from sqlalchemy import text
from sqlalchemy.engine import Connection
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
from app.config import settings
class Base(DeclarativeBase):
pass
def _connect_args() -> dict:
if settings.is_sqlite_database:
return {"check_same_thread": False, "timeout": settings.sqlite_timeout_seconds}
return {}
def _ensure_sqlite_parent() -> None:
if not settings.is_sqlite_database:
return
# sqlite:///./data/workbench.sqlite -> ./data/workbench.sqlite
path = settings.normalized_database_url.replace("sqlite:///", "", 1)
if path and path != ":memory:":
Path(path).parent.mkdir(parents=True, exist_ok=True)
_ensure_sqlite_parent()
engine = create_engine(settings.normalized_database_url, connect_args=_connect_args(), pool_pre_ping=True, future=True)
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False, expire_on_commit=False, future=True)
_CREATE_INDEX_NAME_RE = re.compile(
r"CREATE\s+(?:UNIQUE\s+)?INDEX\s+(?:CONCURRENTLY\s+)?(?:IF\s+NOT\s+EXISTS\s+)?([A-Za-z_][A-Za-z0-9_]*)",
re.IGNORECASE,
)
if settings.is_sqlite_database:
@event.listens_for(engine, "connect")
def _set_sqlite_pragmas(dbapi_connection, _connection_record) -> None:
cursor = dbapi_connection.cursor()
try:
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute(f"PRAGMA busy_timeout={int(settings.sqlite_busy_timeout_ms)}")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA foreign_keys=ON")
finally:
cursor.close()
def init_db() -> None:
# Import models so metadata is populated.
from app import models # noqa: F401
_ensure_database_extensions()
Base.metadata.create_all(bind=engine)
_ensure_runtime_columns()
_ensure_runtime_indexes()
def reset_db() -> None:
from app import models # noqa: F401
_ensure_database_extensions()
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
_ensure_runtime_columns()
_ensure_runtime_indexes()
def _ensure_database_extensions() -> None:
if not settings.is_postgresql_database:
return
with engine.begin() as conn:
conn.execute(text("CREATE EXTENSION IF NOT EXISTS postgis"))
conn.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm"))
has_pgrouting = conn.execute(text("SELECT EXISTS (SELECT 1 FROM pg_available_extensions WHERE name = 'pgrouting')")).scalar()
if has_pgrouting:
conn.execute(text("CREATE EXTENSION IF NOT EXISTS pgrouting"))
def _ensure_runtime_columns() -> None:
if settings.is_postgresql_database:
_ensure_postgresql_runtime_columns()
return
if not settings.is_sqlite_database:
return
with engine.begin() as conn:
columns = {row[1] for row in conn.execute(text("PRAGMA table_info(gtfs_stop_times)")).all()}
if "arrival_seconds" not in columns:
conn.execute(text("ALTER TABLE gtfs_stop_times ADD COLUMN arrival_seconds INTEGER"))
if "departure_seconds" not in columns:
conn.execute(text("ALTER TABLE gtfs_stop_times ADD COLUMN departure_seconds INTEGER"))
source_columns = {row[1] for row in conn.execute(text("PRAGMA table_info(sources)")).all()}
source_runtime_columns = {
"catalog_entry_id": "INTEGER",
"priority": "VARCHAR(16)",
"mode_scope": "TEXT",
"source_basis": "TEXT",
"notes": "TEXT",
}
for column_name, column_type in source_runtime_columns.items():
if column_name not in source_columns:
conn.execute(text(f"ALTER TABLE sources ADD COLUMN {column_name} {column_type}"))
job_columns = {row[1] for row in conn.execute(text("PRAGMA table_info(jobs)")).all()}
job_runtime_columns = {
"priority": "INTEGER NOT NULL DEFAULT 0",
"requested_action": "VARCHAR(32)",
"lease_owner": "VARCHAR(255)",
"lease_expires_at": "DATETIME",
"paused_at": "DATETIME",
"dismissed_at": "DATETIME",
}
for column_name, column_type in job_runtime_columns.items():
if column_name not in job_columns:
conn.execute(text(f"ALTER TABLE jobs ADD COLUMN {column_name} {column_type}"))
route_runtime_tables = {
"gtfs_routes": "VARCHAR(64)",
"route_patterns": "VARCHAR(64)",
"osm_features": "VARCHAR(64)",
}
for table_name, column_type in route_runtime_tables.items():
table_columns = {row[1] for row in conn.execute(text(f"PRAGMA table_info({table_name})")).all()}
if "route_scope" not in table_columns:
conn.execute(text(f"ALTER TABLE {table_name} ADD COLUMN route_scope {column_type}"))
address_columns = {row[1] for row in conn.execute(text("PRAGMA table_info(osm_addresses)")).all()}
if "geometry_geojson" not in address_columns:
conn.execute(text("ALTER TABLE osm_addresses ADD COLUMN geometry_geojson TEXT"))
def _ensure_postgresql_runtime_columns() -> None:
column_statements = [
("osm_features", "geom", "ALTER TABLE osm_features ADD COLUMN geom geometry(Geometry, 4326)"),
("gtfs_routes", "geom", "ALTER TABLE gtfs_routes ADD COLUMN geom geometry(Geometry, 4326)"),
("gtfs_shapes", "geom", "ALTER TABLE gtfs_shapes ADD COLUMN geom geometry(Geometry, 4326)"),
("route_patterns", "geom", "ALTER TABLE route_patterns ADD COLUMN geom geometry(Geometry, 4326)"),
("osm_addresses", "geometry_geojson", "ALTER TABLE osm_addresses ADD COLUMN geometry_geojson TEXT"),
("osm_addresses", "geom", "ALTER TABLE osm_addresses ADD COLUMN geom geometry(Point, 4326)"),
("osm_addresses", "area_geom", "ALTER TABLE osm_addresses ADD COLUMN area_geom geometry(Geometry, 4326)"),
("gtfs_stops", "geom", "ALTER TABLE gtfs_stops ADD COLUMN geom geometry(Point, 4326)"),
("canonical_stops", "geom", "ALTER TABLE canonical_stops ADD COLUMN geom geometry(Point, 4326)"),
("routing_nodes", "geom", "ALTER TABLE routing_nodes ADD COLUMN geom geometry(Point, 4326)"),
("routing_edges", "geom", "ALTER TABLE routing_edges ADD COLUMN geom geometry(LineString, 4326)"),
]
with engine.begin() as conn:
columns = _postgresql_columns(conn)
for table_name, column_name, statement in column_statements:
if (table_name, column_name) not in columns:
conn.execute(text(statement))
country_column = columns.get(("osm_addresses", "country"))
if country_column is not None and country_column["data_type"] != "text":
conn.execute(text("ALTER TABLE osm_addresses ALTER COLUMN country TYPE TEXT"))
def _ensure_runtime_indexes() -> None:
statements = [
"CREATE INDEX IF NOT EXISTS ix_osm_features_map_bbox ON osm_features (dataset_id, kind, mode, min_lon, max_lon, min_lat, max_lat)",
"CREATE INDEX IF NOT EXISTS ix_osm_features_scope_bbox ON osm_features (dataset_id, kind, mode, route_scope, min_lon, max_lon, min_lat, max_lat)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_routes_map_bbox ON gtfs_routes (dataset_id, mode, min_lon, max_lon, min_lat, max_lat)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_routes_scope_bbox ON gtfs_routes (dataset_id, mode, route_scope, min_lon, max_lon, min_lat, max_lat)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stops_map_point ON gtfs_stops (dataset_id, lon, lat)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop ON gtfs_stop_times (dataset_id, stop_id, departure_seconds, trip_id, stop_sequence)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop_depart_trip ON gtfs_stop_times (dataset_id, stop_id, departure_seconds, trip_id)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop_arrival ON gtfs_stop_times (dataset_id, stop_id, arrival_seconds, trip_id, stop_sequence)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop_arrive_trip ON gtfs_stop_times (dataset_id, stop_id, arrival_seconds, trip_id)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_trip_seq ON gtfs_stop_times (dataset_id, trip_id, stop_sequence)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_trip_stop_seq ON gtfs_stop_times (dataset_id, trip_id, stop_id, stop_sequence)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_trip ON gtfs_trips (dataset_id, trip_id)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_route ON gtfs_trips (dataset_id, route_id)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_service ON gtfs_trips (dataset_id, service_id, trip_id)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_route_service ON gtfs_trips (dataset_id, route_id, service_id)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_routes_dataset_route ON gtfs_routes (dataset_id, route_id)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_shapes_dataset_shape ON gtfs_shapes (dataset_id, shape_id)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_calendars_dataset_service_dates ON gtfs_calendars (dataset_id, service_id, start_date, end_date)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_calendar_dates_dataset_date ON gtfs_calendar_dates (dataset_id, date, service_id, exception_type)",
"CREATE INDEX IF NOT EXISTS ix_canonical_stop_links_object ON canonical_stop_links (object_type, dataset_id, object_id)",
"CREATE INDEX IF NOT EXISTS ix_canonical_stop_links_external ON canonical_stop_links (object_type, dataset_id, external_id)",
"CREATE INDEX IF NOT EXISTS ix_route_patterns_ref_mode ON route_patterns (route_ref, mode, source_kind)",
"CREATE INDEX IF NOT EXISTS ix_route_patterns_bbox ON route_patterns (mode, min_lon, max_lon, min_lat, max_lat)",
"CREATE INDEX IF NOT EXISTS ix_route_patterns_scope_bbox ON route_patterns (mode, route_scope, source_kind, min_lon, max_lon, min_lat, max_lat)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_route_pattern_links_trip_shape ON gtfs_route_pattern_links (dataset_id, route_id, shape_id)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_trip_route_pattern_links_trip ON gtfs_trip_route_pattern_links (dataset_id, trip_id)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_trip_route_pattern_links_pattern ON gtfs_trip_route_pattern_links (route_pattern_id, dataset_id, trip_id)",
"CREATE INDEX IF NOT EXISTS ix_sources_catalog_entry ON sources (catalog_entry_id)",
"CREATE INDEX IF NOT EXISTS ix_sources_priority_country_kind ON sources (priority, country, kind)",
"CREATE INDEX IF NOT EXISTS ix_source_catalog_country_priority ON source_catalog_entries (country_code, priority, status)",
"CREATE INDEX IF NOT EXISTS ix_source_catalog_name ON source_catalog_entries (source_name)",
"CREATE INDEX IF NOT EXISTS ix_source_update_checks_source_checked ON source_update_checks (source_id, checked_at)",
"CREATE INDEX IF NOT EXISTS ix_source_update_checks_available ON source_update_checks (source_id, update_available, checked_at)",
"CREATE INDEX IF NOT EXISTS ix_osm_diff_states_source_sequence ON osm_diff_states (source_id, sequence_number)",
"CREATE INDEX IF NOT EXISTS ix_osm_diff_states_source_status ON osm_diff_states (source_id, status, updated_at)",
"CREATE INDEX IF NOT EXISTS ix_jobs_status_created ON jobs (status, created_at)",
"CREATE INDEX IF NOT EXISTS ix_jobs_kind_status ON jobs (kind, status)",
"CREATE INDEX IF NOT EXISTS ix_jobs_queue_claim ON jobs (status, priority, created_at, id)",
"CREATE INDEX IF NOT EXISTS ix_jobs_lease ON jobs (status, lease_expires_at)",
"CREATE INDEX IF NOT EXISTS ix_jobs_dismissed_status ON jobs (dismissed_at, status, created_at)",
"CREATE INDEX IF NOT EXISTS ix_job_events_job_created ON job_events (job_id, created_at, id)",
"CREATE INDEX IF NOT EXISTS ix_pipeline_runs_stage_dataset_hash ON pipeline_runs (stage, dataset_id, dependency_hash, status, started_at)",
"CREATE INDEX IF NOT EXISTS ix_pipeline_runs_stage_source_hash ON pipeline_runs (stage, source_id, dependency_hash, status, started_at)",
"CREATE INDEX IF NOT EXISTS ix_pipeline_runs_job ON pipeline_runs (job_id, stage, status)",
"CREATE INDEX IF NOT EXISTS ix_match_rules_type_active ON match_rules (rule_type, active)",
"CREATE INDEX IF NOT EXISTS ix_journey_search_cache_type_expires ON journey_search_cache (cache_type, expires_at)",
"CREATE INDEX IF NOT EXISTS ix_travel_requests_created ON travel_requests (created_at)",
"CREATE INDEX IF NOT EXISTS ix_itineraries_request_saved ON itineraries (request_id, saved, created_at)",
"CREATE INDEX IF NOT EXISTS ix_itinerary_legs_itinerary_sequence ON itinerary_legs (itinerary_id, sequence)",
"CREATE INDEX IF NOT EXISTS ix_routing_nodes_dataset_osm ON routing_nodes (dataset_id, osm_node_id)",
"CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_source ON routing_edges (dataset_id, source_osm_node_id)",
"CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_target ON routing_edges (dataset_id, target_osm_node_id)",
"CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_drive ON routing_edges (dataset_id, source_osm_node_id) WHERE drive_cost_s IS NOT NULL",
"CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_walk ON routing_edges (dataset_id, source_osm_node_id) WHERE walk_cost_s IS NOT NULL",
"CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_reverse_drive ON routing_edges (dataset_id, target_osm_node_id) WHERE reverse_drive_cost_s IS NOT NULL",
"CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_reverse_walk ON routing_edges (dataset_id, target_osm_node_id) WHERE reverse_walk_cost_s IS NOT NULL",
"CREATE INDEX IF NOT EXISTS ix_routing_edges_bbox ON routing_edges (dataset_id, min_lon, max_lon, min_lat, max_lat)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_dataset_city_street ON osm_addresses (dataset_id, city, street, housenumber)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_dataset_postcode ON osm_addresses (dataset_id, postcode)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_bbox ON osm_addresses (dataset_id, min_lon, max_lon, min_lat, max_lat)",
]
with engine.begin() as conn:
if settings.is_sqlite_database:
conn.execute(text("PRAGMA journal_mode=WAL"))
conn.execute(text(f"PRAGMA busy_timeout={int(settings.sqlite_busy_timeout_ms)}"))
if settings.is_postgresql_database:
_execute_missing_postgresql_indexes(conn, statements + _postgresql_index_statements())
else:
for statement in statements:
conn.execute(text(statement))
def _postgresql_columns(conn: Connection) -> dict[tuple[str, str], dict[str, str]]:
rows = conn.execute(
text(
"""
SELECT table_name, column_name, data_type, udt_name
FROM information_schema.columns
WHERE table_schema = ANY (current_schemas(false))
"""
)
).mappings()
return {
(str(row["table_name"]), str(row["column_name"])): {
"data_type": str(row["data_type"]),
"udt_name": str(row["udt_name"]),
}
for row in rows
}
def _execute_missing_postgresql_indexes(conn: Connection, statements: list[str]) -> None:
existing = _postgresql_index_names(conn)
for statement in statements:
index_name = _index_name_from_create_statement(statement)
if index_name and index_name in existing:
continue
conn.execute(text(statement))
if index_name:
existing.add(index_name)
def _postgresql_index_names(conn: Connection) -> set[str]:
rows = conn.execute(
text(
"""
SELECT indexname
FROM pg_indexes
WHERE schemaname = ANY (current_schemas(false))
"""
)
)
return {str(row[0]) for row in rows}
def _index_name_from_create_statement(statement: str) -> str | None:
match = _CREATE_INDEX_NAME_RE.search(statement)
return match.group(1) if match else None
def _postgresql_index_statements() -> list[str]:
return [
"CREATE INDEX IF NOT EXISTS ix_osm_features_geom_gist ON osm_features USING GIST (geom)",
"CREATE INDEX IF NOT EXISTS ix_osm_features_stop_geom_gist ON osm_features USING GIST (geom) WHERE kind IN ('stop', 'station', 'terminal')",
"CREATE INDEX IF NOT EXISTS ix_osm_features_route_geom_gist ON osm_features USING GIST (geom) WHERE kind = 'route'",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stops_geom_gist ON gtfs_stops USING GIST (geom)",
"CREATE INDEX IF NOT EXISTS ix_canonical_stops_geom_gist ON canonical_stops USING GIST (geom)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_routes_geom_gist ON gtfs_routes USING GIST (geom)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_shapes_geom_gist ON gtfs_shapes USING GIST (geom)",
"CREATE INDEX IF NOT EXISTS ix_route_patterns_geom_gist ON route_patterns USING GIST (geom)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_geom_gist ON osm_addresses USING GIST (geom)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_area_geom_gist ON osm_addresses USING GIST (area_geom)",
"CREATE INDEX IF NOT EXISTS ix_routing_nodes_geom_gist ON routing_nodes USING GIST (geom)",
"CREATE INDEX IF NOT EXISTS ix_routing_edges_bbox_box_gist ON routing_edges USING GIST (box(point(max_lon, max_lat), point(min_lon, min_lat)))",
"CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_route_shape_expr ON gtfs_trips (dataset_id, route_id, (COALESCE(shape_id, '__route__')))",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_dataset_stop ON gtfs_stop_times (dataset_id, stop_id)",
"CREATE INDEX IF NOT EXISTS ix_canonical_stop_links_gtfs_external ON canonical_stop_links (dataset_id, external_id, canonical_stop_id) WHERE object_type = 'gtfs_stop'",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stops_dataset_parent ON gtfs_stops (dataset_id, parent_station)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stops_dataset_stop_prefix ON gtfs_stops (dataset_id, (split_part(stop_id, '::', 1)))",
"CREATE INDEX IF NOT EXISTS ix_osm_features_name_trgm ON osm_features USING GIN (LOWER(COALESCE(name, '')) gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_osm_features_ref_trgm ON osm_features USING GIN (LOWER(COALESCE(ref, '')) gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_osm_features_tags_trgm ON osm_features USING GIN (LOWER(COALESCE(tags_json, '')) gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_search_trgm ON osm_addresses USING GIN (LOWER(COALESCE(search_text, '')) gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_display_trgm ON osm_addresses USING GIN (LOWER(COALESCE(display_name, '')) gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_street_key_house ON osm_addresses (dataset_id, REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss'), housenumber)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_street_key_trgm ON osm_addresses USING GIN (REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss') gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stops_name_trgm ON gtfs_stops USING GIN (name gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_gtfs_stops_stop_id_trgm ON gtfs_stops USING GIN (stop_id gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_route_patterns_ref_trgm ON route_patterns USING GIN (LOWER(COALESCE(route_ref, '')) gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_route_patterns_name_trgm ON route_patterns USING GIN (LOWER(COALESCE(route_name, '')) gin_trgm_ops)",
]
def get_db() -> Iterator[Session]:
db = SessionLocal()
try:
yield db
finally:
db.close()
@contextmanager
def session_scope() -> Iterator[Session]:
db = SessionLocal()
try:
yield db
db.commit()
except Exception:
db.rollback()
raise
finally:
db.close()