718 lines
27 KiB
Python
718 lines
27 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import hashlib
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from dataclasses import dataclass, field
|
|
from datetime import date, datetime, timedelta, timezone
|
|
from typing import Any
|
|
|
|
from sqlalchemy import select
|
|
|
|
from app.address_search import is_location_token
|
|
from app.db import SessionLocal
|
|
from app.journey import find_journeys, parse_service_date, resolve_location_summary
|
|
from app.models import JourneySearchCache
|
|
from app.routing import direct_route_between_points, route_between_points
|
|
|
|
|
|
MAX_PROGRESSIVE_TRANSFERS = 5
|
|
TRANSIT_STAGE_CACHE_TTL_SECONDS = 5 * 60
|
|
TRANSIT_STAGE_CACHE_MAX_ENTRIES = 256
|
|
PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS = 10 * 60
|
|
PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES = 128
|
|
JOURNEY_SEARCH_CACHE_VERSION = "journey-search-v7"
|
|
_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="journey-search")
|
|
_lock = threading.RLock()
|
|
_searches: dict[str, "_SearchState"] = {}
|
|
_progressive_search_inflight: dict[tuple[object, ...], str] = {}
|
|
_transit_stage_cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]] = {}
|
|
_progressive_search_cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]] = {}
|
|
|
|
|
|
@dataclass
|
|
class _SearchState:
|
|
id: str
|
|
request: dict[str, Any]
|
|
cache_key: tuple[object, ...] | None = None
|
|
status: str = "queued"
|
|
message: str = "Queued."
|
|
stage: str = "queued"
|
|
journeys: list[dict] = field(default_factory=list)
|
|
routing: dict[str, Any] | None = None
|
|
context: dict[str, Any] = field(default_factory=dict)
|
|
error: str | None = None
|
|
created_at: float = field(default_factory=time.time)
|
|
updated_at: float = field(default_factory=time.time)
|
|
complete: bool = False
|
|
cancelled: bool = False
|
|
|
|
|
|
def start_journey_search(request: dict[str, Any]) -> dict[str, Any]:
|
|
key = _progressive_cache_key(request)
|
|
cached = _progressive_cache_get(key)
|
|
search_id = uuid.uuid4().hex
|
|
state = _SearchState(id=search_id, request=dict(request), cache_key=key)
|
|
if cached is not None:
|
|
_apply_cached_payload(state, cached)
|
|
with _lock:
|
|
_prune_old_searches()
|
|
if cached is None:
|
|
existing_search_id = _progressive_search_inflight.get(key)
|
|
existing_state = None if existing_search_id is None else _searches.get(existing_search_id)
|
|
if existing_state is not None and not existing_state.complete and not existing_state.cancelled:
|
|
return _payload(existing_state)
|
|
_progressive_search_inflight[key] = search_id
|
|
_searches[search_id] = state
|
|
if cached is None:
|
|
_executor.submit(_run_search, search_id)
|
|
return journey_search_payload(search_id)
|
|
|
|
|
|
def journey_search_payload(search_id: str) -> dict[str, Any]:
|
|
with _lock:
|
|
state = _searches.get(search_id)
|
|
if state is None:
|
|
raise KeyError(search_id)
|
|
return _payload(state)
|
|
|
|
|
|
def cancel_journey_search(search_id: str) -> dict[str, Any]:
|
|
with _lock:
|
|
state = _searches.get(search_id)
|
|
if state is None:
|
|
raise KeyError(search_id)
|
|
state.cancelled = True
|
|
if not state.complete:
|
|
state.status = "cancelled"
|
|
state.message = "Search cancelled."
|
|
state.complete = True
|
|
state.updated_at = time.time()
|
|
_clear_inflight_search_locked(state)
|
|
return _payload(state)
|
|
|
|
|
|
def _run_search(search_id: str) -> None:
|
|
with _lock:
|
|
state = _searches.get(search_id)
|
|
if state is None or state.cancelled:
|
|
return
|
|
state.status = "running"
|
|
state.stage = "starting"
|
|
state.message = "Starting search..."
|
|
state.updated_at = time.time()
|
|
request = dict(state.request)
|
|
try:
|
|
mode = str(request.get("mode") or "transit")
|
|
if mode in {"walk", "drive", "car"}:
|
|
_run_point_route_search(search_id, "drive" if mode == "car" else mode, request)
|
|
else:
|
|
_run_transit_search(search_id, request)
|
|
except Exception as exc: # noqa: BLE001 - report progressive-search failure to client
|
|
_publish_error(search_id, str(exc))
|
|
|
|
|
|
def _run_transit_search(search_id: str, request: dict[str, Any]) -> None:
|
|
direct_only = bool(request.get("direct_only"))
|
|
limit = max(3, min(int(request.get("limit") or 5), 10))
|
|
transfer_seconds = max(0, min(int(request.get("transfer_seconds") or 120), 3600))
|
|
source_ids = _csv_ints(request.get("source_id"))
|
|
service_date = request.get("service_date") or None
|
|
stages = [0] if direct_only else list(range(0, MAX_PROGRESSIVE_TRANSFERS + 1))
|
|
address_search = is_location_token(request.get("from_stop_id")) or is_location_token(request.get("to_stop_id"))
|
|
stage_limit = limit if address_search else max(limit, 10)
|
|
merged: dict[str, dict] = {}
|
|
context: dict[str, Any] = {}
|
|
diagnostics: dict[str, Any] = {"stages": []}
|
|
best_count = 0
|
|
stale_stages = 0
|
|
for transfers in stages:
|
|
if _is_cancelled(search_id):
|
|
return
|
|
label = "direct" if transfers == 0 else f"up to {transfers} transfer{'s' if transfers != 1 else ''}"
|
|
_publish_status(search_id, "running", f"Searching {label}...", f"transfers_{transfers}")
|
|
stage_started_at = time.monotonic()
|
|
with SessionLocal() as db:
|
|
result = _cached_find_journeys(
|
|
db,
|
|
from_stop_id=str(request.get("from_stop_id") or ""),
|
|
to_stop_id=str(request.get("to_stop_id") or ""),
|
|
via_stop_id=request.get("via_stop_id") or None,
|
|
source_ids=source_ids,
|
|
departure=str(request.get("departure") or "08:00"),
|
|
service_date=service_date,
|
|
max_transfers=transfers,
|
|
transfer_seconds=transfer_seconds,
|
|
limit=stage_limit,
|
|
)
|
|
cache_status = str(result.pop("_cache_status", "miss"))
|
|
elapsed_ms = int((time.monotonic() - stage_started_at) * 1000)
|
|
stage_diagnostics = {
|
|
"transfers": transfers,
|
|
"cache": cache_status,
|
|
"elapsed_ms": elapsed_ms,
|
|
"journeys": len(result.get("journeys") or []),
|
|
}
|
|
result_diagnostics = result.get("diagnostics")
|
|
if isinstance(result_diagnostics, dict):
|
|
stage_diagnostics["details"] = result_diagnostics
|
|
diagnostics["stages"].append(stage_diagnostics)
|
|
context = _context_from_result(result)
|
|
context["diagnostics"] = diagnostics
|
|
before = len(merged)
|
|
for journey in result.get("journeys") or []:
|
|
merged.setdefault(_journey_key(journey), journey)
|
|
ranked = _select_diverse_journeys(_rank_journeys(merged.values(), str(request.get("ranking") or "recommended")), limit=limit)
|
|
_publish_results(
|
|
search_id,
|
|
journeys=ranked,
|
|
context=context,
|
|
status="running",
|
|
stage=f"transfers_{transfers}",
|
|
message=f"Found {len(ranked)} option{'s' if len(ranked) != 1 else ''}; still searching..." if not direct_only else "Direct search complete.",
|
|
)
|
|
if len(merged) <= before and ranked:
|
|
stale_stages += 1
|
|
else:
|
|
stale_stages = 0
|
|
best_count = max(best_count, len(ranked))
|
|
if ranked and stale_stages >= 2 and transfers >= 2:
|
|
break
|
|
if _major_hub_address_stage_is_complete(result_diagnostics, ranked, transfers=transfers, limit=limit):
|
|
break
|
|
complete_message = (
|
|
f"Search complete. Found {best_count} option{'s' if best_count != 1 else ''}."
|
|
if best_count
|
|
else "Search complete. No route found in the imported timetable."
|
|
)
|
|
_publish_complete(search_id, message=complete_message)
|
|
payload = journey_search_payload(search_id)
|
|
if payload.get("status") == "complete" and not payload.get("error"):
|
|
_progressive_cache_put(_progressive_cache_key(request), payload)
|
|
|
|
|
|
def _major_hub_address_stage_is_complete(
|
|
diagnostics: dict[str, Any] | None,
|
|
ranked: list[dict],
|
|
*,
|
|
transfers: int,
|
|
limit: int,
|
|
) -> bool:
|
|
if transfers < 1 or not ranked or not isinstance(diagnostics, dict):
|
|
return False
|
|
address_access = diagnostics.get("address_access")
|
|
if not isinstance(address_access, dict) or not address_access.get("major_hubs"):
|
|
return False
|
|
return len(ranked) >= min(3, limit)
|
|
|
|
|
|
def _run_point_route_search(search_id: str, mode: str, request: dict[str, Any]) -> None:
|
|
_publish_status(search_id, "running", f"Searching {mode} route...", mode)
|
|
with SessionLocal() as db:
|
|
from_location = resolve_location_summary(db, str(request.get("from_stop_id") or ""), source_ids=_csv_ints(request.get("source_id")))
|
|
to_location = resolve_location_summary(db, str(request.get("to_stop_id") or ""), source_ids=_csv_ints(request.get("source_id")))
|
|
if from_location.lon is None or from_location.lat is None:
|
|
raise ValueError("Selected start has no coordinates.")
|
|
if to_location.lon is None or to_location.lat is None:
|
|
raise ValueError("Selected destination has no coordinates.")
|
|
try:
|
|
route = route_between_points(
|
|
db,
|
|
from_lon=float(from_location.lon),
|
|
from_lat=float(from_location.lat),
|
|
to_lon=float(to_location.lon),
|
|
to_lat=float(to_location.lat),
|
|
mode=mode,
|
|
max_visited=300_000,
|
|
)
|
|
message = f"{mode.title()} route found."
|
|
except Exception as exc: # noqa: BLE001 - point routing should still return an approximate connector
|
|
route = direct_route_between_points(
|
|
db,
|
|
from_lon=float(from_location.lon),
|
|
from_lat=float(from_location.lat),
|
|
to_lon=float(to_location.lon),
|
|
to_lat=float(to_location.lat),
|
|
mode=mode,
|
|
reason=str(exc),
|
|
)
|
|
message = f"{mode.title()} route approximated."
|
|
context = {
|
|
"from": _stop_payload(from_location),
|
|
"to": _stop_payload(to_location),
|
|
"mode": mode,
|
|
}
|
|
_publish_routing(search_id, route, context=context, message=message)
|
|
_publish_complete(search_id, message=f"{mode.title()} route complete.")
|
|
payload = journey_search_payload(search_id)
|
|
if payload.get("status") == "complete" and not payload.get("error"):
|
|
_progressive_cache_put(_progressive_cache_key(request), payload)
|
|
|
|
|
|
def _cached_find_journeys(
|
|
db,
|
|
*,
|
|
from_stop_id: str,
|
|
to_stop_id: str,
|
|
via_stop_id: object,
|
|
source_ids: list[int] | None,
|
|
departure: str,
|
|
service_date: object,
|
|
max_transfers: int,
|
|
transfer_seconds: int,
|
|
limit: int,
|
|
) -> dict[str, Any]:
|
|
key = (
|
|
from_stop_id,
|
|
to_stop_id,
|
|
str(via_stop_id or ""),
|
|
tuple(sorted(int(source_id) for source_id in source_ids or [])),
|
|
departure,
|
|
str(service_date or ""),
|
|
int(max_transfers),
|
|
int(transfer_seconds),
|
|
int(limit),
|
|
)
|
|
now = time.monotonic()
|
|
with _lock:
|
|
cached = _transit_stage_cache.get(key)
|
|
if cached is not None:
|
|
expires_at, payload = cached
|
|
if expires_at > now:
|
|
return _with_cache_status(payload, "memory")
|
|
_transit_stage_cache.pop(key, None)
|
|
durable = _durable_cache_get("transit_stage", key)
|
|
if durable is not None:
|
|
with _lock:
|
|
_transit_stage_cache[key] = (now + TRANSIT_STAGE_CACHE_TTL_SECONDS, json.loads(json.dumps(durable)))
|
|
_prune_timed_cache(_transit_stage_cache, TRANSIT_STAGE_CACHE_MAX_ENTRIES)
|
|
return _with_cache_status(durable, "persistent")
|
|
result = find_journeys(
|
|
db=db,
|
|
from_stop_id=from_stop_id,
|
|
to_stop_id=to_stop_id,
|
|
via_stop_id=via_stop_id,
|
|
source_ids=source_ids,
|
|
departure=departure,
|
|
service_date=service_date,
|
|
max_transfers=max_transfers,
|
|
transfer_seconds=transfer_seconds,
|
|
limit=limit,
|
|
)
|
|
stored_result = json.loads(json.dumps(result))
|
|
with _lock:
|
|
_transit_stage_cache[key] = (now + TRANSIT_STAGE_CACHE_TTL_SECONDS, stored_result)
|
|
_prune_timed_cache(_transit_stage_cache, TRANSIT_STAGE_CACHE_MAX_ENTRIES)
|
|
_durable_cache_put("transit_stage", key, stored_result, ttl_seconds=TRANSIT_STAGE_CACHE_TTL_SECONDS)
|
|
return _with_cache_status(result, "miss")
|
|
|
|
|
|
def _with_cache_status(payload: dict[str, Any], cache_status: str) -> dict[str, Any]:
|
|
copied = json.loads(json.dumps(payload))
|
|
copied["_cache_status"] = cache_status
|
|
return copied
|
|
|
|
|
|
def _prune_timed_cache(cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]], max_entries: int) -> None:
|
|
if len(cache) <= max_entries:
|
|
return
|
|
oldest = sorted(cache.items(), key=lambda item: item[1][0])[: len(cache) - max_entries]
|
|
for old_key, _ in oldest:
|
|
cache.pop(old_key, None)
|
|
|
|
|
|
def _durable_cache_get(cache_type: str, key: tuple[object, ...]) -> dict[str, Any] | None:
|
|
storage_key = _durable_cache_key(cache_type, key)
|
|
now = datetime.now(timezone.utc)
|
|
try:
|
|
with SessionLocal() as session:
|
|
row = session.scalar(select(JourneySearchCache).where(JourneySearchCache.cache_key == storage_key))
|
|
if row is None:
|
|
return None
|
|
expires_at = _as_utc(row.expires_at)
|
|
if expires_at is None or expires_at <= now:
|
|
session.delete(row)
|
|
session.commit()
|
|
return None
|
|
return json.loads(row.payload_json)
|
|
except Exception: # noqa: BLE001 - cache misses must not break journey search
|
|
return None
|
|
|
|
|
|
def _durable_cache_put(cache_type: str, key: tuple[object, ...], payload: dict[str, Any], *, ttl_seconds: int) -> None:
|
|
storage_key = _durable_cache_key(cache_type, key)
|
|
now = datetime.now(timezone.utc)
|
|
expires_at = now + timedelta(seconds=max(1, int(ttl_seconds)))
|
|
try:
|
|
with SessionLocal() as session:
|
|
row = session.scalar(select(JourneySearchCache).where(JourneySearchCache.cache_key == storage_key))
|
|
if row is None:
|
|
row = JourneySearchCache(
|
|
cache_key=storage_key,
|
|
cache_type=cache_type,
|
|
payload_json=json.dumps(payload, separators=(",", ":")),
|
|
created_at=now,
|
|
updated_at=now,
|
|
expires_at=expires_at,
|
|
)
|
|
session.add(row)
|
|
else:
|
|
row.cache_type = cache_type
|
|
row.payload_json = json.dumps(payload, separators=(",", ":"))
|
|
row.updated_at = now
|
|
row.expires_at = expires_at
|
|
session.commit()
|
|
except Exception: # noqa: BLE001 - cache writes are best-effort
|
|
return
|
|
|
|
|
|
def _durable_cache_key(cache_type: str, key: tuple[object, ...]) -> str:
|
|
raw = json.dumps(
|
|
{
|
|
"version": JOURNEY_SEARCH_CACHE_VERSION,
|
|
"cache_type": cache_type,
|
|
"key": _json_safe(key),
|
|
},
|
|
sort_keys=True,
|
|
separators=(",", ":"),
|
|
)
|
|
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
|
|
|
|
|
def _json_safe(value: object) -> object:
|
|
if isinstance(value, tuple):
|
|
return [_json_safe(item) for item in value]
|
|
if isinstance(value, list):
|
|
return [_json_safe(item) for item in value]
|
|
if isinstance(value, dict):
|
|
return {str(key): _json_safe(item) for key, item in sorted(value.items(), key=lambda item: str(item[0]))}
|
|
if isinstance(value, (str, int, float, bool)) or value is None:
|
|
return value
|
|
if isinstance(value, (date, datetime)):
|
|
return value.isoformat()
|
|
return str(value)
|
|
|
|
|
|
def _as_utc(value: datetime | None) -> datetime | None:
|
|
if value is None:
|
|
return None
|
|
if value.tzinfo is None:
|
|
return value.replace(tzinfo=timezone.utc)
|
|
return value.astimezone(timezone.utc)
|
|
|
|
|
|
def _progressive_cache_key(request: dict[str, Any]) -> tuple[object, ...]:
|
|
source_ids = _csv_ints(request.get("source_id"))
|
|
return (
|
|
str(request.get("mode") or "transit"),
|
|
str(request.get("from_stop_id") or ""),
|
|
str(request.get("to_stop_id") or ""),
|
|
str(request.get("via_stop_id") or ""),
|
|
tuple(sorted(int(source_id) for source_id in source_ids or [])),
|
|
str(request.get("departure") or "08:00"),
|
|
str(request.get("service_date") or ""),
|
|
bool(request.get("direct_only")),
|
|
str(request.get("ranking") or "recommended"),
|
|
int(request.get("transfer_seconds") or 120),
|
|
max(3, min(int(request.get("limit") or 5), 10)),
|
|
)
|
|
|
|
|
|
def _progressive_cache_get(key: tuple[object, ...]) -> dict[str, Any] | None:
|
|
now = time.monotonic()
|
|
with _lock:
|
|
cached = _progressive_search_cache.get(key)
|
|
if cached is not None:
|
|
expires_at, payload = cached
|
|
if expires_at > now:
|
|
copied = json.loads(json.dumps(payload))
|
|
copied["cache_status"] = "memory"
|
|
return copied
|
|
_progressive_search_cache.pop(key, None)
|
|
durable = _durable_cache_get("progressive", key)
|
|
if durable is None:
|
|
return None
|
|
with _lock:
|
|
_progressive_search_cache[key] = (now + PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS, json.loads(json.dumps(durable)))
|
|
_prune_timed_cache(_progressive_search_cache, PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES)
|
|
copied = json.loads(json.dumps(durable))
|
|
copied["cache_status"] = "persistent"
|
|
return copied
|
|
|
|
|
|
def _progressive_cache_put(key: tuple[object, ...], payload: dict[str, Any]) -> None:
|
|
stored_payload = json.loads(json.dumps(payload))
|
|
stored_payload.pop("cache_status", None)
|
|
with _lock:
|
|
_progressive_search_cache[key] = (time.monotonic() + PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS, stored_payload)
|
|
_prune_timed_cache(_progressive_search_cache, PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES)
|
|
_durable_cache_put("progressive", key, stored_payload, ttl_seconds=PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS)
|
|
|
|
|
|
def _apply_cached_payload(state: _SearchState, payload: dict[str, Any]) -> None:
|
|
state.status = str(payload.get("status") or "complete")
|
|
state.message = "Cached result."
|
|
state.stage = str(payload.get("stage") or "cached")
|
|
state.journeys = json.loads(json.dumps(payload.get("journeys") or []))
|
|
state.routing = json.loads(json.dumps(payload.get("routing"))) if payload.get("routing") is not None else None
|
|
state.context = {
|
|
key: value
|
|
for key, value in payload.items()
|
|
if key not in {"search_id", "status", "stage", "message", "complete", "error", "journeys", "routing", "created_at", "updated_at"}
|
|
}
|
|
state.error = None
|
|
state.complete = True
|
|
state.updated_at = time.time()
|
|
|
|
|
|
def _publish_status(search_id: str, status: str, message: str, stage: str) -> None:
|
|
with _lock:
|
|
state = _searches.get(search_id)
|
|
if state is None or state.cancelled:
|
|
return
|
|
state.status = status
|
|
state.message = message
|
|
state.stage = stage
|
|
state.updated_at = time.time()
|
|
|
|
|
|
def _publish_results(search_id: str, *, journeys: list[dict], context: dict[str, Any], status: str, stage: str, message: str) -> None:
|
|
with _lock:
|
|
state = _searches.get(search_id)
|
|
if state is None or state.cancelled:
|
|
return
|
|
state.status = status
|
|
state.stage = stage
|
|
state.message = message
|
|
state.journeys = list(journeys)
|
|
state.context = dict(context)
|
|
state.updated_at = time.time()
|
|
|
|
|
|
def _publish_routing(search_id: str, routing: dict[str, Any], *, context: dict[str, Any], message: str) -> None:
|
|
with _lock:
|
|
state = _searches.get(search_id)
|
|
if state is None or state.cancelled:
|
|
return
|
|
state.status = "running"
|
|
state.stage = str(routing.get("mode") or "route")
|
|
state.message = message
|
|
state.routing = routing
|
|
state.context = dict(context)
|
|
state.updated_at = time.time()
|
|
|
|
|
|
def _publish_complete(search_id: str, *, message: str) -> None:
|
|
with _lock:
|
|
state = _searches.get(search_id)
|
|
if state is None or state.cancelled:
|
|
return
|
|
state.status = "complete"
|
|
state.message = message
|
|
state.complete = True
|
|
state.updated_at = time.time()
|
|
_clear_inflight_search_locked(state)
|
|
|
|
|
|
def _publish_error(search_id: str, message: str) -> None:
|
|
with _lock:
|
|
state = _searches.get(search_id)
|
|
if state is None:
|
|
return
|
|
state.status = "error"
|
|
state.stage = "error"
|
|
state.message = message
|
|
state.error = message
|
|
state.complete = True
|
|
state.updated_at = time.time()
|
|
_clear_inflight_search_locked(state)
|
|
|
|
|
|
def _clear_inflight_search_locked(state: _SearchState) -> None:
|
|
if state.cache_key is not None and _progressive_search_inflight.get(state.cache_key) == state.id:
|
|
_progressive_search_inflight.pop(state.cache_key, None)
|
|
|
|
|
|
def _is_cancelled(search_id: str) -> bool:
|
|
with _lock:
|
|
state = _searches.get(search_id)
|
|
return state is None or state.cancelled
|
|
|
|
|
|
def _payload(state: _SearchState) -> dict[str, Any]:
|
|
return {
|
|
"search_id": state.id,
|
|
"status": state.status,
|
|
"stage": state.stage,
|
|
"message": state.message,
|
|
"complete": state.complete,
|
|
"error": state.error,
|
|
"journeys": json.loads(json.dumps(state.journeys)),
|
|
"routing": json.loads(json.dumps(state.routing)) if state.routing is not None else None,
|
|
"created_at": state.created_at,
|
|
"updated_at": state.updated_at,
|
|
**json.loads(json.dumps(state.context)),
|
|
}
|
|
|
|
|
|
def _context_from_result(result: dict[str, Any]) -> dict[str, Any]:
|
|
return {
|
|
key: value
|
|
for key, value in result.items()
|
|
if key not in {"journeys"} and key not in {"error"}
|
|
}
|
|
|
|
|
|
def _journey_key(journey: dict[str, Any]) -> str:
|
|
parts = []
|
|
for leg in journey.get("legs") or []:
|
|
parts.append(
|
|
"|".join(
|
|
str(part or "")
|
|
for part in [
|
|
leg.get("dataset_id"),
|
|
leg.get("mode"),
|
|
leg.get("route_id"),
|
|
leg.get("trip_id"),
|
|
(leg.get("from") or {}).get("stop_id") or (leg.get("from") or {}).get("name"),
|
|
(leg.get("to") or {}).get("stop_id") or (leg.get("to") or {}).get("name"),
|
|
leg.get("departure_time"),
|
|
leg.get("arrival_time"),
|
|
]
|
|
)
|
|
)
|
|
return "||".join(parts)
|
|
|
|
|
|
def _rank_journeys(journeys, ranking: str) -> list[dict]:
|
|
def key(journey: dict[str, Any]) -> tuple[float, float, int, float]:
|
|
departure = journey.get("departure_seconds")
|
|
arrival = journey.get("arrival_seconds")
|
|
duration = journey.get("duration_minutes")
|
|
transfers = int(journey.get("transfers") or 0)
|
|
walking = sum(float(leg.get("distance_m") or 0) for leg in journey.get("legs") or [] if leg.get("mode") == "walk")
|
|
walking_seconds = walking / 1.35
|
|
if ranking == "duration":
|
|
return (
|
|
float("inf") if duration is None else float(duration),
|
|
float("inf") if arrival is None else float(arrival),
|
|
transfers,
|
|
walking,
|
|
)
|
|
if ranking == "fewest_transfers":
|
|
return (
|
|
transfers,
|
|
float("inf") if arrival is None else float(arrival),
|
|
float("inf") if duration is None else float(duration),
|
|
walking,
|
|
)
|
|
if ranking == "earliest_arrival":
|
|
return (
|
|
float("inf") if arrival is None else float(arrival),
|
|
float("inf") if duration is None else float(duration),
|
|
transfers,
|
|
walking,
|
|
)
|
|
return (
|
|
float("inf") if arrival is None else float(arrival) + transfers * 600 + walking_seconds,
|
|
float("inf") if arrival is None else float(arrival),
|
|
transfers,
|
|
walking,
|
|
)
|
|
|
|
return sorted((dict(journey) for journey in journeys), key=key)
|
|
|
|
|
|
def _select_diverse_journeys(journeys: list[dict], *, limit: int) -> list[dict]:
|
|
selected: list[dict] = []
|
|
selected_exact: set[str] = set()
|
|
selected_diversity: set[tuple[object, ...]] = set()
|
|
for journey in journeys:
|
|
exact_key = _journey_key(journey)
|
|
if exact_key in selected_exact:
|
|
continue
|
|
diversity_key = _journey_diversity_key(journey)
|
|
if diversity_key in selected_diversity and len(selected) >= 3:
|
|
continue
|
|
selected.append(journey)
|
|
selected_exact.add(exact_key)
|
|
selected_diversity.add(diversity_key)
|
|
if len(selected) >= limit:
|
|
return selected
|
|
if len(selected) >= min(3, limit):
|
|
return selected
|
|
for journey in journeys:
|
|
exact_key = _journey_key(journey)
|
|
if exact_key in selected_exact:
|
|
continue
|
|
selected.append(journey)
|
|
selected_exact.add(exact_key)
|
|
if len(selected) >= min(3, limit):
|
|
break
|
|
return _ensure_walk_only_option(selected, journeys, limit=limit)
|
|
|
|
|
|
def _ensure_walk_only_option(selected: list[dict], ranked: list[dict], *, limit: int) -> list[dict]:
|
|
if any(_journey_is_walk_only(journey) for journey in selected):
|
|
return selected
|
|
walk = next((journey for journey in ranked if _journey_is_walk_only(journey)), None)
|
|
if walk is None:
|
|
return selected
|
|
if len(selected) < limit:
|
|
return [*selected, walk]
|
|
if selected:
|
|
selected[-1] = walk
|
|
return selected
|
|
|
|
|
|
def _journey_is_walk_only(journey: dict) -> bool:
|
|
legs = journey.get("legs") or []
|
|
return bool(legs) and all(leg.get("mode") == "walk" for leg in legs)
|
|
|
|
|
|
def _journey_diversity_key(journey: dict[str, Any]) -> tuple[object, ...]:
|
|
route_signature = tuple(
|
|
str(leg.get("route_ref") or leg.get("route_id") or leg.get("mode") or "")
|
|
for leg in journey.get("legs") or []
|
|
if leg.get("mode") != "walk"
|
|
)
|
|
departure = journey.get("departure_seconds")
|
|
time_band = None if departure is None else int(departure) // (30 * 60)
|
|
return (int(journey.get("transfers") or 0), route_signature, time_band)
|
|
|
|
|
|
def _csv_ints(value: object) -> list[int] | None:
|
|
if value is None:
|
|
return None
|
|
items = [item.strip() for item in str(value).split(",") if item.strip()]
|
|
if not items:
|
|
return None
|
|
return [int(item) for item in items]
|
|
|
|
|
|
def _stop_payload(stop) -> dict[str, Any]:
|
|
return {
|
|
"id": stop.id,
|
|
"dataset_id": stop.dataset_id,
|
|
"stop_id": stop.stop_id,
|
|
"name": stop.name,
|
|
"lat": stop.lat,
|
|
"lon": stop.lon,
|
|
}
|
|
|
|
|
|
def _prune_old_searches() -> None:
|
|
now = time.time()
|
|
stale = [
|
|
search_id
|
|
for search_id, state in _searches.items()
|
|
if now - state.updated_at > 15 * 60 or (state.complete and now - state.updated_at > 3 * 60)
|
|
]
|
|
for search_id in stale:
|
|
state = _searches.pop(search_id, None)
|
|
if state is not None:
|
|
_clear_inflight_search_locked(state)
|