from __future__ import annotations import json import sqlite3 from sqlalchemy import func, select from app.config import settings from app.db import reset_db, session_scope from app.models import Dataset, PipelineRun, Source from app.osm_storage import features_are_sidecar, osm_feature_count, query_osm_features, sidecar_path from app.pipeline.osm_labeling import relabel_osm_features from app.pipeline.run import run_source def test_osm_pbf_source_commits_raw_and_extracts_route_geometry(tmp_path): reset_db() osm_path = tmp_path / "transport.osm" osm_path.write_text( """ """, encoding="utf-8", ) with session_scope() as session: source = Source(name="Test OSM", kind="osm_pbf", url=str(osm_path), country="DE") session.add(source) session.flush() dataset = run_source(session, source) raw_dataset = session.scalars(select(Dataset).where(Dataset.kind == "osm_pbf_raw")).one() assert raw_dataset.status == "extracted" assert raw_dataset.is_active is False assert dataset.kind == "osm_geojson" assert dataset.is_active is True assert features_are_sidecar(dataset) assert sidecar_path(dataset) is not None assert sidecar_path(dataset).exists() route = next(iter(query_osm_features(session, [dataset.id], kinds=["route"], search="100")), None) assert route is not None assert route.osm_type == "relation" assert route.mode == "bus" assert json.loads(route.geometry_geojson or "{}") == { "type": "LineString", "coordinates": [[13.4, 52.5], [13.41, 52.501], [13.42, 52.502], [13.43, 52.503]], } stop = next(iter(query_osm_features(session, [dataset.id], kinds=["stop"], search="Example Stop")), None) assert stop is not None cable_station = next(iter(query_osm_features(session, [dataset.id], kinds=["station"], search="Cable Station")), None) assert cable_station is not None ferry_infra = next(iter(query_osm_features(session, [dataset.id], kinds=["infra"], search="Ferry Waterway")), None) assert ferry_infra is not None assert ferry_infra.mode == "ferry" second_dataset = run_source(session, source) assert second_dataset.id == dataset.id assert session.scalar(select(func.count()).select_from(Dataset).where(Dataset.kind == "osm_pbf_raw")) == 1 def test_osm_pbf_source_reuses_raw_and_filtered_transport_dataset(tmp_path): reset_db() osm_path = tmp_path / "transport.osm" osm_path.write_text( """ """, encoding="utf-8", ) filter_script = tmp_path / "copy_filter.sh" filter_script.write_text("#!/usr/bin/env sh\nset -eu\ncp \"$1\" \"$2\"\n", encoding="utf-8") filter_script.chmod(0o755) old_enabled = settings.osm_pbf_prefilter_enabled old_formats = settings.osm_pbf_prefilter_formats old_script = settings.osm_pbf_prefilter_script settings.osm_pbf_prefilter_enabled = True settings.osm_pbf_prefilter_formats = "osm_xml" settings.osm_pbf_prefilter_script = filter_script try: with session_scope() as session: source = Source(name="Filtered OSM", kind="osm_pbf", url=str(osm_path), country="DE") session.add(source) session.flush() dataset = run_source(session, source) raw_dataset = session.scalars(select(Dataset).where(Dataset.kind == "osm_pbf_raw")).one() filtered_dataset = session.scalars(select(Dataset).where(Dataset.kind == "osm_pbf_transport")).one() raw_metadata = json.loads(raw_dataset.metadata_json or "{}") filtered_metadata = json.loads(filtered_dataset.metadata_json or "{}") derived_metadata = json.loads(dataset.metadata_json or "{}") assert raw_dataset.status == "filtered" assert raw_dataset.is_active is False assert raw_metadata["filtered_dataset_id"] == filtered_dataset.id assert filtered_dataset.status == "extracted" assert filtered_dataset.is_active is False assert filtered_metadata["stage"] == "filtered_osm_transport_pbf" assert filtered_metadata["derived_from_dataset_id"] == raw_dataset.id assert filtered_metadata["filter"] == "osmium_transport_filter_v1" assert dataset.kind == "osm_geojson" assert dataset.is_active is True assert derived_metadata["raw_dataset_id"] == raw_dataset.id assert derived_metadata["filtered_dataset_id"] == filtered_dataset.id assert derived_metadata["derived_from_dataset_id"] == filtered_dataset.id second_dataset = run_source(session, source) assert second_dataset.id == dataset.id assert session.scalar(select(func.count()).select_from(Dataset).where(Dataset.kind == "osm_pbf_raw")) == 1 assert session.scalar(select(func.count()).select_from(Dataset).where(Dataset.kind == "osm_pbf_transport")) == 1 finally: settings.osm_pbf_prefilter_enabled = old_enabled settings.osm_pbf_prefilter_formats = old_formats settings.osm_pbf_prefilter_script = old_script def test_osm_geojson_import_deduplicates_duplicate_osm_identities(tmp_path): reset_db() geojson_path = tmp_path / "duplicate-osm-identities.geojson" geojson_path.write_text( json.dumps( { "type": "FeatureCollection", "features": [ { "type": "Feature", "properties": {"osm_type": "relation", "osm_id": "100", "type": "route", "route": "bus", "ref": "100"}, "geometry": { "type": "LineString", "coordinates": [[13.4, 52.5], [13.41, 52.501]], }, }, { "type": "Feature", "properties": {"osm_type": "relation", "osm_id": "100", "name": "Duplicate without route geometry"}, "geometry": None, }, ], } ), encoding="utf-8", ) with session_scope() as session: source = Source(name="Duplicate OSM IDs", kind="osm_geojson", url=str(geojson_path), country="DE") session.add(source) session.flush() dataset = run_source(session, source) metadata = json.loads(dataset.metadata_json or "{}") storage = metadata["osm_storage"] assert dataset.status == "imported" assert storage["features"] == 1 assert storage["duplicate_features_skipped"] == 1 assert osm_feature_count(session, dataset.id) == 1 route = query_osm_features(session, [dataset.id], kinds=["route"])[0] assert route.osm_type == "relation" assert route.osm_id == "100" assert route.ref == "100" def test_osm_relabel_updates_sidecar_route_scope_without_reparse(tmp_path): reset_db() geojson_path = tmp_path / "scope.geojson" geojson_path.write_text( json.dumps( { "type": "FeatureCollection", "features": [ { "type": "Feature", "properties": { "osm_type": "relation", "osm_id": "900", "type": "route", "route": "bus", "name": "FlixBus Berlin Hamburg", "ref": "N900", }, "geometry": {"type": "LineString", "coordinates": [[13.4, 52.5], [10.0, 53.55]]}, } ], } ), encoding="utf-8", ) with session_scope() as session: source = Source(name="Scope OSM", kind="osm_geojson", url=str(geojson_path), country="DE") session.add(source) session.flush() dataset = run_source(session, source) path = sidecar_path(dataset) assert path is not None with sqlite3.connect(path) as connection: connection.execute("UPDATE osm_features SET route_scope = 'local'") connection.commit() stale = query_osm_features(session, [dataset.id], kinds=["route"])[0] assert stale.route_scope == "local" result = relabel_osm_features(session, dataset_id=dataset.id, rebuild_indexes=False) assert result["changed"] == 1 relabeled = query_osm_features(session, [dataset.id], kinds=["route"])[0] assert relabeled.route_scope == "long_distance" metadata = json.loads(session.get(Dataset, dataset.id).metadata_json or "{}") assert metadata["label_features"]["version"] == "route_scope_v2" assert session.scalar(select(func.count()).select_from(PipelineRun).where(PipelineRun.stage == "label_features")) == 1 skipped = relabel_osm_features(session, dataset_id=dataset.id) assert skipped["skipped"] == 1