Alpha stage commit
This commit is contained in:
717
app/journey_search.py
Normal file
717
app/journey_search.py
Normal file
@@ -0,0 +1,717 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import hashlib
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.address_search import is_location_token
|
||||
from app.db import SessionLocal
|
||||
from app.journey import find_journeys, parse_service_date, resolve_location_summary
|
||||
from app.models import JourneySearchCache
|
||||
from app.routing import direct_route_between_points, route_between_points
|
||||
|
||||
|
||||
MAX_PROGRESSIVE_TRANSFERS = 5
|
||||
TRANSIT_STAGE_CACHE_TTL_SECONDS = 5 * 60
|
||||
TRANSIT_STAGE_CACHE_MAX_ENTRIES = 256
|
||||
PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS = 10 * 60
|
||||
PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES = 128
|
||||
JOURNEY_SEARCH_CACHE_VERSION = "journey-search-v7"
|
||||
_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="journey-search")
|
||||
_lock = threading.RLock()
|
||||
_searches: dict[str, "_SearchState"] = {}
|
||||
_progressive_search_inflight: dict[tuple[object, ...], str] = {}
|
||||
_transit_stage_cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]] = {}
|
||||
_progressive_search_cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SearchState:
|
||||
id: str
|
||||
request: dict[str, Any]
|
||||
cache_key: tuple[object, ...] | None = None
|
||||
status: str = "queued"
|
||||
message: str = "Queued."
|
||||
stage: str = "queued"
|
||||
journeys: list[dict] = field(default_factory=list)
|
||||
routing: dict[str, Any] | None = None
|
||||
context: dict[str, Any] = field(default_factory=dict)
|
||||
error: str | None = None
|
||||
created_at: float = field(default_factory=time.time)
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
complete: bool = False
|
||||
cancelled: bool = False
|
||||
|
||||
|
||||
def start_journey_search(request: dict[str, Any]) -> dict[str, Any]:
|
||||
key = _progressive_cache_key(request)
|
||||
cached = _progressive_cache_get(key)
|
||||
search_id = uuid.uuid4().hex
|
||||
state = _SearchState(id=search_id, request=dict(request), cache_key=key)
|
||||
if cached is not None:
|
||||
_apply_cached_payload(state, cached)
|
||||
with _lock:
|
||||
_prune_old_searches()
|
||||
if cached is None:
|
||||
existing_search_id = _progressive_search_inflight.get(key)
|
||||
existing_state = None if existing_search_id is None else _searches.get(existing_search_id)
|
||||
if existing_state is not None and not existing_state.complete and not existing_state.cancelled:
|
||||
return _payload(existing_state)
|
||||
_progressive_search_inflight[key] = search_id
|
||||
_searches[search_id] = state
|
||||
if cached is None:
|
||||
_executor.submit(_run_search, search_id)
|
||||
return journey_search_payload(search_id)
|
||||
|
||||
|
||||
def journey_search_payload(search_id: str) -> dict[str, Any]:
|
||||
with _lock:
|
||||
state = _searches.get(search_id)
|
||||
if state is None:
|
||||
raise KeyError(search_id)
|
||||
return _payload(state)
|
||||
|
||||
|
||||
def cancel_journey_search(search_id: str) -> dict[str, Any]:
|
||||
with _lock:
|
||||
state = _searches.get(search_id)
|
||||
if state is None:
|
||||
raise KeyError(search_id)
|
||||
state.cancelled = True
|
||||
if not state.complete:
|
||||
state.status = "cancelled"
|
||||
state.message = "Search cancelled."
|
||||
state.complete = True
|
||||
state.updated_at = time.time()
|
||||
_clear_inflight_search_locked(state)
|
||||
return _payload(state)
|
||||
|
||||
|
||||
def _run_search(search_id: str) -> None:
|
||||
with _lock:
|
||||
state = _searches.get(search_id)
|
||||
if state is None or state.cancelled:
|
||||
return
|
||||
state.status = "running"
|
||||
state.stage = "starting"
|
||||
state.message = "Starting search..."
|
||||
state.updated_at = time.time()
|
||||
request = dict(state.request)
|
||||
try:
|
||||
mode = str(request.get("mode") or "transit")
|
||||
if mode in {"walk", "drive", "car"}:
|
||||
_run_point_route_search(search_id, "drive" if mode == "car" else mode, request)
|
||||
else:
|
||||
_run_transit_search(search_id, request)
|
||||
except Exception as exc: # noqa: BLE001 - report progressive-search failure to client
|
||||
_publish_error(search_id, str(exc))
|
||||
|
||||
|
||||
def _run_transit_search(search_id: str, request: dict[str, Any]) -> None:
|
||||
direct_only = bool(request.get("direct_only"))
|
||||
limit = max(3, min(int(request.get("limit") or 5), 10))
|
||||
transfer_seconds = max(0, min(int(request.get("transfer_seconds") or 120), 3600))
|
||||
source_ids = _csv_ints(request.get("source_id"))
|
||||
service_date = request.get("service_date") or None
|
||||
stages = [0] if direct_only else list(range(0, MAX_PROGRESSIVE_TRANSFERS + 1))
|
||||
address_search = is_location_token(request.get("from_stop_id")) or is_location_token(request.get("to_stop_id"))
|
||||
stage_limit = limit if address_search else max(limit, 10)
|
||||
merged: dict[str, dict] = {}
|
||||
context: dict[str, Any] = {}
|
||||
diagnostics: dict[str, Any] = {"stages": []}
|
||||
best_count = 0
|
||||
stale_stages = 0
|
||||
for transfers in stages:
|
||||
if _is_cancelled(search_id):
|
||||
return
|
||||
label = "direct" if transfers == 0 else f"up to {transfers} transfer{'s' if transfers != 1 else ''}"
|
||||
_publish_status(search_id, "running", f"Searching {label}...", f"transfers_{transfers}")
|
||||
stage_started_at = time.monotonic()
|
||||
with SessionLocal() as db:
|
||||
result = _cached_find_journeys(
|
||||
db,
|
||||
from_stop_id=str(request.get("from_stop_id") or ""),
|
||||
to_stop_id=str(request.get("to_stop_id") or ""),
|
||||
via_stop_id=request.get("via_stop_id") or None,
|
||||
source_ids=source_ids,
|
||||
departure=str(request.get("departure") or "08:00"),
|
||||
service_date=service_date,
|
||||
max_transfers=transfers,
|
||||
transfer_seconds=transfer_seconds,
|
||||
limit=stage_limit,
|
||||
)
|
||||
cache_status = str(result.pop("_cache_status", "miss"))
|
||||
elapsed_ms = int((time.monotonic() - stage_started_at) * 1000)
|
||||
stage_diagnostics = {
|
||||
"transfers": transfers,
|
||||
"cache": cache_status,
|
||||
"elapsed_ms": elapsed_ms,
|
||||
"journeys": len(result.get("journeys") or []),
|
||||
}
|
||||
result_diagnostics = result.get("diagnostics")
|
||||
if isinstance(result_diagnostics, dict):
|
||||
stage_diagnostics["details"] = result_diagnostics
|
||||
diagnostics["stages"].append(stage_diagnostics)
|
||||
context = _context_from_result(result)
|
||||
context["diagnostics"] = diagnostics
|
||||
before = len(merged)
|
||||
for journey in result.get("journeys") or []:
|
||||
merged.setdefault(_journey_key(journey), journey)
|
||||
ranked = _select_diverse_journeys(_rank_journeys(merged.values(), str(request.get("ranking") or "recommended")), limit=limit)
|
||||
_publish_results(
|
||||
search_id,
|
||||
journeys=ranked,
|
||||
context=context,
|
||||
status="running",
|
||||
stage=f"transfers_{transfers}",
|
||||
message=f"Found {len(ranked)} option{'s' if len(ranked) != 1 else ''}; still searching..." if not direct_only else "Direct search complete.",
|
||||
)
|
||||
if len(merged) <= before and ranked:
|
||||
stale_stages += 1
|
||||
else:
|
||||
stale_stages = 0
|
||||
best_count = max(best_count, len(ranked))
|
||||
if ranked and stale_stages >= 2 and transfers >= 2:
|
||||
break
|
||||
if _major_hub_address_stage_is_complete(result_diagnostics, ranked, transfers=transfers, limit=limit):
|
||||
break
|
||||
complete_message = (
|
||||
f"Search complete. Found {best_count} option{'s' if best_count != 1 else ''}."
|
||||
if best_count
|
||||
else "Search complete. No route found in the imported timetable."
|
||||
)
|
||||
_publish_complete(search_id, message=complete_message)
|
||||
payload = journey_search_payload(search_id)
|
||||
if payload.get("status") == "complete" and not payload.get("error"):
|
||||
_progressive_cache_put(_progressive_cache_key(request), payload)
|
||||
|
||||
|
||||
def _major_hub_address_stage_is_complete(
|
||||
diagnostics: dict[str, Any] | None,
|
||||
ranked: list[dict],
|
||||
*,
|
||||
transfers: int,
|
||||
limit: int,
|
||||
) -> bool:
|
||||
if transfers < 1 or not ranked or not isinstance(diagnostics, dict):
|
||||
return False
|
||||
address_access = diagnostics.get("address_access")
|
||||
if not isinstance(address_access, dict) or not address_access.get("major_hubs"):
|
||||
return False
|
||||
return len(ranked) >= min(3, limit)
|
||||
|
||||
|
||||
def _run_point_route_search(search_id: str, mode: str, request: dict[str, Any]) -> None:
|
||||
_publish_status(search_id, "running", f"Searching {mode} route...", mode)
|
||||
with SessionLocal() as db:
|
||||
from_location = resolve_location_summary(db, str(request.get("from_stop_id") or ""), source_ids=_csv_ints(request.get("source_id")))
|
||||
to_location = resolve_location_summary(db, str(request.get("to_stop_id") or ""), source_ids=_csv_ints(request.get("source_id")))
|
||||
if from_location.lon is None or from_location.lat is None:
|
||||
raise ValueError("Selected start has no coordinates.")
|
||||
if to_location.lon is None or to_location.lat is None:
|
||||
raise ValueError("Selected destination has no coordinates.")
|
||||
try:
|
||||
route = route_between_points(
|
||||
db,
|
||||
from_lon=float(from_location.lon),
|
||||
from_lat=float(from_location.lat),
|
||||
to_lon=float(to_location.lon),
|
||||
to_lat=float(to_location.lat),
|
||||
mode=mode,
|
||||
max_visited=300_000,
|
||||
)
|
||||
message = f"{mode.title()} route found."
|
||||
except Exception as exc: # noqa: BLE001 - point routing should still return an approximate connector
|
||||
route = direct_route_between_points(
|
||||
db,
|
||||
from_lon=float(from_location.lon),
|
||||
from_lat=float(from_location.lat),
|
||||
to_lon=float(to_location.lon),
|
||||
to_lat=float(to_location.lat),
|
||||
mode=mode,
|
||||
reason=str(exc),
|
||||
)
|
||||
message = f"{mode.title()} route approximated."
|
||||
context = {
|
||||
"from": _stop_payload(from_location),
|
||||
"to": _stop_payload(to_location),
|
||||
"mode": mode,
|
||||
}
|
||||
_publish_routing(search_id, route, context=context, message=message)
|
||||
_publish_complete(search_id, message=f"{mode.title()} route complete.")
|
||||
payload = journey_search_payload(search_id)
|
||||
if payload.get("status") == "complete" and not payload.get("error"):
|
||||
_progressive_cache_put(_progressive_cache_key(request), payload)
|
||||
|
||||
|
||||
def _cached_find_journeys(
|
||||
db,
|
||||
*,
|
||||
from_stop_id: str,
|
||||
to_stop_id: str,
|
||||
via_stop_id: object,
|
||||
source_ids: list[int] | None,
|
||||
departure: str,
|
||||
service_date: object,
|
||||
max_transfers: int,
|
||||
transfer_seconds: int,
|
||||
limit: int,
|
||||
) -> dict[str, Any]:
|
||||
key = (
|
||||
from_stop_id,
|
||||
to_stop_id,
|
||||
str(via_stop_id or ""),
|
||||
tuple(sorted(int(source_id) for source_id in source_ids or [])),
|
||||
departure,
|
||||
str(service_date or ""),
|
||||
int(max_transfers),
|
||||
int(transfer_seconds),
|
||||
int(limit),
|
||||
)
|
||||
now = time.monotonic()
|
||||
with _lock:
|
||||
cached = _transit_stage_cache.get(key)
|
||||
if cached is not None:
|
||||
expires_at, payload = cached
|
||||
if expires_at > now:
|
||||
return _with_cache_status(payload, "memory")
|
||||
_transit_stage_cache.pop(key, None)
|
||||
durable = _durable_cache_get("transit_stage", key)
|
||||
if durable is not None:
|
||||
with _lock:
|
||||
_transit_stage_cache[key] = (now + TRANSIT_STAGE_CACHE_TTL_SECONDS, json.loads(json.dumps(durable)))
|
||||
_prune_timed_cache(_transit_stage_cache, TRANSIT_STAGE_CACHE_MAX_ENTRIES)
|
||||
return _with_cache_status(durable, "persistent")
|
||||
result = find_journeys(
|
||||
db=db,
|
||||
from_stop_id=from_stop_id,
|
||||
to_stop_id=to_stop_id,
|
||||
via_stop_id=via_stop_id,
|
||||
source_ids=source_ids,
|
||||
departure=departure,
|
||||
service_date=service_date,
|
||||
max_transfers=max_transfers,
|
||||
transfer_seconds=transfer_seconds,
|
||||
limit=limit,
|
||||
)
|
||||
stored_result = json.loads(json.dumps(result))
|
||||
with _lock:
|
||||
_transit_stage_cache[key] = (now + TRANSIT_STAGE_CACHE_TTL_SECONDS, stored_result)
|
||||
_prune_timed_cache(_transit_stage_cache, TRANSIT_STAGE_CACHE_MAX_ENTRIES)
|
||||
_durable_cache_put("transit_stage", key, stored_result, ttl_seconds=TRANSIT_STAGE_CACHE_TTL_SECONDS)
|
||||
return _with_cache_status(result, "miss")
|
||||
|
||||
|
||||
def _with_cache_status(payload: dict[str, Any], cache_status: str) -> dict[str, Any]:
|
||||
copied = json.loads(json.dumps(payload))
|
||||
copied["_cache_status"] = cache_status
|
||||
return copied
|
||||
|
||||
|
||||
def _prune_timed_cache(cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]], max_entries: int) -> None:
|
||||
if len(cache) <= max_entries:
|
||||
return
|
||||
oldest = sorted(cache.items(), key=lambda item: item[1][0])[: len(cache) - max_entries]
|
||||
for old_key, _ in oldest:
|
||||
cache.pop(old_key, None)
|
||||
|
||||
|
||||
def _durable_cache_get(cache_type: str, key: tuple[object, ...]) -> dict[str, Any] | None:
|
||||
storage_key = _durable_cache_key(cache_type, key)
|
||||
now = datetime.now(timezone.utc)
|
||||
try:
|
||||
with SessionLocal() as session:
|
||||
row = session.scalar(select(JourneySearchCache).where(JourneySearchCache.cache_key == storage_key))
|
||||
if row is None:
|
||||
return None
|
||||
expires_at = _as_utc(row.expires_at)
|
||||
if expires_at is None or expires_at <= now:
|
||||
session.delete(row)
|
||||
session.commit()
|
||||
return None
|
||||
return json.loads(row.payload_json)
|
||||
except Exception: # noqa: BLE001 - cache misses must not break journey search
|
||||
return None
|
||||
|
||||
|
||||
def _durable_cache_put(cache_type: str, key: tuple[object, ...], payload: dict[str, Any], *, ttl_seconds: int) -> None:
|
||||
storage_key = _durable_cache_key(cache_type, key)
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=max(1, int(ttl_seconds)))
|
||||
try:
|
||||
with SessionLocal() as session:
|
||||
row = session.scalar(select(JourneySearchCache).where(JourneySearchCache.cache_key == storage_key))
|
||||
if row is None:
|
||||
row = JourneySearchCache(
|
||||
cache_key=storage_key,
|
||||
cache_type=cache_type,
|
||||
payload_json=json.dumps(payload, separators=(",", ":")),
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
session.add(row)
|
||||
else:
|
||||
row.cache_type = cache_type
|
||||
row.payload_json = json.dumps(payload, separators=(",", ":"))
|
||||
row.updated_at = now
|
||||
row.expires_at = expires_at
|
||||
session.commit()
|
||||
except Exception: # noqa: BLE001 - cache writes are best-effort
|
||||
return
|
||||
|
||||
|
||||
def _durable_cache_key(cache_type: str, key: tuple[object, ...]) -> str:
|
||||
raw = json.dumps(
|
||||
{
|
||||
"version": JOURNEY_SEARCH_CACHE_VERSION,
|
||||
"cache_type": cache_type,
|
||||
"key": _json_safe(key),
|
||||
},
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _json_safe(value: object) -> object:
|
||||
if isinstance(value, tuple):
|
||||
return [_json_safe(item) for item in value]
|
||||
if isinstance(value, list):
|
||||
return [_json_safe(item) for item in value]
|
||||
if isinstance(value, dict):
|
||||
return {str(key): _json_safe(item) for key, item in sorted(value.items(), key=lambda item: str(item[0]))}
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
return value
|
||||
if isinstance(value, (date, datetime)):
|
||||
return value.isoformat()
|
||||
return str(value)
|
||||
|
||||
|
||||
def _as_utc(value: datetime | None) -> datetime | None:
|
||||
if value is None:
|
||||
return None
|
||||
if value.tzinfo is None:
|
||||
return value.replace(tzinfo=timezone.utc)
|
||||
return value.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def _progressive_cache_key(request: dict[str, Any]) -> tuple[object, ...]:
|
||||
source_ids = _csv_ints(request.get("source_id"))
|
||||
return (
|
||||
str(request.get("mode") or "transit"),
|
||||
str(request.get("from_stop_id") or ""),
|
||||
str(request.get("to_stop_id") or ""),
|
||||
str(request.get("via_stop_id") or ""),
|
||||
tuple(sorted(int(source_id) for source_id in source_ids or [])),
|
||||
str(request.get("departure") or "08:00"),
|
||||
str(request.get("service_date") or ""),
|
||||
bool(request.get("direct_only")),
|
||||
str(request.get("ranking") or "recommended"),
|
||||
int(request.get("transfer_seconds") or 120),
|
||||
max(3, min(int(request.get("limit") or 5), 10)),
|
||||
)
|
||||
|
||||
|
||||
def _progressive_cache_get(key: tuple[object, ...]) -> dict[str, Any] | None:
|
||||
now = time.monotonic()
|
||||
with _lock:
|
||||
cached = _progressive_search_cache.get(key)
|
||||
if cached is not None:
|
||||
expires_at, payload = cached
|
||||
if expires_at > now:
|
||||
copied = json.loads(json.dumps(payload))
|
||||
copied["cache_status"] = "memory"
|
||||
return copied
|
||||
_progressive_search_cache.pop(key, None)
|
||||
durable = _durable_cache_get("progressive", key)
|
||||
if durable is None:
|
||||
return None
|
||||
with _lock:
|
||||
_progressive_search_cache[key] = (now + PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS, json.loads(json.dumps(durable)))
|
||||
_prune_timed_cache(_progressive_search_cache, PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES)
|
||||
copied = json.loads(json.dumps(durable))
|
||||
copied["cache_status"] = "persistent"
|
||||
return copied
|
||||
|
||||
|
||||
def _progressive_cache_put(key: tuple[object, ...], payload: dict[str, Any]) -> None:
|
||||
stored_payload = json.loads(json.dumps(payload))
|
||||
stored_payload.pop("cache_status", None)
|
||||
with _lock:
|
||||
_progressive_search_cache[key] = (time.monotonic() + PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS, stored_payload)
|
||||
_prune_timed_cache(_progressive_search_cache, PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES)
|
||||
_durable_cache_put("progressive", key, stored_payload, ttl_seconds=PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS)
|
||||
|
||||
|
||||
def _apply_cached_payload(state: _SearchState, payload: dict[str, Any]) -> None:
|
||||
state.status = str(payload.get("status") or "complete")
|
||||
state.message = "Cached result."
|
||||
state.stage = str(payload.get("stage") or "cached")
|
||||
state.journeys = json.loads(json.dumps(payload.get("journeys") or []))
|
||||
state.routing = json.loads(json.dumps(payload.get("routing"))) if payload.get("routing") is not None else None
|
||||
state.context = {
|
||||
key: value
|
||||
for key, value in payload.items()
|
||||
if key not in {"search_id", "status", "stage", "message", "complete", "error", "journeys", "routing", "created_at", "updated_at"}
|
||||
}
|
||||
state.error = None
|
||||
state.complete = True
|
||||
state.updated_at = time.time()
|
||||
|
||||
|
||||
def _publish_status(search_id: str, status: str, message: str, stage: str) -> None:
|
||||
with _lock:
|
||||
state = _searches.get(search_id)
|
||||
if state is None or state.cancelled:
|
||||
return
|
||||
state.status = status
|
||||
state.message = message
|
||||
state.stage = stage
|
||||
state.updated_at = time.time()
|
||||
|
||||
|
||||
def _publish_results(search_id: str, *, journeys: list[dict], context: dict[str, Any], status: str, stage: str, message: str) -> None:
|
||||
with _lock:
|
||||
state = _searches.get(search_id)
|
||||
if state is None or state.cancelled:
|
||||
return
|
||||
state.status = status
|
||||
state.stage = stage
|
||||
state.message = message
|
||||
state.journeys = list(journeys)
|
||||
state.context = dict(context)
|
||||
state.updated_at = time.time()
|
||||
|
||||
|
||||
def _publish_routing(search_id: str, routing: dict[str, Any], *, context: dict[str, Any], message: str) -> None:
|
||||
with _lock:
|
||||
state = _searches.get(search_id)
|
||||
if state is None or state.cancelled:
|
||||
return
|
||||
state.status = "running"
|
||||
state.stage = str(routing.get("mode") or "route")
|
||||
state.message = message
|
||||
state.routing = routing
|
||||
state.context = dict(context)
|
||||
state.updated_at = time.time()
|
||||
|
||||
|
||||
def _publish_complete(search_id: str, *, message: str) -> None:
|
||||
with _lock:
|
||||
state = _searches.get(search_id)
|
||||
if state is None or state.cancelled:
|
||||
return
|
||||
state.status = "complete"
|
||||
state.message = message
|
||||
state.complete = True
|
||||
state.updated_at = time.time()
|
||||
_clear_inflight_search_locked(state)
|
||||
|
||||
|
||||
def _publish_error(search_id: str, message: str) -> None:
|
||||
with _lock:
|
||||
state = _searches.get(search_id)
|
||||
if state is None:
|
||||
return
|
||||
state.status = "error"
|
||||
state.stage = "error"
|
||||
state.message = message
|
||||
state.error = message
|
||||
state.complete = True
|
||||
state.updated_at = time.time()
|
||||
_clear_inflight_search_locked(state)
|
||||
|
||||
|
||||
def _clear_inflight_search_locked(state: _SearchState) -> None:
|
||||
if state.cache_key is not None and _progressive_search_inflight.get(state.cache_key) == state.id:
|
||||
_progressive_search_inflight.pop(state.cache_key, None)
|
||||
|
||||
|
||||
def _is_cancelled(search_id: str) -> bool:
|
||||
with _lock:
|
||||
state = _searches.get(search_id)
|
||||
return state is None or state.cancelled
|
||||
|
||||
|
||||
def _payload(state: _SearchState) -> dict[str, Any]:
|
||||
return {
|
||||
"search_id": state.id,
|
||||
"status": state.status,
|
||||
"stage": state.stage,
|
||||
"message": state.message,
|
||||
"complete": state.complete,
|
||||
"error": state.error,
|
||||
"journeys": json.loads(json.dumps(state.journeys)),
|
||||
"routing": json.loads(json.dumps(state.routing)) if state.routing is not None else None,
|
||||
"created_at": state.created_at,
|
||||
"updated_at": state.updated_at,
|
||||
**json.loads(json.dumps(state.context)),
|
||||
}
|
||||
|
||||
|
||||
def _context_from_result(result: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
key: value
|
||||
for key, value in result.items()
|
||||
if key not in {"journeys"} and key not in {"error"}
|
||||
}
|
||||
|
||||
|
||||
def _journey_key(journey: dict[str, Any]) -> str:
|
||||
parts = []
|
||||
for leg in journey.get("legs") or []:
|
||||
parts.append(
|
||||
"|".join(
|
||||
str(part or "")
|
||||
for part in [
|
||||
leg.get("dataset_id"),
|
||||
leg.get("mode"),
|
||||
leg.get("route_id"),
|
||||
leg.get("trip_id"),
|
||||
(leg.get("from") or {}).get("stop_id") or (leg.get("from") or {}).get("name"),
|
||||
(leg.get("to") or {}).get("stop_id") or (leg.get("to") or {}).get("name"),
|
||||
leg.get("departure_time"),
|
||||
leg.get("arrival_time"),
|
||||
]
|
||||
)
|
||||
)
|
||||
return "||".join(parts)
|
||||
|
||||
|
||||
def _rank_journeys(journeys, ranking: str) -> list[dict]:
|
||||
def key(journey: dict[str, Any]) -> tuple[float, float, int, float]:
|
||||
departure = journey.get("departure_seconds")
|
||||
arrival = journey.get("arrival_seconds")
|
||||
duration = journey.get("duration_minutes")
|
||||
transfers = int(journey.get("transfers") or 0)
|
||||
walking = sum(float(leg.get("distance_m") or 0) for leg in journey.get("legs") or [] if leg.get("mode") == "walk")
|
||||
walking_seconds = walking / 1.35
|
||||
if ranking == "duration":
|
||||
return (
|
||||
float("inf") if duration is None else float(duration),
|
||||
float("inf") if arrival is None else float(arrival),
|
||||
transfers,
|
||||
walking,
|
||||
)
|
||||
if ranking == "fewest_transfers":
|
||||
return (
|
||||
transfers,
|
||||
float("inf") if arrival is None else float(arrival),
|
||||
float("inf") if duration is None else float(duration),
|
||||
walking,
|
||||
)
|
||||
if ranking == "earliest_arrival":
|
||||
return (
|
||||
float("inf") if arrival is None else float(arrival),
|
||||
float("inf") if duration is None else float(duration),
|
||||
transfers,
|
||||
walking,
|
||||
)
|
||||
return (
|
||||
float("inf") if arrival is None else float(arrival) + transfers * 600 + walking_seconds,
|
||||
float("inf") if arrival is None else float(arrival),
|
||||
transfers,
|
||||
walking,
|
||||
)
|
||||
|
||||
return sorted((dict(journey) for journey in journeys), key=key)
|
||||
|
||||
|
||||
def _select_diverse_journeys(journeys: list[dict], *, limit: int) -> list[dict]:
|
||||
selected: list[dict] = []
|
||||
selected_exact: set[str] = set()
|
||||
selected_diversity: set[tuple[object, ...]] = set()
|
||||
for journey in journeys:
|
||||
exact_key = _journey_key(journey)
|
||||
if exact_key in selected_exact:
|
||||
continue
|
||||
diversity_key = _journey_diversity_key(journey)
|
||||
if diversity_key in selected_diversity and len(selected) >= 3:
|
||||
continue
|
||||
selected.append(journey)
|
||||
selected_exact.add(exact_key)
|
||||
selected_diversity.add(diversity_key)
|
||||
if len(selected) >= limit:
|
||||
return selected
|
||||
if len(selected) >= min(3, limit):
|
||||
return selected
|
||||
for journey in journeys:
|
||||
exact_key = _journey_key(journey)
|
||||
if exact_key in selected_exact:
|
||||
continue
|
||||
selected.append(journey)
|
||||
selected_exact.add(exact_key)
|
||||
if len(selected) >= min(3, limit):
|
||||
break
|
||||
return _ensure_walk_only_option(selected, journeys, limit=limit)
|
||||
|
||||
|
||||
def _ensure_walk_only_option(selected: list[dict], ranked: list[dict], *, limit: int) -> list[dict]:
|
||||
if any(_journey_is_walk_only(journey) for journey in selected):
|
||||
return selected
|
||||
walk = next((journey for journey in ranked if _journey_is_walk_only(journey)), None)
|
||||
if walk is None:
|
||||
return selected
|
||||
if len(selected) < limit:
|
||||
return [*selected, walk]
|
||||
if selected:
|
||||
selected[-1] = walk
|
||||
return selected
|
||||
|
||||
|
||||
def _journey_is_walk_only(journey: dict) -> bool:
|
||||
legs = journey.get("legs") or []
|
||||
return bool(legs) and all(leg.get("mode") == "walk" for leg in legs)
|
||||
|
||||
|
||||
def _journey_diversity_key(journey: dict[str, Any]) -> tuple[object, ...]:
|
||||
route_signature = tuple(
|
||||
str(leg.get("route_ref") or leg.get("route_id") or leg.get("mode") or "")
|
||||
for leg in journey.get("legs") or []
|
||||
if leg.get("mode") != "walk"
|
||||
)
|
||||
departure = journey.get("departure_seconds")
|
||||
time_band = None if departure is None else int(departure) // (30 * 60)
|
||||
return (int(journey.get("transfers") or 0), route_signature, time_band)
|
||||
|
||||
|
||||
def _csv_ints(value: object) -> list[int] | None:
|
||||
if value is None:
|
||||
return None
|
||||
items = [item.strip() for item in str(value).split(",") if item.strip()]
|
||||
if not items:
|
||||
return None
|
||||
return [int(item) for item in items]
|
||||
|
||||
|
||||
def _stop_payload(stop) -> dict[str, Any]:
|
||||
return {
|
||||
"id": stop.id,
|
||||
"dataset_id": stop.dataset_id,
|
||||
"stop_id": stop.stop_id,
|
||||
"name": stop.name,
|
||||
"lat": stop.lat,
|
||||
"lon": stop.lon,
|
||||
}
|
||||
|
||||
|
||||
def _prune_old_searches() -> None:
|
||||
now = time.time()
|
||||
stale = [
|
||||
search_id
|
||||
for search_id, state in _searches.items()
|
||||
if now - state.updated_at > 15 * 60 or (state.complete and now - state.updated_at > 3 * 60)
|
||||
]
|
||||
for search_id in stale:
|
||||
state = _searches.pop(search_id, None)
|
||||
if state is not None:
|
||||
_clear_inflight_search_locked(state)
|
||||
Reference in New Issue
Block a user