Alpha stage commit
This commit is contained in:
317
tests/test_pipeline.py
Normal file
317
tests/test_pipeline.py
Normal file
@@ -0,0 +1,317 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from shapely.geometry import LineString
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import init_db, session_scope
|
||||
from app.models import GtfsRoute, OsmFeature, RouteMatch, RoutePattern
|
||||
from app.osm_classification import infer_osm_route_scope
|
||||
from app.pipeline.gtfs import _gtfs_mode
|
||||
from app.pipeline.matcher import _build_osm_route_index, _candidate_osm_routes, route_match_scope, run_route_matching, score_route_pair
|
||||
from app.pipeline.osm_addresses import _address_area_geometry_geojson
|
||||
from app.pipeline.sample_data import load_sample_project
|
||||
from app.pipeline.route_layer import rebuild_route_layer
|
||||
from app.pipeline.utils import geometry_json_and_bbox
|
||||
|
||||
|
||||
def test_sample_pipeline_imports_and_matches():
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
result = load_sample_project(session)
|
||||
assert result["match_result"]["routes"] == 6
|
||||
assert session.scalar(select(GtfsRoute).where(GtfsRoute.short_name == "RE1")) is not None
|
||||
assert session.scalar(select(OsmFeature).where(OsmFeature.ref == "RE1")) is not None
|
||||
statuses = {row[0]: row[1] for row in session.execute(select(RouteMatch.status, RouteMatch.confidence))}
|
||||
assert "matched" in statuses
|
||||
assert set(statuses) & {"weak", "missing"}
|
||||
|
||||
|
||||
def test_route_matching_preserves_unchanged_match_rows():
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
load_sample_project(session)
|
||||
before = {
|
||||
route_id: (match_id, updated_at)
|
||||
for route_id, match_id, updated_at in session.execute(
|
||||
select(RouteMatch.gtfs_route_id, RouteMatch.id, RouteMatch.updated_at)
|
||||
)
|
||||
}
|
||||
|
||||
result = run_route_matching(session)
|
||||
|
||||
after = {
|
||||
route_id: (match_id, updated_at)
|
||||
for route_id, match_id, updated_at in session.execute(
|
||||
select(RouteMatch.gtfs_route_id, RouteMatch.id, RouteMatch.updated_at)
|
||||
)
|
||||
}
|
||||
assert result["unchanged"] == result["routes"]
|
||||
assert result["created"] == 0
|
||||
assert result["updated"] == 0
|
||||
assert after == before
|
||||
|
||||
|
||||
def test_route_layer_reuses_unchanged_route_patterns():
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
load_sample_project(session)
|
||||
before = {
|
||||
pattern_key: pattern_id
|
||||
for pattern_key, pattern_id in session.execute(select(RoutePattern.pattern_key, RoutePattern.id))
|
||||
}
|
||||
|
||||
result = rebuild_route_layer(session)
|
||||
|
||||
after = {
|
||||
pattern_key: pattern_id
|
||||
for pattern_key, pattern_id in session.execute(select(RoutePattern.pattern_key, RoutePattern.id))
|
||||
}
|
||||
assert result["route_patterns_created"] == 0
|
||||
assert result["route_patterns_removed"] == 0
|
||||
assert result["route_patterns_reused"] == result["route_patterns"]
|
||||
assert after == before
|
||||
|
||||
|
||||
def test_extended_gtfs_route_types_are_mapped_to_modes():
|
||||
assert _gtfs_mode(700) == "bus"
|
||||
assert _gtfs_mode(100) == "train"
|
||||
assert _gtfs_mode(109) == "train"
|
||||
assert _gtfs_mode(900) == "tram"
|
||||
assert _gtfs_mode(1000) == "ferry"
|
||||
|
||||
|
||||
def test_closed_address_way_is_stored_as_polygon_geometry():
|
||||
geometry = _address_area_geometry_geojson(
|
||||
[
|
||||
(8.68590, 49.40435),
|
||||
(8.68600, 49.40435),
|
||||
(8.68600, 49.40445),
|
||||
(8.68590, 49.40445),
|
||||
(8.68590, 49.40435),
|
||||
]
|
||||
)
|
||||
|
||||
assert json.loads(geometry or "{}") == {
|
||||
"type": "Polygon",
|
||||
"coordinates": [
|
||||
[
|
||||
[8.6859, 49.40435],
|
||||
[8.686, 49.40435],
|
||||
[8.686, 49.40445],
|
||||
[8.6859, 49.40445],
|
||||
[8.6859, 49.40435],
|
||||
]
|
||||
],
|
||||
}
|
||||
geometry = _address_area_geometry_geojson(
|
||||
[
|
||||
(8.68590, 49.40435),
|
||||
(8.68600, 49.40435),
|
||||
(8.68600, 49.40445),
|
||||
(8.68590, 49.40445),
|
||||
],
|
||||
closed=True,
|
||||
)
|
||||
assert json.loads(geometry or "{}")["coordinates"][0][-1] == [8.6859, 49.40435]
|
||||
assert _address_area_geometry_geojson([(0, 0), (1, 0), (1, 1)]) is None
|
||||
assert _address_area_geometry_geojson([(0, 0), (1, 0), (1, 1)], closed=False) is None
|
||||
|
||||
|
||||
def test_osm_route_scope_classifier_distinguishes_train_service_classes():
|
||||
assert infer_osm_route_scope(mode="train", ref="ICE 28") == "long_distance"
|
||||
assert infer_osm_route_scope(mode="train", ref="RE1") == "regional"
|
||||
assert infer_osm_route_scope(mode="train", ref="S5", network="S-Bahn Berlin") == "local"
|
||||
assert infer_osm_route_scope(mode="subway", ref="U5") == "local"
|
||||
assert infer_osm_route_scope(mode="coach", ref="FLX") == "long_distance"
|
||||
assert infer_osm_route_scope(mode="bus", ref="100") == "local"
|
||||
assert infer_osm_route_scope(mode="bus", ref="800", tags={"bus": "regional"}) == "regional"
|
||||
assert infer_osm_route_scope(mode="bus", name="FlixBus Berlin Hamburg") == "long_distance"
|
||||
|
||||
|
||||
def test_exact_line_ref_with_overlapping_geometry_scores_as_match_candidate():
|
||||
route = GtfsRoute(
|
||||
route_id="17441_700",
|
||||
short_name="M11",
|
||||
long_name=None,
|
||||
mode="bus",
|
||||
operator_name="Verkehrsverbund Berlin-Brandenburg",
|
||||
min_lon=13.29,
|
||||
min_lat=52.42,
|
||||
max_lon=13.33,
|
||||
max_lat=52.45,
|
||||
route_key="m11",
|
||||
)
|
||||
feature = OsmFeature(
|
||||
osm_type="relation",
|
||||
osm_id="123",
|
||||
kind="route",
|
||||
mode="bus",
|
||||
ref="M11",
|
||||
name="Bus M11: U Dahlem-Dorf => S Schöneweide/Sterndamm",
|
||||
operator="Berliner Verkehrsbetriebe",
|
||||
network="Verkehrsverbund Berlin-Brandenburg",
|
||||
min_lon=13.28,
|
||||
min_lat=52.40,
|
||||
max_lon=13.51,
|
||||
max_lat=52.46,
|
||||
route_key="m11",
|
||||
)
|
||||
|
||||
score, reasons = score_route_pair(route, feature)
|
||||
|
||||
assert score >= 85
|
||||
assert reasons["line_identity"] == "exact_ref_mode_bbox_overlap"
|
||||
|
||||
|
||||
def test_exact_line_ref_with_bbox_overlap_is_strong_without_name_or_operator_match():
|
||||
route = GtfsRoute(
|
||||
route_id="route-1",
|
||||
short_name="M11",
|
||||
long_name="",
|
||||
mode="bus",
|
||||
operator_name="",
|
||||
min_lon=13.29,
|
||||
min_lat=52.42,
|
||||
max_lon=13.33,
|
||||
max_lat=52.45,
|
||||
route_key="m11",
|
||||
)
|
||||
feature = OsmFeature(
|
||||
osm_type="relation",
|
||||
osm_id="456",
|
||||
kind="route",
|
||||
mode="bus",
|
||||
ref="M11",
|
||||
name="",
|
||||
operator="",
|
||||
network="",
|
||||
min_lon=13.30,
|
||||
min_lat=52.43,
|
||||
max_lon=13.35,
|
||||
max_lat=52.46,
|
||||
route_key="m11",
|
||||
)
|
||||
|
||||
score, reasons = score_route_pair(route, feature)
|
||||
|
||||
assert score >= 88
|
||||
assert reasons["strong_identity"] == "exact_ref_mode_bbox_overlap"
|
||||
|
||||
|
||||
def test_common_short_ref_candidates_are_spatially_ranked():
|
||||
route = GtfsRoute(
|
||||
route_id="bus-2-berlin",
|
||||
short_name="2",
|
||||
mode="bus",
|
||||
min_lon=13.30,
|
||||
min_lat=52.40,
|
||||
max_lon=13.40,
|
||||
max_lat=52.50,
|
||||
route_key="2",
|
||||
)
|
||||
far = OsmFeature(
|
||||
id=1,
|
||||
osm_type="relation",
|
||||
osm_id="far",
|
||||
kind="route",
|
||||
mode="bus",
|
||||
ref="2",
|
||||
min_lon=7.0,
|
||||
min_lat=50.0,
|
||||
max_lon=7.1,
|
||||
max_lat=50.1,
|
||||
route_key="2",
|
||||
)
|
||||
near = OsmFeature(
|
||||
id=2,
|
||||
osm_type="relation",
|
||||
osm_id="near",
|
||||
kind="route",
|
||||
mode="bus",
|
||||
ref="2",
|
||||
min_lon=13.31,
|
||||
min_lat=52.41,
|
||||
max_lon=13.39,
|
||||
max_lat=52.49,
|
||||
route_key="2",
|
||||
)
|
||||
|
||||
candidates = _candidate_osm_routes(route, _build_osm_route_index([far, near]))
|
||||
|
||||
assert candidates[0].osm_id == "near"
|
||||
|
||||
|
||||
def test_exact_ref_far_away_is_not_promoted_without_spatial_or_geometry_evidence():
|
||||
route = GtfsRoute(
|
||||
route_id="bus-2-berlin",
|
||||
short_name="2",
|
||||
mode="bus",
|
||||
operator_name="Example Operator",
|
||||
min_lon=13.30,
|
||||
min_lat=52.40,
|
||||
max_lon=13.40,
|
||||
max_lat=52.50,
|
||||
route_key="2",
|
||||
)
|
||||
feature = OsmFeature(
|
||||
osm_type="relation",
|
||||
osm_id="2-cologne",
|
||||
kind="route",
|
||||
mode="bus",
|
||||
ref="2",
|
||||
operator="Example Operator",
|
||||
min_lon=6.9,
|
||||
min_lat=50.9,
|
||||
max_lon=7.1,
|
||||
max_lat=51.0,
|
||||
route_key="2",
|
||||
)
|
||||
|
||||
score, reasons = score_route_pair(route, feature)
|
||||
|
||||
assert score < 65
|
||||
assert reasons["spatial_penalty"] == "exact_ref_far_bbox_center"
|
||||
assert reasons["spatial_cap"] == "exact_ref_far_without_geometry_overlap"
|
||||
|
||||
|
||||
def test_geometry_overlap_can_confirm_exact_ref_match():
|
||||
gtfs_geometry, gtfs_bbox = geometry_json_and_bbox(LineString([(13.30, 52.40), (13.35, 52.45), (13.40, 52.50)]))
|
||||
osm_geometry, osm_bbox = geometry_json_and_bbox(LineString([(13.3005, 52.4005), (13.3505, 52.4505), (13.4005, 52.5005)]))
|
||||
route = GtfsRoute(
|
||||
route_id="bus-2-berlin",
|
||||
short_name="2",
|
||||
mode="bus",
|
||||
min_lon=gtfs_bbox[0],
|
||||
min_lat=gtfs_bbox[1],
|
||||
max_lon=gtfs_bbox[2],
|
||||
max_lat=gtfs_bbox[3],
|
||||
route_key="2",
|
||||
geometry_geojson=gtfs_geometry,
|
||||
)
|
||||
feature = OsmFeature(
|
||||
osm_type="relation",
|
||||
osm_id="2-berlin",
|
||||
kind="route",
|
||||
mode="bus",
|
||||
ref="2",
|
||||
min_lon=osm_bbox[0],
|
||||
min_lat=osm_bbox[1],
|
||||
max_lon=osm_bbox[2],
|
||||
max_lat=osm_bbox[3],
|
||||
route_key="2",
|
||||
geometry_geojson=osm_geometry,
|
||||
)
|
||||
|
||||
score, reasons = score_route_pair(route, feature)
|
||||
|
||||
assert score >= 90
|
||||
assert reasons["strong_identity"] == "exact_ref_mode_geometry_overlap"
|
||||
assert reasons["geometry"]["gtfs_on_osm_ratio"] >= 0.9
|
||||
|
||||
|
||||
def test_route_match_scope_distinguishes_outside_loaded_osm_area():
|
||||
route = GtfsRoute(min_lon=13.3, min_lat=52.4, max_lon=13.4, max_lat=52.5)
|
||||
assert route_match_scope(route, (13.0, 52.3, 13.8, 52.7)) == "in_osm_scope"
|
||||
assert route_match_scope(route, (6.0, 50.0, 7.0, 51.0)) == "outside_osm_scope"
|
||||
Reference in New Issue
Block a user