Files
meubility-workbench/scripts/migrate_sqlite_to_postgres.py
2026-07-01 23:29:51 +02:00

453 lines
16 KiB
Python

from __future__ import annotations
import argparse
import json
import os
import sqlite3
import sys
from pathlib import Path
from typing import Any, Iterable
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
TABLE_ORDER = [
"source_catalog_entries",
"sources",
"datasets",
"source_update_checks",
"osm_diff_states",
"jobs",
"job_events",
"pipeline_runs",
"gtfs_agencies",
"gtfs_stops",
"gtfs_routes",
"gtfs_trips",
"gtfs_calendars",
"gtfs_calendar_dates",
"gtfs_shapes",
"gtfs_stop_times",
"osm_features",
"canonical_stops",
"canonical_stop_links",
"route_matches",
"match_rules",
"route_patterns",
"route_pattern_stops",
"gtfs_route_pattern_links",
"gtfs_trip_route_pattern_links",
"travel_requests",
"itineraries",
"itinerary_legs",
]
def main() -> int:
parser = argparse.ArgumentParser(description="Migrate a Mobility Workbench SQLite database to PostgreSQL/PostGIS.")
parser.add_argument("--postgres-url", default=os.environ.get("POSTGRES_DATABASE_URL") or os.environ.get("DATABASE_URL"))
parser.add_argument("--sqlite-path", default="data/workbench.sqlite")
parser.add_argument("--reset", action="store_true", help="Drop and recreate the target PostgreSQL schema before copying.")
parser.add_argument("--batch-size", type=int, default=100_000)
parser.add_argument("--strict-sidecars", action="store_true", help="Fail when a referenced sidecar file is missing.")
args = parser.parse_args()
if not args.postgres_url:
from app.config import settings as parsed_settings
if parsed_settings.is_postgresql_database:
args.postgres_url = parsed_settings.database_url
if not args.postgres_url:
parser.error("--postgres-url or POSTGRES_DATABASE_URL is required")
if not str(args.postgres_url).startswith(("postgresql://", "postgresql+psycopg://")):
parser.error("--postgres-url must be a PostgreSQL SQLAlchemy URL")
sqlite_path = Path(args.sqlite_path)
if not sqlite_path.exists():
parser.error(f"SQLite database does not exist: {sqlite_path}")
os.environ["DATABASE_URL"] = str(args.postgres_url)
from app import models # noqa: F401
from app.db import Base, SessionLocal, _ensure_database_extensions, _ensure_runtime_columns, _ensure_runtime_indexes, engine, init_db
from app.gtfs_storage import GTFS_STORAGE_METADATA_KEY, GTFS_STORAGE_MAIN, GTFS_STOP_TIME_COLUMNS
from app.osm_storage import OSM_FEATURE_COLUMNS, OSM_STORAGE_MAIN, OSM_STORAGE_METADATA_KEY
from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries
if args.reset:
print("Resetting PostgreSQL schema without secondary indexes...")
_ensure_database_extensions()
Base.metadata.drop_all(bind=engine)
Base.metadata.create_all(bind=engine)
_ensure_runtime_columns()
else:
print("Initializing PostgreSQL schema...")
init_db()
source = sqlite3.connect(sqlite_path)
source.row_factory = sqlite3.Row
try:
source_tables = _sqlite_tables(source)
target_columns = {name: list(table.c.keys()) for name, table in Base.metadata.tables.items()}
bool_columns = {
name: _boolean_columns(table)
for name, table in Base.metadata.tables.items()
}
import psycopg
with psycopg.connect(_psycopg_url(str(args.postgres_url))) as pg:
copied_tables: list[str] = []
for table_name in TABLE_ORDER:
if table_name not in source_tables or table_name not in target_columns:
continue
copied = _copy_sqlite_table(
source,
pg,
table_name=table_name,
target_columns=target_columns[table_name],
bool_columns=bool_columns.get(table_name, set()),
batch_size=max(1_000, int(args.batch_size)),
)
copied_tables.append(table_name)
print(f"Copied {copied:,} rows from {table_name}.")
pg.commit()
_reset_sequences(pg, target_columns)
pg.commit()
sidecar_results = _copy_sidecars(
source,
pg,
sqlite_base_dir=sqlite_path.parent,
batch_size=max(1_000, int(args.batch_size)),
strict=args.strict_sidecars,
osm_columns=OSM_FEATURE_COLUMNS,
gtfs_stop_time_columns=GTFS_STOP_TIME_COLUMNS,
gtfs_storage_key=GTFS_STORAGE_METADATA_KEY,
osm_storage_key=OSM_STORAGE_METADATA_KEY,
gtfs_main_mode=GTFS_STORAGE_MAIN,
osm_main_mode=OSM_STORAGE_MAIN,
)
_reset_sequences(pg, target_columns)
pg.commit()
print("Refreshing PostGIS geometries and indexes...")
with SessionLocal() as session:
refresh_postgis_geometries(session, only_missing=False)
session.commit()
_ensure_runtime_indexes()
with SessionLocal() as session:
analyze_postgresql_tables(session, copied_tables + ["osm_features", "gtfs_stop_times"])
session.commit()
print("Migration complete.")
for message in sidecar_results:
print(message)
return 0
finally:
source.close()
def _copy_sqlite_table(
source: sqlite3.Connection,
pg,
*,
table_name: str,
target_columns: list[str],
bool_columns: set[str],
batch_size: int,
) -> int:
source_columns = [column for column in _sqlite_columns(source, table_name) if column in target_columns]
if not source_columns:
return 0
total = 0
select_sql = f"SELECT {', '.join(_quote_sqlite(column) for column in source_columns)} FROM {_quote_sqlite(table_name)}"
cursor = source.execute(select_sql)
try:
while True:
rows = cursor.fetchmany(batch_size)
if not rows:
break
_copy_rows(
pg,
table_name=table_name,
columns=source_columns,
rows=(_row_values(row, source_columns, bool_columns) for row in rows),
)
total += len(rows)
finally:
cursor.close()
return total
def _copy_sidecars(
source: sqlite3.Connection,
pg,
*,
sqlite_base_dir: Path,
batch_size: int,
strict: bool,
osm_columns: list[str],
gtfs_stop_time_columns: list[str],
gtfs_storage_key: str,
osm_storage_key: str,
gtfs_main_mode: str,
osm_main_mode: str,
) -> list[str]:
messages: list[str] = []
dataset_rows = source.execute("SELECT id, kind, metadata_json FROM datasets ORDER BY id").fetchall()
for row in dataset_rows:
dataset_id = int(row["id"])
metadata = _json_dict(row["metadata_json"])
gtfs_storage = metadata.get(gtfs_storage_key)
if isinstance(gtfs_storage, dict) and _storage_uses_sidecar(gtfs_storage, "gtfs_stop_times"):
path = _resolve_sidecar_path(gtfs_storage.get("sidecar_path"), sqlite_base_dir)
if path is None or not path.exists():
message = f"Missing GTFS sidecar for dataset #{dataset_id}: {path}"
if strict:
raise FileNotFoundError(message)
messages.append(message)
else:
existing = _pg_scalar(pg, "SELECT COUNT(*) FROM gtfs_stop_times WHERE dataset_id = %s", [dataset_id])
if int(existing or 0) > 0:
messages.append(f"Skipped GTFS sidecar for dataset #{dataset_id}; target already has stop_times rows.")
else:
copied = _copy_gtfs_sidecar(pg, dataset_id, path, gtfs_stop_time_columns, batch_size)
_mark_storage_main(metadata, gtfs_storage_key, "gtfs_stop_times", gtfs_main_mode, path)
_update_dataset_metadata(pg, dataset_id, metadata)
pg.commit()
messages.append(f"Copied {copied:,} GTFS stop_times rows from {path}.")
osm_storage = metadata.get(osm_storage_key)
if isinstance(osm_storage, dict) and _storage_uses_sidecar(osm_storage, "osm_features"):
path = _resolve_sidecar_path(osm_storage.get("sidecar_path"), sqlite_base_dir)
if path is None or not path.exists():
message = f"Missing OSM sidecar for dataset #{dataset_id}: {path}"
if strict:
raise FileNotFoundError(message)
messages.append(message)
else:
copied, inserted = _copy_osm_sidecar(pg, dataset_id, path, osm_columns, batch_size)
_mark_storage_main(metadata, osm_storage_key, "osm_features", osm_main_mode, path)
_update_dataset_metadata(pg, dataset_id, metadata)
pg.commit()
messages.append(f"Copied {copied:,} OSM sidecar rows from {path}; inserted {inserted:,} new main rows.")
return messages
def _copy_gtfs_sidecar(pg, dataset_id: int, path: Path, columns: list[str], batch_size: int) -> int:
source = sqlite3.connect(path)
source.row_factory = sqlite3.Row
try:
available = _sqlite_columns(source, "gtfs_stop_times")
select_columns = [(_quote_sqlite(column) if column in available else f"NULL AS {_quote_sqlite(column)}") for column in columns]
total = 0
cursor = source.execute(f"SELECT {', '.join(select_columns)} FROM gtfs_stop_times")
try:
while True:
rows = cursor.fetchmany(batch_size)
if not rows:
break
_copy_rows(
pg,
table_name="gtfs_stop_times",
columns=["dataset_id", *columns],
rows=([dataset_id, *[row[column] for column in columns]] for row in rows),
)
total += len(rows)
finally:
cursor.close()
return total
finally:
source.close()
def _copy_osm_sidecar(pg, dataset_id: int, path: Path, columns: list[str], batch_size: int) -> tuple[int, int]:
source = sqlite3.connect(path)
source.row_factory = sqlite3.Row
temp_table = "tmp_osm_sidecar_features"
try:
available = _sqlite_columns(source, "osm_features")
payload_columns = [column for column in columns if column != "dataset_id"]
select_columns = [
(_quote_sqlite(column) if column in available else f"NULL AS {_quote_sqlite(column)}")
for column in payload_columns
]
with pg.cursor() as cur:
cur.execute(f"DROP TABLE IF EXISTS pg_temp.{_quote_pg(temp_table)}")
cur.execute(f"CREATE TEMP TABLE {temp_table} (LIKE osm_features INCLUDING DEFAULTS) ON COMMIT DROP")
copied = 0
cursor = source.execute(f"SELECT {', '.join(select_columns)} FROM osm_features")
try:
while True:
rows = cursor.fetchmany(batch_size)
if not rows:
break
_copy_rows(
pg,
table_name=temp_table,
columns=columns,
rows=([dataset_id, *[row[column] for column in payload_columns]] for row in rows),
)
copied += len(rows)
finally:
cursor.close()
with pg.cursor() as cur:
column_sql = ", ".join(_quote_pg(column) for column in columns)
cur.execute(
f"""
INSERT INTO osm_features ({column_sql})
SELECT {column_sql}
FROM {temp_table}
ON CONFLICT ON CONSTRAINT uq_osm_feature_dataset_type_id DO NOTHING
"""
)
inserted = int(cur.rowcount or 0)
cur.execute(f"DROP TABLE IF EXISTS pg_temp.{_quote_pg(temp_table)}")
return copied, inserted
finally:
source.close()
def _copy_rows(pg, *, table_name: str, columns: list[str], rows: Iterable[Iterable[Any]]) -> None:
column_sql = ", ".join(_quote_pg(column) for column in columns)
with pg.cursor() as cur:
with cur.copy(f"COPY {_quote_pg(table_name)} ({column_sql}) FROM STDIN") as copy:
for row in rows:
copy.write_row(list(row))
def _reset_sequences(pg, target_columns: dict[str, list[str]]) -> None:
with pg.cursor() as cur:
for table_name, columns in target_columns.items():
if "id" not in columns:
continue
cur.execute("SELECT pg_get_serial_sequence(%s, 'id')", [table_name])
row = cur.fetchone()
sequence_name = row[0] if row else None
if not sequence_name:
continue
cur.execute(
"""
SELECT setval(
%s,
COALESCE((SELECT MAX(id) FROM {table}), 1),
(SELECT MAX(id) IS NOT NULL FROM {table})
)
""".format(table=_quote_pg(table_name)),
[sequence_name],
)
def _mark_storage_main(metadata: dict[str, Any], key: str, table_name: str, mode: str, sidecar_path: Path) -> None:
storage = metadata.setdefault(key, {})
if not isinstance(storage, dict):
storage = {}
metadata[key] = storage
storage["mode"] = mode
storage["tables"] = {table_name: "main"}
storage["storage_status"] = "ready"
storage["legacy_sidecar_path"] = str(sidecar_path)
storage.pop("sidecar_path", None)
storage.pop("sidecar_status", None)
def _update_dataset_metadata(pg, dataset_id: int, metadata: dict[str, Any]) -> None:
with pg.cursor() as cur:
cur.execute(
"UPDATE datasets SET metadata_json = %s WHERE id = %s",
[json.dumps(metadata, separators=(",", ":")), dataset_id],
)
def _pg_scalar(pg, sql: str, params: list[Any]) -> Any:
with pg.cursor() as cur:
cur.execute(sql, params)
row = cur.fetchone()
return row[0] if row else None
def _sqlite_tables(connection: sqlite3.Connection) -> set[str]:
return {
str(row["name"])
for row in connection.execute("SELECT name FROM sqlite_master WHERE type = 'table'").fetchall()
}
def _sqlite_columns(connection: sqlite3.Connection, table_name: str) -> list[str]:
return [str(row["name"]) for row in connection.execute(f"PRAGMA table_info({_quote_sqlite(table_name)})").fetchall()]
def _row_values(row: sqlite3.Row, columns: list[str], bool_columns: set[str]) -> list[Any]:
values: list[Any] = []
for column in columns:
value = row[column]
if column in bool_columns and value is not None:
value = bool(value)
values.append(value)
return values
def _boolean_columns(table) -> set[str]:
columns: set[str] = set()
for column in table.c:
try:
if column.type.python_type is bool:
columns.add(str(column.name))
except NotImplementedError:
continue
return columns
def _storage_uses_sidecar(storage: dict[str, Any], table_name: str) -> bool:
tables = storage.get("tables")
if isinstance(tables, dict) and tables.get(table_name) == "sidecar":
return True
return str(storage.get("mode") or "").startswith("sidecar")
def _resolve_sidecar_path(value: Any, base_dir: Path) -> Path | None:
if not value:
return None
path = Path(str(value))
if path.is_absolute():
return path
if path.exists():
return path
source_relative = base_dir / path
if source_relative.exists():
return source_relative
repo_relative = base_dir.parent / path
if repo_relative.exists():
return repo_relative
return path
def _json_dict(value: str | None) -> dict[str, Any]:
try:
data = json.loads(value or "{}")
except json.JSONDecodeError:
return {}
return data if isinstance(data, dict) else {}
def _psycopg_url(value: str) -> str:
return value.replace("postgresql+psycopg://", "postgresql://", 1)
def _quote_pg(identifier: str) -> str:
return '"' + identifier.replace('"', '""') + '"'
def _quote_sqlite(identifier: str) -> str:
return '"' + identifier.replace('"', '""') + '"'
if __name__ == "__main__":
sys.exit(main())