295 lines
12 KiB
Python
295 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import csv
|
|
import io
|
|
import json
|
|
import zipfile
|
|
from pathlib import Path
|
|
from datetime import datetime, timezone
|
|
|
|
from sqlalchemy import delete, select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.config import settings
|
|
from app.db import init_db
|
|
from app.models import (
|
|
Dataset,
|
|
CanonicalStop,
|
|
CanonicalStopLink,
|
|
GtfsAgency,
|
|
GtfsCalendar,
|
|
GtfsCalendarDate,
|
|
GtfsRoute,
|
|
GtfsRoutePatternLink,
|
|
GtfsShape,
|
|
GtfsStop,
|
|
GtfsStopTime,
|
|
GtfsTripRoutePatternLink,
|
|
GtfsTrip,
|
|
Itinerary,
|
|
ItineraryLeg,
|
|
Job,
|
|
JobEvent,
|
|
MatchRule,
|
|
OsmDiffState,
|
|
OsmFeature,
|
|
PipelineRun,
|
|
RouteMatch,
|
|
RoutePattern,
|
|
RoutePatternStop,
|
|
RoutingEdge,
|
|
RoutingNode,
|
|
Source,
|
|
SourceCatalogEntry,
|
|
SourceUpdateCheck,
|
|
TravelRequest,
|
|
)
|
|
from app.pipeline.matcher import run_route_matching
|
|
from app.pipeline.route_layer import rebuild_route_layer
|
|
from app.pipeline.run import run_source
|
|
|
|
|
|
def load_sample_project(session: Session, *, preserve_job_id: int | None = None) -> dict:
|
|
"""Clear the DB, create a small Berlin-like GTFS + OSM sample, import, and match."""
|
|
init_db()
|
|
clear_project_data(session, preserve_job_id=preserve_job_id, preserve_catalog=True)
|
|
sample_dir = settings.data_dir / "sample"
|
|
sample_dir.mkdir(parents=True, exist_ok=True)
|
|
gtfs_path = sample_dir / "sample_berlin.gtfs.zip"
|
|
osm_path = sample_dir / "sample_berlin_osm.geojson"
|
|
create_sample_gtfs(gtfs_path)
|
|
create_sample_osm_geojson(osm_path)
|
|
|
|
gtfs_source = Source(name="Sample Berlin GTFS", kind="gtfs", url=str(gtfs_path), country="DE", license="sample")
|
|
osm_source = Source(name="Sample Berlin OSM transport", kind="osm_geojson", url=str(osm_path), country="DE", license="sample")
|
|
session.add_all([gtfs_source, osm_source])
|
|
session.flush()
|
|
|
|
gtfs_dataset = run_source(session, gtfs_source)
|
|
osm_dataset = run_source(session, osm_source)
|
|
match_result = run_route_matching(session)
|
|
route_layer_result = rebuild_route_layer(session)
|
|
return {
|
|
"status": "ok",
|
|
"gtfs_dataset_id": gtfs_dataset.id,
|
|
"osm_dataset_id": osm_dataset.id,
|
|
"match_result": match_result,
|
|
"route_layer_result": route_layer_result,
|
|
}
|
|
|
|
|
|
def clear_project_data(
|
|
session: Session,
|
|
*,
|
|
preserve_job_id: int | None = None,
|
|
preserve_catalog: bool = True,
|
|
) -> None:
|
|
"""Clear user/project data while optionally preserving the current queue job."""
|
|
session.execute(delete(PipelineRun))
|
|
if preserve_job_id is None:
|
|
session.execute(delete(JobEvent))
|
|
session.execute(delete(Job))
|
|
else:
|
|
_cancel_other_jobs_for_reset(session, preserve_job_id)
|
|
|
|
for model in [
|
|
ItineraryLeg,
|
|
Itinerary,
|
|
TravelRequest,
|
|
SourceUpdateCheck,
|
|
OsmDiffState,
|
|
MatchRule,
|
|
RouteMatch,
|
|
GtfsTripRoutePatternLink,
|
|
GtfsRoutePatternLink,
|
|
RoutePatternStop,
|
|
RoutePattern,
|
|
CanonicalStopLink,
|
|
CanonicalStop,
|
|
RoutingEdge,
|
|
RoutingNode,
|
|
GtfsStopTime,
|
|
GtfsCalendarDate,
|
|
GtfsCalendar,
|
|
GtfsShape,
|
|
GtfsTrip,
|
|
GtfsRoute,
|
|
GtfsStop,
|
|
GtfsAgency,
|
|
OsmFeature,
|
|
Dataset,
|
|
Source,
|
|
]:
|
|
session.execute(delete(model))
|
|
if not preserve_catalog:
|
|
session.execute(delete(SourceCatalogEntry))
|
|
session.flush()
|
|
|
|
|
|
def _cancel_other_jobs_for_reset(session: Session, preserve_job_id: int) -> None:
|
|
now = datetime.now(timezone.utc)
|
|
active_statuses = {"queued", "running", "paused"}
|
|
jobs = session.scalars(
|
|
select(Job).where(Job.id != preserve_job_id, Job.status.in_(active_statuses))
|
|
).all()
|
|
for job in jobs:
|
|
job.status = "cancelled"
|
|
job.requested_action = None
|
|
job.lease_owner = None
|
|
job.lease_expires_at = None
|
|
job.paused_at = None
|
|
job.error = None
|
|
job.updated_at = now
|
|
job.finished_at = now
|
|
session.add(
|
|
JobEvent(
|
|
job_id=job.id,
|
|
event_type="cancelled_by_reset",
|
|
message=f"Job cancelled by reset job #{preserve_job_id}.",
|
|
progress_current=job.progress_current,
|
|
progress_total=job.progress_total,
|
|
)
|
|
)
|
|
|
|
|
|
def create_sample_gtfs(path: Path) -> None:
|
|
agencies = [
|
|
{"agency_id": "BVG", "agency_name": "BVG", "agency_url": "https://example.invalid/bvg", "agency_timezone": "Europe/Berlin"},
|
|
{"agency_id": "DB", "agency_name": "DB Regio", "agency_url": "https://example.invalid/db", "agency_timezone": "Europe/Berlin"},
|
|
{"agency_id": "XAIR", "agency_name": "Example Airport Shuttle", "agency_url": "https://example.invalid/xair", "agency_timezone": "Europe/Berlin"},
|
|
]
|
|
stops = [
|
|
{"stop_id": "hbf", "stop_name": "Berlin Hauptbahnhof", "stop_lat": "52.5251", "stop_lon": "13.3696"},
|
|
{"stop_id": "friedrich", "stop_name": "Friedrichstraße", "stop_lat": "52.5201", "stop_lon": "13.3862"},
|
|
{"stop_id": "alex", "stop_name": "Alexanderplatz", "stop_lat": "52.5219", "stop_lon": "13.4132"},
|
|
{"stop_id": "ost", "stop_name": "Ostbahnhof", "stop_lat": "52.5100", "stop_lon": "13.4344"},
|
|
{"stop_id": "zoo", "stop_name": "Zoologischer Garten", "stop_lat": "52.5069", "stop_lon": "13.3320"},
|
|
{"stop_id": "wittenberg", "stop_name": "Wittenbergplatz", "stop_lat": "52.5020", "stop_lon": "13.3430"},
|
|
{"stop_id": "potsdamer", "stop_name": "Potsdamer Platz", "stop_lat": "52.5096", "stop_lon": "13.3760"},
|
|
{"stop_id": "stadtmitte", "stop_name": "Stadtmitte", "stop_lat": "52.5113", "stop_lon": "13.3907"},
|
|
{"stop_id": "reichstag", "stop_name": "Reichstag", "stop_lat": "52.5186", "stop_lon": "13.3763"},
|
|
{"stop_id": "hackescher", "stop_name": "Hackescher Markt", "stop_lat": "52.5220", "stop_lon": "13.4023"},
|
|
{"stop_id": "naturkunde", "stop_name": "Naturkundemuseum", "stop_lat": "52.5300", "stop_lon": "13.3790"},
|
|
{"stop_id": "wannsee", "stop_name": "Wannsee", "stop_lat": "52.4210", "stop_lon": "13.1797"},
|
|
{"stop_id": "kladow", "stop_name": "Kladow", "stop_lat": "52.4547", "stop_lon": "13.1439"},
|
|
{"stop_id": "airport", "stop_name": "Example Airport Terminal", "stop_lat": "52.3650", "stop_lon": "13.5100"},
|
|
]
|
|
routes = [
|
|
{"route_id": "u2", "agency_id": "BVG", "route_short_name": "U2", "route_long_name": "Pankow - Ruhleben", "route_type": "1"},
|
|
{"route_id": "re1", "agency_id": "DB", "route_short_name": "RE1", "route_long_name": "Magdeburg - Frankfurt Oder", "route_type": "2"},
|
|
{"route_id": "m5", "agency_id": "BVG", "route_short_name": "M5", "route_long_name": "Hauptbahnhof - Hohenschönhausen", "route_type": "0"},
|
|
{"route_id": "bus100", "agency_id": "BVG", "route_short_name": "100", "route_long_name": "Zoologischer Garten - Alexanderplatz", "route_type": "3"},
|
|
{"route_id": "f10", "agency_id": "BVG", "route_short_name": "F10", "route_long_name": "Wannsee - Kladow", "route_type": "4"},
|
|
{"route_id": "x99", "agency_id": "XAIR", "route_short_name": "X99", "route_long_name": "Airport Express Sample", "route_type": "3"},
|
|
]
|
|
trips = [
|
|
{"route_id": r["route_id"], "service_id": "daily", "trip_id": f"{r['route_id']}_trip", "shape_id": f"{r['route_id']}_shape"}
|
|
for r in routes
|
|
]
|
|
stop_sequences = {
|
|
"u2_trip": ["zoo", "wittenberg", "potsdamer", "stadtmitte", "alex"],
|
|
"re1_trip": ["hbf", "friedrich", "alex", "ost"],
|
|
"m5_trip": ["hbf", "naturkunde", "hackescher", "alex"],
|
|
"bus100_trip": ["zoo", "reichstag", "alex"],
|
|
"f10_trip": ["wannsee", "kladow"],
|
|
"x99_trip": ["alex", "airport"],
|
|
}
|
|
coords = {row["stop_id"]: (row["stop_lon"], row["stop_lat"]) for row in stops}
|
|
stop_times = []
|
|
shapes = []
|
|
for trip in trips:
|
|
trip_id = trip["trip_id"]
|
|
for idx, stop_id in enumerate(stop_sequences[trip_id], start=1):
|
|
stop_times.append(
|
|
{
|
|
"trip_id": trip_id,
|
|
"arrival_time": f"08:{idx * 5:02d}:00",
|
|
"departure_time": f"08:{idx * 5 + 1:02d}:00",
|
|
"stop_id": stop_id,
|
|
"stop_sequence": str(idx),
|
|
}
|
|
)
|
|
lon, lat = coords[stop_id]
|
|
shapes.append(
|
|
{
|
|
"shape_id": trip["shape_id"],
|
|
"shape_pt_lat": lat,
|
|
"shape_pt_lon": lon,
|
|
"shape_pt_sequence": str(idx),
|
|
}
|
|
)
|
|
calendar = [
|
|
{
|
|
"service_id": "daily",
|
|
"monday": "1",
|
|
"tuesday": "1",
|
|
"wednesday": "1",
|
|
"thursday": "1",
|
|
"friday": "1",
|
|
"saturday": "1",
|
|
"sunday": "1",
|
|
"start_date": "20260101",
|
|
"end_date": "20261231",
|
|
}
|
|
]
|
|
|
|
with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
|
|
_write_csv(zf, "agency.txt", ["agency_id", "agency_name", "agency_url", "agency_timezone"], agencies)
|
|
_write_csv(zf, "stops.txt", ["stop_id", "stop_name", "stop_lat", "stop_lon"], stops)
|
|
_write_csv(zf, "routes.txt", ["route_id", "agency_id", "route_short_name", "route_long_name", "route_type"], routes)
|
|
_write_csv(zf, "trips.txt", ["route_id", "service_id", "trip_id", "shape_id"], trips)
|
|
_write_csv(zf, "stop_times.txt", ["trip_id", "arrival_time", "departure_time", "stop_id", "stop_sequence"], stop_times)
|
|
_write_csv(
|
|
zf,
|
|
"calendar.txt",
|
|
["service_id", "monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday", "start_date", "end_date"],
|
|
calendar,
|
|
)
|
|
_write_csv(zf, "shapes.txt", ["shape_id", "shape_pt_lat", "shape_pt_lon", "shape_pt_sequence"], shapes)
|
|
|
|
|
|
def _write_csv(zf: zipfile.ZipFile, name: str, fields: list[str], rows: list[dict[str, str]]) -> None:
|
|
buffer = io.StringIO(newline="")
|
|
writer = csv.DictWriter(buffer, fieldnames=fields)
|
|
writer.writeheader()
|
|
writer.writerows(rows)
|
|
zf.writestr(name, buffer.getvalue())
|
|
|
|
|
|
def create_sample_osm_geojson(path: Path) -> None:
|
|
def line(fid, mode, ref, name, operator, coords):
|
|
return {
|
|
"type": "Feature",
|
|
"geometry": {"type": "LineString", "coordinates": coords},
|
|
"properties": {
|
|
"osm_type": "relation",
|
|
"osm_id": str(fid),
|
|
"type": "route",
|
|
"route": mode,
|
|
"ref": ref,
|
|
"name": name,
|
|
"operator": operator,
|
|
"network": "VBB" if operator == "BVG" else "DB",
|
|
},
|
|
}
|
|
|
|
def point(fid, kind, name, coords, props=None):
|
|
props = props or {}
|
|
props.update({"osm_type": "node", "osm_id": str(fid), "name": name})
|
|
return {"type": "Feature", "geometry": {"type": "Point", "coordinates": coords}, "properties": props}
|
|
|
|
features = [
|
|
line(1002, "subway", "U2", "U2 Ruhleben - Pankow", "BVG", [[13.3320, 52.5069], [13.3430, 52.5020], [13.3760, 52.5096], [13.3907, 52.5113], [13.4132, 52.5219]]),
|
|
line(2001, "train", "RE1", "RE1 Magdeburg - Frankfurt Oder", "DB Regio", [[13.3696, 52.5251], [13.3862, 52.5201], [13.4132, 52.5219], [13.4344, 52.5100]]),
|
|
line(5005, "tram", "M5", "M5 Hauptbahnhof - Hohenschönhausen", "BVG", [[13.3696, 52.5251], [13.3790, 52.5300], [13.4023, 52.5220], [13.4132, 52.5219]]),
|
|
line(6100, "bus", "100", "Bus 100 Zoologischer Garten - Alexanderplatz", "BVG", [[13.3320, 52.5069], [13.3763, 52.5186], [13.4132, 52.5219]]),
|
|
line(7010, "ferry", "F10", "F10 Wannsee - Kladow", "BVG", [[13.1797, 52.4210], [13.1439, 52.4547]]),
|
|
line(5010, "tram", "M10", "M10 Warschauer Straße - Hauptbahnhof", "BVG", [[13.4500, 52.5050], [13.4020, 52.5300], [13.3696, 52.5251]]),
|
|
point(9001, "station", "Berlin Hauptbahnhof", [13.3696, 52.5251], {"railway": "station"}),
|
|
point(9002, "station", "Alexanderplatz", [13.4132, 52.5219], {"railway": "station"}),
|
|
point(9003, "stop", "Zoologischer Garten", [13.3320, 52.5069], {"public_transport": "station", "railway": "station"}),
|
|
point(9004, "terminal", "Wannsee Ferry Terminal", [13.1797, 52.4210], {"amenity": "ferry_terminal"}),
|
|
point(9005, "terminal", "Kladow Ferry Terminal", [13.1439, 52.4547], {"amenity": "ferry_terminal"}),
|
|
]
|
|
path.write_text(json.dumps({"type": "FeatureCollection", "features": features}, indent=2), encoding="utf-8")
|