Alpha stage commit
This commit is contained in:
309
app/source_catalog.py
Normal file
309
app/source_catalog.py
Normal file
@@ -0,0 +1,309 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import hashlib
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
from sqlalchemy import func, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models import Source, SourceCatalogEntry
|
||||
|
||||
|
||||
DIRECT_INGEST_KINDS = {"gtfs", "osm_geojson", "osm_pbf"}
|
||||
|
||||
|
||||
def default_source_catalog_path() -> Path:
|
||||
return Path(__file__).resolve().parents[1] / "docs" / "source_catalog_seed.csv"
|
||||
|
||||
|
||||
def default_ingestable_sources_path() -> Path:
|
||||
return Path(__file__).resolve().parents[1] / "docs" / "ingestable_sources_seed.csv"
|
||||
|
||||
|
||||
def import_source_catalog(session: Session, path: Path | str | None = None, *, update_existing: bool = True) -> dict[str, int]:
|
||||
csv_path = _resolve_path(path, default_source_catalog_path())
|
||||
rows = _read_csv(csv_path)
|
||||
created = 0
|
||||
updated = 0
|
||||
skipped = 0
|
||||
for row in rows:
|
||||
source_name = _value(row, "Source name")
|
||||
if not source_name:
|
||||
skipped += 1
|
||||
continue
|
||||
payload = {
|
||||
"catalog_key": _catalog_key(row),
|
||||
"geography": _value(row, "Geography"),
|
||||
"country_code": _value(row, "Country code"),
|
||||
"mode_scope": _value(row, "Mode scope"),
|
||||
"source_name": source_name,
|
||||
"source_category": _value(row, "Source category"),
|
||||
"formats_apis": _value(row, "Formats / APIs"),
|
||||
"availability": _value(row, "Availability"),
|
||||
"coverage_notes": _value(row, "Coverage notes"),
|
||||
"geometry_notes": _value(row, "Supersedes OSM for"),
|
||||
"disruptions_closures": _value(row, "Disruptions / closures"),
|
||||
"operator_list_use": _value(row, "Operator-list use"),
|
||||
"access_license_notes": _value(row, "Access / licence notes"),
|
||||
"priority": _value(row, "Priority"),
|
||||
"source_url": _value(row, "Source URL"),
|
||||
"evidence_url": _value(row, "Evidence URL"),
|
||||
"next_pipeline_action": _value(row, "Next pipeline action"),
|
||||
}
|
||||
existing = session.scalar(select(SourceCatalogEntry).where(SourceCatalogEntry.catalog_key == payload["catalog_key"]))
|
||||
if existing is None:
|
||||
session.add(SourceCatalogEntry(**payload))
|
||||
created += 1
|
||||
continue
|
||||
if not update_existing:
|
||||
skipped += 1
|
||||
continue
|
||||
for key, value in payload.items():
|
||||
setattr(existing, key, value)
|
||||
existing.updated_at = datetime.now(timezone.utc)
|
||||
updated += 1
|
||||
session.flush()
|
||||
return {"created": created, "updated": updated, "skipped": skipped}
|
||||
|
||||
|
||||
def import_ingestable_sources(
|
||||
session: Session,
|
||||
path: Path | str | None = None,
|
||||
*,
|
||||
update_existing: bool = True,
|
||||
) -> dict[str, int]:
|
||||
csv_path = _resolve_path(path, default_ingestable_sources_path())
|
||||
rows = _read_csv(csv_path)
|
||||
created = 0
|
||||
updated = 0
|
||||
skipped = 0
|
||||
linked_catalog = 0
|
||||
for row in rows:
|
||||
name = _value(row, "name")
|
||||
kind = (_value(row, "kind") or "").lower()
|
||||
url = _value(row, "url")
|
||||
if not name or not url or kind not in DIRECT_INGEST_KINDS:
|
||||
skipped += 1
|
||||
continue
|
||||
catalog_entry = _catalog_entry_for_ingestable_row(session, row)
|
||||
payload = {
|
||||
"name": name,
|
||||
"kind": kind,
|
||||
"url": url,
|
||||
"country": _value(row, "country"),
|
||||
"license": _value(row, "license"),
|
||||
"priority": _value(row, "priority"),
|
||||
"mode_scope": _value(row, "mode_scope"),
|
||||
"source_basis": _value(row, "source_basis"),
|
||||
"notes": _value(row, "notes"),
|
||||
"catalog_entry_id": None if catalog_entry is None else catalog_entry.id,
|
||||
}
|
||||
existing = session.scalar(
|
||||
select(Source)
|
||||
.where(Source.kind == kind, Source.url == url)
|
||||
.order_by(Source.id)
|
||||
.limit(1)
|
||||
)
|
||||
if existing is None:
|
||||
existing = session.scalar(select(Source).where(Source.name == name, Source.url == url).order_by(Source.id).limit(1))
|
||||
if existing is None:
|
||||
session.add(Source(**payload))
|
||||
created += 1
|
||||
if catalog_entry is not None:
|
||||
linked_catalog += 1
|
||||
continue
|
||||
if not update_existing:
|
||||
skipped += 1
|
||||
continue
|
||||
for key, value in payload.items():
|
||||
setattr(existing, key, value)
|
||||
existing.enabled = True
|
||||
updated += 1
|
||||
if catalog_entry is not None:
|
||||
linked_catalog += 1
|
||||
session.flush()
|
||||
return {"created": created, "updated": updated, "skipped": skipped, "linked_catalog": linked_catalog}
|
||||
|
||||
|
||||
def source_catalog_summary(session: Session) -> dict[str, object]:
|
||||
priority_counts = {
|
||||
priority or "unknown": count
|
||||
for priority, count in session.execute(
|
||||
select(SourceCatalogEntry.priority, func.count()).group_by(SourceCatalogEntry.priority)
|
||||
).all()
|
||||
}
|
||||
status_counts = {
|
||||
status or "unknown": count
|
||||
for status, count in session.execute(select(SourceCatalogEntry.status, func.count()).group_by(SourceCatalogEntry.status)).all()
|
||||
}
|
||||
ingestable_sources = session.scalar(
|
||||
select(func.count()).select_from(Source).where(Source.source_basis.is_not(None) | Source.priority.is_not(None))
|
||||
) or 0
|
||||
return {
|
||||
"catalog_entries": session.scalar(select(func.count()).select_from(SourceCatalogEntry)) or 0,
|
||||
"catalog_by_priority": priority_counts,
|
||||
"catalog_by_status": status_counts,
|
||||
"seeded_ingestable_sources": ingestable_sources,
|
||||
}
|
||||
|
||||
|
||||
def source_catalog_rows(
|
||||
session: Session,
|
||||
*,
|
||||
q: str | None = None,
|
||||
country: str | None = None,
|
||||
priority: str | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 100,
|
||||
) -> list[SourceCatalogEntry]:
|
||||
stmt = select(SourceCatalogEntry).order_by(
|
||||
SourceCatalogEntry.priority,
|
||||
SourceCatalogEntry.country_code,
|
||||
SourceCatalogEntry.source_name,
|
||||
SourceCatalogEntry.id,
|
||||
)
|
||||
if q:
|
||||
pattern = f"%{q.strip()}%"
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
SourceCatalogEntry.source_name.ilike(pattern),
|
||||
SourceCatalogEntry.source_category.ilike(pattern),
|
||||
SourceCatalogEntry.formats_apis.ilike(pattern),
|
||||
SourceCatalogEntry.coverage_notes.ilike(pattern),
|
||||
SourceCatalogEntry.next_pipeline_action.ilike(pattern),
|
||||
)
|
||||
)
|
||||
if country:
|
||||
stmt = stmt.where(SourceCatalogEntry.country_code.ilike(f"%{country.strip()}%"))
|
||||
if priority:
|
||||
stmt = stmt.where(SourceCatalogEntry.priority == priority.strip())
|
||||
if status:
|
||||
stmt = stmt.where(SourceCatalogEntry.status == status.strip())
|
||||
return session.scalars(stmt.limit(max(1, min(limit, 500)))).all()
|
||||
|
||||
|
||||
def catalog_entry_payload(entry: SourceCatalogEntry, *, linked_source_count: int = 0) -> dict[str, object]:
|
||||
return {
|
||||
"id": entry.id,
|
||||
"geography": entry.geography,
|
||||
"country_code": entry.country_code,
|
||||
"mode_scope": entry.mode_scope,
|
||||
"source_name": entry.source_name,
|
||||
"source_category": entry.source_category,
|
||||
"formats_apis": entry.formats_apis,
|
||||
"availability": entry.availability,
|
||||
"coverage_notes": entry.coverage_notes,
|
||||
"geometry_notes": entry.geometry_notes,
|
||||
"disruptions_closures": entry.disruptions_closures,
|
||||
"operator_list_use": entry.operator_list_use,
|
||||
"access_license_notes": entry.access_license_notes,
|
||||
"priority": entry.priority,
|
||||
"source_url": entry.source_url,
|
||||
"evidence_url": entry.evidence_url,
|
||||
"next_pipeline_action": entry.next_pipeline_action,
|
||||
"status": entry.status,
|
||||
"linked_source_count": linked_source_count,
|
||||
"created_at": entry.created_at.isoformat() if entry.created_at else None,
|
||||
"updated_at": entry.updated_at.isoformat() if entry.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
def linked_source_counts(session: Session, entries: Iterable[SourceCatalogEntry]) -> dict[int, int]:
|
||||
entry_ids = [entry.id for entry in entries]
|
||||
if not entry_ids:
|
||||
return {}
|
||||
return {
|
||||
entry_id: count
|
||||
for entry_id, count in session.execute(
|
||||
select(Source.catalog_entry_id, func.count())
|
||||
.where(Source.catalog_entry_id.in_(entry_ids))
|
||||
.group_by(Source.catalog_entry_id)
|
||||
).all()
|
||||
if entry_id is not None
|
||||
}
|
||||
|
||||
|
||||
def _catalog_entry_for_ingestable_row(session: Session, row: dict[str, str]) -> SourceCatalogEntry | None:
|
||||
country = _value(row, "country")
|
||||
source_basis = _value(row, "source_basis")
|
||||
name = _value(row, "name")
|
||||
if not country and not source_basis and not name:
|
||||
return None
|
||||
if name:
|
||||
exact = session.scalar(
|
||||
select(SourceCatalogEntry)
|
||||
.where(func.lower(SourceCatalogEntry.source_name) == name.lower())
|
||||
.order_by(SourceCatalogEntry.id)
|
||||
.limit(1)
|
||||
)
|
||||
if exact is not None:
|
||||
return exact
|
||||
clauses = []
|
||||
if country:
|
||||
clauses.append(SourceCatalogEntry.country_code.ilike(f"%{country}%"))
|
||||
if source_basis:
|
||||
for token in _basis_tokens(source_basis):
|
||||
clauses.append(SourceCatalogEntry.source_name.ilike(f"%{token}%"))
|
||||
clauses.append(SourceCatalogEntry.coverage_notes.ilike(f"%{token}%"))
|
||||
if name:
|
||||
first_word = name.split()[0]
|
||||
if len(first_word) > 2:
|
||||
clauses.append(SourceCatalogEntry.source_name.ilike(f"%{first_word}%"))
|
||||
if not clauses:
|
||||
return None
|
||||
return session.scalar(
|
||||
select(SourceCatalogEntry)
|
||||
.where(or_(*clauses))
|
||||
.order_by(SourceCatalogEntry.priority, SourceCatalogEntry.id)
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
|
||||
def _basis_tokens(value: str) -> list[str]:
|
||||
tokens = []
|
||||
for raw in value.replace("/", " ").replace("-", " ").split():
|
||||
token = raw.strip(" ,.;()")
|
||||
if len(token) >= 5 and token.lower() not in {"official", "mirror", "feeds", "transport"}:
|
||||
tokens.append(token)
|
||||
return tokens[:4]
|
||||
|
||||
|
||||
def _catalog_key(row: dict[str, str]) -> str:
|
||||
parts = [
|
||||
_value(row, "Country code"),
|
||||
_value(row, "Source name"),
|
||||
_value(row, "Source URL"),
|
||||
_value(row, "Formats / APIs"),
|
||||
]
|
||||
text = "|".join(part.lower() for part in parts if part)
|
||||
if not text:
|
||||
text = repr(sorted(row.items()))
|
||||
return hashlib.sha256(text.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def _read_csv(path: Path) -> list[dict[str, str]]:
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(path)
|
||||
with path.open("r", encoding="utf-8-sig", newline="") as handle:
|
||||
reader = csv.DictReader(handle)
|
||||
return [dict(row) for row in reader]
|
||||
|
||||
|
||||
def _resolve_path(path: Path | str | None, default_path: Path) -> Path:
|
||||
if path is None:
|
||||
return default_path
|
||||
candidate = Path(path)
|
||||
if candidate.is_absolute():
|
||||
return candidate
|
||||
return Path.cwd() / candidate
|
||||
|
||||
|
||||
def _value(row: dict[str, str], key: str) -> str | None:
|
||||
value = row.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
stripped = value.strip()
|
||||
return stripped or None
|
||||
Reference in New Issue
Block a user