Alpha stage commit
This commit is contained in:
13
tests/conftest.py
Normal file
13
tests/conftest.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
TEST_DATA_DIR = Path("./data/test-runtime")
|
||||
|
||||
shutil.rmtree(TEST_DATA_DIR, ignore_errors=True)
|
||||
os.environ["QUEUE_WORKER_AUTOSTART"] = "false"
|
||||
os.environ["DATA_DIR"] = str(TEST_DATA_DIR)
|
||||
os.environ["DATABASE_URL"] = f"sqlite:///{TEST_DATA_DIR / 'test_workbench.sqlite'}"
|
||||
33
tests/test_address_search.py
Normal file
33
tests/test_address_search.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.address_search import _folded_query_candidates, _numbered_query_candidates, coordinate_token, parse_coordinate_token
|
||||
|
||||
|
||||
def test_numbered_address_query_accepts_city_first_without_comma():
|
||||
assert ("alexanderplatz", "1", "berlin") in _numbered_query_candidates("Berlin Alexanderplatz 1")
|
||||
|
||||
|
||||
def test_numbered_address_query_accepts_city_last_without_comma():
|
||||
assert ("alexanderplatz", "1", "berlin") in _numbered_query_candidates("Alexanderplatz 1 Berlin")
|
||||
|
||||
|
||||
def test_numbered_address_query_prefers_comma_locality():
|
||||
assert _numbered_query_candidates("Berlin, Alexanderplatz 1")[0] == ("alexanderplatz", "1", "berlin")
|
||||
|
||||
|
||||
def test_folded_address_query_accepts_city_first_without_comma():
|
||||
assert ("alexanderplatz", "berlin") in _folded_query_candidates("Berlin Alexanderplatz")
|
||||
|
||||
|
||||
def test_folded_address_query_accepts_city_last_without_comma():
|
||||
assert ("alexanderplatz", "berlin") in _folded_query_candidates("Alexanderplatz Berlin")
|
||||
|
||||
|
||||
def test_folded_address_query_prefers_comma_locality():
|
||||
assert _folded_query_candidates("Berlin, Alexanderplatz")[0] == ("alexanderplatz", "berlin")
|
||||
|
||||
|
||||
def test_coordinate_token_round_trips():
|
||||
token = coordinate_token(49.404539659, 8.685940101)
|
||||
assert token == "coord:49.4045397:8.6859401"
|
||||
assert parse_coordinate_token(token) == (49.4045397, 8.6859401)
|
||||
666
tests/test_api.py
Normal file
666
tests/test_api.py
Normal file
@@ -0,0 +1,666 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy import select
|
||||
|
||||
import app.jobs as jobs_module
|
||||
import app.main as main_module
|
||||
from app.config import settings
|
||||
from app.db import init_db, session_scope
|
||||
from app.db_lock import DatabaseWriteBusy, database_write_lock
|
||||
from app.jobs import run_worker_once
|
||||
from app.main import app
|
||||
from app.models import Dataset, GtfsRoute, Job, Source
|
||||
from app.source_catalog import import_ingestable_sources
|
||||
|
||||
|
||||
def test_api_sample_and_geojson():
|
||||
client = TestClient(app)
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert "Mobility Workbench" in response.text
|
||||
assert "GTFS Harmonization" in response.text
|
||||
assert "Mapping Data" in response.text
|
||||
assert "journeyTransitSnapshot" in response.text
|
||||
assert "journeySource" not in response.text
|
||||
|
||||
response = client.post("/api/sample/reset")
|
||||
assert response.status_code == 200
|
||||
stats = client.get("/api/stats").json()
|
||||
assert stats["gtfs_routes"] == 6
|
||||
assert stats["osm_routes"] == 6
|
||||
geojson = client.get("/api/map/gtfs_routes.geojson").json()
|
||||
assert geojson["type"] == "FeatureCollection"
|
||||
assert len(geojson["features"]) == 6
|
||||
matched_geojson = client.get("/api/map/matched_gtfs_routes.geojson?status=matched").json()
|
||||
assert matched_geojson["features"]
|
||||
assert {feature["properties"]["visual_source"] for feature in matched_geojson["features"]} == {"osm"}
|
||||
filtered = client.get("/api/map/osm_features.geojson?kind=route&mode=tram&bbox=13.3,52.4,13.5,52.6").json()
|
||||
assert filtered["type"] == "FeatureCollection"
|
||||
assert {feature["properties"]["ref"] for feature in filtered["features"]} == {"M5", "M10"}
|
||||
source_filtered_gtfs = client.get("/api/map/gtfs_routes.geojson?source_id=1").json()
|
||||
assert len(source_filtered_gtfs["features"]) == 6
|
||||
source_filtered_osm = client.get("/api/map/osm_features.geojson?source_id=2&kind=route&mode=tram").json()
|
||||
assert {feature["properties"]["ref"] for feature in source_filtered_osm["features"]} == {"M5", "M10"}
|
||||
route_layer = client.post("/api/route-layer/build").json()
|
||||
assert route_layer["route_patterns"] > 0
|
||||
assert client.get("/api/stats").json()["route_patterns"] == route_layer["route_patterns"]
|
||||
regional_osm = client.get("/api/map/osm_features.geojson?kind=route&mode=train&route_scope=regional").json()
|
||||
assert {feature["properties"]["ref"] for feature in regional_osm["features"]} == {"RE1"}
|
||||
regional_patterns = client.get("/api/map/route_patterns.geojson?mode=train&source_kind=osm&route_scope=regional").json()
|
||||
assert {feature["properties"]["ref"] for feature in regional_patterns["features"]} == {"RE1"}
|
||||
local_patterns = client.get("/api/map/route_patterns.geojson?mode=subway&source_kind=osm&route_scope=local").json()
|
||||
assert {feature["properties"]["ref"] for feature in local_patterns["features"]} == {"U2"}
|
||||
local_bus_patterns = client.get("/api/map/route_patterns.geojson?mode=bus&source_kind=osm&route_scope=local").json()
|
||||
assert {feature["properties"]["ref"] for feature in local_bus_patterns["features"]} == {"100"}
|
||||
|
||||
|
||||
def test_journey_demo_direct_and_one_transfer():
|
||||
client = TestClient(app)
|
||||
assert client.post("/api/sample/reset").status_code == 200
|
||||
|
||||
hbf = _first_stop(client, "Hauptbahnhof")
|
||||
alex = _first_stop(client, "Alexanderplatz")
|
||||
direct = client.get(f"/api/journey/search?from_stop_id={hbf['id']}&to_stop_id={alex['id']}&departure=08:00&max_transfers=0").json()
|
||||
assert direct["journeys"]
|
||||
assert direct["journeys"][0]["transfers"] == 0
|
||||
assert direct["journeys"][0]["legs"][0]["route_ref"] in {"RE1", "M5"}
|
||||
coords = direct["journeys"][0]["features"]["features"][0]["geometry"]["coordinates"]
|
||||
assert coords[-1] == [13.4132, 52.5219]
|
||||
assert [13.4344, 52.51] not in coords
|
||||
stop_roles = {
|
||||
feature["properties"]["role"]
|
||||
for feature in direct["journeys"][0]["features"]["features"]
|
||||
if feature["geometry"]["type"] == "Point"
|
||||
}
|
||||
assert {"start", "end", "passed"} <= stop_roles
|
||||
|
||||
zoo = _first_stop(client, "Zoologischer")
|
||||
ost = _first_stop(client, "Ostbahnhof")
|
||||
transfer = client.get(
|
||||
f"/api/journey/search?from_stop_id={zoo['id']}&to_stop_id={ost['id']}&departure=08:00&max_transfers=1&transfer_seconds=0"
|
||||
).json()
|
||||
assert transfer["journeys"]
|
||||
assert transfer["journeys"][0]["transfers"] == 1
|
||||
assert [leg["route_ref"] for leg in transfer["journeys"][0]["legs"]] == ["100", "RE1"]
|
||||
|
||||
|
||||
def test_route_layer_job_endpoint_completes():
|
||||
client = TestClient(app)
|
||||
assert client.post("/api/sample/reset").status_code == 200
|
||||
|
||||
queued = client.post("/api/jobs/route-layer-build").json()
|
||||
assert queued["kind"] == "route_layer_rebuild"
|
||||
assert queued["status"] == "queued"
|
||||
assert queued["priority"] == 0
|
||||
|
||||
worker = run_worker_once(worker_id="test-worker")
|
||||
assert worker["processed"] == 1
|
||||
job = client.get(f"/api/jobs/{queued['id']}").json()
|
||||
|
||||
assert job["status"] == "completed"
|
||||
assert job["result"]["route_patterns"] > 0
|
||||
events = client.get(f"/api/jobs/{queued['id']}/events").json()
|
||||
assert [event["event_type"] for event in events["events"]][-1] == "completed"
|
||||
|
||||
|
||||
def test_route_matching_job_endpoint_completes():
|
||||
client = TestClient(app)
|
||||
assert client.post("/api/sample/reset").status_code == 200
|
||||
|
||||
queued = client.post("/api/jobs/match-run").json()
|
||||
assert queued["kind"] == "route_matching"
|
||||
assert queued["status"] == "queued"
|
||||
|
||||
worker = run_worker_once(worker_id="test-worker")
|
||||
assert worker["processed"] == 1
|
||||
job = client.get(f"/api/jobs/{queued['id']}").json()
|
||||
|
||||
assert job["status"] == "completed"
|
||||
assert job["result"]["routes"] == 6
|
||||
assert job["result"]["matches"] > 0
|
||||
events = client.get(f"/api/jobs/{queued['id']}/events").json()
|
||||
event_types = [event["event_type"] for event in events["events"]]
|
||||
assert "route_matching_batch" in event_types
|
||||
assert event_types[-1] == "completed"
|
||||
|
||||
|
||||
def test_qa_summary_endpoint_exposes_harmonization_sections():
|
||||
client = TestClient(app)
|
||||
assert client.post("/api/sample/reset").status_code == 200
|
||||
|
||||
summary = client.get("/api/qa/summary").json()
|
||||
|
||||
assert summary["decision"]["deployment"] == "same_workbench_for_now"
|
||||
section_ids = {section["id"] for section in summary["sections"]}
|
||||
assert {
|
||||
"source_discovery",
|
||||
"import_health",
|
||||
"gtfs_validation",
|
||||
"deduplication",
|
||||
"route_quality",
|
||||
"publication_readiness",
|
||||
} <= section_ids
|
||||
gtfs_section = next(section for section in summary["sections"] if section["id"] == "gtfs_validation")
|
||||
assert any(item["label"] == "Routes" for item in gtfs_section["items"])
|
||||
|
||||
|
||||
def test_gtfs_harmonization_inventory_and_detail():
|
||||
client = TestClient(app)
|
||||
assert client.post("/api/sample/reset").status_code == 200
|
||||
|
||||
inventory = client.get("/api/harmonization/gtfs/inventory").json()
|
||||
assert inventory["summary"]["sources"] == 1
|
||||
assert inventory["summary"]["active_sources"] == 1
|
||||
feed = inventory["feeds"][0]
|
||||
assert feed["source"]["name"] == "Sample Berlin GTFS"
|
||||
assert feed["active_dataset"]["counts"]["routes"] == 6
|
||||
assert feed["validation"]["items"]
|
||||
assert feed["service"]["items"]
|
||||
|
||||
detail = client.get(f"/api/harmonization/gtfs/sources/{feed['source']['id']}").json()
|
||||
assert detail["source"]["id"] == feed["source"]["id"]
|
||||
assert {section["id"] for section in detail["sections"]} == {"validation", "service", "overlap", "license"}
|
||||
assert all({"id", "severity", "title", "detail"} <= set(issue) for issue in detail["issues"])
|
||||
assert detail["qa_status"] in {"ready", "needs_review", "blocked"}
|
||||
|
||||
reviewed = client.patch(
|
||||
f"/api/harmonization/gtfs/sources/{feed['source']['id']}/review",
|
||||
json={"license": "CC-BY-4.0", "review_status": "approved", "review_note": "Operator publication allowed.", "enabled": True},
|
||||
).json()
|
||||
assert reviewed["source"]["license"] == "CC-BY-4.0"
|
||||
assert reviewed["source"]["qa_review"]["status"] == "approved"
|
||||
assert reviewed["source"]["qa_review"]["note"] == "Operator publication allowed."
|
||||
assert reviewed["source"]["enabled"] is True
|
||||
|
||||
|
||||
def test_terminal_jobs_can_be_dismissed_from_default_view():
|
||||
client = TestClient(app)
|
||||
assert client.post("/api/sample/reset").status_code == 200
|
||||
|
||||
queued = client.post("/api/jobs/route-layer-build").json()
|
||||
assert run_worker_once(worker_id="test-worker")["processed"] == 1
|
||||
|
||||
listed = client.get("/api/jobs").json()
|
||||
assert any(job["id"] == queued["id"] for job in listed["jobs"])
|
||||
|
||||
dismissed = client.post(f"/api/jobs/{queued['id']}/dismiss").json()
|
||||
assert dismissed["dismissed_at"]
|
||||
|
||||
hidden = client.get("/api/jobs").json()
|
||||
assert all(job["id"] != queued["id"] for job in hidden["jobs"])
|
||||
|
||||
visible = client.get("/api/jobs?include_dismissed=true").json()
|
||||
assert any(job["id"] == queued["id"] for job in visible["jobs"])
|
||||
|
||||
|
||||
def test_jobs_revision_endpoint_reports_changes():
|
||||
client = TestClient(app)
|
||||
assert client.post("/api/sample/reset").status_code == 200
|
||||
|
||||
initial = client.get("/api/jobs/revision").json()
|
||||
assert initial["changed"] is True
|
||||
assert initial["revision"]
|
||||
assert initial["job_revision"]
|
||||
assert "workers" in initial
|
||||
|
||||
queued = client.post("/api/jobs/route-layer-build").json()
|
||||
changed = client.get("/api/jobs/revision", params={"since": initial["revision"]}).json()
|
||||
assert changed["changed"] is True
|
||||
assert changed["latest_job_id"] >= queued["id"]
|
||||
assert changed["job_count"] >= 1
|
||||
|
||||
unchanged = client.get("/api/jobs/revision", params={"since": changed["revision"]}).json()
|
||||
assert unchanged["changed"] is False
|
||||
|
||||
listed = client.get("/api/jobs").json()
|
||||
assert listed["revision"] == unchanged["revision"]
|
||||
assert listed["jobs"]
|
||||
|
||||
|
||||
def test_nearest_location_skips_address_lookup_while_address_index_rebuilds(monkeypatch):
|
||||
client = TestClient(app)
|
||||
assert client.post("/api/sample/reset").status_code == 200
|
||||
with session_scope() as session:
|
||||
session.add(Job(kind="address_index_rebuild", status="running", description="test address rebuild"))
|
||||
session.commit()
|
||||
|
||||
def fail_address_lookup(**_kwargs):
|
||||
raise AssertionError("address lookup should be skipped while address index rebuilds")
|
||||
|
||||
monkeypatch.setattr(main_module, "address_at_point", fail_address_lookup)
|
||||
response = client.get("/api/journey/nearest-location?lat=0&lon=0")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["selection_kind"] == "coordinate"
|
||||
assert data["address_lookup_skipped"] is True
|
||||
assert "Address index rebuild" in data["message"]
|
||||
|
||||
|
||||
def test_job_queue_controls_for_queued_job():
|
||||
client = TestClient(app)
|
||||
assert client.post("/api/sample/reset").status_code == 200
|
||||
|
||||
queued = client.post("/api/jobs/route-layer-build?priority=5").json()
|
||||
assert queued["status"] == "queued"
|
||||
assert queued["priority"] == 5
|
||||
|
||||
priority = client.post(f"/api/jobs/{queued['id']}/priority", json={"priority": 20}).json()
|
||||
assert priority["priority"] == 20
|
||||
|
||||
paused = client.post(f"/api/jobs/{queued['id']}/pause").json()
|
||||
assert paused["status"] == "paused"
|
||||
|
||||
idle_worker = run_worker_once(worker_id="test-worker")
|
||||
assert idle_worker["processed"] == 0
|
||||
|
||||
resumed = client.post(f"/api/jobs/{queued['id']}/resume").json()
|
||||
assert resumed["status"] == "queued"
|
||||
|
||||
stopped = client.post(f"/api/jobs/{queued['id']}/stop").json()
|
||||
assert stopped["status"] == "cancelled"
|
||||
|
||||
retried = client.post(f"/api/jobs/{queued['id']}/retry").json()
|
||||
assert retried["status"] == "queued"
|
||||
assert retried["error"] is None
|
||||
|
||||
|
||||
def test_worker_once_returns_idle_when_claim_is_busy(monkeypatch):
|
||||
def busy_claim(_worker_id):
|
||||
raise DatabaseWriteBusy("job:claim", {"operation": "update source"})
|
||||
|
||||
monkeypatch.setattr(jobs_module, "claim_next_job", busy_claim)
|
||||
|
||||
assert jobs_module.run_worker_once(worker_id="test-worker") == {"worker_id": "test-worker", "processed": 0}
|
||||
|
||||
|
||||
def test_running_job_can_be_stopped_while_write_lock_is_held():
|
||||
client = TestClient(app)
|
||||
assert client.post("/api/sample/reset").status_code == 200
|
||||
queued = client.post("/api/jobs/route-layer-build").json()
|
||||
|
||||
with session_scope() as session:
|
||||
job = session.get(Job, queued["id"])
|
||||
job.status = "running"
|
||||
job.lease_owner = "test-worker"
|
||||
|
||||
with database_write_lock("job:route_layer_rebuild:test"):
|
||||
response = client.post(f"/api/jobs/{queued['id']}/stop")
|
||||
|
||||
assert response.status_code == 200
|
||||
stopped = response.json()
|
||||
assert stopped["id"] == queued["id"]
|
||||
assert stopped["requested_action"] == "cancel"
|
||||
|
||||
|
||||
def test_itinerary_generation_and_leg_locking():
|
||||
client = TestClient(app)
|
||||
assert client.post("/api/sample/reset").status_code == 200
|
||||
hbf = _first_stop(client, "Hauptbahnhof")
|
||||
alex = _first_stop(client, "Alexanderplatz")
|
||||
|
||||
generated = client.post(
|
||||
"/api/itineraries/generate",
|
||||
json={
|
||||
"from_stop_id": hbf["id"],
|
||||
"to_stop_id": alex["id"],
|
||||
"departure": "08:00",
|
||||
"service_date": "2026-06-27",
|
||||
"max_transfers": 1,
|
||||
"transfer_seconds": 120,
|
||||
"limit": 2,
|
||||
},
|
||||
).json()
|
||||
|
||||
assert generated["request"]["service_date"] == "2026-06-27"
|
||||
assert any(item["family"] == "public_transport" for item in generated["itineraries"])
|
||||
assert any(item["family"] == "flight_access" for item in generated["itineraries"])
|
||||
public = next(item for item in generated["itineraries"] if item["family"] == "public_transport")
|
||||
saved = client.post(f"/api/itineraries/{public['id']}/save", json={"saved": True}).json()
|
||||
assert saved["saved"] is True
|
||||
leg_id = saved["legs"][0]["id"]
|
||||
locked = client.post(f"/api/itinerary-legs/{leg_id}/lock", json={"locked": True}).json()
|
||||
assert locked["locked"] is True
|
||||
recent = client.get("/api/itineraries?saved_only=true").json()
|
||||
assert any(item["id"] == public["id"] for item in recent["itineraries"])
|
||||
|
||||
|
||||
def test_geofabrik_catalog_source_creation(monkeypatch):
|
||||
from app import main
|
||||
from app.geofabrik import create_geofabrik_source
|
||||
|
||||
client = TestClient(app)
|
||||
assert client.post("/api/sample/reset").status_code == 200
|
||||
fake_entry = {
|
||||
"id": "berlin",
|
||||
"name": "Berlin",
|
||||
"parent": "germany",
|
||||
"country_codes": ["DE"],
|
||||
"pbf_url": "https://download.geofabrik.de/europe/germany/berlin-latest.osm.pbf",
|
||||
"updates_url": "https://download.geofabrik.de/europe/germany/berlin-updates",
|
||||
"taginfo_url": "https://taginfo.geofabrik.de/europe:germany:berlin",
|
||||
"urls": {},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(main, "geofabrik_catalog", lambda q=None, limit=80: [fake_entry])
|
||||
monkeypatch.setattr("app.geofabrik.geofabrik_entry", lambda geofabrik_id: fake_entry if geofabrik_id == "berlin" else None)
|
||||
|
||||
catalog = client.get("/api/geofabrik/catalog?q=berlin").json()
|
||||
assert catalog["entries"][0]["id"] == "berlin"
|
||||
created = client.post(
|
||||
"/api/geofabrik/sources",
|
||||
json={"geofabrik_id": "berlin", "import_updates": True, "run_import": False},
|
||||
).json()
|
||||
assert created["source"]["kind"] == "osm_pbf"
|
||||
assert "berlin-latest.osm.pbf" in created["source"]["url"]
|
||||
|
||||
|
||||
def test_source_management_and_match_candidates():
|
||||
client = TestClient(app)
|
||||
assert client.post("/api/sample/reset").status_code == 200
|
||||
|
||||
stats = client.get("/api/stats").json()
|
||||
assert stats["match_summary"]["missing"] + stats["match_summary"]["weak"] >= 1
|
||||
sources = client.get("/api/sources").json()
|
||||
gtfs_source = next(source for source in sources if source["kind"] == "gtfs")
|
||||
assert gtfs_source["stats"]["routes"] == 6
|
||||
assert gtfs_source["datasets"][0]["stats"]["stop_times"] == 20
|
||||
|
||||
match = client.get("/api/matches?limit=1").json()[0]
|
||||
candidates = client.get(f"/api/matches/{match['id']}/candidates").json()
|
||||
assert candidates["route"]["id"] == match["gtfs"]["id"]
|
||||
assert candidates["route"]["geometry"]["present"] is True
|
||||
assert candidates["candidates"]
|
||||
assert "score" in candidates["candidates"][0]
|
||||
assert candidates["candidates"][0]["osm"]["geometry"]["present"] is True
|
||||
assert candidates["preview"]["type"] == "FeatureCollection"
|
||||
preview_roles = {feature["properties"]["preview_role"] for feature in candidates["preview"]["features"]}
|
||||
assert {"gtfs_route", "candidate"} <= preview_roles
|
||||
candidate_preview = next(feature for feature in candidates["preview"]["features"] if feature["properties"]["preview_role"] == "candidate")
|
||||
assert "candidate_score" in candidate_preview["properties"]
|
||||
picked = candidates["candidates"][0]
|
||||
accepted = client.post(f"/api/matches/{match['id']}/candidates/{picked['osm']['id']}/accept").json()
|
||||
assert accepted["status"] == "accepted"
|
||||
assert accepted["match"]["osm"]["osm_type"] == picked["osm"]["osm_type"]
|
||||
assert accepted["match"]["osm"]["osm_id"] == picked["osm"]["osm_id"]
|
||||
|
||||
search = client.get("/api/datasets/search?q=M5&active_only=true").json()
|
||||
assert search["gtfs_routes"]
|
||||
assert search["osm_routes"]
|
||||
m5_route = next(item for item in search["gtfs_routes"] if item["route"]["ref"] == "M5")
|
||||
assert m5_route["timetable"]["stop_times"] > 0
|
||||
assert m5_route["geometry"]["present"] is True
|
||||
feature = client.get(f"/api/datasets/search/feature.geojson?type=gtfs_route&id={m5_route['route']['id']}").json()
|
||||
assert feature["features"]
|
||||
assert feature["features"][0]["properties"]["search_result_type"] == "gtfs_route"
|
||||
|
||||
update_check = client.post(f"/api/sources/{gtfs_source['id']}/check-update").json()
|
||||
assert update_check["status"] == "checked"
|
||||
assert update_check["update_available"] is False
|
||||
update_result = client.post(f"/api/sources/{gtfs_source['id']}/update").json()
|
||||
assert update_result["status"] == "skipped"
|
||||
history = client.get(f"/api/sources/{gtfs_source['id']}/update-checks").json()
|
||||
assert history["checks"]
|
||||
|
||||
response = client.delete(f"/api/sources/{gtfs_source['id']}")
|
||||
assert response.status_code == 200
|
||||
delete_job = response.json()
|
||||
assert delete_job["kind"] == "source_delete"
|
||||
assert delete_job["status"] == "queued"
|
||||
duplicate = client.delete(f"/api/sources/{gtfs_source['id']}").json()
|
||||
assert duplicate["id"] == delete_job["id"]
|
||||
|
||||
worker = run_worker_once(worker_id="test-worker")
|
||||
assert worker["processed"] == 1
|
||||
completed = client.get(f"/api/jobs/{delete_job['id']}").json()
|
||||
assert completed["status"] == "completed"
|
||||
assert completed["result"]["delete_result"]["deleted"] is True
|
||||
stats_after_delete = client.get("/api/stats").json()
|
||||
assert stats_after_delete["gtfs_routes"] == 0
|
||||
assert stats_after_delete["osm_routes"] == 6
|
||||
|
||||
osm_source = next(source for source in client.get("/api/sources").json() if source["kind"] == "osm_geojson")
|
||||
dataset_id = osm_source["datasets"][0]["id"]
|
||||
dataset_delete_job = client.delete(f"/api/datasets/{dataset_id}").json()
|
||||
assert dataset_delete_job["kind"] == "dataset_delete"
|
||||
assert dataset_delete_job["status"] == "queued"
|
||||
queued_source = next(source for source in client.get("/api/sources").json() if source["id"] == osm_source["id"])
|
||||
assert queued_source["datasets"][0]["active_job"]["id"] == dataset_delete_job["id"]
|
||||
assert queued_source["active_job"]["id"] == dataset_delete_job["id"]
|
||||
|
||||
assert run_worker_once(worker_id="test-worker")["processed"] == 1
|
||||
completed_dataset_delete = client.get(f"/api/jobs/{dataset_delete_job['id']}").json()
|
||||
assert completed_dataset_delete["status"] == "completed"
|
||||
assert completed_dataset_delete["result"]["delete_result"]["deleted"] is True
|
||||
assert client.get("/api/stats").json()["osm_routes"] == 0
|
||||
|
||||
|
||||
def test_missing_gtfs_sidecar_queues_recovery_without_breaking_sources():
|
||||
client = TestClient(app)
|
||||
assert client.post("/api/sample/reset").status_code == 200
|
||||
|
||||
with session_scope() as session:
|
||||
dataset = session.scalar(select(Dataset).where(Dataset.kind == "gtfs", Dataset.is_active.is_(True)))
|
||||
assert dataset is not None
|
||||
source_id = dataset.source_id
|
||||
metadata = json.loads(dataset.metadata_json or "{}")
|
||||
metadata["gtfs_storage"]["sidecar_path"] = str(settings.data_dir / "sidecars" / f"missing_gtfs_dataset_{dataset.id}.sqlite")
|
||||
dataset.metadata_json = json.dumps(metadata)
|
||||
dataset_id = dataset.id
|
||||
|
||||
response = client.get("/api/sources")
|
||||
|
||||
assert response.status_code == 200
|
||||
source = next(item for item in response.json() if item["id"] == source_id)
|
||||
assert source["active_job"]["kind"] == "source_import"
|
||||
assert "GTFS sidecar missing" in source["active_job"]["result"]["recovery_reason"]
|
||||
recovered_dataset = next(item for item in source["datasets"] if item["id"] == dataset_id)
|
||||
assert recovered_dataset["status"] == "missing_files"
|
||||
assert recovered_dataset["stats"]["missing_sidecar"] is True
|
||||
assert recovered_dataset["stats"]["stop_times"] == 0
|
||||
|
||||
second_response = client.get("/api/sources")
|
||||
assert second_response.status_code == 200
|
||||
with session_scope() as session:
|
||||
recovery_jobs = session.scalars(select(Job).where(Job.kind == "source_import", Job.status == "queued")).all()
|
||||
assert len(recovery_jobs) == 1
|
||||
|
||||
|
||||
def test_admin_maintenance_endpoints_are_guarded_and_callable():
|
||||
client = TestClient(app)
|
||||
assert client.post("/api/sample/reset").status_code == 200
|
||||
|
||||
init_job = client.post("/api/admin/init-db").json()
|
||||
assert init_job["kind"] == "maintenance"
|
||||
assert init_job["result"]["action"] == "init-db"
|
||||
assert run_worker_once(worker_id="test-worker")["processed"] == 1
|
||||
init_completed = client.get(f"/api/jobs/{init_job['id']}").json()
|
||||
assert init_completed["status"] == "completed"
|
||||
assert init_completed["result"]["result"]["status"] == "initialized"
|
||||
|
||||
backfill_job = client.post("/api/admin/backfill-gtfs-shapes", json={}).json()
|
||||
assert backfill_job["kind"] == "maintenance"
|
||||
assert run_worker_once(worker_id="test-worker")["processed"] == 1
|
||||
backfill = client.get(f"/api/jobs/{backfill_job['id']}").json()
|
||||
assert "datasets" in backfill["result"]["result"]
|
||||
|
||||
prune_cache_job = client.post("/api/admin/prune-cache", json={}).json()
|
||||
assert prune_cache_job["kind"] == "maintenance"
|
||||
assert run_worker_once(worker_id="test-worker")["processed"] == 1
|
||||
prune_cache = client.get(f"/api/jobs/{prune_cache_job['id']}").json()["result"]["result"]
|
||||
assert prune_cache["dry_run"] is True
|
||||
assert "files" in prune_cache
|
||||
assert "bytes" in prune_cache
|
||||
|
||||
prune_inactive_job = client.post("/api/admin/prune-inactive-datasets", json={}).json()
|
||||
assert prune_inactive_job["kind"] == "maintenance"
|
||||
assert run_worker_once(worker_id="test-worker")["processed"] == 1
|
||||
prune_inactive = client.get(f"/api/jobs/{prune_inactive_job['id']}").json()["result"]["result"]
|
||||
assert prune_inactive["dry_run"] is True
|
||||
assert "would_delete" in prune_inactive
|
||||
|
||||
sample_job = client.post("/api/jobs/sample-reset").json()
|
||||
assert sample_job["kind"] == "maintenance"
|
||||
assert sample_job["result"]["action"] == "sample-reset"
|
||||
assert run_worker_once(worker_id="test-worker")["processed"] == 1
|
||||
sample_completed = client.get(f"/api/jobs/{sample_job['id']}").json()
|
||||
assert sample_completed["status"] == "completed"
|
||||
assert sample_completed["result"]["result"]["status"] == "ok"
|
||||
assert client.get("/api/stats").json()["gtfs_routes"] == 6
|
||||
|
||||
assert client.post("/api/admin/prune-cache", json={"dry_run": False}).status_code == 400
|
||||
assert client.post("/api/admin/prune-inactive-datasets", json={"dry_run": False}).status_code == 400
|
||||
assert client.post("/api/admin/vacuum-db", json={}).status_code == 400
|
||||
assert client.post("/api/admin/reset-db", json={}).status_code == 400
|
||||
|
||||
|
||||
def test_source_catalog_import_and_ingestable_seed_metadata():
|
||||
init_db()
|
||||
client = TestClient(app)
|
||||
|
||||
catalog_import = client.post("/api/source-catalog/import").json()
|
||||
assert catalog_import["summary"]["catalog_entries"] >= 50
|
||||
|
||||
catalog = client.get("/api/source-catalog?country=DE&priority=P0&limit=10").json()
|
||||
assert catalog["entries"]
|
||||
assert any("DELFI" in entry["source_name"] for entry in catalog["entries"])
|
||||
assert "geometry_notes" in catalog["entries"][0]
|
||||
|
||||
osm_catalog = client.get("/api/source-catalog?q=Geofabrik&limit=5").json()
|
||||
osm_entry = next(entry for entry in osm_catalog["entries"] if "Geofabrik" in entry["source_name"])
|
||||
created_source = client.post(
|
||||
"/api/sources",
|
||||
json={
|
||||
"catalog_entry_id": osm_entry["id"],
|
||||
"name": "Berlin Geofabrik OSM PBF",
|
||||
"kind": "osm_pbf",
|
||||
"url": "https://download.geofabrik.de/europe/germany/berlin-latest.osm.pbf",
|
||||
"country": "DE",
|
||||
},
|
||||
).json()
|
||||
sources = client.get("/api/sources").json()
|
||||
linked_source = next(source for source in sources if source["id"] == created_source["id"])
|
||||
assert linked_source["catalog_entry_id"] == osm_entry["id"]
|
||||
assert linked_source["priority"] == osm_entry["priority"]
|
||||
linked_catalog = client.get("/api/source-catalog?q=Geofabrik&limit=5").json()
|
||||
linked_entry = next(entry for entry in linked_catalog["entries"] if entry["id"] == osm_entry["id"])
|
||||
assert linked_entry["linked_source_count"] == 1
|
||||
|
||||
seed_import = client.post("/api/source-catalog/import-ingestable").json()
|
||||
assert seed_import["created"] + seed_import["updated"] >= 10
|
||||
|
||||
sources = client.get("/api/sources").json()
|
||||
swiss = next(source for source in sources if source["name"] == "CH Swiss national GTFS")
|
||||
assert swiss["kind"] == "gtfs"
|
||||
assert swiss["priority"] == "P0"
|
||||
assert "rail" in swiss["mode_scope"]
|
||||
assert swiss["notes"]
|
||||
vbb = next(source for source in sources if source["name"] == "VBB Berlin-Brandenburg GTFS")
|
||||
vbb_catalog = next(entry for entry in client.get("/api/source-catalog?q=VBB&limit=5").json()["entries"] if entry["source_name"] == "VBB Berlin-Brandenburg GTFS")
|
||||
assert vbb["kind"] == "gtfs"
|
||||
assert vbb["priority"] == "P5"
|
||||
assert vbb["catalog_entry_id"] == vbb_catalog["id"]
|
||||
|
||||
|
||||
def test_ingestable_source_import_deduplicates_by_kind_and_url(tmp_path):
|
||||
init_db()
|
||||
first = tmp_path / "first.csv"
|
||||
first.write_text(
|
||||
"name,kind,url,country,license,mode_scope,source_basis,priority,notes\n"
|
||||
"Original GTFS,gtfs,https://example.test/feed.zip,DE,CC0,bus,test,P1,first\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
second = tmp_path / "second.csv"
|
||||
second.write_text(
|
||||
"name,kind,url,country,license,mode_scope,source_basis,priority,notes\n"
|
||||
"Renamed GTFS,gtfs,https://example.test/feed.zip,DE,CC0,bus,test,P0,second\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with session_scope() as session:
|
||||
assert import_ingestable_sources(session, first)["created"] == 1
|
||||
with session_scope() as session:
|
||||
result = import_ingestable_sources(session, second)
|
||||
assert result["created"] == 0
|
||||
assert result["updated"] == 1
|
||||
sources = session.scalars(select(Source).where(Source.url == "https://example.test/feed.zip")).all()
|
||||
assert len(sources) == 1
|
||||
assert sources[0].name == "Renamed GTFS"
|
||||
assert sources[0].priority == "P0"
|
||||
|
||||
|
||||
def test_write_endpoint_returns_busy_when_another_write_is_active():
|
||||
init_db()
|
||||
client = TestClient(app)
|
||||
previous_timeout = settings.database_write_lock_timeout_seconds
|
||||
settings.database_write_lock_timeout_seconds = 0.05
|
||||
try:
|
||||
with database_write_lock("test long write", timeout=0.1):
|
||||
response = client.post(
|
||||
"/api/sources",
|
||||
json={"name": "Busy test source", "kind": "gtfs", "url": "https://example.invalid/feed.zip"},
|
||||
)
|
||||
finally:
|
||||
settings.database_write_lock_timeout_seconds = previous_timeout
|
||||
|
||||
assert response.status_code == 409
|
||||
assert "Database is busy" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_manual_match_rule_survives_new_gtfs_dataset_row():
|
||||
client = TestClient(app)
|
||||
assert client.post("/api/sample/reset").status_code == 200
|
||||
|
||||
match = next(item for item in client.get("/api/matches?status=matched").json() if item["osm"])
|
||||
accepted = client.post(f"/api/matches/{match['id']}/accept").json()
|
||||
assert accepted["status"] == "accepted"
|
||||
|
||||
with session_scope() as session:
|
||||
old_route = session.get(GtfsRoute, match["gtfs"]["id"])
|
||||
assert old_route is not None
|
||||
old_dataset = session.get(Dataset, old_route.dataset_id)
|
||||
assert old_dataset is not None
|
||||
old_dataset.is_active = False
|
||||
replacement_dataset = Dataset(
|
||||
source_id=old_dataset.source_id,
|
||||
kind="gtfs",
|
||||
local_path="./data/replacement.gtfs.zip",
|
||||
sha256="replacement",
|
||||
is_active=True,
|
||||
status="imported",
|
||||
)
|
||||
session.add(replacement_dataset)
|
||||
session.flush()
|
||||
session.add(
|
||||
GtfsRoute(
|
||||
dataset_id=replacement_dataset.id,
|
||||
route_id=old_route.route_id,
|
||||
agency_id=old_route.agency_id,
|
||||
short_name=old_route.short_name,
|
||||
long_name=old_route.long_name,
|
||||
route_type=old_route.route_type,
|
||||
mode=old_route.mode,
|
||||
operator_name=old_route.operator_name,
|
||||
min_lon=old_route.min_lon,
|
||||
min_lat=old_route.min_lat,
|
||||
max_lon=old_route.max_lon,
|
||||
max_lat=old_route.max_lat,
|
||||
route_key=old_route.route_key,
|
||||
operator_key=old_route.operator_key,
|
||||
)
|
||||
)
|
||||
|
||||
rerun = client.post("/api/match/run").json()
|
||||
assert rerun["manual"] >= 1
|
||||
matches = client.get("/api/matches?status=accepted").json()
|
||||
assert any(item["gtfs"]["route_id"] == match["gtfs"]["route_id"] for item in matches)
|
||||
|
||||
|
||||
def _first_stop(client: TestClient, query: str) -> dict:
|
||||
response = client.get(f"/api/journey/stops?q={query}")
|
||||
assert response.status_code == 200
|
||||
stops = response.json()["stops"]
|
||||
assert stops
|
||||
return stops[0]
|
||||
148
tests/test_feed_discovery.py
Normal file
148
tests/test_feed_discovery.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
|
||||
from app import feed_discovery
|
||||
from app.feed_discovery import (
|
||||
FeedCandidate,
|
||||
build_gtfs_discovery_manifests,
|
||||
enrich_ptna_candidate_from_details,
|
||||
parse_ptna_country_page,
|
||||
parse_ptna_detail_fields,
|
||||
select_test_run_candidates,
|
||||
)
|
||||
|
||||
|
||||
def test_parse_ptna_country_and_detail_pages():
|
||||
country_html = """
|
||||
<table>
|
||||
<tr class="gtfs-tablerow">
|
||||
<td><a href="routes.php?feed=DE-BE-VBB">DE-BE-VBB</a></td>
|
||||
<td><a href="https://www.vbb.de">Verkehrsverbund Berlin-Brandenburg</a></td>
|
||||
<td><a href="https://www.vbb.de">VBB Verkehrsverbund Berlin-Brandenburg GmbH</a></td>
|
||||
<td>2026-01-01</td>
|
||||
<td>2026-12-12</td>
|
||||
<td>20260603</td>
|
||||
<td><a href="https://www.vbb.de/vbb-services/api-open-data/datensaetze/">2026-06-03</a></td>
|
||||
<td>2026-06-03</td>
|
||||
<td><a href="/en/gtfs-details.php?feed=DE-BE-VBB">Details, ...</a></td>
|
||||
</tr>
|
||||
</table>
|
||||
"""
|
||||
candidates = parse_ptna_country_page(
|
||||
country_html,
|
||||
country="DE",
|
||||
page_url="https://ptna.openstreetmap.de/gtfs/DE/index.php",
|
||||
)
|
||||
|
||||
assert len(candidates) == 1
|
||||
candidate = candidates[0]
|
||||
assert candidate.ptna_feed_id == "DE-BE-VBB"
|
||||
assert candidate.country == "DE"
|
||||
assert candidate.original_release_url == "https://www.vbb.de/vbb-services/api-open-data/datensaetze/"
|
||||
assert candidate.details_url == "https://ptna.openstreetmap.de/en/gtfs-details.php?feed=DE-BE-VBB"
|
||||
|
||||
detail_html = """
|
||||
<table>
|
||||
<tr><td>Release Url</td><td><a href="https://example.test/gtfs.zip">https://example.test/gtfs.zip</a></td></tr>
|
||||
<tr><td>Publisher's License</td><td><a href="https://example.test/license">CC BY 4.0</a></td></tr>
|
||||
<tr><td>License given for use in OSM</td><td>Attribution on contributor page is sufficient.</td></tr>
|
||||
<tr><td>"network:guid"</td><td>DE-BE-VBB</td></tr>
|
||||
</table>
|
||||
"""
|
||||
fields = parse_ptna_detail_fields(detail_html, "https://ptna.openstreetmap.de/en/gtfs-details.php?feed=DE-BE-VBB")
|
||||
assert fields["publisher's license"] == "CC BY 4.0"
|
||||
assert fields["publisher's license href"] == "https://example.test/license"
|
||||
|
||||
enrich_ptna_candidate_from_details(candidate, detail_html, candidate.details_url)
|
||||
assert candidate.selected_url == "https://example.test/gtfs.zip"
|
||||
assert candidate.license_text == "CC BY 4.0"
|
||||
assert "network:guid=DE-BE-VBB" in candidate.notes
|
||||
|
||||
|
||||
def test_build_gtfs_discovery_manifests_from_stubbed_sources(tmp_path, monkeypatch):
|
||||
mobility = [
|
||||
FeedCandidate(
|
||||
discovery_source="mobility_database",
|
||||
country="DE",
|
||||
provider="Rhein-Neckar-Verkehr",
|
||||
feed_name="RNV",
|
||||
stable_id="mdb-rnv",
|
||||
status="active",
|
||||
is_official="True",
|
||||
selected_url="https://example.test/rnv.zip",
|
||||
direct_download_url="https://example.test/rnv.zip",
|
||||
license_url="https://example.test/license",
|
||||
features="Shapes|Feed Information",
|
||||
priority="P0",
|
||||
)
|
||||
]
|
||||
ptna = [
|
||||
FeedCandidate(
|
||||
discovery_source="ptna",
|
||||
country="DE",
|
||||
provider="Rhein-Neckar-Verkehr",
|
||||
feed_name="RNV",
|
||||
ptna_feed_id="DE-BW-RNV",
|
||||
selected_url="https://example.test/rnv.zip",
|
||||
original_release_url="https://example.test/rnv.zip",
|
||||
license_text="CC BY 4.0",
|
||||
priority="P2",
|
||||
)
|
||||
]
|
||||
curated = [
|
||||
FeedCandidate(
|
||||
discovery_source="curated_seed",
|
||||
country="CH",
|
||||
provider="Swiss national",
|
||||
feed_name="CH Swiss national GTFS",
|
||||
selected_url="https://example.test/ch.zip",
|
||||
license_text="verify",
|
||||
features="rail,bus",
|
||||
priority="P0",
|
||||
)
|
||||
]
|
||||
monkeypatch.setattr(feed_discovery, "fetch_mobility_database_candidates", lambda **_: mobility)
|
||||
monkeypatch.setattr(feed_discovery, "fetch_mobility_acceptance_candidates", lambda **_: [])
|
||||
monkeypatch.setattr(feed_discovery, "fetch_ptna_candidates", lambda **_: ptna)
|
||||
monkeypatch.setattr(feed_discovery, "load_curated_ingestable_seed", lambda **_: curated)
|
||||
|
||||
report = build_gtfs_discovery_manifests(output_dir=tmp_path, countries=["DE", "CH"], test_limit=10)
|
||||
|
||||
assert report["counts"]["candidates"] == 2
|
||||
assert report["counts"]["ingestable"] == 2
|
||||
ingestable_rows = list(csv.DictReader((tmp_path / "gtfs_ingestable_sources.csv").open(encoding="utf-8")))
|
||||
assert {row["url"] for row in ingestable_rows} == {"https://example.test/rnv.zip", "https://example.test/ch.zip"}
|
||||
assert "ptna" in next(row for row in ingestable_rows if row["url"] == "https://example.test/rnv.zip")["source_basis"]
|
||||
|
||||
|
||||
def test_select_test_run_candidates_keeps_overlapping_german_feeds():
|
||||
candidates = [
|
||||
FeedCandidate(
|
||||
discovery_source="curated_seed",
|
||||
country="DE",
|
||||
provider="DB Long-distance Rail GTFS.DE",
|
||||
selected_url="https://download.gtfs.de/germany/fv_free/latest.zip",
|
||||
priority="P1",
|
||||
),
|
||||
FeedCandidate(
|
||||
discovery_source="mobility_database",
|
||||
country="DE",
|
||||
provider="Rhein-Neckar-Verkehr",
|
||||
selected_url="https://gtfs-sandbox-dds.rnv-online.de/latest/gtfs.zip",
|
||||
priority="P0",
|
||||
),
|
||||
FeedCandidate(
|
||||
discovery_source="curated_seed",
|
||||
country="CH",
|
||||
provider="Swiss national",
|
||||
selected_url="https://gtfs.geops.ch/dl/gtfs_complete.zip",
|
||||
priority="P0",
|
||||
),
|
||||
]
|
||||
|
||||
selected = select_test_run_candidates(candidates, limit=3)
|
||||
|
||||
assert len(selected) == 3
|
||||
assert any("gtfs.de" in candidate.selected_url for candidate in selected)
|
||||
assert any("rnv" in candidate.selected_url for candidate in selected)
|
||||
72
tests/test_gtfs_import.py
Normal file
72
tests/test_gtfs_import.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import zipfile
|
||||
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from app.db import reset_db, session_scope
|
||||
from app.gtfs_storage import sidecar_path, stop_time_count, stop_times_by_trip
|
||||
from app.journey import find_journeys, search_scheduled_stops
|
||||
from app.models import Dataset, GtfsCalendar, Source
|
||||
from app.pipeline.run import run_source
|
||||
|
||||
|
||||
def test_gtfs_import_uses_staging_bulk_loader_and_reports_chunks(tmp_path, monkeypatch):
|
||||
reset_db()
|
||||
gtfs_path = tmp_path / "small.gtfs.zip"
|
||||
with zipfile.ZipFile(gtfs_path, "w") as zf:
|
||||
zf.writestr("agency.txt", "agency_id,agency_name,agency_url,agency_timezone\nA,Agency,https://example.invalid,Europe/Berlin\n")
|
||||
zf.writestr(
|
||||
"stops.txt",
|
||||
"stop_id,stop_name,stop_lat,stop_lon\nA,Alpha,52.0,13.0\nB,Beta,52.1,13.1\nC,Gamma,52.2,13.2\n",
|
||||
)
|
||||
zf.writestr("routes.txt", "route_id,agency_id,route_short_name,route_long_name,route_type\nR,A,R1,Alpha - Gamma,3\n")
|
||||
zf.writestr("trips.txt", "route_id,service_id,trip_id,shape_id\nR,daily,t1,s1\nR,daily,t2,s1\n")
|
||||
zf.writestr("calendar.txt", "service_id,monday,tuesday,wednesday,thursday,friday,saturday,sunday,start_date,end_date\ndaily,1,1,1,1,1,1,1,20260101,20261231\n")
|
||||
zf.writestr(
|
||||
"stop_times.txt",
|
||||
"\n".join(
|
||||
[
|
||||
"trip_id,arrival_time,departure_time,stop_id,stop_sequence",
|
||||
"t1,08:00:00,08:00:00,A,1",
|
||||
"t1,08:05:00,08:05:00,B,2",
|
||||
"t1,08:10:00,08:10:00,C,3",
|
||||
"t2,09:00:00,09:00:00,A,1",
|
||||
"t2,09:10:00,09:10:00,C,2",
|
||||
]
|
||||
)
|
||||
+ "\n",
|
||||
)
|
||||
zf.writestr("shapes.txt", "shape_id,shape_pt_lat,shape_pt_lon,shape_pt_sequence\ns1,52.0,13.0,1\ns1,52.2,13.2,2\n")
|
||||
|
||||
monkeypatch.setattr("app.pipeline.gtfs.GTFS_STAGE_BATCH_SIZE", 2)
|
||||
events = []
|
||||
with session_scope() as session:
|
||||
source = Source(name="Small GTFS", kind="gtfs", url=str(gtfs_path))
|
||||
session.add(source)
|
||||
session.flush()
|
||||
dataset = run_source(session, source, progress_callback=lambda *args: events.append(args))
|
||||
|
||||
metadata = json.loads(dataset.metadata_json or "{}")
|
||||
assert metadata["importer"] == "gtfs_import_v6_sidecar_stop_times"
|
||||
assert metadata["staging"] == "sqlite_promoted_to_sidecar"
|
||||
assert metadata["gtfs_storage"]["tables"]["gtfs_stop_times"] == "sidecar"
|
||||
assert metadata["stop_times_imported"] == 5
|
||||
assert sidecar_path(dataset) is not None
|
||||
assert sidecar_path(dataset).exists()
|
||||
assert stop_time_count(session, dataset.id) == 5
|
||||
assert len(stop_times_by_trip(session, dataset.id, ["t1"])["t1"]) == 3
|
||||
assert session.scalar(select(func.count()).select_from(GtfsCalendar).where(GtfsCalendar.dataset_id == dataset.id)) == 1
|
||||
assert session.scalar(select(func.count()).select_from(Dataset).where(Dataset.kind == "gtfs", Dataset.is_active.is_(True))) == 1
|
||||
alpha = search_scheduled_stops(session, "Alpha", limit=1)[0]
|
||||
gamma = search_scheduled_stops(session, "Gamma", limit=1)[0]
|
||||
journey = find_journeys(session, alpha["id"], gamma["id"], "08:00", limit=1)
|
||||
assert journey["journeys"][0]["departure_time"] == "08:00:00"
|
||||
assert journey["journeys"][0]["arrival_time"] == "08:10:00"
|
||||
|
||||
event_types = [event[0] for event in events]
|
||||
assert "gtfs_staging_started" in event_types
|
||||
assert "gtfs_file_chunk" in event_types
|
||||
assert "gtfs_activation_sidecar_stop_times" in event_types
|
||||
assert "gtfs_activation_completed" in event_types
|
||||
282
tests/test_osm_pbf.py
Normal file
282
tests/test_osm_pbf.py
Normal file
@@ -0,0 +1,282 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from app.config import settings
|
||||
from app.db import reset_db, session_scope
|
||||
from app.models import Dataset, PipelineRun, Source
|
||||
from app.osm_storage import features_are_sidecar, osm_feature_count, query_osm_features, sidecar_path
|
||||
from app.pipeline.osm_labeling import relabel_osm_features
|
||||
from app.pipeline.run import run_source
|
||||
|
||||
|
||||
def test_osm_pbf_source_commits_raw_and_extracts_route_geometry(tmp_path):
|
||||
reset_db()
|
||||
osm_path = tmp_path / "transport.osm"
|
||||
osm_path.write_text(
|
||||
"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<osm version="0.6" generator="mobility-workbench-test">
|
||||
<node id="1" lat="52.5000" lon="13.4000" />
|
||||
<node id="2" lat="52.5010" lon="13.4100" />
|
||||
<node id="3" lat="52.5020" lon="13.4200">
|
||||
<tag k="highway" v="bus_stop"/>
|
||||
<tag k="name" v="Example Stop"/>
|
||||
</node>
|
||||
<node id="4" lat="52.5030" lon="13.4300" />
|
||||
<node id="5" lat="52.5030" lon="13.4310" />
|
||||
<node id="6" lat="52.5040" lon="13.4310" />
|
||||
<node id="7" lat="52.5040" lon="13.4300" />
|
||||
<node id="8" lat="52.5050" lon="13.4400" />
|
||||
<node id="9" lat="52.5060" lon="13.4500" />
|
||||
<way id="10">
|
||||
<nd ref="1"/>
|
||||
<nd ref="2"/>
|
||||
<nd ref="3"/>
|
||||
<tag k="highway" v="primary"/>
|
||||
</way>
|
||||
<way id="11">
|
||||
<nd ref="4"/>
|
||||
<nd ref="3"/>
|
||||
<tag k="highway" v="primary"/>
|
||||
</way>
|
||||
<way id="20">
|
||||
<nd ref="4"/>
|
||||
<nd ref="5"/>
|
||||
<nd ref="6"/>
|
||||
<nd ref="7"/>
|
||||
<nd ref="4"/>
|
||||
<tag k="aerialway" v="station"/>
|
||||
<tag k="name" v="Cable Station"/>
|
||||
</way>
|
||||
<way id="30">
|
||||
<nd ref="8"/>
|
||||
<nd ref="9"/>
|
||||
<tag k="route" v="ferry"/>
|
||||
<tag k="name" v="Ferry Waterway"/>
|
||||
</way>
|
||||
<relation id="100">
|
||||
<member type="way" ref="10" role=""/>
|
||||
<member type="way" ref="11" role=""/>
|
||||
<tag k="type" v="route"/>
|
||||
<tag k="route" v="bus"/>
|
||||
<tag k="ref" v="100"/>
|
||||
<tag k="name" v="Bus 100"/>
|
||||
<tag k="operator" v="BVG"/>
|
||||
</relation>
|
||||
</osm>
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with session_scope() as session:
|
||||
source = Source(name="Test OSM", kind="osm_pbf", url=str(osm_path), country="DE")
|
||||
session.add(source)
|
||||
session.flush()
|
||||
|
||||
dataset = run_source(session, source)
|
||||
|
||||
raw_dataset = session.scalars(select(Dataset).where(Dataset.kind == "osm_pbf_raw")).one()
|
||||
assert raw_dataset.status == "extracted"
|
||||
assert raw_dataset.is_active is False
|
||||
assert dataset.kind == "osm_geojson"
|
||||
assert dataset.is_active is True
|
||||
assert features_are_sidecar(dataset)
|
||||
assert sidecar_path(dataset) is not None
|
||||
assert sidecar_path(dataset).exists()
|
||||
|
||||
route = next(iter(query_osm_features(session, [dataset.id], kinds=["route"], search="100")), None)
|
||||
assert route is not None
|
||||
assert route.osm_type == "relation"
|
||||
assert route.mode == "bus"
|
||||
assert json.loads(route.geometry_geojson or "{}") == {
|
||||
"type": "LineString",
|
||||
"coordinates": [[13.4, 52.5], [13.41, 52.501], [13.42, 52.502], [13.43, 52.503]],
|
||||
}
|
||||
|
||||
stop = next(iter(query_osm_features(session, [dataset.id], kinds=["stop"], search="Example Stop")), None)
|
||||
assert stop is not None
|
||||
|
||||
cable_station = next(iter(query_osm_features(session, [dataset.id], kinds=["station"], search="Cable Station")), None)
|
||||
assert cable_station is not None
|
||||
|
||||
ferry_infra = next(iter(query_osm_features(session, [dataset.id], kinds=["infra"], search="Ferry Waterway")), None)
|
||||
assert ferry_infra is not None
|
||||
assert ferry_infra.mode == "ferry"
|
||||
|
||||
second_dataset = run_source(session, source)
|
||||
assert second_dataset.id == dataset.id
|
||||
assert session.scalar(select(func.count()).select_from(Dataset).where(Dataset.kind == "osm_pbf_raw")) == 1
|
||||
|
||||
|
||||
def test_osm_pbf_source_reuses_raw_and_filtered_transport_dataset(tmp_path):
|
||||
reset_db()
|
||||
osm_path = tmp_path / "transport.osm"
|
||||
osm_path.write_text(
|
||||
"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<osm version="0.6" generator="mobility-workbench-test">
|
||||
<node id="1" lat="52.5000" lon="13.4000" />
|
||||
<node id="2" lat="52.5010" lon="13.4100" />
|
||||
<way id="10">
|
||||
<nd ref="1"/>
|
||||
<nd ref="2"/>
|
||||
<tag k="highway" v="primary"/>
|
||||
</way>
|
||||
<relation id="100">
|
||||
<member type="way" ref="10" role=""/>
|
||||
<tag k="type" v="route"/>
|
||||
<tag k="route" v="bus"/>
|
||||
<tag k="ref" v="100"/>
|
||||
</relation>
|
||||
</osm>
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
filter_script = tmp_path / "copy_filter.sh"
|
||||
filter_script.write_text("#!/usr/bin/env sh\nset -eu\ncp \"$1\" \"$2\"\n", encoding="utf-8")
|
||||
filter_script.chmod(0o755)
|
||||
|
||||
old_enabled = settings.osm_pbf_prefilter_enabled
|
||||
old_formats = settings.osm_pbf_prefilter_formats
|
||||
old_script = settings.osm_pbf_prefilter_script
|
||||
settings.osm_pbf_prefilter_enabled = True
|
||||
settings.osm_pbf_prefilter_formats = "osm_xml"
|
||||
settings.osm_pbf_prefilter_script = filter_script
|
||||
try:
|
||||
with session_scope() as session:
|
||||
source = Source(name="Filtered OSM", kind="osm_pbf", url=str(osm_path), country="DE")
|
||||
session.add(source)
|
||||
session.flush()
|
||||
|
||||
dataset = run_source(session, source)
|
||||
|
||||
raw_dataset = session.scalars(select(Dataset).where(Dataset.kind == "osm_pbf_raw")).one()
|
||||
filtered_dataset = session.scalars(select(Dataset).where(Dataset.kind == "osm_pbf_transport")).one()
|
||||
raw_metadata = json.loads(raw_dataset.metadata_json or "{}")
|
||||
filtered_metadata = json.loads(filtered_dataset.metadata_json or "{}")
|
||||
derived_metadata = json.loads(dataset.metadata_json or "{}")
|
||||
|
||||
assert raw_dataset.status == "filtered"
|
||||
assert raw_dataset.is_active is False
|
||||
assert raw_metadata["filtered_dataset_id"] == filtered_dataset.id
|
||||
assert filtered_dataset.status == "extracted"
|
||||
assert filtered_dataset.is_active is False
|
||||
assert filtered_metadata["stage"] == "filtered_osm_transport_pbf"
|
||||
assert filtered_metadata["derived_from_dataset_id"] == raw_dataset.id
|
||||
assert filtered_metadata["filter"] == "osmium_transport_filter_v1"
|
||||
assert dataset.kind == "osm_geojson"
|
||||
assert dataset.is_active is True
|
||||
assert derived_metadata["raw_dataset_id"] == raw_dataset.id
|
||||
assert derived_metadata["filtered_dataset_id"] == filtered_dataset.id
|
||||
assert derived_metadata["derived_from_dataset_id"] == filtered_dataset.id
|
||||
|
||||
second_dataset = run_source(session, source)
|
||||
assert second_dataset.id == dataset.id
|
||||
assert session.scalar(select(func.count()).select_from(Dataset).where(Dataset.kind == "osm_pbf_raw")) == 1
|
||||
assert session.scalar(select(func.count()).select_from(Dataset).where(Dataset.kind == "osm_pbf_transport")) == 1
|
||||
finally:
|
||||
settings.osm_pbf_prefilter_enabled = old_enabled
|
||||
settings.osm_pbf_prefilter_formats = old_formats
|
||||
settings.osm_pbf_prefilter_script = old_script
|
||||
|
||||
|
||||
def test_osm_geojson_import_deduplicates_duplicate_osm_identities(tmp_path):
|
||||
reset_db()
|
||||
geojson_path = tmp_path / "duplicate-osm-identities.geojson"
|
||||
geojson_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "FeatureCollection",
|
||||
"features": [
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {"osm_type": "relation", "osm_id": "100", "type": "route", "route": "bus", "ref": "100"},
|
||||
"geometry": {
|
||||
"type": "LineString",
|
||||
"coordinates": [[13.4, 52.5], [13.41, 52.501]],
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {"osm_type": "relation", "osm_id": "100", "name": "Duplicate without route geometry"},
|
||||
"geometry": None,
|
||||
},
|
||||
],
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with session_scope() as session:
|
||||
source = Source(name="Duplicate OSM IDs", kind="osm_geojson", url=str(geojson_path), country="DE")
|
||||
session.add(source)
|
||||
session.flush()
|
||||
|
||||
dataset = run_source(session, source)
|
||||
|
||||
metadata = json.loads(dataset.metadata_json or "{}")
|
||||
storage = metadata["osm_storage"]
|
||||
assert dataset.status == "imported"
|
||||
assert storage["features"] == 1
|
||||
assert storage["duplicate_features_skipped"] == 1
|
||||
assert osm_feature_count(session, dataset.id) == 1
|
||||
route = query_osm_features(session, [dataset.id], kinds=["route"])[0]
|
||||
assert route.osm_type == "relation"
|
||||
assert route.osm_id == "100"
|
||||
assert route.ref == "100"
|
||||
|
||||
|
||||
def test_osm_relabel_updates_sidecar_route_scope_without_reparse(tmp_path):
|
||||
reset_db()
|
||||
geojson_path = tmp_path / "scope.geojson"
|
||||
geojson_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "FeatureCollection",
|
||||
"features": [
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {
|
||||
"osm_type": "relation",
|
||||
"osm_id": "900",
|
||||
"type": "route",
|
||||
"route": "bus",
|
||||
"name": "FlixBus Berlin Hamburg",
|
||||
"ref": "N900",
|
||||
},
|
||||
"geometry": {"type": "LineString", "coordinates": [[13.4, 52.5], [10.0, 53.55]]},
|
||||
}
|
||||
],
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
with session_scope() as session:
|
||||
source = Source(name="Scope OSM", kind="osm_geojson", url=str(geojson_path), country="DE")
|
||||
session.add(source)
|
||||
session.flush()
|
||||
dataset = run_source(session, source)
|
||||
path = sidecar_path(dataset)
|
||||
assert path is not None
|
||||
|
||||
with sqlite3.connect(path) as connection:
|
||||
connection.execute("UPDATE osm_features SET route_scope = 'local'")
|
||||
connection.commit()
|
||||
|
||||
stale = query_osm_features(session, [dataset.id], kinds=["route"])[0]
|
||||
assert stale.route_scope == "local"
|
||||
|
||||
result = relabel_osm_features(session, dataset_id=dataset.id, rebuild_indexes=False)
|
||||
assert result["changed"] == 1
|
||||
|
||||
relabeled = query_osm_features(session, [dataset.id], kinds=["route"])[0]
|
||||
assert relabeled.route_scope == "long_distance"
|
||||
metadata = json.loads(session.get(Dataset, dataset.id).metadata_json or "{}")
|
||||
assert metadata["label_features"]["version"] == "route_scope_v2"
|
||||
assert session.scalar(select(func.count()).select_from(PipelineRun).where(PipelineRun.stage == "label_features")) == 1
|
||||
|
||||
skipped = relabel_osm_features(session, dataset_id=dataset.id)
|
||||
assert skipped["skipped"] == 1
|
||||
92
tests/test_osm_replication.py
Normal file
92
tests/test_osm_replication.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import reset_db, session_scope
|
||||
from app.models import Dataset, OsmDiffState, Source
|
||||
from app.pipeline.osm_pbf import _try_prepare_raw_from_diffs
|
||||
from app.pipeline.osm_replication import ReplicationState, diff_url_for_sequence, parse_replication_state_text
|
||||
|
||||
|
||||
def test_parse_replication_state_text_and_diff_url():
|
||||
state = parse_replication_state_text(
|
||||
"""
|
||||
#Sat Jun 27 21:21:03 UTC 2026
|
||||
sequenceNumber=1234
|
||||
timestamp=2026-06-27T21\\:21\\:02Z
|
||||
"""
|
||||
)
|
||||
|
||||
assert state.sequence_number == 1234
|
||||
assert state.timestamp == "2026-06-27T21:21:02Z"
|
||||
assert diff_url_for_sequence("https://download.geofabrik.de/europe/germany/berlin-updates", 1234).endswith(
|
||||
"/000/001/234.osc.gz"
|
||||
)
|
||||
|
||||
|
||||
def test_osm_diff_application_records_new_raw_dataset_and_state(tmp_path, monkeypatch):
|
||||
reset_db()
|
||||
base_path = tmp_path / "base.osm.pbf"
|
||||
base_path.write_bytes(b"base")
|
||||
diff_paths = []
|
||||
|
||||
def fake_fetch(_updates_url, timeout=30):
|
||||
return ReplicationState(sequence_number=3, timestamp="2026-06-27T21:21:02Z", raw={"sequenceNumber": "3"})
|
||||
|
||||
def fake_download(_updates_url, sequence_number, output_dir, timeout=120):
|
||||
path = output_dir / f"{sequence_number}.osc.gz"
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_bytes(f"diff-{sequence_number}".encode())
|
||||
diff_paths.append(path)
|
||||
return path
|
||||
|
||||
def fake_apply(base, diffs, output, host_tool):
|
||||
output.write_bytes(base.read_bytes() + b"+" + b"+".join(path.read_bytes() for path in diffs))
|
||||
return subprocess.CompletedProcess(args=["osmium"], returncode=0, stdout="applied", stderr="")
|
||||
|
||||
monkeypatch.setattr("app.pipeline.osm_pbf.fetch_replication_state", fake_fetch)
|
||||
monkeypatch.setattr("app.pipeline.osm_pbf.download_diff", fake_download)
|
||||
monkeypatch.setattr("app.pipeline.osm_pbf.apply_osm_changes", fake_apply)
|
||||
|
||||
with session_scope() as session:
|
||||
source = Source(
|
||||
name="Berlin OSM",
|
||||
kind="osm_pbf",
|
||||
url="https://download.geofabrik.de/europe/germany/berlin-latest.osm.pbf",
|
||||
notes="geofabrik_id=berlin; updates_url=https://download.geofabrik.de/europe/germany/berlin-updates",
|
||||
)
|
||||
session.add(source)
|
||||
session.flush()
|
||||
base_dataset = Dataset(
|
||||
source_id=source.id,
|
||||
kind="osm_pbf_raw",
|
||||
local_path=str(base_path),
|
||||
sha256="b" * 64,
|
||||
is_active=False,
|
||||
status="committed",
|
||||
)
|
||||
session.add(base_dataset)
|
||||
session.flush()
|
||||
session.add(
|
||||
OsmDiffState(
|
||||
source_id=source.id,
|
||||
raw_dataset_id=base_dataset.id,
|
||||
updates_url="https://download.geofabrik.de/europe/germany/berlin-updates",
|
||||
sequence_number=1,
|
||||
timestamp="2026-06-26T21:21:02Z",
|
||||
status="active",
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
new_dataset = _try_prepare_raw_from_diffs(session, source)
|
||||
|
||||
assert new_dataset is not None
|
||||
assert new_dataset.id != base_dataset.id
|
||||
assert new_dataset.kind == "osm_pbf_raw"
|
||||
assert len(diff_paths) == 2
|
||||
states = session.scalars(select(OsmDiffState).where(OsmDiffState.source_id == source.id).order_by(OsmDiffState.sequence_number)).all()
|
||||
assert [state.sequence_number for state in states] == [1, 3]
|
||||
assert [state.status for state in states] == ["superseded", "active"]
|
||||
317
tests/test_pipeline.py
Normal file
317
tests/test_pipeline.py
Normal file
@@ -0,0 +1,317 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from shapely.geometry import LineString
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.db import init_db, session_scope
|
||||
from app.models import GtfsRoute, OsmFeature, RouteMatch, RoutePattern
|
||||
from app.osm_classification import infer_osm_route_scope
|
||||
from app.pipeline.gtfs import _gtfs_mode
|
||||
from app.pipeline.matcher import _build_osm_route_index, _candidate_osm_routes, route_match_scope, run_route_matching, score_route_pair
|
||||
from app.pipeline.osm_addresses import _address_area_geometry_geojson
|
||||
from app.pipeline.sample_data import load_sample_project
|
||||
from app.pipeline.route_layer import rebuild_route_layer
|
||||
from app.pipeline.utils import geometry_json_and_bbox
|
||||
|
||||
|
||||
def test_sample_pipeline_imports_and_matches():
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
result = load_sample_project(session)
|
||||
assert result["match_result"]["routes"] == 6
|
||||
assert session.scalar(select(GtfsRoute).where(GtfsRoute.short_name == "RE1")) is not None
|
||||
assert session.scalar(select(OsmFeature).where(OsmFeature.ref == "RE1")) is not None
|
||||
statuses = {row[0]: row[1] for row in session.execute(select(RouteMatch.status, RouteMatch.confidence))}
|
||||
assert "matched" in statuses
|
||||
assert set(statuses) & {"weak", "missing"}
|
||||
|
||||
|
||||
def test_route_matching_preserves_unchanged_match_rows():
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
load_sample_project(session)
|
||||
before = {
|
||||
route_id: (match_id, updated_at)
|
||||
for route_id, match_id, updated_at in session.execute(
|
||||
select(RouteMatch.gtfs_route_id, RouteMatch.id, RouteMatch.updated_at)
|
||||
)
|
||||
}
|
||||
|
||||
result = run_route_matching(session)
|
||||
|
||||
after = {
|
||||
route_id: (match_id, updated_at)
|
||||
for route_id, match_id, updated_at in session.execute(
|
||||
select(RouteMatch.gtfs_route_id, RouteMatch.id, RouteMatch.updated_at)
|
||||
)
|
||||
}
|
||||
assert result["unchanged"] == result["routes"]
|
||||
assert result["created"] == 0
|
||||
assert result["updated"] == 0
|
||||
assert after == before
|
||||
|
||||
|
||||
def test_route_layer_reuses_unchanged_route_patterns():
|
||||
init_db()
|
||||
with session_scope() as session:
|
||||
load_sample_project(session)
|
||||
before = {
|
||||
pattern_key: pattern_id
|
||||
for pattern_key, pattern_id in session.execute(select(RoutePattern.pattern_key, RoutePattern.id))
|
||||
}
|
||||
|
||||
result = rebuild_route_layer(session)
|
||||
|
||||
after = {
|
||||
pattern_key: pattern_id
|
||||
for pattern_key, pattern_id in session.execute(select(RoutePattern.pattern_key, RoutePattern.id))
|
||||
}
|
||||
assert result["route_patterns_created"] == 0
|
||||
assert result["route_patterns_removed"] == 0
|
||||
assert result["route_patterns_reused"] == result["route_patterns"]
|
||||
assert after == before
|
||||
|
||||
|
||||
def test_extended_gtfs_route_types_are_mapped_to_modes():
|
||||
assert _gtfs_mode(700) == "bus"
|
||||
assert _gtfs_mode(100) == "train"
|
||||
assert _gtfs_mode(109) == "train"
|
||||
assert _gtfs_mode(900) == "tram"
|
||||
assert _gtfs_mode(1000) == "ferry"
|
||||
|
||||
|
||||
def test_closed_address_way_is_stored_as_polygon_geometry():
|
||||
geometry = _address_area_geometry_geojson(
|
||||
[
|
||||
(8.68590, 49.40435),
|
||||
(8.68600, 49.40435),
|
||||
(8.68600, 49.40445),
|
||||
(8.68590, 49.40445),
|
||||
(8.68590, 49.40435),
|
||||
]
|
||||
)
|
||||
|
||||
assert json.loads(geometry or "{}") == {
|
||||
"type": "Polygon",
|
||||
"coordinates": [
|
||||
[
|
||||
[8.6859, 49.40435],
|
||||
[8.686, 49.40435],
|
||||
[8.686, 49.40445],
|
||||
[8.6859, 49.40445],
|
||||
[8.6859, 49.40435],
|
||||
]
|
||||
],
|
||||
}
|
||||
geometry = _address_area_geometry_geojson(
|
||||
[
|
||||
(8.68590, 49.40435),
|
||||
(8.68600, 49.40435),
|
||||
(8.68600, 49.40445),
|
||||
(8.68590, 49.40445),
|
||||
],
|
||||
closed=True,
|
||||
)
|
||||
assert json.loads(geometry or "{}")["coordinates"][0][-1] == [8.6859, 49.40435]
|
||||
assert _address_area_geometry_geojson([(0, 0), (1, 0), (1, 1)]) is None
|
||||
assert _address_area_geometry_geojson([(0, 0), (1, 0), (1, 1)], closed=False) is None
|
||||
|
||||
|
||||
def test_osm_route_scope_classifier_distinguishes_train_service_classes():
|
||||
assert infer_osm_route_scope(mode="train", ref="ICE 28") == "long_distance"
|
||||
assert infer_osm_route_scope(mode="train", ref="RE1") == "regional"
|
||||
assert infer_osm_route_scope(mode="train", ref="S5", network="S-Bahn Berlin") == "local"
|
||||
assert infer_osm_route_scope(mode="subway", ref="U5") == "local"
|
||||
assert infer_osm_route_scope(mode="coach", ref="FLX") == "long_distance"
|
||||
assert infer_osm_route_scope(mode="bus", ref="100") == "local"
|
||||
assert infer_osm_route_scope(mode="bus", ref="800", tags={"bus": "regional"}) == "regional"
|
||||
assert infer_osm_route_scope(mode="bus", name="FlixBus Berlin Hamburg") == "long_distance"
|
||||
|
||||
|
||||
def test_exact_line_ref_with_overlapping_geometry_scores_as_match_candidate():
|
||||
route = GtfsRoute(
|
||||
route_id="17441_700",
|
||||
short_name="M11",
|
||||
long_name=None,
|
||||
mode="bus",
|
||||
operator_name="Verkehrsverbund Berlin-Brandenburg",
|
||||
min_lon=13.29,
|
||||
min_lat=52.42,
|
||||
max_lon=13.33,
|
||||
max_lat=52.45,
|
||||
route_key="m11",
|
||||
)
|
||||
feature = OsmFeature(
|
||||
osm_type="relation",
|
||||
osm_id="123",
|
||||
kind="route",
|
||||
mode="bus",
|
||||
ref="M11",
|
||||
name="Bus M11: U Dahlem-Dorf => S Schöneweide/Sterndamm",
|
||||
operator="Berliner Verkehrsbetriebe",
|
||||
network="Verkehrsverbund Berlin-Brandenburg",
|
||||
min_lon=13.28,
|
||||
min_lat=52.40,
|
||||
max_lon=13.51,
|
||||
max_lat=52.46,
|
||||
route_key="m11",
|
||||
)
|
||||
|
||||
score, reasons = score_route_pair(route, feature)
|
||||
|
||||
assert score >= 85
|
||||
assert reasons["line_identity"] == "exact_ref_mode_bbox_overlap"
|
||||
|
||||
|
||||
def test_exact_line_ref_with_bbox_overlap_is_strong_without_name_or_operator_match():
|
||||
route = GtfsRoute(
|
||||
route_id="route-1",
|
||||
short_name="M11",
|
||||
long_name="",
|
||||
mode="bus",
|
||||
operator_name="",
|
||||
min_lon=13.29,
|
||||
min_lat=52.42,
|
||||
max_lon=13.33,
|
||||
max_lat=52.45,
|
||||
route_key="m11",
|
||||
)
|
||||
feature = OsmFeature(
|
||||
osm_type="relation",
|
||||
osm_id="456",
|
||||
kind="route",
|
||||
mode="bus",
|
||||
ref="M11",
|
||||
name="",
|
||||
operator="",
|
||||
network="",
|
||||
min_lon=13.30,
|
||||
min_lat=52.43,
|
||||
max_lon=13.35,
|
||||
max_lat=52.46,
|
||||
route_key="m11",
|
||||
)
|
||||
|
||||
score, reasons = score_route_pair(route, feature)
|
||||
|
||||
assert score >= 88
|
||||
assert reasons["strong_identity"] == "exact_ref_mode_bbox_overlap"
|
||||
|
||||
|
||||
def test_common_short_ref_candidates_are_spatially_ranked():
|
||||
route = GtfsRoute(
|
||||
route_id="bus-2-berlin",
|
||||
short_name="2",
|
||||
mode="bus",
|
||||
min_lon=13.30,
|
||||
min_lat=52.40,
|
||||
max_lon=13.40,
|
||||
max_lat=52.50,
|
||||
route_key="2",
|
||||
)
|
||||
far = OsmFeature(
|
||||
id=1,
|
||||
osm_type="relation",
|
||||
osm_id="far",
|
||||
kind="route",
|
||||
mode="bus",
|
||||
ref="2",
|
||||
min_lon=7.0,
|
||||
min_lat=50.0,
|
||||
max_lon=7.1,
|
||||
max_lat=50.1,
|
||||
route_key="2",
|
||||
)
|
||||
near = OsmFeature(
|
||||
id=2,
|
||||
osm_type="relation",
|
||||
osm_id="near",
|
||||
kind="route",
|
||||
mode="bus",
|
||||
ref="2",
|
||||
min_lon=13.31,
|
||||
min_lat=52.41,
|
||||
max_lon=13.39,
|
||||
max_lat=52.49,
|
||||
route_key="2",
|
||||
)
|
||||
|
||||
candidates = _candidate_osm_routes(route, _build_osm_route_index([far, near]))
|
||||
|
||||
assert candidates[0].osm_id == "near"
|
||||
|
||||
|
||||
def test_exact_ref_far_away_is_not_promoted_without_spatial_or_geometry_evidence():
|
||||
route = GtfsRoute(
|
||||
route_id="bus-2-berlin",
|
||||
short_name="2",
|
||||
mode="bus",
|
||||
operator_name="Example Operator",
|
||||
min_lon=13.30,
|
||||
min_lat=52.40,
|
||||
max_lon=13.40,
|
||||
max_lat=52.50,
|
||||
route_key="2",
|
||||
)
|
||||
feature = OsmFeature(
|
||||
osm_type="relation",
|
||||
osm_id="2-cologne",
|
||||
kind="route",
|
||||
mode="bus",
|
||||
ref="2",
|
||||
operator="Example Operator",
|
||||
min_lon=6.9,
|
||||
min_lat=50.9,
|
||||
max_lon=7.1,
|
||||
max_lat=51.0,
|
||||
route_key="2",
|
||||
)
|
||||
|
||||
score, reasons = score_route_pair(route, feature)
|
||||
|
||||
assert score < 65
|
||||
assert reasons["spatial_penalty"] == "exact_ref_far_bbox_center"
|
||||
assert reasons["spatial_cap"] == "exact_ref_far_without_geometry_overlap"
|
||||
|
||||
|
||||
def test_geometry_overlap_can_confirm_exact_ref_match():
|
||||
gtfs_geometry, gtfs_bbox = geometry_json_and_bbox(LineString([(13.30, 52.40), (13.35, 52.45), (13.40, 52.50)]))
|
||||
osm_geometry, osm_bbox = geometry_json_and_bbox(LineString([(13.3005, 52.4005), (13.3505, 52.4505), (13.4005, 52.5005)]))
|
||||
route = GtfsRoute(
|
||||
route_id="bus-2-berlin",
|
||||
short_name="2",
|
||||
mode="bus",
|
||||
min_lon=gtfs_bbox[0],
|
||||
min_lat=gtfs_bbox[1],
|
||||
max_lon=gtfs_bbox[2],
|
||||
max_lat=gtfs_bbox[3],
|
||||
route_key="2",
|
||||
geometry_geojson=gtfs_geometry,
|
||||
)
|
||||
feature = OsmFeature(
|
||||
osm_type="relation",
|
||||
osm_id="2-berlin",
|
||||
kind="route",
|
||||
mode="bus",
|
||||
ref="2",
|
||||
min_lon=osm_bbox[0],
|
||||
min_lat=osm_bbox[1],
|
||||
max_lon=osm_bbox[2],
|
||||
max_lat=osm_bbox[3],
|
||||
route_key="2",
|
||||
geometry_geojson=osm_geometry,
|
||||
)
|
||||
|
||||
score, reasons = score_route_pair(route, feature)
|
||||
|
||||
assert score >= 90
|
||||
assert reasons["strong_identity"] == "exact_ref_mode_geometry_overlap"
|
||||
assert reasons["geometry"]["gtfs_on_osm_ratio"] >= 0.9
|
||||
|
||||
|
||||
def test_route_match_scope_distinguishes_outside_loaded_osm_area():
|
||||
route = GtfsRoute(min_lon=13.3, min_lat=52.4, max_lon=13.4, max_lat=52.5)
|
||||
assert route_match_scope(route, (13.0, 52.3, 13.8, 52.7)) == "in_osm_scope"
|
||||
assert route_match_scope(route, (6.0, 50.0, 7.0, 51.0)) == "outside_osm_scope"
|
||||
1004
tests/test_route_layer.py
Normal file
1004
tests/test_route_layer.py
Normal file
File diff suppressed because it is too large
Load Diff
22
tests/test_source_updates.py
Normal file
22
tests/test_source_updates.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from app.models import Source
|
||||
from app.source_updates import _recover_missing_managed_cache_url
|
||||
|
||||
|
||||
def test_missing_managed_cache_source_recovers_seed_url_for_online_update():
|
||||
source = Source(
|
||||
id=3,
|
||||
name="Geofabrik Berlin OSM PBF",
|
||||
kind="osm_pbf",
|
||||
url="data/sources/source_3/1782478365.osm.pbf",
|
||||
country="DE",
|
||||
)
|
||||
|
||||
recovery = _recover_missing_managed_cache_url(source)
|
||||
|
||||
assert recovery == {
|
||||
"previous_url": "data/sources/source_3/1782478365.osm.pbf",
|
||||
"url": "https://download.geofabrik.de/europe/germany/berlin-latest.osm.pbf",
|
||||
}
|
||||
assert source.url == "https://download.geofabrik.de/europe/germany/berlin-latest.osm.pbf"
|
||||
Reference in New Issue
Block a user