from __future__ import annotations import json from datetime import datetime, timezone from pathlib import Path from urllib.parse import urlparse import requests from sqlalchemy import select from sqlalchemy.orm import Session from app.config import settings from app.models import Dataset, Source, SourceUpdateCheck from app.pipeline.utils import norm_text, sha256_file def check_source_for_update(session: Session, source: Source) -> SourceUpdateCheck: active_dataset = session.scalar( select(Dataset) .where(Dataset.source_id == source.id, Dataset.is_active.is_(True)) .order_by(Dataset.created_at.desc(), Dataset.id.desc()) ) recovery = _recover_missing_managed_cache_url(source) remote = _source_remote_metadata(source) if recovery is not None: remote["recovered_source_url"] = recovery["url"] remote["previous_source_url"] = recovery["previous_url"] update_available, reason = _update_decision(active_dataset, remote) check = SourceUpdateCheck( source_id=source.id, status=remote["status"], update_available=update_available, reason=reason, remote_url=source.url, etag=remote.get("etag"), last_modified=remote.get("last_modified"), content_length=remote.get("content_length"), content_type=remote.get("content_type"), local_mtime=remote.get("local_mtime"), local_size=remote.get("local_size"), local_sha256=remote.get("local_sha256"), active_dataset_id=None if active_dataset is None else active_dataset.id, active_dataset_sha256=None if active_dataset is None else active_dataset.sha256, metadata_json=json.dumps(remote, separators=(",", ":"), default=_json_default), ) session.add(check) source.status = "update_check_error" if remote["status"] != "checked" else "update_available" if update_available else "up_to_date" source.last_error = None if remote["status"] == "checked" else reason session.flush() return check def latest_source_update_check(session: Session, source_id: int) -> SourceUpdateCheck | None: return session.scalar( select(SourceUpdateCheck) .where(SourceUpdateCheck.source_id == source_id) .order_by(SourceUpdateCheck.checked_at.desc(), SourceUpdateCheck.id.desc()) ) def update_check_payload(check: SourceUpdateCheck | None) -> dict | None: if check is None: return None try: metadata = json.loads(check.metadata_json or "{}") except json.JSONDecodeError: metadata = {} return { "id": check.id, "source_id": check.source_id, "checked_at": check.checked_at.isoformat() if check.checked_at else None, "status": check.status, "update_available": check.update_available, "reason": check.reason, "etag": check.etag, "last_modified": check.last_modified, "content_length": check.content_length, "content_type": check.content_type, "local_mtime": check.local_mtime.isoformat() if check.local_mtime else None, "local_size": check.local_size, "local_sha256": check.local_sha256, "active_dataset_id": check.active_dataset_id, "active_dataset_sha256": check.active_dataset_sha256, "metadata": metadata, } def record_dataset_update_metadata(dataset: Dataset, check: SourceUpdateCheck | None) -> None: if check is None: return try: metadata = json.loads(dataset.metadata_json or "{}") except json.JSONDecodeError: metadata = {} metadata["source_update_check"] = { "id": check.id, "checked_at": check.checked_at.isoformat() if check.checked_at else None, "etag": check.etag, "last_modified": check.last_modified, "content_length": check.content_length, "content_type": check.content_type, "local_mtime": check.local_mtime.isoformat() if check.local_mtime else None, "local_size": check.local_size, "local_sha256": check.local_sha256, "metadata": update_check_payload(check).get("metadata", {}), } dataset.metadata_json = json.dumps(metadata, indent=2, default=_json_default) def _source_remote_metadata(source: Source) -> dict: parsed = urlparse(source.url) if parsed.scheme in {"http", "https"}: return _http_metadata(source.url) path = Path(parsed.path) if parsed.scheme == "file" else Path(source.url) return _local_metadata(path) def _recover_missing_managed_cache_url(source: Source) -> dict | None: parsed = urlparse(source.url) if parsed.scheme in {"http", "https"}: return None path = Path(parsed.path) if parsed.scheme == "file" else Path(source.url) if path.exists() or not _is_managed_source_cache_path(path, source.id): return None replacement = _seed_source_url_for(source) if replacement is None: return None previous_url = source.url source.url = replacement return {"previous_url": previous_url, "url": replacement} def _is_managed_source_cache_path(path: Path, source_id: int) -> bool: source_dir = f"source_{source_id}" try: resolved = path.resolve() managed_dir = (settings.data_dir / "sources" / source_dir).resolve() resolved.relative_to(managed_dir) return True except ValueError: pass parts = path.parts return any(part == "sources" and index + 1 < len(parts) and parts[index + 1] == source_dir for index, part in enumerate(parts)) def _seed_source_url_for(source: Source) -> str | None: seed_path = Path(__file__).resolve().parents[1] / "scripts" / "example_sources.json" if not seed_path.exists(): return None try: rows = json.loads(seed_path.read_text(encoding="utf-8")) except (OSError, json.JSONDecodeError): return None source_tokens = set(norm_text(source.name).split()) for row in rows if isinstance(rows, list) else []: if not isinstance(row, dict): continue url = str(row.get("url") or "") if urlparse(url).scheme not in {"http", "https"}: continue if row.get("kind") != source.kind: continue if source.country and row.get("country") and str(row.get("country")) != source.country: continue row_tokens = set(norm_text(row.get("name")).split()) if row_tokens and (row_tokens <= source_tokens or source_tokens <= row_tokens): return url return None def _http_metadata(url: str) -> dict: response = None try: response = requests.head(url, allow_redirects=True, timeout=30) if response.status_code in {405, 501}: response.close() response = requests.get(url, stream=True, timeout=30) response.raise_for_status() except Exception as exc: # noqa: BLE001 - persisted as update-check status return {"status": "error", "error": str(exc)} finally: if response is not None: response.close() headers = response.headers content_length = headers.get("Content-Length") return { "status": "checked", "etag": headers.get("ETag"), "last_modified": headers.get("Last-Modified"), "content_length": int(content_length) if content_length and content_length.isdigit() else None, "content_type": headers.get("Content-Type"), "final_url": response.url, "update_artifact": _update_artifact(url, headers.get("Content-Type")), } def _local_metadata(path: Path) -> dict: if not path.exists(): return {"status": "error", "error": f"Source file does not exist: {path}"} stat = path.stat() return { "status": "checked", "local_mtime": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc), "local_size": stat.st_size, "local_sha256": sha256_file(path), "update_artifact": _update_artifact(str(path), None), } def _update_decision(active_dataset: Dataset | None, remote: dict) -> tuple[bool, str]: if remote["status"] != "checked": return False, remote.get("error") or "update check failed" if active_dataset is None: return True, "no active dataset imported" if remote.get("local_sha256"): if remote["local_sha256"] == active_dataset.sha256: return False, "local file hash matches active dataset" return True, "local file hash differs from active dataset" previous = _dataset_update_metadata(active_dataset) comparable = [] for key in ("etag", "last_modified", "content_length"): current = remote.get(key) old = previous.get(key) if current is not None and old is not None: comparable.append(key) if str(current) != str(old): return True, f"remote {key} changed" if comparable: return False, "remote metadata matches active dataset" return True, "no previous remote metadata recorded" def _dataset_update_metadata(dataset: Dataset) -> dict: try: metadata = json.loads(dataset.metadata_json or "{}") except json.JSONDecodeError: return {} return metadata.get("source_update_check") or {} def _json_default(value): if isinstance(value, datetime): return value.isoformat() raise TypeError(f"Object of type {type(value).__name__} is not JSON serializable") def _update_artifact(url_or_path: str, content_type: str | None) -> dict: lower = url_or_path.lower() is_osm_diff = lower.endswith(".osc") or lower.endswith(".osc.gz") is_gtfs_zip = lower.endswith(".zip") or (content_type or "").lower() in {"application/zip", "application/x-zip-compressed"} return { "kind": "osm_diff" if is_osm_diff else "gtfs_or_archive" if is_gtfs_zip else "full_snapshot", "is_diff": is_osm_diff, "content_type": content_type, }