from __future__ import annotations import json import hashlib import threading import time import uuid from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass, field from datetime import date, datetime, timedelta, timezone from typing import Any from sqlalchemy import select from app.address_search import is_location_token from app.db import SessionLocal from app.journey import find_journeys, parse_service_date, resolve_location_summary from app.models import JourneySearchCache from app.routing import direct_route_between_points, route_between_points MAX_PROGRESSIVE_TRANSFERS = 5 TRANSIT_STAGE_CACHE_TTL_SECONDS = 5 * 60 TRANSIT_STAGE_CACHE_MAX_ENTRIES = 256 PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS = 10 * 60 PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES = 128 JOURNEY_SEARCH_CACHE_VERSION = "journey-search-v7" _executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="journey-search") _lock = threading.RLock() _searches: dict[str, "_SearchState"] = {} _progressive_search_inflight: dict[tuple[object, ...], str] = {} _transit_stage_cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]] = {} _progressive_search_cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]] = {} @dataclass class _SearchState: id: str request: dict[str, Any] cache_key: tuple[object, ...] | None = None status: str = "queued" message: str = "Queued." stage: str = "queued" journeys: list[dict] = field(default_factory=list) routing: dict[str, Any] | None = None context: dict[str, Any] = field(default_factory=dict) error: str | None = None created_at: float = field(default_factory=time.time) updated_at: float = field(default_factory=time.time) complete: bool = False cancelled: bool = False def start_journey_search(request: dict[str, Any]) -> dict[str, Any]: key = _progressive_cache_key(request) cached = _progressive_cache_get(key) search_id = uuid.uuid4().hex state = _SearchState(id=search_id, request=dict(request), cache_key=key) if cached is not None: _apply_cached_payload(state, cached) with _lock: _prune_old_searches() if cached is None: existing_search_id = _progressive_search_inflight.get(key) existing_state = None if existing_search_id is None else _searches.get(existing_search_id) if existing_state is not None and not existing_state.complete and not existing_state.cancelled: return _payload(existing_state) _progressive_search_inflight[key] = search_id _searches[search_id] = state if cached is None: _executor.submit(_run_search, search_id) return journey_search_payload(search_id) def journey_search_payload(search_id: str) -> dict[str, Any]: with _lock: state = _searches.get(search_id) if state is None: raise KeyError(search_id) return _payload(state) def cancel_journey_search(search_id: str) -> dict[str, Any]: with _lock: state = _searches.get(search_id) if state is None: raise KeyError(search_id) state.cancelled = True if not state.complete: state.status = "cancelled" state.message = "Search cancelled." state.complete = True state.updated_at = time.time() _clear_inflight_search_locked(state) return _payload(state) def _run_search(search_id: str) -> None: with _lock: state = _searches.get(search_id) if state is None or state.cancelled: return state.status = "running" state.stage = "starting" state.message = "Starting search..." state.updated_at = time.time() request = dict(state.request) try: mode = str(request.get("mode") or "transit") if mode in {"walk", "drive", "car"}: _run_point_route_search(search_id, "drive" if mode == "car" else mode, request) else: _run_transit_search(search_id, request) except Exception as exc: # noqa: BLE001 - report progressive-search failure to client _publish_error(search_id, str(exc)) def _run_transit_search(search_id: str, request: dict[str, Any]) -> None: direct_only = bool(request.get("direct_only")) limit = max(3, min(int(request.get("limit") or 5), 10)) transfer_seconds = max(0, min(int(request.get("transfer_seconds") or 120), 3600)) source_ids = _csv_ints(request.get("source_id")) service_date = request.get("service_date") or None stages = [0] if direct_only else list(range(0, MAX_PROGRESSIVE_TRANSFERS + 1)) address_search = is_location_token(request.get("from_stop_id")) or is_location_token(request.get("to_stop_id")) stage_limit = limit if address_search else max(limit, 10) merged: dict[str, dict] = {} context: dict[str, Any] = {} diagnostics: dict[str, Any] = {"stages": []} best_count = 0 stale_stages = 0 for transfers in stages: if _is_cancelled(search_id): return label = "direct" if transfers == 0 else f"up to {transfers} transfer{'s' if transfers != 1 else ''}" _publish_status(search_id, "running", f"Searching {label}...", f"transfers_{transfers}") stage_started_at = time.monotonic() with SessionLocal() as db: result = _cached_find_journeys( db, from_stop_id=str(request.get("from_stop_id") or ""), to_stop_id=str(request.get("to_stop_id") or ""), via_stop_id=request.get("via_stop_id") or None, source_ids=source_ids, departure=str(request.get("departure") or "08:00"), service_date=service_date, max_transfers=transfers, transfer_seconds=transfer_seconds, limit=stage_limit, ) cache_status = str(result.pop("_cache_status", "miss")) elapsed_ms = int((time.monotonic() - stage_started_at) * 1000) stage_diagnostics = { "transfers": transfers, "cache": cache_status, "elapsed_ms": elapsed_ms, "journeys": len(result.get("journeys") or []), } result_diagnostics = result.get("diagnostics") if isinstance(result_diagnostics, dict): stage_diagnostics["details"] = result_diagnostics diagnostics["stages"].append(stage_diagnostics) context = _context_from_result(result) context["diagnostics"] = diagnostics before = len(merged) for journey in result.get("journeys") or []: merged.setdefault(_journey_key(journey), journey) ranked = _select_diverse_journeys(_rank_journeys(merged.values(), str(request.get("ranking") or "recommended")), limit=limit) _publish_results( search_id, journeys=ranked, context=context, status="running", stage=f"transfers_{transfers}", message=f"Found {len(ranked)} option{'s' if len(ranked) != 1 else ''}; still searching..." if not direct_only else "Direct search complete.", ) if len(merged) <= before and ranked: stale_stages += 1 else: stale_stages = 0 best_count = max(best_count, len(ranked)) if ranked and stale_stages >= 2 and transfers >= 2: break if _major_hub_address_stage_is_complete(result_diagnostics, ranked, transfers=transfers, limit=limit): break complete_message = ( f"Search complete. Found {best_count} option{'s' if best_count != 1 else ''}." if best_count else "Search complete. No route found in the imported timetable." ) _publish_complete(search_id, message=complete_message) payload = journey_search_payload(search_id) if payload.get("status") == "complete" and not payload.get("error"): _progressive_cache_put(_progressive_cache_key(request), payload) def _major_hub_address_stage_is_complete( diagnostics: dict[str, Any] | None, ranked: list[dict], *, transfers: int, limit: int, ) -> bool: if transfers < 1 or not ranked or not isinstance(diagnostics, dict): return False address_access = diagnostics.get("address_access") if not isinstance(address_access, dict) or not address_access.get("major_hubs"): return False return len(ranked) >= min(3, limit) def _run_point_route_search(search_id: str, mode: str, request: dict[str, Any]) -> None: _publish_status(search_id, "running", f"Searching {mode} route...", mode) with SessionLocal() as db: from_location = resolve_location_summary(db, str(request.get("from_stop_id") or ""), source_ids=_csv_ints(request.get("source_id"))) to_location = resolve_location_summary(db, str(request.get("to_stop_id") or ""), source_ids=_csv_ints(request.get("source_id"))) if from_location.lon is None or from_location.lat is None: raise ValueError("Selected start has no coordinates.") if to_location.lon is None or to_location.lat is None: raise ValueError("Selected destination has no coordinates.") try: route = route_between_points( db, from_lon=float(from_location.lon), from_lat=float(from_location.lat), to_lon=float(to_location.lon), to_lat=float(to_location.lat), mode=mode, max_visited=300_000, ) message = f"{mode.title()} route found." except Exception as exc: # noqa: BLE001 - point routing should still return an approximate connector route = direct_route_between_points( db, from_lon=float(from_location.lon), from_lat=float(from_location.lat), to_lon=float(to_location.lon), to_lat=float(to_location.lat), mode=mode, reason=str(exc), ) message = f"{mode.title()} route approximated." context = { "from": _stop_payload(from_location), "to": _stop_payload(to_location), "mode": mode, } _publish_routing(search_id, route, context=context, message=message) _publish_complete(search_id, message=f"{mode.title()} route complete.") payload = journey_search_payload(search_id) if payload.get("status") == "complete" and not payload.get("error"): _progressive_cache_put(_progressive_cache_key(request), payload) def _cached_find_journeys( db, *, from_stop_id: str, to_stop_id: str, via_stop_id: object, source_ids: list[int] | None, departure: str, service_date: object, max_transfers: int, transfer_seconds: int, limit: int, ) -> dict[str, Any]: key = ( from_stop_id, to_stop_id, str(via_stop_id or ""), tuple(sorted(int(source_id) for source_id in source_ids or [])), departure, str(service_date or ""), int(max_transfers), int(transfer_seconds), int(limit), ) now = time.monotonic() with _lock: cached = _transit_stage_cache.get(key) if cached is not None: expires_at, payload = cached if expires_at > now: return _with_cache_status(payload, "memory") _transit_stage_cache.pop(key, None) durable = _durable_cache_get("transit_stage", key) if durable is not None: with _lock: _transit_stage_cache[key] = (now + TRANSIT_STAGE_CACHE_TTL_SECONDS, json.loads(json.dumps(durable))) _prune_timed_cache(_transit_stage_cache, TRANSIT_STAGE_CACHE_MAX_ENTRIES) return _with_cache_status(durable, "persistent") result = find_journeys( db=db, from_stop_id=from_stop_id, to_stop_id=to_stop_id, via_stop_id=via_stop_id, source_ids=source_ids, departure=departure, service_date=service_date, max_transfers=max_transfers, transfer_seconds=transfer_seconds, limit=limit, ) stored_result = json.loads(json.dumps(result)) with _lock: _transit_stage_cache[key] = (now + TRANSIT_STAGE_CACHE_TTL_SECONDS, stored_result) _prune_timed_cache(_transit_stage_cache, TRANSIT_STAGE_CACHE_MAX_ENTRIES) _durable_cache_put("transit_stage", key, stored_result, ttl_seconds=TRANSIT_STAGE_CACHE_TTL_SECONDS) return _with_cache_status(result, "miss") def _with_cache_status(payload: dict[str, Any], cache_status: str) -> dict[str, Any]: copied = json.loads(json.dumps(payload)) copied["_cache_status"] = cache_status return copied def _prune_timed_cache(cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]], max_entries: int) -> None: if len(cache) <= max_entries: return oldest = sorted(cache.items(), key=lambda item: item[1][0])[: len(cache) - max_entries] for old_key, _ in oldest: cache.pop(old_key, None) def _durable_cache_get(cache_type: str, key: tuple[object, ...]) -> dict[str, Any] | None: storage_key = _durable_cache_key(cache_type, key) now = datetime.now(timezone.utc) try: with SessionLocal() as session: row = session.scalar(select(JourneySearchCache).where(JourneySearchCache.cache_key == storage_key)) if row is None: return None expires_at = _as_utc(row.expires_at) if expires_at is None or expires_at <= now: session.delete(row) session.commit() return None return json.loads(row.payload_json) except Exception: # noqa: BLE001 - cache misses must not break journey search return None def _durable_cache_put(cache_type: str, key: tuple[object, ...], payload: dict[str, Any], *, ttl_seconds: int) -> None: storage_key = _durable_cache_key(cache_type, key) now = datetime.now(timezone.utc) expires_at = now + timedelta(seconds=max(1, int(ttl_seconds))) try: with SessionLocal() as session: row = session.scalar(select(JourneySearchCache).where(JourneySearchCache.cache_key == storage_key)) if row is None: row = JourneySearchCache( cache_key=storage_key, cache_type=cache_type, payload_json=json.dumps(payload, separators=(",", ":")), created_at=now, updated_at=now, expires_at=expires_at, ) session.add(row) else: row.cache_type = cache_type row.payload_json = json.dumps(payload, separators=(",", ":")) row.updated_at = now row.expires_at = expires_at session.commit() except Exception: # noqa: BLE001 - cache writes are best-effort return def _durable_cache_key(cache_type: str, key: tuple[object, ...]) -> str: raw = json.dumps( { "version": JOURNEY_SEARCH_CACHE_VERSION, "cache_type": cache_type, "key": _json_safe(key), }, sort_keys=True, separators=(",", ":"), ) return hashlib.sha256(raw.encode("utf-8")).hexdigest() def _json_safe(value: object) -> object: if isinstance(value, tuple): return [_json_safe(item) for item in value] if isinstance(value, list): return [_json_safe(item) for item in value] if isinstance(value, dict): return {str(key): _json_safe(item) for key, item in sorted(value.items(), key=lambda item: str(item[0]))} if isinstance(value, (str, int, float, bool)) or value is None: return value if isinstance(value, (date, datetime)): return value.isoformat() return str(value) def _as_utc(value: datetime | None) -> datetime | None: if value is None: return None if value.tzinfo is None: return value.replace(tzinfo=timezone.utc) return value.astimezone(timezone.utc) def _progressive_cache_key(request: dict[str, Any]) -> tuple[object, ...]: source_ids = _csv_ints(request.get("source_id")) return ( str(request.get("mode") or "transit"), str(request.get("from_stop_id") or ""), str(request.get("to_stop_id") or ""), str(request.get("via_stop_id") or ""), tuple(sorted(int(source_id) for source_id in source_ids or [])), str(request.get("departure") or "08:00"), str(request.get("service_date") or ""), bool(request.get("direct_only")), str(request.get("ranking") or "recommended"), int(request.get("transfer_seconds") or 120), max(3, min(int(request.get("limit") or 5), 10)), ) def _progressive_cache_get(key: tuple[object, ...]) -> dict[str, Any] | None: now = time.monotonic() with _lock: cached = _progressive_search_cache.get(key) if cached is not None: expires_at, payload = cached if expires_at > now: copied = json.loads(json.dumps(payload)) copied["cache_status"] = "memory" return copied _progressive_search_cache.pop(key, None) durable = _durable_cache_get("progressive", key) if durable is None: return None with _lock: _progressive_search_cache[key] = (now + PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS, json.loads(json.dumps(durable))) _prune_timed_cache(_progressive_search_cache, PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES) copied = json.loads(json.dumps(durable)) copied["cache_status"] = "persistent" return copied def _progressive_cache_put(key: tuple[object, ...], payload: dict[str, Any]) -> None: stored_payload = json.loads(json.dumps(payload)) stored_payload.pop("cache_status", None) with _lock: _progressive_search_cache[key] = (time.monotonic() + PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS, stored_payload) _prune_timed_cache(_progressive_search_cache, PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES) _durable_cache_put("progressive", key, stored_payload, ttl_seconds=PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS) def _apply_cached_payload(state: _SearchState, payload: dict[str, Any]) -> None: state.status = str(payload.get("status") or "complete") state.message = "Cached result." state.stage = str(payload.get("stage") or "cached") state.journeys = json.loads(json.dumps(payload.get("journeys") or [])) state.routing = json.loads(json.dumps(payload.get("routing"))) if payload.get("routing") is not None else None state.context = { key: value for key, value in payload.items() if key not in {"search_id", "status", "stage", "message", "complete", "error", "journeys", "routing", "created_at", "updated_at"} } state.error = None state.complete = True state.updated_at = time.time() def _publish_status(search_id: str, status: str, message: str, stage: str) -> None: with _lock: state = _searches.get(search_id) if state is None or state.cancelled: return state.status = status state.message = message state.stage = stage state.updated_at = time.time() def _publish_results(search_id: str, *, journeys: list[dict], context: dict[str, Any], status: str, stage: str, message: str) -> None: with _lock: state = _searches.get(search_id) if state is None or state.cancelled: return state.status = status state.stage = stage state.message = message state.journeys = list(journeys) state.context = dict(context) state.updated_at = time.time() def _publish_routing(search_id: str, routing: dict[str, Any], *, context: dict[str, Any], message: str) -> None: with _lock: state = _searches.get(search_id) if state is None or state.cancelled: return state.status = "running" state.stage = str(routing.get("mode") or "route") state.message = message state.routing = routing state.context = dict(context) state.updated_at = time.time() def _publish_complete(search_id: str, *, message: str) -> None: with _lock: state = _searches.get(search_id) if state is None or state.cancelled: return state.status = "complete" state.message = message state.complete = True state.updated_at = time.time() _clear_inflight_search_locked(state) def _publish_error(search_id: str, message: str) -> None: with _lock: state = _searches.get(search_id) if state is None: return state.status = "error" state.stage = "error" state.message = message state.error = message state.complete = True state.updated_at = time.time() _clear_inflight_search_locked(state) def _clear_inflight_search_locked(state: _SearchState) -> None: if state.cache_key is not None and _progressive_search_inflight.get(state.cache_key) == state.id: _progressive_search_inflight.pop(state.cache_key, None) def _is_cancelled(search_id: str) -> bool: with _lock: state = _searches.get(search_id) return state is None or state.cancelled def _payload(state: _SearchState) -> dict[str, Any]: return { "search_id": state.id, "status": state.status, "stage": state.stage, "message": state.message, "complete": state.complete, "error": state.error, "journeys": json.loads(json.dumps(state.journeys)), "routing": json.loads(json.dumps(state.routing)) if state.routing is not None else None, "created_at": state.created_at, "updated_at": state.updated_at, **json.loads(json.dumps(state.context)), } def _context_from_result(result: dict[str, Any]) -> dict[str, Any]: return { key: value for key, value in result.items() if key not in {"journeys"} and key not in {"error"} } def _journey_key(journey: dict[str, Any]) -> str: parts = [] for leg in journey.get("legs") or []: parts.append( "|".join( str(part or "") for part in [ leg.get("dataset_id"), leg.get("mode"), leg.get("route_id"), leg.get("trip_id"), (leg.get("from") or {}).get("stop_id") or (leg.get("from") or {}).get("name"), (leg.get("to") or {}).get("stop_id") or (leg.get("to") or {}).get("name"), leg.get("departure_time"), leg.get("arrival_time"), ] ) ) return "||".join(parts) def _rank_journeys(journeys, ranking: str) -> list[dict]: def key(journey: dict[str, Any]) -> tuple[float, float, int, float]: departure = journey.get("departure_seconds") arrival = journey.get("arrival_seconds") duration = journey.get("duration_minutes") transfers = int(journey.get("transfers") or 0) walking = sum(float(leg.get("distance_m") or 0) for leg in journey.get("legs") or [] if leg.get("mode") == "walk") walking_seconds = walking / 1.35 if ranking == "duration": return ( float("inf") if duration is None else float(duration), float("inf") if arrival is None else float(arrival), transfers, walking, ) if ranking == "fewest_transfers": return ( transfers, float("inf") if arrival is None else float(arrival), float("inf") if duration is None else float(duration), walking, ) if ranking == "earliest_arrival": return ( float("inf") if arrival is None else float(arrival), float("inf") if duration is None else float(duration), transfers, walking, ) return ( float("inf") if arrival is None else float(arrival) + transfers * 600 + walking_seconds, float("inf") if arrival is None else float(arrival), transfers, walking, ) return sorted((dict(journey) for journey in journeys), key=key) def _select_diverse_journeys(journeys: list[dict], *, limit: int) -> list[dict]: selected: list[dict] = [] selected_exact: set[str] = set() selected_diversity: set[tuple[object, ...]] = set() for journey in journeys: exact_key = _journey_key(journey) if exact_key in selected_exact: continue diversity_key = _journey_diversity_key(journey) if diversity_key in selected_diversity and len(selected) >= 3: continue selected.append(journey) selected_exact.add(exact_key) selected_diversity.add(diversity_key) if len(selected) >= limit: return selected if len(selected) >= min(3, limit): return selected for journey in journeys: exact_key = _journey_key(journey) if exact_key in selected_exact: continue selected.append(journey) selected_exact.add(exact_key) if len(selected) >= min(3, limit): break return _ensure_walk_only_option(selected, journeys, limit=limit) def _ensure_walk_only_option(selected: list[dict], ranked: list[dict], *, limit: int) -> list[dict]: if any(_journey_is_walk_only(journey) for journey in selected): return selected walk = next((journey for journey in ranked if _journey_is_walk_only(journey)), None) if walk is None: return selected if len(selected) < limit: return [*selected, walk] if selected: selected[-1] = walk return selected def _journey_is_walk_only(journey: dict) -> bool: legs = journey.get("legs") or [] return bool(legs) and all(leg.get("mode") == "walk" for leg in legs) def _journey_diversity_key(journey: dict[str, Any]) -> tuple[object, ...]: route_signature = tuple( str(leg.get("route_ref") or leg.get("route_id") or leg.get("mode") or "") for leg in journey.get("legs") or [] if leg.get("mode") != "walk" ) departure = journey.get("departure_seconds") time_band = None if departure is None else int(departure) // (30 * 60) return (int(journey.get("transfers") or 0), route_signature, time_band) def _csv_ints(value: object) -> list[int] | None: if value is None: return None items = [item.strip() for item in str(value).split(",") if item.strip()] if not items: return None return [int(item) for item in items] def _stop_payload(stop) -> dict[str, Any]: return { "id": stop.id, "dataset_id": stop.dataset_id, "stop_id": stop.stop_id, "name": stop.name, "lat": stop.lat, "lon": stop.lon, } def _prune_old_searches() -> None: now = time.time() stale = [ search_id for search_id, state in _searches.items() if now - state.updated_at > 15 * 60 or (state.complete and now - state.updated_at > 3 * 60) ] for search_id in stale: state = _searches.pop(search_id, None) if state is not None: _clear_inflight_search_locked(state)