Files
meubility-workbench/app/pipeline/sample_data.py
2026-07-01 23:29:51 +02:00

295 lines
12 KiB
Python

from __future__ import annotations
import csv
import io
import json
import zipfile
from pathlib import Path
from datetime import datetime, timezone
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from app.config import settings
from app.db import init_db
from app.models import (
Dataset,
CanonicalStop,
CanonicalStopLink,
GtfsAgency,
GtfsCalendar,
GtfsCalendarDate,
GtfsRoute,
GtfsRoutePatternLink,
GtfsShape,
GtfsStop,
GtfsStopTime,
GtfsTripRoutePatternLink,
GtfsTrip,
Itinerary,
ItineraryLeg,
Job,
JobEvent,
MatchRule,
OsmDiffState,
OsmFeature,
PipelineRun,
RouteMatch,
RoutePattern,
RoutePatternStop,
RoutingEdge,
RoutingNode,
Source,
SourceCatalogEntry,
SourceUpdateCheck,
TravelRequest,
)
from app.pipeline.matcher import run_route_matching
from app.pipeline.route_layer import rebuild_route_layer
from app.pipeline.run import run_source
def load_sample_project(session: Session, *, preserve_job_id: int | None = None) -> dict:
"""Clear the DB, create a small Berlin-like GTFS + OSM sample, import, and match."""
init_db()
clear_project_data(session, preserve_job_id=preserve_job_id, preserve_catalog=True)
sample_dir = settings.data_dir / "sample"
sample_dir.mkdir(parents=True, exist_ok=True)
gtfs_path = sample_dir / "sample_berlin.gtfs.zip"
osm_path = sample_dir / "sample_berlin_osm.geojson"
create_sample_gtfs(gtfs_path)
create_sample_osm_geojson(osm_path)
gtfs_source = Source(name="Sample Berlin GTFS", kind="gtfs", url=str(gtfs_path), country="DE", license="sample")
osm_source = Source(name="Sample Berlin OSM transport", kind="osm_geojson", url=str(osm_path), country="DE", license="sample")
session.add_all([gtfs_source, osm_source])
session.flush()
gtfs_dataset = run_source(session, gtfs_source)
osm_dataset = run_source(session, osm_source)
match_result = run_route_matching(session)
route_layer_result = rebuild_route_layer(session)
return {
"status": "ok",
"gtfs_dataset_id": gtfs_dataset.id,
"osm_dataset_id": osm_dataset.id,
"match_result": match_result,
"route_layer_result": route_layer_result,
}
def clear_project_data(
session: Session,
*,
preserve_job_id: int | None = None,
preserve_catalog: bool = True,
) -> None:
"""Clear user/project data while optionally preserving the current queue job."""
session.execute(delete(PipelineRun))
if preserve_job_id is None:
session.execute(delete(JobEvent))
session.execute(delete(Job))
else:
_cancel_other_jobs_for_reset(session, preserve_job_id)
for model in [
ItineraryLeg,
Itinerary,
TravelRequest,
SourceUpdateCheck,
OsmDiffState,
MatchRule,
RouteMatch,
GtfsTripRoutePatternLink,
GtfsRoutePatternLink,
RoutePatternStop,
RoutePattern,
CanonicalStopLink,
CanonicalStop,
RoutingEdge,
RoutingNode,
GtfsStopTime,
GtfsCalendarDate,
GtfsCalendar,
GtfsShape,
GtfsTrip,
GtfsRoute,
GtfsStop,
GtfsAgency,
OsmFeature,
Dataset,
Source,
]:
session.execute(delete(model))
if not preserve_catalog:
session.execute(delete(SourceCatalogEntry))
session.flush()
def _cancel_other_jobs_for_reset(session: Session, preserve_job_id: int) -> None:
now = datetime.now(timezone.utc)
active_statuses = {"queued", "running", "paused"}
jobs = session.scalars(
select(Job).where(Job.id != preserve_job_id, Job.status.in_(active_statuses))
).all()
for job in jobs:
job.status = "cancelled"
job.requested_action = None
job.lease_owner = None
job.lease_expires_at = None
job.paused_at = None
job.error = None
job.updated_at = now
job.finished_at = now
session.add(
JobEvent(
job_id=job.id,
event_type="cancelled_by_reset",
message=f"Job cancelled by reset job #{preserve_job_id}.",
progress_current=job.progress_current,
progress_total=job.progress_total,
)
)
def create_sample_gtfs(path: Path) -> None:
agencies = [
{"agency_id": "BVG", "agency_name": "BVG", "agency_url": "https://example.invalid/bvg", "agency_timezone": "Europe/Berlin"},
{"agency_id": "DB", "agency_name": "DB Regio", "agency_url": "https://example.invalid/db", "agency_timezone": "Europe/Berlin"},
{"agency_id": "XAIR", "agency_name": "Example Airport Shuttle", "agency_url": "https://example.invalid/xair", "agency_timezone": "Europe/Berlin"},
]
stops = [
{"stop_id": "hbf", "stop_name": "Berlin Hauptbahnhof", "stop_lat": "52.5251", "stop_lon": "13.3696"},
{"stop_id": "friedrich", "stop_name": "Friedrichstraße", "stop_lat": "52.5201", "stop_lon": "13.3862"},
{"stop_id": "alex", "stop_name": "Alexanderplatz", "stop_lat": "52.5219", "stop_lon": "13.4132"},
{"stop_id": "ost", "stop_name": "Ostbahnhof", "stop_lat": "52.5100", "stop_lon": "13.4344"},
{"stop_id": "zoo", "stop_name": "Zoologischer Garten", "stop_lat": "52.5069", "stop_lon": "13.3320"},
{"stop_id": "wittenberg", "stop_name": "Wittenbergplatz", "stop_lat": "52.5020", "stop_lon": "13.3430"},
{"stop_id": "potsdamer", "stop_name": "Potsdamer Platz", "stop_lat": "52.5096", "stop_lon": "13.3760"},
{"stop_id": "stadtmitte", "stop_name": "Stadtmitte", "stop_lat": "52.5113", "stop_lon": "13.3907"},
{"stop_id": "reichstag", "stop_name": "Reichstag", "stop_lat": "52.5186", "stop_lon": "13.3763"},
{"stop_id": "hackescher", "stop_name": "Hackescher Markt", "stop_lat": "52.5220", "stop_lon": "13.4023"},
{"stop_id": "naturkunde", "stop_name": "Naturkundemuseum", "stop_lat": "52.5300", "stop_lon": "13.3790"},
{"stop_id": "wannsee", "stop_name": "Wannsee", "stop_lat": "52.4210", "stop_lon": "13.1797"},
{"stop_id": "kladow", "stop_name": "Kladow", "stop_lat": "52.4547", "stop_lon": "13.1439"},
{"stop_id": "airport", "stop_name": "Example Airport Terminal", "stop_lat": "52.3650", "stop_lon": "13.5100"},
]
routes = [
{"route_id": "u2", "agency_id": "BVG", "route_short_name": "U2", "route_long_name": "Pankow - Ruhleben", "route_type": "1"},
{"route_id": "re1", "agency_id": "DB", "route_short_name": "RE1", "route_long_name": "Magdeburg - Frankfurt Oder", "route_type": "2"},
{"route_id": "m5", "agency_id": "BVG", "route_short_name": "M5", "route_long_name": "Hauptbahnhof - Hohenschönhausen", "route_type": "0"},
{"route_id": "bus100", "agency_id": "BVG", "route_short_name": "100", "route_long_name": "Zoologischer Garten - Alexanderplatz", "route_type": "3"},
{"route_id": "f10", "agency_id": "BVG", "route_short_name": "F10", "route_long_name": "Wannsee - Kladow", "route_type": "4"},
{"route_id": "x99", "agency_id": "XAIR", "route_short_name": "X99", "route_long_name": "Airport Express Sample", "route_type": "3"},
]
trips = [
{"route_id": r["route_id"], "service_id": "daily", "trip_id": f"{r['route_id']}_trip", "shape_id": f"{r['route_id']}_shape"}
for r in routes
]
stop_sequences = {
"u2_trip": ["zoo", "wittenberg", "potsdamer", "stadtmitte", "alex"],
"re1_trip": ["hbf", "friedrich", "alex", "ost"],
"m5_trip": ["hbf", "naturkunde", "hackescher", "alex"],
"bus100_trip": ["zoo", "reichstag", "alex"],
"f10_trip": ["wannsee", "kladow"],
"x99_trip": ["alex", "airport"],
}
coords = {row["stop_id"]: (row["stop_lon"], row["stop_lat"]) for row in stops}
stop_times = []
shapes = []
for trip in trips:
trip_id = trip["trip_id"]
for idx, stop_id in enumerate(stop_sequences[trip_id], start=1):
stop_times.append(
{
"trip_id": trip_id,
"arrival_time": f"08:{idx * 5:02d}:00",
"departure_time": f"08:{idx * 5 + 1:02d}:00",
"stop_id": stop_id,
"stop_sequence": str(idx),
}
)
lon, lat = coords[stop_id]
shapes.append(
{
"shape_id": trip["shape_id"],
"shape_pt_lat": lat,
"shape_pt_lon": lon,
"shape_pt_sequence": str(idx),
}
)
calendar = [
{
"service_id": "daily",
"monday": "1",
"tuesday": "1",
"wednesday": "1",
"thursday": "1",
"friday": "1",
"saturday": "1",
"sunday": "1",
"start_date": "20260101",
"end_date": "20261231",
}
]
with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
_write_csv(zf, "agency.txt", ["agency_id", "agency_name", "agency_url", "agency_timezone"], agencies)
_write_csv(zf, "stops.txt", ["stop_id", "stop_name", "stop_lat", "stop_lon"], stops)
_write_csv(zf, "routes.txt", ["route_id", "agency_id", "route_short_name", "route_long_name", "route_type"], routes)
_write_csv(zf, "trips.txt", ["route_id", "service_id", "trip_id", "shape_id"], trips)
_write_csv(zf, "stop_times.txt", ["trip_id", "arrival_time", "departure_time", "stop_id", "stop_sequence"], stop_times)
_write_csv(
zf,
"calendar.txt",
["service_id", "monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday", "start_date", "end_date"],
calendar,
)
_write_csv(zf, "shapes.txt", ["shape_id", "shape_pt_lat", "shape_pt_lon", "shape_pt_sequence"], shapes)
def _write_csv(zf: zipfile.ZipFile, name: str, fields: list[str], rows: list[dict[str, str]]) -> None:
buffer = io.StringIO(newline="")
writer = csv.DictWriter(buffer, fieldnames=fields)
writer.writeheader()
writer.writerows(rows)
zf.writestr(name, buffer.getvalue())
def create_sample_osm_geojson(path: Path) -> None:
def line(fid, mode, ref, name, operator, coords):
return {
"type": "Feature",
"geometry": {"type": "LineString", "coordinates": coords},
"properties": {
"osm_type": "relation",
"osm_id": str(fid),
"type": "route",
"route": mode,
"ref": ref,
"name": name,
"operator": operator,
"network": "VBB" if operator == "BVG" else "DB",
},
}
def point(fid, kind, name, coords, props=None):
props = props or {}
props.update({"osm_type": "node", "osm_id": str(fid), "name": name})
return {"type": "Feature", "geometry": {"type": "Point", "coordinates": coords}, "properties": props}
features = [
line(1002, "subway", "U2", "U2 Ruhleben - Pankow", "BVG", [[13.3320, 52.5069], [13.3430, 52.5020], [13.3760, 52.5096], [13.3907, 52.5113], [13.4132, 52.5219]]),
line(2001, "train", "RE1", "RE1 Magdeburg - Frankfurt Oder", "DB Regio", [[13.3696, 52.5251], [13.3862, 52.5201], [13.4132, 52.5219], [13.4344, 52.5100]]),
line(5005, "tram", "M5", "M5 Hauptbahnhof - Hohenschönhausen", "BVG", [[13.3696, 52.5251], [13.3790, 52.5300], [13.4023, 52.5220], [13.4132, 52.5219]]),
line(6100, "bus", "100", "Bus 100 Zoologischer Garten - Alexanderplatz", "BVG", [[13.3320, 52.5069], [13.3763, 52.5186], [13.4132, 52.5219]]),
line(7010, "ferry", "F10", "F10 Wannsee - Kladow", "BVG", [[13.1797, 52.4210], [13.1439, 52.4547]]),
line(5010, "tram", "M10", "M10 Warschauer Straße - Hauptbahnhof", "BVG", [[13.4500, 52.5050], [13.4020, 52.5300], [13.3696, 52.5251]]),
point(9001, "station", "Berlin Hauptbahnhof", [13.3696, 52.5251], {"railway": "station"}),
point(9002, "station", "Alexanderplatz", [13.4132, 52.5219], {"railway": "station"}),
point(9003, "stop", "Zoologischer Garten", [13.3320, 52.5069], {"public_transport": "station", "railway": "station"}),
point(9004, "terminal", "Wannsee Ferry Terminal", [13.1797, 52.4210], {"amenity": "ferry_terminal"}),
point(9005, "terminal", "Kladow Ferry Terminal", [13.1439, 52.4547], {"amenity": "ferry_terminal"}),
]
path.write_text(json.dumps({"type": "FeatureCollection", "features": features}, indent=2), encoding="utf-8")