from __future__ import annotations from sqlalchemy import func, or_, select from sqlalchemy.orm import Session from app.gtfs_storage import execute_sidecar_query, uses_sidecar_stop_times from app.models import Dataset, GtfsRoute, GtfsShape, GtfsStopTime, GtfsTrip, OsmFeature, RoutePattern, Source from app.osm_storage import osm_feature_public_id, query_osm_features from app.pipeline.utils import norm_ref def search_datasets(session: Session, query: str, *, active_only: bool = False, limit: int = 80) -> dict: q = (query or "").strip() if len(q) < 1: return {"query": q, "gtfs_routes": [], "osm_routes": [], "route_patterns": [], "totals": {}} max_rows = max(1, min(limit, 250)) gtfs_routes = _gtfs_route_hits(session, q, active_only=active_only, limit=max_rows) osm_routes = _osm_route_hits(session, q, active_only=active_only, limit=max_rows) route_patterns = _route_pattern_hits(session, q, limit=max_rows) return { "query": q, "gtfs_routes": gtfs_routes, "osm_routes": osm_routes, "route_patterns": route_patterns, "totals": { "gtfs_routes": len(gtfs_routes), "osm_routes": len(osm_routes), "route_patterns": len(route_patterns), }, } def _gtfs_route_hits(session: Session, query: str, *, active_only: bool, limit: int) -> list[dict]: pattern = f"%{query}%" ref = norm_ref(query) stmt = ( select(GtfsRoute, Dataset, Source) .join(Dataset, Dataset.id == GtfsRoute.dataset_id) .join(Source, Source.id == Dataset.source_id) .where( or_( GtfsRoute.short_name.ilike(pattern), GtfsRoute.route_id.ilike(pattern), GtfsRoute.long_name.ilike(pattern), GtfsRoute.route_key == ref, ) ) .order_by(Dataset.is_active.desc(), Source.name, GtfsRoute.short_name, GtfsRoute.route_id) .limit(limit) ) if active_only: stmt = stmt.where(Dataset.is_active.is_(True)) rows = session.execute(stmt).all() route_ids = [route.id for route, _, _ in rows] trip_counts = _trip_counts(session, route_ids) stop_time_counts = _stop_time_counts(session, route_ids) shape_counts = _shape_counts(session, route_ids) return [ { "type": "gtfs_route", "source": _source_payload(source), "dataset": _dataset_payload(dataset), "route": { "id": route.id, "route_id": route.route_id, "ref": route.short_name, "name": route.long_name, "mode": route.mode, "operator": route.operator_name, }, "geometry": _geometry_payload(route), "timetable": { "trips": trip_counts.get(route.id, 0), "stop_times": stop_time_counts.get(route.id, 0), "shapes": shape_counts.get(route.id, 0), }, } for route, dataset, source in rows ] def _osm_route_hits(session: Session, query: str, *, active_only: bool, limit: int) -> list[dict]: ref = norm_ref(query) dataset_stmt = select(Dataset).where(Dataset.kind == "osm_geojson") if active_only: dataset_stmt = dataset_stmt.where(Dataset.is_active.is_(True)) datasets = session.scalars(dataset_stmt.order_by(Dataset.is_active.desc(), Dataset.id)).all() if not datasets: return [] dataset_ids = [dataset.id for dataset in datasets] sources = {source.id: source for source in session.scalars(select(Source).where(Source.id.in_([dataset.source_id for dataset in datasets]))).all()} dataset_by_id = {dataset.id: dataset for dataset in datasets} features_by_identity: dict[tuple[int, str, str], OsmFeature] = {} for feature in query_osm_features(session, dataset_ids, kinds=["route"], search=query, limit=limit): features_by_identity[(feature.dataset_id, feature.osm_type, feature.osm_id)] = feature if ref: for feature in query_osm_features(session, dataset_ids, kinds=["route"], route_key=ref, limit=limit): features_by_identity[(feature.dataset_id, feature.osm_type, feature.osm_id)] = feature features = sorted( features_by_identity.values(), key=lambda feature: ( 0 if dataset_by_id.get(feature.dataset_id) and dataset_by_id[feature.dataset_id].is_active else 1, sources.get(dataset_by_id[feature.dataset_id].source_id).name if dataset_by_id.get(feature.dataset_id) and sources.get(dataset_by_id[feature.dataset_id].source_id) else "", feature.ref or "", feature.name or "", feature.id or 0, ), )[:limit] return [ { "type": "osm_route", "source": _source_payload(source), "dataset": _dataset_payload(dataset), "osm": { "id": osm_feature_public_id(feature), "osm_type": feature.osm_type, "osm_id": feature.osm_id, "ref": feature.ref, "name": feature.name, "mode": feature.mode, "route_scope": feature.route_scope, "operator": feature.operator, "network": feature.network, }, "geometry": _geometry_payload(feature), } for feature in features if (dataset := dataset_by_id.get(feature.dataset_id)) is not None if (source := sources.get(dataset.source_id)) is not None ] def _route_pattern_hits(session: Session, query: str, *, limit: int) -> list[dict]: pattern = f"%{query}%" ref = norm_ref(query) stmt = ( select(RoutePattern) .where( or_( RoutePattern.route_ref.ilike(pattern), RoutePattern.route_name.ilike(pattern), RoutePattern.pattern_key.ilike(pattern), ) ) .order_by(RoutePattern.source_kind, RoutePattern.route_ref, RoutePattern.id) .limit(limit) ) rows = session.scalars(stmt).all() return [ { "type": "route_pattern", "id": pattern_row.id, "ref": pattern_row.route_ref, "name": pattern_row.route_name, "mode": pattern_row.mode, "route_scope": pattern_row.route_scope, "source_kind": pattern_row.source_kind, "status": pattern_row.status, "confidence": pattern_row.confidence, "gtfs_route_id": pattern_row.gtfs_route_id, "osm_feature_id": pattern_row.osm_feature_id, "geometry": _geometry_payload(pattern_row), } for pattern_row in rows if not ref or norm_ref(pattern_row.route_ref or pattern_row.route_name or "") == ref or query.lower() in (pattern_row.route_name or "").lower() ] def _trip_counts(session: Session, route_row_ids: list[int]) -> dict[int, int]: if not route_row_ids: return {} rows = session.execute( select(GtfsRoute.id, func.count(GtfsTrip.id)) .join(GtfsTrip, (GtfsTrip.dataset_id == GtfsRoute.dataset_id) & (GtfsTrip.route_id == GtfsRoute.route_id)) .where(GtfsRoute.id.in_(route_row_ids)) .group_by(GtfsRoute.id) ).all() return {int(route_id): int(count) for route_id, count in rows} def _stop_time_counts(session: Session, route_row_ids: list[int]) -> dict[int, int]: if not route_row_ids: return {} routes = session.scalars(select(GtfsRoute).where(GtfsRoute.id.in_(route_row_ids))).all() sidecar_routes = [route for route in routes if uses_sidecar_stop_times(session, route.dataset_id)] sidecar_route_ids = {route.id for route in sidecar_routes} main_route_ids = [route.id for route in routes if route.id not in sidecar_route_ids] counts: dict[int, int] = {} if main_route_ids: rows = session.execute( select(GtfsRoute.id, func.count(GtfsStopTime.id)) .join(GtfsTrip, (GtfsTrip.dataset_id == GtfsRoute.dataset_id) & (GtfsTrip.route_id == GtfsRoute.route_id)) .join(GtfsStopTime, (GtfsStopTime.dataset_id == GtfsTrip.dataset_id) & (GtfsStopTime.trip_id == GtfsTrip.trip_id)) .where(GtfsRoute.id.in_(main_route_ids)) .group_by(GtfsRoute.id) ).all() counts.update({int(route_id): int(count) for route_id, count in rows}) for route in sidecar_routes: rows = execute_sidecar_query( session, route.dataset_id, """ SELECT COUNT(*) AS count FROM gtfs_stop_times AS stop_times JOIN gtfs_trips AS trips ON trips.trip_id = stop_times.trip_id WHERE trips.route_id = ? """, [route.route_id], ) counts[int(route.id)] = int(rows[0]["count"] or 0) if rows else 0 return counts def _shape_counts(session: Session, route_row_ids: list[int]) -> dict[int, int]: if not route_row_ids: return {} rows = session.execute( select(GtfsRoute.id, func.count(func.distinct(GtfsShape.shape_id))) .join(GtfsTrip, (GtfsTrip.dataset_id == GtfsRoute.dataset_id) & (GtfsTrip.route_id == GtfsRoute.route_id)) .join(GtfsShape, (GtfsShape.dataset_id == GtfsTrip.dataset_id) & (GtfsShape.shape_id == GtfsTrip.shape_id)) .where(GtfsRoute.id.in_(route_row_ids)) .group_by(GtfsRoute.id) ).all() return {int(route_id): int(count) for route_id, count in rows} def _source_payload(source: Source) -> dict: return {"id": source.id, "name": source.name, "kind": source.kind, "country": source.country} def _dataset_payload(dataset: Dataset) -> dict: return { "id": dataset.id, "kind": dataset.kind, "is_active": dataset.is_active, "status": dataset.status, "created_at": dataset.created_at.isoformat() if dataset.created_at else None, "sha256": dataset.sha256, } def _geometry_payload(row) -> dict: bbox = None if all(getattr(row, attr, None) is not None for attr in ("min_lon", "min_lat", "max_lon", "max_lat")): bbox = { "min_lon": row.min_lon, "min_lat": row.min_lat, "max_lon": row.max_lon, "max_lat": row.max_lat, } return {"present": bool(getattr(row, "geometry_geojson", None)), "bbox": bbox}