Files
meubility-workbench/app/source_catalog.py
2026-07-01 23:29:51 +02:00

310 lines
11 KiB
Python

from __future__ import annotations
import csv
import hashlib
from datetime import datetime, timezone
from pathlib import Path
from typing import Iterable
from sqlalchemy import func, or_, select
from sqlalchemy.orm import Session
from app.models import Source, SourceCatalogEntry
DIRECT_INGEST_KINDS = {"gtfs", "osm_geojson", "osm_pbf"}
def default_source_catalog_path() -> Path:
return Path(__file__).resolve().parents[1] / "docs" / "source_catalog_seed.csv"
def default_ingestable_sources_path() -> Path:
return Path(__file__).resolve().parents[1] / "docs" / "ingestable_sources_seed.csv"
def import_source_catalog(session: Session, path: Path | str | None = None, *, update_existing: bool = True) -> dict[str, int]:
csv_path = _resolve_path(path, default_source_catalog_path())
rows = _read_csv(csv_path)
created = 0
updated = 0
skipped = 0
for row in rows:
source_name = _value(row, "Source name")
if not source_name:
skipped += 1
continue
payload = {
"catalog_key": _catalog_key(row),
"geography": _value(row, "Geography"),
"country_code": _value(row, "Country code"),
"mode_scope": _value(row, "Mode scope"),
"source_name": source_name,
"source_category": _value(row, "Source category"),
"formats_apis": _value(row, "Formats / APIs"),
"availability": _value(row, "Availability"),
"coverage_notes": _value(row, "Coverage notes"),
"geometry_notes": _value(row, "Supersedes OSM for"),
"disruptions_closures": _value(row, "Disruptions / closures"),
"operator_list_use": _value(row, "Operator-list use"),
"access_license_notes": _value(row, "Access / licence notes"),
"priority": _value(row, "Priority"),
"source_url": _value(row, "Source URL"),
"evidence_url": _value(row, "Evidence URL"),
"next_pipeline_action": _value(row, "Next pipeline action"),
}
existing = session.scalar(select(SourceCatalogEntry).where(SourceCatalogEntry.catalog_key == payload["catalog_key"]))
if existing is None:
session.add(SourceCatalogEntry(**payload))
created += 1
continue
if not update_existing:
skipped += 1
continue
for key, value in payload.items():
setattr(existing, key, value)
existing.updated_at = datetime.now(timezone.utc)
updated += 1
session.flush()
return {"created": created, "updated": updated, "skipped": skipped}
def import_ingestable_sources(
session: Session,
path: Path | str | None = None,
*,
update_existing: bool = True,
) -> dict[str, int]:
csv_path = _resolve_path(path, default_ingestable_sources_path())
rows = _read_csv(csv_path)
created = 0
updated = 0
skipped = 0
linked_catalog = 0
for row in rows:
name = _value(row, "name")
kind = (_value(row, "kind") or "").lower()
url = _value(row, "url")
if not name or not url or kind not in DIRECT_INGEST_KINDS:
skipped += 1
continue
catalog_entry = _catalog_entry_for_ingestable_row(session, row)
payload = {
"name": name,
"kind": kind,
"url": url,
"country": _value(row, "country"),
"license": _value(row, "license"),
"priority": _value(row, "priority"),
"mode_scope": _value(row, "mode_scope"),
"source_basis": _value(row, "source_basis"),
"notes": _value(row, "notes"),
"catalog_entry_id": None if catalog_entry is None else catalog_entry.id,
}
existing = session.scalar(
select(Source)
.where(Source.kind == kind, Source.url == url)
.order_by(Source.id)
.limit(1)
)
if existing is None:
existing = session.scalar(select(Source).where(Source.name == name, Source.url == url).order_by(Source.id).limit(1))
if existing is None:
session.add(Source(**payload))
created += 1
if catalog_entry is not None:
linked_catalog += 1
continue
if not update_existing:
skipped += 1
continue
for key, value in payload.items():
setattr(existing, key, value)
existing.enabled = True
updated += 1
if catalog_entry is not None:
linked_catalog += 1
session.flush()
return {"created": created, "updated": updated, "skipped": skipped, "linked_catalog": linked_catalog}
def source_catalog_summary(session: Session) -> dict[str, object]:
priority_counts = {
priority or "unknown": count
for priority, count in session.execute(
select(SourceCatalogEntry.priority, func.count()).group_by(SourceCatalogEntry.priority)
).all()
}
status_counts = {
status or "unknown": count
for status, count in session.execute(select(SourceCatalogEntry.status, func.count()).group_by(SourceCatalogEntry.status)).all()
}
ingestable_sources = session.scalar(
select(func.count()).select_from(Source).where(Source.source_basis.is_not(None) | Source.priority.is_not(None))
) or 0
return {
"catalog_entries": session.scalar(select(func.count()).select_from(SourceCatalogEntry)) or 0,
"catalog_by_priority": priority_counts,
"catalog_by_status": status_counts,
"seeded_ingestable_sources": ingestable_sources,
}
def source_catalog_rows(
session: Session,
*,
q: str | None = None,
country: str | None = None,
priority: str | None = None,
status: str | None = None,
limit: int = 100,
) -> list[SourceCatalogEntry]:
stmt = select(SourceCatalogEntry).order_by(
SourceCatalogEntry.priority,
SourceCatalogEntry.country_code,
SourceCatalogEntry.source_name,
SourceCatalogEntry.id,
)
if q:
pattern = f"%{q.strip()}%"
stmt = stmt.where(
or_(
SourceCatalogEntry.source_name.ilike(pattern),
SourceCatalogEntry.source_category.ilike(pattern),
SourceCatalogEntry.formats_apis.ilike(pattern),
SourceCatalogEntry.coverage_notes.ilike(pattern),
SourceCatalogEntry.next_pipeline_action.ilike(pattern),
)
)
if country:
stmt = stmt.where(SourceCatalogEntry.country_code.ilike(f"%{country.strip()}%"))
if priority:
stmt = stmt.where(SourceCatalogEntry.priority == priority.strip())
if status:
stmt = stmt.where(SourceCatalogEntry.status == status.strip())
return session.scalars(stmt.limit(max(1, min(limit, 500)))).all()
def catalog_entry_payload(entry: SourceCatalogEntry, *, linked_source_count: int = 0) -> dict[str, object]:
return {
"id": entry.id,
"geography": entry.geography,
"country_code": entry.country_code,
"mode_scope": entry.mode_scope,
"source_name": entry.source_name,
"source_category": entry.source_category,
"formats_apis": entry.formats_apis,
"availability": entry.availability,
"coverage_notes": entry.coverage_notes,
"geometry_notes": entry.geometry_notes,
"disruptions_closures": entry.disruptions_closures,
"operator_list_use": entry.operator_list_use,
"access_license_notes": entry.access_license_notes,
"priority": entry.priority,
"source_url": entry.source_url,
"evidence_url": entry.evidence_url,
"next_pipeline_action": entry.next_pipeline_action,
"status": entry.status,
"linked_source_count": linked_source_count,
"created_at": entry.created_at.isoformat() if entry.created_at else None,
"updated_at": entry.updated_at.isoformat() if entry.updated_at else None,
}
def linked_source_counts(session: Session, entries: Iterable[SourceCatalogEntry]) -> dict[int, int]:
entry_ids = [entry.id for entry in entries]
if not entry_ids:
return {}
return {
entry_id: count
for entry_id, count in session.execute(
select(Source.catalog_entry_id, func.count())
.where(Source.catalog_entry_id.in_(entry_ids))
.group_by(Source.catalog_entry_id)
).all()
if entry_id is not None
}
def _catalog_entry_for_ingestable_row(session: Session, row: dict[str, str]) -> SourceCatalogEntry | None:
country = _value(row, "country")
source_basis = _value(row, "source_basis")
name = _value(row, "name")
if not country and not source_basis and not name:
return None
if name:
exact = session.scalar(
select(SourceCatalogEntry)
.where(func.lower(SourceCatalogEntry.source_name) == name.lower())
.order_by(SourceCatalogEntry.id)
.limit(1)
)
if exact is not None:
return exact
clauses = []
if country:
clauses.append(SourceCatalogEntry.country_code.ilike(f"%{country}%"))
if source_basis:
for token in _basis_tokens(source_basis):
clauses.append(SourceCatalogEntry.source_name.ilike(f"%{token}%"))
clauses.append(SourceCatalogEntry.coverage_notes.ilike(f"%{token}%"))
if name:
first_word = name.split()[0]
if len(first_word) > 2:
clauses.append(SourceCatalogEntry.source_name.ilike(f"%{first_word}%"))
if not clauses:
return None
return session.scalar(
select(SourceCatalogEntry)
.where(or_(*clauses))
.order_by(SourceCatalogEntry.priority, SourceCatalogEntry.id)
.limit(1)
)
def _basis_tokens(value: str) -> list[str]:
tokens = []
for raw in value.replace("/", " ").replace("-", " ").split():
token = raw.strip(" ,.;()")
if len(token) >= 5 and token.lower() not in {"official", "mirror", "feeds", "transport"}:
tokens.append(token)
return tokens[:4]
def _catalog_key(row: dict[str, str]) -> str:
parts = [
_value(row, "Country code"),
_value(row, "Source name"),
_value(row, "Source URL"),
_value(row, "Formats / APIs"),
]
text = "|".join(part.lower() for part in parts if part)
if not text:
text = repr(sorted(row.items()))
return hashlib.sha256(text.encode("utf-8")).hexdigest()
def _read_csv(path: Path) -> list[dict[str, str]]:
if not path.exists():
raise FileNotFoundError(path)
with path.open("r", encoding="utf-8-sig", newline="") as handle:
reader = csv.DictReader(handle)
return [dict(row) for row in reader]
def _resolve_path(path: Path | str | None, default_path: Path) -> Path:
if path is None:
return default_path
candidate = Path(path)
if candidate.is_absolute():
return candidate
return Path.cwd() / candidate
def _value(row: dict[str, str], key: str) -> str | None:
value = row.get(key)
if value is None:
return None
stripped = value.strip()
return stripped or None