283 lines
11 KiB
Python
283 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import sqlite3
|
|
|
|
from sqlalchemy import func, select
|
|
|
|
from app.config import settings
|
|
from app.db import reset_db, session_scope
|
|
from app.models import Dataset, PipelineRun, Source
|
|
from app.osm_storage import features_are_sidecar, osm_feature_count, query_osm_features, sidecar_path
|
|
from app.pipeline.osm_labeling import relabel_osm_features
|
|
from app.pipeline.run import run_source
|
|
|
|
|
|
def test_osm_pbf_source_commits_raw_and_extracts_route_geometry(tmp_path):
|
|
reset_db()
|
|
osm_path = tmp_path / "transport.osm"
|
|
osm_path.write_text(
|
|
"""<?xml version="1.0" encoding="UTF-8"?>
|
|
<osm version="0.6" generator="mobility-workbench-test">
|
|
<node id="1" lat="52.5000" lon="13.4000" />
|
|
<node id="2" lat="52.5010" lon="13.4100" />
|
|
<node id="3" lat="52.5020" lon="13.4200">
|
|
<tag k="highway" v="bus_stop"/>
|
|
<tag k="name" v="Example Stop"/>
|
|
</node>
|
|
<node id="4" lat="52.5030" lon="13.4300" />
|
|
<node id="5" lat="52.5030" lon="13.4310" />
|
|
<node id="6" lat="52.5040" lon="13.4310" />
|
|
<node id="7" lat="52.5040" lon="13.4300" />
|
|
<node id="8" lat="52.5050" lon="13.4400" />
|
|
<node id="9" lat="52.5060" lon="13.4500" />
|
|
<way id="10">
|
|
<nd ref="1"/>
|
|
<nd ref="2"/>
|
|
<nd ref="3"/>
|
|
<tag k="highway" v="primary"/>
|
|
</way>
|
|
<way id="11">
|
|
<nd ref="4"/>
|
|
<nd ref="3"/>
|
|
<tag k="highway" v="primary"/>
|
|
</way>
|
|
<way id="20">
|
|
<nd ref="4"/>
|
|
<nd ref="5"/>
|
|
<nd ref="6"/>
|
|
<nd ref="7"/>
|
|
<nd ref="4"/>
|
|
<tag k="aerialway" v="station"/>
|
|
<tag k="name" v="Cable Station"/>
|
|
</way>
|
|
<way id="30">
|
|
<nd ref="8"/>
|
|
<nd ref="9"/>
|
|
<tag k="route" v="ferry"/>
|
|
<tag k="name" v="Ferry Waterway"/>
|
|
</way>
|
|
<relation id="100">
|
|
<member type="way" ref="10" role=""/>
|
|
<member type="way" ref="11" role=""/>
|
|
<tag k="type" v="route"/>
|
|
<tag k="route" v="bus"/>
|
|
<tag k="ref" v="100"/>
|
|
<tag k="name" v="Bus 100"/>
|
|
<tag k="operator" v="BVG"/>
|
|
</relation>
|
|
</osm>
|
|
""",
|
|
encoding="utf-8",
|
|
)
|
|
|
|
with session_scope() as session:
|
|
source = Source(name="Test OSM", kind="osm_pbf", url=str(osm_path), country="DE")
|
|
session.add(source)
|
|
session.flush()
|
|
|
|
dataset = run_source(session, source)
|
|
|
|
raw_dataset = session.scalars(select(Dataset).where(Dataset.kind == "osm_pbf_raw")).one()
|
|
assert raw_dataset.status == "extracted"
|
|
assert raw_dataset.is_active is False
|
|
assert dataset.kind == "osm_geojson"
|
|
assert dataset.is_active is True
|
|
assert features_are_sidecar(dataset)
|
|
assert sidecar_path(dataset) is not None
|
|
assert sidecar_path(dataset).exists()
|
|
|
|
route = next(iter(query_osm_features(session, [dataset.id], kinds=["route"], search="100")), None)
|
|
assert route is not None
|
|
assert route.osm_type == "relation"
|
|
assert route.mode == "bus"
|
|
assert json.loads(route.geometry_geojson or "{}") == {
|
|
"type": "LineString",
|
|
"coordinates": [[13.4, 52.5], [13.41, 52.501], [13.42, 52.502], [13.43, 52.503]],
|
|
}
|
|
|
|
stop = next(iter(query_osm_features(session, [dataset.id], kinds=["stop"], search="Example Stop")), None)
|
|
assert stop is not None
|
|
|
|
cable_station = next(iter(query_osm_features(session, [dataset.id], kinds=["station"], search="Cable Station")), None)
|
|
assert cable_station is not None
|
|
|
|
ferry_infra = next(iter(query_osm_features(session, [dataset.id], kinds=["infra"], search="Ferry Waterway")), None)
|
|
assert ferry_infra is not None
|
|
assert ferry_infra.mode == "ferry"
|
|
|
|
second_dataset = run_source(session, source)
|
|
assert second_dataset.id == dataset.id
|
|
assert session.scalar(select(func.count()).select_from(Dataset).where(Dataset.kind == "osm_pbf_raw")) == 1
|
|
|
|
|
|
def test_osm_pbf_source_reuses_raw_and_filtered_transport_dataset(tmp_path):
|
|
reset_db()
|
|
osm_path = tmp_path / "transport.osm"
|
|
osm_path.write_text(
|
|
"""<?xml version="1.0" encoding="UTF-8"?>
|
|
<osm version="0.6" generator="mobility-workbench-test">
|
|
<node id="1" lat="52.5000" lon="13.4000" />
|
|
<node id="2" lat="52.5010" lon="13.4100" />
|
|
<way id="10">
|
|
<nd ref="1"/>
|
|
<nd ref="2"/>
|
|
<tag k="highway" v="primary"/>
|
|
</way>
|
|
<relation id="100">
|
|
<member type="way" ref="10" role=""/>
|
|
<tag k="type" v="route"/>
|
|
<tag k="route" v="bus"/>
|
|
<tag k="ref" v="100"/>
|
|
</relation>
|
|
</osm>
|
|
""",
|
|
encoding="utf-8",
|
|
)
|
|
filter_script = tmp_path / "copy_filter.sh"
|
|
filter_script.write_text("#!/usr/bin/env sh\nset -eu\ncp \"$1\" \"$2\"\n", encoding="utf-8")
|
|
filter_script.chmod(0o755)
|
|
|
|
old_enabled = settings.osm_pbf_prefilter_enabled
|
|
old_formats = settings.osm_pbf_prefilter_formats
|
|
old_script = settings.osm_pbf_prefilter_script
|
|
settings.osm_pbf_prefilter_enabled = True
|
|
settings.osm_pbf_prefilter_formats = "osm_xml"
|
|
settings.osm_pbf_prefilter_script = filter_script
|
|
try:
|
|
with session_scope() as session:
|
|
source = Source(name="Filtered OSM", kind="osm_pbf", url=str(osm_path), country="DE")
|
|
session.add(source)
|
|
session.flush()
|
|
|
|
dataset = run_source(session, source)
|
|
|
|
raw_dataset = session.scalars(select(Dataset).where(Dataset.kind == "osm_pbf_raw")).one()
|
|
filtered_dataset = session.scalars(select(Dataset).where(Dataset.kind == "osm_pbf_transport")).one()
|
|
raw_metadata = json.loads(raw_dataset.metadata_json or "{}")
|
|
filtered_metadata = json.loads(filtered_dataset.metadata_json or "{}")
|
|
derived_metadata = json.loads(dataset.metadata_json or "{}")
|
|
|
|
assert raw_dataset.status == "filtered"
|
|
assert raw_dataset.is_active is False
|
|
assert raw_metadata["filtered_dataset_id"] == filtered_dataset.id
|
|
assert filtered_dataset.status == "extracted"
|
|
assert filtered_dataset.is_active is False
|
|
assert filtered_metadata["stage"] == "filtered_osm_transport_pbf"
|
|
assert filtered_metadata["derived_from_dataset_id"] == raw_dataset.id
|
|
assert filtered_metadata["filter"] == "osmium_transport_filter_v1"
|
|
assert dataset.kind == "osm_geojson"
|
|
assert dataset.is_active is True
|
|
assert derived_metadata["raw_dataset_id"] == raw_dataset.id
|
|
assert derived_metadata["filtered_dataset_id"] == filtered_dataset.id
|
|
assert derived_metadata["derived_from_dataset_id"] == filtered_dataset.id
|
|
|
|
second_dataset = run_source(session, source)
|
|
assert second_dataset.id == dataset.id
|
|
assert session.scalar(select(func.count()).select_from(Dataset).where(Dataset.kind == "osm_pbf_raw")) == 1
|
|
assert session.scalar(select(func.count()).select_from(Dataset).where(Dataset.kind == "osm_pbf_transport")) == 1
|
|
finally:
|
|
settings.osm_pbf_prefilter_enabled = old_enabled
|
|
settings.osm_pbf_prefilter_formats = old_formats
|
|
settings.osm_pbf_prefilter_script = old_script
|
|
|
|
|
|
def test_osm_geojson_import_deduplicates_duplicate_osm_identities(tmp_path):
|
|
reset_db()
|
|
geojson_path = tmp_path / "duplicate-osm-identities.geojson"
|
|
geojson_path.write_text(
|
|
json.dumps(
|
|
{
|
|
"type": "FeatureCollection",
|
|
"features": [
|
|
{
|
|
"type": "Feature",
|
|
"properties": {"osm_type": "relation", "osm_id": "100", "type": "route", "route": "bus", "ref": "100"},
|
|
"geometry": {
|
|
"type": "LineString",
|
|
"coordinates": [[13.4, 52.5], [13.41, 52.501]],
|
|
},
|
|
},
|
|
{
|
|
"type": "Feature",
|
|
"properties": {"osm_type": "relation", "osm_id": "100", "name": "Duplicate without route geometry"},
|
|
"geometry": None,
|
|
},
|
|
],
|
|
}
|
|
),
|
|
encoding="utf-8",
|
|
)
|
|
|
|
with session_scope() as session:
|
|
source = Source(name="Duplicate OSM IDs", kind="osm_geojson", url=str(geojson_path), country="DE")
|
|
session.add(source)
|
|
session.flush()
|
|
|
|
dataset = run_source(session, source)
|
|
|
|
metadata = json.loads(dataset.metadata_json or "{}")
|
|
storage = metadata["osm_storage"]
|
|
assert dataset.status == "imported"
|
|
assert storage["features"] == 1
|
|
assert storage["duplicate_features_skipped"] == 1
|
|
assert osm_feature_count(session, dataset.id) == 1
|
|
route = query_osm_features(session, [dataset.id], kinds=["route"])[0]
|
|
assert route.osm_type == "relation"
|
|
assert route.osm_id == "100"
|
|
assert route.ref == "100"
|
|
|
|
|
|
def test_osm_relabel_updates_sidecar_route_scope_without_reparse(tmp_path):
|
|
reset_db()
|
|
geojson_path = tmp_path / "scope.geojson"
|
|
geojson_path.write_text(
|
|
json.dumps(
|
|
{
|
|
"type": "FeatureCollection",
|
|
"features": [
|
|
{
|
|
"type": "Feature",
|
|
"properties": {
|
|
"osm_type": "relation",
|
|
"osm_id": "900",
|
|
"type": "route",
|
|
"route": "bus",
|
|
"name": "FlixBus Berlin Hamburg",
|
|
"ref": "N900",
|
|
},
|
|
"geometry": {"type": "LineString", "coordinates": [[13.4, 52.5], [10.0, 53.55]]},
|
|
}
|
|
],
|
|
}
|
|
),
|
|
encoding="utf-8",
|
|
)
|
|
|
|
with session_scope() as session:
|
|
source = Source(name="Scope OSM", kind="osm_geojson", url=str(geojson_path), country="DE")
|
|
session.add(source)
|
|
session.flush()
|
|
dataset = run_source(session, source)
|
|
path = sidecar_path(dataset)
|
|
assert path is not None
|
|
|
|
with sqlite3.connect(path) as connection:
|
|
connection.execute("UPDATE osm_features SET route_scope = 'local'")
|
|
connection.commit()
|
|
|
|
stale = query_osm_features(session, [dataset.id], kinds=["route"])[0]
|
|
assert stale.route_scope == "local"
|
|
|
|
result = relabel_osm_features(session, dataset_id=dataset.id, rebuild_indexes=False)
|
|
assert result["changed"] == 1
|
|
|
|
relabeled = query_osm_features(session, [dataset.id], kinds=["route"])[0]
|
|
assert relabeled.route_scope == "long_distance"
|
|
metadata = json.loads(session.get(Dataset, dataset.id).metadata_json or "{}")
|
|
assert metadata["label_features"]["version"] == "route_scope_v2"
|
|
assert session.scalar(select(func.count()).select_from(PipelineRun).where(PipelineRun.stage == "label_features")) == 1
|
|
|
|
skipped = relabel_osm_features(session, dataset_id=dataset.id)
|
|
assert skipped["skipped"] == 1
|