Files
meubility-workbench/app/pipeline/route_layer.py
2026-07-01 23:29:51 +02:00

1904 lines
75 KiB
Python

from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Callable, Iterable
from shapely.geometry import LineString, MultiLineString, Point, shape
from shapely.ops import linemerge
from sqlalchemy import and_, delete, func, or_, select, text
from sqlalchemy.orm import Session
from app.config import settings
from app.gtfs_storage import all_scheduled_stop_ids, stop_times_by_trip as storage_stop_times_by_trip
from app.models import (
CanonicalStop,
CanonicalStopLink,
Dataset,
GtfsRoute,
GtfsRoutePatternLink,
GtfsShape,
GtfsStop,
GtfsStopTime,
GtfsTrip,
GtfsTripRoutePatternLink,
MatchRule,
OsmFeature,
RouteMatch,
RoutePattern,
RoutePatternStop,
)
from app.osm_classification import infer_osm_route_scope_from_tags
from app.osm_storage import ensure_main_osm_feature, features_are_sidecar, osm_feature_count, query_osm_features
from app.pipeline.matcher import MODE_GROUPS
from app.pipeline.state import STAGE_BUILD_ROUTE_LAYER, dependency_hash, finish_pipeline_run, start_pipeline_run
from app.pipeline.utils import bbox_overlap, geometry_json_and_bbox, norm_ref, norm_text
from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries, using_postgresql
ROUTE_LAYER_VERSION = "route_layer_v3_stop_alias_matching"
GTFS_ROUTE_PATTERN_NULL_SHAPE = "__route__"
OSM_STOP_LINK_RADIUS_DEG = 0.0018
OSM_STOP_NAME_LINK_RADIUS_DEG = 0.0032
GTFS_STOP_EXACT_NAME_LINK_RADIUS_DEG = 0.006
GTFS_STOP_NAME_LINK_RADIUS_DEG = 0.0032
GTFS_STOP_PARTIAL_NAME_LINK_RADIUS_DEG = 0.0014
OSM_ROUTE_MIN_SCORE = 62.0
ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
STOP_MATCH_NOISE_TOKENS = {
"s",
"u",
"bhf",
"station",
"train",
"flixtrain",
"flixbus",
}
@dataclass(frozen=True)
class _GtfsPatternSeed:
route: GtfsRoute
shape_id: str | None
trip_id: str | None
geometry_text: str | None
geometry_source: str
bbox: tuple[float | None, float | None, float | None, float | None]
start_point: Point | None
end_point: Point | None
center_point: Point | None
@dataclass(frozen=True)
class _OsmRouteCandidate:
feature: OsmFeature
geom: object
geometry_text: str
bbox: tuple[float | None, float | None, float | None, float | None]
ref_key: str
mode: str | None
@dataclass(frozen=True)
class _OsmRouteCandidateIndex:
by_ref_mode: dict[tuple[str, str], list[_OsmRouteCandidate]]
by_id: dict[int, _OsmRouteCandidate]
@dataclass(frozen=True)
class _RouteLayerOverrides:
accepted_by_gtfs_route_id: dict[int, int]
rejected_by_gtfs_route_id: dict[int, set[int]]
@dataclass(frozen=True)
class _CanonicalStopLinkOverrides:
link_by_stop: dict[tuple[int, str], dict[str, object]]
unlink_by_stop: dict[tuple[int, str], dict[str, object]]
@dataclass(frozen=True)
class _PatternBuildItem:
seed: _GtfsPatternSeed
pattern: RoutePattern
confidence: float
source_kind: str
status: str
reasons: dict[str, object]
def rebuild_route_layer(
session: Session,
*,
progress_callback: ProgressCallback | None = None,
commit_between_steps: bool = True,
) -> dict[str, object]:
"""Rebuild the visual route layer from active GTFS and OSM datasets."""
dependency = _route_layer_dependency(session)
run = start_pipeline_run(
session,
stage=STAGE_BUILD_ROUTE_LAYER,
version=ROUTE_LAYER_VERSION,
dependency_hash_value=dependency_hash(dependency),
inputs=dependency,
)
_commit_or_flush(session, commit_between_steps)
_emit_progress(progress_callback, "route_layer_started", "Rebuilding visual route layer.", 0, 4, {"version": ROUTE_LAYER_VERSION})
_clear_route_layer(session, preserve_route_patterns=True)
_commit_or_flush(session, commit_between_steps)
_emit_progress(progress_callback, "route_layer_cleared", "Cleared derived route-layer link tables.", 1, 4, None)
canonical_result = _build_canonical_stops(session)
_commit_or_flush(session, commit_between_steps)
_emit_progress(progress_callback, "route_layer_canonical_stops", "Built canonical GTFS stops.", 2, 4, canonical_result)
osm_link_result = _link_osm_stops(session, progress_callback=progress_callback, commit_batches=commit_between_steps)
_commit_or_flush(session, commit_between_steps)
_emit_progress(progress_callback, "route_layer_osm_stop_links", "Linked OSM visual stops to canonical stops.", 3, 4, osm_link_result)
pattern_result = _build_route_patterns(session, progress_callback=progress_callback)
_commit_or_flush(session, commit_between_steps)
result = {
"version": ROUTE_LAYER_VERSION,
"canonical_stops": canonical_result["canonical_stops"],
"canonical_stop_links": canonical_result["canonical_stop_links"] + osm_link_result["canonical_stop_links"],
"route_patterns": pattern_result["route_patterns"],
"route_patterns_created": pattern_result.get("route_patterns_created", 0),
"route_patterns_updated": pattern_result.get("route_patterns_updated", 0),
"route_patterns_reused": pattern_result.get("route_patterns_reused", 0),
"route_patterns_removed": pattern_result.get("route_patterns_removed", 0),
"route_pattern_links": pattern_result["route_pattern_links"],
"trip_pattern_links": pattern_result["trip_pattern_links"],
"route_pattern_stops": pattern_result["route_pattern_stops"],
"gtfs_proposed_patterns": pattern_result["gtfs_proposed_patterns"],
}
finish_pipeline_run(session, run, outputs=result)
_commit_or_flush(session, commit_between_steps)
_emit_progress(progress_callback, "route_layer_completed", "Visual route layer rebuilt.", 4, 4, result)
return result
def _route_layer_dependency(session: Session) -> dict[str, object]:
active_datasets = [
{"id": int(dataset.id), "source_id": int(dataset.source_id), "kind": dataset.kind, "sha256": dataset.sha256, "metadata": dataset.metadata_json}
for dataset in session.scalars(select(Dataset).where(Dataset.is_active.is_(True)).order_by(Dataset.kind, Dataset.id)).all()
]
match_rows = session.execute(select(RouteMatch.id, RouteMatch.gtfs_route_id, RouteMatch.osm_feature_id, RouteMatch.status, RouteMatch.updated_at).order_by(RouteMatch.id)).all()
match_signature = dependency_hash(
[
[
int(row.id),
int(row.gtfs_route_id),
None if row.osm_feature_id is None else int(row.osm_feature_id),
row.status,
row.updated_at.isoformat() if row.updated_at else None,
]
for row in match_rows
]
)
return {
"version": ROUTE_LAYER_VERSION,
"active_datasets": active_datasets,
"route_matches": {"count": len(match_rows), "signature": match_signature},
}
def logical_stop_group_id(stop: GtfsStop) -> str:
if stop.parent_station:
return stop.parent_station
if "::" in stop.stop_id:
return stop.stop_id.split("::", 1)[0]
return stop.stop_id
def route_pattern_for_trip(session: Session, route: GtfsRoute, trip: GtfsTrip) -> RoutePattern | None:
trip_link = session.scalar(
select(GtfsTripRoutePatternLink)
.where(
GtfsTripRoutePatternLink.dataset_id == trip.dataset_id,
GtfsTripRoutePatternLink.trip_id == trip.trip_id,
)
.order_by(GtfsTripRoutePatternLink.confidence.desc(), GtfsTripRoutePatternLink.id)
)
if trip_link is not None:
return session.get(RoutePattern, trip_link.route_pattern_id)
shape_key = trip.shape_id or GTFS_ROUTE_PATTERN_NULL_SHAPE
link = session.scalar(
select(GtfsRoutePatternLink)
.where(
GtfsRoutePatternLink.dataset_id == route.dataset_id,
GtfsRoutePatternLink.route_id == route.route_id,
GtfsRoutePatternLink.shape_id == shape_key,
)
.order_by(GtfsRoutePatternLink.confidence.desc(), GtfsRoutePatternLink.id)
)
if link is None:
return None
return session.get(RoutePattern, link.route_pattern_id)
def canonical_stop_for_gtfs_stop(session: Session, stop: GtfsStop) -> CanonicalStop | None:
link = session.scalar(
select(CanonicalStopLink).where(
CanonicalStopLink.object_type == "gtfs_stop",
CanonicalStopLink.dataset_id == stop.dataset_id,
CanonicalStopLink.object_id == stop.id,
)
)
if link is None:
return None
return session.get(CanonicalStop, link.canonical_stop_id)
def gtfs_stop_ids_for_canonical_stop(session: Session, canonical_stop_id: int, dataset_id: int) -> tuple[str, ...]:
rows = session.scalars(
select(CanonicalStopLink.external_id)
.where(
CanonicalStopLink.canonical_stop_id == canonical_stop_id,
CanonicalStopLink.object_type == "gtfs_stop",
CanonicalStopLink.dataset_id == dataset_id,
)
.order_by(CanonicalStopLink.role, CanonicalStopLink.external_id)
).all()
return tuple(str(row) for row in rows)
def _clear_route_layer(session: Session, *, preserve_route_patterns: bool = False) -> None:
models = [GtfsTripRoutePatternLink, GtfsRoutePatternLink, RoutePatternStop]
if not preserve_route_patterns:
models.append(RoutePattern)
models.extend([CanonicalStopLink, CanonicalStop])
for model in models:
session.execute(delete(model))
session.flush()
def _commit_or_flush(session: Session, should_commit: bool) -> None:
if should_commit:
session.commit()
else:
session.flush()
def _emit_progress(
progress_callback: ProgressCallback | None,
event_type: str,
message: str,
progress_current: int | None,
progress_total: int | None,
metadata: dict[str, object] | None = None,
) -> None:
if progress_callback is not None:
progress_callback(event_type, message, progress_current, progress_total, metadata)
def _build_canonical_stops(session: Session) -> dict[str, int]:
active_gtfs_dataset_ids = _active_dataset_ids(session, "gtfs")
if not active_gtfs_dataset_ids:
return {"canonical_stops": 0, "canonical_stop_links": 0}
overrides = _canonical_stop_link_overrides(session)
source_id_by_dataset = {
int(dataset_id): int(source_id)
for dataset_id, source_id in session.execute(
select(Dataset.id, Dataset.source_id).where(Dataset.id.in_(active_gtfs_dataset_ids))
).all()
}
stops = _scheduled_gtfs_stops(session, active_gtfs_dataset_ids)
groups: dict[tuple[int, str], list[GtfsStop]] = {}
for stop in stops:
groups.setdefault((stop.dataset_id, logical_stop_group_id(stop)), []).append(stop)
canonical_by_group: dict[tuple[int, str], CanonicalStop] = {}
link_quality_by_group: dict[tuple[int, str], tuple[float, float | None]] = {}
canonical_grid: dict[tuple[int, int], list[CanonicalStop]] = {}
for (dataset_id, group_id), group_stops in groups.items():
display = _best_display_stop(group_id, group_stops)
canonical, distance_m, confidence = _nearest_gtfs_canonical_from_grid(canonical_grid, display)
if canonical is None:
canonical = CanonicalStop(
stop_key=f"gtfs:{dataset_id}:{group_id}",
name=display.name or group_id,
normalized_name=norm_text(display.name or group_id),
lat=display.lat,
lon=display.lon,
metadata_json=json.dumps({"dataset_id": dataset_id, "group_id": group_id}, separators=(",", ":")),
)
_add_canonical_to_gtfs_grid(canonical_grid, canonical)
confidence = 1.0
distance_m = None
else:
_maybe_update_canonical_stop_display(canonical, display)
canonical_by_group[(dataset_id, group_id)] = canonical
link_quality_by_group[(dataset_id, group_id)] = (confidence, distance_m)
unique_canonicals = list(dict.fromkeys(canonical_by_group.values()))
session.add_all(unique_canonicals)
session.flush()
canonical_by_stop_key = {canonical.stop_key: canonical for canonical in unique_canonicals}
canonical_by_gtfs_stop = {
(stop.dataset_id, stop.stop_id): canonical_by_group[(dataset_id, group_id)]
for (dataset_id, group_id), group_stops in groups.items()
for stop in group_stops
}
link_objects: list[CanonicalStopLink] = []
for (dataset_id, group_id), group_stops in groups.items():
group_canonical = canonical_by_group[(dataset_id, group_id)]
confidence, distance_m = link_quality_by_group[(dataset_id, group_id)]
for stop in group_stops:
canonical = _canonical_for_gtfs_stop_link(
session=session,
stop=stop,
group_canonical=group_canonical,
overrides=overrides,
canonical_by_stop_key=canonical_by_stop_key,
canonical_by_gtfs_stop=canonical_by_gtfs_stop,
source_id_by_dataset=source_id_by_dataset,
)
role = "parent" if stop.stop_id == group_id and stop.parent_station is None else "platform"
metadata = None
if (stop.dataset_id, stop.stop_id) in overrides.link_by_stop:
metadata = json.dumps({"manual_rule": "link_canonical_stop"}, separators=(",", ":"))
elif (stop.dataset_id, stop.stop_id) in overrides.unlink_by_stop:
metadata = json.dumps({"manual_rule": "unlink_canonical_stop"}, separators=(",", ":"))
link_objects.append(
CanonicalStopLink(
canonical_stop_id=canonical.id,
layer="timetable",
object_type="gtfs_stop",
dataset_id=stop.dataset_id,
object_id=stop.id,
external_id=stop.stop_id,
role=role,
confidence=confidence,
distance_m=distance_m,
metadata_json=metadata,
)
)
if len(link_objects) >= 5000:
session.bulk_save_objects(link_objects)
link_objects.clear()
if link_objects:
session.bulk_save_objects(link_objects)
session.flush()
refresh_postgis_geometries(session, tables=["canonical_stops"])
analyze_postgresql_tables(session, ["canonical_stops", "canonical_stop_links"])
linked_stop_count = sum(len(group_stops) for group_stops in groups.values())
return {"canonical_stops": len(canonical_by_stop_key), "canonical_stop_links": linked_stop_count}
def _scheduled_gtfs_stops(session: Session, active_gtfs_dataset_ids: list[int]) -> list[GtfsStop]:
if using_postgresql():
scheduled_exists = (
select(GtfsStopTime.id)
.where(
GtfsStopTime.dataset_id == GtfsStop.dataset_id,
GtfsStopTime.stop_id == GtfsStop.stop_id,
)
.limit(1)
.exists()
)
return session.scalars(
select(GtfsStop)
.where(GtfsStop.dataset_id.in_(active_gtfs_dataset_ids), scheduled_exists)
.order_by(GtfsStop.dataset_id, GtfsStop.name, GtfsStop.stop_id)
).all()
scheduled_by_dataset = {
dataset_id: all_scheduled_stop_ids(session, dataset_id)
for dataset_id in active_gtfs_dataset_ids
}
stops = session.scalars(
select(GtfsStop)
.where(GtfsStop.dataset_id.in_(active_gtfs_dataset_ids))
.order_by(GtfsStop.dataset_id, GtfsStop.name, GtfsStop.stop_id)
).all()
return [
stop
for stop in stops
if stop.stop_id in scheduled_by_dataset.get(stop.dataset_id, set())
]
def _canonical_for_gtfs_stop_link(
*,
session: Session,
stop: GtfsStop,
group_canonical: CanonicalStop,
overrides: _CanonicalStopLinkOverrides,
canonical_by_stop_key: dict[str, CanonicalStop],
canonical_by_gtfs_stop: dict[tuple[int, str], CanonicalStop],
source_id_by_dataset: dict[int, int],
) -> CanonicalStop:
key = (stop.dataset_id, stop.stop_id)
if key in overrides.unlink_by_stop:
action = overrides.unlink_by_stop[key]
stop_key = str(action.get("target_stop_key") or f"manual:gtfs_stop:{stop.dataset_id}:{stop.stop_id}")
return _manual_canonical_stop(
session=session,
stop=stop,
stop_key=stop_key,
action=action,
canonical_by_stop_key=canonical_by_stop_key,
metadata_type="manual_unlink",
)
if key in overrides.link_by_stop:
action = overrides.link_by_stop[key]
target = _canonical_from_target_gtfs_refs(action, canonical_by_gtfs_stop, source_id_by_dataset)
if target is not None:
return target
target_stop_key = str(action.get("target_stop_key") or group_canonical.stop_key)
canonical = canonical_by_stop_key.get(target_stop_key)
if canonical is not None:
return canonical
return _manual_canonical_stop(
session=session,
stop=stop,
stop_key=target_stop_key,
action=action,
canonical_by_stop_key=canonical_by_stop_key,
metadata_type="manual_link_target",
)
return group_canonical
def _canonical_from_target_gtfs_refs(
action: dict[str, object],
canonical_by_gtfs_stop: dict[tuple[int, str], CanonicalStop],
source_id_by_dataset: dict[int, int],
) -> CanonicalStop | None:
refs = action.get("target_gtfs_stops")
if not isinstance(refs, list):
return None
for ref in refs:
if not isinstance(ref, dict):
continue
external_id = ref.get("external_id") or ref.get("stop_id")
if not external_id:
continue
source_id = ref.get("source_id")
for (dataset_id, stop_id), canonical in canonical_by_gtfs_stop.items():
if stop_id != str(external_id):
continue
if source_id is not None:
try:
if source_id_by_dataset.get(dataset_id) != int(source_id):
continue
except (TypeError, ValueError):
continue
return canonical
return None
def _manual_canonical_stop(
*,
session: Session,
stop: GtfsStop,
stop_key: str,
action: dict[str, object],
canonical_by_stop_key: dict[str, CanonicalStop],
metadata_type: str,
) -> CanonicalStop:
canonical = canonical_by_stop_key.get(stop_key)
if canonical is not None:
return canonical
name = str(action.get("target_name") or stop.name or stop.stop_id)
canonical = CanonicalStop(
stop_key=stop_key,
name=name,
normalized_name=norm_text(name),
lat=_float_or_default(action.get("target_lat"), stop.lat),
lon=_float_or_default(action.get("target_lon"), stop.lon),
mode=str(action.get("target_mode") or "") or None,
metadata_json=json.dumps(
{
"source": metadata_type,
"dataset_id": stop.dataset_id,
"stop_id": stop.stop_id,
},
separators=(",", ":"),
),
)
session.add(canonical)
session.flush()
canonical_by_stop_key[stop_key] = canonical
return canonical
def _canonical_stop_link_overrides(session: Session) -> _CanonicalStopLinkOverrides:
active_dataset_ids_by_source: dict[int, list[int]] = {}
for source_id, dataset_id in session.execute(
select(Dataset.source_id, Dataset.id).where(Dataset.is_active.is_(True), Dataset.kind == "gtfs")
).all():
active_dataset_ids_by_source.setdefault(int(source_id), []).append(int(dataset_id))
rules = session.scalars(
select(MatchRule)
.where(
MatchRule.active.is_(True),
MatchRule.rule_type.in_(["link_canonical_stop", "unlink_canonical_stop"]),
)
.order_by(MatchRule.id)
).all()
link_by_stop: dict[tuple[int, str], dict[str, object]] = {}
unlink_by_stop: dict[tuple[int, str], dict[str, object]] = {}
for rule in rules:
selector = _json_dict(rule.selector_json)
action = _json_dict(rule.action_json)
keys = _gtfs_stop_rule_keys(selector, active_dataset_ids_by_source)
if not keys:
continue
for key in keys:
if rule.rule_type == "link_canonical_stop":
link_by_stop[key] = action
unlink_by_stop.pop(key, None)
elif rule.rule_type == "unlink_canonical_stop":
unlink_by_stop[key] = action
link_by_stop.pop(key, None)
return _CanonicalStopLinkOverrides(link_by_stop=link_by_stop, unlink_by_stop=unlink_by_stop)
def _gtfs_stop_rule_keys(
selector: dict[str, object],
active_dataset_ids_by_source: dict[int, list[int]],
) -> list[tuple[int, str]]:
if selector.get("object_type") not in {None, "gtfs_stop"}:
return []
nested = selector.get("gtfs_stop")
nested_selector = nested if isinstance(nested, dict) else {}
dataset_id = selector.get("dataset_id", nested_selector.get("dataset_id"))
source_id = selector.get("source_id", nested_selector.get("source_id"))
external_id = selector.get("external_id", nested_selector.get("external_id", nested_selector.get("stop_id")))
if external_id is None:
return []
keys: list[tuple[int, str]] = []
try:
if dataset_id is not None:
keys.append((int(dataset_id), str(external_id)))
if source_id is not None:
keys.extend((active_dataset_id, str(external_id)) for active_dataset_id in active_dataset_ids_by_source.get(int(source_id), []))
except (TypeError, ValueError):
return []
return list(dict.fromkeys(keys))
def _json_dict(value: str | None) -> dict[str, object]:
try:
data = json.loads(value or "{}")
except json.JSONDecodeError:
return {}
return data if isinstance(data, dict) else {}
def _float_or_default(value: object, default: float | None) -> float | None:
if value is None:
return default
try:
return float(value)
except (TypeError, ValueError):
return default
def _nearest_gtfs_canonical_from_grid(
grid: dict[tuple[int, int], list[CanonicalStop]], display: GtfsStop
) -> tuple[CanonicalStop | None, float | None, float]:
if display.lon is None or display.lat is None:
return None, None, 0.0
normalized_name = norm_text(display.name or display.stop_id)
cell_x, cell_y = _gtfs_grid_cell(display.lon, display.lat)
candidates = [
stop
for dx in (-1, 0, 1)
for dy in (-1, 0, 1)
for stop in grid.get((cell_x + dx, cell_y + dy), [])
]
best = None
best_score = -1.0
for candidate in candidates:
if candidate.lon is None or candidate.lat is None:
continue
distance_deg = Point(candidate.lon, candidate.lat).distance(Point(display.lon, display.lat))
name_overlap = _stop_name_similarity(normalized_name, candidate.normalized_name)
exact_name = bool(_stop_match_key(normalized_name) and _stop_match_key(normalized_name) == _stop_match_key(candidate.normalized_name))
if exact_name:
max_radius = GTFS_STOP_EXACT_NAME_LINK_RADIUS_DEG
elif name_overlap >= 0.5:
max_radius = GTFS_STOP_NAME_LINK_RADIUS_DEG
elif name_overlap >= 0.25:
max_radius = GTFS_STOP_PARTIAL_NAME_LINK_RADIUS_DEG
else:
continue
if distance_deg > max_radius:
continue
distance_score = max(0.0, 1.0 - (distance_deg / max_radius))
score = distance_score * 0.62 + name_overlap * 0.38
if score > best_score:
best = (candidate, round(distance_deg * 111_320, 1), round(score, 3))
best_score = score
if best is None:
return None, None, 0.0
return best
def _stop_name_similarity(left: str, right: str) -> float:
left_tokens = _stop_match_tokens(left)
right_tokens = _stop_match_tokens(right)
if not left_tokens or not right_tokens:
return 0.0
if left_tokens == right_tokens:
return 1.0
return len(left_tokens & right_tokens) / len(left_tokens | right_tokens)
def _stop_match_key(value: str) -> str:
return " ".join(sorted(_stop_match_tokens(value)))
def _stop_match_tokens(value: str) -> set[str]:
tokens = set(norm_text(value).split())
if not tokens:
return set()
is_main_station = (
"hauptbahnhof" in tokens
or "hbf" in tokens
or ({"central", "station"} <= tokens and "bus" not in tokens)
or ({"main", "station"} <= tokens and "bus" not in tokens)
)
cleaned = {token for token in tokens if token not in STOP_MATCH_NOISE_TOKENS}
if is_main_station:
cleaned.difference_update({"hauptbahnhof", "hbf", "central", "main", "station"})
cleaned.add("mainstation")
return cleaned
def _maybe_update_canonical_stop_display(canonical: CanonicalStop, display: GtfsStop) -> None:
name = display.name or display.stop_id
if _stop_display_name_quality(name) <= _stop_display_name_quality(canonical.name):
return
canonical.name = name
canonical.normalized_name = norm_text(name)
def _stop_display_name_quality(name: str | None) -> int:
normalized = norm_text(name or "")
if not normalized:
return 0
tokens = set(normalized.split())
score = 100
if {"flixtrain", "flixbus"} & tokens:
score -= 35
if "central" in tokens and "station" in tokens:
score -= 5
if "hauptbahnhof" in tokens or "hbf" in tokens:
score += 8
if "berlin" in tokens:
score += 1
return score
def _add_canonical_to_gtfs_grid(grid: dict[tuple[int, int], list[CanonicalStop]], canonical: CanonicalStop) -> None:
if canonical.lon is None or canonical.lat is None:
return
grid.setdefault(_gtfs_grid_cell(canonical.lon, canonical.lat), []).append(canonical)
def _gtfs_grid_cell(lon: float, lat: float) -> tuple[int, int]:
return int(lon / GTFS_STOP_EXACT_NAME_LINK_RADIUS_DEG), int(lat / GTFS_STOP_EXACT_NAME_LINK_RADIUS_DEG)
def _link_osm_stops(
session: Session,
*,
progress_callback: ProgressCallback | None = None,
commit_batches: bool = False,
) -> dict[str, int]:
active_osm_dataset_ids = _active_dataset_ids(session, "osm_geojson")
if not active_osm_dataset_ids:
return {"canonical_stop_links": 0}
sidecar_dataset_ids = {
dataset.id
for dataset in session.scalars(select(Dataset).where(Dataset.id.in_(active_osm_dataset_ids))).all()
if features_are_sidecar(dataset)
}
if using_postgresql() and not sidecar_dataset_ids and not settings.osm_sidecar_create_visual_only_stops:
return _link_osm_stops_postgis(
session,
active_osm_dataset_ids,
progress_callback=progress_callback,
commit_batches=commit_batches,
)
canonical_grid = _canonical_stop_grid(session)
link_objects: list[CanonicalStopLink] = []
visual_only: list[tuple[OsmFeature, CanonicalStop, Point]] = []
link_count = 0
total_features = sum(osm_feature_count(session, dataset_id, kind=["stop", "station", "terminal"]) for dataset_id in active_osm_dataset_ids)
processed = 0
batch_size = max(100, int(settings.route_layer_osm_stop_batch_size))
def flush_links() -> None:
nonlocal link_count
if visual_only:
session.add_all([canonical for _, canonical, _ in visual_only])
session.flush()
for feature, canonical, _ in visual_only:
link_objects.append(
CanonicalStopLink(
canonical_stop_id=canonical.id,
layer="visual",
object_type="osm_feature",
dataset_id=feature.dataset_id,
object_id=feature.id,
external_id=f"{feature.osm_type}:{feature.osm_id}",
role=feature.kind,
confidence=1.0,
distance_m=None,
)
)
visual_only.clear()
if not link_objects:
return
for chunk in _chunks_objects(link_objects, 5000):
session.bulk_save_objects(chunk)
link_count += len(link_objects)
link_objects.clear()
_commit_or_flush(session, commit_batches)
for dataset_id in active_osm_dataset_ids:
offset = 0
while True:
features = query_osm_features(
session,
[dataset_id],
kinds=["stop", "station", "terminal"],
geometry_required=True,
limit=batch_size,
offset=offset,
)
if not features:
break
for feature in features:
point = _representative_point(feature.geometry_geojson)
if point is None:
continue
canonical, distance_m, confidence = _nearest_canonical_stop_from_grid(canonical_grid, feature, point)
if canonical is None:
if feature.dataset_id in sidecar_dataset_ids and not settings.osm_sidecar_create_visual_only_stops:
continue
feature = ensure_main_osm_feature(session, feature)
canonical = CanonicalStop(
stop_key=f"osm:{feature.dataset_id}:{feature.id}",
name=feature.name or feature.ref or f"OSM {feature.osm_type} {feature.osm_id}",
normalized_name=norm_text(feature.name or feature.ref or feature.osm_id),
lat=point.y,
lon=point.x,
mode=feature.mode,
metadata_json=json.dumps({"osm_feature_id": feature.id}, separators=(",", ":")),
)
visual_only.append((feature, canonical, point))
continue
feature = ensure_main_osm_feature(session, feature)
link_objects.append(
CanonicalStopLink(
canonical_stop_id=canonical.id,
layer="visual",
object_type="osm_feature",
dataset_id=feature.dataset_id,
object_id=feature.id,
external_id=f"{feature.osm_type}:{feature.osm_id}",
role=feature.kind,
confidence=confidence,
distance_m=distance_m,
)
)
processed += len(features)
offset += len(features)
flush_links()
_emit_progress(
progress_callback,
"route_layer_osm_stop_batch",
f"Linked OSM stops for dataset #{dataset_id}.",
processed,
total_features or None,
{"dataset_id": dataset_id, "processed": processed, "links": link_count},
)
if len(features) < batch_size:
break
flush_links()
session.flush()
return {"canonical_stop_links": link_count}
def _link_osm_stops_postgis(
session: Session,
active_osm_dataset_ids: list[int],
*,
progress_callback: ProgressCallback | None,
commit_batches: bool,
) -> dict[str, int]:
refresh_postgis_geometries(session, tables=["canonical_stops", "osm_features"])
dataset_sql = ", ".join(str(int(dataset_id)) for dataset_id in active_osm_dataset_ids)
total_features = sum(osm_feature_count(session, dataset_id, kind=["stop", "station", "terminal"]) for dataset_id in active_osm_dataset_ids)
_emit_progress(
progress_callback,
"route_layer_osm_stop_postgis_started",
"Linking OSM stops with PostGIS spatial join.",
0,
total_features or None,
{"datasets": active_osm_dataset_ids},
)
params = {
"base_radius_deg": OSM_STOP_LINK_RADIUS_DEG,
"name_radius_deg": OSM_STOP_NAME_LINK_RADIUS_DEG,
"name_threshold": 0.25,
}
session.execute(
text(
f"""
WITH ranked AS (
SELECT
o.dataset_id,
o.id AS osm_feature_id,
o.osm_type,
o.osm_id,
o.kind,
c.id AS canonical_stop_id,
ST_Distance(o.geom, c.geom) AS distance_deg,
ST_Distance(o.geom::geography, c.geom::geography) AS distance_m,
GREATEST(
similarity(LOWER(COALESCE(o.name, '')), LOWER(COALESCE(c.normalized_name, ''))),
similarity(LOWER(COALESCE(o.ref, '')), LOWER(COALESCE(c.normalized_name, '')))
) AS name_score,
ROW_NUMBER() OVER (
PARTITION BY o.dataset_id, o.id
ORDER BY
(ST_Distance(o.geom, c.geom) * 111320.0)
- (
GREATEST(
similarity(LOWER(COALESCE(o.name, '')), LOWER(COALESCE(c.normalized_name, ''))),
similarity(LOWER(COALESCE(o.ref, '')), LOWER(COALESCE(c.normalized_name, '')))
) * 120.0
),
c.id
) AS rn
FROM osm_features AS o
JOIN LATERAL (
SELECT candidate.*
FROM canonical_stops AS candidate
WHERE candidate.geom IS NOT NULL
AND candidate.geom && ST_Expand(o.geom, :name_radius_deg)
AND ST_DWithin(candidate.geom, o.geom, :name_radius_deg)
ORDER BY o.geom <-> candidate.geom
LIMIT 12
) AS c ON TRUE
WHERE o.dataset_id IN ({dataset_sql})
AND o.kind IN ('stop', 'station', 'terminal')
AND o.geom IS NOT NULL
)
INSERT INTO canonical_stop_links
(canonical_stop_id, layer, object_type, dataset_id, object_id, external_id, role, confidence, distance_m)
SELECT
canonical_stop_id,
'visual',
'osm_feature',
dataset_id,
osm_feature_id,
osm_type || ':' || osm_id,
kind,
ROUND(
LEAST(
1.0::double precision,
GREATEST(
0.0::double precision,
(
1.0
- distance_deg
/ CASE WHEN name_score >= :name_threshold THEN :name_radius_deg ELSE :base_radius_deg END
) * 0.6
+ name_score * 0.4
)
)::numeric,
3
)::double precision,
ROUND(distance_m::numeric, 1)::double precision
FROM ranked
WHERE rn = 1
AND (
distance_deg <= :base_radius_deg
OR (name_score >= :name_threshold AND distance_deg <= :name_radius_deg)
)
ON CONFLICT ON CONSTRAINT uq_canonical_stop_link_object DO NOTHING
"""
),
params,
)
_commit_or_flush(session, commit_batches)
link_count = int(
session.scalar(
text(
f"""
SELECT COUNT(*)
FROM canonical_stop_links
WHERE layer = 'visual'
AND object_type = 'osm_feature'
AND dataset_id IN ({dataset_sql})
"""
)
)
or 0
)
analyze_postgresql_tables(session, ["canonical_stop_links"])
_emit_progress(
progress_callback,
"route_layer_osm_stop_postgis_completed",
"Linked OSM stops with PostGIS spatial join.",
total_features,
total_features or None,
{"datasets": active_osm_dataset_ids, "links": link_count},
)
return {"canonical_stop_links": link_count}
def _build_route_patterns(
session: Session,
*,
progress_callback: ProgressCallback | None = None,
) -> dict[str, int]:
osm_candidates = _osm_route_candidates(session, progress_callback=progress_callback)
overrides = _route_layer_overrides(session)
seeds = _gtfs_pattern_seeds(session)
_emit_progress(
progress_callback,
"route_layer_pattern_seeds",
f"Loaded {len(seeds)} GTFS route-pattern seeds.",
0,
len(seeds),
{"seeds": len(seeds)},
)
link_count = 0
stop_count = 0
proposed_count = 0
existing_patterns_by_key = {
pattern.pattern_key: pattern
for pattern in session.scalars(select(RoutePattern).order_by(RoutePattern.id)).all()
}
patterns_by_key: dict[str, RoutePattern] = {}
pattern_usage: dict[str, int] = {}
pattern_confidence_by_key: dict[str, float] = {}
created_pattern_count = 0
updated_pattern_keys: set[str] = set()
pending: list[_PatternBuildItem] = []
for index, seed in enumerate(seeds, start=1):
if not seed.geometry_text:
continue
shape_key = seed.shape_id or GTFS_ROUTE_PATTERN_NULL_SHAPE
chosen, score, reasons = _choose_osm_candidate(seed, osm_candidates, overrides)
if chosen is not None:
chosen_feature = ensure_main_osm_feature(session, chosen.feature)
pattern_key = _osm_pattern_key(chosen_feature)
source_kind = "osm"
status = "active"
confidence = score
pattern = patterns_by_key.get(pattern_key) or existing_patterns_by_key.get(pattern_key)
if pattern is None:
bbox = chosen.bbox
pattern = RoutePattern(
pattern_key=pattern_key,
route_ref=chosen_feature.ref or seed.route.short_name or seed.route.route_id,
route_name=chosen_feature.name or seed.route.long_name,
mode=chosen_feature.mode or seed.route.mode,
route_scope=chosen_feature.route_scope
or infer_osm_route_scope_from_tags(
chosen_feature.mode,
chosen_feature.ref,
chosen_feature.name,
chosen_feature.network,
chosen_feature.tags_json,
),
operator_name=chosen_feature.operator or seed.route.operator_name,
source_kind=source_kind,
status=status,
osm_feature_id=chosen_feature.id,
gtfs_route_id=seed.route.id,
gtfs_shape_id=None,
geometry_geojson=chosen.geometry_text,
min_lon=bbox[0],
min_lat=bbox[1],
max_lon=bbox[2],
max_lat=bbox[3],
confidence=confidence,
metadata_json=json.dumps(
{
"version": ROUTE_LAYER_VERSION,
"visual_source": "osm_feature",
"osm_feature_id": chosen_feature.id,
"osm_type": chosen_feature.osm_type,
"osm_id": chosen_feature.osm_id,
},
separators=(",", ":"),
),
)
session.add(pattern)
created_pattern_count += 1
else:
bbox = chosen.bbox
changed = _update_route_pattern(
pattern,
route_ref=chosen_feature.ref or seed.route.short_name or seed.route.route_id,
route_name=chosen_feature.name or seed.route.long_name,
mode=chosen_feature.mode or seed.route.mode,
route_scope=chosen_feature.route_scope
or infer_osm_route_scope_from_tags(
chosen_feature.mode,
chosen_feature.ref,
chosen_feature.name,
chosen_feature.network,
chosen_feature.tags_json,
),
operator_name=chosen_feature.operator or seed.route.operator_name,
source_kind=source_kind,
status=status,
osm_feature_id=chosen_feature.id,
gtfs_route_id=seed.route.id,
gtfs_shape_id=None,
geometry_geojson=chosen.geometry_text,
min_lon=bbox[0],
min_lat=bbox[1],
max_lon=bbox[2],
max_lat=bbox[3],
metadata_json=json.dumps(
{
"version": ROUTE_LAYER_VERSION,
"visual_source": "osm_feature",
"osm_feature_id": chosen_feature.id,
"osm_type": chosen_feature.osm_type,
"osm_id": chosen_feature.osm_id,
},
separators=(",", ":"),
),
)
if changed:
updated_pattern_keys.add(pattern_key)
patterns_by_key[pattern_key] = pattern
next_confidence = max(pattern_confidence_by_key.get(pattern_key, confidence), confidence)
pattern_confidence_by_key[pattern_key] = next_confidence
if pattern_key in existing_patterns_by_key and float(pattern.confidence or 0) != float(next_confidence):
updated_pattern_keys.add(pattern_key)
pattern.confidence = next_confidence
link_reasons = _link_reasons(seed, chosen, reasons)
else:
pattern_key = f"gtfs:{seed.route.dataset_id}:{seed.route.route_id}:{shape_key}"
source_kind = "gtfs_proposed"
status = "needs_visual_review"
confidence = 0.0
proposed_count += 1
metadata_json = json.dumps(
{
"version": ROUTE_LAYER_VERSION,
"visual_source": "gtfs_shape",
"gtfs_geometry_source": seed.geometry_source,
"match_reasons": reasons,
},
separators=(",", ":"),
)
pattern = patterns_by_key.get(pattern_key) or existing_patterns_by_key.get(pattern_key)
if pattern is None:
pattern = RoutePattern(
pattern_key=pattern_key,
route_ref=seed.route.short_name or seed.route.route_id,
route_name=seed.route.long_name,
mode=seed.route.mode,
route_scope=seed.route.route_scope,
operator_name=seed.route.operator_name,
source_kind=source_kind,
status=status,
osm_feature_id=None,
gtfs_route_id=seed.route.id,
gtfs_shape_id=seed.shape_id,
geometry_geojson=seed.geometry_text,
min_lon=seed.bbox[0],
min_lat=seed.bbox[1],
max_lon=seed.bbox[2],
max_lat=seed.bbox[3],
confidence=confidence,
metadata_json=metadata_json,
)
session.add(pattern)
created_pattern_count += 1
else:
changed = _update_route_pattern(
pattern,
route_ref=seed.route.short_name or seed.route.route_id,
route_name=seed.route.long_name,
mode=seed.route.mode,
route_scope=seed.route.route_scope,
operator_name=seed.route.operator_name,
source_kind=source_kind,
status=status,
osm_feature_id=None,
gtfs_route_id=seed.route.id,
gtfs_shape_id=seed.shape_id,
geometry_geojson=seed.geometry_text,
min_lon=seed.bbox[0],
min_lat=seed.bbox[1],
max_lon=seed.bbox[2],
max_lat=seed.bbox[3],
metadata_json=metadata_json,
)
if changed:
updated_pattern_keys.add(pattern_key)
patterns_by_key[pattern_key] = pattern
pattern_confidence_by_key[pattern_key] = confidence
if pattern_key in existing_patterns_by_key and float(pattern.confidence or 0) != float(confidence):
updated_pattern_keys.add(pattern_key)
pattern.confidence = confidence
link_reasons = reasons
pattern_usage[pattern_key] = pattern_usage.get(pattern_key, 0) + 1
pending.append(
_PatternBuildItem(
seed=seed,
pattern=pattern,
confidence=confidence,
source_kind=source_kind,
status=status,
reasons=link_reasons,
)
)
if index % 500 == 0:
session.flush()
_emit_progress(
progress_callback,
"route_layer_pattern_batch",
f"Built {index}/{len(seeds)} route-pattern candidates.",
index,
len(seeds),
{"patterns": len(patterns_by_key), "links_pending": len(pending), "gtfs_proposed_patterns": proposed_count},
)
session.flush()
obsolete_pattern_ids = [
pattern.id
for pattern_key, pattern in existing_patterns_by_key.items()
if pattern_key not in patterns_by_key and pattern.id is not None
]
for chunk in _chunks_objects(obsolete_pattern_ids, 1000):
session.execute(delete(RoutePattern).where(RoutePattern.id.in_(chunk)))
if obsolete_pattern_ids:
session.flush()
refresh_postgis_geometries(session, tables=["route_patterns"])
analyze_postgresql_tables(session, ["route_patterns"])
_emit_progress(
progress_callback,
"route_layer_patterns_materialized",
"Materialized route-pattern rows.",
len(seeds),
len(seeds),
{
"route_patterns": len(patterns_by_key),
"route_patterns_created": created_pattern_count,
"route_patterns_updated": len(updated_pattern_keys),
"route_patterns_reused": max(0, len(patterns_by_key) - created_pattern_count - len(updated_pattern_keys)),
"route_patterns_removed": len(obsolete_pattern_ids),
"gtfs_proposed_patterns": proposed_count,
},
)
for pattern_key, count in pattern_usage.items():
_update_pattern_metadata(patterns_by_key[pattern_key], linked_gtfs_patterns=count)
link_objects: list[GtfsRoutePatternLink] = []
for item in pending:
seed = item.seed
shape_key = seed.shape_id or GTFS_ROUTE_PATTERN_NULL_SHAPE
link_objects.append(
GtfsRoutePatternLink(
dataset_id=seed.route.dataset_id,
gtfs_route_id=seed.route.id,
route_id=seed.route.route_id,
shape_id=shape_key,
route_pattern_id=item.pattern.id,
confidence=item.confidence,
status=item.status,
source_kind=item.source_kind,
reasons_json=json.dumps(item.reasons, separators=(",", ":")),
)
)
link_count += 1
for chunk in _chunks_objects(link_objects, 5000):
session.bulk_save_objects(chunk)
_emit_progress(
progress_callback,
"route_layer_pattern_links",
"Stored GTFS route-pattern links.",
link_count,
link_count,
{"route_pattern_links": link_count},
)
stop_times_by_trip = _representative_stop_times(session, pending)
canonical_lookup = _canonical_link_lookup(session, stop_times_by_trip)
stop_objects: list[RoutePatternStop] = []
representative_stop_items: dict[int, _PatternBuildItem] = {}
for item in pending:
if item.pattern.id is not None:
representative_stop_items.setdefault(item.pattern.id, item)
for item in representative_stop_items.values():
seed = item.seed
objects = _route_pattern_stop_objects(
pattern=item.pattern,
dataset_id=seed.route.dataset_id,
trip_id=seed.trip_id,
rows=stop_times_by_trip.get((seed.route.dataset_id, seed.trip_id or ""), []),
canonical_lookup=canonical_lookup,
)
stop_objects.extend(objects)
stop_count += len(objects)
if len(stop_objects) >= 10000:
session.bulk_save_objects(stop_objects)
stop_objects.clear()
_emit_progress(
progress_callback,
"route_layer_pattern_stop_batch",
"Stored route-pattern stop links.",
stop_count,
None,
{"route_pattern_stops": stop_count},
)
if stop_objects:
session.bulk_save_objects(stop_objects)
trip_link_count = _build_trip_route_pattern_links(session)
session.flush()
result = {
"route_patterns": len(patterns_by_key),
"route_patterns_created": created_pattern_count,
"route_patterns_updated": len(updated_pattern_keys),
"route_patterns_reused": max(0, len(patterns_by_key) - created_pattern_count - len(updated_pattern_keys)),
"route_patterns_removed": len(obsolete_pattern_ids),
"route_pattern_links": link_count,
"trip_pattern_links": trip_link_count,
"route_pattern_stops": stop_count,
"gtfs_proposed_patterns": proposed_count,
}
_emit_progress(
progress_callback,
"route_layer_patterns_completed",
"Route-pattern build completed.",
len(seeds),
len(seeds),
result,
)
return result
def _update_route_pattern(pattern: RoutePattern, **fields) -> bool:
changed = False
for key, value in fields.items():
if key == "metadata_json":
value = _route_pattern_metadata_with_existing_derived_values(pattern.metadata_json, value)
if getattr(pattern, key) == value:
continue
setattr(pattern, key, value)
changed = True
return changed
def _route_pattern_metadata_with_existing_derived_values(existing_json: str | None, next_json: str | None) -> str | None:
if not next_json:
return next_json
try:
existing = json.loads(existing_json or "{}")
next_metadata = json.loads(next_json)
except json.JSONDecodeError:
return next_json
if "linked_gtfs_patterns" in existing:
next_metadata["linked_gtfs_patterns"] = existing["linked_gtfs_patterns"]
return json.dumps(next_metadata, separators=(",", ":"))
def _build_trip_route_pattern_links(session: Session) -> int:
session.flush()
session.execute(delete(GtfsTripRoutePatternLink))
result = session.execute(
text(
"""
INSERT INTO gtfs_trip_route_pattern_links
(dataset_id, trip_id, route_id, shape_id, route_pattern_id, source_kind, confidence, status)
SELECT
trips.dataset_id,
trips.trip_id,
trips.route_id,
COALESCE(trips.shape_id, :null_shape) AS shape_id,
links.route_pattern_id,
links.source_kind,
links.confidence,
links.status
FROM gtfs_trips AS trips
JOIN gtfs_route_pattern_links AS links
ON links.dataset_id = trips.dataset_id
AND links.route_id = trips.route_id
AND links.shape_id = COALESCE(trips.shape_id, :null_shape)
"""
),
{"null_shape": GTFS_ROUTE_PATTERN_NULL_SHAPE},
)
return int(result.rowcount or 0)
def _active_dataset_ids(session: Session, kind: str) -> list[int]:
return [
row[0]
for row in session.execute(select(Dataset.id).where(Dataset.is_active.is_(True), Dataset.kind == kind)).all()
]
def _best_display_stop(group_id: str, stops: list[GtfsStop]) -> GtfsStop:
return min(
stops,
key=lambda stop: (
0 if stop.stop_id == group_id and stop.parent_station is None else 1,
0 if stop.parent_station == group_id else 1,
0 if stop.parent_station is not None else 1,
0 if stop.lat is not None and stop.lon is not None else 1,
stop.name or "",
stop.stop_id,
),
)
def _canonical_stop_grid(session: Session) -> dict[tuple[int, int], list[CanonicalStop]]:
stops = session.scalars(select(CanonicalStop).where(CanonicalStop.lon.is_not(None), CanonicalStop.lat.is_not(None))).all()
grid: dict[tuple[int, int], list[CanonicalStop]] = {}
for stop in stops:
grid.setdefault(_grid_cell(stop.lon, stop.lat), []).append(stop)
return grid
def _nearest_canonical_stop_from_grid(
grid: dict[tuple[int, int], list[CanonicalStop]], feature: OsmFeature, point: Point
) -> tuple[CanonicalStop | None, float | None, float]:
cell_x, cell_y = _grid_cell(point.x, point.y)
candidates = [
stop
for dx in (-1, 0, 1)
for dy in (-1, 0, 1)
for stop in grid.get((cell_x + dx, cell_y + dy), [])
]
best = None
best_score = -1.0
feature_name = norm_text(feature.name or feature.ref or "")
for candidate in candidates:
if candidate.lon is None or candidate.lat is None:
continue
distance_deg = Point(candidate.lon, candidate.lat).distance(point)
distance_m = distance_deg * 111_320
name_overlap = _name_overlap(feature_name, candidate.normalized_name)
max_radius = OSM_STOP_NAME_LINK_RADIUS_DEG if name_overlap >= 0.25 else OSM_STOP_LINK_RADIUS_DEG
if distance_deg > max_radius:
continue
distance_score = max(0.0, 1.0 - (distance_deg / max_radius))
score = distance_score * 0.6 + name_overlap * 0.4
if score > best_score:
best = (candidate, round(distance_m, 1), round(score, 3))
best_score = score
if best is None:
return None, None, 0.0
return best
def _grid_cell(lon: float, lat: float) -> tuple[int, int]:
return int(lon / OSM_STOP_LINK_RADIUS_DEG), int(lat / OSM_STOP_LINK_RADIUS_DEG)
def _name_overlap(left: str, right: str) -> float:
if not left or not right:
return 0.0
left_tokens = set(left.split())
right_tokens = set(right.split())
if not left_tokens or not right_tokens:
return 0.0
return len(left_tokens & right_tokens) / len(left_tokens | right_tokens)
def _representative_point(geometry_text: str | None) -> Point | None:
if not geometry_text:
return None
try:
geom = shape(json.loads(geometry_text))
except Exception: # noqa: BLE001 - malformed source geometry should not stop extraction
return None
if isinstance(geom, Point):
return geom
return geom.representative_point()
def _osm_route_candidates(
session: Session,
*,
progress_callback: ProgressCallback | None = None,
) -> _OsmRouteCandidateIndex:
active_osm_dataset_ids = _active_dataset_ids(session, "osm_geojson")
if not active_osm_dataset_ids:
return _OsmRouteCandidateIndex(by_ref_mode={}, by_id={})
indexed: dict[tuple[str, str], list[_OsmRouteCandidate]] = {}
by_id: dict[int, _OsmRouteCandidate] = {}
total_features = sum(osm_feature_count(session, dataset_id, kind="route") for dataset_id in active_osm_dataset_ids)
processed = 0
batch_size = max(100, int(settings.route_layer_osm_route_batch_size))
for dataset_id in active_osm_dataset_ids:
offset = 0
while True:
features = query_osm_features(
session,
[dataset_id],
kinds=["route"],
geometry_required=True,
limit=batch_size,
offset=offset,
)
if not features:
break
for feature in features:
try:
geometry_text = _normalized_geometry_text(feature.geometry_geojson) or feature.geometry_geojson
geom = shape(json.loads(geometry_text))
except Exception: # noqa: BLE001 - ignore malformed route geometry
continue
ref_key = norm_ref(feature.ref or feature.name or "")
if not ref_key:
continue
_, bbox = geometry_json_and_bbox(json.loads(geometry_text))
candidate = _OsmRouteCandidate(
feature=feature,
geom=geom,
geometry_text=geometry_text,
bbox=bbox,
ref_key=ref_key,
mode=feature.mode,
)
indexed.setdefault((ref_key, feature.mode or ""), []).append(candidate)
by_id[feature.id] = candidate
processed += len(features)
offset += len(features)
_emit_progress(
progress_callback,
"route_layer_osm_route_batch",
f"Indexed OSM route candidates for dataset #{dataset_id}.",
processed,
total_features or None,
{"dataset_id": dataset_id, "processed": processed, "candidate_refs": len(indexed), "candidates": len(by_id)},
)
if len(features) < batch_size:
break
_emit_progress(
progress_callback,
"route_layer_osm_routes_indexed",
"Indexed OSM route candidates.",
processed,
total_features or None,
{"candidate_refs": len(indexed), "candidates": len(by_id)},
)
return _OsmRouteCandidateIndex(by_ref_mode=indexed, by_id=by_id)
def _route_layer_overrides(session: Session) -> _RouteLayerOverrides:
matches = session.scalars(
select(RouteMatch).where(RouteMatch.status.in_(["accepted", "rejected"]))
).all()
accepted: dict[int, int] = {}
rejected: dict[int, set[int]] = {}
for match in matches:
if match.osm_feature_id is None:
continue
if match.status == "accepted":
accepted[match.gtfs_route_id] = match.osm_feature_id
elif match.status == "rejected":
rejected.setdefault(match.gtfs_route_id, set()).add(match.osm_feature_id)
return _RouteLayerOverrides(accepted_by_gtfs_route_id=accepted, rejected_by_gtfs_route_id=rejected)
def _gtfs_pattern_seeds(session: Session) -> list[_GtfsPatternSeed]:
active_gtfs_dataset_ids = _active_dataset_ids(session, "gtfs")
if not active_gtfs_dataset_ids:
return []
rows = session.execute(
select(GtfsRoute, GtfsTrip.shape_id, func.min(GtfsTrip.trip_id))
.join(GtfsTrip, and_(GtfsTrip.dataset_id == GtfsRoute.dataset_id, GtfsTrip.route_id == GtfsRoute.route_id))
.where(GtfsRoute.dataset_id.in_(active_gtfs_dataset_ids))
.group_by(GtfsRoute.id, GtfsTrip.shape_id)
.order_by(GtfsRoute.dataset_id, GtfsRoute.route_id, GtfsTrip.shape_id)
).all()
shape_rows = session.execute(
select(
GtfsShape.dataset_id,
GtfsShape.shape_id,
GtfsShape.geometry_geojson,
GtfsShape.min_lon,
GtfsShape.min_lat,
GtfsShape.max_lon,
GtfsShape.max_lat,
).where(GtfsShape.dataset_id.in_(active_gtfs_dataset_ids))
).all()
shapes = {
(dataset_id, shape_id): {
"geometry": geometry,
"bbox": (min_lon, min_lat, max_lon, max_lat),
"points": _geometry_points_from_text(geometry),
}
for dataset_id, shape_id, geometry, min_lon, min_lat, max_lon, max_lat in shape_rows
}
seeds = []
for route, shape_id, trip_id in rows:
geometry_text = None
geometry_source = "none"
bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
points = _geometry_points_from_text(route.geometry_geojson)
if shape_id:
shape_row = shapes.get((route.dataset_id, shape_id))
if shape_row is not None:
geometry_text = shape_row["geometry"]
bbox = shape_row["bbox"]
points = shape_row["points"]
geometry_source = "gtfs_shape"
if not geometry_text and route.geometry_geojson:
geometry_text = route.geometry_geojson
geometry_source = "gtfs_route"
start_point = Point(points[0]) if points else None
end_point = Point(points[-1]) if points else None
center_point = _bbox_center_point(bbox)
seeds.append(
_GtfsPatternSeed(
route=route,
shape_id=shape_id,
trip_id=trip_id,
geometry_text=geometry_text,
geometry_source=geometry_source,
bbox=bbox,
start_point=start_point,
end_point=end_point,
center_point=center_point,
)
)
return seeds
def _choose_osm_candidate(
seed: _GtfsPatternSeed,
candidate_index: _OsmRouteCandidateIndex,
overrides: _RouteLayerOverrides,
) -> tuple[_OsmRouteCandidate | None, float, dict[str, object]]:
if not seed.geometry_text:
return None, 0.0, {"reason": "no GTFS geometry available"}
accepted_feature_id = overrides.accepted_by_gtfs_route_id.get(seed.route.id)
if accepted_feature_id is not None:
accepted = candidate_index.by_id.get(accepted_feature_id)
if accepted is not None:
return (
accepted,
100.0,
{
"manual": "accepted_route_match",
"osm_feature_id": accepted.feature.id,
"osm_id": accepted.feature.osm_id,
},
)
route_ref = norm_ref(seed.route.short_name or seed.route.route_id)
if not route_ref:
return None, 0.0, {"reason": "no GTFS route ref"}
candidate_pool = []
rejected_feature_ids = overrides.rejected_by_gtfs_route_id.get(seed.route.id, set())
for (ref_key, mode), candidates in candidate_index.by_ref_mode.items():
if ref_key != route_ref:
continue
if _mode_compatible(seed.route.mode or "", mode):
candidate_pool.extend(candidate for candidate in candidates if candidate.feature.id not in rejected_feature_ids)
if not candidate_pool:
return None, 0.0, {"reason": "no OSM route candidate with same ref and mode"}
best = None
best_rank_score = 0.0
best_score = 0.0
best_reasons: dict[str, object] = {}
for candidate in candidate_pool:
score = 50.0
reasons: dict[str, object] = {"ref": "exact", "mode": "compatible"}
if bbox_overlap(seed.bbox, candidate.bbox):
score += 20
reasons["bbox"] = "overlap"
if seed.start_point is not None and seed.end_point is not None:
endpoint_distance = candidate.geom.distance(seed.start_point) + candidate.geom.distance(seed.end_point)
reasons["endpoint_distance_deg"] = round(endpoint_distance, 6)
if endpoint_distance < 0.002:
score += 30
elif endpoint_distance < 0.01:
score += 22
elif endpoint_distance < 0.03:
score += 10
direction_metrics = _candidate_direction_metrics(seed, candidate)
if direction_metrics:
direction_score = _direction_alignment_score(direction_metrics)
score += direction_score
reasons["directional_match"] = {**direction_metrics, "score": direction_score}
if seed.center_point is not None:
centroid_distance = candidate.geom.distance(seed.center_point)
reasons["center_distance_deg"] = round(centroid_distance, 6)
if centroid_distance < 0.004:
score += 10
elif centroid_distance < 0.015:
score += 5
if score > best_rank_score:
best = candidate
best_rank_score = score
best_score = min(score, 100.0)
best_reasons = reasons
if best is None or best_score < OSM_ROUTE_MIN_SCORE:
reasons = best_reasons or {"reason": "no OSM candidate above threshold"}
reasons["fallback"] = "gtfs_proposed_route_layer_pattern"
return None, best_score, reasons
best_reasons["osm_feature_id"] = best.feature.id
best_reasons["osm_id"] = best.feature.osm_id
return best, best_score, best_reasons
def _osm_pattern_key(feature: OsmFeature) -> str:
return f"osm:{feature.osm_type}:{feature.osm_id}"
def _link_reasons(seed: _GtfsPatternSeed, chosen: _OsmRouteCandidate, reasons: dict[str, object]) -> dict[str, object]:
link_reasons = dict(reasons)
link_reasons["gtfs_geometry_source"] = seed.geometry_source
link_reasons["direction"] = _direction_evidence(seed, chosen)
return link_reasons
def _direction_evidence(seed: _GtfsPatternSeed, candidate: _OsmRouteCandidate) -> dict[str, object]:
if seed.start_point is None or seed.end_point is None:
return {"direction": "unknown", "reason": "missing GTFS shape endpoints"}
evidence: dict[str, object] = {}
start_projection = _project_point_on_geometry(candidate.geom, seed.start_point)
end_projection = _project_point_on_geometry(candidate.geom, seed.end_point)
if start_projection is not None and end_projection is not None:
evidence["start_projection"] = round(start_projection, 6)
evidence["end_projection"] = round(end_projection, 6)
if abs(start_projection - end_projection) > 1e-9:
evidence["direction"] = "forward" if start_projection < end_projection else "reverse"
endpoints = _geometry_endpoints(candidate.geom)
if endpoints is not None:
osm_start, osm_end = endpoints
forward_distance = osm_start.distance(seed.start_point) + osm_end.distance(seed.end_point)
reverse_distance = osm_start.distance(seed.end_point) + osm_end.distance(seed.start_point)
evidence["endpoint_forward_distance_deg"] = round(forward_distance, 6)
evidence["endpoint_reverse_distance_deg"] = round(reverse_distance, 6)
if abs(forward_distance - reverse_distance) > 1e-9:
evidence["endpoint_direction"] = "forward" if forward_distance < reverse_distance else "reverse"
evidence.setdefault("direction", evidence.get("endpoint_direction", "unknown"))
evidence.setdefault("direction", "unknown")
return evidence
def _candidate_direction_metrics(seed: _GtfsPatternSeed, candidate: _OsmRouteCandidate) -> dict[str, object] | None:
if seed.start_point is None or seed.end_point is None:
return None
metrics: dict[str, object] = {}
start_projection = _project_point_on_geometry(candidate.geom, seed.start_point)
end_projection = _project_point_on_geometry(candidate.geom, seed.end_point)
if start_projection is not None and end_projection is not None:
projection_delta = end_projection - start_projection
metrics["projection_delta"] = round(projection_delta, 6)
if abs(projection_delta) > 1e-9:
metrics["projection_direction"] = "forward" if projection_delta > 0 else "reverse"
endpoints = _geometry_endpoints(candidate.geom)
if endpoints is not None:
osm_start, osm_end = endpoints
forward_distance = osm_start.distance(seed.start_point) + osm_end.distance(seed.end_point)
reverse_distance = osm_start.distance(seed.end_point) + osm_end.distance(seed.start_point)
metrics["endpoint_forward_distance_deg"] = round(forward_distance, 6)
metrics["endpoint_reverse_distance_deg"] = round(reverse_distance, 6)
metrics["endpoint_margin_deg"] = round(abs(reverse_distance - forward_distance), 6)
if abs(forward_distance - reverse_distance) > 1e-9:
metrics["endpoint_direction"] = "forward" if forward_distance < reverse_distance else "reverse"
return metrics or None
def _direction_alignment_score(metrics: dict[str, object]) -> float:
score = 0.0
if metrics.get("projection_direction") == "forward":
score += 16.0
if metrics.get("endpoint_direction") == "forward":
forward_distance = float(metrics.get("endpoint_forward_distance_deg") or 999.0)
margin = float(metrics.get("endpoint_margin_deg") or 0.0)
if forward_distance < 0.004:
score += 12.0
elif forward_distance < 0.015:
score += 7.0
elif forward_distance < 0.04:
score += 3.0
if margin > 0.01:
score += 4.0
elif margin > 0.002:
score += 2.0
return min(score, 28.0)
def _update_pattern_metadata(pattern: RoutePattern, **values: object) -> None:
try:
metadata = json.loads(pattern.metadata_json or "{}")
except json.JSONDecodeError:
metadata = {}
metadata.update(values)
pattern.metadata_json = json.dumps(metadata, separators=(",", ":"))
def _representative_stop_times(
session: Session, pending: list[_PatternBuildItem]
) -> dict[tuple[int, str], list[GtfsStopTime]]:
trip_ids_by_dataset: dict[int, set[str]] = {}
for item in pending:
seed = item.seed
if seed.trip_id:
trip_ids_by_dataset.setdefault(seed.route.dataset_id, set()).add(seed.trip_id)
grouped: dict[tuple[int, str], list[GtfsStopTime]] = {}
for dataset_id, trip_ids in trip_ids_by_dataset.items():
for chunk in _chunks(sorted(trip_ids), 600):
rows_by_trip = storage_stop_times_by_trip(session, dataset_id, chunk)
rows = [row for trip_id in chunk for row in rows_by_trip.get(trip_id, [])]
for row in rows:
grouped.setdefault((dataset_id, row.trip_id), []).append(row)
return grouped
def _canonical_link_lookup(
session: Session, stop_times_by_trip: dict[tuple[int, str], list[GtfsStopTime]]
) -> dict[tuple[int, str], int]:
stop_ids_by_dataset: dict[int, set[str]] = {}
for (dataset_id, _), rows in stop_times_by_trip.items():
stop_ids_by_dataset.setdefault(dataset_id, set()).update(row.stop_id for row in rows)
lookup = {}
for dataset_id, stop_ids in stop_ids_by_dataset.items():
for chunk in _chunks(sorted(stop_ids), 900):
links = session.scalars(
select(CanonicalStopLink).where(
CanonicalStopLink.object_type == "gtfs_stop",
CanonicalStopLink.dataset_id == dataset_id,
CanonicalStopLink.external_id.in_(chunk),
)
).all()
lookup.update({(link.dataset_id, link.external_id): link.canonical_stop_id for link in links})
return lookup
def _route_pattern_stop_objects(
pattern: RoutePattern,
dataset_id: int,
trip_id: str | None,
rows: list[GtfsStopTime],
canonical_lookup: dict[tuple[int, str], int],
) -> list[RoutePatternStop]:
if not trip_id:
return []
if not rows:
return []
objects: list[RoutePatternStop] = []
seen: set[int] = set()
for row in rows:
canonical_stop_id = canonical_lookup.get((dataset_id, row.stop_id))
if canonical_stop_id is None:
continue
if canonical_stop_id in seen:
continue
seen.add(canonical_stop_id)
objects.append(
RoutePatternStop(
route_pattern_id=pattern.id,
canonical_stop_id=canonical_stop_id,
sequence=row.stop_sequence,
distance_along=None,
source_kind="timetable_link",
confidence=0.75 if pattern.source_kind == "osm" else 0.45,
)
)
return objects
def _chunks(values: list[str], size: int) -> Iterable[list[str]]:
for start in range(0, len(values), size):
yield values[start : start + size]
def _chunks_objects(values: list, size: int) -> Iterable[list]:
for start in range(0, len(values), size):
yield values[start : start + size]
def _normalized_geometry_text(geometry_text: str | None) -> str | None:
if not geometry_text:
return None
try:
geom = shape(json.loads(geometry_text))
if isinstance(geom, MultiLineString):
merged = linemerge(geom)
if isinstance(merged, (LineString, MultiLineString)) and not merged.is_empty:
geom = merged
return json.dumps(geom.__geo_interface__, separators=(",", ":"))
except Exception: # noqa: BLE001 - preserve source geometry if normalization fails
return geometry_text
def _geometry_points_from_text(geometry_text: str | None) -> list[tuple[float, float]]:
if not geometry_text:
return []
try:
geometry = json.loads(geometry_text)
except json.JSONDecodeError:
return []
geometry_type = geometry.get("type")
coords = geometry.get("coordinates") or []
if geometry_type == "LineString":
return [(float(lon), float(lat)) for lon, lat, *_ in coords]
if geometry_type == "MultiLineString":
lines = [
[(float(lon), float(lat)) for lon, lat, *_ in line]
for line in coords
if len(line) >= 2
]
if not lines:
return []
return max(lines, key=len)
return []
def _bbox_center_point(bbox: tuple[float | None, float | None, float | None, float | None]) -> Point | None:
min_lon, min_lat, max_lon, max_lat = bbox
if None in bbox:
return None
return Point((float(min_lon) + float(max_lon)) / 2, (float(min_lat) + float(max_lat)) / 2)
def _geometry_endpoints(geom) -> tuple[Point, Point] | None:
lines = list(_iter_lines(geom))
if not lines:
return None
longest = max(lines, key=lambda line: line.length)
coords = list(longest.coords)
if len(coords) < 2:
return None
return Point(coords[0]), Point(coords[-1])
def _iter_lines(geom) -> Iterable[LineString]:
if isinstance(geom, LineString):
yield geom
elif isinstance(geom, MultiLineString):
yield from geom.geoms
def _project_point_on_geometry(geom, point: Point) -> float | None:
best_line = None
best_distance = None
for line in _iter_lines(geom):
distance = line.distance(point)
if best_distance is None or distance < best_distance:
best_line = line
best_distance = distance
if best_line is None:
return None
return float(best_line.project(point))
def _bounds_tuple(geom) -> tuple[float | None, float | None, float | None, float | None]:
if geom.is_empty:
return (None, None, None, None)
min_lon, min_lat, max_lon, max_lat = geom.bounds
return min_lon, min_lat, max_lon, max_lat
def _mode_compatible(gtfs_mode: str, osm_mode: str) -> bool:
if not gtfs_mode or not osm_mode:
return True
if gtfs_mode == osm_mode:
return True
return osm_mode in MODE_GROUPS.get(gtfs_mode, {gtfs_mode}) or gtfs_mode in MODE_GROUPS.get(osm_mode, {osm_mode})