Files
meubility-workbench/app/dataset_search.py
2026-07-01 23:29:51 +02:00

253 lines
10 KiB
Python

from __future__ import annotations
from sqlalchemy import func, or_, select
from sqlalchemy.orm import Session
from app.gtfs_storage import execute_sidecar_query, uses_sidecar_stop_times
from app.models import Dataset, GtfsRoute, GtfsShape, GtfsStopTime, GtfsTrip, OsmFeature, RoutePattern, Source
from app.osm_storage import osm_feature_public_id, query_osm_features
from app.pipeline.utils import norm_ref
def search_datasets(session: Session, query: str, *, active_only: bool = False, limit: int = 80) -> dict:
q = (query or "").strip()
if len(q) < 1:
return {"query": q, "gtfs_routes": [], "osm_routes": [], "route_patterns": [], "totals": {}}
max_rows = max(1, min(limit, 250))
gtfs_routes = _gtfs_route_hits(session, q, active_only=active_only, limit=max_rows)
osm_routes = _osm_route_hits(session, q, active_only=active_only, limit=max_rows)
route_patterns = _route_pattern_hits(session, q, limit=max_rows)
return {
"query": q,
"gtfs_routes": gtfs_routes,
"osm_routes": osm_routes,
"route_patterns": route_patterns,
"totals": {
"gtfs_routes": len(gtfs_routes),
"osm_routes": len(osm_routes),
"route_patterns": len(route_patterns),
},
}
def _gtfs_route_hits(session: Session, query: str, *, active_only: bool, limit: int) -> list[dict]:
pattern = f"%{query}%"
ref = norm_ref(query)
stmt = (
select(GtfsRoute, Dataset, Source)
.join(Dataset, Dataset.id == GtfsRoute.dataset_id)
.join(Source, Source.id == Dataset.source_id)
.where(
or_(
GtfsRoute.short_name.ilike(pattern),
GtfsRoute.route_id.ilike(pattern),
GtfsRoute.long_name.ilike(pattern),
GtfsRoute.route_key == ref,
)
)
.order_by(Dataset.is_active.desc(), Source.name, GtfsRoute.short_name, GtfsRoute.route_id)
.limit(limit)
)
if active_only:
stmt = stmt.where(Dataset.is_active.is_(True))
rows = session.execute(stmt).all()
route_ids = [route.id for route, _, _ in rows]
trip_counts = _trip_counts(session, route_ids)
stop_time_counts = _stop_time_counts(session, route_ids)
shape_counts = _shape_counts(session, route_ids)
return [
{
"type": "gtfs_route",
"source": _source_payload(source),
"dataset": _dataset_payload(dataset),
"route": {
"id": route.id,
"route_id": route.route_id,
"ref": route.short_name,
"name": route.long_name,
"mode": route.mode,
"operator": route.operator_name,
},
"geometry": _geometry_payload(route),
"timetable": {
"trips": trip_counts.get(route.id, 0),
"stop_times": stop_time_counts.get(route.id, 0),
"shapes": shape_counts.get(route.id, 0),
},
}
for route, dataset, source in rows
]
def _osm_route_hits(session: Session, query: str, *, active_only: bool, limit: int) -> list[dict]:
ref = norm_ref(query)
dataset_stmt = select(Dataset).where(Dataset.kind == "osm_geojson")
if active_only:
dataset_stmt = dataset_stmt.where(Dataset.is_active.is_(True))
datasets = session.scalars(dataset_stmt.order_by(Dataset.is_active.desc(), Dataset.id)).all()
if not datasets:
return []
dataset_ids = [dataset.id for dataset in datasets]
sources = {source.id: source for source in session.scalars(select(Source).where(Source.id.in_([dataset.source_id for dataset in datasets]))).all()}
dataset_by_id = {dataset.id: dataset for dataset in datasets}
features_by_identity: dict[tuple[int, str, str], OsmFeature] = {}
for feature in query_osm_features(session, dataset_ids, kinds=["route"], search=query, limit=limit):
features_by_identity[(feature.dataset_id, feature.osm_type, feature.osm_id)] = feature
if ref:
for feature in query_osm_features(session, dataset_ids, kinds=["route"], route_key=ref, limit=limit):
features_by_identity[(feature.dataset_id, feature.osm_type, feature.osm_id)] = feature
features = sorted(
features_by_identity.values(),
key=lambda feature: (
0 if dataset_by_id.get(feature.dataset_id) and dataset_by_id[feature.dataset_id].is_active else 1,
sources.get(dataset_by_id[feature.dataset_id].source_id).name if dataset_by_id.get(feature.dataset_id) and sources.get(dataset_by_id[feature.dataset_id].source_id) else "",
feature.ref or "",
feature.name or "",
feature.id or 0,
),
)[:limit]
return [
{
"type": "osm_route",
"source": _source_payload(source),
"dataset": _dataset_payload(dataset),
"osm": {
"id": osm_feature_public_id(feature),
"osm_type": feature.osm_type,
"osm_id": feature.osm_id,
"ref": feature.ref,
"name": feature.name,
"mode": feature.mode,
"route_scope": feature.route_scope,
"operator": feature.operator,
"network": feature.network,
},
"geometry": _geometry_payload(feature),
}
for feature in features
if (dataset := dataset_by_id.get(feature.dataset_id)) is not None
if (source := sources.get(dataset.source_id)) is not None
]
def _route_pattern_hits(session: Session, query: str, *, limit: int) -> list[dict]:
pattern = f"%{query}%"
ref = norm_ref(query)
stmt = (
select(RoutePattern)
.where(
or_(
RoutePattern.route_ref.ilike(pattern),
RoutePattern.route_name.ilike(pattern),
RoutePattern.pattern_key.ilike(pattern),
)
)
.order_by(RoutePattern.source_kind, RoutePattern.route_ref, RoutePattern.id)
.limit(limit)
)
rows = session.scalars(stmt).all()
return [
{
"type": "route_pattern",
"id": pattern_row.id,
"ref": pattern_row.route_ref,
"name": pattern_row.route_name,
"mode": pattern_row.mode,
"route_scope": pattern_row.route_scope,
"source_kind": pattern_row.source_kind,
"status": pattern_row.status,
"confidence": pattern_row.confidence,
"gtfs_route_id": pattern_row.gtfs_route_id,
"osm_feature_id": pattern_row.osm_feature_id,
"geometry": _geometry_payload(pattern_row),
}
for pattern_row in rows
if not ref or norm_ref(pattern_row.route_ref or pattern_row.route_name or "") == ref or query.lower() in (pattern_row.route_name or "").lower()
]
def _trip_counts(session: Session, route_row_ids: list[int]) -> dict[int, int]:
if not route_row_ids:
return {}
rows = session.execute(
select(GtfsRoute.id, func.count(GtfsTrip.id))
.join(GtfsTrip, (GtfsTrip.dataset_id == GtfsRoute.dataset_id) & (GtfsTrip.route_id == GtfsRoute.route_id))
.where(GtfsRoute.id.in_(route_row_ids))
.group_by(GtfsRoute.id)
).all()
return {int(route_id): int(count) for route_id, count in rows}
def _stop_time_counts(session: Session, route_row_ids: list[int]) -> dict[int, int]:
if not route_row_ids:
return {}
routes = session.scalars(select(GtfsRoute).where(GtfsRoute.id.in_(route_row_ids))).all()
sidecar_routes = [route for route in routes if uses_sidecar_stop_times(session, route.dataset_id)]
sidecar_route_ids = {route.id for route in sidecar_routes}
main_route_ids = [route.id for route in routes if route.id not in sidecar_route_ids]
counts: dict[int, int] = {}
if main_route_ids:
rows = session.execute(
select(GtfsRoute.id, func.count(GtfsStopTime.id))
.join(GtfsTrip, (GtfsTrip.dataset_id == GtfsRoute.dataset_id) & (GtfsTrip.route_id == GtfsRoute.route_id))
.join(GtfsStopTime, (GtfsStopTime.dataset_id == GtfsTrip.dataset_id) & (GtfsStopTime.trip_id == GtfsTrip.trip_id))
.where(GtfsRoute.id.in_(main_route_ids))
.group_by(GtfsRoute.id)
).all()
counts.update({int(route_id): int(count) for route_id, count in rows})
for route in sidecar_routes:
rows = execute_sidecar_query(
session,
route.dataset_id,
"""
SELECT COUNT(*) AS count
FROM gtfs_stop_times AS stop_times
JOIN gtfs_trips AS trips
ON trips.trip_id = stop_times.trip_id
WHERE trips.route_id = ?
""",
[route.route_id],
)
counts[int(route.id)] = int(rows[0]["count"] or 0) if rows else 0
return counts
def _shape_counts(session: Session, route_row_ids: list[int]) -> dict[int, int]:
if not route_row_ids:
return {}
rows = session.execute(
select(GtfsRoute.id, func.count(func.distinct(GtfsShape.shape_id)))
.join(GtfsTrip, (GtfsTrip.dataset_id == GtfsRoute.dataset_id) & (GtfsTrip.route_id == GtfsRoute.route_id))
.join(GtfsShape, (GtfsShape.dataset_id == GtfsTrip.dataset_id) & (GtfsShape.shape_id == GtfsTrip.shape_id))
.where(GtfsRoute.id.in_(route_row_ids))
.group_by(GtfsRoute.id)
).all()
return {int(route_id): int(count) for route_id, count in rows}
def _source_payload(source: Source) -> dict:
return {"id": source.id, "name": source.name, "kind": source.kind, "country": source.country}
def _dataset_payload(dataset: Dataset) -> dict:
return {
"id": dataset.id,
"kind": dataset.kind,
"is_active": dataset.is_active,
"status": dataset.status,
"created_at": dataset.created_at.isoformat() if dataset.created_at else None,
"sha256": dataset.sha256,
}
def _geometry_payload(row) -> dict:
bbox = None
if all(getattr(row, attr, None) is not None for attr in ("min_lon", "min_lat", "max_lon", "max_lat")):
bbox = {
"min_lon": row.min_lon,
"min_lat": row.min_lat,
"max_lon": row.max_lon,
"max_lat": row.max_lat,
}
return {"present": bool(getattr(row, "geometry_geojson", None)), "bbox": bbox}