from __future__ import annotations from pathlib import Path from sqlalchemy import delete, func, or_, select from sqlalchemy.orm import Session from app.config import settings from app.gtfs_storage import dataset_sidecar_paths as gtfs_dataset_sidecar_paths, missing_sidecar_paths as gtfs_missing_sidecar_paths, stop_time_count from app.models import ( CanonicalStopLink, Dataset, GtfsAgency, GtfsCalendar, GtfsCalendarDate, GtfsRoute, GtfsRoutePatternLink, GtfsShape, GtfsStop, GtfsStopTime, GtfsTripRoutePatternLink, GtfsTrip, OsmDiffState, OsmFeature, RouteMatch, RoutePattern, RoutePatternStop, Source, SourceUpdateCheck, ) from app.osm_storage import ( dataset_sidecar_paths as osm_dataset_sidecar_paths, missing_sidecar_paths as osm_missing_sidecar_paths, osm_feature_count, ) def dataset_row_counts(session: Session, dataset_id: int, kind: str) -> dict[str, int]: if kind == "gtfs": route_ids = select(GtfsRoute.id).where(GtfsRoute.dataset_id == dataset_id) match_counts = { status: count for status, count in session.execute( select(RouteMatch.status, func.count()) .where(RouteMatch.gtfs_route_id.in_(route_ids)) .group_by(RouteMatch.status) ).all() } return { "agencies": _count(session, GtfsAgency, dataset_id), "stops": _count(session, GtfsStop, dataset_id), "routes": _count(session, GtfsRoute, dataset_id), "trips": _count(session, GtfsTrip, dataset_id), "calendars": _count(session, GtfsCalendar, dataset_id), "calendar_dates": _count(session, GtfsCalendarDate, dataset_id), "shapes": _count(session, GtfsShape, dataset_id), "stop_times": stop_time_count(session, dataset_id), "missing_sidecar": _gtfs_sidecar_missing(session, dataset_id), "matches": sum(match_counts.values()), "match_counts": match_counts, } if kind == "osm_geojson": return { "features": _safe_osm_feature_count(session, dataset_id), "routes": _safe_osm_feature_count(session, dataset_id, kind="route"), "stops": _safe_osm_feature_count(session, dataset_id, kind=["stop", "station", "terminal"]), "infra": _safe_osm_feature_count(session, dataset_id, kind="infra"), "missing_sidecar": _osm_sidecar_missing(session, dataset_id), } return {} def source_row_counts(session: Session, source: Source) -> dict[str, object]: counts = { "datasets": len(source.datasets), "active_datasets": sum(1 for dataset in source.datasets if dataset.is_active), "routes": 0, "stops": 0, "features": 0, "trips": 0, "shapes": 0, "stop_times": 0, "missing_sidecars": 0, "match_counts": {}, "missing_gtfs_sidecars": 0, "missing_osm_sidecars": 0, } match_counts: dict[str, int] = {} for dataset in source.datasets: stats = dataset_row_counts(session, dataset.id, dataset.kind) counts["routes"] += int(stats.get("routes", 0)) counts["stops"] += int(stats.get("stops", 0)) counts["features"] += int(stats.get("features", 0)) counts["trips"] += int(stats.get("trips", 0)) counts["shapes"] += int(stats.get("shapes", 0)) counts["stop_times"] += int(stats.get("stop_times", 0)) if stats.get("missing_sidecar"): counts["missing_sidecars"] += 1 if dataset.kind == "gtfs": counts["missing_gtfs_sidecars"] += 1 elif dataset.kind == "osm_geojson": counts["missing_osm_sidecars"] += 1 for status, count in stats.get("match_counts", {}).items(): match_counts[status] = match_counts.get(status, 0) + int(count) counts["match_counts"] = match_counts return counts def delete_dataset(session: Session, dataset_id: int) -> dict[str, object]: dataset = session.get(Dataset, dataset_id) if dataset is None: return {"deleted": False, "reason": "dataset not found", "dataset_id": dataset_id} counts = dataset_row_counts(session, dataset.id, dataset.kind) _detach_update_checks_for_dataset(session, dataset.id) session.execute(delete(OsmDiffState).where(OsmDiffState.raw_dataset_id == dataset.id)) _delete_dataset_rows(session, dataset) _delete_dataset_files(dataset) session.delete(dataset) session.flush() return {"deleted": True, "dataset_id": dataset_id, "counts": counts} def delete_source(session: Session, source_id: int) -> dict[str, object]: source = session.get(Source, source_id) if source is None: return {"deleted": False, "reason": "source not found", "source_id": source_id} datasets = list(source.datasets) dataset_results = [] for dataset in datasets: dataset_results.append({"dataset_id": dataset.id, "kind": dataset.kind, "counts": dataset_row_counts(session, dataset.id, dataset.kind)}) _detach_update_checks_for_dataset(session, dataset.id) session.execute(delete(OsmDiffState).where(OsmDiffState.raw_dataset_id == dataset.id)) _delete_dataset_rows(session, dataset) _delete_dataset_files(dataset) session.delete(dataset) session.execute(delete(OsmDiffState).where(OsmDiffState.source_id == source.id)) session.delete(source) session.flush() return {"deleted": True, "source_id": source_id, "datasets": dataset_results} def unreferenced_cache_file_summary(session: Session) -> dict[str, int]: candidates = _unreferenced_cache_files(session) return {"files": len(candidates), "bytes": sum(path.stat().st_size for path in candidates)} def prune_unreferenced_cache_files(session: Session) -> dict[str, int]: candidates = _unreferenced_cache_files(session) total_bytes = sum(path.stat().st_size for path in candidates) for path in candidates: path.unlink() for root in _cache_roots(): _remove_empty_dirs(root) return {"files": len(candidates), "bytes": total_bytes} def _unreferenced_cache_files(session: Session) -> list[Path]: referenced = { Path(path).resolve() for path in session.scalars(select(Dataset.local_path)).all() if path } for dataset in session.scalars(select(Dataset)).all(): referenced.update(path.resolve() for path in dataset_sidecar_paths(dataset)) return [ path for root in _cache_roots() if root.exists() for path in root.rglob("*") if path.is_file() and path.resolve() not in referenced ] def _cache_roots() -> list[Path]: # Staging files are not referenced by datasets until activation. Automatic # pruning must not remove a staging DB from a running import. return [settings.data_dir / "sources", settings.data_dir / "derived", settings.data_dir / "sidecars"] def prune_inactive_datasets(session: Session, dry_run: bool = True) -> dict[str, object]: dataset_rows = session.execute( select(Dataset.id, Dataset.kind).where(Dataset.is_active.is_(False), Dataset.kind.in_(["gtfs", "osm_geojson"])) ).all() dataset_ids = [int(row[0]) for row in dataset_rows] gtfs_ids = [int(dataset_id) for dataset_id, kind in dataset_rows if kind == "gtfs"] osm_ids = [int(dataset_id) for dataset_id, kind in dataset_rows if kind == "osm_geojson"] route_ids = select(GtfsRoute.id).where(GtfsRoute.dataset_id.in_(gtfs_ids)) if gtfs_ids else None osm_feature_ids = select(OsmFeature.id).where(OsmFeature.dataset_id.in_(osm_ids)) if osm_ids else None match_filters = [] if route_ids is not None: match_filters.append(RouteMatch.gtfs_route_id.in_(route_ids)) if osm_feature_ids is not None: match_filters.append(RouteMatch.osm_feature_id.in_(osm_feature_ids)) counts = { "datasets": len(dataset_ids), "gtfs_stop_times": sum(stop_time_count(session, dataset_id) for dataset_id in gtfs_ids), "gtfs_shapes": _count_dataset_rows(session, GtfsShape, gtfs_ids), "gtfs_trips": _count_dataset_rows(session, GtfsTrip, gtfs_ids), "gtfs_calendar_dates": _count_dataset_rows(session, GtfsCalendarDate, gtfs_ids), "gtfs_calendars": _count_dataset_rows(session, GtfsCalendar, gtfs_ids), "gtfs_routes": _count_dataset_rows(session, GtfsRoute, gtfs_ids), "gtfs_stops": _count_dataset_rows(session, GtfsStop, gtfs_ids), "gtfs_agencies": _count_dataset_rows(session, GtfsAgency, gtfs_ids), "osm_features": sum(_safe_osm_feature_count(session, dataset_id) for dataset_id in osm_ids), "missing_osm_sidecars": sum(1 for dataset_id in osm_ids if _osm_sidecar_missing(session, dataset_id)), "gtfs_route_pattern_links": session.scalar(select(func.count()).select_from(GtfsRoutePatternLink).where(GtfsRoutePatternLink.dataset_id.in_(gtfs_ids))) if gtfs_ids else 0, "gtfs_trip_route_pattern_links": session.scalar(select(func.count()).select_from(GtfsTripRoutePatternLink).where(GtfsTripRoutePatternLink.dataset_id.in_(gtfs_ids))) if gtfs_ids else 0, "canonical_stop_links": session.scalar(select(func.count()).select_from(CanonicalStopLink).where(CanonicalStopLink.dataset_id.in_(dataset_ids))) if dataset_ids else 0, "route_matches": session.scalar(select(func.count()).select_from(RouteMatch).where(or_(*match_filters))) if match_filters else 0, } if dry_run or not dataset_ids: return {"dry_run": dry_run, "dataset_ids": dataset_ids, "deleted": counts if not dry_run else {}, "would_delete": counts} for dataset_id in dataset_ids: _detach_update_checks_for_dataset(session, dataset_id) if match_filters: session.execute(delete(RouteMatch).where(or_(*match_filters))) if gtfs_ids: route_ids = select(GtfsRoute.id).where(GtfsRoute.dataset_id.in_(gtfs_ids)) pattern_ids = select(RoutePattern.id).where(RoutePattern.gtfs_route_id.in_(route_ids)) session.execute(delete(RoutePatternStop).where(RoutePatternStop.route_pattern_id.in_(pattern_ids))) session.execute(delete(GtfsTripRoutePatternLink).where(GtfsTripRoutePatternLink.dataset_id.in_(gtfs_ids))) session.execute(delete(GtfsRoutePatternLink).where(GtfsRoutePatternLink.dataset_id.in_(gtfs_ids))) session.execute(delete(RoutePattern).where(RoutePattern.gtfs_route_id.in_(route_ids))) session.execute(delete(CanonicalStopLink).where(CanonicalStopLink.dataset_id.in_(gtfs_ids), CanonicalStopLink.object_type == "gtfs_stop")) for model in [GtfsStopTime, GtfsShape, GtfsTrip, GtfsCalendarDate, GtfsCalendar, GtfsRoute, GtfsStop, GtfsAgency]: session.execute(delete(model).where(model.dataset_id.in_(gtfs_ids))) if osm_ids: osm_feature_ids = select(OsmFeature.id).where(OsmFeature.dataset_id.in_(osm_ids)) pattern_ids = select(RoutePattern.id).where(RoutePattern.osm_feature_id.in_(osm_feature_ids)) session.execute(delete(RoutePatternStop).where(RoutePatternStop.route_pattern_id.in_(pattern_ids))) session.execute(delete(GtfsTripRoutePatternLink).where(GtfsTripRoutePatternLink.route_pattern_id.in_(pattern_ids))) session.execute(delete(GtfsRoutePatternLink).where(GtfsRoutePatternLink.route_pattern_id.in_(pattern_ids))) session.execute(delete(RoutePattern).where(RoutePattern.osm_feature_id.in_(osm_feature_ids))) session.execute(delete(CanonicalStopLink).where(CanonicalStopLink.dataset_id.in_(osm_ids), CanonicalStopLink.object_type == "osm_feature")) session.execute(delete(OsmFeature).where(OsmFeature.dataset_id.in_(osm_ids))) for dataset in session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all(): _delete_dataset_files(dataset) session.execute(delete(Dataset).where(Dataset.id.in_(dataset_ids))) session.flush() return {"dry_run": dry_run, "dataset_ids": dataset_ids, "deleted": counts, "would_delete": {}} def _delete_dataset_rows(session: Session, dataset: Dataset) -> None: if dataset.kind == "gtfs": route_ids = select(GtfsRoute.id).where(GtfsRoute.dataset_id == dataset.id) pattern_ids = select(RoutePattern.id).where(RoutePattern.gtfs_route_id.in_(route_ids)) session.execute(delete(RouteMatch).where(RouteMatch.gtfs_route_id.in_(route_ids))) session.execute(delete(RoutePatternStop).where(RoutePatternStop.route_pattern_id.in_(pattern_ids))) session.execute(delete(GtfsTripRoutePatternLink).where(GtfsTripRoutePatternLink.dataset_id == dataset.id)) session.execute(delete(GtfsRoutePatternLink).where(GtfsRoutePatternLink.dataset_id == dataset.id)) session.execute(delete(RoutePattern).where(RoutePattern.gtfs_route_id.in_(route_ids))) session.execute(delete(CanonicalStopLink).where(CanonicalStopLink.dataset_id == dataset.id, CanonicalStopLink.object_type == "gtfs_stop")) for model in [GtfsStopTime, GtfsShape, GtfsTrip, GtfsCalendarDate, GtfsCalendar, GtfsRoute, GtfsStop, GtfsAgency]: session.execute(delete(model).where(model.dataset_id == dataset.id)) elif dataset.kind == "osm_geojson": osm_feature_ids = select(OsmFeature.id).where(OsmFeature.dataset_id == dataset.id) pattern_ids = select(RoutePattern.id).where(RoutePattern.osm_feature_id.in_(osm_feature_ids)) session.execute(delete(RouteMatch).where(RouteMatch.osm_feature_id.in_(osm_feature_ids))) session.execute(delete(RoutePatternStop).where(RoutePatternStop.route_pattern_id.in_(pattern_ids))) session.execute(delete(GtfsTripRoutePatternLink).where(GtfsTripRoutePatternLink.route_pattern_id.in_(pattern_ids))) session.execute(delete(GtfsRoutePatternLink).where(GtfsRoutePatternLink.route_pattern_id.in_(pattern_ids))) session.execute(delete(RoutePattern).where(RoutePattern.osm_feature_id.in_(osm_feature_ids))) session.execute(delete(CanonicalStopLink).where(CanonicalStopLink.dataset_id == dataset.id, CanonicalStopLink.object_type == "osm_feature")) session.execute(delete(OsmFeature).where(OsmFeature.dataset_id == dataset.id)) def _delete_dataset_files(dataset: Dataset) -> None: for path in dataset_sidecar_paths(dataset): try: path.unlink() except FileNotFoundError: pass def dataset_sidecar_paths(dataset: Dataset) -> list[Path]: return [*gtfs_dataset_sidecar_paths(dataset), *osm_dataset_sidecar_paths(dataset)] def _gtfs_sidecar_missing(session: Session, dataset_id: int) -> bool: dataset = session.get(Dataset, dataset_id) return bool(gtfs_missing_sidecar_paths(dataset)) def _safe_osm_feature_count(session: Session, dataset_id: int, *, kind=None) -> int: try: return osm_feature_count(session, dataset_id, kind=kind) except FileNotFoundError: return 0 def _osm_sidecar_missing(session: Session, dataset_id: int) -> bool: dataset = session.get(Dataset, dataset_id) return bool(osm_missing_sidecar_paths(dataset)) def _detach_update_checks_for_dataset(session: Session, dataset_id: int) -> None: for check in session.scalars(select(SourceUpdateCheck).where(SourceUpdateCheck.active_dataset_id == dataset_id)).all(): check.active_dataset_id = None def _count(session: Session, model, dataset_id: int) -> int: return session.scalar(select(func.count()).select_from(model).where(model.dataset_id == dataset_id)) or 0 def _count_where(session: Session, model, dataset_id: int, *where) -> int: return session.scalar(select(func.count()).select_from(model).where(model.dataset_id == dataset_id, *where)) or 0 def _count_dataset_rows(session: Session, model, dataset_ids: list[int]) -> int: if not dataset_ids: return 0 return session.scalar(select(func.count()).select_from(model).where(model.dataset_id.in_(dataset_ids))) or 0 def _remove_empty_dirs(root: Path) -> None: if not root.exists(): return for path in sorted((p for p in root.rglob("*") if p.is_dir()), key=lambda p: len(p.parts), reverse=True): try: path.rmdir() except OSError: pass