Alpha stage commit

This commit is contained in:
2026-07-01 23:29:51 +02:00
parent b583bb1233
commit e23387738b
84 changed files with 40807 additions and 326 deletions

111
app/pipeline/download.py Normal file
View File

@@ -0,0 +1,111 @@
from __future__ import annotations
import shutil
import time
from pathlib import Path
from urllib.parse import urlparse
import requests
from app.config import settings
from app.models import Source
from app.pipeline.utils import sha256_file
def materialize_source(source: Source) -> Path:
"""Download/copy a source into the local cache and return the file path.
Files are stored by content hash per source. Re-running an unchanged source
reuses the existing cached file instead of creating another timestamped copy.
"""
source_dir = settings.data_dir / "sources" / f"source_{source.id}"
source_dir.mkdir(parents=True, exist_ok=True)
suffix = _guess_suffix(source.url, source.kind)
parsed = urlparse(source.url)
if parsed.scheme in {"http", "https"}:
temp_path = _download_temp_path(source_dir, suffix)
existing_size = temp_path.stat().st_size if temp_path.exists() else 0
headers = {"Range": f"bytes={existing_size}-"} if existing_size > 0 else None
with requests.get(source.url, stream=True, timeout=120, headers=headers) as r:
r.raise_for_status()
mode = "ab" if existing_size > 0 and r.status_code == 206 else "wb"
with temp_path.open(mode) as f:
for chunk in r.iter_content(chunk_size=1024 * 1024):
if chunk:
f.write(chunk)
return _store_or_reuse_cached_file(source_dir=source_dir, source_path=temp_path, suffix=suffix, move=True)
if parsed.scheme == "file":
source_path = Path(parsed.path)
else:
source_path = Path(source.url)
if not source_path.exists():
raise FileNotFoundError(f"Source file does not exist: {source.url}")
if _is_relative_to(source_path.resolve(), source_dir.resolve()):
return source_path
return _store_or_reuse_cached_file(source_dir=source_dir, source_path=source_path, suffix=suffix, move=False)
def _download_temp_path(source_dir: Path, suffix: str) -> Path:
candidates = sorted(
source_dir.glob(f"*.download{suffix}"),
key=lambda path: path.stat().st_mtime if path.exists() else 0,
reverse=True,
)
if candidates:
return candidates[0]
return source_dir / f"{int(time.time())}.download{suffix}"
def _guess_suffix(url: str, kind: str) -> str:
path = urlparse(url).path or url
lower = path.lower()
for suffix in (".zip", ".geojson", ".json", ".osm.pbf", ".pbf", ".osm", ".osm.xml", ".osc.gz", ".osc", ".csv"):
if lower.endswith(suffix):
return suffix
if kind == "gtfs":
return ".zip"
if kind == "osm_geojson":
return ".geojson"
return ".dat"
def _store_or_reuse_cached_file(source_dir: Path, source_path: Path, suffix: str, move: bool) -> Path:
source_hash = sha256_file(source_path)
target = source_dir / f"{source_hash[:16]}{suffix}"
if target.exists() and sha256_file(target) == source_hash:
if move and source_path != target:
source_path.unlink(missing_ok=True)
return target
existing = _find_existing_cached_file(source_dir, source_hash, suffix, exclude=source_path)
if existing is not None:
if move and source_path != existing:
source_path.unlink(missing_ok=True)
return existing
if move:
source_path.replace(target)
else:
shutil.copyfile(source_path, target)
return target
def _find_existing_cached_file(source_dir: Path, source_hash: str, suffix: str, exclude: Path | None = None) -> Path | None:
for candidate in sorted(source_dir.glob(f"*{suffix}")):
if exclude is not None and candidate.resolve() == exclude.resolve():
continue
if candidate.is_file() and sha256_file(candidate) == source_hash:
return candidate
return None
def _is_relative_to(path: Path, parent: Path) -> bool:
try:
path.relative_to(parent)
return True
except ValueError:
return False

1327
app/pipeline/gtfs.py Normal file

File diff suppressed because it is too large Load Diff

995
app/pipeline/matcher.py Normal file
View File

@@ -0,0 +1,995 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
import json
from typing import Callable, Optional
from shapely.geometry import LineString, MultiLineString, Point, shape
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from app.config import settings
from app.models import Dataset, GtfsRoute, MatchRule, OsmFeature, RouteMatch
from app.osm_storage import ensure_main_osm_feature, osm_feature_bbox, query_osm_features
from app.pipeline.state import STAGE_MATCH_ROUTES, dependency_hash, finish_pipeline_run, start_pipeline_run
from app.pipeline.utils import approx_bbox_center_distance_deg, bbox_overlap, norm_ref, norm_text
MODE_GROUPS = {
"train": {"train", "rail", "railway"},
"subway": {"subway", "metro"},
"tram": {"tram", "light_rail"},
"light_rail": {"light_rail", "tram"},
"bus": {"bus", "coach", "trolleybus"},
"coach": {"coach", "bus"},
"trolleybus": {"trolleybus", "bus"},
"ferry": {"ferry"},
"funicular": {"funicular"},
"aerialway": {"aerialway", "cable_car"},
"monorail": {"monorail"},
}
MAX_FALLBACK_CANDIDATES_WITH_REF = 40
MAX_FALLBACK_CANDIDATES_WITHOUT_REF = 80
MAX_EXACT_REF_CANDIDATES = 120
OSM_SCOPE_NEAR_DISTANCE_DEG = 0.15
GEOMETRY_PROXIMITY_DEG = 0.0035
GEOMETRY_SAMPLE_POINTS = 24
MATCHER_VERSION = "matcher_v4_scope_spatial_manual_rules"
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
@dataclass(frozen=True)
class _ManualMatchRule:
id: int
rule_type: str
route_selector: dict[str, object]
osm_selector: dict[str, object] | None
status: str
@dataclass(frozen=True)
class _OsmRouteIndex:
all_routes: list[OsmFeature]
by_ref: dict[str, list[OsmFeature]]
by_route_key: dict[str, list[OsmFeature]]
by_mode: dict[str, list[OsmFeature]]
@dataclass(frozen=True)
class _GeometryProfile:
geom: object
lines: list[LineString]
length: float
sample_points: list[Point]
@dataclass(frozen=True)
class _RouteMatchPayload:
gtfs_route_id: int
osm_feature_id: int | None
confidence: float
status: str
rule_source: str
reasons_json: str | None
def run_route_matching(
session: Session,
*,
progress_callback: ProgressCallback | None = None,
batch_size: int | None = None,
) -> dict[str, object]:
"""Match active GTFS routes against active OSM route features."""
active_datasets = session.execute(
select(Dataset.id, Dataset.kind, Dataset.source_id).where(Dataset.is_active.is_(True))
).all()
if not active_datasets:
return {"routes": 0, "matches": 0, "missing": 0}
dataset_source_ids = {int(dataset_id): int(source_id) for dataset_id, _, source_id in active_datasets}
gtfs_dataset_ids = [int(dataset_id) for dataset_id, kind, _ in active_datasets if kind == "gtfs"]
osm_dataset_ids = [int(dataset_id) for dataset_id, kind, _ in active_datasets if kind == "osm_geojson"]
if not gtfs_dataset_ids:
return {"routes": 0, "matches": 0, "missing": 0}
route_row_ids = session.scalars(
select(GtfsRoute.id)
.where(GtfsRoute.dataset_id.in_(gtfs_dataset_ids))
.order_by(GtfsRoute.dataset_id, GtfsRoute.route_id, GtfsRoute.id)
).all()
# Reconcile current match rows from auto scoring plus durable manual rules.
total_routes = len(route_row_ids)
if total_routes == 0:
return {"routes": 0, "matches": 0, "missing": 0}
dependency = _route_matching_dependency(session, active_datasets)
run = start_pipeline_run(
session,
stage=STAGE_MATCH_ROUTES,
version=MATCHER_VERSION,
dependency_hash_value=dependency_hash(dependency),
inputs=dependency,
)
session.commit()
effective_batch_size = max(1, int(batch_size or settings.route_matching_batch_size))
_emit_progress(
progress_callback,
"route_matching_started",
f"Matching {total_routes} GTFS routes in batches of {effective_batch_size}.",
0,
total_routes,
{"gtfs_datasets": gtfs_dataset_ids, "osm_datasets": osm_dataset_ids, "batch_size": effective_batch_size},
)
manual_rules = _manual_match_rules(session)
osm_scope_bbox = osm_feature_bbox(session, osm_dataset_ids, kinds=["route"])
counts = {"routes": total_routes, "matches": 0, "missing": 0, "manual": 0, "created": 0, "updated": 0, "unchanged": 0}
scoped_counts = {"in_osm_scope": 0, "near_osm_scope": 0, "outside_osm_scope": 0, "unknown_scope": 0}
processed = 0
for chunk in _chunks_int(route_row_ids, effective_batch_size):
routes = session.scalars(
select(GtfsRoute)
.where(GtfsRoute.id.in_(chunk))
.order_by(GtfsRoute.dataset_id, GtfsRoute.route_id, GtfsRoute.id)
).all()
batch_counts = _match_route_batch(
session=session,
routes=routes,
osm_dataset_ids=osm_dataset_ids,
dataset_source_ids=dataset_source_ids,
manual_rules=manual_rules,
osm_scope_bbox=osm_scope_bbox,
scoped_counts=scoped_counts,
)
counts["matches"] += batch_counts["matches"]
counts["missing"] += batch_counts["missing"]
counts["manual"] += batch_counts["manual"]
counts["created"] += batch_counts["created"]
counts["updated"] += batch_counts["updated"]
counts["unchanged"] += batch_counts["unchanged"]
processed += len(routes)
session.commit()
_emit_progress(
progress_callback,
"route_matching_batch",
f"Matched {processed}/{total_routes} GTFS routes.",
processed,
total_routes,
{
"processed": processed,
"matches": counts["matches"],
"missing": counts["missing"],
"manual": counts["manual"],
"created": counts["created"],
"updated": counts["updated"],
"unchanged": counts["unchanged"],
"scope": dict(scoped_counts),
},
)
result = {**counts, "scope": scoped_counts}
finish_pipeline_run(session, run, outputs=result)
session.commit()
_emit_progress(
progress_callback,
"route_matching_completed",
"Route matching completed.",
total_routes,
total_routes,
result,
)
return result
def _route_matching_dependency(session: Session, active_datasets) -> dict[str, object]:
datasets = [
{"id": int(dataset_id), "kind": str(kind), "source_id": int(source_id), "sha256": _dataset_sha(session, int(dataset_id))}
for dataset_id, kind, source_id in active_datasets
]
rules = [
{
"id": int(rule.id),
"type": rule.rule_type,
"active": bool(rule.active),
"selector": rule.selector_json,
"action": rule.action_json,
}
for rule in session.scalars(select(MatchRule).order_by(MatchRule.id)).all()
]
return {"version": MATCHER_VERSION, "active_datasets": datasets, "manual_rules": rules}
def _dataset_sha(session: Session, dataset_id: int) -> str | None:
dataset = session.get(Dataset, dataset_id)
return None if dataset is None else dataset.sha256
def _match_route_batch(
*,
session: Session,
routes: list[GtfsRoute],
osm_dataset_ids: list[int],
dataset_source_ids: dict[int, int],
manual_rules: list[_ManualMatchRule],
osm_scope_bbox: tuple[float | None, float | None, float | None, float | None],
scoped_counts: dict[str, int],
) -> dict[str, int]:
matches = 0
missing = 0
manual = 0
payloads: list[_RouteMatchPayload] = []
for route in routes:
scope = route_match_scope(route, osm_scope_bbox)
scoped_counts[scope] = scoped_counts.get(scope, 0) + 1
route_source_id = dataset_source_ids.get(route.dataset_id)
accepted_rule = _accepted_rule_for_route(manual_rules, route, route_source_id)
if accepted_rule is not None:
accepted_feature = _feature_for_rule_from_storage(session, osm_dataset_ids, dataset_source_ids, accepted_rule)
if accepted_feature is not None:
accepted_feature = ensure_main_osm_feature(session, accepted_feature)
payloads.append(
_RouteMatchPayload(
gtfs_route_id=route.id,
osm_feature_id=accepted_feature.id,
confidence=100.0,
status="accepted",
rule_source="manual",
reasons_json=json.dumps(
{"manual_rule_id": accepted_rule.id, "manual": "accepted_match", "scope": scope},
separators=(",", ":"),
),
)
)
matches += 1
manual += 1
continue
if scope == "outside_osm_scope":
missing += 1
payloads.append(
_RouteMatchPayload(
gtfs_route_id=route.id,
osm_feature_id=None,
confidence=0.0,
status="missing",
rule_source="auto",
reasons_json=json.dumps(
{
"reason": "outside loaded OSM route scope",
"scope": scope,
},
separators=(",", ":"),
),
)
)
continue
best_feature: Optional[OsmFeature] = None
best_score = 0.0
best_reasons: dict[str, object] = {}
route_geometry_profile = _geometry_profile(route.geometry_geojson)
for feature in candidate_osm_routes_for_route(session, route, osm_dataset_ids):
if _is_rejected_pair(manual_rules, route, route_source_id, feature, dataset_source_ids.get(feature.dataset_id)):
continue
feature_geometry_profile = _geometry_profile(feature.geometry_geojson)
score, reasons = score_route_pair(
route,
feature,
route_geometry_profile=route_geometry_profile,
feature_geometry_profile=feature_geometry_profile,
)
if score > best_score:
best_score = score
best_feature = feature
best_reasons = reasons
status = _status_from_score(best_score)
if best_feature is None or status == "missing":
missing += 1
best_feature_id = None
best_reasons = {
"reason": "no OSM candidate above threshold",
"scope": scope,
"best_score_below_threshold": round(float(best_score), 2) if best_score else 0,
"best_reasons": best_reasons,
}
best_score = 0
else:
matches += 1
best_feature = ensure_main_osm_feature(session, best_feature)
best_feature_id = best_feature.id
best_reasons["scope"] = scope
payloads.append(
_RouteMatchPayload(
gtfs_route_id=route.id,
osm_feature_id=best_feature_id,
confidence=round(float(best_score), 2),
status=status,
rule_source="auto",
reasons_json=json.dumps(best_reasons, separators=(",", ":")),
)
)
changes = _apply_route_match_payloads(session, payloads)
session.flush()
return {"matches": matches, "missing": missing, "manual": manual, **changes}
def _apply_route_match_payloads(session: Session, payloads: list[_RouteMatchPayload]) -> dict[str, int]:
if not payloads:
return {"created": 0, "updated": 0, "unchanged": 0}
route_ids = [payload.gtfs_route_id for payload in payloads]
existing_rows = session.scalars(
select(RouteMatch).where(RouteMatch.gtfs_route_id.in_(route_ids)).order_by(RouteMatch.gtfs_route_id, RouteMatch.id)
).all()
existing_by_route: dict[int, list[RouteMatch]] = {}
for row in existing_rows:
existing_by_route.setdefault(row.gtfs_route_id, []).append(row)
created = 0
updated = 0
unchanged = 0
duplicate_ids: list[int] = []
now = datetime.now(timezone.utc)
for payload in payloads:
existing = existing_by_route.get(payload.gtfs_route_id, [])
current = _preferred_existing_match(existing)
if current is None:
session.add(
RouteMatch(
gtfs_route_id=payload.gtfs_route_id,
osm_feature_id=payload.osm_feature_id,
confidence=payload.confidence,
status=payload.status,
rule_source=payload.rule_source,
reasons_json=payload.reasons_json,
)
)
created += 1
continue
duplicate_ids.extend(row.id for row in existing if row.id != current.id)
if _route_match_payload_equal(current, payload):
unchanged += 1
continue
current.osm_feature_id = payload.osm_feature_id
current.confidence = payload.confidence
current.status = payload.status
current.rule_source = payload.rule_source
current.reasons_json = payload.reasons_json
current.updated_at = now
updated += 1
for chunk in _chunks_int(duplicate_ids, 1000):
session.execute(delete(RouteMatch).where(RouteMatch.id.in_(chunk)))
return {"created": created, "updated": updated, "unchanged": unchanged}
def _preferred_existing_match(rows: list[RouteMatch]) -> RouteMatch | None:
if not rows:
return None
return next((row for row in rows if row.rule_source == "manual"), rows[0])
def _route_match_payload_equal(row: RouteMatch, payload: _RouteMatchPayload) -> bool:
return (
row.osm_feature_id == payload.osm_feature_id
and round(float(row.confidence or 0), 2) == round(float(payload.confidence or 0), 2)
and row.status == payload.status
and row.rule_source == payload.rule_source
and (row.reasons_json or None) == (payload.reasons_json or None)
)
def _build_osm_route_index(osm_routes: list[OsmFeature]) -> _OsmRouteIndex:
by_ref: dict[str, list[OsmFeature]] = {}
by_route_key: dict[str, list[OsmFeature]] = {}
by_mode: dict[str, list[OsmFeature]] = {}
for feature in osm_routes:
ref = norm_ref(feature.ref or "")
if ref:
by_ref.setdefault(ref, []).append(feature)
if feature.route_key:
by_route_key.setdefault(feature.route_key, []).append(feature)
if feature.mode:
by_mode.setdefault(feature.mode, []).append(feature)
return _OsmRouteIndex(all_routes=osm_routes, by_ref=by_ref, by_route_key=by_route_key, by_mode=by_mode)
def _candidate_osm_routes(route: GtfsRoute, index: _OsmRouteIndex) -> list[OsmFeature]:
selected: list[OsmFeature] = []
seen: set[int] = set()
def add(features: list[OsmFeature], *, require_compatible_mode: bool = True) -> None:
for feature in features:
if feature.id in seen:
continue
if require_compatible_mode and not _mode_compatible(route.mode or "", feature.mode or ""):
continue
seen.add(feature.id)
selected.append(feature)
route_ref = norm_ref(route.short_name or route.route_id)
if route_ref:
add(index.by_ref.get(route_ref, []))
if route.route_key:
add(index.by_route_key.get(route.route_key, []))
if selected:
return _spatially_ranked_candidates(route, selected, MAX_EXACT_REF_CANDIDATES)
compatible_modes = MODE_GROUPS.get(route.mode or "", {route.mode or ""})
mode_candidates: list[OsmFeature] = []
for mode in compatible_modes:
if mode:
mode_candidates.extend(index.by_mode.get(mode, []))
if not mode_candidates:
mode_candidates = index.all_routes
gtfs_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
near_candidates: list[tuple[float, OsmFeature]] = []
for feature in mode_candidates:
osm_bbox = (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
distance = approx_bbox_center_distance_deg(gtfs_bbox, osm_bbox)
if bbox_overlap(gtfs_bbox, osm_bbox):
near_candidates.append((0.0, feature))
elif distance is not None and distance < 0.12:
near_candidates.append((distance, feature))
fallback_limit = MAX_FALLBACK_CANDIDATES_WITH_REF if route_ref else MAX_FALLBACK_CANDIDATES_WITHOUT_REF
fallback = [feature for _, feature in sorted(near_candidates, key=lambda item: item[0])[:fallback_limit]]
if not fallback:
fallback = mode_candidates[:fallback_limit]
add(fallback)
return _spatially_ranked_candidates(route, selected, fallback_limit)
def candidate_osm_routes_for_route(session: Session, route: GtfsRoute, osm_dataset_ids: list[int]) -> list[OsmFeature]:
if not osm_dataset_ids:
return []
selected: list[OsmFeature] = []
seen: set[tuple[int, str, str]] = set()
def add(features: list[OsmFeature], *, require_compatible_mode: bool = True) -> None:
for feature in features:
key = (feature.dataset_id, feature.osm_type, feature.osm_id)
if key in seen:
continue
if require_compatible_mode and not _mode_compatible(route.mode or "", feature.mode or ""):
continue
seen.add(key)
selected.append(feature)
route_ref = norm_ref(route.short_name or route.route_id)
route_keys = [key for key in [route.route_key, route_ref] if key]
for route_key in dict.fromkeys(route_keys):
add(
query_osm_features(
session,
osm_dataset_ids,
kinds=["route"],
route_key=route_key,
)
)
if selected:
return _spatially_ranked_candidates(route, selected, MAX_EXACT_REF_CANDIDATES)
gtfs_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
compatible_modes = sorted(MODE_GROUPS.get(route.mode or "", {route.mode or ""}) - {""})
if not any(value is None for value in gtfs_bbox):
bbox = _expanded_bbox(gtfs_bbox, 0.10)
add(
query_osm_features(
session,
osm_dataset_ids,
kinds=["route"],
modes=compatible_modes or None,
bbox=bbox,
limit=MAX_FALLBACK_CANDIDATES_WITHOUT_REF * 4,
),
require_compatible_mode=False,
)
if not selected:
add(
query_osm_features(
session,
osm_dataset_ids,
kinds=["route"],
modes=compatible_modes or None,
limit=MAX_FALLBACK_CANDIDATES_WITHOUT_REF,
),
require_compatible_mode=False,
)
fallback_limit = MAX_FALLBACK_CANDIDATES_WITH_REF if route_ref else MAX_FALLBACK_CANDIDATES_WITHOUT_REF
return _spatially_ranked_candidates(route, selected, fallback_limit)
def score_route_pair(
route: GtfsRoute,
feature: OsmFeature,
route_geometry_profile: _GeometryProfile | None = None,
feature_geometry_profile: _GeometryProfile | None = None,
) -> tuple[float, dict[str, object]]:
score = 0.0
reasons: dict[str, object] = {}
gtfs_mode = route.mode or ""
osm_mode = feature.mode or ""
if _mode_compatible(gtfs_mode, osm_mode):
score += 25
reasons["mode"] = "compatible"
elif gtfs_mode and osm_mode:
reasons["mode"] = f"mismatch: {gtfs_mode} != {osm_mode}"
return 0.0, reasons
gtfs_ref = norm_ref(route.short_name or route.route_id)
osm_ref = norm_ref(feature.ref or "")
if gtfs_ref and osm_ref:
if gtfs_ref == osm_ref:
score += 25
reasons["ref"] = "exact"
elif gtfs_ref in osm_ref or osm_ref in gtfs_ref:
score += 15
reasons["ref"] = "partial"
gtfs_name = norm_text(" ".join(v for v in [route.long_name, route.short_name, route.route_id] if v))
osm_name = norm_text(" ".join(v for v in [feature.name, feature.ref] if v))
name_similarity = _ratio(gtfs_name, osm_name)
score += 20 * name_similarity
reasons["name_similarity"] = round(name_similarity, 3)
gtfs_operator = norm_text(route.operator_name or "")
osm_operator = norm_text(" ".join(v for v in [feature.operator, feature.network] if v))
operator_similarity = _ratio(gtfs_operator, osm_operator) if gtfs_operator and osm_operator else 0
score += 15 * operator_similarity
reasons["operator_similarity"] = round(operator_similarity, 3)
gtfs_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
osm_bbox = (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
center_distance = None
if bbox_overlap(gtfs_bbox, osm_bbox):
score += 14
reasons["bbox"] = "overlap"
if gtfs_ref and osm_ref and gtfs_ref == osm_ref and _mode_compatible(gtfs_mode, osm_mode):
score += 8
reasons["line_identity"] = "exact_ref_mode_bbox_overlap"
else:
center_distance = approx_bbox_center_distance_deg(gtfs_bbox, osm_bbox)
if center_distance is not None:
if center_distance < 0.01:
score += 12
elif center_distance < 0.03:
score += 8
elif center_distance < 0.08:
score += 4
elif gtfs_ref and osm_ref and gtfs_ref == osm_ref and center_distance > OSM_SCOPE_NEAR_DISTANCE_DEG:
score -= 8
reasons["spatial_penalty"] = "exact_ref_far_bbox_center"
reasons["bbox_center_distance_deg"] = round(center_distance, 5)
geometry_metrics = (
_geometry_match_metrics_from_profiles(route_geometry_profile, feature_geometry_profile)
if route_geometry_profile is not None and feature_geometry_profile is not None
else _geometry_match_metrics(route.geometry_geojson, feature.geometry_geojson)
)
if geometry_metrics is not None:
reasons["geometry"] = geometry_metrics
geometry_score = 34 * float(geometry_metrics["gtfs_on_osm_ratio"]) + 8 * float(geometry_metrics["osm_on_gtfs_ratio"])
if float(geometry_metrics["endpoint_distance_deg"]) < GEOMETRY_PROXIMITY_DEG * 2:
geometry_score += 6
if float(geometry_metrics["length_ratio"]) < 0.35 or float(geometry_metrics["length_ratio"]) > 2.8:
geometry_score -= 8
reasons["geometry_length"] = "implausible_ratio"
score += max(0.0, min(42.0, geometry_score))
# Extra small boost for same normalized route key.
if route.route_key and feature.route_key and route.route_key == feature.route_key:
score += 5
reasons["route_key"] = "same"
if gtfs_ref and osm_ref and gtfs_ref == osm_ref and _mode_compatible(gtfs_mode, osm_mode):
if bbox_overlap(gtfs_bbox, osm_bbox):
score = max(score, 88.0)
reasons["strong_identity"] = "exact_ref_mode_bbox_overlap"
elif center_distance is not None and center_distance < 0.02:
score = max(score, 82.0)
reasons["strong_identity"] = "exact_ref_mode_near_bbox_center"
if route.route_key and feature.route_key and route.route_key == feature.route_key and _mode_compatible(gtfs_mode, osm_mode):
if bbox_overlap(gtfs_bbox, osm_bbox):
score = max(score, 86.0)
reasons.setdefault("strong_identity", "same_route_key_mode_bbox_overlap")
if geometry_metrics is not None:
gtfs_on_osm = float(geometry_metrics["gtfs_on_osm_ratio"])
endpoint_distance = float(geometry_metrics["endpoint_distance_deg"])
if gtfs_on_osm >= 0.82 and endpoint_distance < GEOMETRY_PROXIMITY_DEG * 3 and _mode_compatible(gtfs_mode, osm_mode):
if gtfs_ref and osm_ref and gtfs_ref == osm_ref:
score = max(score, 90.0)
reasons["strong_identity"] = "exact_ref_mode_geometry_overlap"
elif gtfs_ref and osm_ref and (gtfs_ref in osm_ref or osm_ref in gtfs_ref):
score = max(score, 82.0)
reasons["strong_identity"] = "partial_ref_mode_geometry_overlap"
if (
gtfs_ref
and osm_ref
and gtfs_ref == osm_ref
and center_distance is not None
and center_distance > OSM_SCOPE_NEAR_DISTANCE_DEG
and not bbox_overlap(gtfs_bbox, osm_bbox)
and (
geometry_metrics is None
or float(geometry_metrics.get("gtfs_on_osm_ratio", 0.0)) < 0.25
)
):
score = min(score, 58.0)
reasons["spatial_cap"] = "exact_ref_far_without_geometry_overlap"
return min(score, 100.0), reasons
def route_match_scope(route: GtfsRoute, osm_scope_bbox: tuple[float | None, float | None, float | None, float | None]) -> str:
route_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
if any(value is None for value in route_bbox) or any(value is None for value in osm_scope_bbox):
return "unknown_scope"
if bbox_overlap(route_bbox, osm_scope_bbox):
return "in_osm_scope"
distance = approx_bbox_center_distance_deg(route_bbox, osm_scope_bbox)
if distance is not None and distance < OSM_SCOPE_NEAR_DISTANCE_DEG:
return "near_osm_scope"
return "outside_osm_scope"
def _combined_bbox(features: list[OsmFeature]) -> tuple[float | None, float | None, float | None, float | None]:
boxes = [
(feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
for feature in features
if None not in (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
]
if not boxes:
return (None, None, None, None)
return (
min(float(box[0]) for box in boxes if box[0] is not None),
min(float(box[1]) for box in boxes if box[1] is not None),
max(float(box[2]) for box in boxes if box[2] is not None),
max(float(box[3]) for box in boxes if box[3] is not None),
)
def _spatially_ranked_candidates(route: GtfsRoute, candidates: list[OsmFeature], limit: int) -> list[OsmFeature]:
return [
feature
for _, feature in sorted(
((_spatial_rank(route, feature), feature) for feature in candidates),
key=lambda item: item[0],
)[: max(1, limit)]
]
def _spatial_rank(route: GtfsRoute, feature: OsmFeature) -> tuple[int, float, str]:
route_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
feature_bbox = (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
distance = approx_bbox_center_distance_deg(route_bbox, feature_bbox)
if bbox_overlap(route_bbox, feature_bbox):
bucket = 0
elif distance is not None and distance < OSM_SCOPE_NEAR_DISTANCE_DEG:
bucket = 1
elif distance is not None:
bucket = 2
else:
bucket = 3
return (bucket, distance if distance is not None else 999.0, feature.osm_id)
def _expanded_bbox(
bbox: tuple[float | None, float | None, float | None, float | None],
padding: float,
) -> tuple[float, float, float, float] | None:
min_lon, min_lat, max_lon, max_lat = bbox
if None in (min_lon, min_lat, max_lon, max_lat):
return None
return (float(min_lon) - padding, float(min_lat) - padding, float(max_lon) + padding, float(max_lat) + padding)
def _chunks_int(values: list[int], size: int) -> list[list[int]]:
return [values[start : start + size] for start in range(0, len(values), max(1, size))]
def _emit_progress(
progress_callback: ProgressCallback | None,
event_type: str,
message: str,
progress_current: int | None,
progress_total: int | None,
metadata: dict[str, object] | None = None,
) -> None:
if progress_callback is not None:
progress_callback(event_type, message, progress_current, progress_total, metadata)
def _geometry_match_metrics(route_geometry: str | None, feature_geometry: str | None) -> dict[str, float] | None:
route_profile = _geometry_profile(route_geometry)
feature_profile = _geometry_profile(feature_geometry)
return _geometry_match_metrics_from_profiles(route_profile, feature_profile)
def _geometry_profile(geometry_text: str | None) -> _GeometryProfile | None:
if not geometry_text:
return None
try:
geom = shape(json.loads(geometry_text))
except Exception: # noqa: BLE001 - malformed geometry should not break matching
return None
lines = _iter_lines(geom)
if not lines:
return None
length = sum(line.length for line in lines)
if length == 0:
return None
sample_points = _sample_line_points(lines, GEOMETRY_SAMPLE_POINTS)
if not sample_points:
return None
return _GeometryProfile(geom=geom, lines=lines, length=length, sample_points=sample_points)
def _geometry_match_metrics_from_profiles(
route_profile: _GeometryProfile | None, feature_profile: _GeometryProfile | None
) -> dict[str, float] | None:
if route_profile is None or feature_profile is None:
return None
gtfs_on_osm = _near_point_ratio(route_profile.sample_points, feature_profile.geom, GEOMETRY_PROXIMITY_DEG)
osm_on_gtfs = _near_point_ratio(feature_profile.sample_points, route_profile.geom, GEOMETRY_PROXIMITY_DEG)
endpoint_distance = _endpoint_distance(route_profile.lines, feature_profile.geom)
length_ratio = route_profile.length / feature_profile.length if feature_profile.length else 0.0
return {
"gtfs_on_osm_ratio": round(gtfs_on_osm, 3),
"osm_on_gtfs_ratio": round(osm_on_gtfs, 3),
"endpoint_distance_deg": round(endpoint_distance, 6),
"length_ratio": round(length_ratio, 3),
}
def _iter_lines(geom) -> list[LineString]:
if isinstance(geom, LineString):
return [geom]
if isinstance(geom, MultiLineString):
return [line for line in geom.geoms if isinstance(line, LineString) and line.length > 0]
return []
def _sample_line_points(lines: list[LineString], count: int) -> list[Point]:
total_length = sum(line.length for line in lines)
if total_length == 0:
return []
points = []
for index in range(count):
target = total_length * (index / max(1, count - 1))
traversed = 0.0
for line in lines:
next_traversed = traversed + line.length
if target <= next_traversed or line is lines[-1]:
points.append(line.interpolate(max(0.0, min(line.length, target - traversed))))
break
traversed = next_traversed
return points
def _near_point_ratio(points: list[Point], geom, max_distance: float) -> float:
if not points:
return 0.0
near = sum(1 for point in points if geom.distance(point) <= max_distance)
return near / len(points)
def _endpoint_distance(gtfs_lines: list[LineString], osm_geom) -> float:
longest = max(gtfs_lines, key=lambda line: line.length)
coords = list(longest.coords)
if len(coords) < 2:
return 999.0
return osm_geom.distance(Point(coords[0])) + osm_geom.distance(Point(coords[-1]))
def _manual_match_rules(session: Session) -> list[_ManualMatchRule]:
rules = session.scalars(
select(MatchRule)
.where(MatchRule.active.is_(True), MatchRule.rule_type.in_(["accept_match", "reject_match"]))
.order_by(MatchRule.id.desc())
).all()
parsed: list[_ManualMatchRule] = []
for rule in rules:
try:
selector = json.loads(rule.selector_json or "{}")
action = json.loads(rule.action_json or "{}")
except json.JSONDecodeError:
continue
route_selector = selector.get("gtfs") if isinstance(selector.get("gtfs"), dict) else selector
osm_selector = action.get("osm") if isinstance(action.get("osm"), dict) else selector.get("osm")
if not isinstance(osm_selector, dict) and selector.get("osm_feature_id") is not None:
osm_selector = {"osm_feature_id": selector.get("osm_feature_id")}
status = str(action.get("status") or ("accepted" if rule.rule_type == "accept_match" else "rejected"))
parsed.append(
_ManualMatchRule(
id=rule.id,
rule_type=rule.rule_type,
route_selector=route_selector,
osm_selector=osm_selector if isinstance(osm_selector, dict) else None,
status=status,
)
)
return parsed
def _accepted_rule_for_route(
rules: list[_ManualMatchRule], route: GtfsRoute, route_source_id: int | None
) -> _ManualMatchRule | None:
for rule in rules:
if rule.rule_type != "accept_match":
continue
if rule.status != "accepted":
continue
if _route_matches_selector(route, route_source_id, rule.route_selector):
return rule
return None
def _feature_for_rule(
features: list[OsmFeature], dataset_source_ids: dict[int, int], rule: _ManualMatchRule
) -> OsmFeature | None:
if not rule.osm_selector:
return None
for feature in features:
if _feature_matches_selector(feature, dataset_source_ids.get(feature.dataset_id), rule.osm_selector):
return feature
return None
def _feature_for_rule_from_storage(
session: Session,
osm_dataset_ids: list[int],
dataset_source_ids: dict[int, int],
rule: _ManualMatchRule,
) -> OsmFeature | None:
if not rule.osm_selector:
return None
selector = rule.osm_selector
legacy_id = _safe_int(selector.get("osm_feature_id"))
if legacy_id is not None:
feature = session.get(OsmFeature, legacy_id)
if feature is not None and _feature_matches_selector(feature, dataset_source_ids.get(feature.dataset_id), selector):
return feature
scoped_dataset_ids = list(osm_dataset_ids)
expected_source = selector.get("source_id")
if expected_source is not None:
expected_source_id = _safe_int(expected_source)
if expected_source_id is not None:
scoped_dataset_ids = [
dataset_id
for dataset_id in scoped_dataset_ids
if dataset_source_ids.get(dataset_id) == expected_source_id
]
dataset_id = _safe_int(selector.get("dataset_id"))
if dataset_id is not None:
scoped_dataset_ids = [value for value in scoped_dataset_ids if value == dataset_id]
if not scoped_dataset_ids:
return None
features: list[OsmFeature] = []
osm_type = selector.get("osm_type")
osm_id = selector.get("osm_id")
if osm_type and osm_id:
features = query_osm_features(
session,
scoped_dataset_ids,
kinds=["route"],
osm_type=str(osm_type),
osm_id=str(osm_id),
limit=10,
)
if not features:
route_key = selector.get("route_key")
if route_key:
features = query_osm_features(session, scoped_dataset_ids, kinds=["route"], route_key=str(route_key))
if not features:
ref = norm_ref(selector.get("ref"))
if ref:
features = query_osm_features(session, scoped_dataset_ids, kinds=["route"], route_key=ref)
for feature in features:
if _feature_matches_selector(feature, dataset_source_ids.get(feature.dataset_id), selector):
return feature
return None
def _is_rejected_pair(
rules: list[_ManualMatchRule],
route: GtfsRoute,
route_source_id: int | None,
feature: OsmFeature,
feature_source_id: int | None,
) -> bool:
for rule in rules:
if rule.rule_type != "reject_match":
continue
if not _route_matches_selector(route, route_source_id, rule.route_selector):
continue
if rule.osm_selector and _feature_matches_selector(feature, feature_source_id, rule.osm_selector):
return True
return False
def _route_matches_selector(route: GtfsRoute, source_id: int | None, selector: dict[str, object]) -> bool:
legacy_id = selector.get("gtfs_route_id")
if legacy_id is not None and _safe_int(legacy_id) == route.id:
return True
expected_source = selector.get("source_id")
if expected_source is not None and source_id is not None and _safe_int(expected_source) != source_id:
return False
route_id = selector.get("route_id")
if route_id and str(route_id) == route.route_id:
return True
route_key = selector.get("route_key")
if route_key and route.route_key and str(route_key) == route.route_key:
return True
ref = norm_ref(selector.get("ref"))
mode = selector.get("mode")
if ref and ref == norm_ref(route.short_name or route.route_id):
return not mode or _mode_compatible(str(mode), route.mode or "")
return False
def _feature_matches_selector(feature: OsmFeature, source_id: int | None, selector: dict[str, object]) -> bool:
legacy_id = selector.get("osm_feature_id")
if legacy_id is not None and _safe_int(legacy_id) == feature.id:
return True
expected_source = selector.get("source_id")
if expected_source is not None and source_id is not None and _safe_int(expected_source) != source_id:
return False
osm_type = selector.get("osm_type")
osm_id = selector.get("osm_id")
if osm_type and osm_id and str(osm_type) == feature.osm_type and str(osm_id) == feature.osm_id:
return True
route_key = selector.get("route_key")
if route_key and feature.route_key and str(route_key) == feature.route_key:
return True
ref = norm_ref(selector.get("ref"))
mode = selector.get("mode")
if ref and ref == norm_ref(feature.ref or ""):
return not mode or _mode_compatible(str(mode), feature.mode or "")
return False
def _safe_int(value: object) -> int | None:
try:
return int(value) # type: ignore[arg-type]
except (TypeError, ValueError):
return None
def _mode_compatible(gtfs_mode: str, osm_mode: str) -> bool:
if not gtfs_mode or not osm_mode:
return True
if gtfs_mode == osm_mode:
return True
return osm_mode in MODE_GROUPS.get(gtfs_mode, {gtfs_mode}) or gtfs_mode in MODE_GROUPS.get(osm_mode, {osm_mode})
def _ratio(a: str, b: str) -> float:
if not a or not b:
return 0.0
if a == b:
return 1.0
token_ratio = _token_similarity(a, b)
if a in b or b in a:
token_ratio = max(token_ratio, 0.82)
return token_ratio
def _token_similarity(a: str, b: str) -> float:
left = set(a.split())
right = set(b.split())
if not left or not right:
return 0.0
return len(left & right) / len(left | right)
def _status_from_score(score: float) -> str:
if score >= 85:
return "matched"
if score >= 65:
return "probable"
if score >= 40:
return "weak"
return "missing"

View File

@@ -0,0 +1,508 @@
from __future__ import annotations
import json
import math
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Callable
import osmium
from sqlalchemy import delete, func, select, text
from sqlalchemy.orm import Session
from app.config import settings
from app.models import Dataset, OsmAddress
from app.pipeline.routing_layer import active_routing_dataset
from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
ADDRESS_INDEX_VERSION = "osm_addresses_v2_nodes_ways_area_geometry"
ADDRESS_TAGS = {
"addr:housenumber",
"addr:housename",
"addr:street",
"addr:place",
"addr:postcode",
"addr:city",
"addr:country",
"addr:unit",
"addr:suburb",
"addr:district",
"addr:municipality",
"entrance",
"name",
}
@dataclass
class AddressIndexResult:
dataset_id: int
input_path: str
addresses: int
node_addresses: int
way_addresses: int
skipped: int
version: str = ADDRESS_INDEX_VERSION
def as_dict(self) -> dict[str, object]:
return {
"version": self.version,
"dataset_id": self.dataset_id,
"input_path": self.input_path,
"addresses": self.addresses,
"node_addresses": self.node_addresses,
"way_addresses": self.way_addresses,
"skipped": self.skipped,
}
def rebuild_address_index(
session: Session,
*,
dataset_id: int | None = None,
input_path: str | Path | None = None,
reset: bool = True,
batch_size: int = 20_000,
progress_callback: ProgressCallback | None = None,
) -> dict[str, object]:
dataset = session.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(session)
if dataset is None:
raise ValueError("No OSM PBF dataset is available for address indexing.")
path = Path(input_path or dataset.local_path)
if not path.exists():
raise FileNotFoundError(f"Address index PBF does not exist: {path}")
if reset:
_emit(progress_callback, "address_index_clear_started", "Clearing existing OSM address index.", None, None, {"dataset_id": dataset.id})
_clear_address_rows(session, dataset_id=int(dataset.id))
session.commit()
if settings.is_postgresql_database:
_emit(progress_callback, "address_index_indexes_dropped", "Dropping address lookup indexes before bulk import.", None, None, {"dataset_id": dataset.id})
_drop_address_indexes(session)
session.commit()
_emit(progress_callback, "address_index_import_started", "Importing OSM address nodes and ways.", None, None, {"dataset_id": dataset.id, "path": str(path)})
handler = _AddressHandler(
session=session,
dataset_id=dataset.id,
batch_size=batch_size,
progress_callback=progress_callback,
)
if hasattr(osmium, "FileProcessor"):
_apply_address_file_processor(handler, path)
else:
handler.apply_file(str(path), locations=True)
handler.flush()
return finalize_address_index(
session,
dataset_id=dataset.id,
input_path=path,
node_addresses=handler.node_address_count,
way_addresses=handler.way_address_count,
skipped=handler.skipped_count,
progress_callback=progress_callback,
)
def finalize_address_index(
session: Session,
*,
dataset_id: int,
input_path: str | Path,
node_addresses: int = 0,
way_addresses: int = 0,
skipped: int = 0,
progress_callback: ProgressCallback | None = None,
) -> dict[str, object]:
dataset = session.get(Dataset, dataset_id)
if dataset is None:
raise ValueError("Address index dataset does not exist.")
if settings.is_postgresql_database:
_emit(progress_callback, "address_index_geometry_started", "Refreshing address point geometries.", None, None, {"dataset_id": dataset.id})
refresh_postgis_geometries(session, dataset_id=dataset.id, tables=["osm_addresses"], only_missing=False)
session.commit()
_emit(progress_callback, "address_index_indexes_started", "Rebuilding address lookup indexes.", None, None, {"dataset_id": dataset.id})
_create_address_indexes(session)
session.commit()
analyze_postgresql_tables(session, ["osm_addresses"])
address_count = int(session.scalar(select(func.count()).select_from(OsmAddress).where(OsmAddress.dataset_id == dataset.id)) or 0)
metadata = _metadata(dataset)
metadata["address_index"] = {
"version": ADDRESS_INDEX_VERSION,
"addresses": address_count,
"node_addresses": int(node_addresses),
"way_addresses": int(way_addresses),
"skipped": int(skipped),
"input_path": str(input_path),
}
dataset.metadata_json = json.dumps(metadata, indent=2)
session.commit()
result = AddressIndexResult(
dataset_id=dataset.id,
input_path=str(input_path),
addresses=address_count,
node_addresses=node_addresses,
way_addresses=way_addresses,
skipped=skipped,
).as_dict()
_emit(progress_callback, "address_index_import_completed", "OSM address index import completed.", address_count, address_count, result)
return result
def _clear_address_rows(session: Session, *, dataset_id: int) -> None:
if settings.is_postgresql_database:
other_dataset_count = int(
session.scalar(
select(func.count(func.distinct(OsmAddress.dataset_id))).where(OsmAddress.dataset_id != int(dataset_id))
)
or 0
)
if other_dataset_count == 0:
session.execute(text("TRUNCATE TABLE osm_addresses RESTART IDENTITY"))
return
session.execute(delete(OsmAddress).where(OsmAddress.dataset_id == int(dataset_id)))
def address_index_status(session: Session) -> dict[str, object]:
dataset = active_routing_dataset(session)
dataset_id = None if dataset is None else int(dataset.id)
address_count = 0
metadata: dict[str, object] = {}
if dataset is not None:
metadata = _metadata(dataset).get("address_index") or {}
if isinstance(metadata, dict):
try:
address_count = int(metadata.get("addresses") or 0)
except (TypeError, ValueError):
address_count = 0
if not address_count:
address_count = int(session.scalar(select(func.count()).select_from(OsmAddress).where(OsmAddress.dataset_id == dataset.id)) or 0)
installed_version = metadata.get("version") if isinstance(metadata, dict) else None
return {
"dataset_id": dataset_id,
"addresses": address_count,
"available": address_count > 0,
"version": installed_version,
"current_version": ADDRESS_INDEX_VERSION,
"stale": bool(address_count and installed_version != ADDRESS_INDEX_VERSION),
"input_path": metadata.get("input_path") if isinstance(metadata, dict) else None,
}
class _AddressHandler(osmium.SimpleHandler):
def __init__(
self,
*,
session: Session,
dataset_id: int,
batch_size: int,
progress_callback: ProgressCallback | None,
) -> None:
super().__init__()
self.session = session
self.dataset_id = int(dataset_id)
self.batch_size = max(1_000, int(batch_size))
self.progress_callback = progress_callback
self.rows: list[dict[str, object]] = []
self.address_count = 0
self.node_address_count = 0
self.way_address_count = 0
self.skipped_count = 0
self.processed_count = 0
def node(self, node) -> None:
self.process_node(node)
def way(self, way) -> None:
self.process_way(way)
def process_object(self, obj) -> None:
if hasattr(obj, "nodes"):
self.process_way(obj)
elif hasattr(obj, "location"):
self.process_node(obj)
def process_node(self, node) -> None:
self.processed_count += 1
tags = {tag.k: tag.v for tag in node.tags}
if not _has_address(tags):
return
if not node.location.valid():
self.skipped_count += 1
return
row = _address_row(
dataset_id=self.dataset_id,
osm_type="node",
osm_id=str(node.id),
tags=tags,
lon=float(node.location.lon),
lat=float(node.location.lat),
bounds=(float(node.location.lon), float(node.location.lat), float(node.location.lon), float(node.location.lat)),
geometry_geojson=None,
)
if row is None:
self.skipped_count += 1
return
self.rows.append(row)
self.node_address_count += 1
self._after_address()
def process_way(self, way) -> None:
self.processed_count += 1
tags = {tag.k: tag.v for tag in way.tags}
if not _has_address(tags):
return
coords = [
(float(node.location.lon), float(node.location.lat))
for node in way.nodes
if node.location.valid()
]
if not coords:
self.skipped_count += 1
return
lon, lat = _centroid(coords)
min_lon = min(coord[0] for coord in coords)
max_lon = max(coord[0] for coord in coords)
min_lat = min(coord[1] for coord in coords)
max_lat = max(coord[1] for coord in coords)
row = _address_row(
dataset_id=self.dataset_id,
osm_type="way",
osm_id=str(way.id),
tags=tags,
lon=lon,
lat=lat,
bounds=(min_lon, min_lat, max_lon, max_lat),
geometry_geojson=_address_area_geometry_geojson(coords, closed=_way_is_closed(way)),
)
if row is None:
self.skipped_count += 1
return
self.rows.append(row)
self.way_address_count += 1
self._after_address()
def _after_address(self) -> None:
self.address_count += 1
if len(self.rows) >= self.batch_size:
self.flush()
if self.address_count % 50_000 == 0:
_emit(
self.progress_callback,
"address_index_import_batch",
f"Imported {self.address_count:,} OSM addresses.",
self.address_count,
None,
{"processed": self.processed_count, "skipped": self.skipped_count},
)
def flush(self) -> None:
if not self.rows:
return
self.session.bulk_insert_mappings(OsmAddress, self.rows)
self.session.commit()
self.rows = []
def _apply_address_file_processor(handler: _AddressHandler, path: Path) -> None:
processor = (
osmium.FileProcessor(str(path), osmium.osm.NODE | osmium.osm.WAY)
.with_locations()
.with_filter(osmium.filter.KeyFilter("addr:housenumber", "addr:housename"))
)
for obj in processor:
handler.process_object(obj)
def _has_address(tags: dict[str, str]) -> bool:
housenumber = _clean(tags.get("addr:housenumber") or tags.get("addr:housename"))
if not housenumber:
return False
return any(_clean(tags.get(key)) for key in ("addr:street", "addr:place", "addr:city", "addr:postcode"))
def _address_row(
*,
dataset_id: int,
osm_type: str,
osm_id: str,
tags: dict[str, str],
lon: float,
lat: float,
bounds: tuple[float, float, float, float],
geometry_geojson: str | None = None,
) -> dict[str, object] | None:
housenumber = _clean(tags.get("addr:housenumber") or tags.get("addr:housename"))
street = _clean(tags.get("addr:street"))
place = _clean(tags.get("addr:place"))
postcode = _clean(tags.get("addr:postcode"))
city = _clean(tags.get("addr:city") or tags.get("addr:municipality"))
country = _clean(tags.get("addr:country"))
unit = _clean(tags.get("addr:unit"))
name = _clean(tags.get("name"))
display_name = _display_name(housenumber=housenumber, street=street, place=place, postcode=postcode, city=city, name=name)
if not display_name:
return None
search_text = _search_text(display_name, housenumber, street, place, postcode, city, country, unit, name)
selected_tags = {key: tags[key] for key in sorted(ADDRESS_TAGS) if key in tags}
min_lon, min_lat, max_lon, max_lat = bounds
return {
"dataset_id": dataset_id,
"osm_type": osm_type,
"osm_id": osm_id,
"housenumber": housenumber,
"street": street,
"place": place,
"postcode": postcode,
"city": city,
"country": country,
"unit": unit,
"name": name,
"display_name": display_name,
"search_text": search_text,
"lon": lon,
"lat": lat,
"min_lon": min_lon,
"min_lat": min_lat,
"max_lon": max_lon,
"max_lat": max_lat,
"geometry_geojson": geometry_geojson,
"tags_json": json.dumps(selected_tags, separators=(",", ":")) if selected_tags else None,
}
def _address_area_geometry_geojson(coords: list[tuple[float, float]], *, closed: bool | None = None) -> str | None:
if closed is False:
return None
if len(coords) < 3:
return None
ring_coords = list(coords)
first = ring_coords[0]
last = ring_coords[-1]
already_closed = abs(first[0] - last[0]) <= 1e-12 and abs(first[1] - last[1]) <= 1e-12
if not already_closed:
if closed is not True:
return None
ring_coords.append(first)
if len(ring_coords) < 4:
return None
ring = [[float(lon), float(lat)] for lon, lat in ring_coords]
if len({(round(lon, 12), round(lat, 12)) for lon, lat in ring_coords[:-1]}) < 3:
return None
return json.dumps({"type": "Polygon", "coordinates": [ring]}, separators=(",", ":"))
def _way_is_closed(way) -> bool:
try:
nodes = way.nodes
return len(nodes) >= 3 and nodes[0].ref == nodes[-1].ref
except (AttributeError, IndexError, TypeError):
return False
def _display_name(
*,
housenumber: str | None,
street: str | None,
place: str | None,
postcode: str | None,
city: str | None,
name: str | None,
) -> str | None:
road = street or place or name
if road and housenumber:
first = f"{road} {housenumber}"
else:
first = road or housenumber
locality = " ".join(part for part in [postcode, city] if part)
if first and locality:
return f"{first}, {locality}"
return first or locality
def _search_text(*parts: str | None) -> str:
return re.sub(r"\s+", " ", " ".join(part.casefold() for part in parts if part)).strip()
def _clean(value: object) -> str | None:
cleaned = re.sub(r"\s+", " ", str(value or "")).strip()
return cleaned or None
def _centroid(coords: list[tuple[float, float]]) -> tuple[float, float]:
if len(coords) >= 4 and coords[0] == coords[-1]:
area = 0.0
cx = 0.0
cy = 0.0
for (x1, y1), (x2, y2) in zip(coords, coords[1:]):
cross = x1 * y2 - x2 * y1
area += cross
cx += (x1 + x2) * cross
cy += (y1 + y2) * cross
if abs(area) > 1e-18:
factor = 1 / (3 * area)
return cx * factor, cy * factor
return (
math.fsum(coord[0] for coord in coords) / len(coords),
math.fsum(coord[1] for coord in coords) / len(coords),
)
def _drop_address_indexes(session: Session) -> None:
for name in [
"ix_osm_addresses_dataset_city_street",
"ix_osm_addresses_dataset_postcode",
"ix_osm_addresses_bbox",
"ix_osm_addresses_geom_gist",
"ix_osm_addresses_area_geom_gist",
"ix_osm_addresses_search_trgm",
"ix_osm_addresses_display_trgm",
"ix_osm_addresses_street_key_house",
"ix_osm_addresses_street_key_trgm",
]:
session.execute(text(f"DROP INDEX IF EXISTS {name}"))
def _create_address_indexes(session: Session) -> None:
statements = [
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_dataset_city_street ON osm_addresses (dataset_id, city, street, housenumber)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_dataset_postcode ON osm_addresses (dataset_id, postcode)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_bbox ON osm_addresses (dataset_id, min_lon, max_lon, min_lat, max_lat)",
]
if settings.is_postgresql_database:
statements.extend(
[
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_geom_gist ON osm_addresses USING GIST (geom)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_area_geom_gist ON osm_addresses USING GIST (area_geom)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_search_trgm ON osm_addresses USING GIN (LOWER(COALESCE(search_text, '')) gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_display_trgm ON osm_addresses USING GIN (LOWER(COALESCE(display_name, '')) gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_street_key_house ON osm_addresses (dataset_id, REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss'), housenumber)",
"CREATE INDEX IF NOT EXISTS ix_osm_addresses_street_key_trgm ON osm_addresses USING GIN (REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss') gin_trgm_ops)",
]
)
for statement in statements:
session.execute(text(statement))
def _metadata(dataset: Dataset) -> dict[str, object]:
try:
value = json.loads(dataset.metadata_json or "{}")
except json.JSONDecodeError:
return {}
return value if isinstance(value, dict) else {}
def _emit(
progress_callback: ProgressCallback | None,
event_type: str,
message: str,
progress_current: int | None,
progress_total: int | None,
metadata: dict[str, object] | None = None,
) -> None:
if progress_callback is not None:
progress_callback(event_type, message, progress_current, progress_total, metadata)

100
app/pipeline/osm_diff.py Normal file
View File

@@ -0,0 +1,100 @@
from __future__ import annotations
import json
from urllib.parse import urlparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.config import settings
from app.models import Dataset, Source
from app.pipeline.download import materialize_source
from app.pipeline.osm_pbf import _raw_format
from app.pipeline.osm_replication import fetch_replication_state
from app.pipeline.utils import sha256_file
def run_osm_diff_source(session: Session, source: Source) -> Dataset:
"""Commit an OSM change file as a raw update artifact.
Applying the diff to an authoritative OSM base extract is a separate step;
this importer deliberately records the file without treating it as a
complete visual route layer.
"""
if _looks_like_update_directory(source.url):
return _commit_update_directory_state(session, source)
raw_path = materialize_source(source)
raw_hash = sha256_file(raw_path)
existing = session.scalar(
select(Dataset)
.where(Dataset.source_id == source.id, Dataset.kind == "osm_diff_raw", Dataset.sha256 == raw_hash)
.order_by(Dataset.id.desc())
)
if existing is not None:
return existing
dataset = Dataset(
source_id=source.id,
kind="osm_diff_raw",
local_path=str(raw_path),
sha256=raw_hash,
is_active=False,
status="committed",
metadata_json=json.dumps(
{
"stage": "raw_osm_diff",
"raw_format": _raw_format(raw_path),
"source_url": source.url,
},
indent=2,
),
)
session.add(dataset)
session.flush()
return dataset
def _commit_update_directory_state(session: Session, source: Source) -> Dataset:
state = fetch_replication_state(source.url, timeout=settings.osm_diff_state_timeout_seconds)
source_dir = settings.data_dir / "sources" / f"source_{source.id}"
source_dir.mkdir(parents=True, exist_ok=True)
state_path = source_dir / f"state_{state.sequence_number}.txt"
state_path.write_text(
"\n".join(f"{key}={value}" for key, value in sorted(state.raw.items())) + "\n",
encoding="utf-8",
)
state_hash = sha256_file(state_path)
existing = session.scalar(
select(Dataset)
.where(Dataset.source_id == source.id, Dataset.kind == "osm_diff_state", Dataset.sha256 == state_hash)
.order_by(Dataset.id.desc())
)
if existing is not None:
return existing
dataset = Dataset(
source_id=source.id,
kind="osm_diff_state",
local_path=str(state_path),
sha256=state_hash,
is_active=False,
status="committed",
metadata_json=json.dumps(
{
"stage": "osm_diff_state",
"updates_url": source.url,
"sequence_number": state.sequence_number,
"timestamp": state.timestamp,
"state": state.raw,
},
indent=2,
),
)
session.add(dataset)
session.flush()
return dataset
def _looks_like_update_directory(url: str) -> bool:
lower_path = urlparse(url).path.lower()
return lower_path.endswith("-updates") or lower_path.endswith("-updates/")

248
app/pipeline/osm_geojson.py Normal file
View File

@@ -0,0 +1,248 @@
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.config import settings
from app.models import Dataset, OsmFeature, Source
from app.osm_classification import infer_osm_route_scope
from app.osm_storage import (
OSM_STORAGE_METADATA_KEY,
OSM_STORAGE_MAIN,
OSM_STORAGE_SIDECAR_FEATURES,
create_osm_sidecar,
dedupe_osm_feature_rows,
effective_osm_feature_storage,
)
from app.pipeline.download import materialize_source
from app.pipeline.utils import first_nonempty, geometry_json_and_bbox, norm_ref, norm_text, sha256_file
from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries
ROUTE_MODES = {
"train",
"railway",
"light_rail",
"subway",
"tram",
"bus",
"trolleybus",
"coach",
"ferry",
"monorail",
"funicular",
"aerialway",
}
def run_osm_geojson_source(session: Session, source: Source) -> Dataset:
local_path = materialize_source(source)
source_hash = sha256_file(local_path)
existing = session.scalar(
select(Dataset)
.where(
Dataset.source_id == source.id,
Dataset.kind == "osm_geojson",
Dataset.sha256 == source_hash,
Dataset.is_active.is_(True),
Dataset.status == "imported",
)
.order_by(Dataset.id.desc())
)
if existing is not None:
return existing
return import_osm_geojson(session=session, source=source, path=local_path, source_hash=source_hash)
def import_osm_geojson(
session: Session,
source: Source,
path: Path,
source_hash: str | None = None,
*,
storage_mode: str | None = None,
) -> Dataset:
for dataset in source.datasets:
dataset.is_active = False
dataset = Dataset(
source_id=source.id,
kind="osm_geojson",
local_path=str(path),
sha256=source_hash or sha256_file(path),
is_active=True,
status="importing",
)
session.add(dataset)
session.flush()
source_hash = source_hash or sha256_file(path)
dataset.metadata_json = json.dumps(
prepare_osm_geojson_storage(
session=session,
dataset=dataset,
path=path,
source_hash=source_hash,
storage_mode=storage_mode,
),
indent=2,
)
dataset.status = "imported"
source.status = "ok"
source.last_error = None
session.flush()
return dataset
def prepare_osm_geojson_storage(
*,
session: Session,
dataset: Dataset,
path: Path,
source_hash: str | None = None,
storage_mode: str | None = None,
) -> dict[str, object]:
data = json.loads(path.read_text(encoding="utf-8"))
features = _as_features(data)
feature_rows = [_feature_row(dataset.id, idx, feature) for idx, feature in enumerate(features)]
storage = effective_osm_feature_storage(storage_mode)
if storage not in {OSM_STORAGE_MAIN, OSM_STORAGE_SIDECAR_FEATURES}:
raise ValueError(f"Unsupported OSM feature storage mode: {storage}")
if storage == OSM_STORAGE_SIDECAR_FEATURES:
return {
"features": len(feature_rows),
OSM_STORAGE_METADATA_KEY: create_osm_sidecar(dataset, feature_rows, source_hash=source_hash or dataset.sha256),
}
_insert_main_features(session, feature_rows)
session.flush()
refresh_postgis_geometries(session, dataset_id=dataset.id, tables=["osm_features"])
analyze_postgresql_tables(session, ["osm_features"])
return {"features": len(feature_rows), OSM_STORAGE_METADATA_KEY: {"mode": OSM_STORAGE_MAIN}}
def _insert_main_features(session: Session, feature_rows: list[dict[str, object]]) -> None:
objects: list[OsmFeature] = []
deduped_rows, _duplicate_count = dedupe_osm_feature_rows(feature_rows)
for row in deduped_rows:
objects.append(
OsmFeature(
dataset_id=row["dataset_id"],
osm_type=row["osm_type"],
osm_id=row["osm_id"],
kind=row["kind"],
mode=row["mode"],
route_scope=row["route_scope"],
name=row["name"],
ref=row["ref"],
operator=row["operator"],
network=row["network"],
geometry_geojson=row["geometry_geojson"],
min_lon=row["min_lon"],
min_lat=row["min_lat"],
max_lon=row["max_lon"],
max_lat=row["max_lat"],
tags_json=row["tags_json"],
route_key=row["route_key"],
operator_key=row["operator_key"],
)
)
if len(objects) >= 5000:
session.bulk_save_objects(objects)
objects.clear()
if objects:
session.bulk_save_objects(objects)
def _feature_row(dataset_id: int, idx: int, feature: dict[str, Any]) -> dict[str, object]:
props = feature.get("properties") or {}
geometry = feature.get("geometry")
geometry_text, bbox = geometry_json_and_bbox(geometry)
osm_type = str(first_nonempty(props.get("osm_type"), props.get("@type"), props.get("type"), "feature"))
osm_id = str(first_nonempty(props.get("osm_id"), props.get("@id"), props.get("id"), f"feature_{idx}"))
mode = _infer_mode(props)
kind = _infer_kind(props, mode)
name = first_nonempty(props.get("name"), props.get("official_name")) or None
ref = first_nonempty(props.get("ref"), props.get("route_ref"), props.get("line")) or None
operator = first_nonempty(props.get("operator"), props.get("agency"), props.get("brand")) or None
network = first_nonempty(props.get("network"), props.get("network:short")) or None
route_scope = infer_osm_route_scope(mode=mode, ref=ref, name=name, network=network, tags=props)
route_key = norm_ref(ref) or norm_text(name) or norm_ref(osm_id)
operator_key = norm_text(operator or network or "")
return {
"dataset_id": dataset_id,
"osm_type": osm_type,
"osm_id": osm_id,
"kind": kind,
"mode": mode,
"route_scope": route_scope,
"name": name,
"ref": ref,
"operator": operator,
"network": network,
"geometry_geojson": geometry_text,
"min_lon": bbox[0],
"min_lat": bbox[1],
"max_lon": bbox[2],
"max_lat": bbox[3],
"tags_json": json.dumps(props, separators=(",", ":")),
"route_key": route_key,
"operator_key": operator_key,
}
def _as_features(data: Any) -> list[dict[str, Any]]:
if isinstance(data, dict) and data.get("type") == "FeatureCollection":
return [f for f in data.get("features", []) if isinstance(f, dict)]
if isinstance(data, dict) and data.get("type") == "Feature":
return [data]
if isinstance(data, list):
return [f for f in data if isinstance(f, dict)]
raise ValueError("OSM source must be GeoJSON FeatureCollection, Feature, or list of Features")
def _infer_mode(props: dict[str, Any]) -> str | None:
for key in ("mode", "route", "route_master"):
value = str(props.get(key) or "").strip()
if value in ROUTE_MODES:
return "train" if value == "railway" else value
railway = str(props.get("railway") or "").strip()
if railway in {"station", "halt"}:
return "train"
if railway == "tram_stop":
return "tram"
if railway == "subway_entrance":
return "subway"
if str(props.get("highway") or "") == "bus_stop" or str(props.get("amenity") or "") == "bus_station":
return "bus"
if str(props.get("amenity") or "") == "ferry_terminal":
return "ferry"
if str(props.get("aerialway") or "") == "station":
return "aerialway"
return None
def _infer_kind(props: dict[str, Any], mode: str | None) -> str:
explicit_kind = str(props.get("kind") or "").strip()
if explicit_kind in {"route", "stop", "station", "terminal", "infra", "feature"}:
return explicit_kind
if str(props.get("type") or "") in {"route", "route_master"} or str(props.get("route") or "") in ROUTE_MODES:
return "route"
if str(props.get("amenity") or "") == "ferry_terminal":
return "terminal"
if str(props.get("amenity") or "") == "bus_station":
return "terminal"
if str(props.get("railway") or "") in {"station", "halt"}:
return "station"
if str(props.get("aerialway") or "") == "station":
return "station"
if str(props.get("public_transport") or "") in {"platform", "stop_position", "station"}:
return "stop"
if str(props.get("highway") or "") == "bus_stop":
return "stop"
if mode:
return "infra"
return "feature"

View File

@@ -0,0 +1,456 @@
from __future__ import annotations
from datetime import datetime, timezone
import json
from pathlib import Path
import sqlite3
from typing import Callable
from sqlalchemy import func, select, text
from sqlalchemy.orm import Session
from app.models import Dataset, OsmFeature
from app.osm_classification import OSM_ROUTE_SCOPE_CLASSIFIER_VERSION, infer_osm_route_scope_from_tags
from app.osm_storage import (
dataset_metadata,
drop_osm_sidecar_route_scope_indexes,
ensure_osm_sidecar_schema,
features_are_sidecar,
rebuild_osm_sidecar_indexes,
sidecar_path,
writable_sidecar_connection,
)
from app.pipeline.state import (
STAGE_BUILD_INDEXES,
STAGE_LABEL_FEATURES,
dependency_hash,
finish_pipeline_run,
latest_completed_run,
start_pipeline_run,
)
OSM_LABEL_FEATURES_VERSION = OSM_ROUTE_SCOPE_CLASSIFIER_VERSION
MAIN_ROUTE_SCOPE_INDEX = "ix_osm_features_scope_bbox"
MAIN_INDEX_REBUILD_THRESHOLD = 10_000
SIDECAR_INDEX_REBUILD_THRESHOLD = 10_000
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
def relabel_osm_features(
session: Session,
*,
dataset_id: int | None = None,
chunk_size: int = 5000,
force: bool = False,
rebuild_indexes: bool = True,
progress_callback: ProgressCallback | None = None,
job_id: int | None = None,
) -> dict[str, object]:
datasets = _target_datasets(session, dataset_id)
result: dict[str, object] = {
"version": OSM_LABEL_FEATURES_VERSION,
"datasets": len(datasets),
"processed": 0,
"changed": 0,
"skipped": 0,
"missing": 0,
"index_rebuilds": 0,
"dataset_results": [],
}
_emit_progress(
progress_callback,
"osm_labeling_started",
f"Relabeling {len(datasets)} OSM dataset(s).",
0,
len(datasets),
{"dataset_id": dataset_id, "force": force, "version": OSM_LABEL_FEATURES_VERSION},
)
for index, dataset in enumerate(datasets, start=1):
dataset_result = relabel_osm_dataset(
session,
dataset,
chunk_size=chunk_size,
force=force,
rebuild_indexes=rebuild_indexes,
progress_callback=progress_callback,
job_id=job_id,
)
result["processed"] = int(result["processed"]) + int(dataset_result.get("processed", 0) or 0)
result["changed"] = int(result["changed"]) + int(dataset_result.get("changed", 0) or 0)
result["skipped"] = int(result["skipped"]) + (1 if dataset_result.get("status") == "skipped" else 0)
result["missing"] = int(result["missing"]) + (1 if dataset_result.get("status") == "missing_sidecar" else 0)
result["index_rebuilds"] = int(result["index_rebuilds"]) + int(dataset_result.get("index_rebuilds", 0) or 0)
result["dataset_results"].append(dataset_result) # type: ignore[union-attr]
_emit_progress(
progress_callback,
"osm_labeling_dataset_completed",
f"Relabeled {index}/{len(datasets)} OSM dataset(s).",
index,
len(datasets),
dataset_result,
)
_emit_progress(progress_callback, "osm_labeling_completed", "OSM feature relabeling completed.", len(datasets), len(datasets), result)
return result
def relabel_osm_dataset(
session: Session,
dataset: Dataset,
*,
chunk_size: int = 5000,
force: bool = False,
rebuild_indexes: bool = True,
progress_callback: ProgressCallback | None = None,
job_id: int | None = None,
) -> dict[str, object]:
dependency = _label_dependency(dataset)
dependency_hash_value = dependency_hash(dependency)
if not force and _dataset_label_is_current(session, dataset, dependency_hash_value):
return {
"dataset_id": dataset.id,
"source_id": dataset.source_id,
"status": "skipped",
"reason": "label_features dependency is current",
"dependency_hash": dependency_hash_value,
"version": OSM_LABEL_FEATURES_VERSION,
"processed": 0,
"changed": 0,
"index_rebuilds": 0,
}
run = start_pipeline_run(
session,
stage=STAGE_LABEL_FEATURES,
version=OSM_LABEL_FEATURES_VERSION,
dependency_hash_value=dependency_hash_value,
source_id=dataset.source_id,
dataset_id=dataset.id,
job_id=job_id,
inputs=dependency,
)
session.commit()
try:
if features_are_sidecar(dataset):
counts = _relabel_sidecar_dataset(dataset, chunk_size=chunk_size, rebuild_indexes=rebuild_indexes, progress_callback=progress_callback)
else:
counts = _relabel_main_dataset(session, dataset, chunk_size=chunk_size, rebuild_indexes=rebuild_indexes, progress_callback=progress_callback)
output = {
"dataset_id": dataset.id,
"source_id": dataset.source_id,
"status": "completed",
"dependency_hash": dependency_hash_value,
"version": OSM_LABEL_FEATURES_VERSION,
**counts,
}
_stamp_dataset_metadata(session, dataset, dependency_hash_value, output)
finish_pipeline_run(session, run, outputs=output)
session.commit()
return output
except FileNotFoundError as exc:
output = {
"dataset_id": dataset.id,
"source_id": dataset.source_id,
"status": "missing_sidecar",
"dependency_hash": dependency_hash_value,
"version": OSM_LABEL_FEATURES_VERSION,
"processed": 0,
"changed": 0,
"index_rebuilds": 0,
"error": str(exc),
}
finish_pipeline_run(session, run, status="failed", outputs=output, error=str(exc))
session.commit()
return output
except Exception as exc:
finish_pipeline_run(session, run, status="failed", error=str(exc))
session.commit()
raise
def _target_datasets(session: Session, dataset_id: int | None) -> list[Dataset]:
stmt = select(Dataset).where(Dataset.kind == "osm_geojson", Dataset.status == "imported")
if dataset_id is None:
stmt = stmt.where(Dataset.is_active.is_(True))
else:
stmt = stmt.where(Dataset.id == dataset_id)
return session.scalars(stmt.order_by(Dataset.source_id, Dataset.id)).all()
def _dataset_label_is_current(session: Session, dataset: Dataset, dependency_hash_value: str) -> bool:
metadata = dataset_metadata(dataset)
label_info = metadata.get("label_features")
metadata_current = (
isinstance(label_info, dict)
and label_info.get("version") == OSM_LABEL_FEATURES_VERSION
and label_info.get("dependency_hash") == dependency_hash_value
)
if not metadata_current:
return False
return (
latest_completed_run(
session,
stage=STAGE_LABEL_FEATURES,
version=OSM_LABEL_FEATURES_VERSION,
dependency_hash_value=dependency_hash_value,
source_id=dataset.source_id,
dataset_id=dataset.id,
)
is not None
)
def _relabel_sidecar_dataset(
dataset: Dataset,
*,
chunk_size: int,
rebuild_indexes: bool,
progress_callback: ProgressCallback | None,
) -> dict[str, int | str]:
path = sidecar_path(dataset)
if path is None or not path.exists():
raise FileNotFoundError(f"OSM sidecar does not exist: {path}")
with writable_sidecar_connection(dataset) as connection:
ensure_osm_sidecar_schema(connection)
total = int(connection.execute("SELECT COUNT(*) FROM osm_features").fetchone()[0] or 0)
should_rebuild_index = rebuild_indexes and total >= SIDECAR_INDEX_REBUILD_THRESHOLD
if should_rebuild_index:
drop_osm_sidecar_route_scope_indexes(connection)
connection.commit()
processed = 0
changed = 0
last_id = 0
try:
while True:
rows = connection.execute(
"""
SELECT id, mode, ref, name, network, tags_json, route_scope
FROM osm_features
WHERE id > ?
ORDER BY id
LIMIT ?
""",
(last_id, max(1, int(chunk_size))),
).fetchall()
if not rows:
break
updates: list[tuple[str | None, int]] = []
for row in rows:
last_id = int(row["id"])
new_scope = _classified_scope(row["mode"], row["ref"], row["name"], row["network"], row["tags_json"])
if _normalize_scope(row["route_scope"]) != new_scope:
updates.append((new_scope, last_id))
if updates:
connection.executemany("UPDATE osm_features SET route_scope = ? WHERE id = ?", updates)
processed += len(rows)
changed += len(updates)
connection.commit()
_emit_progress(
progress_callback,
"osm_labeling_batch",
f"Relabeled {processed}/{total} OSM sidecar features.",
processed,
total,
{"dataset_id": dataset.id, "changed": changed, "storage": "sidecar"},
)
finally:
index_rebuilds = 0
if should_rebuild_index:
rebuild_osm_sidecar_indexes(connection)
connection.commit()
index_rebuilds = 1
_record_sidecar_index_build(connection, dataset, path)
_record_sidecar_label(connection, dataset, processed=processed, changed=changed)
connection.commit()
return {"storage": "sidecar", "processed": processed, "changed": changed, "index_rebuilds": index_rebuilds}
def _relabel_main_dataset(
session: Session,
dataset: Dataset,
*,
chunk_size: int,
rebuild_indexes: bool,
progress_callback: ProgressCallback | None,
) -> dict[str, int | str]:
total = int(session.scalar(select(func.count()).select_from(OsmFeature).where(OsmFeature.dataset_id == dataset.id)) or 0)
should_rebuild_index = rebuild_indexes and total >= MAIN_INDEX_REBUILD_THRESHOLD
index_rebuilds = 0
if should_rebuild_index:
session.execute(text(f"DROP INDEX IF EXISTS {MAIN_ROUTE_SCOPE_INDEX}"))
session.commit()
processed = 0
changed = 0
last_id = 0
try:
while True:
rows = session.scalars(
select(OsmFeature)
.where(OsmFeature.dataset_id == dataset.id, OsmFeature.id > last_id)
.order_by(OsmFeature.id)
.limit(max(1, int(chunk_size)))
).all()
if not rows:
break
updates: list[dict[str, object]] = []
for feature in rows:
last_id = int(feature.id)
new_scope = _classified_scope(feature.mode, feature.ref, feature.name, feature.network, feature.tags_json)
if _normalize_scope(feature.route_scope) != new_scope:
updates.append({"id": feature.id, "route_scope": new_scope})
if updates:
session.bulk_update_mappings(OsmFeature, updates)
processed += len(rows)
changed += len(updates)
session.commit()
_emit_progress(
progress_callback,
"osm_labeling_batch",
f"Relabeled {processed}/{total} main-table OSM features.",
processed,
total,
{"dataset_id": dataset.id, "changed": changed, "storage": "main"},
)
finally:
if should_rebuild_index:
session.execute(
text(
"CREATE INDEX IF NOT EXISTS ix_osm_features_scope_bbox "
"ON osm_features (dataset_id, kind, mode, route_scope, min_lon, max_lon, min_lat, max_lat)"
)
)
session.commit()
index_rebuilds = 1
_record_main_index_build(session, dataset)
return {"storage": "main", "processed": processed, "changed": changed, "index_rebuilds": index_rebuilds}
def _classified_scope(mode: object, ref: object, name: object, network: object, tags_json: object) -> str | None:
return _normalize_scope(
infer_osm_route_scope_from_tags(
None if mode is None else str(mode),
None if ref is None else str(ref),
None if name is None else str(name),
None if network is None else str(network),
None if tags_json is None else str(tags_json),
)
)
def _normalize_scope(value: object) -> str | None:
text_value = str(value or "").strip()
return text_value or None
def _label_dependency(dataset: Dataset) -> dict[str, object]:
metadata = dataset_metadata(dataset)
storage = metadata.get("osm_storage") if isinstance(metadata, dict) else None
path = sidecar_path(dataset)
path_fingerprint: dict[str, object] | None = None
if path is not None:
resolved = Path(path)
if resolved.exists():
path_fingerprint = {"path": str(resolved), "exists": True}
else:
path_fingerprint = {"path": str(resolved), "missing": True}
return {
"dataset_id": dataset.id,
"source_id": dataset.source_id,
"kind": dataset.kind,
"dataset_sha256": dataset.sha256,
"storage": storage,
"sidecar": path_fingerprint,
"classifier_version": OSM_LABEL_FEATURES_VERSION,
}
def _stamp_dataset_metadata(session: Session, dataset: Dataset, dependency_hash_value: str, output: dict[str, object]) -> None:
refreshed = session.get(Dataset, dataset.id)
if refreshed is None:
return
metadata = dataset_metadata(refreshed)
metadata["label_features"] = {
"stage": STAGE_LABEL_FEATURES,
"version": OSM_LABEL_FEATURES_VERSION,
"dependency_hash": dependency_hash_value,
"labeled_at": datetime.now(timezone.utc).isoformat(),
"processed": output.get("processed", 0),
"changed": output.get("changed", 0),
"storage": output.get("storage"),
}
refreshed.metadata_json = json.dumps(metadata, indent=2)
session.flush()
def _record_sidecar_label(connection: sqlite3.Connection, dataset: Dataset, *, processed: int, changed: int) -> None:
connection.execute("CREATE TABLE IF NOT EXISTS pipeline_metadata (key TEXT PRIMARY KEY, value TEXT NOT NULL)")
connection.execute(
"INSERT OR REPLACE INTO pipeline_metadata (key, value) VALUES (?, ?)",
(
"label_features",
json.dumps(
{
"stage": STAGE_LABEL_FEATURES,
"version": OSM_LABEL_FEATURES_VERSION,
"dataset_id": dataset.id,
"processed": processed,
"changed": changed,
"updated_at": datetime.now(timezone.utc).isoformat(),
},
sort_keys=True,
separators=(",", ":"),
),
),
)
def _record_sidecar_index_build(connection: sqlite3.Connection, dataset: Dataset, path: Path) -> None:
connection.execute("CREATE TABLE IF NOT EXISTS pipeline_metadata (key TEXT PRIMARY KEY, value TEXT NOT NULL)")
connection.execute(
"INSERT OR REPLACE INTO pipeline_metadata (key, value) VALUES (?, ?)",
(
"build_indexes:route_scope",
json.dumps(
{
"stage": STAGE_BUILD_INDEXES,
"version": "osm_sidecar_indexes_v1",
"dataset_id": dataset.id,
"path": str(path),
"updated_at": datetime.now(timezone.utc).isoformat(),
},
sort_keys=True,
separators=(",", ":"),
),
),
)
def _record_main_index_build(session: Session, dataset: Dataset) -> None:
dependency = {
"dataset_id": dataset.id,
"index": MAIN_ROUTE_SCOPE_INDEX,
"version": "osm_main_indexes_v1",
}
run = start_pipeline_run(
session,
stage=STAGE_BUILD_INDEXES,
version="osm_main_indexes_v1",
dependency_hash_value=dependency_hash(dependency),
source_id=dataset.source_id,
dataset_id=dataset.id,
inputs=dependency,
)
finish_pipeline_run(session, run, outputs={"index": MAIN_ROUTE_SCOPE_INDEX})
session.commit()
def _emit_progress(
callback: ProgressCallback | None,
event_type: str,
message: str,
current: int | None,
total: int | None,
metadata: dict[str, object] | None,
) -> None:
if callback is not None:
callback(event_type, message, current, total, metadata)

1581
app/pipeline/osm_pbf.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,105 @@
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import subprocess
from urllib.parse import urljoin, urlparse
import requests
@dataclass(frozen=True)
class ReplicationState:
sequence_number: int
timestamp: str | None
raw: dict[str, str]
def fetch_replication_state(updates_url: str, *, timeout: float = 30) -> ReplicationState:
state_url = _state_url(updates_url)
response = requests.get(state_url, timeout=timeout)
response.raise_for_status()
return parse_replication_state_text(response.text)
def parse_replication_state_text(text: str) -> ReplicationState:
values: dict[str, str] = {}
for line in text.splitlines():
line = line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, value = line.split("=", 1)
values[key.strip()] = _unescape_state_value(value.strip())
sequence = values.get("sequenceNumber")
if sequence is None:
raise ValueError("replication state is missing sequenceNumber")
try:
sequence_number = int(sequence)
except ValueError as exc:
raise ValueError(f"invalid replication sequenceNumber: {sequence}") from exc
return ReplicationState(
sequence_number=sequence_number,
timestamp=values.get("timestamp"),
raw=values,
)
def diff_url_for_sequence(updates_url: str, sequence_number: int) -> str:
padded = str(sequence_number).zfill(max(9, ((len(str(sequence_number)) + 2) // 3) * 3))
parts = [padded[index : index + 3] for index in range(0, len(padded), 3)]
return urljoin(_directory_url(updates_url), "/".join(parts) + ".osc.gz")
def download_diff(updates_url: str, sequence_number: int, output_dir: Path, *, timeout: float = 120) -> Path:
url = diff_url_for_sequence(updates_url, sequence_number)
parsed_path = Path(urlparse(url).path)
output_path = output_dir / parsed_path.name
nested = output_dir / parsed_path.parent.name / output_path.name
if output_path.exists():
return output_path
if nested.exists():
return nested
output_dir.mkdir(parents=True, exist_ok=True)
temp_path = output_dir / f"{sequence_number}.download"
with requests.get(url, stream=True, timeout=timeout) as response:
response.raise_for_status()
with temp_path.open("wb") as handle:
for chunk in response.iter_content(chunk_size=1024 * 1024):
if chunk:
handle.write(chunk)
temp_path.replace(output_path)
return output_path
def apply_osm_changes(base_path: Path, diff_paths: list[Path], output_path: Path, host_tool_path: Path) -> subprocess.CompletedProcess[str]:
if not diff_paths:
raise ValueError("no OSM change files supplied")
output_path.parent.mkdir(parents=True, exist_ok=True)
command = [
str(host_tool_path),
"osmium",
"apply-changes",
"--output",
str(output_path),
"--overwrite",
str(base_path),
*[str(path) for path in diff_paths],
]
return subprocess.run(command, check=True, capture_output=True, text=True)
def _state_url(updates_url: str) -> str:
return urljoin(_directory_url(updates_url), "state.txt")
def _directory_url(url: str) -> str:
return url if url.endswith("/") else f"{url}/"
def _unescape_state_value(value: str) -> str:
return (
value.replace("\\:", ":")
.replace("\\=", "=")
.replace("\\ ", " ")
.replace("\\\\", "\\")
)

1903
app/pipeline/route_layer.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,473 @@
from __future__ import annotations
import json
import math
from dataclasses import dataclass
from pathlib import Path
from typing import Callable
import osmium
from sqlalchemy import delete, func, select, text
from sqlalchemy.dialects.postgresql import insert as postgresql_insert
from sqlalchemy.orm import Session
from app.config import settings
from app.models import Dataset, RoutingEdge, RoutingNode
from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
ROUTING_LAYER_VERSION = "routing_layer_v2_osm_highway_segments_service_tags"
DRIVE_HIGHWAYS = {
"motorway",
"motorway_link",
"trunk",
"trunk_link",
"primary",
"primary_link",
"secondary",
"secondary_link",
"tertiary",
"tertiary_link",
"unclassified",
"residential",
"living_street",
"service",
"road",
"track",
}
WALK_HIGHWAYS = {
"pedestrian",
"footway",
"path",
"steps",
"cycleway",
"bridleway",
"living_street",
"residential",
"service",
"track",
"unclassified",
"tertiary",
"tertiary_link",
"secondary",
"secondary_link",
"primary",
"primary_link",
"road",
}
EXCLUDED_HIGHWAYS = {"construction", "proposed", "abandoned", "platform", "raceway"}
NO_VALUES = {"no", "private", "agricultural", "forestry", "delivery", "customers"}
YES_VALUES = {"yes", "designated", "permissive", "destination"}
ONEWAY_FORWARD = {"yes", "true", "1"}
ONEWAY_REVERSE = {"-1", "reverse"}
DEFAULT_DRIVE_SPEED_KMH = {
"motorway": 110,
"motorway_link": 50,
"trunk": 90,
"trunk_link": 45,
"primary": 70,
"primary_link": 40,
"secondary": 60,
"secondary_link": 35,
"tertiary": 50,
"tertiary_link": 30,
"unclassified": 40,
"residential": 30,
"living_street": 10,
"service": 15,
"road": 30,
"track": 15,
}
DEFAULT_WALK_SPEED_MPS = 1.35
STEP_WALK_SPEED_MPS = 0.65
@dataclass
class RoutingImportResult:
dataset_id: int
input_path: str
nodes: int
edges: int
walk_edges: int
drive_edges: int
skipped_ways: int
version: str = ROUTING_LAYER_VERSION
def as_dict(self) -> dict[str, object]:
return {
"version": self.version,
"dataset_id": self.dataset_id,
"input_path": self.input_path,
"nodes": self.nodes,
"edges": self.edges,
"walk_edges": self.walk_edges,
"drive_edges": self.drive_edges,
"skipped_ways": self.skipped_ways,
}
def active_routing_dataset(session: Session) -> Dataset | None:
active_osm = session.scalar(
select(Dataset).where(Dataset.kind == "osm_geojson", Dataset.is_active.is_(True)).order_by(Dataset.id.desc())
)
if active_osm is not None:
metadata = _metadata(active_osm)
raw_dataset_id = metadata.get("raw_dataset_id")
if raw_dataset_id is not None:
raw = session.get(Dataset, int(raw_dataset_id))
if raw is not None and Path(raw.local_path).exists():
return raw
return session.scalar(
select(Dataset)
.where(Dataset.kind == "osm_pbf_raw")
.order_by(Dataset.is_active.desc(), Dataset.id.desc())
)
def rebuild_routing_layer(
session: Session,
*,
dataset_id: int | None = None,
input_path: str | Path | None = None,
reset: bool = True,
batch_size: int = 5000,
progress_callback: ProgressCallback | None = None,
) -> dict[str, object]:
if not settings.is_postgresql_database:
raise RuntimeError("The routing layer importer requires PostgreSQL/PostGIS.")
dataset = session.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(session)
if dataset is None:
raise ValueError("No OSM PBF dataset is available for routing import.")
path = Path(input_path or dataset.local_path)
if not path.exists():
raise FileNotFoundError(f"Routing import PBF does not exist: {path}")
if reset:
_emit(progress_callback, "routing_layer_clear_started", "Clearing existing routing graph.", None, None, {"dataset_id": dataset.id})
session.execute(delete(RoutingEdge).where(RoutingEdge.dataset_id == dataset.id))
session.execute(delete(RoutingNode).where(RoutingNode.dataset_id == dataset.id))
session.commit()
_emit(progress_callback, "routing_layer_import_started", "Importing routable OSM highway graph.", None, None, {"dataset_id": dataset.id, "path": str(path)})
handler = _RoutingGraphHandler(session=session, dataset_id=dataset.id, batch_size=batch_size, progress_callback=progress_callback)
handler.apply_file(str(path), locations=True)
handler.flush()
return finalize_routing_layer(
session,
dataset_id=dataset.id,
input_path=str(path),
skipped_way_count=handler.skipped_way_count,
progress_callback=progress_callback,
)
def finalize_routing_layer(
session: Session,
*,
dataset_id: int | None = None,
input_path: str | Path | None = None,
skipped_way_count: int = 0,
progress_callback: ProgressCallback | None = None,
) -> dict[str, object]:
if not settings.is_postgresql_database:
raise RuntimeError("The routing layer finalizer requires PostgreSQL/PostGIS.")
dataset = session.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(session)
if dataset is None:
raise ValueError("No routing dataset is available to finalize.")
path = Path(input_path or dataset.local_path)
_emit(progress_callback, "routing_layer_geometry_indexes_dropped", "Dropping routing geometry indexes before bulk refresh.", None, None, {"dataset_id": dataset.id})
_drop_routing_geometry_indexes(session)
session.commit()
_emit(progress_callback, "routing_layer_geometry_started", "Refreshing routing node PostGIS geometries.", None, None, {"dataset_id": dataset.id})
refresh_postgis_geometries(session, dataset_id=dataset.id, tables=["routing_nodes"], only_missing=False)
session.commit()
_emit(progress_callback, "routing_layer_geometry_indexes_started", "Rebuilding routing geometry indexes.", None, None, {"dataset_id": dataset.id})
_create_routing_geometry_indexes(session)
session.commit()
analyze_postgresql_tables(session, ["routing_nodes", "routing_edges"])
node_count = int(session.scalar(select(func.count()).select_from(RoutingNode).where(RoutingNode.dataset_id == dataset.id)) or 0)
edge_count = int(session.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset.id)) or 0)
walk_edge_count = int(session.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset.id, RoutingEdge.walk_cost_s.is_not(None))) or 0)
drive_edge_count = int(session.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset.id, RoutingEdge.drive_cost_s.is_not(None))) or 0)
dataset_metadata = _metadata(dataset)
dataset_metadata["routing_layer"] = {
"version": ROUTING_LAYER_VERSION,
"nodes": node_count,
"edges": edge_count,
"walk_edges": walk_edge_count,
"drive_edges": drive_edge_count,
"input_path": str(path),
}
dataset.metadata_json = json.dumps(dataset_metadata, indent=2)
session.commit()
result = RoutingImportResult(
dataset_id=dataset.id,
input_path=str(path),
nodes=node_count,
edges=edge_count,
walk_edges=walk_edge_count,
drive_edges=drive_edge_count,
skipped_ways=skipped_way_count,
).as_dict()
_emit(progress_callback, "routing_layer_import_completed", "Routing graph import completed.", edge_count, edge_count, result)
return result
def _drop_routing_geometry_indexes(session: Session) -> None:
session.execute(text("DROP INDEX IF EXISTS ix_routing_nodes_geom_gist"))
session.execute(text("DROP INDEX IF EXISTS ix_routing_edges_geom_gist"))
session.execute(text("DROP INDEX IF EXISTS ix_routing_edges_bbox_box_gist"))
def _create_routing_geometry_indexes(session: Session) -> None:
session.execute(text("CREATE INDEX IF NOT EXISTS ix_routing_nodes_geom_gist ON routing_nodes USING GIST (geom)"))
session.execute(text("CREATE INDEX IF NOT EXISTS ix_routing_edges_bbox_box_gist ON routing_edges USING GIST (box(point(max_lon, max_lat), point(min_lon, min_lat)))"))
class _RoutingGraphHandler(osmium.SimpleHandler):
def __init__(
self,
*,
session: Session,
dataset_id: int,
batch_size: int,
progress_callback: ProgressCallback | None,
) -> None:
super().__init__()
self.session = session
self.dataset_id = dataset_id
self.batch_size = max(500, int(batch_size))
self.progress_callback = progress_callback
self.nodes: dict[int, dict[str, object]] = {}
self.edges: list[dict[str, object]] = []
self.node_count = int(
session.scalar(select(func.count()).select_from(RoutingNode).where(RoutingNode.dataset_id == dataset_id)) or 0
)
self.edge_count = int(
session.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset_id)) or 0
)
self.walk_edge_count = 0
self.drive_edge_count = 0
self.skipped_way_count = 0
self.processed_way_count = 0
def way(self, way) -> None:
tags = {tag.k: tag.v for tag in way.tags}
highway = tags.get("highway")
if not highway or highway in EXCLUDED_HIGHWAYS:
self.skipped_way_count += 1
return
walkable = _walkable(tags, highway)
drivable = _drivable(tags, highway)
if not walkable and not drivable:
self.skipped_way_count += 1
return
nodes = []
for node in way.nodes:
if not node.location.valid():
continue
nodes.append((int(node.ref), float(node.location.lon), float(node.location.lat)))
if len(nodes) < 2:
self.skipped_way_count += 1
return
oneway = _oneway_direction(tags, highway)
drive_speed_mps = _drive_speed_mps(tags, highway)
walk_speed_mps = STEP_WALK_SPEED_MPS if highway == "steps" else DEFAULT_WALK_SPEED_MPS
for left, right in zip(nodes, nodes[1:]):
source_id, source_lon, source_lat = left
target_id, target_lon, target_lat = right
if source_id == target_id:
continue
length_m = _distance_m(source_lat, source_lon, target_lat, target_lon)
if length_m <= 0:
continue
if oneway == "reverse":
source_id, target_id = target_id, source_id
source_lon, target_lon = target_lon, source_lon
source_lat, target_lat = target_lat, source_lat
walk_cost = length_m / walk_speed_mps if walkable else None
drive_cost = length_m / drive_speed_mps if drivable and drive_speed_mps > 0 else None
reverse_walk_cost = walk_cost
reverse_drive_cost = None if oneway in {"forward", "reverse"} else drive_cost
self.nodes[source_id] = {"dataset_id": self.dataset_id, "osm_node_id": source_id, "lon": source_lon, "lat": source_lat}
self.nodes[target_id] = {"dataset_id": self.dataset_id, "osm_node_id": target_id, "lon": target_lon, "lat": target_lat}
self.edges.append(
{
"dataset_id": self.dataset_id,
"osm_way_id": int(way.id),
"source_osm_node_id": source_id,
"target_osm_node_id": target_id,
"source_lon": source_lon,
"source_lat": source_lat,
"target_lon": target_lon,
"target_lat": target_lat,
"highway": highway,
"name": tags.get("name"),
"length_m": length_m,
"walk_cost_s": walk_cost,
"reverse_walk_cost_s": reverse_walk_cost,
"drive_cost_s": drive_cost,
"reverse_drive_cost_s": reverse_drive_cost,
"geometry_geojson": json.dumps({"type": "LineString", "coordinates": [[source_lon, source_lat], [target_lon, target_lat]]}, separators=(",", ":")),
"min_lon": min(source_lon, target_lon),
"min_lat": min(source_lat, target_lat),
"max_lon": max(source_lon, target_lon),
"max_lat": max(source_lat, target_lat),
"tags_json": _routing_tags_json(tags),
}
)
self.edge_count += 1
if walk_cost is not None:
self.walk_edge_count += 1
if drive_cost is not None:
self.drive_edge_count += 1
self.processed_way_count += 1
if len(self.edges) >= self.batch_size:
self.flush()
if self.processed_way_count % 100_000 == 0:
_emit(
self.progress_callback,
"routing_layer_import_batch",
f"Imported {self.edge_count:,} routing edges.",
self.edge_count,
None,
{"processed_ways": self.processed_way_count, "nodes_pending": len(self.nodes), "edges": self.edge_count},
)
def flush(self) -> None:
if not self.nodes and not self.edges:
return
node_rows = list(self.nodes.values())
edge_rows = self.edges
if node_rows:
stmt = postgresql_insert(RoutingNode).values(node_rows)
stmt = stmt.on_conflict_do_nothing(index_elements=["dataset_id", "osm_node_id"])
self.session.execute(stmt)
self.node_count += len(node_rows)
self.nodes.clear()
if edge_rows:
self.session.bulk_insert_mappings(RoutingEdge, edge_rows)
self.edges = []
self.session.commit()
def _walkable(tags: dict[str, str], highway: str) -> bool:
if highway not in WALK_HIGHWAYS:
return False
access = _tag_value(tags, "access")
foot = _tag_value(tags, "foot")
if foot in NO_VALUES:
return False
if access in NO_VALUES and foot not in YES_VALUES:
return False
if highway in {"motorway", "motorway_link", "trunk", "trunk_link"} and foot not in YES_VALUES:
return False
return True
def _drivable(tags: dict[str, str], highway: str) -> bool:
if highway not in DRIVE_HIGHWAYS:
return False
access = _tag_value(tags, "access")
motor_vehicle = _tag_value(tags, "motor_vehicle")
motorcar = _tag_value(tags, "motorcar")
vehicle = _tag_value(tags, "vehicle")
if motorcar in NO_VALUES or motor_vehicle in NO_VALUES or vehicle in NO_VALUES:
return False
if access in NO_VALUES and motorcar not in YES_VALUES and motor_vehicle not in YES_VALUES:
return False
if highway in {"footway", "path", "pedestrian", "steps", "cycleway", "bridleway"}:
return motorcar in YES_VALUES or motor_vehicle in YES_VALUES
return True
def _oneway_direction(tags: dict[str, str], highway: str) -> str:
oneway = _tag_value(tags, "oneway")
if oneway in ONEWAY_REVERSE:
return "reverse"
if oneway in ONEWAY_FORWARD or tags.get("junction") == "roundabout" or highway == "motorway":
return "forward"
return "both"
def _drive_speed_mps(tags: dict[str, str], highway: str) -> float:
maxspeed = _parse_maxspeed(tags.get("maxspeed"))
kmh = maxspeed or DEFAULT_DRIVE_SPEED_KMH.get(highway, 30)
return max(5.0, float(kmh) / 3.6)
def _parse_maxspeed(value: str | None) -> float | None:
if not value:
return None
text = value.strip().lower()
if text in {"signals", "none", "walk", "variable"}:
return None
if text.endswith("mph"):
number = _leading_float(text[:-3])
return None if number is None else number * 1.60934
return _leading_float(text)
def _leading_float(value: str) -> float | None:
digits = []
for char in value.strip():
if char.isdigit() or char == ".":
digits.append(char)
elif digits:
break
if not digits:
return None
try:
return float("".join(digits))
except ValueError:
return None
def _routing_tags_json(tags: dict[str, str]) -> str:
selected = {
key: value
for key, value in tags.items()
if key in {"access", "bicycle", "bridge", "foot", "highway", "junction", "maxspeed", "motor_vehicle", "motorcar", "name", "oneway", "service", "surface", "tunnel", "vehicle"}
}
return json.dumps(selected, separators=(",", ":"))
def _tag_value(tags: dict[str, str], key: str) -> str:
return str(tags.get(key) or "").strip().lower()
def _distance_m(lat_a: float, lon_a: float, lat_b: float, lon_b: float) -> float:
radius = 6_371_000.0
phi_a = math.radians(lat_a)
phi_b = math.radians(lat_b)
delta_phi = math.radians(lat_b - lat_a)
delta_lambda = math.radians(lon_b - lon_a)
hav = math.sin(delta_phi / 2) ** 2 + math.cos(phi_a) * math.cos(phi_b) * math.sin(delta_lambda / 2) ** 2
return radius * 2 * math.atan2(math.sqrt(hav), math.sqrt(1 - hav))
def _metadata(dataset: Dataset) -> dict[str, object]:
try:
value = json.loads(dataset.metadata_json or "{}")
except json.JSONDecodeError:
return {}
return value if isinstance(value, dict) else {}
def _emit(
progress_callback: ProgressCallback | None,
event_type: str,
message: str,
progress_current: int | None,
progress_total: int | None,
metadata: dict[str, object] | None = None,
) -> None:
if progress_callback is not None:
progress_callback(event_type, message, progress_current, progress_total, metadata)

40
app/pipeline/run.py Normal file
View File

@@ -0,0 +1,40 @@
from __future__ import annotations
from datetime import datetime, timezone
from typing import Callable, Any
from sqlalchemy.orm import Session
from app.models import Source
from app.pipeline.gtfs import run_gtfs_source
from app.pipeline.osm_diff import run_osm_diff_source
from app.pipeline.osm_geojson import run_osm_geojson_source
from app.pipeline.osm_pbf import run_osm_pbf_source
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, Any] | None], None]
def run_source(session: Session, source: Source, progress_callback: ProgressCallback | None = None):
source.status = "running"
source.last_run_at = datetime.now(timezone.utc)
source.last_error = None
session.flush()
try:
if source.kind == "gtfs":
dataset = run_gtfs_source(session, source, progress_callback=progress_callback)
elif source.kind == "osm_geojson":
dataset = run_osm_geojson_source(session, source)
elif source.kind == "osm_pbf":
dataset = run_osm_pbf_source(session, source, progress_callback=progress_callback)
elif source.kind == "osm_diff":
dataset = run_osm_diff_source(session, source)
else:
raise ValueError(f"Unsupported source kind: {source.kind}")
source.status = "ok"
source.last_error = None
return dataset
except Exception as exc: # noqa: BLE001 - persist pipeline error for UI
source.status = "error"
source.last_error = str(exc)
raise

294
app/pipeline/sample_data.py Normal file
View File

@@ -0,0 +1,294 @@
from __future__ import annotations
import csv
import io
import json
import zipfile
from pathlib import Path
from datetime import datetime, timezone
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from app.config import settings
from app.db import init_db
from app.models import (
Dataset,
CanonicalStop,
CanonicalStopLink,
GtfsAgency,
GtfsCalendar,
GtfsCalendarDate,
GtfsRoute,
GtfsRoutePatternLink,
GtfsShape,
GtfsStop,
GtfsStopTime,
GtfsTripRoutePatternLink,
GtfsTrip,
Itinerary,
ItineraryLeg,
Job,
JobEvent,
MatchRule,
OsmDiffState,
OsmFeature,
PipelineRun,
RouteMatch,
RoutePattern,
RoutePatternStop,
RoutingEdge,
RoutingNode,
Source,
SourceCatalogEntry,
SourceUpdateCheck,
TravelRequest,
)
from app.pipeline.matcher import run_route_matching
from app.pipeline.route_layer import rebuild_route_layer
from app.pipeline.run import run_source
def load_sample_project(session: Session, *, preserve_job_id: int | None = None) -> dict:
"""Clear the DB, create a small Berlin-like GTFS + OSM sample, import, and match."""
init_db()
clear_project_data(session, preserve_job_id=preserve_job_id, preserve_catalog=True)
sample_dir = settings.data_dir / "sample"
sample_dir.mkdir(parents=True, exist_ok=True)
gtfs_path = sample_dir / "sample_berlin.gtfs.zip"
osm_path = sample_dir / "sample_berlin_osm.geojson"
create_sample_gtfs(gtfs_path)
create_sample_osm_geojson(osm_path)
gtfs_source = Source(name="Sample Berlin GTFS", kind="gtfs", url=str(gtfs_path), country="DE", license="sample")
osm_source = Source(name="Sample Berlin OSM transport", kind="osm_geojson", url=str(osm_path), country="DE", license="sample")
session.add_all([gtfs_source, osm_source])
session.flush()
gtfs_dataset = run_source(session, gtfs_source)
osm_dataset = run_source(session, osm_source)
match_result = run_route_matching(session)
route_layer_result = rebuild_route_layer(session)
return {
"status": "ok",
"gtfs_dataset_id": gtfs_dataset.id,
"osm_dataset_id": osm_dataset.id,
"match_result": match_result,
"route_layer_result": route_layer_result,
}
def clear_project_data(
session: Session,
*,
preserve_job_id: int | None = None,
preserve_catalog: bool = True,
) -> None:
"""Clear user/project data while optionally preserving the current queue job."""
session.execute(delete(PipelineRun))
if preserve_job_id is None:
session.execute(delete(JobEvent))
session.execute(delete(Job))
else:
_cancel_other_jobs_for_reset(session, preserve_job_id)
for model in [
ItineraryLeg,
Itinerary,
TravelRequest,
SourceUpdateCheck,
OsmDiffState,
MatchRule,
RouteMatch,
GtfsTripRoutePatternLink,
GtfsRoutePatternLink,
RoutePatternStop,
RoutePattern,
CanonicalStopLink,
CanonicalStop,
RoutingEdge,
RoutingNode,
GtfsStopTime,
GtfsCalendarDate,
GtfsCalendar,
GtfsShape,
GtfsTrip,
GtfsRoute,
GtfsStop,
GtfsAgency,
OsmFeature,
Dataset,
Source,
]:
session.execute(delete(model))
if not preserve_catalog:
session.execute(delete(SourceCatalogEntry))
session.flush()
def _cancel_other_jobs_for_reset(session: Session, preserve_job_id: int) -> None:
now = datetime.now(timezone.utc)
active_statuses = {"queued", "running", "paused"}
jobs = session.scalars(
select(Job).where(Job.id != preserve_job_id, Job.status.in_(active_statuses))
).all()
for job in jobs:
job.status = "cancelled"
job.requested_action = None
job.lease_owner = None
job.lease_expires_at = None
job.paused_at = None
job.error = None
job.updated_at = now
job.finished_at = now
session.add(
JobEvent(
job_id=job.id,
event_type="cancelled_by_reset",
message=f"Job cancelled by reset job #{preserve_job_id}.",
progress_current=job.progress_current,
progress_total=job.progress_total,
)
)
def create_sample_gtfs(path: Path) -> None:
agencies = [
{"agency_id": "BVG", "agency_name": "BVG", "agency_url": "https://example.invalid/bvg", "agency_timezone": "Europe/Berlin"},
{"agency_id": "DB", "agency_name": "DB Regio", "agency_url": "https://example.invalid/db", "agency_timezone": "Europe/Berlin"},
{"agency_id": "XAIR", "agency_name": "Example Airport Shuttle", "agency_url": "https://example.invalid/xair", "agency_timezone": "Europe/Berlin"},
]
stops = [
{"stop_id": "hbf", "stop_name": "Berlin Hauptbahnhof", "stop_lat": "52.5251", "stop_lon": "13.3696"},
{"stop_id": "friedrich", "stop_name": "Friedrichstraße", "stop_lat": "52.5201", "stop_lon": "13.3862"},
{"stop_id": "alex", "stop_name": "Alexanderplatz", "stop_lat": "52.5219", "stop_lon": "13.4132"},
{"stop_id": "ost", "stop_name": "Ostbahnhof", "stop_lat": "52.5100", "stop_lon": "13.4344"},
{"stop_id": "zoo", "stop_name": "Zoologischer Garten", "stop_lat": "52.5069", "stop_lon": "13.3320"},
{"stop_id": "wittenberg", "stop_name": "Wittenbergplatz", "stop_lat": "52.5020", "stop_lon": "13.3430"},
{"stop_id": "potsdamer", "stop_name": "Potsdamer Platz", "stop_lat": "52.5096", "stop_lon": "13.3760"},
{"stop_id": "stadtmitte", "stop_name": "Stadtmitte", "stop_lat": "52.5113", "stop_lon": "13.3907"},
{"stop_id": "reichstag", "stop_name": "Reichstag", "stop_lat": "52.5186", "stop_lon": "13.3763"},
{"stop_id": "hackescher", "stop_name": "Hackescher Markt", "stop_lat": "52.5220", "stop_lon": "13.4023"},
{"stop_id": "naturkunde", "stop_name": "Naturkundemuseum", "stop_lat": "52.5300", "stop_lon": "13.3790"},
{"stop_id": "wannsee", "stop_name": "Wannsee", "stop_lat": "52.4210", "stop_lon": "13.1797"},
{"stop_id": "kladow", "stop_name": "Kladow", "stop_lat": "52.4547", "stop_lon": "13.1439"},
{"stop_id": "airport", "stop_name": "Example Airport Terminal", "stop_lat": "52.3650", "stop_lon": "13.5100"},
]
routes = [
{"route_id": "u2", "agency_id": "BVG", "route_short_name": "U2", "route_long_name": "Pankow - Ruhleben", "route_type": "1"},
{"route_id": "re1", "agency_id": "DB", "route_short_name": "RE1", "route_long_name": "Magdeburg - Frankfurt Oder", "route_type": "2"},
{"route_id": "m5", "agency_id": "BVG", "route_short_name": "M5", "route_long_name": "Hauptbahnhof - Hohenschönhausen", "route_type": "0"},
{"route_id": "bus100", "agency_id": "BVG", "route_short_name": "100", "route_long_name": "Zoologischer Garten - Alexanderplatz", "route_type": "3"},
{"route_id": "f10", "agency_id": "BVG", "route_short_name": "F10", "route_long_name": "Wannsee - Kladow", "route_type": "4"},
{"route_id": "x99", "agency_id": "XAIR", "route_short_name": "X99", "route_long_name": "Airport Express Sample", "route_type": "3"},
]
trips = [
{"route_id": r["route_id"], "service_id": "daily", "trip_id": f"{r['route_id']}_trip", "shape_id": f"{r['route_id']}_shape"}
for r in routes
]
stop_sequences = {
"u2_trip": ["zoo", "wittenberg", "potsdamer", "stadtmitte", "alex"],
"re1_trip": ["hbf", "friedrich", "alex", "ost"],
"m5_trip": ["hbf", "naturkunde", "hackescher", "alex"],
"bus100_trip": ["zoo", "reichstag", "alex"],
"f10_trip": ["wannsee", "kladow"],
"x99_trip": ["alex", "airport"],
}
coords = {row["stop_id"]: (row["stop_lon"], row["stop_lat"]) for row in stops}
stop_times = []
shapes = []
for trip in trips:
trip_id = trip["trip_id"]
for idx, stop_id in enumerate(stop_sequences[trip_id], start=1):
stop_times.append(
{
"trip_id": trip_id,
"arrival_time": f"08:{idx * 5:02d}:00",
"departure_time": f"08:{idx * 5 + 1:02d}:00",
"stop_id": stop_id,
"stop_sequence": str(idx),
}
)
lon, lat = coords[stop_id]
shapes.append(
{
"shape_id": trip["shape_id"],
"shape_pt_lat": lat,
"shape_pt_lon": lon,
"shape_pt_sequence": str(idx),
}
)
calendar = [
{
"service_id": "daily",
"monday": "1",
"tuesday": "1",
"wednesday": "1",
"thursday": "1",
"friday": "1",
"saturday": "1",
"sunday": "1",
"start_date": "20260101",
"end_date": "20261231",
}
]
with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
_write_csv(zf, "agency.txt", ["agency_id", "agency_name", "agency_url", "agency_timezone"], agencies)
_write_csv(zf, "stops.txt", ["stop_id", "stop_name", "stop_lat", "stop_lon"], stops)
_write_csv(zf, "routes.txt", ["route_id", "agency_id", "route_short_name", "route_long_name", "route_type"], routes)
_write_csv(zf, "trips.txt", ["route_id", "service_id", "trip_id", "shape_id"], trips)
_write_csv(zf, "stop_times.txt", ["trip_id", "arrival_time", "departure_time", "stop_id", "stop_sequence"], stop_times)
_write_csv(
zf,
"calendar.txt",
["service_id", "monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday", "start_date", "end_date"],
calendar,
)
_write_csv(zf, "shapes.txt", ["shape_id", "shape_pt_lat", "shape_pt_lon", "shape_pt_sequence"], shapes)
def _write_csv(zf: zipfile.ZipFile, name: str, fields: list[str], rows: list[dict[str, str]]) -> None:
buffer = io.StringIO(newline="")
writer = csv.DictWriter(buffer, fieldnames=fields)
writer.writeheader()
writer.writerows(rows)
zf.writestr(name, buffer.getvalue())
def create_sample_osm_geojson(path: Path) -> None:
def line(fid, mode, ref, name, operator, coords):
return {
"type": "Feature",
"geometry": {"type": "LineString", "coordinates": coords},
"properties": {
"osm_type": "relation",
"osm_id": str(fid),
"type": "route",
"route": mode,
"ref": ref,
"name": name,
"operator": operator,
"network": "VBB" if operator == "BVG" else "DB",
},
}
def point(fid, kind, name, coords, props=None):
props = props or {}
props.update({"osm_type": "node", "osm_id": str(fid), "name": name})
return {"type": "Feature", "geometry": {"type": "Point", "coordinates": coords}, "properties": props}
features = [
line(1002, "subway", "U2", "U2 Ruhleben - Pankow", "BVG", [[13.3320, 52.5069], [13.3430, 52.5020], [13.3760, 52.5096], [13.3907, 52.5113], [13.4132, 52.5219]]),
line(2001, "train", "RE1", "RE1 Magdeburg - Frankfurt Oder", "DB Regio", [[13.3696, 52.5251], [13.3862, 52.5201], [13.4132, 52.5219], [13.4344, 52.5100]]),
line(5005, "tram", "M5", "M5 Hauptbahnhof - Hohenschönhausen", "BVG", [[13.3696, 52.5251], [13.3790, 52.5300], [13.4023, 52.5220], [13.4132, 52.5219]]),
line(6100, "bus", "100", "Bus 100 Zoologischer Garten - Alexanderplatz", "BVG", [[13.3320, 52.5069], [13.3763, 52.5186], [13.4132, 52.5219]]),
line(7010, "ferry", "F10", "F10 Wannsee - Kladow", "BVG", [[13.1797, 52.4210], [13.1439, 52.4547]]),
line(5010, "tram", "M10", "M10 Warschauer Straße - Hauptbahnhof", "BVG", [[13.4500, 52.5050], [13.4020, 52.5300], [13.3696, 52.5251]]),
point(9001, "station", "Berlin Hauptbahnhof", [13.3696, 52.5251], {"railway": "station"}),
point(9002, "station", "Alexanderplatz", [13.4132, 52.5219], {"railway": "station"}),
point(9003, "stop", "Zoologischer Garten", [13.3320, 52.5069], {"public_transport": "station", "railway": "station"}),
point(9004, "terminal", "Wannsee Ferry Terminal", [13.1797, 52.4210], {"amenity": "ferry_terminal"}),
point(9005, "terminal", "Kladow Ferry Terminal", [13.1439, 52.4547], {"amenity": "ferry_terminal"}),
]
path.write_text(json.dumps({"type": "FeatureCollection", "features": features}, indent=2), encoding="utf-8")

135
app/pipeline/state.py Normal file
View File

@@ -0,0 +1,135 @@
from __future__ import annotations
from datetime import datetime, timezone
import hashlib
import json
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.models import PipelineRun
STAGE_ACQUIRE_RAW = "acquire_raw"
STAGE_FILTER_TRANSPORT = "filter_transport"
STAGE_EXTRACT_GEOMETRY = "extract_geometry"
STAGE_LABEL_FEATURES = "label_features"
STAGE_BUILD_INDEXES = "build_indexes"
STAGE_MATCH_ROUTES = "match_routes"
STAGE_BUILD_ROUTE_LAYER = "build_route_layer"
def stable_json(value: Any) -> str:
return json.dumps(value, sort_keys=True, separators=(",", ":"), default=str)
def dependency_hash(value: Any) -> str:
return hashlib.sha256(stable_json(value).encode("utf-8")).hexdigest()
def latest_completed_run(
session: Session,
*,
stage: str,
version: str,
dependency_hash_value: str,
source_id: int | None = None,
dataset_id: int | None = None,
) -> PipelineRun | None:
stmt = (
select(PipelineRun)
.where(
PipelineRun.stage == stage,
PipelineRun.version == version,
PipelineRun.dependency_hash == dependency_hash_value,
PipelineRun.status == "completed",
)
.order_by(PipelineRun.finished_at.desc(), PipelineRun.id.desc())
.limit(1)
)
if source_id is None:
stmt = stmt.where(PipelineRun.source_id.is_(None))
else:
stmt = stmt.where(PipelineRun.source_id == source_id)
if dataset_id is None:
stmt = stmt.where(PipelineRun.dataset_id.is_(None))
else:
stmt = stmt.where(PipelineRun.dataset_id == dataset_id)
return session.scalar(stmt)
def start_pipeline_run(
session: Session,
*,
stage: str,
version: str,
dependency_hash_value: str,
source_id: int | None = None,
dataset_id: int | None = None,
job_id: int | None = None,
inputs: dict[str, Any] | None = None,
) -> PipelineRun:
now = datetime.now(timezone.utc)
run = PipelineRun(
stage=stage,
version=version,
dependency_hash=dependency_hash_value,
status="running",
source_id=source_id,
dataset_id=dataset_id,
job_id=job_id,
input_json=None if inputs is None else stable_json(inputs),
started_at=now,
updated_at=now,
)
session.add(run)
session.flush()
return run
def finish_pipeline_run(
session: Session,
run: PipelineRun,
*,
status: str = "completed",
outputs: dict[str, Any] | None = None,
error: str | None = None,
) -> PipelineRun:
now = datetime.now(timezone.utc)
run.status = status
run.output_json = None if outputs is None else stable_json(outputs)
run.error = error
run.updated_at = now
run.finished_at = now
session.flush()
return run
def pipeline_run_payload(run: PipelineRun) -> dict[str, Any]:
return {
"id": run.id,
"stage": run.stage,
"version": run.version,
"dependency_hash": run.dependency_hash,
"status": run.status,
"source_id": run.source_id,
"dataset_id": run.dataset_id,
"job_id": run.job_id,
"input": _json_object(run.input_json),
"output": _json_object(run.output_json),
"error": run.error,
"started_at": run.started_at.isoformat() if run.started_at else None,
"updated_at": run.updated_at.isoformat() if run.updated_at else None,
"finished_at": run.finished_at.isoformat() if run.finished_at else None,
}
def _json_object(text: str | None) -> dict[str, Any]:
if not text:
return {}
try:
value = json.loads(text)
except json.JSONDecodeError:
return {}
return value if isinstance(value, dict) else {}

89
app/pipeline/utils.py Normal file
View File

@@ -0,0 +1,89 @@
from __future__ import annotations
import hashlib
import json
import re
from pathlib import Path
from typing import Iterable, Optional
from shapely.geometry import shape
def sha256_file(path: Path) -> str:
h = hashlib.sha256()
with path.open("rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
h.update(chunk)
return h.hexdigest()
def norm_text(value: object) -> str:
if value is None:
return ""
value = str(value).lower().strip()
value = value.replace("ß", "ss")
value = re.sub(r"[^a-z0-9]+", " ", value)
return re.sub(r"\s+", " ", value).strip()
def norm_ref(value: object) -> str:
if value is None:
return ""
return re.sub(r"[^a-z0-9]+", "", str(value).lower())
def first_nonempty(*values: object) -> str:
for value in values:
if value is None:
continue
text = str(value).strip()
if text:
return text
return ""
def geometry_json_and_bbox(geometry: object) -> tuple[Optional[str], tuple[Optional[float], Optional[float], Optional[float], Optional[float]]]:
if geometry is None:
return None, (None, None, None, None)
try:
geom = shape(geometry) if isinstance(geometry, dict) else geometry
if geom.is_empty:
return None, (None, None, None, None)
min_lon, min_lat, max_lon, max_lat = geom.bounds
return json.dumps(geom.__geo_interface__, separators=(",", ":")), (min_lon, min_lat, max_lon, max_lat)
except Exception:
return None, (None, None, None, None)
def bbox_overlap(a: tuple[float | None, float | None, float | None, float | None], b: tuple[float | None, float | None, float | None, float | None]) -> bool:
if any(v is None for v in (*a, *b)):
return False
aminx, aminy, amaxx, amaxy = a # type: ignore[misc]
bminx, bminy, bmaxx, bmaxy = b # type: ignore[misc]
return not (amaxx < bminx or bmaxx < aminx or amaxy < bminy or bmaxy < aminy)
def bbox_center(b: tuple[float | None, float | None, float | None, float | None]) -> Optional[tuple[float, float]]:
if any(v is None for v in b):
return None
minx, miny, maxx, maxy = b # type: ignore[misc]
return ((minx + maxx) / 2, (miny + maxy) / 2)
def approx_bbox_center_distance_deg(a: tuple[float | None, float | None, float | None, float | None], b: tuple[float | None, float | None, float | None, float | None]) -> Optional[float]:
ca = bbox_center(a)
cb = bbox_center(b)
if ca is None or cb is None:
return None
return ((ca[0] - cb[0]) ** 2 + (ca[1] - cb[1]) ** 2) ** 0.5
def batched(iterable: Iterable[dict], batch_size: int = 1000) -> Iterable[list[dict]]:
batch: list[dict] = []
for item in iterable:
batch.append(item)
if len(batch) >= batch_size:
yield batch
batch = []
if batch:
yield batch