from __future__ import annotations import csv import hashlib from datetime import datetime, timezone from pathlib import Path from typing import Iterable from sqlalchemy import func, or_, select from sqlalchemy.orm import Session from app.models import Source, SourceCatalogEntry DIRECT_INGEST_KINDS = {"gtfs", "osm_geojson", "osm_pbf"} def default_source_catalog_path() -> Path: return Path(__file__).resolve().parents[1] / "docs" / "source_catalog_seed.csv" def default_ingestable_sources_path() -> Path: return Path(__file__).resolve().parents[1] / "docs" / "ingestable_sources_seed.csv" def import_source_catalog(session: Session, path: Path | str | None = None, *, update_existing: bool = True) -> dict[str, int]: csv_path = _resolve_path(path, default_source_catalog_path()) rows = _read_csv(csv_path) created = 0 updated = 0 skipped = 0 for row in rows: source_name = _value(row, "Source name") if not source_name: skipped += 1 continue payload = { "catalog_key": _catalog_key(row), "geography": _value(row, "Geography"), "country_code": _value(row, "Country code"), "mode_scope": _value(row, "Mode scope"), "source_name": source_name, "source_category": _value(row, "Source category"), "formats_apis": _value(row, "Formats / APIs"), "availability": _value(row, "Availability"), "coverage_notes": _value(row, "Coverage notes"), "geometry_notes": _value(row, "Supersedes OSM for"), "disruptions_closures": _value(row, "Disruptions / closures"), "operator_list_use": _value(row, "Operator-list use"), "access_license_notes": _value(row, "Access / licence notes"), "priority": _value(row, "Priority"), "source_url": _value(row, "Source URL"), "evidence_url": _value(row, "Evidence URL"), "next_pipeline_action": _value(row, "Next pipeline action"), } existing = session.scalar(select(SourceCatalogEntry).where(SourceCatalogEntry.catalog_key == payload["catalog_key"])) if existing is None: session.add(SourceCatalogEntry(**payload)) created += 1 continue if not update_existing: skipped += 1 continue for key, value in payload.items(): setattr(existing, key, value) existing.updated_at = datetime.now(timezone.utc) updated += 1 session.flush() return {"created": created, "updated": updated, "skipped": skipped} def import_ingestable_sources( session: Session, path: Path | str | None = None, *, update_existing: bool = True, ) -> dict[str, int]: csv_path = _resolve_path(path, default_ingestable_sources_path()) rows = _read_csv(csv_path) created = 0 updated = 0 skipped = 0 linked_catalog = 0 for row in rows: name = _value(row, "name") kind = (_value(row, "kind") or "").lower() url = _value(row, "url") if not name or not url or kind not in DIRECT_INGEST_KINDS: skipped += 1 continue catalog_entry = _catalog_entry_for_ingestable_row(session, row) payload = { "name": name, "kind": kind, "url": url, "country": _value(row, "country"), "license": _value(row, "license"), "priority": _value(row, "priority"), "mode_scope": _value(row, "mode_scope"), "source_basis": _value(row, "source_basis"), "notes": _value(row, "notes"), "catalog_entry_id": None if catalog_entry is None else catalog_entry.id, } existing = session.scalar( select(Source) .where(Source.kind == kind, Source.url == url) .order_by(Source.id) .limit(1) ) if existing is None: existing = session.scalar(select(Source).where(Source.name == name, Source.url == url).order_by(Source.id).limit(1)) if existing is None: session.add(Source(**payload)) created += 1 if catalog_entry is not None: linked_catalog += 1 continue if not update_existing: skipped += 1 continue for key, value in payload.items(): setattr(existing, key, value) existing.enabled = True updated += 1 if catalog_entry is not None: linked_catalog += 1 session.flush() return {"created": created, "updated": updated, "skipped": skipped, "linked_catalog": linked_catalog} def source_catalog_summary(session: Session) -> dict[str, object]: priority_counts = { priority or "unknown": count for priority, count in session.execute( select(SourceCatalogEntry.priority, func.count()).group_by(SourceCatalogEntry.priority) ).all() } status_counts = { status or "unknown": count for status, count in session.execute(select(SourceCatalogEntry.status, func.count()).group_by(SourceCatalogEntry.status)).all() } ingestable_sources = session.scalar( select(func.count()).select_from(Source).where(Source.source_basis.is_not(None) | Source.priority.is_not(None)) ) or 0 return { "catalog_entries": session.scalar(select(func.count()).select_from(SourceCatalogEntry)) or 0, "catalog_by_priority": priority_counts, "catalog_by_status": status_counts, "seeded_ingestable_sources": ingestable_sources, } def source_catalog_rows( session: Session, *, q: str | None = None, country: str | None = None, priority: str | None = None, status: str | None = None, limit: int = 100, ) -> list[SourceCatalogEntry]: stmt = select(SourceCatalogEntry).order_by( SourceCatalogEntry.priority, SourceCatalogEntry.country_code, SourceCatalogEntry.source_name, SourceCatalogEntry.id, ) if q: pattern = f"%{q.strip()}%" stmt = stmt.where( or_( SourceCatalogEntry.source_name.ilike(pattern), SourceCatalogEntry.source_category.ilike(pattern), SourceCatalogEntry.formats_apis.ilike(pattern), SourceCatalogEntry.coverage_notes.ilike(pattern), SourceCatalogEntry.next_pipeline_action.ilike(pattern), ) ) if country: stmt = stmt.where(SourceCatalogEntry.country_code.ilike(f"%{country.strip()}%")) if priority: stmt = stmt.where(SourceCatalogEntry.priority == priority.strip()) if status: stmt = stmt.where(SourceCatalogEntry.status == status.strip()) return session.scalars(stmt.limit(max(1, min(limit, 500)))).all() def catalog_entry_payload(entry: SourceCatalogEntry, *, linked_source_count: int = 0) -> dict[str, object]: return { "id": entry.id, "geography": entry.geography, "country_code": entry.country_code, "mode_scope": entry.mode_scope, "source_name": entry.source_name, "source_category": entry.source_category, "formats_apis": entry.formats_apis, "availability": entry.availability, "coverage_notes": entry.coverage_notes, "geometry_notes": entry.geometry_notes, "disruptions_closures": entry.disruptions_closures, "operator_list_use": entry.operator_list_use, "access_license_notes": entry.access_license_notes, "priority": entry.priority, "source_url": entry.source_url, "evidence_url": entry.evidence_url, "next_pipeline_action": entry.next_pipeline_action, "status": entry.status, "linked_source_count": linked_source_count, "created_at": entry.created_at.isoformat() if entry.created_at else None, "updated_at": entry.updated_at.isoformat() if entry.updated_at else None, } def linked_source_counts(session: Session, entries: Iterable[SourceCatalogEntry]) -> dict[int, int]: entry_ids = [entry.id for entry in entries] if not entry_ids: return {} return { entry_id: count for entry_id, count in session.execute( select(Source.catalog_entry_id, func.count()) .where(Source.catalog_entry_id.in_(entry_ids)) .group_by(Source.catalog_entry_id) ).all() if entry_id is not None } def _catalog_entry_for_ingestable_row(session: Session, row: dict[str, str]) -> SourceCatalogEntry | None: country = _value(row, "country") source_basis = _value(row, "source_basis") name = _value(row, "name") if not country and not source_basis and not name: return None if name: exact = session.scalar( select(SourceCatalogEntry) .where(func.lower(SourceCatalogEntry.source_name) == name.lower()) .order_by(SourceCatalogEntry.id) .limit(1) ) if exact is not None: return exact clauses = [] if country: clauses.append(SourceCatalogEntry.country_code.ilike(f"%{country}%")) if source_basis: for token in _basis_tokens(source_basis): clauses.append(SourceCatalogEntry.source_name.ilike(f"%{token}%")) clauses.append(SourceCatalogEntry.coverage_notes.ilike(f"%{token}%")) if name: first_word = name.split()[0] if len(first_word) > 2: clauses.append(SourceCatalogEntry.source_name.ilike(f"%{first_word}%")) if not clauses: return None return session.scalar( select(SourceCatalogEntry) .where(or_(*clauses)) .order_by(SourceCatalogEntry.priority, SourceCatalogEntry.id) .limit(1) ) def _basis_tokens(value: str) -> list[str]: tokens = [] for raw in value.replace("/", " ").replace("-", " ").split(): token = raw.strip(" ,.;()") if len(token) >= 5 and token.lower() not in {"official", "mirror", "feeds", "transport"}: tokens.append(token) return tokens[:4] def _catalog_key(row: dict[str, str]) -> str: parts = [ _value(row, "Country code"), _value(row, "Source name"), _value(row, "Source URL"), _value(row, "Formats / APIs"), ] text = "|".join(part.lower() for part in parts if part) if not text: text = repr(sorted(row.items())) return hashlib.sha256(text.encode("utf-8")).hexdigest() def _read_csv(path: Path) -> list[dict[str, str]]: if not path.exists(): raise FileNotFoundError(path) with path.open("r", encoding="utf-8-sig", newline="") as handle: reader = csv.DictReader(handle) return [dict(row) for row in reader] def _resolve_path(path: Path | str | None, default_path: Path) -> Path: if path is None: return default_path candidate = Path(path) if candidate.is_absolute(): return candidate return Path.cwd() / candidate def _value(row: dict[str, str], key: str) -> str | None: value = row.get(key) if value is None: return None stripped = value.strip() return stripped or None