from __future__ import annotations from contextlib import contextmanager from pathlib import Path import re from typing import Iterator from sqlalchemy import create_engine from sqlalchemy import event from sqlalchemy import text from sqlalchemy.engine import Connection from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker from app.config import settings class Base(DeclarativeBase): pass def _connect_args() -> dict: if settings.is_sqlite_database: return {"check_same_thread": False, "timeout": settings.sqlite_timeout_seconds} return {} def _ensure_sqlite_parent() -> None: if not settings.is_sqlite_database: return # sqlite:///./data/workbench.sqlite -> ./data/workbench.sqlite path = settings.normalized_database_url.replace("sqlite:///", "", 1) if path and path != ":memory:": Path(path).parent.mkdir(parents=True, exist_ok=True) _ensure_sqlite_parent() engine = create_engine(settings.normalized_database_url, connect_args=_connect_args(), pool_pre_ping=True, future=True) SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False, expire_on_commit=False, future=True) _CREATE_INDEX_NAME_RE = re.compile( r"CREATE\s+(?:UNIQUE\s+)?INDEX\s+(?:CONCURRENTLY\s+)?(?:IF\s+NOT\s+EXISTS\s+)?([A-Za-z_][A-Za-z0-9_]*)", re.IGNORECASE, ) if settings.is_sqlite_database: @event.listens_for(engine, "connect") def _set_sqlite_pragmas(dbapi_connection, _connection_record) -> None: cursor = dbapi_connection.cursor() try: cursor.execute("PRAGMA journal_mode=WAL") cursor.execute(f"PRAGMA busy_timeout={int(settings.sqlite_busy_timeout_ms)}") cursor.execute("PRAGMA synchronous=NORMAL") cursor.execute("PRAGMA foreign_keys=ON") finally: cursor.close() def init_db() -> None: # Import models so metadata is populated. from app import models # noqa: F401 _ensure_database_extensions() Base.metadata.create_all(bind=engine) _ensure_runtime_columns() _ensure_runtime_indexes() def reset_db() -> None: from app import models # noqa: F401 _ensure_database_extensions() Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) _ensure_runtime_columns() _ensure_runtime_indexes() def _ensure_database_extensions() -> None: if not settings.is_postgresql_database: return with engine.begin() as conn: conn.execute(text("CREATE EXTENSION IF NOT EXISTS postgis")) conn.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm")) has_pgrouting = conn.execute(text("SELECT EXISTS (SELECT 1 FROM pg_available_extensions WHERE name = 'pgrouting')")).scalar() if has_pgrouting: conn.execute(text("CREATE EXTENSION IF NOT EXISTS pgrouting")) def _ensure_runtime_columns() -> None: if settings.is_postgresql_database: _ensure_postgresql_runtime_columns() return if not settings.is_sqlite_database: return with engine.begin() as conn: columns = {row[1] for row in conn.execute(text("PRAGMA table_info(gtfs_stop_times)")).all()} if "arrival_seconds" not in columns: conn.execute(text("ALTER TABLE gtfs_stop_times ADD COLUMN arrival_seconds INTEGER")) if "departure_seconds" not in columns: conn.execute(text("ALTER TABLE gtfs_stop_times ADD COLUMN departure_seconds INTEGER")) source_columns = {row[1] for row in conn.execute(text("PRAGMA table_info(sources)")).all()} source_runtime_columns = { "catalog_entry_id": "INTEGER", "priority": "VARCHAR(16)", "mode_scope": "TEXT", "source_basis": "TEXT", "notes": "TEXT", } for column_name, column_type in source_runtime_columns.items(): if column_name not in source_columns: conn.execute(text(f"ALTER TABLE sources ADD COLUMN {column_name} {column_type}")) job_columns = {row[1] for row in conn.execute(text("PRAGMA table_info(jobs)")).all()} job_runtime_columns = { "priority": "INTEGER NOT NULL DEFAULT 0", "requested_action": "VARCHAR(32)", "lease_owner": "VARCHAR(255)", "lease_expires_at": "DATETIME", "paused_at": "DATETIME", "dismissed_at": "DATETIME", } for column_name, column_type in job_runtime_columns.items(): if column_name not in job_columns: conn.execute(text(f"ALTER TABLE jobs ADD COLUMN {column_name} {column_type}")) route_runtime_tables = { "gtfs_routes": "VARCHAR(64)", "route_patterns": "VARCHAR(64)", "osm_features": "VARCHAR(64)", } for table_name, column_type in route_runtime_tables.items(): table_columns = {row[1] for row in conn.execute(text(f"PRAGMA table_info({table_name})")).all()} if "route_scope" not in table_columns: conn.execute(text(f"ALTER TABLE {table_name} ADD COLUMN route_scope {column_type}")) address_columns = {row[1] for row in conn.execute(text("PRAGMA table_info(osm_addresses)")).all()} if "geometry_geojson" not in address_columns: conn.execute(text("ALTER TABLE osm_addresses ADD COLUMN geometry_geojson TEXT")) def _ensure_postgresql_runtime_columns() -> None: column_statements = [ ("osm_features", "geom", "ALTER TABLE osm_features ADD COLUMN geom geometry(Geometry, 4326)"), ("gtfs_routes", "geom", "ALTER TABLE gtfs_routes ADD COLUMN geom geometry(Geometry, 4326)"), ("gtfs_shapes", "geom", "ALTER TABLE gtfs_shapes ADD COLUMN geom geometry(Geometry, 4326)"), ("route_patterns", "geom", "ALTER TABLE route_patterns ADD COLUMN geom geometry(Geometry, 4326)"), ("osm_addresses", "geometry_geojson", "ALTER TABLE osm_addresses ADD COLUMN geometry_geojson TEXT"), ("osm_addresses", "geom", "ALTER TABLE osm_addresses ADD COLUMN geom geometry(Point, 4326)"), ("osm_addresses", "area_geom", "ALTER TABLE osm_addresses ADD COLUMN area_geom geometry(Geometry, 4326)"), ("gtfs_stops", "geom", "ALTER TABLE gtfs_stops ADD COLUMN geom geometry(Point, 4326)"), ("canonical_stops", "geom", "ALTER TABLE canonical_stops ADD COLUMN geom geometry(Point, 4326)"), ("routing_nodes", "geom", "ALTER TABLE routing_nodes ADD COLUMN geom geometry(Point, 4326)"), ("routing_edges", "geom", "ALTER TABLE routing_edges ADD COLUMN geom geometry(LineString, 4326)"), ] with engine.begin() as conn: columns = _postgresql_columns(conn) for table_name, column_name, statement in column_statements: if (table_name, column_name) not in columns: conn.execute(text(statement)) country_column = columns.get(("osm_addresses", "country")) if country_column is not None and country_column["data_type"] != "text": conn.execute(text("ALTER TABLE osm_addresses ALTER COLUMN country TYPE TEXT")) def _ensure_runtime_indexes() -> None: statements = [ "CREATE INDEX IF NOT EXISTS ix_osm_features_map_bbox ON osm_features (dataset_id, kind, mode, min_lon, max_lon, min_lat, max_lat)", "CREATE INDEX IF NOT EXISTS ix_osm_features_scope_bbox ON osm_features (dataset_id, kind, mode, route_scope, min_lon, max_lon, min_lat, max_lat)", "CREATE INDEX IF NOT EXISTS ix_gtfs_routes_map_bbox ON gtfs_routes (dataset_id, mode, min_lon, max_lon, min_lat, max_lat)", "CREATE INDEX IF NOT EXISTS ix_gtfs_routes_scope_bbox ON gtfs_routes (dataset_id, mode, route_scope, min_lon, max_lon, min_lat, max_lat)", "CREATE INDEX IF NOT EXISTS ix_gtfs_stops_map_point ON gtfs_stops (dataset_id, lon, lat)", "CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop ON gtfs_stop_times (dataset_id, stop_id, departure_seconds, trip_id, stop_sequence)", "CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop_depart_trip ON gtfs_stop_times (dataset_id, stop_id, departure_seconds, trip_id)", "CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop_arrival ON gtfs_stop_times (dataset_id, stop_id, arrival_seconds, trip_id, stop_sequence)", "CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop_arrive_trip ON gtfs_stop_times (dataset_id, stop_id, arrival_seconds, trip_id)", "CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_trip_seq ON gtfs_stop_times (dataset_id, trip_id, stop_sequence)", "CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_trip_stop_seq ON gtfs_stop_times (dataset_id, trip_id, stop_id, stop_sequence)", "CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_trip ON gtfs_trips (dataset_id, trip_id)", "CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_route ON gtfs_trips (dataset_id, route_id)", "CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_service ON gtfs_trips (dataset_id, service_id, trip_id)", "CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_route_service ON gtfs_trips (dataset_id, route_id, service_id)", "CREATE INDEX IF NOT EXISTS ix_gtfs_routes_dataset_route ON gtfs_routes (dataset_id, route_id)", "CREATE INDEX IF NOT EXISTS ix_gtfs_shapes_dataset_shape ON gtfs_shapes (dataset_id, shape_id)", "CREATE INDEX IF NOT EXISTS ix_gtfs_calendars_dataset_service_dates ON gtfs_calendars (dataset_id, service_id, start_date, end_date)", "CREATE INDEX IF NOT EXISTS ix_gtfs_calendar_dates_dataset_date ON gtfs_calendar_dates (dataset_id, date, service_id, exception_type)", "CREATE INDEX IF NOT EXISTS ix_canonical_stop_links_object ON canonical_stop_links (object_type, dataset_id, object_id)", "CREATE INDEX IF NOT EXISTS ix_canonical_stop_links_external ON canonical_stop_links (object_type, dataset_id, external_id)", "CREATE INDEX IF NOT EXISTS ix_route_patterns_ref_mode ON route_patterns (route_ref, mode, source_kind)", "CREATE INDEX IF NOT EXISTS ix_route_patterns_bbox ON route_patterns (mode, min_lon, max_lon, min_lat, max_lat)", "CREATE INDEX IF NOT EXISTS ix_route_patterns_scope_bbox ON route_patterns (mode, route_scope, source_kind, min_lon, max_lon, min_lat, max_lat)", "CREATE INDEX IF NOT EXISTS ix_gtfs_route_pattern_links_trip_shape ON gtfs_route_pattern_links (dataset_id, route_id, shape_id)", "CREATE INDEX IF NOT EXISTS ix_gtfs_trip_route_pattern_links_trip ON gtfs_trip_route_pattern_links (dataset_id, trip_id)", "CREATE INDEX IF NOT EXISTS ix_gtfs_trip_route_pattern_links_pattern ON gtfs_trip_route_pattern_links (route_pattern_id, dataset_id, trip_id)", "CREATE INDEX IF NOT EXISTS ix_sources_catalog_entry ON sources (catalog_entry_id)", "CREATE INDEX IF NOT EXISTS ix_sources_priority_country_kind ON sources (priority, country, kind)", "CREATE INDEX IF NOT EXISTS ix_source_catalog_country_priority ON source_catalog_entries (country_code, priority, status)", "CREATE INDEX IF NOT EXISTS ix_source_catalog_name ON source_catalog_entries (source_name)", "CREATE INDEX IF NOT EXISTS ix_source_update_checks_source_checked ON source_update_checks (source_id, checked_at)", "CREATE INDEX IF NOT EXISTS ix_source_update_checks_available ON source_update_checks (source_id, update_available, checked_at)", "CREATE INDEX IF NOT EXISTS ix_osm_diff_states_source_sequence ON osm_diff_states (source_id, sequence_number)", "CREATE INDEX IF NOT EXISTS ix_osm_diff_states_source_status ON osm_diff_states (source_id, status, updated_at)", "CREATE INDEX IF NOT EXISTS ix_jobs_status_created ON jobs (status, created_at)", "CREATE INDEX IF NOT EXISTS ix_jobs_kind_status ON jobs (kind, status)", "CREATE INDEX IF NOT EXISTS ix_jobs_queue_claim ON jobs (status, priority, created_at, id)", "CREATE INDEX IF NOT EXISTS ix_jobs_lease ON jobs (status, lease_expires_at)", "CREATE INDEX IF NOT EXISTS ix_jobs_dismissed_status ON jobs (dismissed_at, status, created_at)", "CREATE INDEX IF NOT EXISTS ix_job_events_job_created ON job_events (job_id, created_at, id)", "CREATE INDEX IF NOT EXISTS ix_pipeline_runs_stage_dataset_hash ON pipeline_runs (stage, dataset_id, dependency_hash, status, started_at)", "CREATE INDEX IF NOT EXISTS ix_pipeline_runs_stage_source_hash ON pipeline_runs (stage, source_id, dependency_hash, status, started_at)", "CREATE INDEX IF NOT EXISTS ix_pipeline_runs_job ON pipeline_runs (job_id, stage, status)", "CREATE INDEX IF NOT EXISTS ix_match_rules_type_active ON match_rules (rule_type, active)", "CREATE INDEX IF NOT EXISTS ix_journey_search_cache_type_expires ON journey_search_cache (cache_type, expires_at)", "CREATE INDEX IF NOT EXISTS ix_travel_requests_created ON travel_requests (created_at)", "CREATE INDEX IF NOT EXISTS ix_itineraries_request_saved ON itineraries (request_id, saved, created_at)", "CREATE INDEX IF NOT EXISTS ix_itinerary_legs_itinerary_sequence ON itinerary_legs (itinerary_id, sequence)", "CREATE INDEX IF NOT EXISTS ix_routing_nodes_dataset_osm ON routing_nodes (dataset_id, osm_node_id)", "CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_source ON routing_edges (dataset_id, source_osm_node_id)", "CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_target ON routing_edges (dataset_id, target_osm_node_id)", "CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_drive ON routing_edges (dataset_id, source_osm_node_id) WHERE drive_cost_s IS NOT NULL", "CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_walk ON routing_edges (dataset_id, source_osm_node_id) WHERE walk_cost_s IS NOT NULL", "CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_reverse_drive ON routing_edges (dataset_id, target_osm_node_id) WHERE reverse_drive_cost_s IS NOT NULL", "CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_reverse_walk ON routing_edges (dataset_id, target_osm_node_id) WHERE reverse_walk_cost_s IS NOT NULL", "CREATE INDEX IF NOT EXISTS ix_routing_edges_bbox ON routing_edges (dataset_id, min_lon, max_lon, min_lat, max_lat)", "CREATE INDEX IF NOT EXISTS ix_osm_addresses_dataset_city_street ON osm_addresses (dataset_id, city, street, housenumber)", "CREATE INDEX IF NOT EXISTS ix_osm_addresses_dataset_postcode ON osm_addresses (dataset_id, postcode)", "CREATE INDEX IF NOT EXISTS ix_osm_addresses_bbox ON osm_addresses (dataset_id, min_lon, max_lon, min_lat, max_lat)", ] with engine.begin() as conn: if settings.is_sqlite_database: conn.execute(text("PRAGMA journal_mode=WAL")) conn.execute(text(f"PRAGMA busy_timeout={int(settings.sqlite_busy_timeout_ms)}")) if settings.is_postgresql_database: _execute_missing_postgresql_indexes(conn, statements + _postgresql_index_statements()) else: for statement in statements: conn.execute(text(statement)) def _postgresql_columns(conn: Connection) -> dict[tuple[str, str], dict[str, str]]: rows = conn.execute( text( """ SELECT table_name, column_name, data_type, udt_name FROM information_schema.columns WHERE table_schema = ANY (current_schemas(false)) """ ) ).mappings() return { (str(row["table_name"]), str(row["column_name"])): { "data_type": str(row["data_type"]), "udt_name": str(row["udt_name"]), } for row in rows } def _execute_missing_postgresql_indexes(conn: Connection, statements: list[str]) -> None: existing = _postgresql_index_names(conn) for statement in statements: index_name = _index_name_from_create_statement(statement) if index_name and index_name in existing: continue conn.execute(text(statement)) if index_name: existing.add(index_name) def _postgresql_index_names(conn: Connection) -> set[str]: rows = conn.execute( text( """ SELECT indexname FROM pg_indexes WHERE schemaname = ANY (current_schemas(false)) """ ) ) return {str(row[0]) for row in rows} def _index_name_from_create_statement(statement: str) -> str | None: match = _CREATE_INDEX_NAME_RE.search(statement) return match.group(1) if match else None def _postgresql_index_statements() -> list[str]: return [ "CREATE INDEX IF NOT EXISTS ix_osm_features_geom_gist ON osm_features USING GIST (geom)", "CREATE INDEX IF NOT EXISTS ix_osm_features_stop_geom_gist ON osm_features USING GIST (geom) WHERE kind IN ('stop', 'station', 'terminal')", "CREATE INDEX IF NOT EXISTS ix_osm_features_route_geom_gist ON osm_features USING GIST (geom) WHERE kind = 'route'", "CREATE INDEX IF NOT EXISTS ix_gtfs_stops_geom_gist ON gtfs_stops USING GIST (geom)", "CREATE INDEX IF NOT EXISTS ix_canonical_stops_geom_gist ON canonical_stops USING GIST (geom)", "CREATE INDEX IF NOT EXISTS ix_gtfs_routes_geom_gist ON gtfs_routes USING GIST (geom)", "CREATE INDEX IF NOT EXISTS ix_gtfs_shapes_geom_gist ON gtfs_shapes USING GIST (geom)", "CREATE INDEX IF NOT EXISTS ix_route_patterns_geom_gist ON route_patterns USING GIST (geom)", "CREATE INDEX IF NOT EXISTS ix_osm_addresses_geom_gist ON osm_addresses USING GIST (geom)", "CREATE INDEX IF NOT EXISTS ix_osm_addresses_area_geom_gist ON osm_addresses USING GIST (area_geom)", "CREATE INDEX IF NOT EXISTS ix_routing_nodes_geom_gist ON routing_nodes USING GIST (geom)", "CREATE INDEX IF NOT EXISTS ix_routing_edges_bbox_box_gist ON routing_edges USING GIST (box(point(max_lon, max_lat), point(min_lon, min_lat)))", "CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_route_shape_expr ON gtfs_trips (dataset_id, route_id, (COALESCE(shape_id, '__route__')))", "CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_dataset_stop ON gtfs_stop_times (dataset_id, stop_id)", "CREATE INDEX IF NOT EXISTS ix_canonical_stop_links_gtfs_external ON canonical_stop_links (dataset_id, external_id, canonical_stop_id) WHERE object_type = 'gtfs_stop'", "CREATE INDEX IF NOT EXISTS ix_gtfs_stops_dataset_parent ON gtfs_stops (dataset_id, parent_station)", "CREATE INDEX IF NOT EXISTS ix_gtfs_stops_dataset_stop_prefix ON gtfs_stops (dataset_id, (split_part(stop_id, '::', 1)))", "CREATE INDEX IF NOT EXISTS ix_osm_features_name_trgm ON osm_features USING GIN (LOWER(COALESCE(name, '')) gin_trgm_ops)", "CREATE INDEX IF NOT EXISTS ix_osm_features_ref_trgm ON osm_features USING GIN (LOWER(COALESCE(ref, '')) gin_trgm_ops)", "CREATE INDEX IF NOT EXISTS ix_osm_features_tags_trgm ON osm_features USING GIN (LOWER(COALESCE(tags_json, '')) gin_trgm_ops)", "CREATE INDEX IF NOT EXISTS ix_osm_addresses_search_trgm ON osm_addresses USING GIN (LOWER(COALESCE(search_text, '')) gin_trgm_ops)", "CREATE INDEX IF NOT EXISTS ix_osm_addresses_display_trgm ON osm_addresses USING GIN (LOWER(COALESCE(display_name, '')) gin_trgm_ops)", "CREATE INDEX IF NOT EXISTS ix_osm_addresses_street_key_house ON osm_addresses (dataset_id, REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss'), housenumber)", "CREATE INDEX IF NOT EXISTS ix_osm_addresses_street_key_trgm ON osm_addresses USING GIN (REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss') gin_trgm_ops)", "CREATE INDEX IF NOT EXISTS ix_gtfs_stops_name_trgm ON gtfs_stops USING GIN (name gin_trgm_ops)", "CREATE INDEX IF NOT EXISTS ix_gtfs_stops_stop_id_trgm ON gtfs_stops USING GIN (stop_id gin_trgm_ops)", "CREATE INDEX IF NOT EXISTS ix_route_patterns_ref_trgm ON route_patterns USING GIN (LOWER(COALESCE(route_ref, '')) gin_trgm_ops)", "CREATE INDEX IF NOT EXISTS ix_route_patterns_name_trgm ON route_patterns USING GIN (LOWER(COALESCE(route_name, '')) gin_trgm_ops)", ] def get_db() -> Iterator[Session]: db = SessionLocal() try: yield db finally: db.close() @contextmanager def session_scope() -> Iterator[Session]: db = SessionLocal() try: yield db db.commit() except Exception: db.rollback() raise finally: db.close()