318 lines
9.8 KiB
Python
318 lines
9.8 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
|
|
from shapely.geometry import LineString
|
|
from sqlalchemy import select
|
|
|
|
from app.db import init_db, session_scope
|
|
from app.models import GtfsRoute, OsmFeature, RouteMatch, RoutePattern
|
|
from app.osm_classification import infer_osm_route_scope
|
|
from app.pipeline.gtfs import _gtfs_mode
|
|
from app.pipeline.matcher import _build_osm_route_index, _candidate_osm_routes, route_match_scope, run_route_matching, score_route_pair
|
|
from app.pipeline.osm_addresses import _address_area_geometry_geojson
|
|
from app.pipeline.sample_data import load_sample_project
|
|
from app.pipeline.route_layer import rebuild_route_layer
|
|
from app.pipeline.utils import geometry_json_and_bbox
|
|
|
|
|
|
def test_sample_pipeline_imports_and_matches():
|
|
init_db()
|
|
with session_scope() as session:
|
|
result = load_sample_project(session)
|
|
assert result["match_result"]["routes"] == 6
|
|
assert session.scalar(select(GtfsRoute).where(GtfsRoute.short_name == "RE1")) is not None
|
|
assert session.scalar(select(OsmFeature).where(OsmFeature.ref == "RE1")) is not None
|
|
statuses = {row[0]: row[1] for row in session.execute(select(RouteMatch.status, RouteMatch.confidence))}
|
|
assert "matched" in statuses
|
|
assert set(statuses) & {"weak", "missing"}
|
|
|
|
|
|
def test_route_matching_preserves_unchanged_match_rows():
|
|
init_db()
|
|
with session_scope() as session:
|
|
load_sample_project(session)
|
|
before = {
|
|
route_id: (match_id, updated_at)
|
|
for route_id, match_id, updated_at in session.execute(
|
|
select(RouteMatch.gtfs_route_id, RouteMatch.id, RouteMatch.updated_at)
|
|
)
|
|
}
|
|
|
|
result = run_route_matching(session)
|
|
|
|
after = {
|
|
route_id: (match_id, updated_at)
|
|
for route_id, match_id, updated_at in session.execute(
|
|
select(RouteMatch.gtfs_route_id, RouteMatch.id, RouteMatch.updated_at)
|
|
)
|
|
}
|
|
assert result["unchanged"] == result["routes"]
|
|
assert result["created"] == 0
|
|
assert result["updated"] == 0
|
|
assert after == before
|
|
|
|
|
|
def test_route_layer_reuses_unchanged_route_patterns():
|
|
init_db()
|
|
with session_scope() as session:
|
|
load_sample_project(session)
|
|
before = {
|
|
pattern_key: pattern_id
|
|
for pattern_key, pattern_id in session.execute(select(RoutePattern.pattern_key, RoutePattern.id))
|
|
}
|
|
|
|
result = rebuild_route_layer(session)
|
|
|
|
after = {
|
|
pattern_key: pattern_id
|
|
for pattern_key, pattern_id in session.execute(select(RoutePattern.pattern_key, RoutePattern.id))
|
|
}
|
|
assert result["route_patterns_created"] == 0
|
|
assert result["route_patterns_removed"] == 0
|
|
assert result["route_patterns_reused"] == result["route_patterns"]
|
|
assert after == before
|
|
|
|
|
|
def test_extended_gtfs_route_types_are_mapped_to_modes():
|
|
assert _gtfs_mode(700) == "bus"
|
|
assert _gtfs_mode(100) == "train"
|
|
assert _gtfs_mode(109) == "train"
|
|
assert _gtfs_mode(900) == "tram"
|
|
assert _gtfs_mode(1000) == "ferry"
|
|
|
|
|
|
def test_closed_address_way_is_stored_as_polygon_geometry():
|
|
geometry = _address_area_geometry_geojson(
|
|
[
|
|
(8.68590, 49.40435),
|
|
(8.68600, 49.40435),
|
|
(8.68600, 49.40445),
|
|
(8.68590, 49.40445),
|
|
(8.68590, 49.40435),
|
|
]
|
|
)
|
|
|
|
assert json.loads(geometry or "{}") == {
|
|
"type": "Polygon",
|
|
"coordinates": [
|
|
[
|
|
[8.6859, 49.40435],
|
|
[8.686, 49.40435],
|
|
[8.686, 49.40445],
|
|
[8.6859, 49.40445],
|
|
[8.6859, 49.40435],
|
|
]
|
|
],
|
|
}
|
|
geometry = _address_area_geometry_geojson(
|
|
[
|
|
(8.68590, 49.40435),
|
|
(8.68600, 49.40435),
|
|
(8.68600, 49.40445),
|
|
(8.68590, 49.40445),
|
|
],
|
|
closed=True,
|
|
)
|
|
assert json.loads(geometry or "{}")["coordinates"][0][-1] == [8.6859, 49.40435]
|
|
assert _address_area_geometry_geojson([(0, 0), (1, 0), (1, 1)]) is None
|
|
assert _address_area_geometry_geojson([(0, 0), (1, 0), (1, 1)], closed=False) is None
|
|
|
|
|
|
def test_osm_route_scope_classifier_distinguishes_train_service_classes():
|
|
assert infer_osm_route_scope(mode="train", ref="ICE 28") == "long_distance"
|
|
assert infer_osm_route_scope(mode="train", ref="RE1") == "regional"
|
|
assert infer_osm_route_scope(mode="train", ref="S5", network="S-Bahn Berlin") == "local"
|
|
assert infer_osm_route_scope(mode="subway", ref="U5") == "local"
|
|
assert infer_osm_route_scope(mode="coach", ref="FLX") == "long_distance"
|
|
assert infer_osm_route_scope(mode="bus", ref="100") == "local"
|
|
assert infer_osm_route_scope(mode="bus", ref="800", tags={"bus": "regional"}) == "regional"
|
|
assert infer_osm_route_scope(mode="bus", name="FlixBus Berlin Hamburg") == "long_distance"
|
|
|
|
|
|
def test_exact_line_ref_with_overlapping_geometry_scores_as_match_candidate():
|
|
route = GtfsRoute(
|
|
route_id="17441_700",
|
|
short_name="M11",
|
|
long_name=None,
|
|
mode="bus",
|
|
operator_name="Verkehrsverbund Berlin-Brandenburg",
|
|
min_lon=13.29,
|
|
min_lat=52.42,
|
|
max_lon=13.33,
|
|
max_lat=52.45,
|
|
route_key="m11",
|
|
)
|
|
feature = OsmFeature(
|
|
osm_type="relation",
|
|
osm_id="123",
|
|
kind="route",
|
|
mode="bus",
|
|
ref="M11",
|
|
name="Bus M11: U Dahlem-Dorf => S Schöneweide/Sterndamm",
|
|
operator="Berliner Verkehrsbetriebe",
|
|
network="Verkehrsverbund Berlin-Brandenburg",
|
|
min_lon=13.28,
|
|
min_lat=52.40,
|
|
max_lon=13.51,
|
|
max_lat=52.46,
|
|
route_key="m11",
|
|
)
|
|
|
|
score, reasons = score_route_pair(route, feature)
|
|
|
|
assert score >= 85
|
|
assert reasons["line_identity"] == "exact_ref_mode_bbox_overlap"
|
|
|
|
|
|
def test_exact_line_ref_with_bbox_overlap_is_strong_without_name_or_operator_match():
|
|
route = GtfsRoute(
|
|
route_id="route-1",
|
|
short_name="M11",
|
|
long_name="",
|
|
mode="bus",
|
|
operator_name="",
|
|
min_lon=13.29,
|
|
min_lat=52.42,
|
|
max_lon=13.33,
|
|
max_lat=52.45,
|
|
route_key="m11",
|
|
)
|
|
feature = OsmFeature(
|
|
osm_type="relation",
|
|
osm_id="456",
|
|
kind="route",
|
|
mode="bus",
|
|
ref="M11",
|
|
name="",
|
|
operator="",
|
|
network="",
|
|
min_lon=13.30,
|
|
min_lat=52.43,
|
|
max_lon=13.35,
|
|
max_lat=52.46,
|
|
route_key="m11",
|
|
)
|
|
|
|
score, reasons = score_route_pair(route, feature)
|
|
|
|
assert score >= 88
|
|
assert reasons["strong_identity"] == "exact_ref_mode_bbox_overlap"
|
|
|
|
|
|
def test_common_short_ref_candidates_are_spatially_ranked():
|
|
route = GtfsRoute(
|
|
route_id="bus-2-berlin",
|
|
short_name="2",
|
|
mode="bus",
|
|
min_lon=13.30,
|
|
min_lat=52.40,
|
|
max_lon=13.40,
|
|
max_lat=52.50,
|
|
route_key="2",
|
|
)
|
|
far = OsmFeature(
|
|
id=1,
|
|
osm_type="relation",
|
|
osm_id="far",
|
|
kind="route",
|
|
mode="bus",
|
|
ref="2",
|
|
min_lon=7.0,
|
|
min_lat=50.0,
|
|
max_lon=7.1,
|
|
max_lat=50.1,
|
|
route_key="2",
|
|
)
|
|
near = OsmFeature(
|
|
id=2,
|
|
osm_type="relation",
|
|
osm_id="near",
|
|
kind="route",
|
|
mode="bus",
|
|
ref="2",
|
|
min_lon=13.31,
|
|
min_lat=52.41,
|
|
max_lon=13.39,
|
|
max_lat=52.49,
|
|
route_key="2",
|
|
)
|
|
|
|
candidates = _candidate_osm_routes(route, _build_osm_route_index([far, near]))
|
|
|
|
assert candidates[0].osm_id == "near"
|
|
|
|
|
|
def test_exact_ref_far_away_is_not_promoted_without_spatial_or_geometry_evidence():
|
|
route = GtfsRoute(
|
|
route_id="bus-2-berlin",
|
|
short_name="2",
|
|
mode="bus",
|
|
operator_name="Example Operator",
|
|
min_lon=13.30,
|
|
min_lat=52.40,
|
|
max_lon=13.40,
|
|
max_lat=52.50,
|
|
route_key="2",
|
|
)
|
|
feature = OsmFeature(
|
|
osm_type="relation",
|
|
osm_id="2-cologne",
|
|
kind="route",
|
|
mode="bus",
|
|
ref="2",
|
|
operator="Example Operator",
|
|
min_lon=6.9,
|
|
min_lat=50.9,
|
|
max_lon=7.1,
|
|
max_lat=51.0,
|
|
route_key="2",
|
|
)
|
|
|
|
score, reasons = score_route_pair(route, feature)
|
|
|
|
assert score < 65
|
|
assert reasons["spatial_penalty"] == "exact_ref_far_bbox_center"
|
|
assert reasons["spatial_cap"] == "exact_ref_far_without_geometry_overlap"
|
|
|
|
|
|
def test_geometry_overlap_can_confirm_exact_ref_match():
|
|
gtfs_geometry, gtfs_bbox = geometry_json_and_bbox(LineString([(13.30, 52.40), (13.35, 52.45), (13.40, 52.50)]))
|
|
osm_geometry, osm_bbox = geometry_json_and_bbox(LineString([(13.3005, 52.4005), (13.3505, 52.4505), (13.4005, 52.5005)]))
|
|
route = GtfsRoute(
|
|
route_id="bus-2-berlin",
|
|
short_name="2",
|
|
mode="bus",
|
|
min_lon=gtfs_bbox[0],
|
|
min_lat=gtfs_bbox[1],
|
|
max_lon=gtfs_bbox[2],
|
|
max_lat=gtfs_bbox[3],
|
|
route_key="2",
|
|
geometry_geojson=gtfs_geometry,
|
|
)
|
|
feature = OsmFeature(
|
|
osm_type="relation",
|
|
osm_id="2-berlin",
|
|
kind="route",
|
|
mode="bus",
|
|
ref="2",
|
|
min_lon=osm_bbox[0],
|
|
min_lat=osm_bbox[1],
|
|
max_lon=osm_bbox[2],
|
|
max_lat=osm_bbox[3],
|
|
route_key="2",
|
|
geometry_geojson=osm_geometry,
|
|
)
|
|
|
|
score, reasons = score_route_pair(route, feature)
|
|
|
|
assert score >= 90
|
|
assert reasons["strong_identity"] == "exact_ref_mode_geometry_overlap"
|
|
assert reasons["geometry"]["gtfs_on_osm_ratio"] >= 0.9
|
|
|
|
|
|
def test_route_match_scope_distinguishes_outside_loaded_osm_area():
|
|
route = GtfsRoute(min_lon=13.3, min_lat=52.4, max_lon=13.4, max_lat=52.5)
|
|
assert route_match_scope(route, (13.0, 52.3, 13.8, 52.7)) == "in_osm_scope"
|
|
assert route_match_scope(route, (6.0, 50.0, 7.0, 51.0)) == "outside_osm_scope"
|