310 lines
11 KiB
Python
310 lines
11 KiB
Python
from __future__ import annotations
|
|
|
|
import csv
|
|
import hashlib
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Iterable
|
|
|
|
from sqlalchemy import func, or_, select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.models import Source, SourceCatalogEntry
|
|
|
|
|
|
DIRECT_INGEST_KINDS = {"gtfs", "osm_geojson", "osm_pbf"}
|
|
|
|
|
|
def default_source_catalog_path() -> Path:
|
|
return Path(__file__).resolve().parents[1] / "docs" / "source_catalog_seed.csv"
|
|
|
|
|
|
def default_ingestable_sources_path() -> Path:
|
|
return Path(__file__).resolve().parents[1] / "docs" / "ingestable_sources_seed.csv"
|
|
|
|
|
|
def import_source_catalog(session: Session, path: Path | str | None = None, *, update_existing: bool = True) -> dict[str, int]:
|
|
csv_path = _resolve_path(path, default_source_catalog_path())
|
|
rows = _read_csv(csv_path)
|
|
created = 0
|
|
updated = 0
|
|
skipped = 0
|
|
for row in rows:
|
|
source_name = _value(row, "Source name")
|
|
if not source_name:
|
|
skipped += 1
|
|
continue
|
|
payload = {
|
|
"catalog_key": _catalog_key(row),
|
|
"geography": _value(row, "Geography"),
|
|
"country_code": _value(row, "Country code"),
|
|
"mode_scope": _value(row, "Mode scope"),
|
|
"source_name": source_name,
|
|
"source_category": _value(row, "Source category"),
|
|
"formats_apis": _value(row, "Formats / APIs"),
|
|
"availability": _value(row, "Availability"),
|
|
"coverage_notes": _value(row, "Coverage notes"),
|
|
"geometry_notes": _value(row, "Supersedes OSM for"),
|
|
"disruptions_closures": _value(row, "Disruptions / closures"),
|
|
"operator_list_use": _value(row, "Operator-list use"),
|
|
"access_license_notes": _value(row, "Access / licence notes"),
|
|
"priority": _value(row, "Priority"),
|
|
"source_url": _value(row, "Source URL"),
|
|
"evidence_url": _value(row, "Evidence URL"),
|
|
"next_pipeline_action": _value(row, "Next pipeline action"),
|
|
}
|
|
existing = session.scalar(select(SourceCatalogEntry).where(SourceCatalogEntry.catalog_key == payload["catalog_key"]))
|
|
if existing is None:
|
|
session.add(SourceCatalogEntry(**payload))
|
|
created += 1
|
|
continue
|
|
if not update_existing:
|
|
skipped += 1
|
|
continue
|
|
for key, value in payload.items():
|
|
setattr(existing, key, value)
|
|
existing.updated_at = datetime.now(timezone.utc)
|
|
updated += 1
|
|
session.flush()
|
|
return {"created": created, "updated": updated, "skipped": skipped}
|
|
|
|
|
|
def import_ingestable_sources(
|
|
session: Session,
|
|
path: Path | str | None = None,
|
|
*,
|
|
update_existing: bool = True,
|
|
) -> dict[str, int]:
|
|
csv_path = _resolve_path(path, default_ingestable_sources_path())
|
|
rows = _read_csv(csv_path)
|
|
created = 0
|
|
updated = 0
|
|
skipped = 0
|
|
linked_catalog = 0
|
|
for row in rows:
|
|
name = _value(row, "name")
|
|
kind = (_value(row, "kind") or "").lower()
|
|
url = _value(row, "url")
|
|
if not name or not url or kind not in DIRECT_INGEST_KINDS:
|
|
skipped += 1
|
|
continue
|
|
catalog_entry = _catalog_entry_for_ingestable_row(session, row)
|
|
payload = {
|
|
"name": name,
|
|
"kind": kind,
|
|
"url": url,
|
|
"country": _value(row, "country"),
|
|
"license": _value(row, "license"),
|
|
"priority": _value(row, "priority"),
|
|
"mode_scope": _value(row, "mode_scope"),
|
|
"source_basis": _value(row, "source_basis"),
|
|
"notes": _value(row, "notes"),
|
|
"catalog_entry_id": None if catalog_entry is None else catalog_entry.id,
|
|
}
|
|
existing = session.scalar(
|
|
select(Source)
|
|
.where(Source.kind == kind, Source.url == url)
|
|
.order_by(Source.id)
|
|
.limit(1)
|
|
)
|
|
if existing is None:
|
|
existing = session.scalar(select(Source).where(Source.name == name, Source.url == url).order_by(Source.id).limit(1))
|
|
if existing is None:
|
|
session.add(Source(**payload))
|
|
created += 1
|
|
if catalog_entry is not None:
|
|
linked_catalog += 1
|
|
continue
|
|
if not update_existing:
|
|
skipped += 1
|
|
continue
|
|
for key, value in payload.items():
|
|
setattr(existing, key, value)
|
|
existing.enabled = True
|
|
updated += 1
|
|
if catalog_entry is not None:
|
|
linked_catalog += 1
|
|
session.flush()
|
|
return {"created": created, "updated": updated, "skipped": skipped, "linked_catalog": linked_catalog}
|
|
|
|
|
|
def source_catalog_summary(session: Session) -> dict[str, object]:
|
|
priority_counts = {
|
|
priority or "unknown": count
|
|
for priority, count in session.execute(
|
|
select(SourceCatalogEntry.priority, func.count()).group_by(SourceCatalogEntry.priority)
|
|
).all()
|
|
}
|
|
status_counts = {
|
|
status or "unknown": count
|
|
for status, count in session.execute(select(SourceCatalogEntry.status, func.count()).group_by(SourceCatalogEntry.status)).all()
|
|
}
|
|
ingestable_sources = session.scalar(
|
|
select(func.count()).select_from(Source).where(Source.source_basis.is_not(None) | Source.priority.is_not(None))
|
|
) or 0
|
|
return {
|
|
"catalog_entries": session.scalar(select(func.count()).select_from(SourceCatalogEntry)) or 0,
|
|
"catalog_by_priority": priority_counts,
|
|
"catalog_by_status": status_counts,
|
|
"seeded_ingestable_sources": ingestable_sources,
|
|
}
|
|
|
|
|
|
def source_catalog_rows(
|
|
session: Session,
|
|
*,
|
|
q: str | None = None,
|
|
country: str | None = None,
|
|
priority: str | None = None,
|
|
status: str | None = None,
|
|
limit: int = 100,
|
|
) -> list[SourceCatalogEntry]:
|
|
stmt = select(SourceCatalogEntry).order_by(
|
|
SourceCatalogEntry.priority,
|
|
SourceCatalogEntry.country_code,
|
|
SourceCatalogEntry.source_name,
|
|
SourceCatalogEntry.id,
|
|
)
|
|
if q:
|
|
pattern = f"%{q.strip()}%"
|
|
stmt = stmt.where(
|
|
or_(
|
|
SourceCatalogEntry.source_name.ilike(pattern),
|
|
SourceCatalogEntry.source_category.ilike(pattern),
|
|
SourceCatalogEntry.formats_apis.ilike(pattern),
|
|
SourceCatalogEntry.coverage_notes.ilike(pattern),
|
|
SourceCatalogEntry.next_pipeline_action.ilike(pattern),
|
|
)
|
|
)
|
|
if country:
|
|
stmt = stmt.where(SourceCatalogEntry.country_code.ilike(f"%{country.strip()}%"))
|
|
if priority:
|
|
stmt = stmt.where(SourceCatalogEntry.priority == priority.strip())
|
|
if status:
|
|
stmt = stmt.where(SourceCatalogEntry.status == status.strip())
|
|
return session.scalars(stmt.limit(max(1, min(limit, 500)))).all()
|
|
|
|
|
|
def catalog_entry_payload(entry: SourceCatalogEntry, *, linked_source_count: int = 0) -> dict[str, object]:
|
|
return {
|
|
"id": entry.id,
|
|
"geography": entry.geography,
|
|
"country_code": entry.country_code,
|
|
"mode_scope": entry.mode_scope,
|
|
"source_name": entry.source_name,
|
|
"source_category": entry.source_category,
|
|
"formats_apis": entry.formats_apis,
|
|
"availability": entry.availability,
|
|
"coverage_notes": entry.coverage_notes,
|
|
"geometry_notes": entry.geometry_notes,
|
|
"disruptions_closures": entry.disruptions_closures,
|
|
"operator_list_use": entry.operator_list_use,
|
|
"access_license_notes": entry.access_license_notes,
|
|
"priority": entry.priority,
|
|
"source_url": entry.source_url,
|
|
"evidence_url": entry.evidence_url,
|
|
"next_pipeline_action": entry.next_pipeline_action,
|
|
"status": entry.status,
|
|
"linked_source_count": linked_source_count,
|
|
"created_at": entry.created_at.isoformat() if entry.created_at else None,
|
|
"updated_at": entry.updated_at.isoformat() if entry.updated_at else None,
|
|
}
|
|
|
|
|
|
def linked_source_counts(session: Session, entries: Iterable[SourceCatalogEntry]) -> dict[int, int]:
|
|
entry_ids = [entry.id for entry in entries]
|
|
if not entry_ids:
|
|
return {}
|
|
return {
|
|
entry_id: count
|
|
for entry_id, count in session.execute(
|
|
select(Source.catalog_entry_id, func.count())
|
|
.where(Source.catalog_entry_id.in_(entry_ids))
|
|
.group_by(Source.catalog_entry_id)
|
|
).all()
|
|
if entry_id is not None
|
|
}
|
|
|
|
|
|
def _catalog_entry_for_ingestable_row(session: Session, row: dict[str, str]) -> SourceCatalogEntry | None:
|
|
country = _value(row, "country")
|
|
source_basis = _value(row, "source_basis")
|
|
name = _value(row, "name")
|
|
if not country and not source_basis and not name:
|
|
return None
|
|
if name:
|
|
exact = session.scalar(
|
|
select(SourceCatalogEntry)
|
|
.where(func.lower(SourceCatalogEntry.source_name) == name.lower())
|
|
.order_by(SourceCatalogEntry.id)
|
|
.limit(1)
|
|
)
|
|
if exact is not None:
|
|
return exact
|
|
clauses = []
|
|
if country:
|
|
clauses.append(SourceCatalogEntry.country_code.ilike(f"%{country}%"))
|
|
if source_basis:
|
|
for token in _basis_tokens(source_basis):
|
|
clauses.append(SourceCatalogEntry.source_name.ilike(f"%{token}%"))
|
|
clauses.append(SourceCatalogEntry.coverage_notes.ilike(f"%{token}%"))
|
|
if name:
|
|
first_word = name.split()[0]
|
|
if len(first_word) > 2:
|
|
clauses.append(SourceCatalogEntry.source_name.ilike(f"%{first_word}%"))
|
|
if not clauses:
|
|
return None
|
|
return session.scalar(
|
|
select(SourceCatalogEntry)
|
|
.where(or_(*clauses))
|
|
.order_by(SourceCatalogEntry.priority, SourceCatalogEntry.id)
|
|
.limit(1)
|
|
)
|
|
|
|
|
|
def _basis_tokens(value: str) -> list[str]:
|
|
tokens = []
|
|
for raw in value.replace("/", " ").replace("-", " ").split():
|
|
token = raw.strip(" ,.;()")
|
|
if len(token) >= 5 and token.lower() not in {"official", "mirror", "feeds", "transport"}:
|
|
tokens.append(token)
|
|
return tokens[:4]
|
|
|
|
|
|
def _catalog_key(row: dict[str, str]) -> str:
|
|
parts = [
|
|
_value(row, "Country code"),
|
|
_value(row, "Source name"),
|
|
_value(row, "Source URL"),
|
|
_value(row, "Formats / APIs"),
|
|
]
|
|
text = "|".join(part.lower() for part in parts if part)
|
|
if not text:
|
|
text = repr(sorted(row.items()))
|
|
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
|
|
|
|
|
def _read_csv(path: Path) -> list[dict[str, str]]:
|
|
if not path.exists():
|
|
raise FileNotFoundError(path)
|
|
with path.open("r", encoding="utf-8-sig", newline="") as handle:
|
|
reader = csv.DictReader(handle)
|
|
return [dict(row) for row in reader]
|
|
|
|
|
|
def _resolve_path(path: Path | str | None, default_path: Path) -> Path:
|
|
if path is None:
|
|
return default_path
|
|
candidate = Path(path)
|
|
if candidate.is_absolute():
|
|
return candidate
|
|
return Path.cwd() / candidate
|
|
|
|
|
|
def _value(row: dict[str, str], key: str) -> str | None:
|
|
value = row.get(key)
|
|
if value is None:
|
|
return None
|
|
stripped = value.strip()
|
|
return stripped or None
|