253 lines
10 KiB
Python
253 lines
10 KiB
Python
from __future__ import annotations
|
|
|
|
from sqlalchemy import func, or_, select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.gtfs_storage import execute_sidecar_query, uses_sidecar_stop_times
|
|
from app.models import Dataset, GtfsRoute, GtfsShape, GtfsStopTime, GtfsTrip, OsmFeature, RoutePattern, Source
|
|
from app.osm_storage import osm_feature_public_id, query_osm_features
|
|
from app.pipeline.utils import norm_ref
|
|
|
|
|
|
def search_datasets(session: Session, query: str, *, active_only: bool = False, limit: int = 80) -> dict:
|
|
q = (query or "").strip()
|
|
if len(q) < 1:
|
|
return {"query": q, "gtfs_routes": [], "osm_routes": [], "route_patterns": [], "totals": {}}
|
|
max_rows = max(1, min(limit, 250))
|
|
gtfs_routes = _gtfs_route_hits(session, q, active_only=active_only, limit=max_rows)
|
|
osm_routes = _osm_route_hits(session, q, active_only=active_only, limit=max_rows)
|
|
route_patterns = _route_pattern_hits(session, q, limit=max_rows)
|
|
return {
|
|
"query": q,
|
|
"gtfs_routes": gtfs_routes,
|
|
"osm_routes": osm_routes,
|
|
"route_patterns": route_patterns,
|
|
"totals": {
|
|
"gtfs_routes": len(gtfs_routes),
|
|
"osm_routes": len(osm_routes),
|
|
"route_patterns": len(route_patterns),
|
|
},
|
|
}
|
|
|
|
|
|
def _gtfs_route_hits(session: Session, query: str, *, active_only: bool, limit: int) -> list[dict]:
|
|
pattern = f"%{query}%"
|
|
ref = norm_ref(query)
|
|
stmt = (
|
|
select(GtfsRoute, Dataset, Source)
|
|
.join(Dataset, Dataset.id == GtfsRoute.dataset_id)
|
|
.join(Source, Source.id == Dataset.source_id)
|
|
.where(
|
|
or_(
|
|
GtfsRoute.short_name.ilike(pattern),
|
|
GtfsRoute.route_id.ilike(pattern),
|
|
GtfsRoute.long_name.ilike(pattern),
|
|
GtfsRoute.route_key == ref,
|
|
)
|
|
)
|
|
.order_by(Dataset.is_active.desc(), Source.name, GtfsRoute.short_name, GtfsRoute.route_id)
|
|
.limit(limit)
|
|
)
|
|
if active_only:
|
|
stmt = stmt.where(Dataset.is_active.is_(True))
|
|
rows = session.execute(stmt).all()
|
|
route_ids = [route.id for route, _, _ in rows]
|
|
trip_counts = _trip_counts(session, route_ids)
|
|
stop_time_counts = _stop_time_counts(session, route_ids)
|
|
shape_counts = _shape_counts(session, route_ids)
|
|
return [
|
|
{
|
|
"type": "gtfs_route",
|
|
"source": _source_payload(source),
|
|
"dataset": _dataset_payload(dataset),
|
|
"route": {
|
|
"id": route.id,
|
|
"route_id": route.route_id,
|
|
"ref": route.short_name,
|
|
"name": route.long_name,
|
|
"mode": route.mode,
|
|
"operator": route.operator_name,
|
|
},
|
|
"geometry": _geometry_payload(route),
|
|
"timetable": {
|
|
"trips": trip_counts.get(route.id, 0),
|
|
"stop_times": stop_time_counts.get(route.id, 0),
|
|
"shapes": shape_counts.get(route.id, 0),
|
|
},
|
|
}
|
|
for route, dataset, source in rows
|
|
]
|
|
|
|
|
|
def _osm_route_hits(session: Session, query: str, *, active_only: bool, limit: int) -> list[dict]:
|
|
ref = norm_ref(query)
|
|
dataset_stmt = select(Dataset).where(Dataset.kind == "osm_geojson")
|
|
if active_only:
|
|
dataset_stmt = dataset_stmt.where(Dataset.is_active.is_(True))
|
|
datasets = session.scalars(dataset_stmt.order_by(Dataset.is_active.desc(), Dataset.id)).all()
|
|
if not datasets:
|
|
return []
|
|
dataset_ids = [dataset.id for dataset in datasets]
|
|
sources = {source.id: source for source in session.scalars(select(Source).where(Source.id.in_([dataset.source_id for dataset in datasets]))).all()}
|
|
dataset_by_id = {dataset.id: dataset for dataset in datasets}
|
|
features_by_identity: dict[tuple[int, str, str], OsmFeature] = {}
|
|
for feature in query_osm_features(session, dataset_ids, kinds=["route"], search=query, limit=limit):
|
|
features_by_identity[(feature.dataset_id, feature.osm_type, feature.osm_id)] = feature
|
|
if ref:
|
|
for feature in query_osm_features(session, dataset_ids, kinds=["route"], route_key=ref, limit=limit):
|
|
features_by_identity[(feature.dataset_id, feature.osm_type, feature.osm_id)] = feature
|
|
features = sorted(
|
|
features_by_identity.values(),
|
|
key=lambda feature: (
|
|
0 if dataset_by_id.get(feature.dataset_id) and dataset_by_id[feature.dataset_id].is_active else 1,
|
|
sources.get(dataset_by_id[feature.dataset_id].source_id).name if dataset_by_id.get(feature.dataset_id) and sources.get(dataset_by_id[feature.dataset_id].source_id) else "",
|
|
feature.ref or "",
|
|
feature.name or "",
|
|
feature.id or 0,
|
|
),
|
|
)[:limit]
|
|
return [
|
|
{
|
|
"type": "osm_route",
|
|
"source": _source_payload(source),
|
|
"dataset": _dataset_payload(dataset),
|
|
"osm": {
|
|
"id": osm_feature_public_id(feature),
|
|
"osm_type": feature.osm_type,
|
|
"osm_id": feature.osm_id,
|
|
"ref": feature.ref,
|
|
"name": feature.name,
|
|
"mode": feature.mode,
|
|
"route_scope": feature.route_scope,
|
|
"operator": feature.operator,
|
|
"network": feature.network,
|
|
},
|
|
"geometry": _geometry_payload(feature),
|
|
}
|
|
for feature in features
|
|
if (dataset := dataset_by_id.get(feature.dataset_id)) is not None
|
|
if (source := sources.get(dataset.source_id)) is not None
|
|
]
|
|
|
|
|
|
def _route_pattern_hits(session: Session, query: str, *, limit: int) -> list[dict]:
|
|
pattern = f"%{query}%"
|
|
ref = norm_ref(query)
|
|
stmt = (
|
|
select(RoutePattern)
|
|
.where(
|
|
or_(
|
|
RoutePattern.route_ref.ilike(pattern),
|
|
RoutePattern.route_name.ilike(pattern),
|
|
RoutePattern.pattern_key.ilike(pattern),
|
|
)
|
|
)
|
|
.order_by(RoutePattern.source_kind, RoutePattern.route_ref, RoutePattern.id)
|
|
.limit(limit)
|
|
)
|
|
rows = session.scalars(stmt).all()
|
|
return [
|
|
{
|
|
"type": "route_pattern",
|
|
"id": pattern_row.id,
|
|
"ref": pattern_row.route_ref,
|
|
"name": pattern_row.route_name,
|
|
"mode": pattern_row.mode,
|
|
"route_scope": pattern_row.route_scope,
|
|
"source_kind": pattern_row.source_kind,
|
|
"status": pattern_row.status,
|
|
"confidence": pattern_row.confidence,
|
|
"gtfs_route_id": pattern_row.gtfs_route_id,
|
|
"osm_feature_id": pattern_row.osm_feature_id,
|
|
"geometry": _geometry_payload(pattern_row),
|
|
}
|
|
for pattern_row in rows
|
|
if not ref or norm_ref(pattern_row.route_ref or pattern_row.route_name or "") == ref or query.lower() in (pattern_row.route_name or "").lower()
|
|
]
|
|
|
|
|
|
def _trip_counts(session: Session, route_row_ids: list[int]) -> dict[int, int]:
|
|
if not route_row_ids:
|
|
return {}
|
|
rows = session.execute(
|
|
select(GtfsRoute.id, func.count(GtfsTrip.id))
|
|
.join(GtfsTrip, (GtfsTrip.dataset_id == GtfsRoute.dataset_id) & (GtfsTrip.route_id == GtfsRoute.route_id))
|
|
.where(GtfsRoute.id.in_(route_row_ids))
|
|
.group_by(GtfsRoute.id)
|
|
).all()
|
|
return {int(route_id): int(count) for route_id, count in rows}
|
|
|
|
|
|
def _stop_time_counts(session: Session, route_row_ids: list[int]) -> dict[int, int]:
|
|
if not route_row_ids:
|
|
return {}
|
|
routes = session.scalars(select(GtfsRoute).where(GtfsRoute.id.in_(route_row_ids))).all()
|
|
sidecar_routes = [route for route in routes if uses_sidecar_stop_times(session, route.dataset_id)]
|
|
sidecar_route_ids = {route.id for route in sidecar_routes}
|
|
main_route_ids = [route.id for route in routes if route.id not in sidecar_route_ids]
|
|
counts: dict[int, int] = {}
|
|
if main_route_ids:
|
|
rows = session.execute(
|
|
select(GtfsRoute.id, func.count(GtfsStopTime.id))
|
|
.join(GtfsTrip, (GtfsTrip.dataset_id == GtfsRoute.dataset_id) & (GtfsTrip.route_id == GtfsRoute.route_id))
|
|
.join(GtfsStopTime, (GtfsStopTime.dataset_id == GtfsTrip.dataset_id) & (GtfsStopTime.trip_id == GtfsTrip.trip_id))
|
|
.where(GtfsRoute.id.in_(main_route_ids))
|
|
.group_by(GtfsRoute.id)
|
|
).all()
|
|
counts.update({int(route_id): int(count) for route_id, count in rows})
|
|
for route in sidecar_routes:
|
|
rows = execute_sidecar_query(
|
|
session,
|
|
route.dataset_id,
|
|
"""
|
|
SELECT COUNT(*) AS count
|
|
FROM gtfs_stop_times AS stop_times
|
|
JOIN gtfs_trips AS trips
|
|
ON trips.trip_id = stop_times.trip_id
|
|
WHERE trips.route_id = ?
|
|
""",
|
|
[route.route_id],
|
|
)
|
|
counts[int(route.id)] = int(rows[0]["count"] or 0) if rows else 0
|
|
return counts
|
|
|
|
|
|
def _shape_counts(session: Session, route_row_ids: list[int]) -> dict[int, int]:
|
|
if not route_row_ids:
|
|
return {}
|
|
rows = session.execute(
|
|
select(GtfsRoute.id, func.count(func.distinct(GtfsShape.shape_id)))
|
|
.join(GtfsTrip, (GtfsTrip.dataset_id == GtfsRoute.dataset_id) & (GtfsTrip.route_id == GtfsRoute.route_id))
|
|
.join(GtfsShape, (GtfsShape.dataset_id == GtfsTrip.dataset_id) & (GtfsShape.shape_id == GtfsTrip.shape_id))
|
|
.where(GtfsRoute.id.in_(route_row_ids))
|
|
.group_by(GtfsRoute.id)
|
|
).all()
|
|
return {int(route_id): int(count) for route_id, count in rows}
|
|
|
|
|
|
def _source_payload(source: Source) -> dict:
|
|
return {"id": source.id, "name": source.name, "kind": source.kind, "country": source.country}
|
|
|
|
|
|
def _dataset_payload(dataset: Dataset) -> dict:
|
|
return {
|
|
"id": dataset.id,
|
|
"kind": dataset.kind,
|
|
"is_active": dataset.is_active,
|
|
"status": dataset.status,
|
|
"created_at": dataset.created_at.isoformat() if dataset.created_at else None,
|
|
"sha256": dataset.sha256,
|
|
}
|
|
|
|
|
|
def _geometry_payload(row) -> dict:
|
|
bbox = None
|
|
if all(getattr(row, attr, None) is not None for attr in ("min_lon", "min_lat", "max_lon", "max_lat")):
|
|
bbox = {
|
|
"min_lon": row.min_lon,
|
|
"min_lat": row.min_lat,
|
|
"max_lon": row.max_lon,
|
|
"max_lat": row.max_lat,
|
|
}
|
|
return {"present": bool(getattr(row, "geometry_geojson", None)), "bbox": bbox}
|