from __future__ import annotations from collections.abc import Iterable from sqlalchemy import text from sqlalchemy.orm import Session from app.config import settings POSTGIS_GEOMETRY_TABLES = { "osm_features", "gtfs_routes", "gtfs_shapes", "gtfs_stops", "canonical_stops", "route_patterns", "osm_addresses", "routing_nodes", "routing_edges", } def using_postgresql() -> bool: return settings.is_postgresql_database def refresh_postgis_geometries( session: Session, *, dataset_id: int | None = None, tables: Iterable[str] | None = None, only_missing: bool = True, ) -> None: if not using_postgresql(): return selected = set(tables or POSTGIS_GEOMETRY_TABLES) unknown = selected - POSTGIS_GEOMETRY_TABLES if unknown: raise ValueError(f"Unsupported PostGIS geometry table(s): {', '.join(sorted(unknown))}") if "osm_features" in selected: _refresh_geojson_geometry(session, "osm_features", dataset_id=dataset_id, only_missing=only_missing) if "gtfs_routes" in selected: _refresh_geojson_geometry(session, "gtfs_routes", dataset_id=dataset_id, only_missing=only_missing) if "gtfs_shapes" in selected: _refresh_geojson_geometry(session, "gtfs_shapes", dataset_id=dataset_id, only_missing=only_missing) if "route_patterns" in selected: _refresh_geojson_geometry(session, "route_patterns", dataset_id=None, only_missing=only_missing) if "osm_addresses" in selected: _refresh_address_geometry(session, dataset_id=dataset_id, only_missing=only_missing) if "gtfs_stops" in selected: _refresh_point_geometry(session, "gtfs_stops", dataset_id=dataset_id, only_missing=only_missing) if "canonical_stops" in selected: _refresh_point_geometry(session, "canonical_stops", dataset_id=None, only_missing=only_missing) if "routing_nodes" in selected: _refresh_point_geometry(session, "routing_nodes", dataset_id=dataset_id, only_missing=only_missing) if "routing_edges" in selected: _refresh_routing_edge_geometry(session, dataset_id=dataset_id, only_missing=only_missing) def analyze_postgresql_tables(session: Session, tables: Iterable[str]) -> None: if not using_postgresql(): return for table in tables: session.execute(text(f"ANALYZE {table}")) def _refresh_geojson_geometry(session: Session, table: str, *, dataset_id: int | None, only_missing: bool) -> None: where = ["geometry_geojson IS NOT NULL", "geometry_geojson <> ''"] params: dict[str, object] = {} if dataset_id is not None: where.append("dataset_id = :dataset_id") params["dataset_id"] = int(dataset_id) if only_missing: where.append("geom IS NULL") session.execute( text( f""" UPDATE {table} SET geom = ST_SetSRID(ST_GeomFromGeoJSON(geometry_geojson), 4326) WHERE {" AND ".join(where)} """ ), params, ) def _refresh_point_geometry(session: Session, table: str, *, dataset_id: int | None, only_missing: bool) -> None: where = ["lon IS NOT NULL", "lat IS NOT NULL"] params: dict[str, object] = {} if dataset_id is not None: where.append("dataset_id = :dataset_id") params["dataset_id"] = int(dataset_id) if only_missing: where.append("geom IS NULL") session.execute( text( f""" UPDATE {table} SET geom = ST_SetSRID(ST_MakePoint(lon, lat), 4326) WHERE {" AND ".join(where)} """ ), params, ) def _refresh_address_geometry(session: Session, *, dataset_id: int | None, only_missing: bool) -> None: _refresh_point_geometry(session, "osm_addresses", dataset_id=dataset_id, only_missing=only_missing) where = ["geometry_geojson IS NOT NULL", "geometry_geojson <> ''"] params: dict[str, object] = {} if dataset_id is not None: where.append("dataset_id = :dataset_id") params["dataset_id"] = int(dataset_id) if only_missing: where.append("area_geom IS NULL") session.execute( text( f""" UPDATE osm_addresses SET area_geom = ST_SetSRID(ST_GeomFromGeoJSON(geometry_geojson), 4326) WHERE {" AND ".join(where)} """ ), params, ) def _refresh_routing_edge_geometry(session: Session, *, dataset_id: int | None, only_missing: bool) -> None: where = [ "source_lon IS NOT NULL", "source_lat IS NOT NULL", "target_lon IS NOT NULL", "target_lat IS NOT NULL", ] params: dict[str, object] = {} if dataset_id is not None: where.append("dataset_id = :dataset_id") params["dataset_id"] = int(dataset_id) if only_missing: where.append("geom IS NULL") session.execute( text( f""" UPDATE routing_edges SET geom = ST_SetSRID( ST_MakeLine( ST_MakePoint(source_lon, source_lat), ST_MakePoint(target_lon, target_lat) ), 4326 ) WHERE {" AND ".join(where)} """ ), params, )