Files
meubility-workbench/tests/test_pipeline.py
2026-07-01 23:29:51 +02:00

318 lines
9.8 KiB
Python

from __future__ import annotations
import json
from shapely.geometry import LineString
from sqlalchemy import select
from app.db import init_db, session_scope
from app.models import GtfsRoute, OsmFeature, RouteMatch, RoutePattern
from app.osm_classification import infer_osm_route_scope
from app.pipeline.gtfs import _gtfs_mode
from app.pipeline.matcher import _build_osm_route_index, _candidate_osm_routes, route_match_scope, run_route_matching, score_route_pair
from app.pipeline.osm_addresses import _address_area_geometry_geojson
from app.pipeline.sample_data import load_sample_project
from app.pipeline.route_layer import rebuild_route_layer
from app.pipeline.utils import geometry_json_and_bbox
def test_sample_pipeline_imports_and_matches():
init_db()
with session_scope() as session:
result = load_sample_project(session)
assert result["match_result"]["routes"] == 6
assert session.scalar(select(GtfsRoute).where(GtfsRoute.short_name == "RE1")) is not None
assert session.scalar(select(OsmFeature).where(OsmFeature.ref == "RE1")) is not None
statuses = {row[0]: row[1] for row in session.execute(select(RouteMatch.status, RouteMatch.confidence))}
assert "matched" in statuses
assert set(statuses) & {"weak", "missing"}
def test_route_matching_preserves_unchanged_match_rows():
init_db()
with session_scope() as session:
load_sample_project(session)
before = {
route_id: (match_id, updated_at)
for route_id, match_id, updated_at in session.execute(
select(RouteMatch.gtfs_route_id, RouteMatch.id, RouteMatch.updated_at)
)
}
result = run_route_matching(session)
after = {
route_id: (match_id, updated_at)
for route_id, match_id, updated_at in session.execute(
select(RouteMatch.gtfs_route_id, RouteMatch.id, RouteMatch.updated_at)
)
}
assert result["unchanged"] == result["routes"]
assert result["created"] == 0
assert result["updated"] == 0
assert after == before
def test_route_layer_reuses_unchanged_route_patterns():
init_db()
with session_scope() as session:
load_sample_project(session)
before = {
pattern_key: pattern_id
for pattern_key, pattern_id in session.execute(select(RoutePattern.pattern_key, RoutePattern.id))
}
result = rebuild_route_layer(session)
after = {
pattern_key: pattern_id
for pattern_key, pattern_id in session.execute(select(RoutePattern.pattern_key, RoutePattern.id))
}
assert result["route_patterns_created"] == 0
assert result["route_patterns_removed"] == 0
assert result["route_patterns_reused"] == result["route_patterns"]
assert after == before
def test_extended_gtfs_route_types_are_mapped_to_modes():
assert _gtfs_mode(700) == "bus"
assert _gtfs_mode(100) == "train"
assert _gtfs_mode(109) == "train"
assert _gtfs_mode(900) == "tram"
assert _gtfs_mode(1000) == "ferry"
def test_closed_address_way_is_stored_as_polygon_geometry():
geometry = _address_area_geometry_geojson(
[
(8.68590, 49.40435),
(8.68600, 49.40435),
(8.68600, 49.40445),
(8.68590, 49.40445),
(8.68590, 49.40435),
]
)
assert json.loads(geometry or "{}") == {
"type": "Polygon",
"coordinates": [
[
[8.6859, 49.40435],
[8.686, 49.40435],
[8.686, 49.40445],
[8.6859, 49.40445],
[8.6859, 49.40435],
]
],
}
geometry = _address_area_geometry_geojson(
[
(8.68590, 49.40435),
(8.68600, 49.40435),
(8.68600, 49.40445),
(8.68590, 49.40445),
],
closed=True,
)
assert json.loads(geometry or "{}")["coordinates"][0][-1] == [8.6859, 49.40435]
assert _address_area_geometry_geojson([(0, 0), (1, 0), (1, 1)]) is None
assert _address_area_geometry_geojson([(0, 0), (1, 0), (1, 1)], closed=False) is None
def test_osm_route_scope_classifier_distinguishes_train_service_classes():
assert infer_osm_route_scope(mode="train", ref="ICE 28") == "long_distance"
assert infer_osm_route_scope(mode="train", ref="RE1") == "regional"
assert infer_osm_route_scope(mode="train", ref="S5", network="S-Bahn Berlin") == "local"
assert infer_osm_route_scope(mode="subway", ref="U5") == "local"
assert infer_osm_route_scope(mode="coach", ref="FLX") == "long_distance"
assert infer_osm_route_scope(mode="bus", ref="100") == "local"
assert infer_osm_route_scope(mode="bus", ref="800", tags={"bus": "regional"}) == "regional"
assert infer_osm_route_scope(mode="bus", name="FlixBus Berlin Hamburg") == "long_distance"
def test_exact_line_ref_with_overlapping_geometry_scores_as_match_candidate():
route = GtfsRoute(
route_id="17441_700",
short_name="M11",
long_name=None,
mode="bus",
operator_name="Verkehrsverbund Berlin-Brandenburg",
min_lon=13.29,
min_lat=52.42,
max_lon=13.33,
max_lat=52.45,
route_key="m11",
)
feature = OsmFeature(
osm_type="relation",
osm_id="123",
kind="route",
mode="bus",
ref="M11",
name="Bus M11: U Dahlem-Dorf => S Schöneweide/Sterndamm",
operator="Berliner Verkehrsbetriebe",
network="Verkehrsverbund Berlin-Brandenburg",
min_lon=13.28,
min_lat=52.40,
max_lon=13.51,
max_lat=52.46,
route_key="m11",
)
score, reasons = score_route_pair(route, feature)
assert score >= 85
assert reasons["line_identity"] == "exact_ref_mode_bbox_overlap"
def test_exact_line_ref_with_bbox_overlap_is_strong_without_name_or_operator_match():
route = GtfsRoute(
route_id="route-1",
short_name="M11",
long_name="",
mode="bus",
operator_name="",
min_lon=13.29,
min_lat=52.42,
max_lon=13.33,
max_lat=52.45,
route_key="m11",
)
feature = OsmFeature(
osm_type="relation",
osm_id="456",
kind="route",
mode="bus",
ref="M11",
name="",
operator="",
network="",
min_lon=13.30,
min_lat=52.43,
max_lon=13.35,
max_lat=52.46,
route_key="m11",
)
score, reasons = score_route_pair(route, feature)
assert score >= 88
assert reasons["strong_identity"] == "exact_ref_mode_bbox_overlap"
def test_common_short_ref_candidates_are_spatially_ranked():
route = GtfsRoute(
route_id="bus-2-berlin",
short_name="2",
mode="bus",
min_lon=13.30,
min_lat=52.40,
max_lon=13.40,
max_lat=52.50,
route_key="2",
)
far = OsmFeature(
id=1,
osm_type="relation",
osm_id="far",
kind="route",
mode="bus",
ref="2",
min_lon=7.0,
min_lat=50.0,
max_lon=7.1,
max_lat=50.1,
route_key="2",
)
near = OsmFeature(
id=2,
osm_type="relation",
osm_id="near",
kind="route",
mode="bus",
ref="2",
min_lon=13.31,
min_lat=52.41,
max_lon=13.39,
max_lat=52.49,
route_key="2",
)
candidates = _candidate_osm_routes(route, _build_osm_route_index([far, near]))
assert candidates[0].osm_id == "near"
def test_exact_ref_far_away_is_not_promoted_without_spatial_or_geometry_evidence():
route = GtfsRoute(
route_id="bus-2-berlin",
short_name="2",
mode="bus",
operator_name="Example Operator",
min_lon=13.30,
min_lat=52.40,
max_lon=13.40,
max_lat=52.50,
route_key="2",
)
feature = OsmFeature(
osm_type="relation",
osm_id="2-cologne",
kind="route",
mode="bus",
ref="2",
operator="Example Operator",
min_lon=6.9,
min_lat=50.9,
max_lon=7.1,
max_lat=51.0,
route_key="2",
)
score, reasons = score_route_pair(route, feature)
assert score < 65
assert reasons["spatial_penalty"] == "exact_ref_far_bbox_center"
assert reasons["spatial_cap"] == "exact_ref_far_without_geometry_overlap"
def test_geometry_overlap_can_confirm_exact_ref_match():
gtfs_geometry, gtfs_bbox = geometry_json_and_bbox(LineString([(13.30, 52.40), (13.35, 52.45), (13.40, 52.50)]))
osm_geometry, osm_bbox = geometry_json_and_bbox(LineString([(13.3005, 52.4005), (13.3505, 52.4505), (13.4005, 52.5005)]))
route = GtfsRoute(
route_id="bus-2-berlin",
short_name="2",
mode="bus",
min_lon=gtfs_bbox[0],
min_lat=gtfs_bbox[1],
max_lon=gtfs_bbox[2],
max_lat=gtfs_bbox[3],
route_key="2",
geometry_geojson=gtfs_geometry,
)
feature = OsmFeature(
osm_type="relation",
osm_id="2-berlin",
kind="route",
mode="bus",
ref="2",
min_lon=osm_bbox[0],
min_lat=osm_bbox[1],
max_lon=osm_bbox[2],
max_lat=osm_bbox[3],
route_key="2",
geometry_geojson=osm_geometry,
)
score, reasons = score_route_pair(route, feature)
assert score >= 90
assert reasons["strong_identity"] == "exact_ref_mode_geometry_overlap"
assert reasons["geometry"]["gtfs_on_osm_ratio"] >= 0.9
def test_route_match_scope_distinguishes_outside_loaded_osm_area():
route = GtfsRoute(min_lon=13.3, min_lat=52.4, max_lon=13.4, max_lat=52.5)
assert route_match_scope(route, (13.0, 52.3, 13.8, 52.7)) == "in_osm_scope"
assert route_match_scope(route, (6.0, 50.0, 7.0, 51.0)) == "outside_osm_scope"