Alpha stage commit
This commit is contained in:
294
app/pipeline/sample_data.py
Normal file
294
app/pipeline/sample_data.py
Normal file
@@ -0,0 +1,294 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.config import settings
|
||||
from app.db import init_db
|
||||
from app.models import (
|
||||
Dataset,
|
||||
CanonicalStop,
|
||||
CanonicalStopLink,
|
||||
GtfsAgency,
|
||||
GtfsCalendar,
|
||||
GtfsCalendarDate,
|
||||
GtfsRoute,
|
||||
GtfsRoutePatternLink,
|
||||
GtfsShape,
|
||||
GtfsStop,
|
||||
GtfsStopTime,
|
||||
GtfsTripRoutePatternLink,
|
||||
GtfsTrip,
|
||||
Itinerary,
|
||||
ItineraryLeg,
|
||||
Job,
|
||||
JobEvent,
|
||||
MatchRule,
|
||||
OsmDiffState,
|
||||
OsmFeature,
|
||||
PipelineRun,
|
||||
RouteMatch,
|
||||
RoutePattern,
|
||||
RoutePatternStop,
|
||||
RoutingEdge,
|
||||
RoutingNode,
|
||||
Source,
|
||||
SourceCatalogEntry,
|
||||
SourceUpdateCheck,
|
||||
TravelRequest,
|
||||
)
|
||||
from app.pipeline.matcher import run_route_matching
|
||||
from app.pipeline.route_layer import rebuild_route_layer
|
||||
from app.pipeline.run import run_source
|
||||
|
||||
|
||||
def load_sample_project(session: Session, *, preserve_job_id: int | None = None) -> dict:
|
||||
"""Clear the DB, create a small Berlin-like GTFS + OSM sample, import, and match."""
|
||||
init_db()
|
||||
clear_project_data(session, preserve_job_id=preserve_job_id, preserve_catalog=True)
|
||||
sample_dir = settings.data_dir / "sample"
|
||||
sample_dir.mkdir(parents=True, exist_ok=True)
|
||||
gtfs_path = sample_dir / "sample_berlin.gtfs.zip"
|
||||
osm_path = sample_dir / "sample_berlin_osm.geojson"
|
||||
create_sample_gtfs(gtfs_path)
|
||||
create_sample_osm_geojson(osm_path)
|
||||
|
||||
gtfs_source = Source(name="Sample Berlin GTFS", kind="gtfs", url=str(gtfs_path), country="DE", license="sample")
|
||||
osm_source = Source(name="Sample Berlin OSM transport", kind="osm_geojson", url=str(osm_path), country="DE", license="sample")
|
||||
session.add_all([gtfs_source, osm_source])
|
||||
session.flush()
|
||||
|
||||
gtfs_dataset = run_source(session, gtfs_source)
|
||||
osm_dataset = run_source(session, osm_source)
|
||||
match_result = run_route_matching(session)
|
||||
route_layer_result = rebuild_route_layer(session)
|
||||
return {
|
||||
"status": "ok",
|
||||
"gtfs_dataset_id": gtfs_dataset.id,
|
||||
"osm_dataset_id": osm_dataset.id,
|
||||
"match_result": match_result,
|
||||
"route_layer_result": route_layer_result,
|
||||
}
|
||||
|
||||
|
||||
def clear_project_data(
|
||||
session: Session,
|
||||
*,
|
||||
preserve_job_id: int | None = None,
|
||||
preserve_catalog: bool = True,
|
||||
) -> None:
|
||||
"""Clear user/project data while optionally preserving the current queue job."""
|
||||
session.execute(delete(PipelineRun))
|
||||
if preserve_job_id is None:
|
||||
session.execute(delete(JobEvent))
|
||||
session.execute(delete(Job))
|
||||
else:
|
||||
_cancel_other_jobs_for_reset(session, preserve_job_id)
|
||||
|
||||
for model in [
|
||||
ItineraryLeg,
|
||||
Itinerary,
|
||||
TravelRequest,
|
||||
SourceUpdateCheck,
|
||||
OsmDiffState,
|
||||
MatchRule,
|
||||
RouteMatch,
|
||||
GtfsTripRoutePatternLink,
|
||||
GtfsRoutePatternLink,
|
||||
RoutePatternStop,
|
||||
RoutePattern,
|
||||
CanonicalStopLink,
|
||||
CanonicalStop,
|
||||
RoutingEdge,
|
||||
RoutingNode,
|
||||
GtfsStopTime,
|
||||
GtfsCalendarDate,
|
||||
GtfsCalendar,
|
||||
GtfsShape,
|
||||
GtfsTrip,
|
||||
GtfsRoute,
|
||||
GtfsStop,
|
||||
GtfsAgency,
|
||||
OsmFeature,
|
||||
Dataset,
|
||||
Source,
|
||||
]:
|
||||
session.execute(delete(model))
|
||||
if not preserve_catalog:
|
||||
session.execute(delete(SourceCatalogEntry))
|
||||
session.flush()
|
||||
|
||||
|
||||
def _cancel_other_jobs_for_reset(session: Session, preserve_job_id: int) -> None:
|
||||
now = datetime.now(timezone.utc)
|
||||
active_statuses = {"queued", "running", "paused"}
|
||||
jobs = session.scalars(
|
||||
select(Job).where(Job.id != preserve_job_id, Job.status.in_(active_statuses))
|
||||
).all()
|
||||
for job in jobs:
|
||||
job.status = "cancelled"
|
||||
job.requested_action = None
|
||||
job.lease_owner = None
|
||||
job.lease_expires_at = None
|
||||
job.paused_at = None
|
||||
job.error = None
|
||||
job.updated_at = now
|
||||
job.finished_at = now
|
||||
session.add(
|
||||
JobEvent(
|
||||
job_id=job.id,
|
||||
event_type="cancelled_by_reset",
|
||||
message=f"Job cancelled by reset job #{preserve_job_id}.",
|
||||
progress_current=job.progress_current,
|
||||
progress_total=job.progress_total,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def create_sample_gtfs(path: Path) -> None:
|
||||
agencies = [
|
||||
{"agency_id": "BVG", "agency_name": "BVG", "agency_url": "https://example.invalid/bvg", "agency_timezone": "Europe/Berlin"},
|
||||
{"agency_id": "DB", "agency_name": "DB Regio", "agency_url": "https://example.invalid/db", "agency_timezone": "Europe/Berlin"},
|
||||
{"agency_id": "XAIR", "agency_name": "Example Airport Shuttle", "agency_url": "https://example.invalid/xair", "agency_timezone": "Europe/Berlin"},
|
||||
]
|
||||
stops = [
|
||||
{"stop_id": "hbf", "stop_name": "Berlin Hauptbahnhof", "stop_lat": "52.5251", "stop_lon": "13.3696"},
|
||||
{"stop_id": "friedrich", "stop_name": "Friedrichstraße", "stop_lat": "52.5201", "stop_lon": "13.3862"},
|
||||
{"stop_id": "alex", "stop_name": "Alexanderplatz", "stop_lat": "52.5219", "stop_lon": "13.4132"},
|
||||
{"stop_id": "ost", "stop_name": "Ostbahnhof", "stop_lat": "52.5100", "stop_lon": "13.4344"},
|
||||
{"stop_id": "zoo", "stop_name": "Zoologischer Garten", "stop_lat": "52.5069", "stop_lon": "13.3320"},
|
||||
{"stop_id": "wittenberg", "stop_name": "Wittenbergplatz", "stop_lat": "52.5020", "stop_lon": "13.3430"},
|
||||
{"stop_id": "potsdamer", "stop_name": "Potsdamer Platz", "stop_lat": "52.5096", "stop_lon": "13.3760"},
|
||||
{"stop_id": "stadtmitte", "stop_name": "Stadtmitte", "stop_lat": "52.5113", "stop_lon": "13.3907"},
|
||||
{"stop_id": "reichstag", "stop_name": "Reichstag", "stop_lat": "52.5186", "stop_lon": "13.3763"},
|
||||
{"stop_id": "hackescher", "stop_name": "Hackescher Markt", "stop_lat": "52.5220", "stop_lon": "13.4023"},
|
||||
{"stop_id": "naturkunde", "stop_name": "Naturkundemuseum", "stop_lat": "52.5300", "stop_lon": "13.3790"},
|
||||
{"stop_id": "wannsee", "stop_name": "Wannsee", "stop_lat": "52.4210", "stop_lon": "13.1797"},
|
||||
{"stop_id": "kladow", "stop_name": "Kladow", "stop_lat": "52.4547", "stop_lon": "13.1439"},
|
||||
{"stop_id": "airport", "stop_name": "Example Airport Terminal", "stop_lat": "52.3650", "stop_lon": "13.5100"},
|
||||
]
|
||||
routes = [
|
||||
{"route_id": "u2", "agency_id": "BVG", "route_short_name": "U2", "route_long_name": "Pankow - Ruhleben", "route_type": "1"},
|
||||
{"route_id": "re1", "agency_id": "DB", "route_short_name": "RE1", "route_long_name": "Magdeburg - Frankfurt Oder", "route_type": "2"},
|
||||
{"route_id": "m5", "agency_id": "BVG", "route_short_name": "M5", "route_long_name": "Hauptbahnhof - Hohenschönhausen", "route_type": "0"},
|
||||
{"route_id": "bus100", "agency_id": "BVG", "route_short_name": "100", "route_long_name": "Zoologischer Garten - Alexanderplatz", "route_type": "3"},
|
||||
{"route_id": "f10", "agency_id": "BVG", "route_short_name": "F10", "route_long_name": "Wannsee - Kladow", "route_type": "4"},
|
||||
{"route_id": "x99", "agency_id": "XAIR", "route_short_name": "X99", "route_long_name": "Airport Express Sample", "route_type": "3"},
|
||||
]
|
||||
trips = [
|
||||
{"route_id": r["route_id"], "service_id": "daily", "trip_id": f"{r['route_id']}_trip", "shape_id": f"{r['route_id']}_shape"}
|
||||
for r in routes
|
||||
]
|
||||
stop_sequences = {
|
||||
"u2_trip": ["zoo", "wittenberg", "potsdamer", "stadtmitte", "alex"],
|
||||
"re1_trip": ["hbf", "friedrich", "alex", "ost"],
|
||||
"m5_trip": ["hbf", "naturkunde", "hackescher", "alex"],
|
||||
"bus100_trip": ["zoo", "reichstag", "alex"],
|
||||
"f10_trip": ["wannsee", "kladow"],
|
||||
"x99_trip": ["alex", "airport"],
|
||||
}
|
||||
coords = {row["stop_id"]: (row["stop_lon"], row["stop_lat"]) for row in stops}
|
||||
stop_times = []
|
||||
shapes = []
|
||||
for trip in trips:
|
||||
trip_id = trip["trip_id"]
|
||||
for idx, stop_id in enumerate(stop_sequences[trip_id], start=1):
|
||||
stop_times.append(
|
||||
{
|
||||
"trip_id": trip_id,
|
||||
"arrival_time": f"08:{idx * 5:02d}:00",
|
||||
"departure_time": f"08:{idx * 5 + 1:02d}:00",
|
||||
"stop_id": stop_id,
|
||||
"stop_sequence": str(idx),
|
||||
}
|
||||
)
|
||||
lon, lat = coords[stop_id]
|
||||
shapes.append(
|
||||
{
|
||||
"shape_id": trip["shape_id"],
|
||||
"shape_pt_lat": lat,
|
||||
"shape_pt_lon": lon,
|
||||
"shape_pt_sequence": str(idx),
|
||||
}
|
||||
)
|
||||
calendar = [
|
||||
{
|
||||
"service_id": "daily",
|
||||
"monday": "1",
|
||||
"tuesday": "1",
|
||||
"wednesday": "1",
|
||||
"thursday": "1",
|
||||
"friday": "1",
|
||||
"saturday": "1",
|
||||
"sunday": "1",
|
||||
"start_date": "20260101",
|
||||
"end_date": "20261231",
|
||||
}
|
||||
]
|
||||
|
||||
with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
|
||||
_write_csv(zf, "agency.txt", ["agency_id", "agency_name", "agency_url", "agency_timezone"], agencies)
|
||||
_write_csv(zf, "stops.txt", ["stop_id", "stop_name", "stop_lat", "stop_lon"], stops)
|
||||
_write_csv(zf, "routes.txt", ["route_id", "agency_id", "route_short_name", "route_long_name", "route_type"], routes)
|
||||
_write_csv(zf, "trips.txt", ["route_id", "service_id", "trip_id", "shape_id"], trips)
|
||||
_write_csv(zf, "stop_times.txt", ["trip_id", "arrival_time", "departure_time", "stop_id", "stop_sequence"], stop_times)
|
||||
_write_csv(
|
||||
zf,
|
||||
"calendar.txt",
|
||||
["service_id", "monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday", "start_date", "end_date"],
|
||||
calendar,
|
||||
)
|
||||
_write_csv(zf, "shapes.txt", ["shape_id", "shape_pt_lat", "shape_pt_lon", "shape_pt_sequence"], shapes)
|
||||
|
||||
|
||||
def _write_csv(zf: zipfile.ZipFile, name: str, fields: list[str], rows: list[dict[str, str]]) -> None:
|
||||
buffer = io.StringIO(newline="")
|
||||
writer = csv.DictWriter(buffer, fieldnames=fields)
|
||||
writer.writeheader()
|
||||
writer.writerows(rows)
|
||||
zf.writestr(name, buffer.getvalue())
|
||||
|
||||
|
||||
def create_sample_osm_geojson(path: Path) -> None:
|
||||
def line(fid, mode, ref, name, operator, coords):
|
||||
return {
|
||||
"type": "Feature",
|
||||
"geometry": {"type": "LineString", "coordinates": coords},
|
||||
"properties": {
|
||||
"osm_type": "relation",
|
||||
"osm_id": str(fid),
|
||||
"type": "route",
|
||||
"route": mode,
|
||||
"ref": ref,
|
||||
"name": name,
|
||||
"operator": operator,
|
||||
"network": "VBB" if operator == "BVG" else "DB",
|
||||
},
|
||||
}
|
||||
|
||||
def point(fid, kind, name, coords, props=None):
|
||||
props = props or {}
|
||||
props.update({"osm_type": "node", "osm_id": str(fid), "name": name})
|
||||
return {"type": "Feature", "geometry": {"type": "Point", "coordinates": coords}, "properties": props}
|
||||
|
||||
features = [
|
||||
line(1002, "subway", "U2", "U2 Ruhleben - Pankow", "BVG", [[13.3320, 52.5069], [13.3430, 52.5020], [13.3760, 52.5096], [13.3907, 52.5113], [13.4132, 52.5219]]),
|
||||
line(2001, "train", "RE1", "RE1 Magdeburg - Frankfurt Oder", "DB Regio", [[13.3696, 52.5251], [13.3862, 52.5201], [13.4132, 52.5219], [13.4344, 52.5100]]),
|
||||
line(5005, "tram", "M5", "M5 Hauptbahnhof - Hohenschönhausen", "BVG", [[13.3696, 52.5251], [13.3790, 52.5300], [13.4023, 52.5220], [13.4132, 52.5219]]),
|
||||
line(6100, "bus", "100", "Bus 100 Zoologischer Garten - Alexanderplatz", "BVG", [[13.3320, 52.5069], [13.3763, 52.5186], [13.4132, 52.5219]]),
|
||||
line(7010, "ferry", "F10", "F10 Wannsee - Kladow", "BVG", [[13.1797, 52.4210], [13.1439, 52.4547]]),
|
||||
line(5010, "tram", "M10", "M10 Warschauer Straße - Hauptbahnhof", "BVG", [[13.4500, 52.5050], [13.4020, 52.5300], [13.3696, 52.5251]]),
|
||||
point(9001, "station", "Berlin Hauptbahnhof", [13.3696, 52.5251], {"railway": "station"}),
|
||||
point(9002, "station", "Alexanderplatz", [13.4132, 52.5219], {"railway": "station"}),
|
||||
point(9003, "stop", "Zoologischer Garten", [13.3320, 52.5069], {"public_transport": "station", "railway": "station"}),
|
||||
point(9004, "terminal", "Wannsee Ferry Terminal", [13.1797, 52.4210], {"amenity": "ferry_terminal"}),
|
||||
point(9005, "terminal", "Kladow Ferry Terminal", [13.1439, 52.4547], {"amenity": "ferry_terminal"}),
|
||||
]
|
||||
path.write_text(json.dumps({"type": "FeatureCollection", "features": features}, indent=2), encoding="utf-8")
|
||||
Reference in New Issue
Block a user