Files
meubility-workbench/app/journey_search.py
2026-07-01 23:29:51 +02:00

718 lines
27 KiB
Python

from __future__ import annotations
import json
import hashlib
import threading
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from datetime import date, datetime, timedelta, timezone
from typing import Any
from sqlalchemy import select
from app.address_search import is_location_token
from app.db import SessionLocal
from app.journey import find_journeys, parse_service_date, resolve_location_summary
from app.models import JourneySearchCache
from app.routing import direct_route_between_points, route_between_points
MAX_PROGRESSIVE_TRANSFERS = 5
TRANSIT_STAGE_CACHE_TTL_SECONDS = 5 * 60
TRANSIT_STAGE_CACHE_MAX_ENTRIES = 256
PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS = 10 * 60
PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES = 128
JOURNEY_SEARCH_CACHE_VERSION = "journey-search-v7"
_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="journey-search")
_lock = threading.RLock()
_searches: dict[str, "_SearchState"] = {}
_progressive_search_inflight: dict[tuple[object, ...], str] = {}
_transit_stage_cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]] = {}
_progressive_search_cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]] = {}
@dataclass
class _SearchState:
id: str
request: dict[str, Any]
cache_key: tuple[object, ...] | None = None
status: str = "queued"
message: str = "Queued."
stage: str = "queued"
journeys: list[dict] = field(default_factory=list)
routing: dict[str, Any] | None = None
context: dict[str, Any] = field(default_factory=dict)
error: str | None = None
created_at: float = field(default_factory=time.time)
updated_at: float = field(default_factory=time.time)
complete: bool = False
cancelled: bool = False
def start_journey_search(request: dict[str, Any]) -> dict[str, Any]:
key = _progressive_cache_key(request)
cached = _progressive_cache_get(key)
search_id = uuid.uuid4().hex
state = _SearchState(id=search_id, request=dict(request), cache_key=key)
if cached is not None:
_apply_cached_payload(state, cached)
with _lock:
_prune_old_searches()
if cached is None:
existing_search_id = _progressive_search_inflight.get(key)
existing_state = None if existing_search_id is None else _searches.get(existing_search_id)
if existing_state is not None and not existing_state.complete and not existing_state.cancelled:
return _payload(existing_state)
_progressive_search_inflight[key] = search_id
_searches[search_id] = state
if cached is None:
_executor.submit(_run_search, search_id)
return journey_search_payload(search_id)
def journey_search_payload(search_id: str) -> dict[str, Any]:
with _lock:
state = _searches.get(search_id)
if state is None:
raise KeyError(search_id)
return _payload(state)
def cancel_journey_search(search_id: str) -> dict[str, Any]:
with _lock:
state = _searches.get(search_id)
if state is None:
raise KeyError(search_id)
state.cancelled = True
if not state.complete:
state.status = "cancelled"
state.message = "Search cancelled."
state.complete = True
state.updated_at = time.time()
_clear_inflight_search_locked(state)
return _payload(state)
def _run_search(search_id: str) -> None:
with _lock:
state = _searches.get(search_id)
if state is None or state.cancelled:
return
state.status = "running"
state.stage = "starting"
state.message = "Starting search..."
state.updated_at = time.time()
request = dict(state.request)
try:
mode = str(request.get("mode") or "transit")
if mode in {"walk", "drive", "car"}:
_run_point_route_search(search_id, "drive" if mode == "car" else mode, request)
else:
_run_transit_search(search_id, request)
except Exception as exc: # noqa: BLE001 - report progressive-search failure to client
_publish_error(search_id, str(exc))
def _run_transit_search(search_id: str, request: dict[str, Any]) -> None:
direct_only = bool(request.get("direct_only"))
limit = max(3, min(int(request.get("limit") or 5), 10))
transfer_seconds = max(0, min(int(request.get("transfer_seconds") or 120), 3600))
source_ids = _csv_ints(request.get("source_id"))
service_date = request.get("service_date") or None
stages = [0] if direct_only else list(range(0, MAX_PROGRESSIVE_TRANSFERS + 1))
address_search = is_location_token(request.get("from_stop_id")) or is_location_token(request.get("to_stop_id"))
stage_limit = limit if address_search else max(limit, 10)
merged: dict[str, dict] = {}
context: dict[str, Any] = {}
diagnostics: dict[str, Any] = {"stages": []}
best_count = 0
stale_stages = 0
for transfers in stages:
if _is_cancelled(search_id):
return
label = "direct" if transfers == 0 else f"up to {transfers} transfer{'s' if transfers != 1 else ''}"
_publish_status(search_id, "running", f"Searching {label}...", f"transfers_{transfers}")
stage_started_at = time.monotonic()
with SessionLocal() as db:
result = _cached_find_journeys(
db,
from_stop_id=str(request.get("from_stop_id") or ""),
to_stop_id=str(request.get("to_stop_id") or ""),
via_stop_id=request.get("via_stop_id") or None,
source_ids=source_ids,
departure=str(request.get("departure") or "08:00"),
service_date=service_date,
max_transfers=transfers,
transfer_seconds=transfer_seconds,
limit=stage_limit,
)
cache_status = str(result.pop("_cache_status", "miss"))
elapsed_ms = int((time.monotonic() - stage_started_at) * 1000)
stage_diagnostics = {
"transfers": transfers,
"cache": cache_status,
"elapsed_ms": elapsed_ms,
"journeys": len(result.get("journeys") or []),
}
result_diagnostics = result.get("diagnostics")
if isinstance(result_diagnostics, dict):
stage_diagnostics["details"] = result_diagnostics
diagnostics["stages"].append(stage_diagnostics)
context = _context_from_result(result)
context["diagnostics"] = diagnostics
before = len(merged)
for journey in result.get("journeys") or []:
merged.setdefault(_journey_key(journey), journey)
ranked = _select_diverse_journeys(_rank_journeys(merged.values(), str(request.get("ranking") or "recommended")), limit=limit)
_publish_results(
search_id,
journeys=ranked,
context=context,
status="running",
stage=f"transfers_{transfers}",
message=f"Found {len(ranked)} option{'s' if len(ranked) != 1 else ''}; still searching..." if not direct_only else "Direct search complete.",
)
if len(merged) <= before and ranked:
stale_stages += 1
else:
stale_stages = 0
best_count = max(best_count, len(ranked))
if ranked and stale_stages >= 2 and transfers >= 2:
break
if _major_hub_address_stage_is_complete(result_diagnostics, ranked, transfers=transfers, limit=limit):
break
complete_message = (
f"Search complete. Found {best_count} option{'s' if best_count != 1 else ''}."
if best_count
else "Search complete. No route found in the imported timetable."
)
_publish_complete(search_id, message=complete_message)
payload = journey_search_payload(search_id)
if payload.get("status") == "complete" and not payload.get("error"):
_progressive_cache_put(_progressive_cache_key(request), payload)
def _major_hub_address_stage_is_complete(
diagnostics: dict[str, Any] | None,
ranked: list[dict],
*,
transfers: int,
limit: int,
) -> bool:
if transfers < 1 or not ranked or not isinstance(diagnostics, dict):
return False
address_access = diagnostics.get("address_access")
if not isinstance(address_access, dict) or not address_access.get("major_hubs"):
return False
return len(ranked) >= min(3, limit)
def _run_point_route_search(search_id: str, mode: str, request: dict[str, Any]) -> None:
_publish_status(search_id, "running", f"Searching {mode} route...", mode)
with SessionLocal() as db:
from_location = resolve_location_summary(db, str(request.get("from_stop_id") or ""), source_ids=_csv_ints(request.get("source_id")))
to_location = resolve_location_summary(db, str(request.get("to_stop_id") or ""), source_ids=_csv_ints(request.get("source_id")))
if from_location.lon is None or from_location.lat is None:
raise ValueError("Selected start has no coordinates.")
if to_location.lon is None or to_location.lat is None:
raise ValueError("Selected destination has no coordinates.")
try:
route = route_between_points(
db,
from_lon=float(from_location.lon),
from_lat=float(from_location.lat),
to_lon=float(to_location.lon),
to_lat=float(to_location.lat),
mode=mode,
max_visited=300_000,
)
message = f"{mode.title()} route found."
except Exception as exc: # noqa: BLE001 - point routing should still return an approximate connector
route = direct_route_between_points(
db,
from_lon=float(from_location.lon),
from_lat=float(from_location.lat),
to_lon=float(to_location.lon),
to_lat=float(to_location.lat),
mode=mode,
reason=str(exc),
)
message = f"{mode.title()} route approximated."
context = {
"from": _stop_payload(from_location),
"to": _stop_payload(to_location),
"mode": mode,
}
_publish_routing(search_id, route, context=context, message=message)
_publish_complete(search_id, message=f"{mode.title()} route complete.")
payload = journey_search_payload(search_id)
if payload.get("status") == "complete" and not payload.get("error"):
_progressive_cache_put(_progressive_cache_key(request), payload)
def _cached_find_journeys(
db,
*,
from_stop_id: str,
to_stop_id: str,
via_stop_id: object,
source_ids: list[int] | None,
departure: str,
service_date: object,
max_transfers: int,
transfer_seconds: int,
limit: int,
) -> dict[str, Any]:
key = (
from_stop_id,
to_stop_id,
str(via_stop_id or ""),
tuple(sorted(int(source_id) for source_id in source_ids or [])),
departure,
str(service_date or ""),
int(max_transfers),
int(transfer_seconds),
int(limit),
)
now = time.monotonic()
with _lock:
cached = _transit_stage_cache.get(key)
if cached is not None:
expires_at, payload = cached
if expires_at > now:
return _with_cache_status(payload, "memory")
_transit_stage_cache.pop(key, None)
durable = _durable_cache_get("transit_stage", key)
if durable is not None:
with _lock:
_transit_stage_cache[key] = (now + TRANSIT_STAGE_CACHE_TTL_SECONDS, json.loads(json.dumps(durable)))
_prune_timed_cache(_transit_stage_cache, TRANSIT_STAGE_CACHE_MAX_ENTRIES)
return _with_cache_status(durable, "persistent")
result = find_journeys(
db=db,
from_stop_id=from_stop_id,
to_stop_id=to_stop_id,
via_stop_id=via_stop_id,
source_ids=source_ids,
departure=departure,
service_date=service_date,
max_transfers=max_transfers,
transfer_seconds=transfer_seconds,
limit=limit,
)
stored_result = json.loads(json.dumps(result))
with _lock:
_transit_stage_cache[key] = (now + TRANSIT_STAGE_CACHE_TTL_SECONDS, stored_result)
_prune_timed_cache(_transit_stage_cache, TRANSIT_STAGE_CACHE_MAX_ENTRIES)
_durable_cache_put("transit_stage", key, stored_result, ttl_seconds=TRANSIT_STAGE_CACHE_TTL_SECONDS)
return _with_cache_status(result, "miss")
def _with_cache_status(payload: dict[str, Any], cache_status: str) -> dict[str, Any]:
copied = json.loads(json.dumps(payload))
copied["_cache_status"] = cache_status
return copied
def _prune_timed_cache(cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]], max_entries: int) -> None:
if len(cache) <= max_entries:
return
oldest = sorted(cache.items(), key=lambda item: item[1][0])[: len(cache) - max_entries]
for old_key, _ in oldest:
cache.pop(old_key, None)
def _durable_cache_get(cache_type: str, key: tuple[object, ...]) -> dict[str, Any] | None:
storage_key = _durable_cache_key(cache_type, key)
now = datetime.now(timezone.utc)
try:
with SessionLocal() as session:
row = session.scalar(select(JourneySearchCache).where(JourneySearchCache.cache_key == storage_key))
if row is None:
return None
expires_at = _as_utc(row.expires_at)
if expires_at is None or expires_at <= now:
session.delete(row)
session.commit()
return None
return json.loads(row.payload_json)
except Exception: # noqa: BLE001 - cache misses must not break journey search
return None
def _durable_cache_put(cache_type: str, key: tuple[object, ...], payload: dict[str, Any], *, ttl_seconds: int) -> None:
storage_key = _durable_cache_key(cache_type, key)
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=max(1, int(ttl_seconds)))
try:
with SessionLocal() as session:
row = session.scalar(select(JourneySearchCache).where(JourneySearchCache.cache_key == storage_key))
if row is None:
row = JourneySearchCache(
cache_key=storage_key,
cache_type=cache_type,
payload_json=json.dumps(payload, separators=(",", ":")),
created_at=now,
updated_at=now,
expires_at=expires_at,
)
session.add(row)
else:
row.cache_type = cache_type
row.payload_json = json.dumps(payload, separators=(",", ":"))
row.updated_at = now
row.expires_at = expires_at
session.commit()
except Exception: # noqa: BLE001 - cache writes are best-effort
return
def _durable_cache_key(cache_type: str, key: tuple[object, ...]) -> str:
raw = json.dumps(
{
"version": JOURNEY_SEARCH_CACHE_VERSION,
"cache_type": cache_type,
"key": _json_safe(key),
},
sort_keys=True,
separators=(",", ":"),
)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
def _json_safe(value: object) -> object:
if isinstance(value, tuple):
return [_json_safe(item) for item in value]
if isinstance(value, list):
return [_json_safe(item) for item in value]
if isinstance(value, dict):
return {str(key): _json_safe(item) for key, item in sorted(value.items(), key=lambda item: str(item[0]))}
if isinstance(value, (str, int, float, bool)) or value is None:
return value
if isinstance(value, (date, datetime)):
return value.isoformat()
return str(value)
def _as_utc(value: datetime | None) -> datetime | None:
if value is None:
return None
if value.tzinfo is None:
return value.replace(tzinfo=timezone.utc)
return value.astimezone(timezone.utc)
def _progressive_cache_key(request: dict[str, Any]) -> tuple[object, ...]:
source_ids = _csv_ints(request.get("source_id"))
return (
str(request.get("mode") or "transit"),
str(request.get("from_stop_id") or ""),
str(request.get("to_stop_id") or ""),
str(request.get("via_stop_id") or ""),
tuple(sorted(int(source_id) for source_id in source_ids or [])),
str(request.get("departure") or "08:00"),
str(request.get("service_date") or ""),
bool(request.get("direct_only")),
str(request.get("ranking") or "recommended"),
int(request.get("transfer_seconds") or 120),
max(3, min(int(request.get("limit") or 5), 10)),
)
def _progressive_cache_get(key: tuple[object, ...]) -> dict[str, Any] | None:
now = time.monotonic()
with _lock:
cached = _progressive_search_cache.get(key)
if cached is not None:
expires_at, payload = cached
if expires_at > now:
copied = json.loads(json.dumps(payload))
copied["cache_status"] = "memory"
return copied
_progressive_search_cache.pop(key, None)
durable = _durable_cache_get("progressive", key)
if durable is None:
return None
with _lock:
_progressive_search_cache[key] = (now + PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS, json.loads(json.dumps(durable)))
_prune_timed_cache(_progressive_search_cache, PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES)
copied = json.loads(json.dumps(durable))
copied["cache_status"] = "persistent"
return copied
def _progressive_cache_put(key: tuple[object, ...], payload: dict[str, Any]) -> None:
stored_payload = json.loads(json.dumps(payload))
stored_payload.pop("cache_status", None)
with _lock:
_progressive_search_cache[key] = (time.monotonic() + PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS, stored_payload)
_prune_timed_cache(_progressive_search_cache, PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES)
_durable_cache_put("progressive", key, stored_payload, ttl_seconds=PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS)
def _apply_cached_payload(state: _SearchState, payload: dict[str, Any]) -> None:
state.status = str(payload.get("status") or "complete")
state.message = "Cached result."
state.stage = str(payload.get("stage") or "cached")
state.journeys = json.loads(json.dumps(payload.get("journeys") or []))
state.routing = json.loads(json.dumps(payload.get("routing"))) if payload.get("routing") is not None else None
state.context = {
key: value
for key, value in payload.items()
if key not in {"search_id", "status", "stage", "message", "complete", "error", "journeys", "routing", "created_at", "updated_at"}
}
state.error = None
state.complete = True
state.updated_at = time.time()
def _publish_status(search_id: str, status: str, message: str, stage: str) -> None:
with _lock:
state = _searches.get(search_id)
if state is None or state.cancelled:
return
state.status = status
state.message = message
state.stage = stage
state.updated_at = time.time()
def _publish_results(search_id: str, *, journeys: list[dict], context: dict[str, Any], status: str, stage: str, message: str) -> None:
with _lock:
state = _searches.get(search_id)
if state is None or state.cancelled:
return
state.status = status
state.stage = stage
state.message = message
state.journeys = list(journeys)
state.context = dict(context)
state.updated_at = time.time()
def _publish_routing(search_id: str, routing: dict[str, Any], *, context: dict[str, Any], message: str) -> None:
with _lock:
state = _searches.get(search_id)
if state is None or state.cancelled:
return
state.status = "running"
state.stage = str(routing.get("mode") or "route")
state.message = message
state.routing = routing
state.context = dict(context)
state.updated_at = time.time()
def _publish_complete(search_id: str, *, message: str) -> None:
with _lock:
state = _searches.get(search_id)
if state is None or state.cancelled:
return
state.status = "complete"
state.message = message
state.complete = True
state.updated_at = time.time()
_clear_inflight_search_locked(state)
def _publish_error(search_id: str, message: str) -> None:
with _lock:
state = _searches.get(search_id)
if state is None:
return
state.status = "error"
state.stage = "error"
state.message = message
state.error = message
state.complete = True
state.updated_at = time.time()
_clear_inflight_search_locked(state)
def _clear_inflight_search_locked(state: _SearchState) -> None:
if state.cache_key is not None and _progressive_search_inflight.get(state.cache_key) == state.id:
_progressive_search_inflight.pop(state.cache_key, None)
def _is_cancelled(search_id: str) -> bool:
with _lock:
state = _searches.get(search_id)
return state is None or state.cancelled
def _payload(state: _SearchState) -> dict[str, Any]:
return {
"search_id": state.id,
"status": state.status,
"stage": state.stage,
"message": state.message,
"complete": state.complete,
"error": state.error,
"journeys": json.loads(json.dumps(state.journeys)),
"routing": json.loads(json.dumps(state.routing)) if state.routing is not None else None,
"created_at": state.created_at,
"updated_at": state.updated_at,
**json.loads(json.dumps(state.context)),
}
def _context_from_result(result: dict[str, Any]) -> dict[str, Any]:
return {
key: value
for key, value in result.items()
if key not in {"journeys"} and key not in {"error"}
}
def _journey_key(journey: dict[str, Any]) -> str:
parts = []
for leg in journey.get("legs") or []:
parts.append(
"|".join(
str(part or "")
for part in [
leg.get("dataset_id"),
leg.get("mode"),
leg.get("route_id"),
leg.get("trip_id"),
(leg.get("from") or {}).get("stop_id") or (leg.get("from") or {}).get("name"),
(leg.get("to") or {}).get("stop_id") or (leg.get("to") or {}).get("name"),
leg.get("departure_time"),
leg.get("arrival_time"),
]
)
)
return "||".join(parts)
def _rank_journeys(journeys, ranking: str) -> list[dict]:
def key(journey: dict[str, Any]) -> tuple[float, float, int, float]:
departure = journey.get("departure_seconds")
arrival = journey.get("arrival_seconds")
duration = journey.get("duration_minutes")
transfers = int(journey.get("transfers") or 0)
walking = sum(float(leg.get("distance_m") or 0) for leg in journey.get("legs") or [] if leg.get("mode") == "walk")
walking_seconds = walking / 1.35
if ranking == "duration":
return (
float("inf") if duration is None else float(duration),
float("inf") if arrival is None else float(arrival),
transfers,
walking,
)
if ranking == "fewest_transfers":
return (
transfers,
float("inf") if arrival is None else float(arrival),
float("inf") if duration is None else float(duration),
walking,
)
if ranking == "earliest_arrival":
return (
float("inf") if arrival is None else float(arrival),
float("inf") if duration is None else float(duration),
transfers,
walking,
)
return (
float("inf") if arrival is None else float(arrival) + transfers * 600 + walking_seconds,
float("inf") if arrival is None else float(arrival),
transfers,
walking,
)
return sorted((dict(journey) for journey in journeys), key=key)
def _select_diverse_journeys(journeys: list[dict], *, limit: int) -> list[dict]:
selected: list[dict] = []
selected_exact: set[str] = set()
selected_diversity: set[tuple[object, ...]] = set()
for journey in journeys:
exact_key = _journey_key(journey)
if exact_key in selected_exact:
continue
diversity_key = _journey_diversity_key(journey)
if diversity_key in selected_diversity and len(selected) >= 3:
continue
selected.append(journey)
selected_exact.add(exact_key)
selected_diversity.add(diversity_key)
if len(selected) >= limit:
return selected
if len(selected) >= min(3, limit):
return selected
for journey in journeys:
exact_key = _journey_key(journey)
if exact_key in selected_exact:
continue
selected.append(journey)
selected_exact.add(exact_key)
if len(selected) >= min(3, limit):
break
return _ensure_walk_only_option(selected, journeys, limit=limit)
def _ensure_walk_only_option(selected: list[dict], ranked: list[dict], *, limit: int) -> list[dict]:
if any(_journey_is_walk_only(journey) for journey in selected):
return selected
walk = next((journey for journey in ranked if _journey_is_walk_only(journey)), None)
if walk is None:
return selected
if len(selected) < limit:
return [*selected, walk]
if selected:
selected[-1] = walk
return selected
def _journey_is_walk_only(journey: dict) -> bool:
legs = journey.get("legs") or []
return bool(legs) and all(leg.get("mode") == "walk" for leg in legs)
def _journey_diversity_key(journey: dict[str, Any]) -> tuple[object, ...]:
route_signature = tuple(
str(leg.get("route_ref") or leg.get("route_id") or leg.get("mode") or "")
for leg in journey.get("legs") or []
if leg.get("mode") != "walk"
)
departure = journey.get("departure_seconds")
time_band = None if departure is None else int(departure) // (30 * 60)
return (int(journey.get("transfers") or 0), route_signature, time_band)
def _csv_ints(value: object) -> list[int] | None:
if value is None:
return None
items = [item.strip() for item in str(value).split(",") if item.strip()]
if not items:
return None
return [int(item) for item in items]
def _stop_payload(stop) -> dict[str, Any]:
return {
"id": stop.id,
"dataset_id": stop.dataset_id,
"stop_id": stop.stop_id,
"name": stop.name,
"lat": stop.lat,
"lon": stop.lon,
}
def _prune_old_searches() -> None:
now = time.time()
stale = [
search_id
for search_id, state in _searches.items()
if now - state.updated_at > 15 * 60 or (state.complete and now - state.updated_at > 3 * 60)
]
for search_id in stale:
state = _searches.pop(search_id, None)
if state is not None:
_clear_inflight_search_locked(state)