from __future__ import annotations import argparse import json import os import sqlite3 import sys from pathlib import Path from typing import Any, Iterable REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) TABLE_ORDER = [ "source_catalog_entries", "sources", "datasets", "source_update_checks", "osm_diff_states", "jobs", "job_events", "pipeline_runs", "gtfs_agencies", "gtfs_stops", "gtfs_routes", "gtfs_trips", "gtfs_calendars", "gtfs_calendar_dates", "gtfs_shapes", "gtfs_stop_times", "osm_features", "canonical_stops", "canonical_stop_links", "route_matches", "match_rules", "route_patterns", "route_pattern_stops", "gtfs_route_pattern_links", "gtfs_trip_route_pattern_links", "travel_requests", "itineraries", "itinerary_legs", ] def main() -> int: parser = argparse.ArgumentParser(description="Migrate a Mobility Workbench SQLite database to PostgreSQL/PostGIS.") parser.add_argument("--postgres-url", default=os.environ.get("POSTGRES_DATABASE_URL") or os.environ.get("DATABASE_URL")) parser.add_argument("--sqlite-path", default="data/workbench.sqlite") parser.add_argument("--reset", action="store_true", help="Drop and recreate the target PostgreSQL schema before copying.") parser.add_argument("--batch-size", type=int, default=100_000) parser.add_argument("--strict-sidecars", action="store_true", help="Fail when a referenced sidecar file is missing.") args = parser.parse_args() if not args.postgres_url: from app.config import settings as parsed_settings if parsed_settings.is_postgresql_database: args.postgres_url = parsed_settings.database_url if not args.postgres_url: parser.error("--postgres-url or POSTGRES_DATABASE_URL is required") if not str(args.postgres_url).startswith(("postgresql://", "postgresql+psycopg://")): parser.error("--postgres-url must be a PostgreSQL SQLAlchemy URL") sqlite_path = Path(args.sqlite_path) if not sqlite_path.exists(): parser.error(f"SQLite database does not exist: {sqlite_path}") os.environ["DATABASE_URL"] = str(args.postgres_url) from app import models # noqa: F401 from app.db import Base, SessionLocal, _ensure_database_extensions, _ensure_runtime_columns, _ensure_runtime_indexes, engine, init_db from app.gtfs_storage import GTFS_STORAGE_METADATA_KEY, GTFS_STORAGE_MAIN, GTFS_STOP_TIME_COLUMNS from app.osm_storage import OSM_FEATURE_COLUMNS, OSM_STORAGE_MAIN, OSM_STORAGE_METADATA_KEY from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries if args.reset: print("Resetting PostgreSQL schema without secondary indexes...") _ensure_database_extensions() Base.metadata.drop_all(bind=engine) Base.metadata.create_all(bind=engine) _ensure_runtime_columns() else: print("Initializing PostgreSQL schema...") init_db() source = sqlite3.connect(sqlite_path) source.row_factory = sqlite3.Row try: source_tables = _sqlite_tables(source) target_columns = {name: list(table.c.keys()) for name, table in Base.metadata.tables.items()} bool_columns = { name: _boolean_columns(table) for name, table in Base.metadata.tables.items() } import psycopg with psycopg.connect(_psycopg_url(str(args.postgres_url))) as pg: copied_tables: list[str] = [] for table_name in TABLE_ORDER: if table_name not in source_tables or table_name not in target_columns: continue copied = _copy_sqlite_table( source, pg, table_name=table_name, target_columns=target_columns[table_name], bool_columns=bool_columns.get(table_name, set()), batch_size=max(1_000, int(args.batch_size)), ) copied_tables.append(table_name) print(f"Copied {copied:,} rows from {table_name}.") pg.commit() _reset_sequences(pg, target_columns) pg.commit() sidecar_results = _copy_sidecars( source, pg, sqlite_base_dir=sqlite_path.parent, batch_size=max(1_000, int(args.batch_size)), strict=args.strict_sidecars, osm_columns=OSM_FEATURE_COLUMNS, gtfs_stop_time_columns=GTFS_STOP_TIME_COLUMNS, gtfs_storage_key=GTFS_STORAGE_METADATA_KEY, osm_storage_key=OSM_STORAGE_METADATA_KEY, gtfs_main_mode=GTFS_STORAGE_MAIN, osm_main_mode=OSM_STORAGE_MAIN, ) _reset_sequences(pg, target_columns) pg.commit() print("Refreshing PostGIS geometries and indexes...") with SessionLocal() as session: refresh_postgis_geometries(session, only_missing=False) session.commit() _ensure_runtime_indexes() with SessionLocal() as session: analyze_postgresql_tables(session, copied_tables + ["osm_features", "gtfs_stop_times"]) session.commit() print("Migration complete.") for message in sidecar_results: print(message) return 0 finally: source.close() def _copy_sqlite_table( source: sqlite3.Connection, pg, *, table_name: str, target_columns: list[str], bool_columns: set[str], batch_size: int, ) -> int: source_columns = [column for column in _sqlite_columns(source, table_name) if column in target_columns] if not source_columns: return 0 total = 0 select_sql = f"SELECT {', '.join(_quote_sqlite(column) for column in source_columns)} FROM {_quote_sqlite(table_name)}" cursor = source.execute(select_sql) try: while True: rows = cursor.fetchmany(batch_size) if not rows: break _copy_rows( pg, table_name=table_name, columns=source_columns, rows=(_row_values(row, source_columns, bool_columns) for row in rows), ) total += len(rows) finally: cursor.close() return total def _copy_sidecars( source: sqlite3.Connection, pg, *, sqlite_base_dir: Path, batch_size: int, strict: bool, osm_columns: list[str], gtfs_stop_time_columns: list[str], gtfs_storage_key: str, osm_storage_key: str, gtfs_main_mode: str, osm_main_mode: str, ) -> list[str]: messages: list[str] = [] dataset_rows = source.execute("SELECT id, kind, metadata_json FROM datasets ORDER BY id").fetchall() for row in dataset_rows: dataset_id = int(row["id"]) metadata = _json_dict(row["metadata_json"]) gtfs_storage = metadata.get(gtfs_storage_key) if isinstance(gtfs_storage, dict) and _storage_uses_sidecar(gtfs_storage, "gtfs_stop_times"): path = _resolve_sidecar_path(gtfs_storage.get("sidecar_path"), sqlite_base_dir) if path is None or not path.exists(): message = f"Missing GTFS sidecar for dataset #{dataset_id}: {path}" if strict: raise FileNotFoundError(message) messages.append(message) else: existing = _pg_scalar(pg, "SELECT COUNT(*) FROM gtfs_stop_times WHERE dataset_id = %s", [dataset_id]) if int(existing or 0) > 0: messages.append(f"Skipped GTFS sidecar for dataset #{dataset_id}; target already has stop_times rows.") else: copied = _copy_gtfs_sidecar(pg, dataset_id, path, gtfs_stop_time_columns, batch_size) _mark_storage_main(metadata, gtfs_storage_key, "gtfs_stop_times", gtfs_main_mode, path) _update_dataset_metadata(pg, dataset_id, metadata) pg.commit() messages.append(f"Copied {copied:,} GTFS stop_times rows from {path}.") osm_storage = metadata.get(osm_storage_key) if isinstance(osm_storage, dict) and _storage_uses_sidecar(osm_storage, "osm_features"): path = _resolve_sidecar_path(osm_storage.get("sidecar_path"), sqlite_base_dir) if path is None or not path.exists(): message = f"Missing OSM sidecar for dataset #{dataset_id}: {path}" if strict: raise FileNotFoundError(message) messages.append(message) else: copied, inserted = _copy_osm_sidecar(pg, dataset_id, path, osm_columns, batch_size) _mark_storage_main(metadata, osm_storage_key, "osm_features", osm_main_mode, path) _update_dataset_metadata(pg, dataset_id, metadata) pg.commit() messages.append(f"Copied {copied:,} OSM sidecar rows from {path}; inserted {inserted:,} new main rows.") return messages def _copy_gtfs_sidecar(pg, dataset_id: int, path: Path, columns: list[str], batch_size: int) -> int: source = sqlite3.connect(path) source.row_factory = sqlite3.Row try: available = _sqlite_columns(source, "gtfs_stop_times") select_columns = [(_quote_sqlite(column) if column in available else f"NULL AS {_quote_sqlite(column)}") for column in columns] total = 0 cursor = source.execute(f"SELECT {', '.join(select_columns)} FROM gtfs_stop_times") try: while True: rows = cursor.fetchmany(batch_size) if not rows: break _copy_rows( pg, table_name="gtfs_stop_times", columns=["dataset_id", *columns], rows=([dataset_id, *[row[column] for column in columns]] for row in rows), ) total += len(rows) finally: cursor.close() return total finally: source.close() def _copy_osm_sidecar(pg, dataset_id: int, path: Path, columns: list[str], batch_size: int) -> tuple[int, int]: source = sqlite3.connect(path) source.row_factory = sqlite3.Row temp_table = "tmp_osm_sidecar_features" try: available = _sqlite_columns(source, "osm_features") payload_columns = [column for column in columns if column != "dataset_id"] select_columns = [ (_quote_sqlite(column) if column in available else f"NULL AS {_quote_sqlite(column)}") for column in payload_columns ] with pg.cursor() as cur: cur.execute(f"DROP TABLE IF EXISTS pg_temp.{_quote_pg(temp_table)}") cur.execute(f"CREATE TEMP TABLE {temp_table} (LIKE osm_features INCLUDING DEFAULTS) ON COMMIT DROP") copied = 0 cursor = source.execute(f"SELECT {', '.join(select_columns)} FROM osm_features") try: while True: rows = cursor.fetchmany(batch_size) if not rows: break _copy_rows( pg, table_name=temp_table, columns=columns, rows=([dataset_id, *[row[column] for column in payload_columns]] for row in rows), ) copied += len(rows) finally: cursor.close() with pg.cursor() as cur: column_sql = ", ".join(_quote_pg(column) for column in columns) cur.execute( f""" INSERT INTO osm_features ({column_sql}) SELECT {column_sql} FROM {temp_table} ON CONFLICT ON CONSTRAINT uq_osm_feature_dataset_type_id DO NOTHING """ ) inserted = int(cur.rowcount or 0) cur.execute(f"DROP TABLE IF EXISTS pg_temp.{_quote_pg(temp_table)}") return copied, inserted finally: source.close() def _copy_rows(pg, *, table_name: str, columns: list[str], rows: Iterable[Iterable[Any]]) -> None: column_sql = ", ".join(_quote_pg(column) for column in columns) with pg.cursor() as cur: with cur.copy(f"COPY {_quote_pg(table_name)} ({column_sql}) FROM STDIN") as copy: for row in rows: copy.write_row(list(row)) def _reset_sequences(pg, target_columns: dict[str, list[str]]) -> None: with pg.cursor() as cur: for table_name, columns in target_columns.items(): if "id" not in columns: continue cur.execute("SELECT pg_get_serial_sequence(%s, 'id')", [table_name]) row = cur.fetchone() sequence_name = row[0] if row else None if not sequence_name: continue cur.execute( """ SELECT setval( %s, COALESCE((SELECT MAX(id) FROM {table}), 1), (SELECT MAX(id) IS NOT NULL FROM {table}) ) """.format(table=_quote_pg(table_name)), [sequence_name], ) def _mark_storage_main(metadata: dict[str, Any], key: str, table_name: str, mode: str, sidecar_path: Path) -> None: storage = metadata.setdefault(key, {}) if not isinstance(storage, dict): storage = {} metadata[key] = storage storage["mode"] = mode storage["tables"] = {table_name: "main"} storage["storage_status"] = "ready" storage["legacy_sidecar_path"] = str(sidecar_path) storage.pop("sidecar_path", None) storage.pop("sidecar_status", None) def _update_dataset_metadata(pg, dataset_id: int, metadata: dict[str, Any]) -> None: with pg.cursor() as cur: cur.execute( "UPDATE datasets SET metadata_json = %s WHERE id = %s", [json.dumps(metadata, separators=(",", ":")), dataset_id], ) def _pg_scalar(pg, sql: str, params: list[Any]) -> Any: with pg.cursor() as cur: cur.execute(sql, params) row = cur.fetchone() return row[0] if row else None def _sqlite_tables(connection: sqlite3.Connection) -> set[str]: return { str(row["name"]) for row in connection.execute("SELECT name FROM sqlite_master WHERE type = 'table'").fetchall() } def _sqlite_columns(connection: sqlite3.Connection, table_name: str) -> list[str]: return [str(row["name"]) for row in connection.execute(f"PRAGMA table_info({_quote_sqlite(table_name)})").fetchall()] def _row_values(row: sqlite3.Row, columns: list[str], bool_columns: set[str]) -> list[Any]: values: list[Any] = [] for column in columns: value = row[column] if column in bool_columns and value is not None: value = bool(value) values.append(value) return values def _boolean_columns(table) -> set[str]: columns: set[str] = set() for column in table.c: try: if column.type.python_type is bool: columns.add(str(column.name)) except NotImplementedError: continue return columns def _storage_uses_sidecar(storage: dict[str, Any], table_name: str) -> bool: tables = storage.get("tables") if isinstance(tables, dict) and tables.get(table_name) == "sidecar": return True return str(storage.get("mode") or "").startswith("sidecar") def _resolve_sidecar_path(value: Any, base_dir: Path) -> Path | None: if not value: return None path = Path(str(value)) if path.is_absolute(): return path if path.exists(): return path source_relative = base_dir / path if source_relative.exists(): return source_relative repo_relative = base_dir.parent / path if repo_relative.exists(): return repo_relative return path def _json_dict(value: str | None) -> dict[str, Any]: try: data = json.loads(value or "{}") except json.JSONDecodeError: return {} return data if isinstance(data, dict) else {} def _psycopg_url(value: str) -> str: return value.replace("postgresql+psycopg://", "postgresql://", 1) def _quote_pg(identifier: str) -> str: return '"' + identifier.replace('"', '""') + '"' def _quote_sqlite(identifier: str) -> str: return '"' + identifier.replace('"', '""') + '"' if __name__ == "__main__": sys.exit(main())