453 lines
16 KiB
Python
453 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sqlite3
|
|
import sys
|
|
from pathlib import Path
|
|
from typing import Any, Iterable
|
|
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parents[1]
|
|
if str(REPO_ROOT) not in sys.path:
|
|
sys.path.insert(0, str(REPO_ROOT))
|
|
|
|
TABLE_ORDER = [
|
|
"source_catalog_entries",
|
|
"sources",
|
|
"datasets",
|
|
"source_update_checks",
|
|
"osm_diff_states",
|
|
"jobs",
|
|
"job_events",
|
|
"pipeline_runs",
|
|
"gtfs_agencies",
|
|
"gtfs_stops",
|
|
"gtfs_routes",
|
|
"gtfs_trips",
|
|
"gtfs_calendars",
|
|
"gtfs_calendar_dates",
|
|
"gtfs_shapes",
|
|
"gtfs_stop_times",
|
|
"osm_features",
|
|
"canonical_stops",
|
|
"canonical_stop_links",
|
|
"route_matches",
|
|
"match_rules",
|
|
"route_patterns",
|
|
"route_pattern_stops",
|
|
"gtfs_route_pattern_links",
|
|
"gtfs_trip_route_pattern_links",
|
|
"travel_requests",
|
|
"itineraries",
|
|
"itinerary_legs",
|
|
]
|
|
|
|
|
|
def main() -> int:
|
|
parser = argparse.ArgumentParser(description="Migrate a Mobility Workbench SQLite database to PostgreSQL/PostGIS.")
|
|
parser.add_argument("--postgres-url", default=os.environ.get("POSTGRES_DATABASE_URL") or os.environ.get("DATABASE_URL"))
|
|
parser.add_argument("--sqlite-path", default="data/workbench.sqlite")
|
|
parser.add_argument("--reset", action="store_true", help="Drop and recreate the target PostgreSQL schema before copying.")
|
|
parser.add_argument("--batch-size", type=int, default=100_000)
|
|
parser.add_argument("--strict-sidecars", action="store_true", help="Fail when a referenced sidecar file is missing.")
|
|
args = parser.parse_args()
|
|
|
|
if not args.postgres_url:
|
|
from app.config import settings as parsed_settings
|
|
|
|
if parsed_settings.is_postgresql_database:
|
|
args.postgres_url = parsed_settings.database_url
|
|
|
|
if not args.postgres_url:
|
|
parser.error("--postgres-url or POSTGRES_DATABASE_URL is required")
|
|
if not str(args.postgres_url).startswith(("postgresql://", "postgresql+psycopg://")):
|
|
parser.error("--postgres-url must be a PostgreSQL SQLAlchemy URL")
|
|
|
|
sqlite_path = Path(args.sqlite_path)
|
|
if not sqlite_path.exists():
|
|
parser.error(f"SQLite database does not exist: {sqlite_path}")
|
|
|
|
os.environ["DATABASE_URL"] = str(args.postgres_url)
|
|
|
|
from app import models # noqa: F401
|
|
from app.db import Base, SessionLocal, _ensure_database_extensions, _ensure_runtime_columns, _ensure_runtime_indexes, engine, init_db
|
|
from app.gtfs_storage import GTFS_STORAGE_METADATA_KEY, GTFS_STORAGE_MAIN, GTFS_STOP_TIME_COLUMNS
|
|
from app.osm_storage import OSM_FEATURE_COLUMNS, OSM_STORAGE_MAIN, OSM_STORAGE_METADATA_KEY
|
|
from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries
|
|
|
|
if args.reset:
|
|
print("Resetting PostgreSQL schema without secondary indexes...")
|
|
_ensure_database_extensions()
|
|
Base.metadata.drop_all(bind=engine)
|
|
Base.metadata.create_all(bind=engine)
|
|
_ensure_runtime_columns()
|
|
else:
|
|
print("Initializing PostgreSQL schema...")
|
|
init_db()
|
|
|
|
source = sqlite3.connect(sqlite_path)
|
|
source.row_factory = sqlite3.Row
|
|
try:
|
|
source_tables = _sqlite_tables(source)
|
|
target_columns = {name: list(table.c.keys()) for name, table in Base.metadata.tables.items()}
|
|
bool_columns = {
|
|
name: _boolean_columns(table)
|
|
for name, table in Base.metadata.tables.items()
|
|
}
|
|
|
|
import psycopg
|
|
|
|
with psycopg.connect(_psycopg_url(str(args.postgres_url))) as pg:
|
|
copied_tables: list[str] = []
|
|
for table_name in TABLE_ORDER:
|
|
if table_name not in source_tables or table_name not in target_columns:
|
|
continue
|
|
copied = _copy_sqlite_table(
|
|
source,
|
|
pg,
|
|
table_name=table_name,
|
|
target_columns=target_columns[table_name],
|
|
bool_columns=bool_columns.get(table_name, set()),
|
|
batch_size=max(1_000, int(args.batch_size)),
|
|
)
|
|
copied_tables.append(table_name)
|
|
print(f"Copied {copied:,} rows from {table_name}.")
|
|
pg.commit()
|
|
|
|
_reset_sequences(pg, target_columns)
|
|
pg.commit()
|
|
|
|
sidecar_results = _copy_sidecars(
|
|
source,
|
|
pg,
|
|
sqlite_base_dir=sqlite_path.parent,
|
|
batch_size=max(1_000, int(args.batch_size)),
|
|
strict=args.strict_sidecars,
|
|
osm_columns=OSM_FEATURE_COLUMNS,
|
|
gtfs_stop_time_columns=GTFS_STOP_TIME_COLUMNS,
|
|
gtfs_storage_key=GTFS_STORAGE_METADATA_KEY,
|
|
osm_storage_key=OSM_STORAGE_METADATA_KEY,
|
|
gtfs_main_mode=GTFS_STORAGE_MAIN,
|
|
osm_main_mode=OSM_STORAGE_MAIN,
|
|
)
|
|
_reset_sequences(pg, target_columns)
|
|
pg.commit()
|
|
|
|
print("Refreshing PostGIS geometries and indexes...")
|
|
with SessionLocal() as session:
|
|
refresh_postgis_geometries(session, only_missing=False)
|
|
session.commit()
|
|
_ensure_runtime_indexes()
|
|
with SessionLocal() as session:
|
|
analyze_postgresql_tables(session, copied_tables + ["osm_features", "gtfs_stop_times"])
|
|
session.commit()
|
|
|
|
print("Migration complete.")
|
|
for message in sidecar_results:
|
|
print(message)
|
|
return 0
|
|
finally:
|
|
source.close()
|
|
|
|
|
|
def _copy_sqlite_table(
|
|
source: sqlite3.Connection,
|
|
pg,
|
|
*,
|
|
table_name: str,
|
|
target_columns: list[str],
|
|
bool_columns: set[str],
|
|
batch_size: int,
|
|
) -> int:
|
|
source_columns = [column for column in _sqlite_columns(source, table_name) if column in target_columns]
|
|
if not source_columns:
|
|
return 0
|
|
total = 0
|
|
select_sql = f"SELECT {', '.join(_quote_sqlite(column) for column in source_columns)} FROM {_quote_sqlite(table_name)}"
|
|
cursor = source.execute(select_sql)
|
|
try:
|
|
while True:
|
|
rows = cursor.fetchmany(batch_size)
|
|
if not rows:
|
|
break
|
|
_copy_rows(
|
|
pg,
|
|
table_name=table_name,
|
|
columns=source_columns,
|
|
rows=(_row_values(row, source_columns, bool_columns) for row in rows),
|
|
)
|
|
total += len(rows)
|
|
finally:
|
|
cursor.close()
|
|
return total
|
|
|
|
|
|
def _copy_sidecars(
|
|
source: sqlite3.Connection,
|
|
pg,
|
|
*,
|
|
sqlite_base_dir: Path,
|
|
batch_size: int,
|
|
strict: bool,
|
|
osm_columns: list[str],
|
|
gtfs_stop_time_columns: list[str],
|
|
gtfs_storage_key: str,
|
|
osm_storage_key: str,
|
|
gtfs_main_mode: str,
|
|
osm_main_mode: str,
|
|
) -> list[str]:
|
|
messages: list[str] = []
|
|
dataset_rows = source.execute("SELECT id, kind, metadata_json FROM datasets ORDER BY id").fetchall()
|
|
for row in dataset_rows:
|
|
dataset_id = int(row["id"])
|
|
metadata = _json_dict(row["metadata_json"])
|
|
|
|
gtfs_storage = metadata.get(gtfs_storage_key)
|
|
if isinstance(gtfs_storage, dict) and _storage_uses_sidecar(gtfs_storage, "gtfs_stop_times"):
|
|
path = _resolve_sidecar_path(gtfs_storage.get("sidecar_path"), sqlite_base_dir)
|
|
if path is None or not path.exists():
|
|
message = f"Missing GTFS sidecar for dataset #{dataset_id}: {path}"
|
|
if strict:
|
|
raise FileNotFoundError(message)
|
|
messages.append(message)
|
|
else:
|
|
existing = _pg_scalar(pg, "SELECT COUNT(*) FROM gtfs_stop_times WHERE dataset_id = %s", [dataset_id])
|
|
if int(existing or 0) > 0:
|
|
messages.append(f"Skipped GTFS sidecar for dataset #{dataset_id}; target already has stop_times rows.")
|
|
else:
|
|
copied = _copy_gtfs_sidecar(pg, dataset_id, path, gtfs_stop_time_columns, batch_size)
|
|
_mark_storage_main(metadata, gtfs_storage_key, "gtfs_stop_times", gtfs_main_mode, path)
|
|
_update_dataset_metadata(pg, dataset_id, metadata)
|
|
pg.commit()
|
|
messages.append(f"Copied {copied:,} GTFS stop_times rows from {path}.")
|
|
|
|
osm_storage = metadata.get(osm_storage_key)
|
|
if isinstance(osm_storage, dict) and _storage_uses_sidecar(osm_storage, "osm_features"):
|
|
path = _resolve_sidecar_path(osm_storage.get("sidecar_path"), sqlite_base_dir)
|
|
if path is None or not path.exists():
|
|
message = f"Missing OSM sidecar for dataset #{dataset_id}: {path}"
|
|
if strict:
|
|
raise FileNotFoundError(message)
|
|
messages.append(message)
|
|
else:
|
|
copied, inserted = _copy_osm_sidecar(pg, dataset_id, path, osm_columns, batch_size)
|
|
_mark_storage_main(metadata, osm_storage_key, "osm_features", osm_main_mode, path)
|
|
_update_dataset_metadata(pg, dataset_id, metadata)
|
|
pg.commit()
|
|
messages.append(f"Copied {copied:,} OSM sidecar rows from {path}; inserted {inserted:,} new main rows.")
|
|
return messages
|
|
|
|
|
|
def _copy_gtfs_sidecar(pg, dataset_id: int, path: Path, columns: list[str], batch_size: int) -> int:
|
|
source = sqlite3.connect(path)
|
|
source.row_factory = sqlite3.Row
|
|
try:
|
|
available = _sqlite_columns(source, "gtfs_stop_times")
|
|
select_columns = [(_quote_sqlite(column) if column in available else f"NULL AS {_quote_sqlite(column)}") for column in columns]
|
|
total = 0
|
|
cursor = source.execute(f"SELECT {', '.join(select_columns)} FROM gtfs_stop_times")
|
|
try:
|
|
while True:
|
|
rows = cursor.fetchmany(batch_size)
|
|
if not rows:
|
|
break
|
|
_copy_rows(
|
|
pg,
|
|
table_name="gtfs_stop_times",
|
|
columns=["dataset_id", *columns],
|
|
rows=([dataset_id, *[row[column] for column in columns]] for row in rows),
|
|
)
|
|
total += len(rows)
|
|
finally:
|
|
cursor.close()
|
|
return total
|
|
finally:
|
|
source.close()
|
|
|
|
|
|
def _copy_osm_sidecar(pg, dataset_id: int, path: Path, columns: list[str], batch_size: int) -> tuple[int, int]:
|
|
source = sqlite3.connect(path)
|
|
source.row_factory = sqlite3.Row
|
|
temp_table = "tmp_osm_sidecar_features"
|
|
try:
|
|
available = _sqlite_columns(source, "osm_features")
|
|
payload_columns = [column for column in columns if column != "dataset_id"]
|
|
select_columns = [
|
|
(_quote_sqlite(column) if column in available else f"NULL AS {_quote_sqlite(column)}")
|
|
for column in payload_columns
|
|
]
|
|
with pg.cursor() as cur:
|
|
cur.execute(f"DROP TABLE IF EXISTS pg_temp.{_quote_pg(temp_table)}")
|
|
cur.execute(f"CREATE TEMP TABLE {temp_table} (LIKE osm_features INCLUDING DEFAULTS) ON COMMIT DROP")
|
|
copied = 0
|
|
cursor = source.execute(f"SELECT {', '.join(select_columns)} FROM osm_features")
|
|
try:
|
|
while True:
|
|
rows = cursor.fetchmany(batch_size)
|
|
if not rows:
|
|
break
|
|
_copy_rows(
|
|
pg,
|
|
table_name=temp_table,
|
|
columns=columns,
|
|
rows=([dataset_id, *[row[column] for column in payload_columns]] for row in rows),
|
|
)
|
|
copied += len(rows)
|
|
finally:
|
|
cursor.close()
|
|
with pg.cursor() as cur:
|
|
column_sql = ", ".join(_quote_pg(column) for column in columns)
|
|
cur.execute(
|
|
f"""
|
|
INSERT INTO osm_features ({column_sql})
|
|
SELECT {column_sql}
|
|
FROM {temp_table}
|
|
ON CONFLICT ON CONSTRAINT uq_osm_feature_dataset_type_id DO NOTHING
|
|
"""
|
|
)
|
|
inserted = int(cur.rowcount or 0)
|
|
cur.execute(f"DROP TABLE IF EXISTS pg_temp.{_quote_pg(temp_table)}")
|
|
return copied, inserted
|
|
finally:
|
|
source.close()
|
|
|
|
|
|
def _copy_rows(pg, *, table_name: str, columns: list[str], rows: Iterable[Iterable[Any]]) -> None:
|
|
column_sql = ", ".join(_quote_pg(column) for column in columns)
|
|
with pg.cursor() as cur:
|
|
with cur.copy(f"COPY {_quote_pg(table_name)} ({column_sql}) FROM STDIN") as copy:
|
|
for row in rows:
|
|
copy.write_row(list(row))
|
|
|
|
|
|
def _reset_sequences(pg, target_columns: dict[str, list[str]]) -> None:
|
|
with pg.cursor() as cur:
|
|
for table_name, columns in target_columns.items():
|
|
if "id" not in columns:
|
|
continue
|
|
cur.execute("SELECT pg_get_serial_sequence(%s, 'id')", [table_name])
|
|
row = cur.fetchone()
|
|
sequence_name = row[0] if row else None
|
|
if not sequence_name:
|
|
continue
|
|
cur.execute(
|
|
"""
|
|
SELECT setval(
|
|
%s,
|
|
COALESCE((SELECT MAX(id) FROM {table}), 1),
|
|
(SELECT MAX(id) IS NOT NULL FROM {table})
|
|
)
|
|
""".format(table=_quote_pg(table_name)),
|
|
[sequence_name],
|
|
)
|
|
|
|
|
|
def _mark_storage_main(metadata: dict[str, Any], key: str, table_name: str, mode: str, sidecar_path: Path) -> None:
|
|
storage = metadata.setdefault(key, {})
|
|
if not isinstance(storage, dict):
|
|
storage = {}
|
|
metadata[key] = storage
|
|
storage["mode"] = mode
|
|
storage["tables"] = {table_name: "main"}
|
|
storage["storage_status"] = "ready"
|
|
storage["legacy_sidecar_path"] = str(sidecar_path)
|
|
storage.pop("sidecar_path", None)
|
|
storage.pop("sidecar_status", None)
|
|
|
|
|
|
def _update_dataset_metadata(pg, dataset_id: int, metadata: dict[str, Any]) -> None:
|
|
with pg.cursor() as cur:
|
|
cur.execute(
|
|
"UPDATE datasets SET metadata_json = %s WHERE id = %s",
|
|
[json.dumps(metadata, separators=(",", ":")), dataset_id],
|
|
)
|
|
|
|
|
|
def _pg_scalar(pg, sql: str, params: list[Any]) -> Any:
|
|
with pg.cursor() as cur:
|
|
cur.execute(sql, params)
|
|
row = cur.fetchone()
|
|
return row[0] if row else None
|
|
|
|
|
|
def _sqlite_tables(connection: sqlite3.Connection) -> set[str]:
|
|
return {
|
|
str(row["name"])
|
|
for row in connection.execute("SELECT name FROM sqlite_master WHERE type = 'table'").fetchall()
|
|
}
|
|
|
|
|
|
def _sqlite_columns(connection: sqlite3.Connection, table_name: str) -> list[str]:
|
|
return [str(row["name"]) for row in connection.execute(f"PRAGMA table_info({_quote_sqlite(table_name)})").fetchall()]
|
|
|
|
|
|
def _row_values(row: sqlite3.Row, columns: list[str], bool_columns: set[str]) -> list[Any]:
|
|
values: list[Any] = []
|
|
for column in columns:
|
|
value = row[column]
|
|
if column in bool_columns and value is not None:
|
|
value = bool(value)
|
|
values.append(value)
|
|
return values
|
|
|
|
|
|
def _boolean_columns(table) -> set[str]:
|
|
columns: set[str] = set()
|
|
for column in table.c:
|
|
try:
|
|
if column.type.python_type is bool:
|
|
columns.add(str(column.name))
|
|
except NotImplementedError:
|
|
continue
|
|
return columns
|
|
|
|
|
|
def _storage_uses_sidecar(storage: dict[str, Any], table_name: str) -> bool:
|
|
tables = storage.get("tables")
|
|
if isinstance(tables, dict) and tables.get(table_name) == "sidecar":
|
|
return True
|
|
return str(storage.get("mode") or "").startswith("sidecar")
|
|
|
|
|
|
def _resolve_sidecar_path(value: Any, base_dir: Path) -> Path | None:
|
|
if not value:
|
|
return None
|
|
path = Path(str(value))
|
|
if path.is_absolute():
|
|
return path
|
|
if path.exists():
|
|
return path
|
|
source_relative = base_dir / path
|
|
if source_relative.exists():
|
|
return source_relative
|
|
repo_relative = base_dir.parent / path
|
|
if repo_relative.exists():
|
|
return repo_relative
|
|
return path
|
|
|
|
|
|
def _json_dict(value: str | None) -> dict[str, Any]:
|
|
try:
|
|
data = json.loads(value or "{}")
|
|
except json.JSONDecodeError:
|
|
return {}
|
|
return data if isinstance(data, dict) else {}
|
|
|
|
|
|
def _psycopg_url(value: str) -> str:
|
|
return value.replace("postgresql+psycopg://", "postgresql://", 1)
|
|
|
|
|
|
def _quote_pg(identifier: str) -> str:
|
|
return '"' + identifier.replace('"', '""') + '"'
|
|
|
|
|
|
def _quote_sqlite(identifier: str) -> str:
|
|
return '"' + identifier.replace('"', '""') + '"'
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(main())
|