Files
meubility-workbench/app/pipeline/matcher.py
2026-07-01 23:29:51 +02:00

996 lines
39 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
import json
from typing import Callable, Optional
from shapely.geometry import LineString, MultiLineString, Point, shape
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from app.config import settings
from app.models import Dataset, GtfsRoute, MatchRule, OsmFeature, RouteMatch
from app.osm_storage import ensure_main_osm_feature, osm_feature_bbox, query_osm_features
from app.pipeline.state import STAGE_MATCH_ROUTES, dependency_hash, finish_pipeline_run, start_pipeline_run
from app.pipeline.utils import approx_bbox_center_distance_deg, bbox_overlap, norm_ref, norm_text
MODE_GROUPS = {
"train": {"train", "rail", "railway"},
"subway": {"subway", "metro"},
"tram": {"tram", "light_rail"},
"light_rail": {"light_rail", "tram"},
"bus": {"bus", "coach", "trolleybus"},
"coach": {"coach", "bus"},
"trolleybus": {"trolleybus", "bus"},
"ferry": {"ferry"},
"funicular": {"funicular"},
"aerialway": {"aerialway", "cable_car"},
"monorail": {"monorail"},
}
MAX_FALLBACK_CANDIDATES_WITH_REF = 40
MAX_FALLBACK_CANDIDATES_WITHOUT_REF = 80
MAX_EXACT_REF_CANDIDATES = 120
OSM_SCOPE_NEAR_DISTANCE_DEG = 0.15
GEOMETRY_PROXIMITY_DEG = 0.0035
GEOMETRY_SAMPLE_POINTS = 24
MATCHER_VERSION = "matcher_v4_scope_spatial_manual_rules"
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
@dataclass(frozen=True)
class _ManualMatchRule:
id: int
rule_type: str
route_selector: dict[str, object]
osm_selector: dict[str, object] | None
status: str
@dataclass(frozen=True)
class _OsmRouteIndex:
all_routes: list[OsmFeature]
by_ref: dict[str, list[OsmFeature]]
by_route_key: dict[str, list[OsmFeature]]
by_mode: dict[str, list[OsmFeature]]
@dataclass(frozen=True)
class _GeometryProfile:
geom: object
lines: list[LineString]
length: float
sample_points: list[Point]
@dataclass(frozen=True)
class _RouteMatchPayload:
gtfs_route_id: int
osm_feature_id: int | None
confidence: float
status: str
rule_source: str
reasons_json: str | None
def run_route_matching(
session: Session,
*,
progress_callback: ProgressCallback | None = None,
batch_size: int | None = None,
) -> dict[str, object]:
"""Match active GTFS routes against active OSM route features."""
active_datasets = session.execute(
select(Dataset.id, Dataset.kind, Dataset.source_id).where(Dataset.is_active.is_(True))
).all()
if not active_datasets:
return {"routes": 0, "matches": 0, "missing": 0}
dataset_source_ids = {int(dataset_id): int(source_id) for dataset_id, _, source_id in active_datasets}
gtfs_dataset_ids = [int(dataset_id) for dataset_id, kind, _ in active_datasets if kind == "gtfs"]
osm_dataset_ids = [int(dataset_id) for dataset_id, kind, _ in active_datasets if kind == "osm_geojson"]
if not gtfs_dataset_ids:
return {"routes": 0, "matches": 0, "missing": 0}
route_row_ids = session.scalars(
select(GtfsRoute.id)
.where(GtfsRoute.dataset_id.in_(gtfs_dataset_ids))
.order_by(GtfsRoute.dataset_id, GtfsRoute.route_id, GtfsRoute.id)
).all()
# Reconcile current match rows from auto scoring plus durable manual rules.
total_routes = len(route_row_ids)
if total_routes == 0:
return {"routes": 0, "matches": 0, "missing": 0}
dependency = _route_matching_dependency(session, active_datasets)
run = start_pipeline_run(
session,
stage=STAGE_MATCH_ROUTES,
version=MATCHER_VERSION,
dependency_hash_value=dependency_hash(dependency),
inputs=dependency,
)
session.commit()
effective_batch_size = max(1, int(batch_size or settings.route_matching_batch_size))
_emit_progress(
progress_callback,
"route_matching_started",
f"Matching {total_routes} GTFS routes in batches of {effective_batch_size}.",
0,
total_routes,
{"gtfs_datasets": gtfs_dataset_ids, "osm_datasets": osm_dataset_ids, "batch_size": effective_batch_size},
)
manual_rules = _manual_match_rules(session)
osm_scope_bbox = osm_feature_bbox(session, osm_dataset_ids, kinds=["route"])
counts = {"routes": total_routes, "matches": 0, "missing": 0, "manual": 0, "created": 0, "updated": 0, "unchanged": 0}
scoped_counts = {"in_osm_scope": 0, "near_osm_scope": 0, "outside_osm_scope": 0, "unknown_scope": 0}
processed = 0
for chunk in _chunks_int(route_row_ids, effective_batch_size):
routes = session.scalars(
select(GtfsRoute)
.where(GtfsRoute.id.in_(chunk))
.order_by(GtfsRoute.dataset_id, GtfsRoute.route_id, GtfsRoute.id)
).all()
batch_counts = _match_route_batch(
session=session,
routes=routes,
osm_dataset_ids=osm_dataset_ids,
dataset_source_ids=dataset_source_ids,
manual_rules=manual_rules,
osm_scope_bbox=osm_scope_bbox,
scoped_counts=scoped_counts,
)
counts["matches"] += batch_counts["matches"]
counts["missing"] += batch_counts["missing"]
counts["manual"] += batch_counts["manual"]
counts["created"] += batch_counts["created"]
counts["updated"] += batch_counts["updated"]
counts["unchanged"] += batch_counts["unchanged"]
processed += len(routes)
session.commit()
_emit_progress(
progress_callback,
"route_matching_batch",
f"Matched {processed}/{total_routes} GTFS routes.",
processed,
total_routes,
{
"processed": processed,
"matches": counts["matches"],
"missing": counts["missing"],
"manual": counts["manual"],
"created": counts["created"],
"updated": counts["updated"],
"unchanged": counts["unchanged"],
"scope": dict(scoped_counts),
},
)
result = {**counts, "scope": scoped_counts}
finish_pipeline_run(session, run, outputs=result)
session.commit()
_emit_progress(
progress_callback,
"route_matching_completed",
"Route matching completed.",
total_routes,
total_routes,
result,
)
return result
def _route_matching_dependency(session: Session, active_datasets) -> dict[str, object]:
datasets = [
{"id": int(dataset_id), "kind": str(kind), "source_id": int(source_id), "sha256": _dataset_sha(session, int(dataset_id))}
for dataset_id, kind, source_id in active_datasets
]
rules = [
{
"id": int(rule.id),
"type": rule.rule_type,
"active": bool(rule.active),
"selector": rule.selector_json,
"action": rule.action_json,
}
for rule in session.scalars(select(MatchRule).order_by(MatchRule.id)).all()
]
return {"version": MATCHER_VERSION, "active_datasets": datasets, "manual_rules": rules}
def _dataset_sha(session: Session, dataset_id: int) -> str | None:
dataset = session.get(Dataset, dataset_id)
return None if dataset is None else dataset.sha256
def _match_route_batch(
*,
session: Session,
routes: list[GtfsRoute],
osm_dataset_ids: list[int],
dataset_source_ids: dict[int, int],
manual_rules: list[_ManualMatchRule],
osm_scope_bbox: tuple[float | None, float | None, float | None, float | None],
scoped_counts: dict[str, int],
) -> dict[str, int]:
matches = 0
missing = 0
manual = 0
payloads: list[_RouteMatchPayload] = []
for route in routes:
scope = route_match_scope(route, osm_scope_bbox)
scoped_counts[scope] = scoped_counts.get(scope, 0) + 1
route_source_id = dataset_source_ids.get(route.dataset_id)
accepted_rule = _accepted_rule_for_route(manual_rules, route, route_source_id)
if accepted_rule is not None:
accepted_feature = _feature_for_rule_from_storage(session, osm_dataset_ids, dataset_source_ids, accepted_rule)
if accepted_feature is not None:
accepted_feature = ensure_main_osm_feature(session, accepted_feature)
payloads.append(
_RouteMatchPayload(
gtfs_route_id=route.id,
osm_feature_id=accepted_feature.id,
confidence=100.0,
status="accepted",
rule_source="manual",
reasons_json=json.dumps(
{"manual_rule_id": accepted_rule.id, "manual": "accepted_match", "scope": scope},
separators=(",", ":"),
),
)
)
matches += 1
manual += 1
continue
if scope == "outside_osm_scope":
missing += 1
payloads.append(
_RouteMatchPayload(
gtfs_route_id=route.id,
osm_feature_id=None,
confidence=0.0,
status="missing",
rule_source="auto",
reasons_json=json.dumps(
{
"reason": "outside loaded OSM route scope",
"scope": scope,
},
separators=(",", ":"),
),
)
)
continue
best_feature: Optional[OsmFeature] = None
best_score = 0.0
best_reasons: dict[str, object] = {}
route_geometry_profile = _geometry_profile(route.geometry_geojson)
for feature in candidate_osm_routes_for_route(session, route, osm_dataset_ids):
if _is_rejected_pair(manual_rules, route, route_source_id, feature, dataset_source_ids.get(feature.dataset_id)):
continue
feature_geometry_profile = _geometry_profile(feature.geometry_geojson)
score, reasons = score_route_pair(
route,
feature,
route_geometry_profile=route_geometry_profile,
feature_geometry_profile=feature_geometry_profile,
)
if score > best_score:
best_score = score
best_feature = feature
best_reasons = reasons
status = _status_from_score(best_score)
if best_feature is None or status == "missing":
missing += 1
best_feature_id = None
best_reasons = {
"reason": "no OSM candidate above threshold",
"scope": scope,
"best_score_below_threshold": round(float(best_score), 2) if best_score else 0,
"best_reasons": best_reasons,
}
best_score = 0
else:
matches += 1
best_feature = ensure_main_osm_feature(session, best_feature)
best_feature_id = best_feature.id
best_reasons["scope"] = scope
payloads.append(
_RouteMatchPayload(
gtfs_route_id=route.id,
osm_feature_id=best_feature_id,
confidence=round(float(best_score), 2),
status=status,
rule_source="auto",
reasons_json=json.dumps(best_reasons, separators=(",", ":")),
)
)
changes = _apply_route_match_payloads(session, payloads)
session.flush()
return {"matches": matches, "missing": missing, "manual": manual, **changes}
def _apply_route_match_payloads(session: Session, payloads: list[_RouteMatchPayload]) -> dict[str, int]:
if not payloads:
return {"created": 0, "updated": 0, "unchanged": 0}
route_ids = [payload.gtfs_route_id for payload in payloads]
existing_rows = session.scalars(
select(RouteMatch).where(RouteMatch.gtfs_route_id.in_(route_ids)).order_by(RouteMatch.gtfs_route_id, RouteMatch.id)
).all()
existing_by_route: dict[int, list[RouteMatch]] = {}
for row in existing_rows:
existing_by_route.setdefault(row.gtfs_route_id, []).append(row)
created = 0
updated = 0
unchanged = 0
duplicate_ids: list[int] = []
now = datetime.now(timezone.utc)
for payload in payloads:
existing = existing_by_route.get(payload.gtfs_route_id, [])
current = _preferred_existing_match(existing)
if current is None:
session.add(
RouteMatch(
gtfs_route_id=payload.gtfs_route_id,
osm_feature_id=payload.osm_feature_id,
confidence=payload.confidence,
status=payload.status,
rule_source=payload.rule_source,
reasons_json=payload.reasons_json,
)
)
created += 1
continue
duplicate_ids.extend(row.id for row in existing if row.id != current.id)
if _route_match_payload_equal(current, payload):
unchanged += 1
continue
current.osm_feature_id = payload.osm_feature_id
current.confidence = payload.confidence
current.status = payload.status
current.rule_source = payload.rule_source
current.reasons_json = payload.reasons_json
current.updated_at = now
updated += 1
for chunk in _chunks_int(duplicate_ids, 1000):
session.execute(delete(RouteMatch).where(RouteMatch.id.in_(chunk)))
return {"created": created, "updated": updated, "unchanged": unchanged}
def _preferred_existing_match(rows: list[RouteMatch]) -> RouteMatch | None:
if not rows:
return None
return next((row for row in rows if row.rule_source == "manual"), rows[0])
def _route_match_payload_equal(row: RouteMatch, payload: _RouteMatchPayload) -> bool:
return (
row.osm_feature_id == payload.osm_feature_id
and round(float(row.confidence or 0), 2) == round(float(payload.confidence or 0), 2)
and row.status == payload.status
and row.rule_source == payload.rule_source
and (row.reasons_json or None) == (payload.reasons_json or None)
)
def _build_osm_route_index(osm_routes: list[OsmFeature]) -> _OsmRouteIndex:
by_ref: dict[str, list[OsmFeature]] = {}
by_route_key: dict[str, list[OsmFeature]] = {}
by_mode: dict[str, list[OsmFeature]] = {}
for feature in osm_routes:
ref = norm_ref(feature.ref or "")
if ref:
by_ref.setdefault(ref, []).append(feature)
if feature.route_key:
by_route_key.setdefault(feature.route_key, []).append(feature)
if feature.mode:
by_mode.setdefault(feature.mode, []).append(feature)
return _OsmRouteIndex(all_routes=osm_routes, by_ref=by_ref, by_route_key=by_route_key, by_mode=by_mode)
def _candidate_osm_routes(route: GtfsRoute, index: _OsmRouteIndex) -> list[OsmFeature]:
selected: list[OsmFeature] = []
seen: set[int] = set()
def add(features: list[OsmFeature], *, require_compatible_mode: bool = True) -> None:
for feature in features:
if feature.id in seen:
continue
if require_compatible_mode and not _mode_compatible(route.mode or "", feature.mode or ""):
continue
seen.add(feature.id)
selected.append(feature)
route_ref = norm_ref(route.short_name or route.route_id)
if route_ref:
add(index.by_ref.get(route_ref, []))
if route.route_key:
add(index.by_route_key.get(route.route_key, []))
if selected:
return _spatially_ranked_candidates(route, selected, MAX_EXACT_REF_CANDIDATES)
compatible_modes = MODE_GROUPS.get(route.mode or "", {route.mode or ""})
mode_candidates: list[OsmFeature] = []
for mode in compatible_modes:
if mode:
mode_candidates.extend(index.by_mode.get(mode, []))
if not mode_candidates:
mode_candidates = index.all_routes
gtfs_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
near_candidates: list[tuple[float, OsmFeature]] = []
for feature in mode_candidates:
osm_bbox = (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
distance = approx_bbox_center_distance_deg(gtfs_bbox, osm_bbox)
if bbox_overlap(gtfs_bbox, osm_bbox):
near_candidates.append((0.0, feature))
elif distance is not None and distance < 0.12:
near_candidates.append((distance, feature))
fallback_limit = MAX_FALLBACK_CANDIDATES_WITH_REF if route_ref else MAX_FALLBACK_CANDIDATES_WITHOUT_REF
fallback = [feature for _, feature in sorted(near_candidates, key=lambda item: item[0])[:fallback_limit]]
if not fallback:
fallback = mode_candidates[:fallback_limit]
add(fallback)
return _spatially_ranked_candidates(route, selected, fallback_limit)
def candidate_osm_routes_for_route(session: Session, route: GtfsRoute, osm_dataset_ids: list[int]) -> list[OsmFeature]:
if not osm_dataset_ids:
return []
selected: list[OsmFeature] = []
seen: set[tuple[int, str, str]] = set()
def add(features: list[OsmFeature], *, require_compatible_mode: bool = True) -> None:
for feature in features:
key = (feature.dataset_id, feature.osm_type, feature.osm_id)
if key in seen:
continue
if require_compatible_mode and not _mode_compatible(route.mode or "", feature.mode or ""):
continue
seen.add(key)
selected.append(feature)
route_ref = norm_ref(route.short_name or route.route_id)
route_keys = [key for key in [route.route_key, route_ref] if key]
for route_key in dict.fromkeys(route_keys):
add(
query_osm_features(
session,
osm_dataset_ids,
kinds=["route"],
route_key=route_key,
)
)
if selected:
return _spatially_ranked_candidates(route, selected, MAX_EXACT_REF_CANDIDATES)
gtfs_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
compatible_modes = sorted(MODE_GROUPS.get(route.mode or "", {route.mode or ""}) - {""})
if not any(value is None for value in gtfs_bbox):
bbox = _expanded_bbox(gtfs_bbox, 0.10)
add(
query_osm_features(
session,
osm_dataset_ids,
kinds=["route"],
modes=compatible_modes or None,
bbox=bbox,
limit=MAX_FALLBACK_CANDIDATES_WITHOUT_REF * 4,
),
require_compatible_mode=False,
)
if not selected:
add(
query_osm_features(
session,
osm_dataset_ids,
kinds=["route"],
modes=compatible_modes or None,
limit=MAX_FALLBACK_CANDIDATES_WITHOUT_REF,
),
require_compatible_mode=False,
)
fallback_limit = MAX_FALLBACK_CANDIDATES_WITH_REF if route_ref else MAX_FALLBACK_CANDIDATES_WITHOUT_REF
return _spatially_ranked_candidates(route, selected, fallback_limit)
def score_route_pair(
route: GtfsRoute,
feature: OsmFeature,
route_geometry_profile: _GeometryProfile | None = None,
feature_geometry_profile: _GeometryProfile | None = None,
) -> tuple[float, dict[str, object]]:
score = 0.0
reasons: dict[str, object] = {}
gtfs_mode = route.mode or ""
osm_mode = feature.mode or ""
if _mode_compatible(gtfs_mode, osm_mode):
score += 25
reasons["mode"] = "compatible"
elif gtfs_mode and osm_mode:
reasons["mode"] = f"mismatch: {gtfs_mode} != {osm_mode}"
return 0.0, reasons
gtfs_ref = norm_ref(route.short_name or route.route_id)
osm_ref = norm_ref(feature.ref or "")
if gtfs_ref and osm_ref:
if gtfs_ref == osm_ref:
score += 25
reasons["ref"] = "exact"
elif gtfs_ref in osm_ref or osm_ref in gtfs_ref:
score += 15
reasons["ref"] = "partial"
gtfs_name = norm_text(" ".join(v for v in [route.long_name, route.short_name, route.route_id] if v))
osm_name = norm_text(" ".join(v for v in [feature.name, feature.ref] if v))
name_similarity = _ratio(gtfs_name, osm_name)
score += 20 * name_similarity
reasons["name_similarity"] = round(name_similarity, 3)
gtfs_operator = norm_text(route.operator_name or "")
osm_operator = norm_text(" ".join(v for v in [feature.operator, feature.network] if v))
operator_similarity = _ratio(gtfs_operator, osm_operator) if gtfs_operator and osm_operator else 0
score += 15 * operator_similarity
reasons["operator_similarity"] = round(operator_similarity, 3)
gtfs_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
osm_bbox = (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
center_distance = None
if bbox_overlap(gtfs_bbox, osm_bbox):
score += 14
reasons["bbox"] = "overlap"
if gtfs_ref and osm_ref and gtfs_ref == osm_ref and _mode_compatible(gtfs_mode, osm_mode):
score += 8
reasons["line_identity"] = "exact_ref_mode_bbox_overlap"
else:
center_distance = approx_bbox_center_distance_deg(gtfs_bbox, osm_bbox)
if center_distance is not None:
if center_distance < 0.01:
score += 12
elif center_distance < 0.03:
score += 8
elif center_distance < 0.08:
score += 4
elif gtfs_ref and osm_ref and gtfs_ref == osm_ref and center_distance > OSM_SCOPE_NEAR_DISTANCE_DEG:
score -= 8
reasons["spatial_penalty"] = "exact_ref_far_bbox_center"
reasons["bbox_center_distance_deg"] = round(center_distance, 5)
geometry_metrics = (
_geometry_match_metrics_from_profiles(route_geometry_profile, feature_geometry_profile)
if route_geometry_profile is not None and feature_geometry_profile is not None
else _geometry_match_metrics(route.geometry_geojson, feature.geometry_geojson)
)
if geometry_metrics is not None:
reasons["geometry"] = geometry_metrics
geometry_score = 34 * float(geometry_metrics["gtfs_on_osm_ratio"]) + 8 * float(geometry_metrics["osm_on_gtfs_ratio"])
if float(geometry_metrics["endpoint_distance_deg"]) < GEOMETRY_PROXIMITY_DEG * 2:
geometry_score += 6
if float(geometry_metrics["length_ratio"]) < 0.35 or float(geometry_metrics["length_ratio"]) > 2.8:
geometry_score -= 8
reasons["geometry_length"] = "implausible_ratio"
score += max(0.0, min(42.0, geometry_score))
# Extra small boost for same normalized route key.
if route.route_key and feature.route_key and route.route_key == feature.route_key:
score += 5
reasons["route_key"] = "same"
if gtfs_ref and osm_ref and gtfs_ref == osm_ref and _mode_compatible(gtfs_mode, osm_mode):
if bbox_overlap(gtfs_bbox, osm_bbox):
score = max(score, 88.0)
reasons["strong_identity"] = "exact_ref_mode_bbox_overlap"
elif center_distance is not None and center_distance < 0.02:
score = max(score, 82.0)
reasons["strong_identity"] = "exact_ref_mode_near_bbox_center"
if route.route_key and feature.route_key and route.route_key == feature.route_key and _mode_compatible(gtfs_mode, osm_mode):
if bbox_overlap(gtfs_bbox, osm_bbox):
score = max(score, 86.0)
reasons.setdefault("strong_identity", "same_route_key_mode_bbox_overlap")
if geometry_metrics is not None:
gtfs_on_osm = float(geometry_metrics["gtfs_on_osm_ratio"])
endpoint_distance = float(geometry_metrics["endpoint_distance_deg"])
if gtfs_on_osm >= 0.82 and endpoint_distance < GEOMETRY_PROXIMITY_DEG * 3 and _mode_compatible(gtfs_mode, osm_mode):
if gtfs_ref and osm_ref and gtfs_ref == osm_ref:
score = max(score, 90.0)
reasons["strong_identity"] = "exact_ref_mode_geometry_overlap"
elif gtfs_ref and osm_ref and (gtfs_ref in osm_ref or osm_ref in gtfs_ref):
score = max(score, 82.0)
reasons["strong_identity"] = "partial_ref_mode_geometry_overlap"
if (
gtfs_ref
and osm_ref
and gtfs_ref == osm_ref
and center_distance is not None
and center_distance > OSM_SCOPE_NEAR_DISTANCE_DEG
and not bbox_overlap(gtfs_bbox, osm_bbox)
and (
geometry_metrics is None
or float(geometry_metrics.get("gtfs_on_osm_ratio", 0.0)) < 0.25
)
):
score = min(score, 58.0)
reasons["spatial_cap"] = "exact_ref_far_without_geometry_overlap"
return min(score, 100.0), reasons
def route_match_scope(route: GtfsRoute, osm_scope_bbox: tuple[float | None, float | None, float | None, float | None]) -> str:
route_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
if any(value is None for value in route_bbox) or any(value is None for value in osm_scope_bbox):
return "unknown_scope"
if bbox_overlap(route_bbox, osm_scope_bbox):
return "in_osm_scope"
distance = approx_bbox_center_distance_deg(route_bbox, osm_scope_bbox)
if distance is not None and distance < OSM_SCOPE_NEAR_DISTANCE_DEG:
return "near_osm_scope"
return "outside_osm_scope"
def _combined_bbox(features: list[OsmFeature]) -> tuple[float | None, float | None, float | None, float | None]:
boxes = [
(feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
for feature in features
if None not in (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
]
if not boxes:
return (None, None, None, None)
return (
min(float(box[0]) for box in boxes if box[0] is not None),
min(float(box[1]) for box in boxes if box[1] is not None),
max(float(box[2]) for box in boxes if box[2] is not None),
max(float(box[3]) for box in boxes if box[3] is not None),
)
def _spatially_ranked_candidates(route: GtfsRoute, candidates: list[OsmFeature], limit: int) -> list[OsmFeature]:
return [
feature
for _, feature in sorted(
((_spatial_rank(route, feature), feature) for feature in candidates),
key=lambda item: item[0],
)[: max(1, limit)]
]
def _spatial_rank(route: GtfsRoute, feature: OsmFeature) -> tuple[int, float, str]:
route_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
feature_bbox = (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
distance = approx_bbox_center_distance_deg(route_bbox, feature_bbox)
if bbox_overlap(route_bbox, feature_bbox):
bucket = 0
elif distance is not None and distance < OSM_SCOPE_NEAR_DISTANCE_DEG:
bucket = 1
elif distance is not None:
bucket = 2
else:
bucket = 3
return (bucket, distance if distance is not None else 999.0, feature.osm_id)
def _expanded_bbox(
bbox: tuple[float | None, float | None, float | None, float | None],
padding: float,
) -> tuple[float, float, float, float] | None:
min_lon, min_lat, max_lon, max_lat = bbox
if None in (min_lon, min_lat, max_lon, max_lat):
return None
return (float(min_lon) - padding, float(min_lat) - padding, float(max_lon) + padding, float(max_lat) + padding)
def _chunks_int(values: list[int], size: int) -> list[list[int]]:
return [values[start : start + size] for start in range(0, len(values), max(1, size))]
def _emit_progress(
progress_callback: ProgressCallback | None,
event_type: str,
message: str,
progress_current: int | None,
progress_total: int | None,
metadata: dict[str, object] | None = None,
) -> None:
if progress_callback is not None:
progress_callback(event_type, message, progress_current, progress_total, metadata)
def _geometry_match_metrics(route_geometry: str | None, feature_geometry: str | None) -> dict[str, float] | None:
route_profile = _geometry_profile(route_geometry)
feature_profile = _geometry_profile(feature_geometry)
return _geometry_match_metrics_from_profiles(route_profile, feature_profile)
def _geometry_profile(geometry_text: str | None) -> _GeometryProfile | None:
if not geometry_text:
return None
try:
geom = shape(json.loads(geometry_text))
except Exception: # noqa: BLE001 - malformed geometry should not break matching
return None
lines = _iter_lines(geom)
if not lines:
return None
length = sum(line.length for line in lines)
if length == 0:
return None
sample_points = _sample_line_points(lines, GEOMETRY_SAMPLE_POINTS)
if not sample_points:
return None
return _GeometryProfile(geom=geom, lines=lines, length=length, sample_points=sample_points)
def _geometry_match_metrics_from_profiles(
route_profile: _GeometryProfile | None, feature_profile: _GeometryProfile | None
) -> dict[str, float] | None:
if route_profile is None or feature_profile is None:
return None
gtfs_on_osm = _near_point_ratio(route_profile.sample_points, feature_profile.geom, GEOMETRY_PROXIMITY_DEG)
osm_on_gtfs = _near_point_ratio(feature_profile.sample_points, route_profile.geom, GEOMETRY_PROXIMITY_DEG)
endpoint_distance = _endpoint_distance(route_profile.lines, feature_profile.geom)
length_ratio = route_profile.length / feature_profile.length if feature_profile.length else 0.0
return {
"gtfs_on_osm_ratio": round(gtfs_on_osm, 3),
"osm_on_gtfs_ratio": round(osm_on_gtfs, 3),
"endpoint_distance_deg": round(endpoint_distance, 6),
"length_ratio": round(length_ratio, 3),
}
def _iter_lines(geom) -> list[LineString]:
if isinstance(geom, LineString):
return [geom]
if isinstance(geom, MultiLineString):
return [line for line in geom.geoms if isinstance(line, LineString) and line.length > 0]
return []
def _sample_line_points(lines: list[LineString], count: int) -> list[Point]:
total_length = sum(line.length for line in lines)
if total_length == 0:
return []
points = []
for index in range(count):
target = total_length * (index / max(1, count - 1))
traversed = 0.0
for line in lines:
next_traversed = traversed + line.length
if target <= next_traversed or line is lines[-1]:
points.append(line.interpolate(max(0.0, min(line.length, target - traversed))))
break
traversed = next_traversed
return points
def _near_point_ratio(points: list[Point], geom, max_distance: float) -> float:
if not points:
return 0.0
near = sum(1 for point in points if geom.distance(point) <= max_distance)
return near / len(points)
def _endpoint_distance(gtfs_lines: list[LineString], osm_geom) -> float:
longest = max(gtfs_lines, key=lambda line: line.length)
coords = list(longest.coords)
if len(coords) < 2:
return 999.0
return osm_geom.distance(Point(coords[0])) + osm_geom.distance(Point(coords[-1]))
def _manual_match_rules(session: Session) -> list[_ManualMatchRule]:
rules = session.scalars(
select(MatchRule)
.where(MatchRule.active.is_(True), MatchRule.rule_type.in_(["accept_match", "reject_match"]))
.order_by(MatchRule.id.desc())
).all()
parsed: list[_ManualMatchRule] = []
for rule in rules:
try:
selector = json.loads(rule.selector_json or "{}")
action = json.loads(rule.action_json or "{}")
except json.JSONDecodeError:
continue
route_selector = selector.get("gtfs") if isinstance(selector.get("gtfs"), dict) else selector
osm_selector = action.get("osm") if isinstance(action.get("osm"), dict) else selector.get("osm")
if not isinstance(osm_selector, dict) and selector.get("osm_feature_id") is not None:
osm_selector = {"osm_feature_id": selector.get("osm_feature_id")}
status = str(action.get("status") or ("accepted" if rule.rule_type == "accept_match" else "rejected"))
parsed.append(
_ManualMatchRule(
id=rule.id,
rule_type=rule.rule_type,
route_selector=route_selector,
osm_selector=osm_selector if isinstance(osm_selector, dict) else None,
status=status,
)
)
return parsed
def _accepted_rule_for_route(
rules: list[_ManualMatchRule], route: GtfsRoute, route_source_id: int | None
) -> _ManualMatchRule | None:
for rule in rules:
if rule.rule_type != "accept_match":
continue
if rule.status != "accepted":
continue
if _route_matches_selector(route, route_source_id, rule.route_selector):
return rule
return None
def _feature_for_rule(
features: list[OsmFeature], dataset_source_ids: dict[int, int], rule: _ManualMatchRule
) -> OsmFeature | None:
if not rule.osm_selector:
return None
for feature in features:
if _feature_matches_selector(feature, dataset_source_ids.get(feature.dataset_id), rule.osm_selector):
return feature
return None
def _feature_for_rule_from_storage(
session: Session,
osm_dataset_ids: list[int],
dataset_source_ids: dict[int, int],
rule: _ManualMatchRule,
) -> OsmFeature | None:
if not rule.osm_selector:
return None
selector = rule.osm_selector
legacy_id = _safe_int(selector.get("osm_feature_id"))
if legacy_id is not None:
feature = session.get(OsmFeature, legacy_id)
if feature is not None and _feature_matches_selector(feature, dataset_source_ids.get(feature.dataset_id), selector):
return feature
scoped_dataset_ids = list(osm_dataset_ids)
expected_source = selector.get("source_id")
if expected_source is not None:
expected_source_id = _safe_int(expected_source)
if expected_source_id is not None:
scoped_dataset_ids = [
dataset_id
for dataset_id in scoped_dataset_ids
if dataset_source_ids.get(dataset_id) == expected_source_id
]
dataset_id = _safe_int(selector.get("dataset_id"))
if dataset_id is not None:
scoped_dataset_ids = [value for value in scoped_dataset_ids if value == dataset_id]
if not scoped_dataset_ids:
return None
features: list[OsmFeature] = []
osm_type = selector.get("osm_type")
osm_id = selector.get("osm_id")
if osm_type and osm_id:
features = query_osm_features(
session,
scoped_dataset_ids,
kinds=["route"],
osm_type=str(osm_type),
osm_id=str(osm_id),
limit=10,
)
if not features:
route_key = selector.get("route_key")
if route_key:
features = query_osm_features(session, scoped_dataset_ids, kinds=["route"], route_key=str(route_key))
if not features:
ref = norm_ref(selector.get("ref"))
if ref:
features = query_osm_features(session, scoped_dataset_ids, kinds=["route"], route_key=ref)
for feature in features:
if _feature_matches_selector(feature, dataset_source_ids.get(feature.dataset_id), selector):
return feature
return None
def _is_rejected_pair(
rules: list[_ManualMatchRule],
route: GtfsRoute,
route_source_id: int | None,
feature: OsmFeature,
feature_source_id: int | None,
) -> bool:
for rule in rules:
if rule.rule_type != "reject_match":
continue
if not _route_matches_selector(route, route_source_id, rule.route_selector):
continue
if rule.osm_selector and _feature_matches_selector(feature, feature_source_id, rule.osm_selector):
return True
return False
def _route_matches_selector(route: GtfsRoute, source_id: int | None, selector: dict[str, object]) -> bool:
legacy_id = selector.get("gtfs_route_id")
if legacy_id is not None and _safe_int(legacy_id) == route.id:
return True
expected_source = selector.get("source_id")
if expected_source is not None and source_id is not None and _safe_int(expected_source) != source_id:
return False
route_id = selector.get("route_id")
if route_id and str(route_id) == route.route_id:
return True
route_key = selector.get("route_key")
if route_key and route.route_key and str(route_key) == route.route_key:
return True
ref = norm_ref(selector.get("ref"))
mode = selector.get("mode")
if ref and ref == norm_ref(route.short_name or route.route_id):
return not mode or _mode_compatible(str(mode), route.mode or "")
return False
def _feature_matches_selector(feature: OsmFeature, source_id: int | None, selector: dict[str, object]) -> bool:
legacy_id = selector.get("osm_feature_id")
if legacy_id is not None and _safe_int(legacy_id) == feature.id:
return True
expected_source = selector.get("source_id")
if expected_source is not None and source_id is not None and _safe_int(expected_source) != source_id:
return False
osm_type = selector.get("osm_type")
osm_id = selector.get("osm_id")
if osm_type and osm_id and str(osm_type) == feature.osm_type and str(osm_id) == feature.osm_id:
return True
route_key = selector.get("route_key")
if route_key and feature.route_key and str(route_key) == feature.route_key:
return True
ref = norm_ref(selector.get("ref"))
mode = selector.get("mode")
if ref and ref == norm_ref(feature.ref or ""):
return not mode or _mode_compatible(str(mode), feature.mode or "")
return False
def _safe_int(value: object) -> int | None:
try:
return int(value) # type: ignore[arg-type]
except (TypeError, ValueError):
return None
def _mode_compatible(gtfs_mode: str, osm_mode: str) -> bool:
if not gtfs_mode or not osm_mode:
return True
if gtfs_mode == osm_mode:
return True
return osm_mode in MODE_GROUPS.get(gtfs_mode, {gtfs_mode}) or gtfs_mode in MODE_GROUPS.get(osm_mode, {osm_mode})
def _ratio(a: str, b: str) -> float:
if not a or not b:
return 0.0
if a == b:
return 1.0
token_ratio = _token_similarity(a, b)
if a in b or b in a:
token_ratio = max(token_ratio, 0.82)
return token_ratio
def _token_similarity(a: str, b: str) -> float:
left = set(a.split())
right = set(b.split())
if not left or not right:
return 0.0
return len(left & right) / len(left | right)
def _status_from_score(score: float) -> str:
if score >= 85:
return "matched"
if score >= 65:
return "probable"
if score >= 40:
return "weak"
return "missing"