from __future__ import annotations import csv import io import json import zipfile from pathlib import Path from datetime import datetime, timezone from sqlalchemy import delete, select from sqlalchemy.orm import Session from app.config import settings from app.db import init_db from app.models import ( Dataset, CanonicalStop, CanonicalStopLink, GtfsAgency, GtfsCalendar, GtfsCalendarDate, GtfsRoute, GtfsRoutePatternLink, GtfsShape, GtfsStop, GtfsStopTime, GtfsTripRoutePatternLink, GtfsTrip, Itinerary, ItineraryLeg, Job, JobEvent, MatchRule, OsmDiffState, OsmFeature, PipelineRun, RouteMatch, RoutePattern, RoutePatternStop, RoutingEdge, RoutingNode, Source, SourceCatalogEntry, SourceUpdateCheck, TravelRequest, ) from app.pipeline.matcher import run_route_matching from app.pipeline.route_layer import rebuild_route_layer from app.pipeline.run import run_source def load_sample_project(session: Session, *, preserve_job_id: int | None = None) -> dict: """Clear the DB, create a small Berlin-like GTFS + OSM sample, import, and match.""" init_db() clear_project_data(session, preserve_job_id=preserve_job_id, preserve_catalog=True) sample_dir = settings.data_dir / "sample" sample_dir.mkdir(parents=True, exist_ok=True) gtfs_path = sample_dir / "sample_berlin.gtfs.zip" osm_path = sample_dir / "sample_berlin_osm.geojson" create_sample_gtfs(gtfs_path) create_sample_osm_geojson(osm_path) gtfs_source = Source(name="Sample Berlin GTFS", kind="gtfs", url=str(gtfs_path), country="DE", license="sample") osm_source = Source(name="Sample Berlin OSM transport", kind="osm_geojson", url=str(osm_path), country="DE", license="sample") session.add_all([gtfs_source, osm_source]) session.flush() gtfs_dataset = run_source(session, gtfs_source) osm_dataset = run_source(session, osm_source) match_result = run_route_matching(session) route_layer_result = rebuild_route_layer(session) return { "status": "ok", "gtfs_dataset_id": gtfs_dataset.id, "osm_dataset_id": osm_dataset.id, "match_result": match_result, "route_layer_result": route_layer_result, } def clear_project_data( session: Session, *, preserve_job_id: int | None = None, preserve_catalog: bool = True, ) -> None: """Clear user/project data while optionally preserving the current queue job.""" session.execute(delete(PipelineRun)) if preserve_job_id is None: session.execute(delete(JobEvent)) session.execute(delete(Job)) else: _cancel_other_jobs_for_reset(session, preserve_job_id) for model in [ ItineraryLeg, Itinerary, TravelRequest, SourceUpdateCheck, OsmDiffState, MatchRule, RouteMatch, GtfsTripRoutePatternLink, GtfsRoutePatternLink, RoutePatternStop, RoutePattern, CanonicalStopLink, CanonicalStop, RoutingEdge, RoutingNode, GtfsStopTime, GtfsCalendarDate, GtfsCalendar, GtfsShape, GtfsTrip, GtfsRoute, GtfsStop, GtfsAgency, OsmFeature, Dataset, Source, ]: session.execute(delete(model)) if not preserve_catalog: session.execute(delete(SourceCatalogEntry)) session.flush() def _cancel_other_jobs_for_reset(session: Session, preserve_job_id: int) -> None: now = datetime.now(timezone.utc) active_statuses = {"queued", "running", "paused"} jobs = session.scalars( select(Job).where(Job.id != preserve_job_id, Job.status.in_(active_statuses)) ).all() for job in jobs: job.status = "cancelled" job.requested_action = None job.lease_owner = None job.lease_expires_at = None job.paused_at = None job.error = None job.updated_at = now job.finished_at = now session.add( JobEvent( job_id=job.id, event_type="cancelled_by_reset", message=f"Job cancelled by reset job #{preserve_job_id}.", progress_current=job.progress_current, progress_total=job.progress_total, ) ) def create_sample_gtfs(path: Path) -> None: agencies = [ {"agency_id": "BVG", "agency_name": "BVG", "agency_url": "https://example.invalid/bvg", "agency_timezone": "Europe/Berlin"}, {"agency_id": "DB", "agency_name": "DB Regio", "agency_url": "https://example.invalid/db", "agency_timezone": "Europe/Berlin"}, {"agency_id": "XAIR", "agency_name": "Example Airport Shuttle", "agency_url": "https://example.invalid/xair", "agency_timezone": "Europe/Berlin"}, ] stops = [ {"stop_id": "hbf", "stop_name": "Berlin Hauptbahnhof", "stop_lat": "52.5251", "stop_lon": "13.3696"}, {"stop_id": "friedrich", "stop_name": "Friedrichstraße", "stop_lat": "52.5201", "stop_lon": "13.3862"}, {"stop_id": "alex", "stop_name": "Alexanderplatz", "stop_lat": "52.5219", "stop_lon": "13.4132"}, {"stop_id": "ost", "stop_name": "Ostbahnhof", "stop_lat": "52.5100", "stop_lon": "13.4344"}, {"stop_id": "zoo", "stop_name": "Zoologischer Garten", "stop_lat": "52.5069", "stop_lon": "13.3320"}, {"stop_id": "wittenberg", "stop_name": "Wittenbergplatz", "stop_lat": "52.5020", "stop_lon": "13.3430"}, {"stop_id": "potsdamer", "stop_name": "Potsdamer Platz", "stop_lat": "52.5096", "stop_lon": "13.3760"}, {"stop_id": "stadtmitte", "stop_name": "Stadtmitte", "stop_lat": "52.5113", "stop_lon": "13.3907"}, {"stop_id": "reichstag", "stop_name": "Reichstag", "stop_lat": "52.5186", "stop_lon": "13.3763"}, {"stop_id": "hackescher", "stop_name": "Hackescher Markt", "stop_lat": "52.5220", "stop_lon": "13.4023"}, {"stop_id": "naturkunde", "stop_name": "Naturkundemuseum", "stop_lat": "52.5300", "stop_lon": "13.3790"}, {"stop_id": "wannsee", "stop_name": "Wannsee", "stop_lat": "52.4210", "stop_lon": "13.1797"}, {"stop_id": "kladow", "stop_name": "Kladow", "stop_lat": "52.4547", "stop_lon": "13.1439"}, {"stop_id": "airport", "stop_name": "Example Airport Terminal", "stop_lat": "52.3650", "stop_lon": "13.5100"}, ] routes = [ {"route_id": "u2", "agency_id": "BVG", "route_short_name": "U2", "route_long_name": "Pankow - Ruhleben", "route_type": "1"}, {"route_id": "re1", "agency_id": "DB", "route_short_name": "RE1", "route_long_name": "Magdeburg - Frankfurt Oder", "route_type": "2"}, {"route_id": "m5", "agency_id": "BVG", "route_short_name": "M5", "route_long_name": "Hauptbahnhof - Hohenschönhausen", "route_type": "0"}, {"route_id": "bus100", "agency_id": "BVG", "route_short_name": "100", "route_long_name": "Zoologischer Garten - Alexanderplatz", "route_type": "3"}, {"route_id": "f10", "agency_id": "BVG", "route_short_name": "F10", "route_long_name": "Wannsee - Kladow", "route_type": "4"}, {"route_id": "x99", "agency_id": "XAIR", "route_short_name": "X99", "route_long_name": "Airport Express Sample", "route_type": "3"}, ] trips = [ {"route_id": r["route_id"], "service_id": "daily", "trip_id": f"{r['route_id']}_trip", "shape_id": f"{r['route_id']}_shape"} for r in routes ] stop_sequences = { "u2_trip": ["zoo", "wittenberg", "potsdamer", "stadtmitte", "alex"], "re1_trip": ["hbf", "friedrich", "alex", "ost"], "m5_trip": ["hbf", "naturkunde", "hackescher", "alex"], "bus100_trip": ["zoo", "reichstag", "alex"], "f10_trip": ["wannsee", "kladow"], "x99_trip": ["alex", "airport"], } coords = {row["stop_id"]: (row["stop_lon"], row["stop_lat"]) for row in stops} stop_times = [] shapes = [] for trip in trips: trip_id = trip["trip_id"] for idx, stop_id in enumerate(stop_sequences[trip_id], start=1): stop_times.append( { "trip_id": trip_id, "arrival_time": f"08:{idx * 5:02d}:00", "departure_time": f"08:{idx * 5 + 1:02d}:00", "stop_id": stop_id, "stop_sequence": str(idx), } ) lon, lat = coords[stop_id] shapes.append( { "shape_id": trip["shape_id"], "shape_pt_lat": lat, "shape_pt_lon": lon, "shape_pt_sequence": str(idx), } ) calendar = [ { "service_id": "daily", "monday": "1", "tuesday": "1", "wednesday": "1", "thursday": "1", "friday": "1", "saturday": "1", "sunday": "1", "start_date": "20260101", "end_date": "20261231", } ] with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf: _write_csv(zf, "agency.txt", ["agency_id", "agency_name", "agency_url", "agency_timezone"], agencies) _write_csv(zf, "stops.txt", ["stop_id", "stop_name", "stop_lat", "stop_lon"], stops) _write_csv(zf, "routes.txt", ["route_id", "agency_id", "route_short_name", "route_long_name", "route_type"], routes) _write_csv(zf, "trips.txt", ["route_id", "service_id", "trip_id", "shape_id"], trips) _write_csv(zf, "stop_times.txt", ["trip_id", "arrival_time", "departure_time", "stop_id", "stop_sequence"], stop_times) _write_csv( zf, "calendar.txt", ["service_id", "monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday", "start_date", "end_date"], calendar, ) _write_csv(zf, "shapes.txt", ["shape_id", "shape_pt_lat", "shape_pt_lon", "shape_pt_sequence"], shapes) def _write_csv(zf: zipfile.ZipFile, name: str, fields: list[str], rows: list[dict[str, str]]) -> None: buffer = io.StringIO(newline="") writer = csv.DictWriter(buffer, fieldnames=fields) writer.writeheader() writer.writerows(rows) zf.writestr(name, buffer.getvalue()) def create_sample_osm_geojson(path: Path) -> None: def line(fid, mode, ref, name, operator, coords): return { "type": "Feature", "geometry": {"type": "LineString", "coordinates": coords}, "properties": { "osm_type": "relation", "osm_id": str(fid), "type": "route", "route": mode, "ref": ref, "name": name, "operator": operator, "network": "VBB" if operator == "BVG" else "DB", }, } def point(fid, kind, name, coords, props=None): props = props or {} props.update({"osm_type": "node", "osm_id": str(fid), "name": name}) return {"type": "Feature", "geometry": {"type": "Point", "coordinates": coords}, "properties": props} features = [ line(1002, "subway", "U2", "U2 Ruhleben - Pankow", "BVG", [[13.3320, 52.5069], [13.3430, 52.5020], [13.3760, 52.5096], [13.3907, 52.5113], [13.4132, 52.5219]]), line(2001, "train", "RE1", "RE1 Magdeburg - Frankfurt Oder", "DB Regio", [[13.3696, 52.5251], [13.3862, 52.5201], [13.4132, 52.5219], [13.4344, 52.5100]]), line(5005, "tram", "M5", "M5 Hauptbahnhof - Hohenschönhausen", "BVG", [[13.3696, 52.5251], [13.3790, 52.5300], [13.4023, 52.5220], [13.4132, 52.5219]]), line(6100, "bus", "100", "Bus 100 Zoologischer Garten - Alexanderplatz", "BVG", [[13.3320, 52.5069], [13.3763, 52.5186], [13.4132, 52.5219]]), line(7010, "ferry", "F10", "F10 Wannsee - Kladow", "BVG", [[13.1797, 52.4210], [13.1439, 52.4547]]), line(5010, "tram", "M10", "M10 Warschauer Straße - Hauptbahnhof", "BVG", [[13.4500, 52.5050], [13.4020, 52.5300], [13.3696, 52.5251]]), point(9001, "station", "Berlin Hauptbahnhof", [13.3696, 52.5251], {"railway": "station"}), point(9002, "station", "Alexanderplatz", [13.4132, 52.5219], {"railway": "station"}), point(9003, "stop", "Zoologischer Garten", [13.3320, 52.5069], {"public_transport": "station", "railway": "station"}), point(9004, "terminal", "Wannsee Ferry Terminal", [13.1797, 52.4210], {"amenity": "ferry_terminal"}), point(9005, "terminal", "Kladow Ferry Terminal", [13.1439, 52.4547], {"amenity": "ferry_terminal"}), ] path.write_text(json.dumps({"type": "FeatureCollection", "features": features}, indent=2), encoding="utf-8")