diff --git a/.env.example b/.env.example
new file mode 100644
index 0000000..41df7e3
--- /dev/null
+++ b/.env.example
@@ -0,0 +1,20 @@
+# SQLite is the default for the prototype. It keeps the project runnable without Docker.
+DATABASE_URL=sqlite:///./data/workbench.sqlite
+# For large imports, use PostgreSQL/PostGIS instead:
+# DATABASE_URL=postgresql://USER:PASSWORD@localhost:5432/meubility
+# POSTGRES_USE_SIDECARS=false
+DATA_DIR=./data
+GTFS_STOP_TIMES_IMPORT_LIMIT=250000
+
+# Start separate queue worker processes from the API server lifespan.
+# Workers survive normal server restarts by default; stale leases are recovered.
+QUEUE_WORKER_AUTOSTART=true
+QUEUE_WORKER_COUNT=1
+QUEUE_WORKER_POLL_INTERVAL_SECONDS=2
+QUEUE_JOB_LEASE_SECONDS=7200
+QUEUE_WORKER_STOP_ON_SHUTDOWN=false
+
+# Chunk sizes for queued data-preparation jobs.
+ROUTE_MATCHING_BATCH_SIZE=100
+ROUTE_LAYER_OSM_ROUTE_BATCH_SIZE=1000
+ROUTE_LAYER_OSM_STOP_BATCH_SIZE=5000
diff --git a/.gitignore b/.gitignore
index ccaad04..c2d91ed 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,328 +1,8 @@
-# ---> Python
-# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
-*$py.class
-
-# C extensions
-*.so
-
-# Distribution / packaging
-.Python
-build/
-develop-eggs/
-dist/
-downloads/
-eggs/
-.eggs/
-lib/
-lib64/
-parts/
-sdist/
-var/
-wheels/
-share/python-wheels/
-*.egg-info/
-.installed.cfg
-*.egg
-MANIFEST
-
-# PyInstaller
-# Usually these files are written by a python script from a template
-# before PyInstaller builds the exe, so as to inject date/other infos into it.
-*.manifest
-*.spec
-
-# Installer logs
-pip-log.txt
-pip-delete-this-directory.txt
-
-# Unit test / coverage reports
-htmlcov/
-.tox/
-.nox/
-.coverage
-.coverage.*
-.cache
-nosetests.xml
-coverage.xml
-*.cover
-*.py,cover
-.hypothesis/
-.pytest_cache/
-cover/
-
-# Translations
-*.mo
-*.pot
-
-# Django stuff:
-*.log
-local_settings.py
-db.sqlite3
-db.sqlite3-journal
-
-# Flask stuff:
-instance/
-.webassets-cache
-
-# Scrapy stuff:
-.scrapy
-
-# Sphinx documentation
-docs/_build/
-
-# PyBuilder
-.pybuilder/
-target/
-
-# Jupyter Notebook
-.ipynb_checkpoints
-
-# IPython
-profile_default/
-ipython_config.py
-
-# pyenv
-# For a library or package, you might want to ignore these files since the code is
-# intended to run in multiple environments; otherwise, check them in:
-# .python-version
-
-# pipenv
-# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
-# However, in case of collaboration, if having platform-specific dependencies or dependencies
-# having no cross-platform support, pipenv may install dependencies that don't work, or not
-# install all needed dependencies.
-#Pipfile.lock
-
-# UV
-# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
-# This is especially recommended for binary packages to ensure reproducibility, and is more
-# commonly ignored for libraries.
-#uv.lock
-
-# poetry
-# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
-# This is especially recommended for binary packages to ensure reproducibility, and is more
-# commonly ignored for libraries.
-# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
-#poetry.lock
-
-# pdm
-# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
-#pdm.lock
-# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
-# in version control.
-# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
-.pdm.toml
-.pdm-python
-.pdm-build/
-
-# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
-__pypackages__/
-
-# Celery stuff
-celerybeat-schedule
-celerybeat.pid
-
-# SageMath parsed files
-*.sage.py
-
-# Environments
.env
-.venv
-env/
-venv/
-ENV/
-env.bak/
-venv.bak/
-
-# Spyder project settings
-.spyderproject
-.spyproject
-
-# Rope project settings
-.ropeproject
-
-# mkdocs documentation
-/site
-
-# mypy
-.mypy_cache/
-.dmypy.json
-dmypy.json
-
-# Pyre type checker
-.pyre/
-
-# pytype static type analyzer
-.pytype/
-
-# Cython debug symbols
-cython_debug/
-
-# PyCharm
-# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
-# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
-# and can be added to the global gitignore or merged into this file. For a more nuclear
-# option (not recommended) you can uncomment the following to ignore the entire idea folder.
-#.idea/
-
-# Ruff stuff:
-.ruff_cache/
-
-# PyPI configuration file
-.pypirc
-
-# ---> Node
-# Logs
-logs
-*.log
-npm-debug.log*
-yarn-debug.log*
-yarn-error.log*
-lerna-debug.log*
-.pnpm-debug.log*
-
-# Diagnostic reports (https://nodejs.org/api/report.html)
-report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
-
-# Runtime data
-pids
-*.pid
-*.seed
-*.pid.lock
-
-# Directory for instrumented libs generated by jscoverage/JSCover
-lib-cov
-
-# Coverage directory used by tools like istanbul
-coverage
-*.lcov
-
-# nyc test coverage
-.nyc_output
-
-# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
-.grunt
-
-# Bower dependency directory (https://bower.io/)
-bower_components
-
-# node-waf configuration
-.lock-wscript
-
-# Compiled binary addons (https://nodejs.org/api/addons.html)
-build/Release
-
-# Dependency directories
-node_modules/
-jspm_packages/
-
-# Snowpack dependency directory (https://snowpack.dev/)
-web_modules/
-
-# TypeScript cache
-*.tsbuildinfo
-
-# Optional npm cache directory
-.npm
-
-# Optional eslint cache
-.eslintcache
-
-# Optional stylelint cache
-.stylelintcache
-
-# Microbundle cache
-.rpt2_cache/
-.rts2_cache_cjs/
-.rts2_cache_es/
-.rts2_cache_umd/
-
-# Optional REPL history
-.node_repl_history
-
-# Output of 'npm pack'
-*.tgz
-
-# Yarn Integrity file
-.yarn-integrity
-
-# dotenv environment variable files
-.env
-.env.development.local
-.env.test.local
-.env.production.local
-.env.local
-
-# parcel-bundler cache (https://parceljs.org/)
-.cache
-.parcel-cache
-
-# Next.js build output
-.next
-out
-
-# Nuxt.js build / generate output
-.nuxt
-dist
-
-# Gatsby files
-.cache/
-# Comment in the public line in if your project uses Gatsby and not Next.js
-# https://nextjs.org/blog/next-9-1#public-directory-support
-# public
-
-# vuepress build output
-.vuepress/dist
-
-# vuepress v2.x temp and cache directory
-.temp
-.cache
-
-# vitepress build output
-**/.vitepress/dist
-
-# vitepress cache directory
-**/.vitepress/cache
-
-# Docusaurus cache and generated files
-.docusaurus
-
-# Serverless directories
-.serverless/
-
-# FuseBox cache
-.fusebox/
-
-# DynamoDB Local files
-.dynamodb/
-
-# TernJS port file
-.tern-port
-
-# Stores VSCode versions used for testing VSCode extensions
-.vscode-test
-
-# yarn v2
-.yarn/cache
-.yarn/unplugged
-.yarn/build-state.yml
-.yarn/install-state.gz
-.pnp.*
-
-# ---> VisualStudioCode
-.vscode/*
-!.vscode/settings.json
-!.vscode/tasks.json
-!.vscode/launch.json
-!.vscode/extensions.json
-!.vscode/*.code-snippets
-
-# Local History for Visual Studio Code
-.history/
-
-# Built Visual Studio Code Extensions
-*.vsix
-
+/data/*
+!/data/.gitkeep
+*.sqlite
+*.db
+.DS_Store
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..68f1fd3
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,17 @@
+FROM python:3.12-slim
+
+WORKDIR /app
+ENV PYTHONDONTWRITEBYTECODE=1 \
+ PYTHONUNBUFFERED=1 \
+ DATA_DIR=/app/data \
+ DATABASE_URL=sqlite:////app/data/workbench.sqlite
+
+COPY requirements.txt ./
+RUN pip install --no-cache-dir -r requirements.txt
+
+COPY app ./app
+COPY README.md MVP_ROADMAP.md ./
+RUN mkdir -p /app/data
+
+EXPOSE 8000
+CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
diff --git a/MVP_ROADMAP.md b/MVP_ROADMAP.md
new file mode 100644
index 0000000..a9a1326
--- /dev/null
+++ b/MVP_ROADMAP.md
@@ -0,0 +1,220 @@
+# MVP roadmap
+
+Last updated: 2026-07-01
+
+See also `docs/backlog.md` for the prioritized engineering backlog, caveats, and open optimization list.
+
+## Objective
+
+Build an internal management workbench that turns public mobility data into a normalized, auditable, coverage-scored dataset for a future traveller-facing web/native app.
+
+The workbench stays distinct from the public app. Its users are data engineers, analysts, and operations staff who need to ingest, inspect, link, correct, route against, and publish mobility data.
+
+## Current prototype: implemented
+
+The repository has moved beyond the original SQLite/Berlin prototype. The current development path is Germany-scale and PostGIS-first, while SQLite remains useful as a legacy/test fallback.
+
+Implemented:
+
+```text
+source registry and source catalog
+local source cache
+job queue with job events and worker process
+PostgreSQL/PostGIS runtime support with SQLite fallback
+GTFS static importer for large national feeds
+OSM PBF import path for Germany-scale extracts
+OSM address index and address-aware journey endpoints
+canonical stop/station linking from GTFS and OSM
+automatic GTFS <-> OSM route matching
+manual route and canonical-stop rule persistence
+visual route-layer builder from OSM routes and GTFS shapes
+walk/drive routing layer from OSM-derived routing graph
+progressive journey-search API and UI polling
+map right-click "from here" / "to here"
+management UI with map, sources, stats, jobs, matches, search, and journeys
+separate GTFS Harmonization and Mapping Data source modules in the UI
+generic job-details overlay with phase timeline, event log, and queue snapshot
+QA dashboard skeleton for source/import/link/route/publication health
+GTFS harmonization concept and service-boundary decision
+CLI commands
+tests and syntax checks for changed modules
+```
+
+Recent fixes:
+
+```text
+PostgreSQL startup avoids unnecessary DDL when PostGIS columns/indexes already exist.
+Queue route-layer rebuild can be claimed by a real worker instead of staying queued behind a stale worker pid.
+Timetable routing no longer requires visual route-pattern trip links.
+Walk-leg route geometry has a short-lived in-process cache.
+Address search is bbox-aware without being bbox-limited.
+Job rows expose a details overlay that polls job events only while open.
+Journey routing consumes the active harmonized GTFS snapshot instead of a raw feed picker.
+```
+
+## Current prototype: known limits
+
+The app can import and inspect Germany-scale OSM and GTFS, but the routing and route-layer rebuild paths are still prototype-grade.
+
+Important limits:
+
+```text
+journey search is not yet RAPTOR/CSA or connection-scan based
+address endpoints can multiply transit searches through several nearby access/egress stops
+progressive transfer stages still recompute too much
+route-layer rebuild is coarse-grained and rewrites derived link tables
+visual route-pattern links are not yet incrementally updated
+canonical stop extraction is CPU/memory heavy on national feeds
+route geometry cannot yet classify temporary GTFS detours as separate variants
+local-transport-only routing is not a first-class query mode
+route-search caches are process-local and not persisted
+Alembic migrations are still missing
+```
+
+## MVP 1: stable Germany data workbench
+
+### Backend
+
+- Add proper Alembic migrations for PostgreSQL and keep SQLite test support.
+- Add source-run history and dataset-version comparison.
+- Make route-layer rebuild incremental: update only affected matches/patterns/stops.
+- Keep old route-layer tables readable while a rebuild prepares replacement rows.
+- Add source health checks: download success, hash change, feed freshness, calendar validity.
+- Expand the QA dashboard into drill-down review queues for source health, GTFS validation, canonical stop conflicts, route conflicts, and publication blockers.
+- Add GTFS validation summary reports: service dates, route direction coverage, stop coordinate outliers, bad stop_times, missing shapes.
+- Add database maintenance jobs: analyze, vacuum, stale job recovery, orphan cleanup.
+- Add durable cache tables for journey stages, nearest stops, address access candidates, and common station-to-station searches.
+
+### Routing
+
+- Replace the demo round-expansion router with a GTFS-appropriate algorithm such as RAPTOR or CSA.
+- Precompute transfer graph edges: station-internal transfers, nearby walking transfers, and access/egress stop candidates.
+- Add routing profiles:
+
+```text
+fastest public transport
+fewest transfers
+local transport only / Deutschlandticket-like
+walk only
+drive
+car comparison
+```
+
+- Treat access/egress walking as access legs, not as public-transport transfers.
+- Add bounded hub-aware long-distance routing for city-to-city requests: local access to likely hubs, long-distance/regional trunk, local egress.
+- Add arrive-by search and better stop conditions for "good enough" results.
+- Add route diagnostics that explain why a route was found or pruned.
+
+### Frontend
+
+- Add source detail page.
+- Add dataset detail page.
+- Add match-review queue with filters by mode, operator, country, confidence, and source scope.
+- Add route detail inspection: GTFS geometry, OSM geometry, candidate matches, stops, evidence, and route-pattern provenance.
+- Add canonical stop/station detail overlay.
+- Add persistent rule editor.
+- Add routing controls for profile, transfer buffer, avoid/prefer modes, arrive-by, via, and local-only.
+- Show partial/progressive route results with clear stage labels.
+
+### Data outputs
+
+- GeoJSON exports for small regions.
+- GeoParquet exports for analysis.
+- PMTiles/vector-tile export for map display.
+- Coverage CSV/API for downstream services.
+
+## MVP 2: Europe-scale coverage map
+
+- Use Geofabrik country/Europe extracts and reproducible OSM PBF jobs.
+- Store OSM transport features, addresses, and routing graph in PostGIS.
+- Generate ranked/generalized transport route layers by zoom level.
+- Serve tiles with Martin or export PMTiles.
+- Add coverage statuses:
+
+```text
+existing_in_osm
+static_timetable_covered
+live_data_covered
+fare_data_covered
+booking_covered
+missing_static
+stale_feed
+restricted_license
+low_confidence_match
+detour_or_temporary_variant
+```
+
+- Add coverage metrics:
+
+```text
+operator coverage
+route coverage
+route-km coverage
+stop coverage
+live-data coverage
+feed freshness
+license confidence
+booking coverage
+route-layer provenance coverage
+```
+
+## MVP 3: more source formats
+
+Add importers:
+
+```text
+NeTEx
+TransXChange
+SIRI discovery/live endpoints
+GTFS-Realtime
+GBFS for shared mobility, optional
+operator CSV/API adapters
+```
+
+Target data model:
+
+```text
+canonical operators
+canonical stops/stations/terminals
+canonical routes
+route variants
+trip patterns
+calendar/service validity
+transfers
+access/egress legs
+coverage observations
+source evidence
+manual rules
+```
+
+## MVP 4: production journey-planning dataset
+
+- Build a canonical stop/station graph with transfer rules and transfer-time profiles.
+- Generate timetable-routing input for RAPTOR/CSA.
+- Add first/last-mile routing from OSM walk/drive graph.
+- Add emissions factors per mode/operator/country.
+- Add fare/ticket placeholders and booking/deep-link metadata.
+- Add confidence and provenance to every derived route/journey.
+
+## MVP 5: booking-readiness layer
+
+- Track booking availability separately from timetable coverage.
+- Add deep-link metadata per operator/route.
+- Add partner API adapters later.
+- Distinguish clearly:
+
+```text
+travel-plausible itinerary
+bookable itinerary
+single-interface multi-booking
+protected through-ticket
+```
+
+## Recommended next implementation sprint
+
+1. Finish route-layer rebuild resilience: incremental updates, shadow tables, and detour/provenance classification.
+2. Replace or heavily optimize journey routing: precomputed transfers, hub-aware long-distance routing, local-only profile, and bounded search.
+3. Add durable PostgreSQL-backed journey caches for address access, stop pairs, and repeated stage searches.
+4. Add Alembic migrations and remove runtime DDL from normal request/worker startup.
+5. Add route/journey diagnostics so slow or failed requests explain what was searched and pruned.
+6. Add vector-tile output for route layers and large map rendering.
diff --git a/README.md b/README.md
index 868fdfd..c02abfb 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,282 @@
-# meubility-workbench
+# Mobility Workbench
+Working prototype for a mobility-data management interface and pipeline.
+
+It is intentionally small but executable. The current implementation lets you:
+
+- register data sources;
+- download/copy source files into a local cache;
+- import GTFS static timetable feeds;
+- import raw OSM PBF extracts by deriving transport GeoJSON;
+- import OSM-derived transport GeoJSON;
+- persist raw datasets and normalized route/stop records;
+- run automatic GTFS-route ↔ OSM-route matching;
+- persist manual accept/reject rules from the UI;
+- expose GeoJSON layers for a zoomable map;
+- use a management web UI with separate GTFS Harmonization and Mapping Data modules, plus source runs, stats, matches, and map inspection.
+
+The default database is SQLite so the prototype runs immediately. The schema is kept simple enough to migrate to PostGIS when the pipeline needs European scale, vector tiles, and spatial indexes.
+
+## Quick start
+
+```bash
+cd mobility-workbench
+python -m venv .venv
+source .venv/bin/activate
+pip install -r requirements.txt
+python -m app.cli load-sample
+uvicorn app.main:app --reload
+```
+
+Open:
+
+```text
+http://127.0.0.1:8000
+```
+
+The sample project loads a small Berlin-like GTFS feed plus an OSM-like GeoJSON network. It imports routes/stops, runs the matcher, and shows matched and missing coverage on the map.
+
+## PostgreSQL/PostGIS
+
+SQLite remains the default. For Germany-scale imports, point `DATABASE_URL` at PostgreSQL:
+
+```bash
+export DATABASE_URL=postgresql://USER:PASSWORD@localhost:5432/meubility
+python -m app.cli init-db
+uvicorn app.main:app --reload
+```
+
+PostgreSQL mode automatically creates `postgis` and `pg_trgm`, stores GTFS `stop_times` and OSM features in main tables, and uses GiST/trigram indexes for map bbox queries, route-layer stop linking, and search filters. To keep using legacy sidecars with PostgreSQL, set:
+
+```bash
+export POSTGRES_USE_SIDECARS=true
+```
+
+To migrate the existing SQLite project into a fresh PostgreSQL database:
+
+```bash
+python scripts/migrate_sqlite_to_postgres.py \
+ --sqlite-path data/workbench.sqlite \
+ --postgres-url postgresql://USER:PASSWORD@localhost:5432/meubility \
+ --reset
+```
+
+The migration copies normal tables first, imports legacy GTFS/OSM sidecars into PostgreSQL main tables, rewrites dataset storage metadata to `main`, refreshes PostGIS geometry columns, and rebuilds runtime indexes.
+
+## Docker start
+
+```bash
+docker compose up --build
+```
+
+Then open:
+
+```text
+http://127.0.0.1:8000
+```
+
+## CLI commands
+
+```bash
+python -m app.cli init-db
+python -m app.cli reset-db
+python -m app.cli load-sample
+python -m app.cli stats
+python -m app.cli add-source --name "My GTFS" --kind gtfs --url ./data/feed.zip --country DE
+python -m app.cli add-source --name "VBB Online GTFS" --kind gtfs --url https://unternehmen.vbb.de/fileadmin/user_upload/VBB/Dokumente/API-Datensaetze/gtfs-mastscharf/GTFS.zip --country DE --license "CC BY 4.0"
+python -m app.cli add-source --name "DB Long-distance Rail GTFS.DE" --kind gtfs --url https://download.gtfs.de/germany/fv_free/latest.zip --country DE --license "Creative Commons 4.0"
+python -m app.cli add-source --name "Germany Regional Rail GTFS.DE" --kind gtfs --url https://download.gtfs.de/germany/rv_free/latest.zip --country DE --license "Creative Commons 4.0"
+python -m app.cli add-source --name "Berlin OSM" --kind osm_pbf --url https://download.geofabrik.de/europe/germany/berlin-latest.osm.pbf --country DE --license ODbL
+python -m app.cli run-source 1
+python -m app.cli run-match
+python -m app.cli prune-cache --dry-run
+python -m app.cli prune-cache
+```
+
+## HTTP API
+
+Core endpoints:
+
+```text
+GET /api/sources
+POST /api/sources
+POST /api/sources/{source_id}/run
+POST /api/sample/reset
+POST /api/match/run
+GET /api/stats
+GET /api/matches
+POST /api/matches/{match_id}/accept
+POST /api/matches/{match_id}/reject
+GET /api/rules
+POST /api/rules
+```
+
+Map layers:
+
+```text
+GET /api/map/osm_routes.geojson
+GET /api/map/osm_stops.geojson
+GET /api/map/gtfs_routes.geojson
+GET /api/map/gtfs_stops.geojson
+GET /api/map/matched_gtfs_routes.geojson
+GET /api/map/matched_gtfs_routes.geojson?status=missing
+```
+
+Map endpoints accept viewport and layer filters:
+
+```text
+bbox=min_lon,min_lat,max_lon,max_lat
+zoom=13
+kind=route,infra,stop,station,terminal
+mode=bus,tram,train,subway,light_rail,ferry
+geometry=point,line,polygon,nonpoint
+source_id=4
+dataset_id=5
+limit=5000
+```
+
+## Source types implemented
+
+### `gtfs`
+
+Expected input: GTFS static zip.
+
+Imported files:
+
+```text
+agency.txt
+stops.txt
+routes.txt
+trips.txt
+stop_times.txt
+shapes.txt, if available
+```
+
+The importer stores agencies, stops, routes, trips, limited stop-times, and representative route geometries. Route geometry comes from `shapes.txt` where available; otherwise it falls back to stop sequences from a representative trip.
+
+Multiple GTFS sources can be active at once. Map endpoints and layer controls keep sources separate with `source_id` filters, so VBB, DB long-distance rail, DB/regional rail, and local sample feeds can be rendered independently.
+
+The journey UI routes against the active harmonized transit snapshot instead of exposing a raw GTFS source selector. Feed-level filters remain available for map layers, QA, and source diagnostics.
+
+### `osm_pbf`
+
+Expected input: an OSM `.osm.pbf` extract, for example a Geofabrik regional extract.
+
+The importer records the downloaded/copied file once as an immutable raw dataset with kind `osm_pbf_raw`. For `.osm.pbf` inputs it then runs `scripts/osmium_transport_filter.sh` and stores one transport-only extract as `osm_pbf_transport`. The Python extractor reads that filtered extract, writes `transport.geojson`, and imports it through the `osm_geojson` importer.
+
+The raw and filtered datasets are inactive storage stages; the derived `osm_geojson` dataset is the active visual layer. Re-running an unchanged source reuses the existing raw, filtered, and derived datasets instead of duplicating the extract.
+
+The extractor emits:
+
+```text
+route relations as LineString/MultiLineString features built from member ways
+rail/tram/subway/ferry/aerialway infrastructure ways
+stations, stops, platforms, bus stations, and ferry terminals
+```
+
+Route display uses OSM route relation member ways, not stop-to-stop straight-line interpolation.
+
+### `osm_geojson`
+
+Expected input: GeoJSON `FeatureCollection` containing OSM-derived route/station/stop/terminal features.
+
+Minimum useful properties for route features:
+
+```json
+{
+ "osm_type": "relation",
+ "osm_id": "12345",
+ "type": "route",
+ "route": "train",
+ "ref": "RE1",
+ "name": "RE1 Example Line",
+ "operator": "Example Operator",
+ "network": "Example Network"
+}
+```
+
+Supported route modes include:
+
+```text
+train, light_rail, subway, tram, bus, trolleybus, coach,
+ferry, monorail, funicular, aerialway
+```
+
+## Matching logic
+
+The current automatic matcher scores each GTFS route against OSM route features using:
+
+```text
+mode compatibility
+route ref similarity
+route name similarity
+operator/network similarity
+bbox overlap or proximity, used as a major disambiguator for common refs
+GTFS/OSM geometry proximity, where both geometries are available
+same normalized route key
+```
+
+Each match also stores a scope classification:
+
+```text
+in_osm_scope
+near_osm_scope
+outside_osm_scope
+unknown_scope
+```
+
+Overall coverage and in-scope coverage are intentionally separate. A GTFS route outside the loaded OSM extract should not be interpreted as a failed route match.
+
+Status thresholds:
+
+```text
+>= 85 matched
+65–84 probable
+40–64 weak
+< 40 missing
+```
+
+Manual accept/reject actions are stored as `match_rules`. The current prototype records the rule; the next implementation step is applying those rules automatically before/after every matching run.
+
+The route layer treats OSM route geometry as the visual authority when a suitable match exists. Multiple GTFS timetable shapes or trips, including opposite directions, can link to the same OSM-backed `RoutePattern`; each GTFS shape link keeps its own match and direction evidence. When no OSM route matches, the builder creates a `gtfs_proposed` visual pattern from GTFS geometry for review.
+
+## Data flow
+
+```text
+source registration
+→ local source cache
+→ dataset record with hash
+→ raw OSM commit, if source is osm_pbf
+→ filtered transport extract, if source is osm_pbf and prefiltering is enabled
+→ derived transport GeoJSON extraction, if source is osm_pbf
+→ normalized GTFS / OSM tables
+→ route matching
+→ canonical stops and OSM-authoritative route layer
+→ manual review rules
+→ GeoJSON map layers
+→ downstream routing/coverage/tile generation
+```
+
+## Current limitations
+
+- PostgreSQL/PostGIS is supported for large local imports; vector tiles are still the next step for country/Europe-scale browsing.
+- OSM PBF snapshot extraction is implemented; applying replication `.osc.gz` diffs onto prior raw snapshots is still a next step.
+- GTFS-RT, SIRI, NeTEx, TransXChange, OSDM, fares, and booking APIs are not yet implemented.
+- The matcher is deliberately transparent rather than sophisticated.
+- The frontend requests viewport-bounded GeoJSON by layer; vector tiles are still the next step for country/Europe scale.
+
+## OSM extraction helper
+
+A starter Osmium shell filter script is included:
+
+```bash
+scripts/osmium_transport_filter.sh europe-latest.osm.pbf transport.osm.pbf
+```
+
+The script calls Osmium through `scripts/host_tool.sh`, which also works from a Flatpak/containerized terminal when `flatpak-spawn --host` is available. The app has a Python Osmium-based `osm_pbf` importer for repeatable prototype runs. For the next stage, add OSM replication diff application, move large-region imports to PostGIS, and serve generalized vector tiles where network editing requires broad viewport rendering.
+
+## Tests
+
+```bash
+pytest -q
+```
diff --git a/app/__init__.py b/app/__init__.py
new file mode 100644
index 0000000..fa0adc2
--- /dev/null
+++ b/app/__init__.py
@@ -0,0 +1 @@
+"""Mobility Workbench prototype."""
diff --git a/app/address_search.py b/app/address_search.py
new file mode 100644
index 0000000..6bbaadf
--- /dev/null
+++ b/app/address_search.py
@@ -0,0 +1,1272 @@
+from __future__ import annotations
+
+import re
+import math
+from typing import Any
+
+from sqlalchemy import select, text
+from sqlalchemy.orm import Session
+
+from app.config import settings
+from app.models import OsmAddress
+from app.pipeline.routing_layer import active_routing_dataset
+
+
+ADDRESS_PREFIX = "address:"
+ADDRESS_POINT_PREFIX = "address-point:"
+COORDINATE_PREFIX = "coord:"
+MAX_ADDRESS_SEARCH_ROWS = 250
+
+
+def address_token(address_id: int) -> str:
+ return f"{ADDRESS_PREFIX}{int(address_id)}"
+
+
+def address_point_token(address_id: int, lat: float, lon: float) -> str:
+ return f"{ADDRESS_POINT_PREFIX}{int(address_id)}:{float(lat):.7f}:{float(lon):.7f}"
+
+
+def coordinate_token(lat: float, lon: float) -> str:
+ return f"{COORDINATE_PREFIX}{float(lat):.7f}:{float(lon):.7f}"
+
+
+def is_address_token(value: object) -> bool:
+ token = str(value or "").strip()
+ return token.startswith(ADDRESS_PREFIX) or token.startswith(ADDRESS_POINT_PREFIX)
+
+
+def is_address_point_token(value: object) -> bool:
+ return str(value or "").strip().startswith(ADDRESS_POINT_PREFIX)
+
+
+def is_coordinate_token(value: object) -> bool:
+ return str(value or "").strip().startswith(COORDINATE_PREFIX)
+
+
+def is_location_token(value: object) -> bool:
+ return is_address_token(value) or is_coordinate_token(value)
+
+
+def parse_address_token(value: object) -> int:
+ token = str(value or "").strip()
+ if not token.startswith(ADDRESS_PREFIX):
+ raise ValueError("invalid address token")
+ try:
+ address_id = int(token[len(ADDRESS_PREFIX) :])
+ except ValueError as exc:
+ raise ValueError("invalid address token") from exc
+ if address_id <= 0:
+ raise ValueError("invalid address token")
+ return address_id
+
+
+def parse_address_point_token(value: object) -> tuple[int, float, float]:
+ token = str(value or "").strip()
+ if not token.startswith(ADDRESS_POINT_PREFIX):
+ raise ValueError("invalid address point token")
+ parts = token[len(ADDRESS_POINT_PREFIX) :].split(":")
+ if len(parts) != 3:
+ raise ValueError("invalid address point token")
+ try:
+ address_id = int(parts[0])
+ lat = float(parts[1])
+ lon = float(parts[2])
+ except ValueError as exc:
+ raise ValueError("invalid address point token") from exc
+ if address_id <= 0 or not (-90 <= lat <= 90) or not (-180 <= lon <= 180):
+ raise ValueError("invalid address point token")
+ return address_id, lat, lon
+
+
+def parse_coordinate_token(value: object) -> tuple[float, float]:
+ token = str(value or "").strip()
+ if not token.startswith(COORDINATE_PREFIX):
+ raise ValueError("invalid coordinate token")
+ parts = token[len(COORDINATE_PREFIX) :].split(":")
+ if len(parts) != 2:
+ raise ValueError("invalid coordinate token")
+ try:
+ lat = float(parts[0])
+ lon = float(parts[1])
+ except ValueError as exc:
+ raise ValueError("invalid coordinate token") from exc
+ if not (-90 <= lat <= 90) or not (-180 <= lon <= 180):
+ raise ValueError("invalid coordinate token")
+ return lat, lon
+
+
+def search_addresses(
+ db: Session,
+ query: str | None = None,
+ *,
+ limit: int = 25,
+ bbox: tuple[float, float, float, float] | None = None,
+) -> list[dict[str, Any]]:
+ dataset = active_routing_dataset(db)
+ if dataset is None:
+ return []
+ q = _normalize_query(query)
+ selected_limit = max(1, min(int(limit), 100))
+ if settings.is_postgresql_database:
+ if q and not _query_has_number(q):
+ payloads = _search_folded_addresses_postgresql(db, int(dataset.id), q, selected_limit, bbox)
+ return payloads[:selected_limit]
+ rows = (
+ _search_numbered_addresses_postgresql(db, int(dataset.id), q, selected_limit, bbox)
+ if q and _query_has_number(q)
+ else _search_addresses_postgresql(db, int(dataset.id), q, selected_limit, bbox)
+ )
+ else:
+ rows = _search_addresses_sqlite(db, int(dataset.id), q, selected_limit, bbox)
+ payloads = [_address_payload(row) for row in rows]
+ if not _query_has_number(q):
+ payloads = _fold_street_payloads(payloads)
+ return payloads[:selected_limit]
+
+
+def _search_folded_addresses_postgresql(
+ db: Session,
+ dataset_id: int,
+ query: str,
+ limit: int,
+ bbox: tuple[float, float, float, float] | None,
+) -> list[dict[str, Any]]:
+ combined: list[dict[str, Any]] = []
+ seen: set[tuple[str, str, str]] = set()
+ for street_query, locality_query in _folded_query_candidates(query):
+ query_specs: list[tuple[int, tuple[float, float, float, float] | None]] = []
+ if bbox is not None and locality_query is None:
+ query_specs.append((limit, bbox))
+ query_specs.append((max(limit * 3, limit), None))
+ else:
+ query_specs.append((limit, None))
+ for query_limit, bbox_filter in query_specs:
+ for payload in _search_folded_addresses_postgresql_query(
+ db,
+ dataset_id,
+ street_query,
+ query_limit,
+ bbox,
+ bbox_filter=bbox_filter,
+ locality_query=locality_query,
+ ):
+ key = _folded_payload_key(payload)
+ if key in seen:
+ continue
+ seen.add(key)
+ combined.append(payload)
+ if len(combined) >= limit:
+ return combined[:limit]
+ return combined[:limit]
+
+
+def _search_folded_addresses_postgresql_query(
+ db: Session,
+ dataset_id: int,
+ query: str,
+ limit: int,
+ bbox: tuple[float, float, float, float] | None,
+ *,
+ bbox_filter: tuple[float, float, float, float] | None,
+ locality_query: str | None,
+) -> list[dict[str, Any]]:
+ params: dict[str, Any] = {
+ "dataset_id": dataset_id,
+ "query": query,
+ "prefix": f"{query}%",
+ "limit": limit,
+ }
+ bbox_filter_sql = ""
+ if bbox_filter is not None:
+ min_lon, min_lat, max_lon, max_lat = bbox_filter
+ params.update(
+ {
+ "filter_min_lon": min_lon,
+ "filter_min_lat": min_lat,
+ "filter_max_lon": max_lon,
+ "filter_max_lat": max_lat,
+ }
+ )
+ bbox_filter_sql = """
+ AND geom && ST_MakeEnvelope(:filter_min_lon, :filter_min_lat, :filter_max_lon, :filter_max_lat, 4326)
+ """
+ locality_filter_sql = ""
+ locality_rank_sql = "0"
+ if locality_query:
+ params["locality_query"] = locality_query
+ params["locality_prefix"] = f"{locality_query}%"
+ locality_filter_sql = """
+ AND (
+ LOWER(COALESCE(city, '')) = :locality_query
+ OR LOWER(COALESCE(city, '')) LIKE :locality_prefix
+ OR LOWER(COALESCE(postcode, '')) = :locality_query
+ )
+ """
+ locality_rank_sql = """
+ CASE
+ WHEN LOWER(COALESCE(city, '')) = :locality_query THEN 0
+ WHEN LOWER(COALESCE(postcode, '')) = :locality_query THEN 1
+ WHEN LOWER(COALESCE(city, '')) LIKE :locality_prefix THEN 2
+ ELSE 3
+ END
+ """
+ bbox_rank_sql, bbox_distance_sql = _postgresql_bbox_rank_sql_for_alias("grouped", bbox, params)
+ street_key_sql = _street_key_sql()
+ rows = db.execute(
+ text(
+ f"""
+ WITH grouped AS (
+ SELECT
+ MIN(id) AS id,
+ MIN(dataset_id) AS dataset_id,
+ COALESCE(NULLIF(street, ''), NULLIF(place, '')) AS street_label,
+ MIN(street) AS street,
+ MIN(place) AS place,
+ postcode,
+ city,
+ MIN(country) AS country,
+ AVG(lat) AS lat,
+ AVG(lon) AS lon,
+ COUNT(*) AS folded_address_count,
+ {locality_rank_sql} AS locality_rank,
+ CASE
+ WHEN {street_key_sql} = :query THEN 0
+ WHEN {street_key_sql} LIKE :prefix THEN 1
+ ELSE 2
+ END AS match_rank
+ FROM osm_addresses
+ WHERE dataset_id = :dataset_id
+ AND {street_key_sql} <> ''
+ AND ({street_key_sql} = :query OR {street_key_sql} LIKE :prefix)
+ {bbox_filter_sql}
+ {locality_filter_sql}
+ GROUP BY COALESCE(NULLIF(street, ''), NULLIF(place, '')), postcode, city, locality_rank, match_rank
+ )
+ SELECT
+ id,
+ dataset_id,
+ street_label,
+ street,
+ place,
+ postcode,
+ city,
+ country,
+ lat,
+ lon,
+ folded_address_count,
+ locality_rank,
+ match_rank,
+ {bbox_rank_sql} AS bbox_rank,
+ {bbox_distance_sql} AS bbox_distance_m
+ FROM grouped
+ ORDER BY locality_rank, bbox_rank, match_rank, bbox_distance_m, street_label, postcode, city, id
+ LIMIT :limit
+ """
+ ),
+ params,
+ ).mappings()
+ return [_folded_address_payload(dict(row)) for row in rows]
+
+
+def _search_numbered_addresses_postgresql(
+ db: Session,
+ dataset_id: int,
+ query: str,
+ limit: int,
+ bbox: tuple[float, float, float, float] | None,
+) -> list[dict[str, Any]]:
+ candidates = _numbered_query_candidates(query)
+ if not candidates:
+ return _search_addresses_postgresql(db, dataset_id, query, limit, bbox)
+ result_by_id: dict[int, dict[str, Any]] = {}
+ for street_query, housenumber_query, locality_query in candidates:
+ for row in _execute_numbered_addresses_postgresql(
+ db,
+ dataset_id=dataset_id,
+ street_query=street_query,
+ housenumber_query=housenumber_query,
+ locality_query=locality_query,
+ limit=limit,
+ bbox=bbox,
+ ):
+ result_by_id.setdefault(int(row["id"]), row)
+ if len(result_by_id) >= limit:
+ return list(result_by_id.values())[:limit]
+ result = list(result_by_id.values())
+ if result:
+ return result
+ for street_query, housenumber_query, locality_query in candidates:
+ for row in _execute_numbered_street_fallback_postgresql(
+ db,
+ dataset_id=dataset_id,
+ street_query=street_query,
+ housenumber_query=housenumber_query,
+ locality_query=locality_query,
+ limit=limit,
+ bbox=bbox,
+ ):
+ result_by_id.setdefault(int(row["id"]), row)
+ if len(result_by_id) >= limit:
+ return list(result_by_id.values())[:limit]
+ result = list(result_by_id.values())
+ return result or _search_addresses_postgresql(db, dataset_id, query, limit, bbox)
+
+
+def _execute_numbered_addresses_postgresql(
+ db: Session,
+ *,
+ dataset_id: int,
+ street_query: str,
+ housenumber_query: str,
+ locality_query: str | None,
+ limit: int,
+ bbox: tuple[float, float, float, float] | None,
+) -> list[dict[str, Any]]:
+ params: dict[str, Any] = {
+ "dataset_id": dataset_id,
+ "street_query": street_query,
+ "street_prefix": f"{street_query}%",
+ "housenumber_query": housenumber_query,
+ "housenumber_prefix": f"{housenumber_query}%",
+ "limit": limit,
+ }
+ locality_filter_sql = ""
+ locality_rank_sql = "0"
+ if locality_query:
+ params["locality_query"] = locality_query
+ params["locality_prefix"] = f"{locality_query}%"
+ locality_filter_sql = """
+ AND (
+ LOWER(COALESCE(city, '')) = :locality_query
+ OR LOWER(COALESCE(city, '')) LIKE :locality_prefix
+ OR LOWER(COALESCE(postcode, '')) = :locality_query
+ )
+ """
+ locality_rank_sql = """
+ CASE
+ WHEN LOWER(COALESCE(city, '')) = :locality_query THEN 0
+ WHEN LOWER(COALESCE(postcode, '')) = :locality_query THEN 1
+ WHEN LOWER(COALESCE(city, '')) LIKE :locality_prefix THEN 2
+ ELSE 3
+ END
+ """
+ bbox_rank_sql, bbox_distance_sql = _postgresql_bbox_rank_sql(bbox, params)
+ street_key_sql = _street_key_sql()
+ rows = db.execute(
+ text(
+ f"""
+ SELECT
+ id,
+ dataset_id,
+ housenumber,
+ street,
+ place,
+ postcode,
+ city,
+ country,
+ unit,
+ name,
+ display_name,
+ search_text,
+ lon,
+ lat,
+ {bbox_rank_sql} AS bbox_rank,
+ {bbox_distance_sql} AS bbox_distance_m,
+ {locality_rank_sql} AS locality_rank,
+ CASE
+ WHEN {street_key_sql} = :street_query AND LOWER(COALESCE(housenumber, '')) = :housenumber_query THEN 0
+ WHEN {street_key_sql} = :street_query AND LOWER(COALESCE(housenumber, '')) LIKE :housenumber_prefix THEN 1
+ WHEN {street_key_sql} LIKE :street_prefix AND LOWER(COALESCE(housenumber, '')) LIKE :housenumber_prefix THEN 2
+ ELSE 3
+ END AS match_rank,
+ 1.0 AS similarity_rank
+ FROM osm_addresses
+ WHERE dataset_id = :dataset_id
+ AND {street_key_sql} <> ''
+ AND ({street_key_sql} = :street_query OR {street_key_sql} LIKE :street_prefix)
+ AND LOWER(COALESCE(housenumber, '')) LIKE :housenumber_prefix
+ {locality_filter_sql}
+ ORDER BY locality_rank, bbox_rank, match_rank, bbox_distance_m, display_name, id
+ LIMIT :limit
+ """
+ ),
+ params,
+ ).mappings()
+ return [dict(row) for row in rows]
+
+
+def _execute_numbered_street_fallback_postgresql(
+ db: Session,
+ *,
+ dataset_id: int,
+ street_query: str,
+ housenumber_query: str,
+ locality_query: str | None,
+ limit: int,
+ bbox: tuple[float, float, float, float] | None,
+) -> list[dict[str, Any]]:
+ params: dict[str, Any] = {
+ "dataset_id": dataset_id,
+ "street_query": street_query,
+ "street_prefix": f"{street_query}%",
+ "housenumber_query": housenumber_query,
+ "housenumber_prefix": f"{housenumber_query}%",
+ "housenumber_number": _leading_number(housenumber_query),
+ "limit": limit,
+ }
+ locality_filter_sql, locality_rank_sql = _postgresql_locality_sql(locality_query, params, indent=" ")
+ bbox_rank_sql, bbox_distance_sql = _postgresql_bbox_rank_sql(bbox, params)
+ street_key_sql = _street_key_sql()
+ rows = db.execute(
+ text(
+ f"""
+ SELECT
+ id,
+ dataset_id,
+ housenumber,
+ street,
+ place,
+ postcode,
+ city,
+ country,
+ unit,
+ name,
+ display_name,
+ search_text,
+ lon,
+ lat,
+ {bbox_rank_sql} AS bbox_rank,
+ {bbox_distance_sql} AS bbox_distance_m,
+ {locality_rank_sql} AS locality_rank,
+ CASE
+ WHEN LOWER(COALESCE(housenumber, '')) = :housenumber_query THEN 0
+ WHEN LOWER(COALESCE(housenumber, '')) LIKE :housenumber_prefix THEN 1
+ ELSE 2
+ END AS match_rank,
+ CASE
+ WHEN :housenumber_number IS NULL THEN 999999
+ WHEN substring(COALESCE(housenumber, '') from '^[0-9]+') = '' THEN 999999
+ ELSE abs(CAST(substring(COALESCE(housenumber, '') from '^[0-9]+') AS INTEGER) - :housenumber_number)
+ END AS house_distance
+ FROM osm_addresses
+ WHERE dataset_id = :dataset_id
+ AND {street_key_sql} <> ''
+ AND ({street_key_sql} = :street_query OR {street_key_sql} LIKE :street_prefix)
+ {locality_filter_sql}
+ ORDER BY locality_rank, bbox_rank, match_rank, house_distance, bbox_distance_m, display_name, id
+ LIMIT :limit
+ """
+ ),
+ params,
+ ).mappings()
+ return [dict(row) for row in rows]
+
+
+def _postgresql_locality_sql(locality_query: str | None, params: dict[str, Any], *, indent: str = "") -> tuple[str, str]:
+ if not locality_query:
+ return "", "0"
+ params["locality_query"] = locality_query
+ params["locality_prefix"] = f"{locality_query}%"
+ filter_sql = f"""
+{indent}AND (
+{indent} LOWER(COALESCE(city, '')) = :locality_query
+{indent} OR LOWER(COALESCE(city, '')) LIKE :locality_prefix
+{indent} OR LOWER(COALESCE(postcode, '')) = :locality_query
+{indent})
+ """
+ rank_sql = """
+ CASE
+ WHEN LOWER(COALESCE(city, '')) = :locality_query THEN 0
+ WHEN LOWER(COALESCE(postcode, '')) = :locality_query THEN 1
+ WHEN LOWER(COALESCE(city, '')) LIKE :locality_prefix THEN 2
+ ELSE 3
+ END
+ """
+ return filter_sql, rank_sql
+
+
+def address_by_token(db: Session, value: object) -> OsmAddress:
+ address_id = parse_address_token(value)
+ address = db.get(OsmAddress, address_id)
+ if address is None:
+ raise ValueError("selected address does not exist")
+ return address
+
+
+def address_point_by_token(db: Session, value: object) -> tuple[OsmAddress, float, float]:
+ address_id, lat, lon = parse_address_point_token(value)
+ address = db.get(OsmAddress, address_id)
+ if address is None:
+ raise ValueError("selected address does not exist")
+ return address, lat, lon
+
+
+def nearest_addresses(
+ db: Session,
+ *,
+ lat: float,
+ lon: float,
+ limit: int = 3,
+ radius_m: float = 150,
+) -> list[dict[str, Any]]:
+ dataset = active_routing_dataset(db)
+ if dataset is None:
+ return []
+ selected_limit = max(1, min(int(limit), 25))
+ if not settings.is_postgresql_database:
+ radius_deg = float(radius_m) / 111_320
+ rows = db.scalars(
+ select(OsmAddress)
+ .where(
+ OsmAddress.dataset_id == dataset.id,
+ OsmAddress.lat >= lat - radius_deg,
+ OsmAddress.lat <= lat + radius_deg,
+ OsmAddress.lon >= lon - radius_deg,
+ OsmAddress.lon <= lon + radius_deg,
+ )
+ .limit(250)
+ ).all()
+ payloads = []
+ for row in rows:
+ payload = _address_payload(row)
+ payload["distance_m"] = _distance_m(lat, lon, float(row.lat), float(row.lon))
+ if payload["distance_m"] <= radius_m:
+ payloads.append(payload)
+ payloads.sort(key=lambda item: (float(item.get("distance_m") or 0), item.get("display_name") or ""))
+ return payloads[:selected_limit]
+
+ radius_deg = float(radius_m) / 111_320
+ rows = db.execute(
+ text(
+ """
+ WITH point AS (
+ SELECT ST_SetSRID(ST_MakePoint(:lon, :lat), 4326) AS geom
+ )
+ SELECT
+ id,
+ dataset_id,
+ housenumber,
+ street,
+ place,
+ postcode,
+ city,
+ country,
+ unit,
+ name,
+ display_name,
+ search_text,
+ lon,
+ lat,
+ ST_DistanceSphere(osm_addresses.geom, point.geom) AS distance_m
+ FROM osm_addresses
+ CROSS JOIN point
+ WHERE dataset_id = :dataset_id
+ AND osm_addresses.geom IS NOT NULL
+ AND osm_addresses.geom && ST_Expand(point.geom, :radius_deg)
+ AND ST_DWithin(osm_addresses.geom::geography, point.geom::geography, :radius_m)
+ ORDER BY osm_addresses.geom <-> point.geom, id
+ LIMIT :limit
+ """
+ ),
+ {
+ "dataset_id": int(dataset.id),
+ "lat": float(lat),
+ "lon": float(lon),
+ "radius_deg": radius_deg,
+ "radius_m": float(radius_m),
+ "limit": selected_limit,
+ },
+ ).mappings()
+ payloads = []
+ for row in rows:
+ payload = _address_payload(dict(row))
+ payload["distance_m"] = float(row["distance_m"] or 0)
+ payloads.append(payload)
+ return payloads
+
+
+def address_at_point(
+ db: Session,
+ *,
+ lat: float,
+ lon: float,
+ max_size_m: float = 250,
+ node_radius_m: float = 12,
+) -> dict[str, Any] | None:
+ dataset = active_routing_dataset(db)
+ if dataset is None:
+ return None
+ lat_span = float(max_size_m) / 111_320
+ lon_span = float(max_size_m) / (111_320 * max(0.2, abs(math.cos(math.radians(float(lat))))))
+ if not settings.is_postgresql_database:
+ row = db.scalar(
+ select(OsmAddress)
+ .where(
+ OsmAddress.dataset_id == dataset.id,
+ OsmAddress.osm_type == "way",
+ OsmAddress.min_lon <= lon,
+ OsmAddress.max_lon >= lon,
+ OsmAddress.min_lat <= lat,
+ OsmAddress.max_lat >= lat,
+ (OsmAddress.max_lon - OsmAddress.min_lon) <= lon_span,
+ (OsmAddress.max_lat - OsmAddress.min_lat) <= lat_span,
+ )
+ .order_by((OsmAddress.max_lon - OsmAddress.min_lon) * (OsmAddress.max_lat - OsmAddress.min_lat), OsmAddress.id)
+ )
+ if row is None:
+ return None
+ payload = _address_payload(row)
+ payload["distance_m"] = _distance_m(lat, lon, float(row.lat), float(row.lon))
+ payload["selection_reason"] = "address_bbox"
+ return payload
+
+ candidate_radius_m = max(float(max_size_m), float(node_radius_m), 20.0)
+ candidate_radius_deg = candidate_radius_m / 111_320
+ row = db.execute(
+ text(
+ """
+ WITH point AS (
+ SELECT ST_SetSRID(ST_MakePoint(:lon, :lat), 4326) AS geom
+ ),
+ polygon_hit AS (
+ SELECT
+ id,
+ dataset_id,
+ housenumber,
+ street,
+ place,
+ postcode,
+ city,
+ country,
+ unit,
+ name,
+ display_name,
+ search_text,
+ lon,
+ lat,
+ ST_DistanceSphere(osm_addresses.geom, point.geom) AS distance_m,
+ 'address_polygon' AS selection_reason
+ FROM osm_addresses
+ CROSS JOIN point
+ WHERE dataset_id = :dataset_id
+ AND osm_type = 'way'
+ AND area_geom IS NOT NULL
+ AND area_geom && point.geom
+ AND ST_Covers(area_geom, point.geom)
+ ORDER BY ST_Area(area_geom::geography), ST_DistanceSphere(osm_addresses.geom, point.geom), id
+ LIMIT 1
+ ),
+ nearby_candidates AS MATERIALIZED (
+ SELECT
+ id,
+ dataset_id,
+ osm_type,
+ housenumber,
+ street,
+ place,
+ postcode,
+ city,
+ country,
+ unit,
+ name,
+ display_name,
+ search_text,
+ lon,
+ lat,
+ min_lon,
+ min_lat,
+ max_lon,
+ max_lat,
+ osm_addresses.geom AS geom,
+ ST_DistanceSphere(osm_addresses.geom, point.geom) AS distance_m
+ FROM osm_addresses
+ CROSS JOIN point
+ WHERE dataset_id = :dataset_id
+ AND osm_addresses.geom IS NOT NULL
+ AND osm_addresses.geom && ST_Expand(point.geom, :candidate_radius_deg)
+ ORDER BY osm_addresses.geom <-> point.geom, id
+ LIMIT 200
+ ),
+ bbox_hit AS (
+ SELECT
+ id,
+ dataset_id,
+ housenumber,
+ street,
+ place,
+ postcode,
+ city,
+ country,
+ unit,
+ name,
+ display_name,
+ search_text,
+ lon,
+ lat,
+ distance_m,
+ 'address_bbox' AS selection_reason
+ FROM nearby_candidates
+ WHERE dataset_id = :dataset_id
+ AND osm_type = 'way'
+ AND min_lon <= :lon
+ AND max_lon >= :lon
+ AND min_lat <= :lat
+ AND max_lat >= :lat
+ AND (max_lon - min_lon) <= :lon_span
+ AND (max_lat - min_lat) <= :lat_span
+ AND NOT EXISTS (SELECT 1 FROM polygon_hit)
+ ORDER BY ABS((max_lon - min_lon) * (max_lat - min_lat)), distance_m, id
+ LIMIT 1
+ ),
+ node_hit AS (
+ SELECT
+ id,
+ dataset_id,
+ housenumber,
+ street,
+ place,
+ postcode,
+ city,
+ country,
+ unit,
+ name,
+ display_name,
+ search_text,
+ lon,
+ lat,
+ distance_m,
+ 'address_node' AS selection_reason
+ FROM nearby_candidates
+ WHERE osm_type = 'node'
+ AND distance_m <= :node_radius_m
+ AND NOT EXISTS (SELECT 1 FROM polygon_hit)
+ AND NOT EXISTS (SELECT 1 FROM bbox_hit)
+ ORDER BY distance_m, id
+ LIMIT 1
+ )
+ SELECT * FROM polygon_hit
+ UNION ALL
+ SELECT * FROM bbox_hit
+ UNION ALL
+ SELECT * FROM node_hit
+ LIMIT 1
+ """
+ ),
+ {
+ "dataset_id": int(dataset.id),
+ "lat": float(lat),
+ "lon": float(lon),
+ "lat_span": lat_span,
+ "lon_span": lon_span,
+ "candidate_radius_deg": candidate_radius_deg,
+ "node_radius_m": max(0.0, float(node_radius_m)),
+ },
+ ).mappings().first()
+ if row is None:
+ return None
+ payload = _address_payload(dict(row))
+ payload["distance_m"] = float(row["distance_m"] or 0)
+ payload["selection_reason"] = row["selection_reason"]
+ return payload
+
+
+def _search_addresses_postgresql(
+ db: Session,
+ dataset_id: int,
+ query: str,
+ limit: int,
+ bbox: tuple[float, float, float, float] | None,
+) -> list[dict[str, Any]]:
+ params: dict[str, Any] = {"dataset_id": dataset_id, "limit": _raw_address_limit(query, limit)}
+ where = ["dataset_id = :dataset_id"]
+ tokens = [token for token in re.split(r"[\s,;/]+", query) if token]
+ long_tokens = [token for token in tokens if len(token) >= 3]
+ if query:
+ params["query"] = query
+ params["pattern"] = f"%{query}%"
+ token_clauses = []
+ for index, token in enumerate(long_tokens[:6]):
+ key = f"token_{index}"
+ params[key] = f"%{token}%"
+ token_clauses.append(f"LOWER(COALESCE(search_text, '')) LIKE :{key}")
+ token_sql = " AND ".join(token_clauses)
+ where.append(
+ "("
+ "LOWER(COALESCE(search_text, '')) % :query "
+ "OR LOWER(COALESCE(search_text, '')) LIKE :pattern "
+ + (f"OR ({token_sql})" if token_sql else "")
+ + ")"
+ )
+ bbox_rank_sql, bbox_distance_sql = _postgresql_bbox_rank_sql(bbox, params)
+ rank_sql = (
+ """
+ CASE
+ WHEN :query = '' THEN 4
+ WHEN LOWER(COALESCE(display_name, '')) = :query THEN 0
+ WHEN LOWER(COALESCE(display_name, '')) LIKE (:query || '%') THEN 1
+ WHEN LOWER(COALESCE(search_text, '')) LIKE :pattern THEN 2
+ ELSE 3
+ END
+ """
+ if query
+ else "4"
+ )
+ if not query:
+ params["query"] = ""
+ params["pattern"] = "%"
+ rows = db.execute(
+ text(
+ f"""
+ SELECT
+ id,
+ dataset_id,
+ housenumber,
+ street,
+ place,
+ postcode,
+ city,
+ country,
+ unit,
+ name,
+ display_name,
+ search_text,
+ lon,
+ lat,
+ {bbox_rank_sql} AS bbox_rank,
+ {bbox_distance_sql} AS bbox_distance_m,
+ {rank_sql} AS match_rank,
+ CASE
+ WHEN :query = '' THEN 0
+ ELSE similarity(LOWER(COALESCE(search_text, '')), :query)
+ END AS similarity_rank
+ FROM osm_addresses
+ WHERE {" AND ".join(where)}
+ ORDER BY bbox_rank, match_rank, similarity_rank DESC, display_name, id
+ LIMIT :limit
+ """
+ ),
+ params,
+ ).mappings()
+ return [dict(row) for row in rows]
+
+
+def _search_addresses_sqlite(
+ db: Session,
+ dataset_id: int,
+ query: str,
+ limit: int,
+ bbox: tuple[float, float, float, float] | None,
+) -> list[OsmAddress]:
+ stmt = select(OsmAddress).where(OsmAddress.dataset_id == dataset_id)
+ if query:
+ tokens = [token for token in re.split(r"[\s,;/]+", query) if token]
+ for token in tokens[:6]:
+ stmt = stmt.where(OsmAddress.search_text.ilike(f"%{token}%"))
+ stmt = stmt.limit(MAX_ADDRESS_SEARCH_ROWS)
+ rows = list(db.scalars(stmt).all())
+ rows.sort(key=lambda row: (_bbox_rank(row.lat, row.lon, bbox), _address_match_rank(row, query), row.display_name, row.id))
+ return rows[: _raw_address_limit(query, limit)]
+
+
+def _postgresql_bbox_rank_sql(
+ bbox: tuple[float, float, float, float] | None,
+ params: dict[str, Any],
+) -> tuple[str, str]:
+ if bbox is None:
+ return "1", "0.0"
+ min_lon, min_lat, max_lon, max_lat = bbox
+ center_lon = (min_lon + max_lon) / 2
+ center_lat = (min_lat + max_lat) / 2
+ params.update(
+ {
+ "bbox_min_lon": min_lon,
+ "bbox_min_lat": min_lat,
+ "bbox_max_lon": max_lon,
+ "bbox_max_lat": max_lat,
+ "bbox_center_lon": center_lon,
+ "bbox_center_lat": center_lat,
+ }
+ )
+ bbox_rank_sql = """
+ CASE
+ WHEN lon IS NULL OR lat IS NULL THEN 2
+ WHEN lon BETWEEN :bbox_min_lon AND :bbox_max_lon
+ AND lat BETWEEN :bbox_min_lat AND :bbox_max_lat THEN 0
+ ELSE 1
+ END
+ """
+ bbox_distance_sql = """
+ sqrt(
+ power((lon - :bbox_center_lon) * 111320.0 * cos(radians(:bbox_center_lat)), 2)
+ + power((lat - :bbox_center_lat) * 111320.0, 2)
+ )
+ """
+ return bbox_rank_sql, bbox_distance_sql
+
+
+def _postgresql_bbox_rank_sql_for_alias(
+ alias: str,
+ bbox: tuple[float, float, float, float] | None,
+ params: dict[str, Any],
+) -> tuple[str, str]:
+ if bbox is None:
+ return "1", "0.0"
+ min_lon, min_lat, max_lon, max_lat = bbox
+ center_lon = (min_lon + max_lon) / 2
+ center_lat = (min_lat + max_lat) / 2
+ params.update(
+ {
+ "bbox_min_lon": min_lon,
+ "bbox_min_lat": min_lat,
+ "bbox_max_lon": max_lon,
+ "bbox_max_lat": max_lat,
+ "bbox_center_lon": center_lon,
+ "bbox_center_lat": center_lat,
+ }
+ )
+ bbox_rank_sql = f"""
+ CASE
+ WHEN {alias}.lon IS NULL OR {alias}.lat IS NULL THEN 2
+ WHEN {alias}.lon BETWEEN :bbox_min_lon AND :bbox_max_lon
+ AND {alias}.lat BETWEEN :bbox_min_lat AND :bbox_max_lat THEN 0
+ ELSE 1
+ END
+ """
+ bbox_distance_sql = f"""
+ sqrt(
+ power(({alias}.lon - :bbox_center_lon) * 111320.0 * cos(radians(:bbox_center_lat)), 2)
+ + power(({alias}.lat - :bbox_center_lat) * 111320.0, 2)
+ )
+ """
+ return bbox_rank_sql, bbox_distance_sql
+
+
+def _address_payload(row: OsmAddress | dict[str, Any]) -> dict[str, Any]:
+ get = row.get if isinstance(row, dict) else lambda key, default=None: getattr(row, key, default)
+ address_id = int(get("id"))
+ street = get("street")
+ place = get("place")
+ housenumber = get("housenumber")
+ city = get("city")
+ local_name = " ".join(str(part) for part in [street or place, housenumber] if part).strip() or get("display_name")
+ return {
+ "id": address_token(address_id),
+ "address_id": address_id,
+ "kind": "address",
+ "dataset_id": get("dataset_id"),
+ "stop_id": address_token(address_id),
+ "name": get("display_name"),
+ "display_name": get("display_name"),
+ "city": city,
+ "local_name": local_name,
+ "street": street,
+ "place": place,
+ "housenumber": housenumber,
+ "postcode": get("postcode"),
+ "lat": get("lat"),
+ "lon": get("lon"),
+ "source_id": None,
+ "source_name": "OSM address",
+ "scheduled": False,
+ "grouped": False,
+ "grouped_stop_count": 1,
+ "folded_address_count": 1,
+ "approximate": False,
+ }
+
+
+def _folded_address_payload(row: dict[str, Any]) -> dict[str, Any]:
+ address_id = int(row["id"])
+ lat = row.get("lat")
+ lon = row.get("lon")
+ street_label = row.get("street_label") or row.get("street") or row.get("place")
+ locality = " ".join(str(part) for part in [row.get("postcode"), row.get("city")] if part).strip()
+ display_name = f"{street_label}, {locality}" if locality else str(street_label or "Address")
+ token = address_point_token(address_id, float(lat), float(lon)) if lat is not None and lon is not None else address_token(address_id)
+ return {
+ "id": token,
+ "address_id": address_id,
+ "representative_address_id": address_id,
+ "kind": "address",
+ "dataset_id": row.get("dataset_id"),
+ "stop_id": token,
+ "name": display_name,
+ "display_name": display_name,
+ "city": row.get("city"),
+ "local_name": str(street_label or display_name),
+ "street": row.get("street") or street_label,
+ "place": row.get("place"),
+ "housenumber": None,
+ "postcode": row.get("postcode"),
+ "lat": lat,
+ "lon": lon,
+ "source_id": None,
+ "source_name": "OSM street address",
+ "scheduled": False,
+ "grouped": False,
+ "grouped_stop_count": 1,
+ "folded_address_count": int(row.get("folded_address_count") or 1),
+ "approximate": True,
+ }
+
+
+def _folded_payload_key(payload: dict[str, Any]) -> tuple[str, str, str]:
+ return (
+ str(payload.get("street") or payload.get("place") or payload.get("display_name") or "").casefold(),
+ str(payload.get("postcode") or "").casefold(),
+ str(payload.get("city") or "").casefold(),
+ )
+
+
+def _fold_street_payloads(payloads: list[dict[str, Any]]) -> list[dict[str, Any]]:
+ folded: dict[tuple[str, str, str], dict[str, Any]] = {}
+ singles: list[dict[str, Any]] = []
+ for payload in payloads:
+ street = str(payload.get("street") or payload.get("place") or "").casefold().strip()
+ city = str(payload.get("city") or "").casefold().strip()
+ postcode = str(payload.get("postcode") or "").casefold().strip()
+ if not street:
+ singles.append(payload)
+ continue
+ key = (street, city, postcode)
+ current = folded.get(key)
+ if current is None:
+ current = dict(payload)
+ current["_representatives"] = [payload]
+ current["folded_address_count"] = 1
+ current["approximate"] = True
+ current["housenumber"] = None
+ local_name = str(payload.get("street") or payload.get("place") or "")
+ locality = " ".join(part for part in [payload.get("postcode"), payload.get("city")] if part)
+ current["local_name"] = local_name
+ current["display_name"] = f"{local_name}, {locality}" if locality else local_name
+ current["name"] = current["display_name"]
+ folded[key] = current
+ continue
+ current["folded_address_count"] = int(current.get("folded_address_count") or 1) + 1
+ current["_representatives"].append(payload)
+
+ result = list(folded.values())
+ for payload in result:
+ representatives = payload.pop("_representatives", [])
+ coords = [
+ (float(item["lat"]), float(item["lon"]))
+ for item in representatives
+ if item.get("lat") is not None and item.get("lon") is not None
+ ]
+ if coords:
+ payload["lat"] = sum(item[0] for item in coords) / len(coords)
+ payload["lon"] = sum(item[1] for item in coords) / len(coords)
+ token = address_point_token(int(payload["address_id"]), float(payload["lat"]), float(payload["lon"]))
+ payload["id"] = token
+ payload["stop_id"] = token
+ payload["representative_address_id"] = payload["address_id"]
+ payload["source_name"] = "OSM street address"
+ result.extend(singles)
+ return result
+
+
+def _address_match_rank(row: OsmAddress, query: str) -> int:
+ if not query:
+ return 4
+ haystack = row.search_text.casefold()
+ if row.display_name.casefold() == query:
+ return 0
+ if row.display_name.casefold().startswith(query):
+ return 1
+ if query in haystack:
+ return 2
+ tokens = [token for token in re.split(r"[\s,;/]+", query) if token]
+ return 3 if tokens and all(token in haystack for token in tokens) else 4
+
+
+def _bbox_rank(lat: float | None, lon: float | None, bbox: tuple[float, float, float, float] | None) -> tuple[int, float]:
+ if bbox is None:
+ return (1, 0.0)
+ if lat is None or lon is None:
+ return (2, float("inf"))
+ min_lon, min_lat, max_lon, max_lat = bbox
+ if min_lon <= lon <= max_lon and min_lat <= lat <= max_lat:
+ return (0, 0.0)
+ center_lon = (min_lon + max_lon) / 2
+ center_lat = (min_lat + max_lat) / 2
+ return (1, (lon - center_lon) * (lon - center_lon) + (lat - center_lat) * (lat - center_lat))
+
+
+def _distance_m(lat_a: float, lon_a: float, lat_b: float, lon_b: float) -> float:
+ return (((float(lon_b) - float(lon_a)) ** 2 + (float(lat_b) - float(lat_a)) ** 2) ** 0.5) * 111_320
+
+
+def _normalize_query(query: str | None) -> str:
+ return re.sub(r"\s+", " ", str(query or "").casefold().strip())
+
+
+def _query_has_number(query: str) -> bool:
+ return bool(re.search(r"\d", query or ""))
+
+
+def _split_numbered_query(query: str) -> tuple[str, str, str | None] | None:
+ candidates = _numbered_query_candidates(query)
+ return candidates[0] if candidates else None
+
+
+def _numbered_query_candidates(query: str) -> list[tuple[str, str, str | None]]:
+ normalized = _normalize_query(query)
+ if "," in normalized:
+ left, right = [part.strip() for part in normalized.split(",", 1)]
+ left_has_number = _query_has_number(left)
+ right_has_number = _query_has_number(right)
+ if left_has_number and not right_has_number:
+ return _numbered_query_candidates_from_parts(left, right)
+ if left_has_number and _looks_like_locality(right):
+ return _numbered_query_candidates_from_parts(left, right)
+ if right_has_number:
+ candidates: list[tuple[str, str, str | None]] = []
+ for candidate in _numbered_query_candidates_from_parts(right, left):
+ if candidate not in candidates:
+ candidates.append(candidate)
+ for candidate in _numbered_query_candidates_from_parts(left, right):
+ if candidate not in candidates:
+ candidates.append(candidate)
+ return candidates
+ street_part, locality_query = _split_locality_query(query)
+ return _numbered_query_candidates_from_parts(street_part, locality_query)
+
+
+def _numbered_query_candidates_from_parts(street_part: str, locality_query: str | None) -> list[tuple[str, str, str | None]]:
+ match = re.search(r"\b(\d+[a-zäöüß]?)\b", street_part or "", flags=re.IGNORECASE)
+ if match is None:
+ return []
+ housenumber = match.group(1).casefold()
+ street = re.sub(r"\b" + re.escape(match.group(1)) + r"\b", " ", street_part, count=1, flags=re.IGNORECASE)
+ street = _normalize_query(street)
+ if len(street) < 3 or not housenumber:
+ return []
+ candidates: list[tuple[str, str, str | None]] = []
+ for locality in _locality_candidates(locality_query):
+ _append_numbered_candidate(candidates, street, housenumber, locality)
+ if locality_query is None:
+ tokens = [token for token in street.split(" ") if token]
+ for index in range(len(tokens) - 1, 0, -1):
+ leading_locality = " ".join(tokens[:index])
+ trailing_street = " ".join(tokens[index:])
+ _append_numbered_candidate(candidates, trailing_street, housenumber, leading_locality)
+ for index in range(1, len(tokens)):
+ leading_street = " ".join(tokens[:index])
+ trailing_locality = " ".join(tokens[index:])
+ _append_numbered_candidate(candidates, leading_street, housenumber, trailing_locality)
+ return candidates
+
+
+def _append_numbered_candidate(
+ candidates: list[tuple[str, str, str | None]],
+ street: str,
+ housenumber: str,
+ locality: str | None,
+) -> None:
+ normalized_street = _normalize_query(street)
+ normalized_locality = _normalize_query(locality) if locality else None
+ if len(normalized_street) < 3 or not housenumber:
+ return
+ if normalized_locality is not None and len(normalized_locality) < 2:
+ normalized_locality = None
+ candidate = (normalized_street, housenumber, normalized_locality)
+ if candidate not in candidates:
+ candidates.append(candidate)
+
+
+def _locality_candidates(locality: str | None) -> list[str | None]:
+ normalized = _normalize_query(locality)
+ if not normalized:
+ return [None]
+ candidates: list[str | None] = []
+ _append_locality_candidate(candidates, normalized)
+ match = re.match(r"^(\d{4,5})\s+(.+)$", normalized)
+ if match:
+ _append_locality_candidate(candidates, match.group(2))
+ _append_locality_candidate(candidates, match.group(1))
+ match = re.match(r"^(.+)\s+(\d{4,5})$", normalized)
+ if match:
+ _append_locality_candidate(candidates, match.group(1))
+ _append_locality_candidate(candidates, match.group(2))
+ return candidates
+
+
+def _append_locality_candidate(candidates: list[str | None], value: str | None) -> None:
+ normalized = _normalize_query(value)
+ candidate = normalized if normalized else None
+ if candidate not in candidates:
+ candidates.append(candidate)
+
+
+def _looks_like_locality(value: str) -> bool:
+ normalized = _normalize_query(value)
+ return bool(re.match(r"^\d{4,5}(\s+|$)", normalized)) or not _query_has_number(normalized)
+
+
+def _leading_number(value: str | None) -> int | None:
+ match = re.match(r"\s*(\d+)", str(value or ""))
+ return None if match is None else int(match.group(1))
+
+
+def _split_locality_query(query: str) -> tuple[str, str | None]:
+ normalized = _normalize_query(query)
+ if "," not in normalized:
+ return normalized, None
+ locality, remainder = normalized.split(",", 1)
+ locality = locality.strip()
+ remainder = remainder.strip()
+ if len(locality) < 2 or len(remainder) < 2:
+ return normalized, None
+ return remainder, locality
+
+
+def _folded_query_candidates(query: str) -> list[tuple[str, str | None]]:
+ normalized = _normalize_query(query)
+ if not normalized:
+ return []
+ street_query, locality_query = _split_locality_query(normalized)
+ if locality_query:
+ candidates: list[tuple[str, str | None]] = []
+ for locality in _locality_candidates(locality_query):
+ _append_folded_candidate(candidates, street_query, locality)
+ return candidates
+
+ candidates = [(normalized, None)]
+ tokens = [token for token in normalized.split(" ") if token]
+ if len(tokens) < 2:
+ return candidates
+ for index in range(1, len(tokens)):
+ leading_locality = " ".join(tokens[:index])
+ trailing_street = " ".join(tokens[index:])
+ _append_folded_candidate(candidates, trailing_street, leading_locality)
+ for index in range(1, len(tokens)):
+ leading_street = " ".join(tokens[:index])
+ trailing_locality = " ".join(tokens[index:])
+ _append_folded_candidate(candidates, leading_street, trailing_locality)
+ return candidates
+
+
+def _append_folded_candidate(
+ candidates: list[tuple[str, str | None]],
+ street: str,
+ locality: str | None,
+) -> None:
+ normalized_street = _normalize_query(street)
+ if len(normalized_street) < 3:
+ return
+ for locality_candidate in _locality_candidates(locality):
+ if locality_candidate is not None and len(locality_candidate) < 2:
+ locality_candidate = None
+ candidate = (normalized_street, locality_candidate)
+ if candidate not in candidates:
+ candidates.append(candidate)
+
+
+def _street_key_sql() -> str:
+ return "REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss')"
+
+
+def _raw_address_limit(query: str, limit: int) -> int:
+ multiplier = 30 if query and not _query_has_number(query) else 6
+ return min(MAX_ADDRESS_SEARCH_ROWS, max(limit * multiplier, limit))
diff --git a/app/cli.py b/app/cli.py
new file mode 100644
index 0000000..01ec81c
--- /dev/null
+++ b/app/cli.py
@@ -0,0 +1,394 @@
+from __future__ import annotations
+
+import json
+import csv
+from pathlib import Path
+from typing import Optional
+
+import typer
+from sqlalchemy import func, select, text
+
+from app.config import settings
+from app.data_management import dataset_sidecar_paths, prune_inactive_datasets
+from app.db import engine, init_db, reset_db, session_scope
+from app.db_lock import database_write_lock
+from app.feed_discovery import build_gtfs_discovery_manifests, default_generated_dir
+from app.models import (
+ Dataset,
+ GtfsRoute,
+ GtfsShape,
+ GtfsStop,
+ RouteMatch,
+ RoutePattern,
+ Source,
+ SourceCatalogEntry,
+)
+from app.pipeline.matcher import run_route_matching
+from app.pipeline.osm_labeling import relabel_osm_features
+from app.pipeline.osm_pbf import run_osm_pbf_source_staged
+from app.pipeline.run import run_source
+from app.pipeline.gtfs import backfill_gtfs_shapes
+from app.pipeline.route_layer import rebuild_route_layer
+from app.pipeline.sample_data import load_sample_project
+from app.osm_storage import osm_feature_count
+from app.jobs import run_worker_loop
+from app.jobs import create_route_layer_rebuild_job, create_route_matching_job, create_source_import_job
+from app.source_catalog import (
+ default_ingestable_sources_path,
+ default_source_catalog_path,
+ import_ingestable_sources,
+ import_source_catalog,
+ source_catalog_summary,
+)
+
+cli = typer.Typer(help="Mobility Workbench pipeline CLI")
+
+
+@cli.command("init-db")
+def init_db_command() -> None:
+ with _write_lock("init-db"):
+ init_db()
+ typer.echo("Database initialized")
+
+
+@cli.command("reset-db")
+def reset_db_command() -> None:
+ with _write_lock("reset-db"):
+ reset_db()
+ typer.echo("Database reset")
+
+
+@cli.command("load-sample")
+def load_sample_command() -> None:
+ with _write_lock("load-sample"):
+ init_db()
+ with session_scope() as session:
+ result = load_sample_project(session)
+ typer.echo(json.dumps(result, indent=2))
+
+
+@cli.command("add-source")
+def add_source_command(
+ name: str = typer.Option(..., help="Source name"),
+ kind: str = typer.Option(..., help="gtfs, osm_geojson, osm_pbf, or osm_diff"),
+ url: str = typer.Option(..., help="HTTP URL or local path"),
+ country: Optional[str] = typer.Option(None),
+ license: Optional[str] = typer.Option(None),
+ priority: Optional[str] = typer.Option(None),
+ mode_scope: Optional[str] = typer.Option(None),
+ source_basis: Optional[str] = typer.Option(None),
+ notes: Optional[str] = typer.Option(None),
+) -> None:
+ with _write_lock("add-source"):
+ init_db()
+ if kind not in {"gtfs", "osm_geojson", "osm_pbf", "osm_diff"}:
+ raise typer.BadParameter("kind must be gtfs, osm_geojson, osm_pbf, or osm_diff")
+ with session_scope() as session:
+ source = Source(
+ name=name,
+ kind=kind,
+ url=url,
+ country=country,
+ license=license,
+ priority=priority,
+ mode_scope=mode_scope,
+ source_basis=source_basis,
+ notes=notes,
+ )
+ session.add(source)
+ session.flush()
+ typer.echo(json.dumps({"id": source.id, "name": source.name}, indent=2))
+
+
+@cli.command("run-source")
+def run_source_command(source_id: int) -> None:
+ init_db()
+ with session_scope() as session:
+ source = session.get(Source, source_id)
+ if source is None:
+ raise typer.BadParameter(f"source not found: {source_id}")
+ source_kind = source.kind
+ if source_kind == "osm_pbf":
+ dataset = run_osm_pbf_source_staged(source_id)
+ typer.echo(json.dumps({"source_id": source_id, "dataset_id": dataset.id, "status": dataset.status, "import_mode": "staged_short_lock"}, indent=2))
+ return
+ with _write_lock("run-source"):
+ with session_scope() as session:
+ source = session.get(Source, source_id)
+ if source is None:
+ raise typer.BadParameter(f"source not found: {source_id}")
+ dataset = run_source(session, source)
+ typer.echo(json.dumps({"source_id": source.id, "dataset_id": dataset.id, "status": dataset.status}, indent=2))
+
+
+@cli.command("run-match")
+def run_match_command() -> None:
+ with _write_lock("run-match"):
+ init_db()
+ with session_scope() as session:
+ result = run_route_matching(session)
+ typer.echo(json.dumps(result, indent=2))
+
+
+@cli.command("build-route-layer")
+def build_route_layer_command() -> None:
+ with _write_lock("build-route-layer"):
+ init_db()
+ with session_scope() as session:
+ result = rebuild_route_layer(session)
+ typer.echo(json.dumps(result, indent=2))
+
+
+@cli.command("relabel-osm-features")
+def relabel_osm_features_command(
+ dataset_id: Optional[int] = typer.Option(None, help="Only relabel one OSM dataset"),
+ force: bool = typer.Option(False, help="Run even when the recorded dependency signature is current"),
+ chunk_size: int = typer.Option(5000, help="Rows per relabel batch"),
+ rebuild_indexes: bool = typer.Option(True, help="Drop/rebuild affected route-scope indexes around large relabel writes"),
+ build_route_layer: bool = typer.Option(True, help="Rebuild the route layer after relabeling"),
+) -> None:
+ with _write_lock("relabel-osm-features"):
+ init_db()
+ with session_scope() as session:
+ result = relabel_osm_features(
+ session,
+ dataset_id=dataset_id,
+ force=force,
+ chunk_size=chunk_size,
+ rebuild_indexes=rebuild_indexes,
+ )
+ if build_route_layer and (result["changed"] or force):
+ result["route_layer_result"] = rebuild_route_layer(session)
+ typer.echo(json.dumps(result, indent=2))
+
+
+@cli.command("backfill-gtfs-shapes")
+def backfill_gtfs_shapes_command(dataset_id: Optional[int] = typer.Option(None, help="Only backfill one GTFS dataset")) -> None:
+ with _write_lock("backfill-gtfs-shapes"):
+ init_db()
+ with session_scope() as session:
+ result = backfill_gtfs_shapes(session, dataset_id=dataset_id)
+ typer.echo(json.dumps(result, indent=2))
+
+
+@cli.command("stats")
+def stats_command() -> None:
+ init_db()
+ with session_scope() as session:
+ active_dataset_ids = [row[0] for row in session.execute(select(Dataset.id).where(Dataset.is_active.is_(True))).all()]
+ stats = {
+ "sources": session.scalar(select(func.count()).select_from(Source)),
+ "source_catalog_entries": session.scalar(select(func.count()).select_from(SourceCatalogEntry)) or 0,
+ "active_datasets": len(active_dataset_ids),
+ "gtfs_routes": session.scalar(select(func.count()).select_from(GtfsRoute).where(GtfsRoute.dataset_id.in_(active_dataset_ids))) if active_dataset_ids else 0,
+ "gtfs_stops": session.scalar(select(func.count()).select_from(GtfsStop).where(GtfsStop.dataset_id.in_(active_dataset_ids))) if active_dataset_ids else 0,
+ "gtfs_shapes": session.scalar(select(func.count()).select_from(GtfsShape).where(GtfsShape.dataset_id.in_(active_dataset_ids))) if active_dataset_ids else 0,
+ "route_patterns": session.scalar(select(func.count()).select_from(RoutePattern)) or 0,
+ "osm_routes": sum(osm_feature_count(session, dataset_id, kind="route") for dataset_id in active_dataset_ids),
+ "matches": {status: count for status, count in session.execute(select(RouteMatch.status, func.count()).group_by(RouteMatch.status)).all()},
+ }
+ typer.echo(json.dumps(stats, indent=2))
+
+
+@cli.command("import-source-catalog")
+def import_source_catalog_command(
+ csv_path: Path = typer.Option(default_source_catalog_path(), "--csv", help="Source catalog CSV path"),
+ no_update: bool = typer.Option(False, help="Skip rows that already exist"),
+) -> None:
+ with _write_lock("import-source-catalog"):
+ init_db()
+ with session_scope() as session:
+ result = import_source_catalog(session, csv_path, update_existing=not no_update)
+ result["summary"] = source_catalog_summary(session)
+ typer.echo(json.dumps(result, indent=2))
+
+
+@cli.command("import-ingestable-sources")
+def import_ingestable_sources_command(
+ csv_path: Path = typer.Option(default_ingestable_sources_path(), "--csv", help="Ingestable source seed CSV path"),
+ no_update: bool = typer.Option(False, help="Skip sources that already exist"),
+) -> None:
+ with _write_lock("import-ingestable-sources"):
+ init_db()
+ with session_scope() as session:
+ result = import_ingestable_sources(session, csv_path, update_existing=not no_update)
+ result["summary"] = source_catalog_summary(session)
+ typer.echo(json.dumps(result, indent=2))
+
+
+@cli.command("discover-gtfs-sources")
+def discover_gtfs_sources_command(
+ output_dir: Path = typer.Option(default_generated_dir(), "--output-dir", help="Directory for generated discovery CSVs"),
+ countries: str = typer.Option(
+ ",".join(["DE", "AT", "CH", "NL", "DK", "FR", "BE", "LU", "NO", "SE", "FI", "IE", "GB"]),
+ "--countries",
+ help="Comma-separated country codes, or ALL for every country exposed by the upstream catalogs",
+ ),
+ no_mobility_database: bool = typer.Option(False, help="Skip Mobility Database feeds_v2.csv"),
+ no_acceptance_test_list: bool = typer.Option(False, help="Skip MobilityData validator acceptance-test feed list"),
+ no_ptna: bool = typer.Option(False, help="Skip PTNA GTFS analysis pages"),
+ max_ptna_details: int = typer.Option(80, help="Maximum PTNA detail pages to fetch for license/crosswalk metadata"),
+ test_limit: int = typer.Option(24, help="Rows to write to the focused test-run ingestable CSV"),
+ check_urls: bool = typer.Option(False, help="Run HEAD/range checks for ingestable feed URLs"),
+) -> None:
+ result = build_gtfs_discovery_manifests(
+ output_dir=output_dir,
+ countries=[part.strip() for part in countries.split(",") if part.strip()],
+ include_mobility_database=not no_mobility_database,
+ include_acceptance_test_list=not no_acceptance_test_list,
+ include_ptna=not no_ptna,
+ max_ptna_details=max_ptna_details,
+ test_limit=test_limit,
+ check_urls=check_urls,
+ )
+ typer.echo(json.dumps(result, indent=2, ensure_ascii=False))
+
+
+@cli.command("queue-source-imports-from-csv")
+def queue_source_imports_from_csv_command(
+ csv_path: Path = typer.Option(default_ingestable_sources_path(), "--csv", help="Ingestable source CSV path"),
+ no_update: bool = typer.Option(False, help="Skip sources that already exist instead of updating them"),
+ run_match_at_end: bool = typer.Option(True, help="Queue one route-matching job after all source imports"),
+ build_route_layer_at_end: bool = typer.Option(True, help="Queue one route-layer rebuild after route matching"),
+ priority: int = typer.Option(0, help="Priority for queued source import jobs"),
+) -> None:
+ with _write_lock("queue-source-imports-from-csv"):
+ init_db()
+ with session_scope() as session:
+ csv_path = csv_path if csv_path.is_absolute() else Path.cwd() / csv_path
+ imported = import_ingestable_sources(session, csv_path, update_existing=not no_update)
+ source_urls = _source_urls_from_ingestable_csv(csv_path)
+ sources = session.scalars(
+ select(Source)
+ .where(Source.kind == "gtfs", Source.url.in_(source_urls))
+ .order_by(Source.id)
+ ).all()
+ jobs = [
+ create_source_import_job(
+ session,
+ source,
+ run_match=False,
+ build_route_layer=False,
+ priority=priority,
+ )
+ for source in sources
+ ]
+ route_match_job = create_route_matching_job(session, priority=priority) if run_match_at_end else None
+ route_layer_job = create_route_layer_rebuild_job(session, priority=priority) if build_route_layer_at_end else None
+ typer.echo(
+ json.dumps(
+ {
+ "csv": str(csv_path),
+ "imported": imported,
+ "sources": [{"id": source.id, "name": source.name} for source in sources],
+ "source_import_jobs": [job.id for job in jobs],
+ "route_match_job": None if route_match_job is None else route_match_job.id,
+ "route_layer_job": None if route_layer_job is None else route_layer_job.id,
+ },
+ indent=2,
+ ensure_ascii=False,
+ )
+ )
+
+
+@cli.command("prune-cache")
+def prune_cache_command(dry_run: bool = typer.Option(False, help="Report files without deleting them")) -> None:
+ with _write_lock("prune-cache"):
+ init_db()
+ with session_scope() as session:
+ referenced = {
+ Path(path).resolve()
+ for path in session.scalars(select(Dataset.local_path)).all()
+ if path
+ }
+ for dataset in session.scalars(select(Dataset)).all():
+ referenced.update(path.resolve() for path in dataset_sidecar_paths(dataset))
+
+ roots = [settings.data_dir / "sources", settings.data_dir / "derived", settings.data_dir / "sidecars", settings.data_dir / "staging"]
+ candidates = [
+ path
+ for root in roots
+ if root.exists()
+ for path in root.rglob("*")
+ if path.is_file() and path.resolve() not in referenced
+ ]
+ total_bytes = sum(path.stat().st_size for path in candidates)
+ if not dry_run:
+ for path in candidates:
+ path.unlink()
+ for root in roots:
+ _remove_empty_dirs(root)
+
+ typer.echo(
+ json.dumps(
+ {
+ "dry_run": dry_run,
+ "files": len(candidates),
+ "bytes": total_bytes,
+ "deleted": 0 if dry_run else len(candidates),
+ },
+ indent=2,
+ )
+ )
+
+
+@cli.command("prune-inactive-datasets")
+def prune_inactive_datasets_command(
+ dry_run: bool = typer.Option(False, help="Report inactive normalized datasets without deleting them"),
+) -> None:
+ with _write_lock("prune-inactive-datasets"):
+ init_db()
+ with session_scope() as session:
+ result = prune_inactive_datasets(session, dry_run=dry_run)
+ typer.echo(json.dumps(result, indent=2))
+
+
+@cli.command("vacuum-db")
+def vacuum_db_command() -> None:
+ with _write_lock("vacuum-db"):
+ init_db()
+ with engine.connect().execution_options(isolation_level="AUTOCOMMIT") as connection:
+ connection.execute(text("VACUUM"))
+ connection.execute(text("PRAGMA wal_checkpoint(TRUNCATE)"))
+ typer.echo("Database vacuumed")
+
+
+@cli.command("worker")
+def worker_command(
+ once: bool = typer.Option(False, help="Process at most one queued job and exit"),
+ max_jobs: Optional[int] = typer.Option(None, help="Process at most this many jobs and exit"),
+ poll_interval: float = typer.Option(2.0, help="Seconds to wait between queue polls"),
+ worker_id: Optional[str] = typer.Option(None, help="Stable worker identifier"),
+) -> None:
+ result = run_worker_loop(worker_id=worker_id, poll_interval=poll_interval, max_jobs=max_jobs, once=once)
+ typer.echo(json.dumps(result, indent=2))
+
+
+def _remove_empty_dirs(root: Path) -> None:
+ if not root.exists():
+ return
+ for path in sorted((p for p in root.rglob("*") if p.is_dir()), key=lambda p: len(p.parts), reverse=True):
+ try:
+ path.rmdir()
+ except OSError:
+ pass
+
+
+def _write_lock(operation: str):
+ return database_write_lock(f"cli:{operation}", timeout=settings.database_write_lock_cli_timeout_seconds)
+
+
+def _source_urls_from_ingestable_csv(path: Path) -> list[str]:
+ urls: list[str] = []
+ with path.open("r", encoding="utf-8-sig", newline="") as handle:
+ for row in csv.DictReader(handle):
+ if (row.get("kind") or "").strip().lower() != "gtfs":
+ continue
+ url = (row.get("url") or "").strip()
+ if url and url not in urls:
+ urls.append(url)
+ return urls
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/app/config.py b/app/config.py
new file mode 100644
index 0000000..1325451
--- /dev/null
+++ b/app/config.py
@@ -0,0 +1,74 @@
+from __future__ import annotations
+
+from pathlib import Path
+from pydantic_settings import BaseSettings, SettingsConfigDict
+
+
+class Settings(BaseSettings):
+ """Runtime settings.
+
+ SQLite is the default because this prototype should run immediately.
+ The schema is deliberately plain enough to migrate to PostGIS later.
+ """
+
+ database_url: str = "sqlite:///./data/workbench.sqlite"
+ data_dir: Path = Path("./data")
+ # 0 means import all stop_times. Use a positive value only for constrained
+ # demos where full timetable routing is not needed.
+ gtfs_stop_times_import_limit: int = 0
+ # "sidecar_stop_times" keeps the large timetable call table in a per-dataset
+ # SQLite file and stores compact GTFS tables in the main app database.
+ # Set to "main" for the old all-in-one SQLite layout.
+ gtfs_timetable_storage: str = "sidecar_stop_times"
+ gtfs_keep_activation_stage: bool = False
+ # "sidecar_features" keeps extracted OSM transport features in a per-dataset
+ # SQLite file. The main DB materializes only OSM rows that need stable
+ # foreign keys for matches or route-layer output.
+ osm_feature_storage: str = "sidecar_features"
+ osm_sidecar_create_visual_only_stops: bool = False
+ # Large OSM PBF extracts should be reduced to transport objects before the
+ # Python extractor scans them. XML fixtures stay unfiltered by default.
+ osm_pbf_prefilter_enabled: bool = True
+ osm_pbf_prefilter_formats: str = "osm_pbf"
+ osm_pbf_prefilter_script: Path = Path("scripts/osmium_transport_filter.sh")
+ osm_diff_max_sequence_gap: int = 14
+ osm_diff_apply_batch_size: int = 7
+ osm_diff_state_timeout_seconds: float = 30.0
+ sqlite_timeout_seconds: float = 120.0
+ sqlite_busy_timeout_ms: int = 120000
+ database_write_lock_timeout_seconds: float = 1.0
+ database_write_lock_cli_timeout_seconds: float = 3600.0
+ queue_worker_autostart: bool = True
+ queue_worker_count: int = 1
+ queue_worker_poll_interval_seconds: float = 2.0
+ queue_job_lease_seconds: int = 7200
+ route_matching_batch_size: int = 100
+ route_layer_osm_route_batch_size: int = 1000
+ route_layer_osm_stop_batch_size: int = 5000
+ # SQLite defaults to sidecar storage. PostgreSQL/PostGIS defaults to main
+ # table storage so indexes, joins, and spatial operators can work over the
+ # full imported datasets.
+ postgres_use_sidecars: bool = False
+ # Keep supervised workers alive across API server restarts. Stale workers are
+ # detected by PID files at the next startup; stale job leases are requeued.
+ queue_worker_stop_on_shutdown: bool = False
+
+ model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
+
+ @property
+ def normalized_database_url(self) -> str:
+ if self.database_url.startswith("postgresql://"):
+ return "postgresql+psycopg://" + self.database_url.removeprefix("postgresql://")
+ return self.database_url
+
+ @property
+ def is_sqlite_database(self) -> bool:
+ return self.normalized_database_url.startswith("sqlite")
+
+ @property
+ def is_postgresql_database(self) -> bool:
+ return self.normalized_database_url.startswith("postgresql")
+
+
+settings = Settings()
+settings.data_dir.mkdir(parents=True, exist_ok=True)
diff --git a/app/data_management.py b/app/data_management.py
new file mode 100644
index 0000000..bc24c8d
--- /dev/null
+++ b/app/data_management.py
@@ -0,0 +1,327 @@
+from __future__ import annotations
+
+from pathlib import Path
+
+from sqlalchemy import delete, func, or_, select
+from sqlalchemy.orm import Session
+
+from app.config import settings
+from app.gtfs_storage import dataset_sidecar_paths as gtfs_dataset_sidecar_paths, missing_sidecar_paths as gtfs_missing_sidecar_paths, stop_time_count
+from app.models import (
+ CanonicalStopLink,
+ Dataset,
+ GtfsAgency,
+ GtfsCalendar,
+ GtfsCalendarDate,
+ GtfsRoute,
+ GtfsRoutePatternLink,
+ GtfsShape,
+ GtfsStop,
+ GtfsStopTime,
+ GtfsTripRoutePatternLink,
+ GtfsTrip,
+ OsmDiffState,
+ OsmFeature,
+ RouteMatch,
+ RoutePattern,
+ RoutePatternStop,
+ Source,
+ SourceUpdateCheck,
+)
+from app.osm_storage import (
+ dataset_sidecar_paths as osm_dataset_sidecar_paths,
+ missing_sidecar_paths as osm_missing_sidecar_paths,
+ osm_feature_count,
+)
+
+
+def dataset_row_counts(session: Session, dataset_id: int, kind: str) -> dict[str, int]:
+ if kind == "gtfs":
+ route_ids = select(GtfsRoute.id).where(GtfsRoute.dataset_id == dataset_id)
+ match_counts = {
+ status: count
+ for status, count in session.execute(
+ select(RouteMatch.status, func.count())
+ .where(RouteMatch.gtfs_route_id.in_(route_ids))
+ .group_by(RouteMatch.status)
+ ).all()
+ }
+ return {
+ "agencies": _count(session, GtfsAgency, dataset_id),
+ "stops": _count(session, GtfsStop, dataset_id),
+ "routes": _count(session, GtfsRoute, dataset_id),
+ "trips": _count(session, GtfsTrip, dataset_id),
+ "calendars": _count(session, GtfsCalendar, dataset_id),
+ "calendar_dates": _count(session, GtfsCalendarDate, dataset_id),
+ "shapes": _count(session, GtfsShape, dataset_id),
+ "stop_times": stop_time_count(session, dataset_id),
+ "missing_sidecar": _gtfs_sidecar_missing(session, dataset_id),
+ "matches": sum(match_counts.values()),
+ "match_counts": match_counts,
+ }
+ if kind == "osm_geojson":
+ return {
+ "features": _safe_osm_feature_count(session, dataset_id),
+ "routes": _safe_osm_feature_count(session, dataset_id, kind="route"),
+ "stops": _safe_osm_feature_count(session, dataset_id, kind=["stop", "station", "terminal"]),
+ "infra": _safe_osm_feature_count(session, dataset_id, kind="infra"),
+ "missing_sidecar": _osm_sidecar_missing(session, dataset_id),
+ }
+ return {}
+
+
+def source_row_counts(session: Session, source: Source) -> dict[str, object]:
+ counts = {
+ "datasets": len(source.datasets),
+ "active_datasets": sum(1 for dataset in source.datasets if dataset.is_active),
+ "routes": 0,
+ "stops": 0,
+ "features": 0,
+ "trips": 0,
+ "shapes": 0,
+ "stop_times": 0,
+ "missing_sidecars": 0,
+ "match_counts": {},
+ "missing_gtfs_sidecars": 0,
+ "missing_osm_sidecars": 0,
+ }
+ match_counts: dict[str, int] = {}
+ for dataset in source.datasets:
+ stats = dataset_row_counts(session, dataset.id, dataset.kind)
+ counts["routes"] += int(stats.get("routes", 0))
+ counts["stops"] += int(stats.get("stops", 0))
+ counts["features"] += int(stats.get("features", 0))
+ counts["trips"] += int(stats.get("trips", 0))
+ counts["shapes"] += int(stats.get("shapes", 0))
+ counts["stop_times"] += int(stats.get("stop_times", 0))
+ if stats.get("missing_sidecar"):
+ counts["missing_sidecars"] += 1
+ if dataset.kind == "gtfs":
+ counts["missing_gtfs_sidecars"] += 1
+ elif dataset.kind == "osm_geojson":
+ counts["missing_osm_sidecars"] += 1
+ for status, count in stats.get("match_counts", {}).items():
+ match_counts[status] = match_counts.get(status, 0) + int(count)
+ counts["match_counts"] = match_counts
+ return counts
+
+
+def delete_dataset(session: Session, dataset_id: int) -> dict[str, object]:
+ dataset = session.get(Dataset, dataset_id)
+ if dataset is None:
+ return {"deleted": False, "reason": "dataset not found", "dataset_id": dataset_id}
+
+ counts = dataset_row_counts(session, dataset.id, dataset.kind)
+ _detach_update_checks_for_dataset(session, dataset.id)
+ session.execute(delete(OsmDiffState).where(OsmDiffState.raw_dataset_id == dataset.id))
+ _delete_dataset_rows(session, dataset)
+ _delete_dataset_files(dataset)
+ session.delete(dataset)
+ session.flush()
+ return {"deleted": True, "dataset_id": dataset_id, "counts": counts}
+
+
+def delete_source(session: Session, source_id: int) -> dict[str, object]:
+ source = session.get(Source, source_id)
+ if source is None:
+ return {"deleted": False, "reason": "source not found", "source_id": source_id}
+
+ datasets = list(source.datasets)
+ dataset_results = []
+ for dataset in datasets:
+ dataset_results.append({"dataset_id": dataset.id, "kind": dataset.kind, "counts": dataset_row_counts(session, dataset.id, dataset.kind)})
+ _detach_update_checks_for_dataset(session, dataset.id)
+ session.execute(delete(OsmDiffState).where(OsmDiffState.raw_dataset_id == dataset.id))
+ _delete_dataset_rows(session, dataset)
+ _delete_dataset_files(dataset)
+ session.delete(dataset)
+ session.execute(delete(OsmDiffState).where(OsmDiffState.source_id == source.id))
+ session.delete(source)
+ session.flush()
+ return {"deleted": True, "source_id": source_id, "datasets": dataset_results}
+
+
+def unreferenced_cache_file_summary(session: Session) -> dict[str, int]:
+ candidates = _unreferenced_cache_files(session)
+ return {"files": len(candidates), "bytes": sum(path.stat().st_size for path in candidates)}
+
+
+def prune_unreferenced_cache_files(session: Session) -> dict[str, int]:
+ candidates = _unreferenced_cache_files(session)
+ total_bytes = sum(path.stat().st_size for path in candidates)
+ for path in candidates:
+ path.unlink()
+ for root in _cache_roots():
+ _remove_empty_dirs(root)
+ return {"files": len(candidates), "bytes": total_bytes}
+
+
+def _unreferenced_cache_files(session: Session) -> list[Path]:
+ referenced = {
+ Path(path).resolve()
+ for path in session.scalars(select(Dataset.local_path)).all()
+ if path
+ }
+ for dataset in session.scalars(select(Dataset)).all():
+ referenced.update(path.resolve() for path in dataset_sidecar_paths(dataset))
+ return [
+ path
+ for root in _cache_roots()
+ if root.exists()
+ for path in root.rglob("*")
+ if path.is_file() and path.resolve() not in referenced
+ ]
+
+
+def _cache_roots() -> list[Path]:
+ # Staging files are not referenced by datasets until activation. Automatic
+ # pruning must not remove a staging DB from a running import.
+ return [settings.data_dir / "sources", settings.data_dir / "derived", settings.data_dir / "sidecars"]
+
+
+def prune_inactive_datasets(session: Session, dry_run: bool = True) -> dict[str, object]:
+ dataset_rows = session.execute(
+ select(Dataset.id, Dataset.kind).where(Dataset.is_active.is_(False), Dataset.kind.in_(["gtfs", "osm_geojson"]))
+ ).all()
+ dataset_ids = [int(row[0]) for row in dataset_rows]
+ gtfs_ids = [int(dataset_id) for dataset_id, kind in dataset_rows if kind == "gtfs"]
+ osm_ids = [int(dataset_id) for dataset_id, kind in dataset_rows if kind == "osm_geojson"]
+
+ route_ids = select(GtfsRoute.id).where(GtfsRoute.dataset_id.in_(gtfs_ids)) if gtfs_ids else None
+ osm_feature_ids = select(OsmFeature.id).where(OsmFeature.dataset_id.in_(osm_ids)) if osm_ids else None
+ match_filters = []
+ if route_ids is not None:
+ match_filters.append(RouteMatch.gtfs_route_id.in_(route_ids))
+ if osm_feature_ids is not None:
+ match_filters.append(RouteMatch.osm_feature_id.in_(osm_feature_ids))
+
+ counts = {
+ "datasets": len(dataset_ids),
+ "gtfs_stop_times": sum(stop_time_count(session, dataset_id) for dataset_id in gtfs_ids),
+ "gtfs_shapes": _count_dataset_rows(session, GtfsShape, gtfs_ids),
+ "gtfs_trips": _count_dataset_rows(session, GtfsTrip, gtfs_ids),
+ "gtfs_calendar_dates": _count_dataset_rows(session, GtfsCalendarDate, gtfs_ids),
+ "gtfs_calendars": _count_dataset_rows(session, GtfsCalendar, gtfs_ids),
+ "gtfs_routes": _count_dataset_rows(session, GtfsRoute, gtfs_ids),
+ "gtfs_stops": _count_dataset_rows(session, GtfsStop, gtfs_ids),
+ "gtfs_agencies": _count_dataset_rows(session, GtfsAgency, gtfs_ids),
+ "osm_features": sum(_safe_osm_feature_count(session, dataset_id) for dataset_id in osm_ids),
+ "missing_osm_sidecars": sum(1 for dataset_id in osm_ids if _osm_sidecar_missing(session, dataset_id)),
+ "gtfs_route_pattern_links": session.scalar(select(func.count()).select_from(GtfsRoutePatternLink).where(GtfsRoutePatternLink.dataset_id.in_(gtfs_ids))) if gtfs_ids else 0,
+ "gtfs_trip_route_pattern_links": session.scalar(select(func.count()).select_from(GtfsTripRoutePatternLink).where(GtfsTripRoutePatternLink.dataset_id.in_(gtfs_ids))) if gtfs_ids else 0,
+ "canonical_stop_links": session.scalar(select(func.count()).select_from(CanonicalStopLink).where(CanonicalStopLink.dataset_id.in_(dataset_ids))) if dataset_ids else 0,
+ "route_matches": session.scalar(select(func.count()).select_from(RouteMatch).where(or_(*match_filters))) if match_filters else 0,
+ }
+ if dry_run or not dataset_ids:
+ return {"dry_run": dry_run, "dataset_ids": dataset_ids, "deleted": counts if not dry_run else {}, "would_delete": counts}
+
+ for dataset_id in dataset_ids:
+ _detach_update_checks_for_dataset(session, dataset_id)
+ if match_filters:
+ session.execute(delete(RouteMatch).where(or_(*match_filters)))
+ if gtfs_ids:
+ route_ids = select(GtfsRoute.id).where(GtfsRoute.dataset_id.in_(gtfs_ids))
+ pattern_ids = select(RoutePattern.id).where(RoutePattern.gtfs_route_id.in_(route_ids))
+ session.execute(delete(RoutePatternStop).where(RoutePatternStop.route_pattern_id.in_(pattern_ids)))
+ session.execute(delete(GtfsTripRoutePatternLink).where(GtfsTripRoutePatternLink.dataset_id.in_(gtfs_ids)))
+ session.execute(delete(GtfsRoutePatternLink).where(GtfsRoutePatternLink.dataset_id.in_(gtfs_ids)))
+ session.execute(delete(RoutePattern).where(RoutePattern.gtfs_route_id.in_(route_ids)))
+ session.execute(delete(CanonicalStopLink).where(CanonicalStopLink.dataset_id.in_(gtfs_ids), CanonicalStopLink.object_type == "gtfs_stop"))
+ for model in [GtfsStopTime, GtfsShape, GtfsTrip, GtfsCalendarDate, GtfsCalendar, GtfsRoute, GtfsStop, GtfsAgency]:
+ session.execute(delete(model).where(model.dataset_id.in_(gtfs_ids)))
+ if osm_ids:
+ osm_feature_ids = select(OsmFeature.id).where(OsmFeature.dataset_id.in_(osm_ids))
+ pattern_ids = select(RoutePattern.id).where(RoutePattern.osm_feature_id.in_(osm_feature_ids))
+ session.execute(delete(RoutePatternStop).where(RoutePatternStop.route_pattern_id.in_(pattern_ids)))
+ session.execute(delete(GtfsTripRoutePatternLink).where(GtfsTripRoutePatternLink.route_pattern_id.in_(pattern_ids)))
+ session.execute(delete(GtfsRoutePatternLink).where(GtfsRoutePatternLink.route_pattern_id.in_(pattern_ids)))
+ session.execute(delete(RoutePattern).where(RoutePattern.osm_feature_id.in_(osm_feature_ids)))
+ session.execute(delete(CanonicalStopLink).where(CanonicalStopLink.dataset_id.in_(osm_ids), CanonicalStopLink.object_type == "osm_feature"))
+ session.execute(delete(OsmFeature).where(OsmFeature.dataset_id.in_(osm_ids)))
+ for dataset in session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all():
+ _delete_dataset_files(dataset)
+ session.execute(delete(Dataset).where(Dataset.id.in_(dataset_ids)))
+ session.flush()
+ return {"dry_run": dry_run, "dataset_ids": dataset_ids, "deleted": counts, "would_delete": {}}
+
+
+def _delete_dataset_rows(session: Session, dataset: Dataset) -> None:
+ if dataset.kind == "gtfs":
+ route_ids = select(GtfsRoute.id).where(GtfsRoute.dataset_id == dataset.id)
+ pattern_ids = select(RoutePattern.id).where(RoutePattern.gtfs_route_id.in_(route_ids))
+ session.execute(delete(RouteMatch).where(RouteMatch.gtfs_route_id.in_(route_ids)))
+ session.execute(delete(RoutePatternStop).where(RoutePatternStop.route_pattern_id.in_(pattern_ids)))
+ session.execute(delete(GtfsTripRoutePatternLink).where(GtfsTripRoutePatternLink.dataset_id == dataset.id))
+ session.execute(delete(GtfsRoutePatternLink).where(GtfsRoutePatternLink.dataset_id == dataset.id))
+ session.execute(delete(RoutePattern).where(RoutePattern.gtfs_route_id.in_(route_ids)))
+ session.execute(delete(CanonicalStopLink).where(CanonicalStopLink.dataset_id == dataset.id, CanonicalStopLink.object_type == "gtfs_stop"))
+ for model in [GtfsStopTime, GtfsShape, GtfsTrip, GtfsCalendarDate, GtfsCalendar, GtfsRoute, GtfsStop, GtfsAgency]:
+ session.execute(delete(model).where(model.dataset_id == dataset.id))
+ elif dataset.kind == "osm_geojson":
+ osm_feature_ids = select(OsmFeature.id).where(OsmFeature.dataset_id == dataset.id)
+ pattern_ids = select(RoutePattern.id).where(RoutePattern.osm_feature_id.in_(osm_feature_ids))
+ session.execute(delete(RouteMatch).where(RouteMatch.osm_feature_id.in_(osm_feature_ids)))
+ session.execute(delete(RoutePatternStop).where(RoutePatternStop.route_pattern_id.in_(pattern_ids)))
+ session.execute(delete(GtfsTripRoutePatternLink).where(GtfsTripRoutePatternLink.route_pattern_id.in_(pattern_ids)))
+ session.execute(delete(GtfsRoutePatternLink).where(GtfsRoutePatternLink.route_pattern_id.in_(pattern_ids)))
+ session.execute(delete(RoutePattern).where(RoutePattern.osm_feature_id.in_(osm_feature_ids)))
+ session.execute(delete(CanonicalStopLink).where(CanonicalStopLink.dataset_id == dataset.id, CanonicalStopLink.object_type == "osm_feature"))
+ session.execute(delete(OsmFeature).where(OsmFeature.dataset_id == dataset.id))
+
+
+def _delete_dataset_files(dataset: Dataset) -> None:
+ for path in dataset_sidecar_paths(dataset):
+ try:
+ path.unlink()
+ except FileNotFoundError:
+ pass
+
+
+def dataset_sidecar_paths(dataset: Dataset) -> list[Path]:
+ return [*gtfs_dataset_sidecar_paths(dataset), *osm_dataset_sidecar_paths(dataset)]
+
+
+def _gtfs_sidecar_missing(session: Session, dataset_id: int) -> bool:
+ dataset = session.get(Dataset, dataset_id)
+ return bool(gtfs_missing_sidecar_paths(dataset))
+
+
+def _safe_osm_feature_count(session: Session, dataset_id: int, *, kind=None) -> int:
+ try:
+ return osm_feature_count(session, dataset_id, kind=kind)
+ except FileNotFoundError:
+ return 0
+
+
+def _osm_sidecar_missing(session: Session, dataset_id: int) -> bool:
+ dataset = session.get(Dataset, dataset_id)
+ return bool(osm_missing_sidecar_paths(dataset))
+
+
+def _detach_update_checks_for_dataset(session: Session, dataset_id: int) -> None:
+ for check in session.scalars(select(SourceUpdateCheck).where(SourceUpdateCheck.active_dataset_id == dataset_id)).all():
+ check.active_dataset_id = None
+
+
+def _count(session: Session, model, dataset_id: int) -> int:
+ return session.scalar(select(func.count()).select_from(model).where(model.dataset_id == dataset_id)) or 0
+
+
+def _count_where(session: Session, model, dataset_id: int, *where) -> int:
+ return session.scalar(select(func.count()).select_from(model).where(model.dataset_id == dataset_id, *where)) or 0
+
+
+def _count_dataset_rows(session: Session, model, dataset_ids: list[int]) -> int:
+ if not dataset_ids:
+ return 0
+ return session.scalar(select(func.count()).select_from(model).where(model.dataset_id.in_(dataset_ids))) or 0
+
+
+def _remove_empty_dirs(root: Path) -> None:
+ if not root.exists():
+ return
+ for path in sorted((p for p in root.rglob("*") if p.is_dir()), key=lambda p: len(p.parts), reverse=True):
+ try:
+ path.rmdir()
+ except OSError:
+ pass
diff --git a/app/dataset_search.py b/app/dataset_search.py
new file mode 100644
index 0000000..ce2967d
--- /dev/null
+++ b/app/dataset_search.py
@@ -0,0 +1,252 @@
+from __future__ import annotations
+
+from sqlalchemy import func, or_, select
+from sqlalchemy.orm import Session
+
+from app.gtfs_storage import execute_sidecar_query, uses_sidecar_stop_times
+from app.models import Dataset, GtfsRoute, GtfsShape, GtfsStopTime, GtfsTrip, OsmFeature, RoutePattern, Source
+from app.osm_storage import osm_feature_public_id, query_osm_features
+from app.pipeline.utils import norm_ref
+
+
+def search_datasets(session: Session, query: str, *, active_only: bool = False, limit: int = 80) -> dict:
+ q = (query or "").strip()
+ if len(q) < 1:
+ return {"query": q, "gtfs_routes": [], "osm_routes": [], "route_patterns": [], "totals": {}}
+ max_rows = max(1, min(limit, 250))
+ gtfs_routes = _gtfs_route_hits(session, q, active_only=active_only, limit=max_rows)
+ osm_routes = _osm_route_hits(session, q, active_only=active_only, limit=max_rows)
+ route_patterns = _route_pattern_hits(session, q, limit=max_rows)
+ return {
+ "query": q,
+ "gtfs_routes": gtfs_routes,
+ "osm_routes": osm_routes,
+ "route_patterns": route_patterns,
+ "totals": {
+ "gtfs_routes": len(gtfs_routes),
+ "osm_routes": len(osm_routes),
+ "route_patterns": len(route_patterns),
+ },
+ }
+
+
+def _gtfs_route_hits(session: Session, query: str, *, active_only: bool, limit: int) -> list[dict]:
+ pattern = f"%{query}%"
+ ref = norm_ref(query)
+ stmt = (
+ select(GtfsRoute, Dataset, Source)
+ .join(Dataset, Dataset.id == GtfsRoute.dataset_id)
+ .join(Source, Source.id == Dataset.source_id)
+ .where(
+ or_(
+ GtfsRoute.short_name.ilike(pattern),
+ GtfsRoute.route_id.ilike(pattern),
+ GtfsRoute.long_name.ilike(pattern),
+ GtfsRoute.route_key == ref,
+ )
+ )
+ .order_by(Dataset.is_active.desc(), Source.name, GtfsRoute.short_name, GtfsRoute.route_id)
+ .limit(limit)
+ )
+ if active_only:
+ stmt = stmt.where(Dataset.is_active.is_(True))
+ rows = session.execute(stmt).all()
+ route_ids = [route.id for route, _, _ in rows]
+ trip_counts = _trip_counts(session, route_ids)
+ stop_time_counts = _stop_time_counts(session, route_ids)
+ shape_counts = _shape_counts(session, route_ids)
+ return [
+ {
+ "type": "gtfs_route",
+ "source": _source_payload(source),
+ "dataset": _dataset_payload(dataset),
+ "route": {
+ "id": route.id,
+ "route_id": route.route_id,
+ "ref": route.short_name,
+ "name": route.long_name,
+ "mode": route.mode,
+ "operator": route.operator_name,
+ },
+ "geometry": _geometry_payload(route),
+ "timetable": {
+ "trips": trip_counts.get(route.id, 0),
+ "stop_times": stop_time_counts.get(route.id, 0),
+ "shapes": shape_counts.get(route.id, 0),
+ },
+ }
+ for route, dataset, source in rows
+ ]
+
+
+def _osm_route_hits(session: Session, query: str, *, active_only: bool, limit: int) -> list[dict]:
+ ref = norm_ref(query)
+ dataset_stmt = select(Dataset).where(Dataset.kind == "osm_geojson")
+ if active_only:
+ dataset_stmt = dataset_stmt.where(Dataset.is_active.is_(True))
+ datasets = session.scalars(dataset_stmt.order_by(Dataset.is_active.desc(), Dataset.id)).all()
+ if not datasets:
+ return []
+ dataset_ids = [dataset.id for dataset in datasets]
+ sources = {source.id: source for source in session.scalars(select(Source).where(Source.id.in_([dataset.source_id for dataset in datasets]))).all()}
+ dataset_by_id = {dataset.id: dataset for dataset in datasets}
+ features_by_identity: dict[tuple[int, str, str], OsmFeature] = {}
+ for feature in query_osm_features(session, dataset_ids, kinds=["route"], search=query, limit=limit):
+ features_by_identity[(feature.dataset_id, feature.osm_type, feature.osm_id)] = feature
+ if ref:
+ for feature in query_osm_features(session, dataset_ids, kinds=["route"], route_key=ref, limit=limit):
+ features_by_identity[(feature.dataset_id, feature.osm_type, feature.osm_id)] = feature
+ features = sorted(
+ features_by_identity.values(),
+ key=lambda feature: (
+ 0 if dataset_by_id.get(feature.dataset_id) and dataset_by_id[feature.dataset_id].is_active else 1,
+ sources.get(dataset_by_id[feature.dataset_id].source_id).name if dataset_by_id.get(feature.dataset_id) and sources.get(dataset_by_id[feature.dataset_id].source_id) else "",
+ feature.ref or "",
+ feature.name or "",
+ feature.id or 0,
+ ),
+ )[:limit]
+ return [
+ {
+ "type": "osm_route",
+ "source": _source_payload(source),
+ "dataset": _dataset_payload(dataset),
+ "osm": {
+ "id": osm_feature_public_id(feature),
+ "osm_type": feature.osm_type,
+ "osm_id": feature.osm_id,
+ "ref": feature.ref,
+ "name": feature.name,
+ "mode": feature.mode,
+ "route_scope": feature.route_scope,
+ "operator": feature.operator,
+ "network": feature.network,
+ },
+ "geometry": _geometry_payload(feature),
+ }
+ for feature in features
+ if (dataset := dataset_by_id.get(feature.dataset_id)) is not None
+ if (source := sources.get(dataset.source_id)) is not None
+ ]
+
+
+def _route_pattern_hits(session: Session, query: str, *, limit: int) -> list[dict]:
+ pattern = f"%{query}%"
+ ref = norm_ref(query)
+ stmt = (
+ select(RoutePattern)
+ .where(
+ or_(
+ RoutePattern.route_ref.ilike(pattern),
+ RoutePattern.route_name.ilike(pattern),
+ RoutePattern.pattern_key.ilike(pattern),
+ )
+ )
+ .order_by(RoutePattern.source_kind, RoutePattern.route_ref, RoutePattern.id)
+ .limit(limit)
+ )
+ rows = session.scalars(stmt).all()
+ return [
+ {
+ "type": "route_pattern",
+ "id": pattern_row.id,
+ "ref": pattern_row.route_ref,
+ "name": pattern_row.route_name,
+ "mode": pattern_row.mode,
+ "route_scope": pattern_row.route_scope,
+ "source_kind": pattern_row.source_kind,
+ "status": pattern_row.status,
+ "confidence": pattern_row.confidence,
+ "gtfs_route_id": pattern_row.gtfs_route_id,
+ "osm_feature_id": pattern_row.osm_feature_id,
+ "geometry": _geometry_payload(pattern_row),
+ }
+ for pattern_row in rows
+ if not ref or norm_ref(pattern_row.route_ref or pattern_row.route_name or "") == ref or query.lower() in (pattern_row.route_name or "").lower()
+ ]
+
+
+def _trip_counts(session: Session, route_row_ids: list[int]) -> dict[int, int]:
+ if not route_row_ids:
+ return {}
+ rows = session.execute(
+ select(GtfsRoute.id, func.count(GtfsTrip.id))
+ .join(GtfsTrip, (GtfsTrip.dataset_id == GtfsRoute.dataset_id) & (GtfsTrip.route_id == GtfsRoute.route_id))
+ .where(GtfsRoute.id.in_(route_row_ids))
+ .group_by(GtfsRoute.id)
+ ).all()
+ return {int(route_id): int(count) for route_id, count in rows}
+
+
+def _stop_time_counts(session: Session, route_row_ids: list[int]) -> dict[int, int]:
+ if not route_row_ids:
+ return {}
+ routes = session.scalars(select(GtfsRoute).where(GtfsRoute.id.in_(route_row_ids))).all()
+ sidecar_routes = [route for route in routes if uses_sidecar_stop_times(session, route.dataset_id)]
+ sidecar_route_ids = {route.id for route in sidecar_routes}
+ main_route_ids = [route.id for route in routes if route.id not in sidecar_route_ids]
+ counts: dict[int, int] = {}
+ if main_route_ids:
+ rows = session.execute(
+ select(GtfsRoute.id, func.count(GtfsStopTime.id))
+ .join(GtfsTrip, (GtfsTrip.dataset_id == GtfsRoute.dataset_id) & (GtfsTrip.route_id == GtfsRoute.route_id))
+ .join(GtfsStopTime, (GtfsStopTime.dataset_id == GtfsTrip.dataset_id) & (GtfsStopTime.trip_id == GtfsTrip.trip_id))
+ .where(GtfsRoute.id.in_(main_route_ids))
+ .group_by(GtfsRoute.id)
+ ).all()
+ counts.update({int(route_id): int(count) for route_id, count in rows})
+ for route in sidecar_routes:
+ rows = execute_sidecar_query(
+ session,
+ route.dataset_id,
+ """
+ SELECT COUNT(*) AS count
+ FROM gtfs_stop_times AS stop_times
+ JOIN gtfs_trips AS trips
+ ON trips.trip_id = stop_times.trip_id
+ WHERE trips.route_id = ?
+ """,
+ [route.route_id],
+ )
+ counts[int(route.id)] = int(rows[0]["count"] or 0) if rows else 0
+ return counts
+
+
+def _shape_counts(session: Session, route_row_ids: list[int]) -> dict[int, int]:
+ if not route_row_ids:
+ return {}
+ rows = session.execute(
+ select(GtfsRoute.id, func.count(func.distinct(GtfsShape.shape_id)))
+ .join(GtfsTrip, (GtfsTrip.dataset_id == GtfsRoute.dataset_id) & (GtfsTrip.route_id == GtfsRoute.route_id))
+ .join(GtfsShape, (GtfsShape.dataset_id == GtfsTrip.dataset_id) & (GtfsShape.shape_id == GtfsTrip.shape_id))
+ .where(GtfsRoute.id.in_(route_row_ids))
+ .group_by(GtfsRoute.id)
+ ).all()
+ return {int(route_id): int(count) for route_id, count in rows}
+
+
+def _source_payload(source: Source) -> dict:
+ return {"id": source.id, "name": source.name, "kind": source.kind, "country": source.country}
+
+
+def _dataset_payload(dataset: Dataset) -> dict:
+ return {
+ "id": dataset.id,
+ "kind": dataset.kind,
+ "is_active": dataset.is_active,
+ "status": dataset.status,
+ "created_at": dataset.created_at.isoformat() if dataset.created_at else None,
+ "sha256": dataset.sha256,
+ }
+
+
+def _geometry_payload(row) -> dict:
+ bbox = None
+ if all(getattr(row, attr, None) is not None for attr in ("min_lon", "min_lat", "max_lon", "max_lat")):
+ bbox = {
+ "min_lon": row.min_lon,
+ "min_lat": row.min_lat,
+ "max_lon": row.max_lon,
+ "max_lat": row.max_lat,
+ }
+ return {"present": bool(getattr(row, "geometry_geojson", None)), "bbox": bbox}
diff --git a/app/db.py b/app/db.py
new file mode 100644
index 0000000..a5abafa
--- /dev/null
+++ b/app/db.py
@@ -0,0 +1,339 @@
+from __future__ import annotations
+
+from contextlib import contextmanager
+from pathlib import Path
+import re
+from typing import Iterator
+
+from sqlalchemy import create_engine
+from sqlalchemy import event
+from sqlalchemy import text
+from sqlalchemy.engine import Connection
+from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
+
+from app.config import settings
+
+
+class Base(DeclarativeBase):
+ pass
+
+
+def _connect_args() -> dict:
+ if settings.is_sqlite_database:
+ return {"check_same_thread": False, "timeout": settings.sqlite_timeout_seconds}
+ return {}
+
+
+def _ensure_sqlite_parent() -> None:
+ if not settings.is_sqlite_database:
+ return
+ # sqlite:///./data/workbench.sqlite -> ./data/workbench.sqlite
+ path = settings.normalized_database_url.replace("sqlite:///", "", 1)
+ if path and path != ":memory:":
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
+
+
+_ensure_sqlite_parent()
+engine = create_engine(settings.normalized_database_url, connect_args=_connect_args(), pool_pre_ping=True, future=True)
+SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False, expire_on_commit=False, future=True)
+
+_CREATE_INDEX_NAME_RE = re.compile(
+ r"CREATE\s+(?:UNIQUE\s+)?INDEX\s+(?:CONCURRENTLY\s+)?(?:IF\s+NOT\s+EXISTS\s+)?([A-Za-z_][A-Za-z0-9_]*)",
+ re.IGNORECASE,
+)
+
+
+if settings.is_sqlite_database:
+ @event.listens_for(engine, "connect")
+ def _set_sqlite_pragmas(dbapi_connection, _connection_record) -> None:
+ cursor = dbapi_connection.cursor()
+ try:
+ cursor.execute("PRAGMA journal_mode=WAL")
+ cursor.execute(f"PRAGMA busy_timeout={int(settings.sqlite_busy_timeout_ms)}")
+ cursor.execute("PRAGMA synchronous=NORMAL")
+ cursor.execute("PRAGMA foreign_keys=ON")
+ finally:
+ cursor.close()
+
+
+def init_db() -> None:
+ # Import models so metadata is populated.
+ from app import models # noqa: F401
+
+ _ensure_database_extensions()
+ Base.metadata.create_all(bind=engine)
+ _ensure_runtime_columns()
+ _ensure_runtime_indexes()
+
+
+def reset_db() -> None:
+ from app import models # noqa: F401
+
+ _ensure_database_extensions()
+ Base.metadata.drop_all(bind=engine)
+ Base.metadata.create_all(bind=engine)
+ _ensure_runtime_columns()
+ _ensure_runtime_indexes()
+
+
+def _ensure_database_extensions() -> None:
+ if not settings.is_postgresql_database:
+ return
+ with engine.begin() as conn:
+ conn.execute(text("CREATE EXTENSION IF NOT EXISTS postgis"))
+ conn.execute(text("CREATE EXTENSION IF NOT EXISTS pg_trgm"))
+ has_pgrouting = conn.execute(text("SELECT EXISTS (SELECT 1 FROM pg_available_extensions WHERE name = 'pgrouting')")).scalar()
+ if has_pgrouting:
+ conn.execute(text("CREATE EXTENSION IF NOT EXISTS pgrouting"))
+
+
+def _ensure_runtime_columns() -> None:
+ if settings.is_postgresql_database:
+ _ensure_postgresql_runtime_columns()
+ return
+ if not settings.is_sqlite_database:
+ return
+ with engine.begin() as conn:
+ columns = {row[1] for row in conn.execute(text("PRAGMA table_info(gtfs_stop_times)")).all()}
+ if "arrival_seconds" not in columns:
+ conn.execute(text("ALTER TABLE gtfs_stop_times ADD COLUMN arrival_seconds INTEGER"))
+ if "departure_seconds" not in columns:
+ conn.execute(text("ALTER TABLE gtfs_stop_times ADD COLUMN departure_seconds INTEGER"))
+
+ source_columns = {row[1] for row in conn.execute(text("PRAGMA table_info(sources)")).all()}
+ source_runtime_columns = {
+ "catalog_entry_id": "INTEGER",
+ "priority": "VARCHAR(16)",
+ "mode_scope": "TEXT",
+ "source_basis": "TEXT",
+ "notes": "TEXT",
+ }
+ for column_name, column_type in source_runtime_columns.items():
+ if column_name not in source_columns:
+ conn.execute(text(f"ALTER TABLE sources ADD COLUMN {column_name} {column_type}"))
+
+ job_columns = {row[1] for row in conn.execute(text("PRAGMA table_info(jobs)")).all()}
+ job_runtime_columns = {
+ "priority": "INTEGER NOT NULL DEFAULT 0",
+ "requested_action": "VARCHAR(32)",
+ "lease_owner": "VARCHAR(255)",
+ "lease_expires_at": "DATETIME",
+ "paused_at": "DATETIME",
+ "dismissed_at": "DATETIME",
+ }
+ for column_name, column_type in job_runtime_columns.items():
+ if column_name not in job_columns:
+ conn.execute(text(f"ALTER TABLE jobs ADD COLUMN {column_name} {column_type}"))
+
+ route_runtime_tables = {
+ "gtfs_routes": "VARCHAR(64)",
+ "route_patterns": "VARCHAR(64)",
+ "osm_features": "VARCHAR(64)",
+ }
+ for table_name, column_type in route_runtime_tables.items():
+ table_columns = {row[1] for row in conn.execute(text(f"PRAGMA table_info({table_name})")).all()}
+ if "route_scope" not in table_columns:
+ conn.execute(text(f"ALTER TABLE {table_name} ADD COLUMN route_scope {column_type}"))
+ address_columns = {row[1] for row in conn.execute(text("PRAGMA table_info(osm_addresses)")).all()}
+ if "geometry_geojson" not in address_columns:
+ conn.execute(text("ALTER TABLE osm_addresses ADD COLUMN geometry_geojson TEXT"))
+
+
+def _ensure_postgresql_runtime_columns() -> None:
+ column_statements = [
+ ("osm_features", "geom", "ALTER TABLE osm_features ADD COLUMN geom geometry(Geometry, 4326)"),
+ ("gtfs_routes", "geom", "ALTER TABLE gtfs_routes ADD COLUMN geom geometry(Geometry, 4326)"),
+ ("gtfs_shapes", "geom", "ALTER TABLE gtfs_shapes ADD COLUMN geom geometry(Geometry, 4326)"),
+ ("route_patterns", "geom", "ALTER TABLE route_patterns ADD COLUMN geom geometry(Geometry, 4326)"),
+ ("osm_addresses", "geometry_geojson", "ALTER TABLE osm_addresses ADD COLUMN geometry_geojson TEXT"),
+ ("osm_addresses", "geom", "ALTER TABLE osm_addresses ADD COLUMN geom geometry(Point, 4326)"),
+ ("osm_addresses", "area_geom", "ALTER TABLE osm_addresses ADD COLUMN area_geom geometry(Geometry, 4326)"),
+ ("gtfs_stops", "geom", "ALTER TABLE gtfs_stops ADD COLUMN geom geometry(Point, 4326)"),
+ ("canonical_stops", "geom", "ALTER TABLE canonical_stops ADD COLUMN geom geometry(Point, 4326)"),
+ ("routing_nodes", "geom", "ALTER TABLE routing_nodes ADD COLUMN geom geometry(Point, 4326)"),
+ ("routing_edges", "geom", "ALTER TABLE routing_edges ADD COLUMN geom geometry(LineString, 4326)"),
+ ]
+ with engine.begin() as conn:
+ columns = _postgresql_columns(conn)
+ for table_name, column_name, statement in column_statements:
+ if (table_name, column_name) not in columns:
+ conn.execute(text(statement))
+ country_column = columns.get(("osm_addresses", "country"))
+ if country_column is not None and country_column["data_type"] != "text":
+ conn.execute(text("ALTER TABLE osm_addresses ALTER COLUMN country TYPE TEXT"))
+
+
+def _ensure_runtime_indexes() -> None:
+ statements = [
+ "CREATE INDEX IF NOT EXISTS ix_osm_features_map_bbox ON osm_features (dataset_id, kind, mode, min_lon, max_lon, min_lat, max_lat)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_features_scope_bbox ON osm_features (dataset_id, kind, mode, route_scope, min_lon, max_lon, min_lat, max_lat)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_routes_map_bbox ON gtfs_routes (dataset_id, mode, min_lon, max_lon, min_lat, max_lat)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_routes_scope_bbox ON gtfs_routes (dataset_id, mode, route_scope, min_lon, max_lon, min_lat, max_lat)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_stops_map_point ON gtfs_stops (dataset_id, lon, lat)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop ON gtfs_stop_times (dataset_id, stop_id, departure_seconds, trip_id, stop_sequence)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop_depart_trip ON gtfs_stop_times (dataset_id, stop_id, departure_seconds, trip_id)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop_arrival ON gtfs_stop_times (dataset_id, stop_id, arrival_seconds, trip_id, stop_sequence)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop_arrive_trip ON gtfs_stop_times (dataset_id, stop_id, arrival_seconds, trip_id)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_trip_seq ON gtfs_stop_times (dataset_id, trip_id, stop_sequence)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_trip_stop_seq ON gtfs_stop_times (dataset_id, trip_id, stop_id, stop_sequence)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_trip ON gtfs_trips (dataset_id, trip_id)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_route ON gtfs_trips (dataset_id, route_id)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_service ON gtfs_trips (dataset_id, service_id, trip_id)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_route_service ON gtfs_trips (dataset_id, route_id, service_id)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_routes_dataset_route ON gtfs_routes (dataset_id, route_id)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_shapes_dataset_shape ON gtfs_shapes (dataset_id, shape_id)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_calendars_dataset_service_dates ON gtfs_calendars (dataset_id, service_id, start_date, end_date)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_calendar_dates_dataset_date ON gtfs_calendar_dates (dataset_id, date, service_id, exception_type)",
+ "CREATE INDEX IF NOT EXISTS ix_canonical_stop_links_object ON canonical_stop_links (object_type, dataset_id, object_id)",
+ "CREATE INDEX IF NOT EXISTS ix_canonical_stop_links_external ON canonical_stop_links (object_type, dataset_id, external_id)",
+ "CREATE INDEX IF NOT EXISTS ix_route_patterns_ref_mode ON route_patterns (route_ref, mode, source_kind)",
+ "CREATE INDEX IF NOT EXISTS ix_route_patterns_bbox ON route_patterns (mode, min_lon, max_lon, min_lat, max_lat)",
+ "CREATE INDEX IF NOT EXISTS ix_route_patterns_scope_bbox ON route_patterns (mode, route_scope, source_kind, min_lon, max_lon, min_lat, max_lat)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_route_pattern_links_trip_shape ON gtfs_route_pattern_links (dataset_id, route_id, shape_id)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_trip_route_pattern_links_trip ON gtfs_trip_route_pattern_links (dataset_id, trip_id)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_trip_route_pattern_links_pattern ON gtfs_trip_route_pattern_links (route_pattern_id, dataset_id, trip_id)",
+ "CREATE INDEX IF NOT EXISTS ix_sources_catalog_entry ON sources (catalog_entry_id)",
+ "CREATE INDEX IF NOT EXISTS ix_sources_priority_country_kind ON sources (priority, country, kind)",
+ "CREATE INDEX IF NOT EXISTS ix_source_catalog_country_priority ON source_catalog_entries (country_code, priority, status)",
+ "CREATE INDEX IF NOT EXISTS ix_source_catalog_name ON source_catalog_entries (source_name)",
+ "CREATE INDEX IF NOT EXISTS ix_source_update_checks_source_checked ON source_update_checks (source_id, checked_at)",
+ "CREATE INDEX IF NOT EXISTS ix_source_update_checks_available ON source_update_checks (source_id, update_available, checked_at)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_diff_states_source_sequence ON osm_diff_states (source_id, sequence_number)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_diff_states_source_status ON osm_diff_states (source_id, status, updated_at)",
+ "CREATE INDEX IF NOT EXISTS ix_jobs_status_created ON jobs (status, created_at)",
+ "CREATE INDEX IF NOT EXISTS ix_jobs_kind_status ON jobs (kind, status)",
+ "CREATE INDEX IF NOT EXISTS ix_jobs_queue_claim ON jobs (status, priority, created_at, id)",
+ "CREATE INDEX IF NOT EXISTS ix_jobs_lease ON jobs (status, lease_expires_at)",
+ "CREATE INDEX IF NOT EXISTS ix_jobs_dismissed_status ON jobs (dismissed_at, status, created_at)",
+ "CREATE INDEX IF NOT EXISTS ix_job_events_job_created ON job_events (job_id, created_at, id)",
+ "CREATE INDEX IF NOT EXISTS ix_pipeline_runs_stage_dataset_hash ON pipeline_runs (stage, dataset_id, dependency_hash, status, started_at)",
+ "CREATE INDEX IF NOT EXISTS ix_pipeline_runs_stage_source_hash ON pipeline_runs (stage, source_id, dependency_hash, status, started_at)",
+ "CREATE INDEX IF NOT EXISTS ix_pipeline_runs_job ON pipeline_runs (job_id, stage, status)",
+ "CREATE INDEX IF NOT EXISTS ix_match_rules_type_active ON match_rules (rule_type, active)",
+ "CREATE INDEX IF NOT EXISTS ix_journey_search_cache_type_expires ON journey_search_cache (cache_type, expires_at)",
+ "CREATE INDEX IF NOT EXISTS ix_travel_requests_created ON travel_requests (created_at)",
+ "CREATE INDEX IF NOT EXISTS ix_itineraries_request_saved ON itineraries (request_id, saved, created_at)",
+ "CREATE INDEX IF NOT EXISTS ix_itinerary_legs_itinerary_sequence ON itinerary_legs (itinerary_id, sequence)",
+ "CREATE INDEX IF NOT EXISTS ix_routing_nodes_dataset_osm ON routing_nodes (dataset_id, osm_node_id)",
+ "CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_source ON routing_edges (dataset_id, source_osm_node_id)",
+ "CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_target ON routing_edges (dataset_id, target_osm_node_id)",
+ "CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_drive ON routing_edges (dataset_id, source_osm_node_id) WHERE drive_cost_s IS NOT NULL",
+ "CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_walk ON routing_edges (dataset_id, source_osm_node_id) WHERE walk_cost_s IS NOT NULL",
+ "CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_reverse_drive ON routing_edges (dataset_id, target_osm_node_id) WHERE reverse_drive_cost_s IS NOT NULL",
+ "CREATE INDEX IF NOT EXISTS ix_routing_edges_dataset_reverse_walk ON routing_edges (dataset_id, target_osm_node_id) WHERE reverse_walk_cost_s IS NOT NULL",
+ "CREATE INDEX IF NOT EXISTS ix_routing_edges_bbox ON routing_edges (dataset_id, min_lon, max_lon, min_lat, max_lat)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_addresses_dataset_city_street ON osm_addresses (dataset_id, city, street, housenumber)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_addresses_dataset_postcode ON osm_addresses (dataset_id, postcode)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_addresses_bbox ON osm_addresses (dataset_id, min_lon, max_lon, min_lat, max_lat)",
+ ]
+ with engine.begin() as conn:
+ if settings.is_sqlite_database:
+ conn.execute(text("PRAGMA journal_mode=WAL"))
+ conn.execute(text(f"PRAGMA busy_timeout={int(settings.sqlite_busy_timeout_ms)}"))
+ if settings.is_postgresql_database:
+ _execute_missing_postgresql_indexes(conn, statements + _postgresql_index_statements())
+ else:
+ for statement in statements:
+ conn.execute(text(statement))
+
+
+def _postgresql_columns(conn: Connection) -> dict[tuple[str, str], dict[str, str]]:
+ rows = conn.execute(
+ text(
+ """
+ SELECT table_name, column_name, data_type, udt_name
+ FROM information_schema.columns
+ WHERE table_schema = ANY (current_schemas(false))
+ """
+ )
+ ).mappings()
+ return {
+ (str(row["table_name"]), str(row["column_name"])): {
+ "data_type": str(row["data_type"]),
+ "udt_name": str(row["udt_name"]),
+ }
+ for row in rows
+ }
+
+
+def _execute_missing_postgresql_indexes(conn: Connection, statements: list[str]) -> None:
+ existing = _postgresql_index_names(conn)
+ for statement in statements:
+ index_name = _index_name_from_create_statement(statement)
+ if index_name and index_name in existing:
+ continue
+ conn.execute(text(statement))
+ if index_name:
+ existing.add(index_name)
+
+
+def _postgresql_index_names(conn: Connection) -> set[str]:
+ rows = conn.execute(
+ text(
+ """
+ SELECT indexname
+ FROM pg_indexes
+ WHERE schemaname = ANY (current_schemas(false))
+ """
+ )
+ )
+ return {str(row[0]) for row in rows}
+
+
+def _index_name_from_create_statement(statement: str) -> str | None:
+ match = _CREATE_INDEX_NAME_RE.search(statement)
+ return match.group(1) if match else None
+
+
+def _postgresql_index_statements() -> list[str]:
+ return [
+ "CREATE INDEX IF NOT EXISTS ix_osm_features_geom_gist ON osm_features USING GIST (geom)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_features_stop_geom_gist ON osm_features USING GIST (geom) WHERE kind IN ('stop', 'station', 'terminal')",
+ "CREATE INDEX IF NOT EXISTS ix_osm_features_route_geom_gist ON osm_features USING GIST (geom) WHERE kind = 'route'",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_stops_geom_gist ON gtfs_stops USING GIST (geom)",
+ "CREATE INDEX IF NOT EXISTS ix_canonical_stops_geom_gist ON canonical_stops USING GIST (geom)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_routes_geom_gist ON gtfs_routes USING GIST (geom)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_shapes_geom_gist ON gtfs_shapes USING GIST (geom)",
+ "CREATE INDEX IF NOT EXISTS ix_route_patterns_geom_gist ON route_patterns USING GIST (geom)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_addresses_geom_gist ON osm_addresses USING GIST (geom)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_addresses_area_geom_gist ON osm_addresses USING GIST (area_geom)",
+ "CREATE INDEX IF NOT EXISTS ix_routing_nodes_geom_gist ON routing_nodes USING GIST (geom)",
+ "CREATE INDEX IF NOT EXISTS ix_routing_edges_bbox_box_gist ON routing_edges USING GIST (box(point(max_lon, max_lat), point(min_lon, min_lat)))",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_route_shape_expr ON gtfs_trips (dataset_id, route_id, (COALESCE(shape_id, '__route__')))",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_dataset_stop ON gtfs_stop_times (dataset_id, stop_id)",
+ "CREATE INDEX IF NOT EXISTS ix_canonical_stop_links_gtfs_external ON canonical_stop_links (dataset_id, external_id, canonical_stop_id) WHERE object_type = 'gtfs_stop'",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_stops_dataset_parent ON gtfs_stops (dataset_id, parent_station)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_stops_dataset_stop_prefix ON gtfs_stops (dataset_id, (split_part(stop_id, '::', 1)))",
+ "CREATE INDEX IF NOT EXISTS ix_osm_features_name_trgm ON osm_features USING GIN (LOWER(COALESCE(name, '')) gin_trgm_ops)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_features_ref_trgm ON osm_features USING GIN (LOWER(COALESCE(ref, '')) gin_trgm_ops)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_features_tags_trgm ON osm_features USING GIN (LOWER(COALESCE(tags_json, '')) gin_trgm_ops)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_addresses_search_trgm ON osm_addresses USING GIN (LOWER(COALESCE(search_text, '')) gin_trgm_ops)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_addresses_display_trgm ON osm_addresses USING GIN (LOWER(COALESCE(display_name, '')) gin_trgm_ops)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_addresses_street_key_house ON osm_addresses (dataset_id, REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss'), housenumber)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_addresses_street_key_trgm ON osm_addresses USING GIN (REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss') gin_trgm_ops)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_stops_name_trgm ON gtfs_stops USING GIN (name gin_trgm_ops)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_stops_stop_id_trgm ON gtfs_stops USING GIN (stop_id gin_trgm_ops)",
+ "CREATE INDEX IF NOT EXISTS ix_route_patterns_ref_trgm ON route_patterns USING GIN (LOWER(COALESCE(route_ref, '')) gin_trgm_ops)",
+ "CREATE INDEX IF NOT EXISTS ix_route_patterns_name_trgm ON route_patterns USING GIN (LOWER(COALESCE(route_name, '')) gin_trgm_ops)",
+ ]
+
+
+def get_db() -> Iterator[Session]:
+ db = SessionLocal()
+ try:
+ yield db
+ finally:
+ db.close()
+
+
+@contextmanager
+def session_scope() -> Iterator[Session]:
+ db = SessionLocal()
+ try:
+ yield db
+ db.commit()
+ except Exception:
+ db.rollback()
+ raise
+ finally:
+ db.close()
diff --git a/app/db_lock.py b/app/db_lock.py
new file mode 100644
index 0000000..e13d7c7
--- /dev/null
+++ b/app/db_lock.py
@@ -0,0 +1,211 @@
+from __future__ import annotations
+
+from contextlib import contextmanager
+from dataclasses import dataclass
+import json
+import os
+from pathlib import Path
+import threading
+import time
+from typing import Iterator
+
+from app.config import settings
+
+try:
+ import fcntl
+except ImportError: # pragma: no cover - this app currently targets Linux/macOS dev hosts
+ fcntl = None # type: ignore[assignment]
+
+
+class DatabaseWriteBusy(RuntimeError):
+ def __init__(self, operation: str, active: dict[str, object] | None = None) -> None:
+ self.operation = operation
+ self.active = active or {}
+ active_operation = self.active.get("operation")
+ detail = f"Database is busy with another write operation"
+ if active_operation:
+ detail += f": {active_operation}"
+ super().__init__(detail)
+
+
+@dataclass(frozen=True)
+class DatabaseWriteState:
+ locked: bool
+ operation: str | None = None
+ pid: int | None = None
+ started_at: float | None = None
+
+ @property
+ def elapsed_seconds(self) -> float | None:
+ if self.started_at is None:
+ return None
+ return max(0.0, time.time() - self.started_at)
+
+
+_process_write_lock = threading.Lock()
+_state_lock = threading.Lock()
+_state = DatabaseWriteState(locked=False)
+
+
+def is_sqlite_database() -> bool:
+ return settings.is_sqlite_database
+
+
+@contextmanager
+def database_write_lock(operation: str, timeout: float | None = None) -> Iterator[None]:
+ """Serialize SQLite writes inside and across app processes.
+
+ SQLite allows only one writer. This lock prevents mutating endpoints from
+ competing until SQLite times out with a low-level "database is locked" error.
+ """
+ if not is_sqlite_database():
+ yield
+ return
+
+ effective_timeout = settings.database_write_lock_timeout_seconds if timeout is None else timeout
+ deadline = None if effective_timeout is None else time.monotonic() + max(0.0, effective_timeout)
+ if not _acquire_process_lock(deadline):
+ raise DatabaseWriteBusy(operation, database_write_status().__dict__)
+
+ handle = None
+ file_locked = False
+ try:
+ lock_path = _lock_path()
+ lock_path.parent.mkdir(parents=True, exist_ok=True)
+ handle = _open_locked_handle(lock_path, deadline)
+ if handle is None:
+ raise DatabaseWriteBusy(operation, _read_lock_metadata(lock_path))
+ file_locked = True
+ _write_lock_metadata(handle, operation)
+ _set_state(DatabaseWriteState(locked=True, operation=operation, pid=os.getpid(), started_at=time.time()))
+ yield
+ finally:
+ _set_state(DatabaseWriteState(locked=False))
+ if handle is not None:
+ if file_locked and fcntl is not None:
+ try:
+ fcntl.flock(handle.fileno(), fcntl.LOCK_UN)
+ except OSError:
+ pass
+ handle.close()
+ if file_locked:
+ try:
+ _lock_path().unlink()
+ except FileNotFoundError:
+ pass
+ except OSError:
+ pass
+ _process_write_lock.release()
+
+
+def database_write_status() -> DatabaseWriteState:
+ with _state_lock:
+ return _state
+
+
+def _acquire_process_lock(deadline: float | None) -> bool:
+ while True:
+ if _process_write_lock.acquire(blocking=False):
+ return True
+ if deadline is not None and time.monotonic() >= deadline:
+ return False
+ time.sleep(0.05)
+
+
+def _acquire_file_lock(handle, deadline: float | None) -> bool:
+ if fcntl is None:
+ return True
+ while True:
+ try:
+ fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
+ return True
+ except BlockingIOError:
+ if deadline is not None and time.monotonic() >= deadline:
+ return False
+ time.sleep(0.05)
+
+
+def _open_locked_handle(lock_path: Path, deadline: float | None):
+ while True:
+ try:
+ lock_path.parent.mkdir(parents=True, exist_ok=True)
+ handle = lock_path.open("a+", encoding="utf-8")
+ except FileNotFoundError:
+ if deadline is not None and time.monotonic() >= deadline:
+ return None
+ time.sleep(0.05)
+ continue
+ if _try_file_lock(handle):
+ return handle
+ metadata = _read_lock_metadata(lock_path)
+ handle.close()
+ if not _lock_metadata_is_stale(metadata):
+ if deadline is not None and time.monotonic() >= deadline:
+ return None
+ time.sleep(0.05)
+ continue
+ try:
+ lock_path.unlink()
+ except FileNotFoundError:
+ pass
+ except OSError:
+ return None
+ if deadline is not None and time.monotonic() >= deadline:
+ return None
+
+
+def _try_file_lock(handle) -> bool:
+ if fcntl is None:
+ return True
+ try:
+ fcntl.flock(handle.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
+ except BlockingIOError:
+ return False
+ return True
+
+
+def _lock_metadata_is_stale(metadata: dict[str, object]) -> bool:
+ pid = metadata.get("pid")
+ try:
+ pid_int = int(pid) # type: ignore[arg-type]
+ except (TypeError, ValueError):
+ return False
+ if pid_int <= 0 or pid_int == os.getpid():
+ return False
+ return not _pid_exists(pid_int)
+
+
+def _pid_exists(pid: int) -> bool:
+ try:
+ os.kill(pid, 0)
+ except ProcessLookupError:
+ return False
+ except PermissionError:
+ return True
+ return True
+
+
+def _set_state(state: DatabaseWriteState) -> None:
+ global _state
+ with _state_lock:
+ _state = state
+
+
+def _lock_path() -> Path:
+ return settings.data_dir / "workbench.write.lock"
+
+
+def _write_lock_metadata(handle, operation: str) -> None:
+ handle.seek(0)
+ handle.truncate()
+ json.dump({"operation": operation, "pid": os.getpid(), "started_at": time.time()}, handle, separators=(",", ":"))
+ handle.flush()
+ os.fsync(handle.fileno())
+
+
+def _read_lock_metadata(path: Path) -> dict[str, object]:
+ try:
+ text = path.read_text(encoding="utf-8").strip()
+ return json.loads(text) if text else {}
+ except (OSError, json.JSONDecodeError):
+ return {}
diff --git a/app/feed_discovery.py b/app/feed_discovery.py
new file mode 100644
index 0000000..3aa609d
--- /dev/null
+++ b/app/feed_discovery.py
@@ -0,0 +1,923 @@
+from __future__ import annotations
+
+import csv
+import hashlib
+import json
+import re
+from dataclasses import dataclass, field
+from datetime import datetime, timezone
+from html import unescape
+from html.parser import HTMLParser
+from pathlib import Path
+from typing import Iterable
+from urllib.parse import parse_qs, urljoin, urlparse
+
+import requests
+
+
+MOBILITY_DATABASE_FEEDS_URL = "https://files.mobilitydatabase.org/feeds_v2.csv"
+MOBILITY_DATABASE_ACCEPTANCE_TEST_URL = (
+ "https://raw.githubusercontent.com/MobilityData/gtfs-validator/master/"
+ "scripts/mobility-database-harvester/acceptance_test_feed_list.csv"
+)
+PTNA_GTFS_INDEX_URL = "https://ptna.openstreetmap.de/gtfs/index.html"
+PTNA_COUNTRY_URL_TEMPLATE = "https://ptna.openstreetmap.de/gtfs/{country}/index.php"
+
+DEFAULT_DISCOVERY_COUNTRIES = ["DE", "AT", "CH", "NL", "DK", "FR", "BE", "LU", "NO", "SE", "FI", "IE", "GB"]
+CURATED_TEST_COUNTRIES = ["DE", "CH", "AT", "NL", "DK", "FI", "NO", "SE", "IE", "GB", "FR", "BE", "LU"]
+DIRECT_INGEST_HEADERS = ["name", "kind", "url", "country", "license", "mode_scope", "source_basis", "priority", "notes"]
+CANONICAL_HEADERS = [
+ "candidate_id",
+ "discovery_source",
+ "country",
+ "subdivision",
+ "provider",
+ "feed_name",
+ "stable_id",
+ "ptna_feed_id",
+ "data_type",
+ "status",
+ "is_official",
+ "selected_url",
+ "direct_download_url",
+ "latest_url",
+ "original_release_url",
+ "license_url",
+ "license_text",
+ "osm_license_text",
+ "details_url",
+ "routes_url",
+ "valid_from",
+ "valid_to",
+ "release_date",
+ "feed_version",
+ "bbox",
+ "features",
+ "priority",
+ "availability_status",
+ "http_status",
+ "content_type",
+ "content_length",
+ "final_url",
+ "source_basis",
+ "notes",
+]
+
+
+@dataclass
+class FeedCandidate:
+ discovery_source: str
+ country: str = ""
+ subdivision: str = ""
+ provider: str = ""
+ feed_name: str = ""
+ stable_id: str = ""
+ ptna_feed_id: str = ""
+ data_type: str = "gtfs"
+ status: str = ""
+ is_official: str = ""
+ selected_url: str = ""
+ direct_download_url: str = ""
+ latest_url: str = ""
+ original_release_url: str = ""
+ license_url: str = ""
+ license_text: str = ""
+ osm_license_text: str = ""
+ details_url: str = ""
+ routes_url: str = ""
+ valid_from: str = ""
+ valid_to: str = ""
+ release_date: str = ""
+ feed_version: str = ""
+ bbox: str = ""
+ features: str = ""
+ priority: str = ""
+ availability_status: str = "unchecked"
+ http_status: str = ""
+ content_type: str = ""
+ content_length: str = ""
+ final_url: str = ""
+ source_basis: str = ""
+ notes: str = ""
+ evidence_sources: list[str] = field(default_factory=list)
+
+ def key(self) -> str:
+ if self.stable_id:
+ return f"stable:{self.stable_id}"
+ if self.selected_url:
+ return f"url:{_normalize_url_key(self.selected_url)}"
+ if self.ptna_feed_id:
+ return f"ptna:{self.ptna_feed_id}"
+ return "hash:" + hashlib.sha256(json.dumps(self.row(), sort_keys=True).encode("utf-8")).hexdigest()
+
+ def candidate_id(self) -> str:
+ seed = "|".join(
+ [
+ self.discovery_source,
+ self.country,
+ self.stable_id,
+ self.ptna_feed_id,
+ self.selected_url,
+ self.provider,
+ self.feed_name,
+ ]
+ )
+ return hashlib.sha256(seed.encode("utf-8")).hexdigest()[:16]
+
+ def row(self) -> dict[str, str]:
+ payload = {header: _string(getattr(self, header, "")) for header in CANONICAL_HEADERS if header != "candidate_id"}
+ payload["candidate_id"] = self.candidate_id()
+ return payload
+
+ def ingestable_row(self) -> dict[str, str]:
+ name = _feed_source_name(self.country, self.provider or self.feed_name)
+ license_value = self.license_text or (f"see {self.license_url}" if self.license_url else "")
+ basis_parts = [self.source_basis or self.discovery_source]
+ if self.details_url:
+ basis_parts.append(f"details: {self.details_url}")
+ if self.original_release_url and self.original_release_url != self.selected_url:
+ basis_parts.append(f"release: {self.original_release_url}")
+ notes = self.notes or ""
+ if self.latest_url and self.latest_url != self.selected_url:
+ notes = _join_notes(notes, f"Mobility Database mirror: {self.latest_url}")
+ if self.osm_license_text:
+ notes = _join_notes(notes, f"OSM permission note: {_truncate(self.osm_license_text, 240)}")
+ return {
+ "name": _truncate(name, 240),
+ "kind": "gtfs",
+ "url": self.selected_url,
+ "country": self.country,
+ "license": _truncate(license_value, 240),
+ "mode_scope": _mode_scope_from_features(self.features),
+ "source_basis": _truncate("; ".join(part for part in basis_parts if part), 500),
+ "priority": self.priority or _candidate_priority(self),
+ "notes": _truncate(notes, 1200),
+ }
+
+
+def default_generated_dir() -> Path:
+ return Path(__file__).resolve().parents[1] / "docs" / "generated"
+
+
+def build_gtfs_discovery_manifests(
+ *,
+ output_dir: Path | str | None = None,
+ countries: Iterable[str] | None = None,
+ include_mobility_database: bool = True,
+ include_acceptance_test_list: bool = True,
+ include_ptna: bool = True,
+ max_ptna_details: int = 80,
+ test_limit: int = 24,
+ check_urls: bool = False,
+ timeout: float = 30.0,
+) -> dict[str, object]:
+ selected_countries = _normalize_countries(countries)
+ out_dir = Path(output_dir) if output_dir is not None else default_generated_dir()
+ out_dir.mkdir(parents=True, exist_ok=True)
+
+ candidates: list[FeedCandidate] = []
+ candidates.extend(load_curated_ingestable_seed(countries=selected_countries))
+ if include_mobility_database:
+ candidates.extend(fetch_mobility_database_candidates(countries=selected_countries, timeout=timeout))
+ if include_acceptance_test_list:
+ candidates.extend(fetch_mobility_acceptance_candidates(countries=selected_countries, timeout=timeout))
+ if include_ptna:
+ candidates.extend(fetch_ptna_candidates(countries=selected_countries, max_details=max_ptna_details, timeout=timeout))
+
+ merged = merge_candidates(candidates)
+ ingestable = [candidate for candidate in merged if candidate.selected_url and candidate.data_type == "gtfs"]
+ if check_urls:
+ for candidate in ingestable:
+ annotate_url_availability(candidate, timeout=min(timeout, 12.0))
+ test_run = select_test_run_candidates(ingestable, limit=test_limit)
+
+ candidates_path = out_dir / "gtfs_feed_candidates.csv"
+ ingestable_path = out_dir / "gtfs_ingestable_sources.csv"
+ test_path = out_dir / "gtfs_test_run_sources.csv"
+ report_path = out_dir / "gtfs_discovery_report.json"
+
+ _write_csv(candidates_path, CANONICAL_HEADERS, [candidate.row() for candidate in merged])
+ _write_csv(ingestable_path, DIRECT_INGEST_HEADERS, [candidate.ingestable_row() for candidate in ingestable])
+ _write_csv(test_path, DIRECT_INGEST_HEADERS, [candidate.ingestable_row() for candidate in test_run])
+
+ by_source = _count_by(merged, lambda item: item.discovery_source)
+ by_country = _count_by(ingestable, lambda item: item.country or "unknown")
+ report = {
+ "generated_at": datetime.now(timezone.utc).isoformat(),
+ "countries": selected_countries or "all",
+ "sources": {
+ "mobility_database": MOBILITY_DATABASE_FEEDS_URL if include_mobility_database else None,
+ "mobility_acceptance_test_list": MOBILITY_DATABASE_ACCEPTANCE_TEST_URL if include_acceptance_test_list else None,
+ "ptna": PTNA_GTFS_INDEX_URL if include_ptna else None,
+ },
+ "counts": {
+ "candidates": len(merged),
+ "ingestable": len(ingestable),
+ "test_run": len(test_run),
+ "by_source": by_source,
+ "ingestable_by_country": by_country,
+ },
+ "files": {
+ "candidates": str(candidates_path),
+ "ingestable": str(ingestable_path),
+ "test_run": str(test_path),
+ },
+ }
+ report_path.write_text(json.dumps(report, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
+ return report
+
+
+def fetch_mobility_database_candidates(
+ *,
+ countries: list[str] | None = None,
+ timeout: float = 30.0,
+ url: str = MOBILITY_DATABASE_FEEDS_URL,
+) -> list[FeedCandidate]:
+ text = _fetch_text(url, timeout=timeout)
+ rows = csv.DictReader(text.splitlines())
+ candidates: list[FeedCandidate] = []
+ for row in rows:
+ if _value(row, "data_type").lower() != "gtfs":
+ continue
+ country = _value(row, "location.country_code").upper()
+ if countries and country not in countries:
+ continue
+ direct_url = _normalize_feed_url(_value(row, "urls.direct_download"))
+ latest_url = _normalize_feed_url(_value(row, "urls.latest"))
+ selected_url = _choose_feed_url(direct_url, latest_url)
+ candidate = FeedCandidate(
+ discovery_source="mobility_database",
+ country=country,
+ subdivision=_value(row, "location.subdivision_name"),
+ provider=_value(row, "provider"),
+ feed_name=_value(row, "name"),
+ stable_id=_value(row, "id"),
+ data_type="gtfs",
+ status=_value(row, "status"),
+ is_official=_value(row, "is_official"),
+ selected_url=selected_url,
+ direct_download_url=direct_url,
+ latest_url=latest_url,
+ license_url=_value(row, "urls.license"),
+ bbox=_bbox_from_mobility_row(row),
+ features=_value(row, "features"),
+ source_basis="Mobility Database feed catalog",
+ notes=_value(row, "note"),
+ )
+ normalize_candidate_geography(candidate)
+ apply_known_download_overrides(candidate)
+ candidate.priority = _candidate_priority(candidate)
+ candidates.append(candidate)
+ return candidates
+
+
+def fetch_mobility_acceptance_candidates(
+ *,
+ countries: list[str] | None = None,
+ timeout: float = 30.0,
+ url: str = MOBILITY_DATABASE_ACCEPTANCE_TEST_URL,
+) -> list[FeedCandidate]:
+ text = _fetch_text(url, timeout=timeout)
+ rows = csv.DictReader(text.splitlines())
+ candidates: list[FeedCandidate] = []
+ for row in rows:
+ country = _value(row, "country_code").upper()
+ if countries and country not in countries:
+ continue
+ latest_url = _normalize_feed_url(_value(row, "urls.latest"))
+ if not latest_url:
+ continue
+ candidate = FeedCandidate(
+ discovery_source="mobility_validator_acceptance",
+ country=country,
+ subdivision=_value(row, "subdivision_name"),
+ provider=_value(row, "provider"),
+ feed_name=_value(row, "provider"),
+ stable_id=_value(row, "stable_id"),
+ status="acceptance_test",
+ selected_url=latest_url,
+ latest_url=latest_url,
+ source_basis="MobilityData validator acceptance-test feed list",
+ notes="Useful smoke-test feed list; prefer Mobility Database feeds_v2 metadata for production source review.",
+ priority="P3",
+ )
+ normalize_candidate_geography(candidate)
+ apply_known_download_overrides(candidate)
+ candidates.append(candidate)
+ return candidates
+
+
+def fetch_ptna_candidates(
+ *,
+ countries: list[str] | None = None,
+ max_details: int = 80,
+ timeout: float = 30.0,
+) -> list[FeedCandidate]:
+ country_codes = countries or DEFAULT_DISCOVERY_COUNTRIES
+ if not country_codes:
+ country_codes = discover_ptna_country_codes(timeout=timeout)
+ candidates: list[FeedCandidate] = []
+ detail_fetches = 0
+ for country in country_codes:
+ country_url = PTNA_COUNTRY_URL_TEMPLATE.format(country=country)
+ try:
+ html = _fetch_text(country_url, timeout=timeout)
+ except requests.RequestException:
+ continue
+ for candidate in parse_ptna_country_page(html, country=country, page_url=country_url):
+ if candidate.details_url and detail_fetches < max_details:
+ try:
+ detail_html = _fetch_text(candidate.details_url, timeout=timeout)
+ enrich_ptna_candidate_from_details(candidate, detail_html, candidate.details_url)
+ detail_fetches += 1
+ except requests.RequestException:
+ candidate.notes = _join_notes(candidate.notes, "PTNA detail page could not be fetched during discovery.")
+ candidate.priority = _candidate_priority(candidate)
+ candidates.append(candidate)
+ return candidates
+
+
+def discover_ptna_country_codes(*, timeout: float = 30.0) -> list[str]:
+ html = _fetch_text(PTNA_GTFS_INDEX_URL, timeout=timeout)
+ links = _all_links(html, PTNA_GTFS_INDEX_URL)
+ codes: list[str] = []
+ for link in links:
+ match = re.search(r"/gtfs/([A-Z]{2})/index\.php$", urlparse(link).path)
+ if match and match.group(1) not in codes:
+ codes.append(match.group(1))
+ return codes
+
+
+def parse_ptna_country_page(html: str, *, country: str, page_url: str) -> list[FeedCandidate]:
+ rows = _parse_table_rows(html, page_url)
+ candidates: list[FeedCandidate] = []
+ for row in rows:
+ links = [link for cell in row.cells for link in cell.links]
+ routes_url = _first_link_matching(links, "routes.php?feed=")
+ details_url = _first_link_matching(links, "gtfs-details.php?feed=")
+ if not routes_url and not details_url:
+ continue
+ feed_id = _feed_id_from_url(routes_url or details_url)
+ if not feed_id:
+ continue
+ texts = [cell.text for cell in row.cells]
+ release_link = _normalize_feed_url(row.cells[6].first_external_link if len(row.cells) > 6 else "")
+ direct_url = release_link if _looks_like_download_url(release_link) else ""
+ candidate = FeedCandidate(
+ discovery_source="ptna",
+ country=country,
+ provider=texts[2] if len(texts) > 2 else "",
+ feed_name=texts[1] if len(texts) > 1 else feed_id,
+ ptna_feed_id=feed_id,
+ selected_url=direct_url,
+ direct_download_url=direct_url,
+ original_release_url=release_link,
+ details_url=details_url,
+ routes_url=routes_url,
+ valid_from=texts[3] if len(texts) > 3 else "",
+ valid_to=texts[4] if len(texts) > 4 else "",
+ feed_version=texts[5] if len(texts) > 5 else "",
+ release_date=texts[6] if len(texts) > 6 else "",
+ source_basis="PTNA GTFS analysis",
+ notes="PTNA candidate; use original publisher URL where available.",
+ )
+ normalize_candidate_geography(candidate)
+ apply_known_download_overrides(candidate)
+ candidates.append(candidate)
+ return candidates
+
+
+def enrich_ptna_candidate_from_details(candidate: FeedCandidate, html: str, page_url: str) -> None:
+ fields = parse_ptna_detail_fields(html, page_url)
+ candidate.original_release_url = _normalize_feed_url(fields.get("release url href") or fields.get("release url") or candidate.original_release_url)
+ candidate.license_url = fields.get("publisher's license href") or candidate.license_url
+ candidate.license_text = fields.get("publisher's license") or candidate.license_text
+ candidate.osm_license_text = fields.get("license given for use in osm") or candidate.osm_license_text
+ candidate.valid_from = fields.get("feed start date") or candidate.valid_from
+ candidate.valid_to = fields.get("feed end date") or candidate.valid_to
+ candidate.feed_version = fields.get("feed version") or candidate.feed_version
+ candidate.release_date = fields.get("release date") or candidate.release_date
+ network_guid = fields.get('"network:guid"')
+ if network_guid:
+ candidate.notes = _join_notes(candidate.notes, f"PTNA network:guid={network_guid}")
+ if not candidate.selected_url and _looks_like_download_url(candidate.original_release_url):
+ candidate.selected_url = _normalize_feed_url(candidate.original_release_url)
+ candidate.direct_download_url = candidate.selected_url
+ normalize_candidate_geography(candidate)
+
+
+def parse_ptna_detail_fields(html: str, page_url: str) -> dict[str, str]:
+ parsed: dict[str, str] = {}
+ for row in _parse_table_rows(html, page_url):
+ if len(row.cells) < 2:
+ continue
+ label = _clean_text(row.cells[0].text).lower()
+ if not label:
+ continue
+ detail = _clean_text(row.cells[1].text)
+ parsed[label] = detail
+ if row.cells[1].first_external_link:
+ parsed[f"{label} href"] = row.cells[1].first_external_link
+ return parsed
+
+
+def load_curated_ingestable_seed(
+ *,
+ countries: list[str] | None = None,
+ path: Path | str | None = None,
+) -> list[FeedCandidate]:
+ seed_path = Path(path) if path is not None else Path(__file__).resolve().parents[1] / "docs" / "ingestable_sources_seed.csv"
+ if not seed_path.exists():
+ return []
+ candidates: list[FeedCandidate] = []
+ with seed_path.open("r", encoding="utf-8-sig", newline="") as handle:
+ for row in csv.DictReader(handle):
+ if _value(row, "kind").lower() != "gtfs":
+ continue
+ country = _value(row, "country").upper()
+ if countries and country not in countries and country != "EU":
+ continue
+ candidate = FeedCandidate(
+ discovery_source="curated_seed",
+ country=country,
+ provider=_value(row, "name").removesuffix(" GTFS"),
+ feed_name=_value(row, "name"),
+ selected_url=_normalize_feed_url(_value(row, "url")),
+ direct_download_url=_normalize_feed_url(_value(row, "url")),
+ license_text=_value(row, "license"),
+ features=_value(row, "mode_scope"),
+ priority=_value(row, "priority"),
+ source_basis=_value(row, "source_basis") or "curated seed",
+ notes=_value(row, "notes"),
+ )
+ normalize_candidate_geography(candidate)
+ apply_known_download_overrides(candidate)
+ candidates.append(candidate)
+ return candidates
+
+
+def merge_candidates(candidates: Iterable[FeedCandidate]) -> list[FeedCandidate]:
+ by_key: dict[str, FeedCandidate] = {}
+ alias_to_key: dict[str, str] = {}
+ for candidate in candidates:
+ keys = _candidate_alias_keys(candidate)
+ primary_key = keys[0]
+ existing_key = next((alias_to_key[key] for key in keys if key in alias_to_key), None)
+ existing = by_key.get(existing_key) if existing_key is not None else None
+ if existing is None:
+ by_key[primary_key] = candidate
+ for key in keys:
+ alias_to_key[key] = primary_key
+ continue
+ _merge_candidate(existing, candidate)
+ for key in keys:
+ alias_to_key[key] = existing_key or primary_key
+ return sorted(by_key.values(), key=lambda item: (_priority_sort_key(item.priority), item.country, item.provider.lower(), item.feed_name.lower()))
+
+
+def select_test_run_candidates(candidates: Iterable[FeedCandidate], *, limit: int = 24) -> list[FeedCandidate]:
+ sorted_candidates = sorted(
+ [
+ candidate
+ for candidate in candidates
+ if candidate.discovery_source != "mobility_validator_acceptance" and _test_candidate_eligible(candidate)
+ ],
+ key=_test_candidate_sort_key,
+ )
+ selected: list[FeedCandidate] = []
+ seen_urls: set[str] = set()
+ per_country: dict[str, int] = {}
+
+ def add(candidate: FeedCandidate, *, force: bool = False) -> None:
+ if len(selected) >= limit:
+ return
+ url_key = _normalize_url_key(candidate.selected_url)
+ if not candidate.selected_url or url_key in seen_urls:
+ return
+ country = candidate.country or "unknown"
+ country_limit = 7 if force and country == "DE" else 3
+ if per_country.get(country, 0) >= country_limit:
+ return
+ selected.append(candidate)
+ seen_urls.add(url_key)
+ per_country[country] = per_country.get(country, 0) + 1
+
+ preferred_tokens = [
+ "opendata-oepnv.de",
+ "download.gtfs.de/germany/",
+ "vbb.de/vbbgtfs",
+ "rnv-online.de",
+ "vrn.de",
+ "gtfs.geops.ch",
+ "wienerlinien.at",
+ "gtfs.openov.nl",
+ "gtfs.ovapi.nl",
+ "rejseplanen.info",
+ "dev.hsl.fi/gtfs",
+ "hsldev.com/gtfs",
+ "rb_norway-aggregated-gtfs",
+ "data.bus-data.dft.gov.uk",
+ "transportforireland",
+ "gtfs.irail.be/de-lijn",
+ ]
+ for candidate in sorted_candidates:
+ text = " ".join([candidate.provider, candidate.feed_name, candidate.source_basis, candidate.selected_url]).lower()
+ if any(token in text for token in preferred_tokens):
+ add(candidate, force=True)
+ for country in CURATED_TEST_COUNTRIES:
+ for candidate in sorted_candidates:
+ if candidate.country == country:
+ add(candidate)
+ if len(selected) >= limit:
+ break
+ if len(selected) >= limit:
+ break
+ for candidate in sorted_candidates:
+ add(candidate)
+ if len(selected) >= limit:
+ break
+ return selected
+
+
+def _test_candidate_eligible(candidate: FeedCandidate) -> bool:
+ if not candidate.selected_url:
+ return False
+ if _priority_sort_key(candidate.priority) > 2:
+ return False
+ text = " ".join([candidate.status, candidate.selected_url, candidate.provider, candidate.feed_name, candidate.notes]).lower()
+ if "deprecated" in text or "inactive" in text or "{apikey}" in text:
+ return False
+ if "registration required" in text or "authentication" in text:
+ return False
+ return True
+
+
+def annotate_url_availability(candidate: FeedCandidate, *, timeout: float = 10.0) -> FeedCandidate:
+ if not candidate.selected_url:
+ candidate.availability_status = "missing_url"
+ return candidate
+ headers = {"User-Agent": "meubility-workbench-feed-discovery/0.1"}
+ try:
+ response = requests.head(candidate.selected_url, allow_redirects=True, timeout=timeout, headers=headers)
+ if response.status_code in {405, 403} or response.status_code >= 500:
+ response = requests.get(
+ candidate.selected_url,
+ allow_redirects=True,
+ timeout=timeout,
+ headers={**headers, "Range": "bytes=0-0"},
+ stream=True,
+ )
+ candidate.http_status = str(response.status_code)
+ candidate.content_type = response.headers.get("content-type", "")
+ candidate.content_length = response.headers.get("content-length", "")
+ candidate.final_url = response.url
+ candidate.availability_status = "ok" if response.status_code < 400 else "error"
+ response.close()
+ except requests.RequestException as exc:
+ candidate.availability_status = "error"
+ candidate.notes = _join_notes(candidate.notes, f"Availability check failed: {exc}")
+ return candidate
+
+
+def normalize_candidate_geography(candidate: FeedCandidate) -> None:
+ text = " ".join(
+ [
+ candidate.selected_url,
+ candidate.direct_download_url,
+ candidate.latest_url,
+ candidate.original_release_url,
+ candidate.provider,
+ candidate.feed_name,
+ candidate.source_basis,
+ ]
+ ).lower()
+ if "download.gtfs.de/germany/" in text or "gtfs for germany" in text:
+ candidate.country = "DE"
+ elif "storage.googleapis.com/marduk-production/outbound/gtfs/rb_norway" in text:
+ candidate.country = "NO"
+ elif "gtfs.ovapi.nl" in text or "openov.nl" in text:
+ candidate.country = "NL"
+ elif "www.nvbw.de/fileadmin/user_upload/service/open_data/" in text:
+ candidate.country = "DE"
+
+
+def apply_known_download_overrides(candidate: FeedCandidate) -> None:
+ stale_direct_ids = {"mdb-684", "mdb-777"}
+ if candidate.stable_id in stale_direct_ids and candidate.latest_url:
+ candidate.selected_url = candidate.latest_url
+ candidate.notes = _join_notes(
+ candidate.notes,
+ "Selected Mobility Database latest.zip mirror because the catalog direct URL is known to be stale.",
+ )
+
+
+@dataclass
+class _HtmlCell:
+ text: str = ""
+ links: list[str] = field(default_factory=list)
+
+ @property
+ def first_external_link(self) -> str:
+ for link in self.links:
+ parsed = urlparse(link)
+ if parsed.scheme in {"http", "https"} and "ptna.openstreetmap.de" not in parsed.netloc:
+ return link
+ return ""
+
+
+@dataclass
+class _HtmlRow:
+ cells: list[_HtmlCell] = field(default_factory=list)
+
+
+class _TableParser(HTMLParser):
+ def __init__(self, base_url: str):
+ super().__init__(convert_charrefs=True)
+ self.base_url = base_url
+ self.rows: list[_HtmlRow] = []
+ self._row: _HtmlRow | None = None
+ self._cell: _HtmlCell | None = None
+ self._active_link: str = ""
+
+ def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
+ attrs_dict = {key: value or "" for key, value in attrs}
+ if tag == "tr":
+ self._row = _HtmlRow()
+ elif tag in {"td", "th"} and self._row is not None:
+ self._cell = _HtmlCell()
+ elif tag == "a" and self._cell is not None:
+ href = attrs_dict.get("href", "")
+ if href:
+ self._active_link = urljoin(self.base_url, href)
+ self._cell.links.append(self._active_link)
+
+ def handle_endtag(self, tag: str) -> None:
+ if tag in {"td", "th"} and self._row is not None and self._cell is not None:
+ self._cell.text = _clean_text(self._cell.text)
+ self._row.cells.append(self._cell)
+ self._cell = None
+ self._active_link = ""
+ elif tag == "a":
+ self._active_link = ""
+ elif tag == "tr":
+ if self._row is not None and self._row.cells:
+ self.rows.append(self._row)
+ self._row = None
+ self._cell = None
+ self._active_link = ""
+
+ def handle_data(self, data: str) -> None:
+ if self._cell is not None:
+ self._cell.text += data
+
+
+class _LinkParser(HTMLParser):
+ def __init__(self, base_url: str):
+ super().__init__(convert_charrefs=True)
+ self.base_url = base_url
+ self.links: list[str] = []
+
+ def handle_starttag(self, tag: str, attrs: list[tuple[str, str | None]]) -> None:
+ if tag != "a":
+ return
+ for key, value in attrs:
+ if key == "href" and value:
+ self.links.append(urljoin(self.base_url, value))
+
+
+def _parse_table_rows(html: str, base_url: str) -> list[_HtmlRow]:
+ parser = _TableParser(base_url)
+ parser.feed(html)
+ return parser.rows
+
+
+def _all_links(html: str, base_url: str) -> list[str]:
+ parser = _LinkParser(base_url)
+ parser.feed(html)
+ return parser.links
+
+
+def _fetch_text(url: str, *, timeout: float) -> str:
+ response = requests.get(url, timeout=timeout, headers={"User-Agent": "meubility-workbench-feed-discovery/0.1"})
+ response.raise_for_status()
+ return response.text
+
+
+def _first_link_matching(links: Iterable[str], needle: str) -> str:
+ for link in links:
+ if needle in link:
+ return link
+ return ""
+
+
+def _feed_id_from_url(url: str) -> str:
+ query = parse_qs(urlparse(url).query)
+ return (query.get("feed") or [""])[0]
+
+
+def _looks_like_download_url(url: str) -> bool:
+ if not url:
+ return False
+ parsed = urlparse(url)
+ lower_path = parsed.path.lower()
+ lower_url = url.lower()
+ if lower_path.endswith(".zip"):
+ return True
+ if "exportformat=gtfs" in lower_url or "google_transit" in lower_url:
+ return True
+ if lower_path.rstrip("/").endswith(("current_gtfs", "gtfs")):
+ return True
+ if "gtfs.ovapi.nl" in parsed.netloc.lower() and "gtfs" in lower_path:
+ return True
+ return False
+
+
+def _normalize_feed_url(url: str) -> str:
+ cleaned = _clean_text(url)
+ if not cleaned:
+ return ""
+ parsed = urlparse(cleaned)
+ if parsed.scheme:
+ return cleaned
+ first = cleaned.split("/", 1)[0]
+ if "." in first:
+ return f"https://{cleaned}"
+ return cleaned
+
+
+def _choose_feed_url(direct_url: str, latest_url: str) -> str:
+ if direct_url:
+ return direct_url
+ return latest_url
+
+
+def _candidate_priority(candidate: FeedCandidate) -> str:
+ status = candidate.status.lower()
+ official = candidate.is_official.lower() == "true"
+ if candidate.discovery_source == "curated_seed":
+ return candidate.priority or "P1"
+ if status == "active" and official and candidate.direct_download_url:
+ return "P0"
+ if status == "active" and candidate.direct_download_url:
+ return "P1"
+ if status == "active" and candidate.latest_url:
+ return "P2"
+ if candidate.discovery_source == "ptna":
+ return "P2" if candidate.selected_url else "P4"
+ return "P3"
+
+
+def _test_candidate_sort_key(candidate: FeedCandidate) -> tuple[int, int, str, str]:
+ source_bonus = 0 if candidate.discovery_source == "curated_seed" else 1
+ country_bonus = CURATED_TEST_COUNTRIES.index(candidate.country) if candidate.country in CURATED_TEST_COUNTRIES else 99
+ return (_priority_sort_key(candidate.priority), source_bonus + country_bonus, candidate.country, candidate.provider.lower())
+
+
+def _priority_sort_key(priority: str) -> int:
+ match = re.match(r"P(\d+)", priority or "")
+ return int(match.group(1)) if match else 9
+
+
+def _candidate_alias_keys(candidate: FeedCandidate) -> list[str]:
+ keys = [candidate.key()]
+ if candidate.stable_id:
+ keys.append(f"stable:{candidate.stable_id}")
+ for url in [candidate.selected_url, candidate.direct_download_url, candidate.latest_url]:
+ if url:
+ keys.append(f"url:{_normalize_url_key(url)}")
+ if candidate.ptna_feed_id:
+ keys.append(f"ptna:{candidate.ptna_feed_id}")
+ deduped: list[str] = []
+ for key in keys:
+ if key not in deduped:
+ deduped.append(key)
+ return deduped
+
+
+def _merge_candidate(existing: FeedCandidate, incoming: FeedCandidate) -> None:
+ if incoming.discovery_source == "curated_seed":
+ for field_name in ["country", "provider", "feed_name", "license_text", "features", "source_basis", "notes"]:
+ new_value = getattr(incoming, field_name, "")
+ if new_value:
+ setattr(existing, field_name, new_value)
+ existing.discovery_source = _join_unique(existing.discovery_source, incoming.discovery_source)
+ for field_name in CANONICAL_HEADERS:
+ if field_name == "candidate_id":
+ continue
+ current = getattr(existing, field_name, "")
+ new_value = getattr(incoming, field_name, "")
+ if not current and new_value:
+ setattr(existing, field_name, new_value)
+ existing.priority = _better_priority(existing.priority, incoming.priority)
+ existing.source_basis = _join_unique(existing.source_basis, incoming.source_basis)
+ existing.notes = _join_notes(existing.notes, incoming.notes)
+
+
+def _better_priority(left: str, right: str) -> str:
+ return left if _priority_sort_key(left) <= _priority_sort_key(right) else right
+
+
+def _join_unique(left: str, right: str) -> str:
+ parts: list[str] = []
+ for value in [left, right]:
+ for part in value.split(";"):
+ cleaned = part.strip()
+ if cleaned and cleaned not in parts:
+ parts.append(cleaned)
+ return "; ".join(parts)
+
+
+def _join_notes(left: str, right: str) -> str:
+ return _join_unique(left, right)
+
+
+def _compact_name(value: str) -> str:
+ return re.sub(r"\s+", " ", _clean_text(value)).strip()
+
+
+def _feed_source_name(country: str, value: str) -> str:
+ base = _compact_name(value) or "GTFS feed"
+ prefix = country.upper()
+ display = base
+ if prefix and not base.upper().startswith(f"{prefix} "):
+ display = f"{prefix} {base}"
+ if "gtfs" not in display.lower():
+ display = f"{display} GTFS"
+ return display
+
+
+def _clean_text(value: str) -> str:
+ cleaned = unescape(value or "").replace("\xa0", " ")
+ cleaned = re.sub(r"\s+", " ", cleaned)
+ return cleaned.strip()
+
+
+def _mode_scope_from_features(features: str) -> str:
+ lower = features.lower()
+ modes = []
+ if "rail" in lower or "train" in lower:
+ modes.append("rail")
+ if "tram" in lower or "light_rail" in lower:
+ modes.append("tram")
+ if "subway" in lower or "metro" in lower:
+ modes.append("metro")
+ if "bus" in lower or not modes:
+ modes.append("bus")
+ if "ferry" in lower:
+ modes.append("ferry")
+ return ",".join(dict.fromkeys(modes))
+
+
+def _bbox_from_mobility_row(row: dict[str, str]) -> str:
+ min_lat = _value(row, "location.bounding_box.minimum_latitude")
+ max_lat = _value(row, "location.bounding_box.maximum_latitude")
+ min_lon = _value(row, "location.bounding_box.minimum_longitude")
+ max_lon = _value(row, "location.bounding_box.maximum_longitude")
+ if not all([min_lat, max_lat, min_lon, max_lon]):
+ return ""
+ return f"{min_lon},{min_lat},{max_lon},{max_lat}"
+
+
+def _normalize_countries(countries: Iterable[str] | None) -> list[str] | None:
+ if countries is None:
+ return DEFAULT_DISCOVERY_COUNTRIES
+ normalized = [country.strip().upper() for country in countries if country and country.strip()]
+ if any(country == "ALL" for country in normalized):
+ return None
+ return normalized
+
+
+def _normalize_url_key(url: str) -> str:
+ parsed = urlparse(url.strip())
+ scheme = parsed.scheme.lower()
+ netloc = parsed.netloc.lower()
+ path = parsed.path.rstrip("/")
+ query = parsed.query
+ return f"{scheme}://{netloc}{path}" + (f"?{query}" if query else "")
+
+
+def _write_csv(path: Path, headers: list[str], rows: list[dict[str, str]]) -> None:
+ with path.open("w", encoding="utf-8", newline="") as handle:
+ writer = csv.DictWriter(handle, fieldnames=headers, extrasaction="ignore")
+ writer.writeheader()
+ writer.writerows(rows)
+
+
+def _count_by(items: Iterable[FeedCandidate], key_fn) -> dict[str, int]:
+ counts: dict[str, int] = {}
+ for item in items:
+ key = key_fn(item)
+ counts[key] = counts.get(key, 0) + 1
+ return dict(sorted(counts.items()))
+
+
+def _value(row: dict[str, str], key: str) -> str:
+ return _clean_text(row.get(key, ""))
+
+
+def _string(value: object) -> str:
+ return "" if value is None else str(value)
+
+
+def _truncate(value: str, length: int) -> str:
+ return value[:length] if value else ""
diff --git a/app/geofabrik.py b/app/geofabrik.py
new file mode 100644
index 0000000..c5c011b
--- /dev/null
+++ b/app/geofabrik.py
@@ -0,0 +1,120 @@
+from __future__ import annotations
+
+from datetime import datetime, timedelta, timezone
+from typing import Any
+
+import requests
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
+from app.models import Source
+
+
+GEOFABRIK_INDEX_URL = "https://download.geofabrik.de/index-v1-nogeom.json"
+_CACHE: dict[str, Any] = {"expires_at": None, "rows": None}
+
+
+def geofabrik_catalog(q: str | None = None, limit: int = 80) -> list[dict[str, Any]]:
+ rows = _geofabrik_rows()
+ query = (q or "").strip().casefold()
+ if query:
+ rows = [
+ row
+ for row in rows
+ if query in row["id"].casefold()
+ or query in row["name"].casefold()
+ or query in (row.get("parent") or "").casefold()
+ or query in " ".join(row.get("country_codes") or []).casefold()
+ ]
+ rows.sort(key=lambda row: (row.get("parent") or "", row["name"]))
+ return rows[: max(1, min(limit, 500))]
+
+
+def geofabrik_entry(geofabrik_id: str) -> dict[str, Any] | None:
+ target = geofabrik_id.strip().casefold()
+ for row in _geofabrik_rows():
+ if row["id"].casefold() == target:
+ return row
+ return None
+
+
+def create_geofabrik_source(session: Session, geofabrik_id: str, *, import_updates: bool = False) -> Source:
+ entry = geofabrik_entry(geofabrik_id)
+ if entry is None:
+ raise ValueError(f"Geofabrik extract not found: {geofabrik_id}")
+ if not entry.get("pbf_url"):
+ raise ValueError(f"Geofabrik extract has no PBF URL: {geofabrik_id}")
+ existing = session.scalar(select(Source).where(Source.kind == "osm_pbf", Source.url == entry["pbf_url"]))
+ if existing is not None:
+ return existing
+ source = Source(
+ name=f"Geofabrik {entry['name']}",
+ kind="osm_pbf",
+ url=entry["pbf_url"],
+ country=",".join(entry.get("country_codes") or [])[:8] or None,
+ license="ODbL / Geofabrik extract terms",
+ priority="P0 fallback",
+ mode_scope="public transport OSM routes, stops, and infrastructure",
+ source_basis="OpenStreetMap / Geofabrik extracts",
+ notes=_geofabrik_notes(entry, import_updates=import_updates),
+ )
+ session.add(source)
+ session.flush()
+ if import_updates and entry.get("updates_url"):
+ update_source = Source(
+ name=f"Geofabrik {entry['name']} updates",
+ kind="osm_diff",
+ url=entry["updates_url"],
+ country=source.country,
+ license=source.license,
+ priority=source.priority,
+ mode_scope=source.mode_scope,
+ source_basis="OpenStreetMap / Geofabrik replication diffs",
+ notes=f"Diff base for Geofabrik extract {entry['id']}; applying diffs to a local base extract is not implemented yet.",
+ )
+ session.add(update_source)
+ return source
+
+
+def _geofabrik_rows() -> list[dict[str, Any]]:
+ now = datetime.now(timezone.utc)
+ expires_at = _CACHE.get("expires_at")
+ if _CACHE.get("rows") is not None and isinstance(expires_at, datetime) and expires_at > now:
+ return list(_CACHE["rows"])
+ response = requests.get(GEOFABRIK_INDEX_URL, timeout=45)
+ response.raise_for_status()
+ payload = response.json()
+ rows = [_normalize_feature(feature) for feature in payload.get("features", [])]
+ rows = [row for row in rows if row.get("id") and row.get("pbf_url")]
+ _CACHE["rows"] = rows
+ _CACHE["expires_at"] = now + timedelta(hours=12)
+ return list(rows)
+
+
+def _normalize_feature(feature: dict[str, Any]) -> dict[str, Any]:
+ props = feature.get("properties") or {}
+ urls = props.get("urls") or {}
+ country_codes = props.get("iso3166-1:alpha2") or []
+ if isinstance(country_codes, str):
+ country_codes = [country_codes]
+ return {
+ "id": str(props.get("id") or ""),
+ "name": str(props.get("name") or props.get("id") or ""),
+ "parent": props.get("parent"),
+ "country_codes": country_codes,
+ "pbf_url": urls.get("pbf"),
+ "updates_url": urls.get("updates"),
+ "taginfo_url": urls.get("taginfo"),
+ "urls": urls,
+ }
+
+
+def _geofabrik_notes(entry: dict[str, Any], *, import_updates: bool) -> str:
+ parts = [
+ f"geofabrik_id={entry['id']}",
+ f"parent={entry.get('parent') or 'root'}",
+ f"updates_url={entry.get('updates_url') or ''}",
+ "diff_source_requested=true" if import_updates else "diff_source_requested=false",
+ "Overlap dedupe is handled by OSM object identity in the route layer; source-specific map layers may still show both extracts.",
+ ]
+ return "; ".join(parts)
diff --git a/app/gtfs_storage.py b/app/gtfs_storage.py
new file mode 100644
index 0000000..80d9016
--- /dev/null
+++ b/app/gtfs_storage.py
@@ -0,0 +1,308 @@
+from __future__ import annotations
+
+import json
+import sqlite3
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Iterator, Sequence
+
+from sqlalchemy import func, select
+from sqlalchemy.orm import Session
+
+from app.config import settings
+from app.models import Dataset, GtfsStopTime
+
+
+GTFS_STORAGE_METADATA_KEY = "gtfs_storage"
+GTFS_STORAGE_MAIN = "main"
+GTFS_STORAGE_SIDECAR_STOP_TIMES = "sidecar_stop_times"
+GTFS_STOP_TIME_COLUMNS = [
+ "trip_id",
+ "stop_id",
+ "stop_sequence",
+ "arrival_time",
+ "departure_time",
+ "arrival_seconds",
+ "departure_seconds",
+]
+SQLITE_IN_CHUNK_SIZE = 800
+
+
+def effective_gtfs_timetable_storage(value: str | None = None) -> str:
+ configured = str(value or settings.gtfs_timetable_storage or GTFS_STORAGE_SIDECAR_STOP_TIMES).strip().lower()
+ if configured in {GTFS_STORAGE_MAIN, "main_db", "main_sqlite", "postgres", "postgresql"}:
+ return GTFS_STORAGE_MAIN
+ if settings.is_postgresql_database and not settings.postgres_use_sidecars:
+ return GTFS_STORAGE_MAIN
+ return GTFS_STORAGE_SIDECAR_STOP_TIMES
+
+
+class MissingGtfsSidecar(FileNotFoundError):
+ def __init__(self, dataset_id: int | None, path: Path | None) -> None:
+ self.dataset_id = dataset_id
+ self.path = path
+ if path is None:
+ message = f"dataset #{dataset_id} does not reference a GTFS sidecar"
+ else:
+ message = f"GTFS sidecar does not exist: {path}"
+ super().__init__(message)
+
+
+def dataset_metadata(dataset: Dataset) -> dict:
+ try:
+ metadata = json.loads(dataset.metadata_json or "{}")
+ except json.JSONDecodeError:
+ return {}
+ return metadata if isinstance(metadata, dict) else {}
+
+
+def stop_times_are_sidecar(dataset: Dataset | None) -> bool:
+ if dataset is None:
+ return False
+ storage = dataset_metadata(dataset).get(GTFS_STORAGE_METADATA_KEY)
+ if not isinstance(storage, dict):
+ return False
+ tables = storage.get("tables")
+ if isinstance(tables, dict):
+ return tables.get("gtfs_stop_times") == "sidecar"
+ return storage.get("mode") == GTFS_STORAGE_SIDECAR_STOP_TIMES
+
+
+def sidecar_path(dataset: Dataset | None) -> Path | None:
+ if dataset is None:
+ return None
+ storage = dataset_metadata(dataset).get(GTFS_STORAGE_METADATA_KEY)
+ if not isinstance(storage, dict):
+ return None
+ value = storage.get("sidecar_path")
+ if not value:
+ return None
+ return Path(str(value))
+
+
+def dataset_sidecar_paths(dataset: Dataset) -> list[Path]:
+ path = sidecar_path(dataset)
+ return [] if path is None else [path]
+
+
+def missing_sidecar_paths(dataset: Dataset | None) -> list[str]:
+ if not stop_times_are_sidecar(dataset):
+ return []
+ path = sidecar_path(dataset)
+ if path is None:
+ dataset_id = "unknown" if dataset is None else str(dataset.id)
+ return [f"dataset #{dataset_id} has no configured GTFS sidecar path"]
+ return [] if path.exists() else [str(path)]
+
+
+def uses_sidecar_stop_times(session: Session, dataset_id: int) -> bool:
+ return stop_times_are_sidecar(session.get(Dataset, dataset_id))
+
+
+@contextmanager
+def sidecar_connection(dataset: Dataset) -> Iterator[sqlite3.Connection]:
+ path = sidecar_path(dataset)
+ if path is None:
+ raise MissingGtfsSidecar(dataset.id, None)
+ if not path.exists():
+ raise MissingGtfsSidecar(dataset.id, path)
+ connection = sqlite3.connect(f"file:{path}?mode=ro", uri=True)
+ connection.row_factory = sqlite3.Row
+ try:
+ yield connection
+ finally:
+ connection.close()
+
+
+def stop_time_count(session: Session, dataset_id: int) -> int:
+ dataset = session.get(Dataset, dataset_id)
+ if stop_times_are_sidecar(dataset):
+ try:
+ with sidecar_connection(dataset) as connection:
+ return int(connection.execute("SELECT COUNT(*) FROM gtfs_stop_times").fetchone()[0] or 0)
+ except MissingGtfsSidecar:
+ return 0
+ return session.scalar(select(func.count()).select_from(GtfsStopTime).where(GtfsStopTime.dataset_id == dataset_id)) or 0
+
+
+def stop_time_counts_by_dataset(session: Session, dataset_ids: Sequence[int]) -> dict[int, int]:
+ counts: dict[int, int] = {}
+ for dataset_id in dataset_ids:
+ counts[int(dataset_id)] = stop_time_count(session, int(dataset_id))
+ return counts
+
+
+def scheduled_stop_ids(session: Session, dataset_id: int, stop_ids: Sequence[str]) -> tuple[str, ...]:
+ if not stop_ids:
+ return ()
+ dataset = session.get(Dataset, dataset_id)
+ requested = [str(stop_id) for stop_id in stop_ids]
+ found: set[str] = set()
+ if stop_times_are_sidecar(dataset):
+ try:
+ with sidecar_connection(dataset) as connection:
+ for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
+ placeholders = ", ".join(["?"] * len(chunk))
+ rows = connection.execute(
+ f"""
+ SELECT stop_id
+ FROM gtfs_stop_times
+ WHERE stop_id IN ({placeholders})
+ GROUP BY stop_id
+ """,
+ list(chunk),
+ ).fetchall()
+ found.update(str(row["stop_id"]) for row in rows)
+ except MissingGtfsSidecar:
+ return ()
+ else:
+ for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
+ rows = session.scalars(
+ select(GtfsStopTime.stop_id)
+ .where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.stop_id.in_(chunk))
+ .group_by(GtfsStopTime.stop_id)
+ ).all()
+ found.update(str(row) for row in rows)
+ return tuple(sorted(found))
+
+
+def all_scheduled_stop_ids(session: Session, dataset_id: int) -> set[str]:
+ dataset = session.get(Dataset, dataset_id)
+ if stop_times_are_sidecar(dataset):
+ try:
+ with sidecar_connection(dataset) as connection:
+ return {
+ str(row["stop_id"])
+ for row in connection.execute("SELECT stop_id FROM gtfs_stop_times GROUP BY stop_id").fetchall()
+ }
+ except MissingGtfsSidecar:
+ return set()
+ return {
+ str(row)
+ for row in session.scalars(
+ select(GtfsStopTime.stop_id)
+ .where(GtfsStopTime.dataset_id == dataset_id)
+ .group_by(GtfsStopTime.stop_id)
+ ).all()
+ }
+
+
+def scheduled_stop_ids_by_dataset(session: Session, dataset_ids: Sequence[int]) -> dict[int, set[str]]:
+ return {int(dataset_id): all_scheduled_stop_ids(session, int(dataset_id)) for dataset_id in dataset_ids}
+
+
+def has_scheduled_stop(session: Session, dataset_id: int, stop_id: str) -> bool:
+ return bool(scheduled_stop_ids(session, dataset_id, [stop_id]))
+
+
+def stop_times_by_trip(
+ session: Session,
+ dataset_id: int,
+ trip_ids: Sequence[str],
+) -> dict[str, list[GtfsStopTime]]:
+ if not trip_ids:
+ return {}
+ grouped: dict[str, list[GtfsStopTime]] = {}
+ dataset = session.get(Dataset, dataset_id)
+ requested = [str(trip_id) for trip_id in trip_ids]
+ if stop_times_are_sidecar(dataset):
+ column_sql = ", ".join(GTFS_STOP_TIME_COLUMNS)
+ try:
+ with sidecar_connection(dataset) as connection:
+ for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
+ placeholders = ", ".join(["?"] * len(chunk))
+ rows = connection.execute(
+ f"""
+ SELECT {column_sql}
+ FROM gtfs_stop_times
+ WHERE trip_id IN ({placeholders})
+ ORDER BY trip_id, stop_sequence
+ """,
+ list(chunk),
+ ).fetchall()
+ for row in rows:
+ stop_time = stop_time_from_row(dataset_id, row)
+ grouped.setdefault(stop_time.trip_id, []).append(stop_time)
+ except MissingGtfsSidecar:
+ return {}
+ return grouped
+
+ for chunk in _chunks(requested, SQLITE_IN_CHUNK_SIZE):
+ rows = session.scalars(
+ select(GtfsStopTime)
+ .where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.trip_id.in_(chunk))
+ .order_by(GtfsStopTime.trip_id, GtfsStopTime.stop_sequence)
+ ).all()
+ for row in rows:
+ grouped.setdefault(row.trip_id, []).append(row)
+ return grouped
+
+
+def stop_times_for_trip_range(
+ session: Session,
+ dataset_id: int,
+ trip_id: str,
+ start_sequence: int,
+ end_sequence: int,
+) -> list[GtfsStopTime]:
+ dataset = session.get(Dataset, dataset_id)
+ if stop_times_are_sidecar(dataset):
+ column_sql = ", ".join(GTFS_STOP_TIME_COLUMNS)
+ try:
+ with sidecar_connection(dataset) as connection:
+ rows = connection.execute(
+ f"""
+ SELECT {column_sql}
+ FROM gtfs_stop_times
+ WHERE trip_id = ?
+ AND stop_sequence >= ?
+ AND stop_sequence <= ?
+ ORDER BY stop_sequence
+ """,
+ (trip_id, int(start_sequence), int(end_sequence)),
+ ).fetchall()
+ return [stop_time_from_row(dataset_id, row) for row in rows]
+ except MissingGtfsSidecar:
+ return []
+
+ return list(
+ session.scalars(
+ select(GtfsStopTime)
+ .where(
+ GtfsStopTime.dataset_id == dataset_id,
+ GtfsStopTime.trip_id == trip_id,
+ GtfsStopTime.stop_sequence >= start_sequence,
+ GtfsStopTime.stop_sequence <= end_sequence,
+ )
+ .order_by(GtfsStopTime.stop_sequence)
+ ).all()
+ )
+
+
+def stop_time_from_row(dataset_id: int, row) -> GtfsStopTime:
+ return GtfsStopTime(
+ dataset_id=dataset_id,
+ trip_id=str(row["trip_id"]),
+ stop_id=str(row["stop_id"]),
+ stop_sequence=int(row["stop_sequence"]),
+ arrival_time=row["arrival_time"],
+ departure_time=row["departure_time"],
+ arrival_seconds=row["arrival_seconds"],
+ departure_seconds=row["departure_seconds"],
+ )
+
+
+def execute_sidecar_query(session: Session, dataset_id: int, sql: str, params: Sequence[object]) -> list[sqlite3.Row]:
+ dataset = session.get(Dataset, dataset_id)
+ if not stop_times_are_sidecar(dataset):
+ raise ValueError(f"dataset #{dataset_id} does not use sidecar stop_times")
+ try:
+ with sidecar_connection(dataset) as connection:
+ return list(connection.execute(sql, list(params)).fetchall())
+ except MissingGtfsSidecar:
+ return []
+
+
+def _chunks[T](items: Sequence[T], size: int) -> Iterator[Sequence[T]]:
+ for index in range(0, len(items), size):
+ yield items[index : index + size]
diff --git a/app/harmonization.py b/app/harmonization.py
new file mode 100644
index 0000000..701425c
--- /dev/null
+++ b/app/harmonization.py
@@ -0,0 +1,394 @@
+from __future__ import annotations
+
+from datetime import date, datetime, timezone
+from typing import Any
+
+from sqlalchemy import and_, func, select
+from sqlalchemy.orm import Session, aliased
+
+from app.data_management import dataset_row_counts
+from app.models import (
+ CanonicalStopLink,
+ Dataset,
+ GtfsCalendar,
+ GtfsCalendarDate,
+ GtfsRoute,
+ GtfsStop,
+ GtfsStopTime,
+ GtfsTrip,
+ RouteMatch,
+ Source,
+)
+
+
+GTFS_QA_NOTE_PREFIX = "[GTFS QA]"
+
+
+def gtfs_harmonization_inventory(session: Session) -> dict[str, Any]:
+ feeds = [_feed_inventory_item(session, source) for source in _gtfs_sources(session)]
+ summary = {
+ "sources": len(feeds),
+ "active_sources": sum(1 for feed in feeds if feed["active_dataset"] is not None),
+ "datasets": sum(len(feed["datasets"]) for feed in feeds),
+ "ready": sum(1 for feed in feeds if feed["qa_status"] == "ready"),
+ "needs_review": sum(1 for feed in feeds if feed["qa_status"] == "needs_review"),
+ "blocked": sum(1 for feed in feeds if feed["qa_status"] == "blocked"),
+ }
+ return {
+ "summary": summary,
+ "feeds": feeds,
+ }
+
+
+def gtfs_harmonization_feed_detail(session: Session, source_id: int) -> dict[str, Any] | None:
+ source = session.get(Source, source_id)
+ if source is None or source.kind != "gtfs":
+ return None
+ feed = _feed_inventory_item(session, source)
+ return {
+ **feed,
+ "sections": _feed_sections(feed),
+ }
+
+
+def _gtfs_sources(session: Session) -> list[Source]:
+ return session.scalars(select(Source).where(Source.kind == "gtfs").order_by(Source.country, Source.priority, Source.name, Source.id)).all()
+
+
+def _feed_inventory_item(session: Session, source: Source) -> dict[str, Any]:
+ datasets = sorted([dataset for dataset in source.datasets if dataset.kind == "gtfs"], key=lambda item: (not item.is_active, item.created_at, item.id))
+ active_dataset = next((dataset for dataset in datasets if dataset.is_active), None)
+ counts = dataset_row_counts(session, active_dataset.id, active_dataset.kind) if active_dataset is not None else {}
+ validation = _validate_gtfs_dataset(session, source, active_dataset, counts)
+ overlap = _overlap_summary(session, active_dataset)
+ service = _service_horizon(session, active_dataset)
+ issues = [*validation["issues"], *service["issues"], *overlap["issues"], *_license_issues(source)]
+ qa_status = _qa_status(issues, active_dataset)
+ return {
+ "source": _source_payload(source),
+ "active_dataset": None if active_dataset is None else _dataset_payload(active_dataset, counts),
+ "datasets": [_dataset_payload(dataset, dataset_row_counts(session, dataset.id, dataset.kind)) for dataset in datasets],
+ "counts": counts,
+ "validation": validation,
+ "service": service,
+ "overlap": overlap,
+ "license": _license_payload(source),
+ "issues": issues,
+ "qa_status": qa_status,
+ }
+
+
+def _source_payload(source: Source) -> dict[str, Any]:
+ return {
+ "id": source.id,
+ "name": source.name,
+ "country": source.country,
+ "license": source.license,
+ "priority": source.priority,
+ "mode_scope": source.mode_scope,
+ "source_basis": source.source_basis,
+ "status": source.status,
+ "enabled": source.enabled,
+ "last_error": source.last_error,
+ "last_run_at": _iso(source.last_run_at),
+ "url": source.url,
+ "catalog_entry_id": source.catalog_entry_id,
+ "notes": source.notes,
+ "qa_review": _qa_review_payload(source.notes),
+ }
+
+
+def _dataset_payload(dataset: Dataset, counts: dict[str, Any]) -> dict[str, Any]:
+ return {
+ "id": dataset.id,
+ "kind": dataset.kind,
+ "is_active": dataset.is_active,
+ "status": dataset.status,
+ "sha256": dataset.sha256,
+ "local_path": dataset.local_path,
+ "created_at": _iso(dataset.created_at),
+ "counts": counts,
+ }
+
+
+def _validate_gtfs_dataset(session: Session, source: Source, dataset: Dataset | None, counts: dict[str, Any]) -> dict[str, Any]:
+ if dataset is None:
+ return {
+ "status": "blocked",
+ "items": [],
+ "issues": [_issue("missing_active_dataset", "bad", "No active GTFS dataset", "Import this source before harmonization.")],
+ }
+ items = [
+ _metric("Agencies", counts.get("agencies", 0), "bad" if not counts.get("agencies", 0) else "good"),
+ _metric("Stops", counts.get("stops", 0), "bad" if not counts.get("stops", 0) else "good"),
+ _metric("Routes", counts.get("routes", 0), "bad" if not counts.get("routes", 0) else "good"),
+ _metric("Trips", counts.get("trips", 0), "bad" if not counts.get("trips", 0) else "good"),
+ _metric("Stop times", counts.get("stop_times", 0), "bad" if not counts.get("stop_times", 0) else "good"),
+ _metric("Shapes", counts.get("shapes", 0), "warn" if not counts.get("shapes", 0) else "good"),
+ ]
+ missing_coords = _count(session, GtfsStop, dataset.id, (GtfsStop.lat.is_(None) | GtfsStop.lon.is_(None)))
+ invalid_coords = _count(
+ session,
+ GtfsStop,
+ dataset.id,
+ (GtfsStop.lat < -90) | (GtfsStop.lat > 90) | (GtfsStop.lon < -180) | (GtfsStop.lon > 180),
+ )
+ routes_without_trips = _routes_without_trips(session, dataset.id)
+ trips_without_stop_times = _trips_without_stop_times(session, dataset.id)
+ stop_times_without_seconds = _stop_times_without_seconds(session, dataset.id)
+ route_geometry_missing = _count(session, GtfsRoute, dataset.id, GtfsRoute.geometry_geojson.is_(None))
+ canonical_links = _count(session, CanonicalStopLink, dataset.id, CanonicalStopLink.object_type == "gtfs_stop")
+ match_counts = counts.get("match_counts", {}) if isinstance(counts.get("match_counts"), dict) else {}
+
+ items.extend(
+ [
+ _metric("Stops missing coordinates", missing_coords, "bad" if missing_coords else "good"),
+ _metric("Stops with invalid coordinates", invalid_coords, "bad" if invalid_coords else "good"),
+ _metric("Routes without trips", routes_without_trips, "bad" if routes_without_trips else "good"),
+ _metric("Trips without stop_times", trips_without_stop_times, "bad" if trips_without_stop_times else "good"),
+ _metric("Stop times without parsed seconds", stop_times_without_seconds, "warn" if stop_times_without_seconds else "good"),
+ _metric("Routes without geometry", route_geometry_missing, "warn" if route_geometry_missing else "good"),
+ _metric("Canonical stop links", canonical_links, "warn" if counts.get("stops", 0) and canonical_links == 0 else "good"),
+ _metric("Route matches", counts.get("matches", 0), "warn" if counts.get("routes", 0) and not counts.get("matches", 0) else "good"),
+ ]
+ )
+ issues: list[dict[str, str]] = []
+ if counts.get("missing_sidecar"):
+ issues.append(_issue("missing_sidecar", "bad", "GTFS sidecar is missing", "Queue a recovery import for this dataset."))
+ for key, label in [
+ ("agencies", "No agencies imported"),
+ ("stops", "No stops imported"),
+ ("routes", "No routes imported"),
+ ("trips", "No trips imported"),
+ ("stop_times", "No stop_times imported"),
+ ]:
+ if not counts.get(key, 0):
+ issues.append(_issue(f"missing_{key}", "bad", label, "Required GTFS content is absent or failed to import."))
+ if missing_coords:
+ issues.append(_issue("missing_stop_coordinates", "bad", f"{missing_coords:,} stops have no coordinates", "Stop coordinates are required for deduplication and routing access."))
+ if invalid_coords:
+ issues.append(_issue("invalid_stop_coordinates", "bad", f"{invalid_coords:,} stops have invalid coordinates", "Fix or exclude invalid stop coordinates before publication."))
+ if routes_without_trips:
+ issues.append(_issue("routes_without_trips", "warn", f"{routes_without_trips:,} routes have no trips", "These routes cannot contribute timetable service."))
+ if trips_without_stop_times:
+ issues.append(_issue("trips_without_stop_times", "bad", f"{trips_without_stop_times:,} trips have no stop_times", "These trips cannot be routed."))
+ if route_geometry_missing:
+ issues.append(_issue("route_geometry_missing", "warn", f"{route_geometry_missing:,} routes have no geometry", "Use GTFS shapes, route-layer matching, or stop-by-stop fallback."))
+ if counts.get("routes", 0) and not counts.get("shapes", 0):
+ issues.append(_issue("missing_shapes", "warn", "No GTFS shapes imported", "OSM route matching or generated geometry will be needed."))
+ if counts.get("routes", 0) and not match_counts:
+ issues.append(_issue("no_route_matching", "warn", "No route-match rows", "Run route matching before route-layer publication QA."))
+ return {
+ "status": _qa_status(issues, dataset),
+ "items": items,
+ "issues": issues,
+ }
+
+
+def _service_horizon(session: Session, dataset: Dataset | None) -> dict[str, Any]:
+ if dataset is None:
+ return {"start_date": None, "end_date": None, "days_until_end": None, "items": [], "issues": []}
+ cal_min, cal_max = session.execute(
+ select(func.min(GtfsCalendar.start_date), func.max(GtfsCalendar.end_date)).where(GtfsCalendar.dataset_id == dataset.id)
+ ).one()
+ date_min, date_max = session.execute(
+ select(func.min(GtfsCalendarDate.date), func.max(GtfsCalendarDate.date)).where(GtfsCalendarDate.dataset_id == dataset.id)
+ ).one()
+ start_int = _min_int(cal_min, date_min)
+ end_int = _max_int(cal_max, date_max)
+ start_date = _gtfs_date(start_int)
+ end_date = _gtfs_date(end_int)
+ today = datetime.now(timezone.utc).date()
+ days_until_end = None if end_date is None else (end_date - today).days
+ issues: list[dict[str, str]] = []
+ if end_date is None:
+ issues.append(_issue("service_horizon_missing", "bad", "No service calendar horizon", "calendar.txt or calendar_dates.txt is required for reliable routing."))
+ elif days_until_end is not None and days_until_end < 0:
+ issues.append(_issue("service_horizon_expired", "bad", f"Service expired {abs(days_until_end):,} days ago", "Update or exclude this feed."))
+ elif days_until_end is not None and days_until_end < 30:
+ issues.append(_issue("service_horizon_short", "warn", f"Service ends in {days_until_end:,} days", "Update cadence is too close for publication confidence."))
+ return {
+ "start_date": None if start_date is None else start_date.isoformat(),
+ "end_date": None if end_date is None else end_date.isoformat(),
+ "days_until_end": days_until_end,
+ "items": [
+ _metric("Service starts", start_date.isoformat() if start_date else "n/a", "info"),
+ _metric("Service ends", end_date.isoformat() if end_date else "n/a", "bad" if end_date is None or (days_until_end is not None and days_until_end < 0) else "warn" if days_until_end is not None and days_until_end < 30 else "good"),
+ ],
+ "issues": issues,
+ }
+
+
+def _overlap_summary(session: Session, dataset: Dataset | None) -> dict[str, Any]:
+ if dataset is None:
+ return {"items": [], "issues": []}
+ route_key_overlaps = _shared_route_keys(session, dataset.id)
+ canonical_stop_overlaps = _shared_canonical_stops(session, dataset.id)
+ issues: list[dict[str, str]] = []
+ if route_key_overlaps:
+ issues.append(_issue("shared_route_keys", "warn", f"{route_key_overlaps:,} route keys also exist in another active feed", "Deduplicate or rank source authority for overlapping routes."))
+ if canonical_stop_overlaps:
+ issues.append(_issue("shared_canonical_stops", "warn", f"{canonical_stop_overlaps:,} canonical stops are shared with another active feed", "This is useful linking evidence, but conflicts need review."))
+ return {
+ "items": [
+ _metric("Shared route keys", route_key_overlaps, "warn" if route_key_overlaps else "good"),
+ _metric("Shared canonical stops", canonical_stop_overlaps, "warn" if canonical_stop_overlaps else "good"),
+ ],
+ "issues": issues,
+ }
+
+
+def _license_payload(source: Source) -> dict[str, Any]:
+ text = (source.license or "").strip()
+ unknown = not text or "unknown" in text.lower()
+ return {
+ "label": text or "unknown",
+ "redistribution_status": "unknown" if unknown else "review_required",
+ "tone": "warn" if unknown else "info",
+ }
+
+
+def _license_issues(source: Source) -> list[dict[str, str]]:
+ if _license_payload(source)["redistribution_status"] == "unknown":
+ return [_issue("license_unknown", "warn", "License/redistribution status is unknown", "Publication needs explicit import, derivation, redistribution, and attribution flags.")]
+ return []
+
+
+def _qa_review_payload(notes: str | None) -> dict[str, Any]:
+ if not notes:
+ return {"status": "unreviewed", "note": "", "updated_at": None}
+ for line in str(notes).splitlines():
+ if not line.startswith(GTFS_QA_NOTE_PREFIX):
+ continue
+ payload: dict[str, str] = {}
+ for part in line[len(GTFS_QA_NOTE_PREFIX) :].strip().split(";"):
+ if "=" not in part:
+ continue
+ key, value = part.split("=", 1)
+ payload[key.strip()] = value.strip()
+ return {
+ "status": payload.get("status") or "unreviewed",
+ "note": payload.get("note") or "",
+ "updated_at": payload.get("updated_at"),
+ }
+ return {"status": "unreviewed", "note": "", "updated_at": None}
+
+
+def _routes_without_trips(session: Session, dataset_id: int) -> int:
+ trip_exists = select(GtfsTrip.id).where(GtfsTrip.dataset_id == dataset_id, GtfsTrip.route_id == GtfsRoute.route_id).exists()
+ return int(session.scalar(select(func.count()).select_from(GtfsRoute).where(GtfsRoute.dataset_id == dataset_id, ~trip_exists)) or 0)
+
+
+def _trips_without_stop_times(session: Session, dataset_id: int) -> int:
+ stop_time_exists = select(GtfsStopTime.id).where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.trip_id == GtfsTrip.trip_id).exists()
+ return int(session.scalar(select(func.count()).select_from(GtfsTrip).where(GtfsTrip.dataset_id == dataset_id, ~stop_time_exists)) or 0)
+
+
+def _stop_times_without_seconds(session: Session, dataset_id: int) -> int:
+ return int(
+ session.scalar(
+ select(func.count())
+ .select_from(GtfsStopTime)
+ .where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.arrival_seconds.is_(None), GtfsStopTime.departure_seconds.is_(None))
+ )
+ or 0
+ )
+
+
+def _shared_route_keys(session: Session, dataset_id: int) -> int:
+ current = aliased(GtfsRoute)
+ other = aliased(GtfsRoute)
+ other_dataset = aliased(Dataset)
+ return int(
+ session.scalar(
+ select(func.count(func.distinct(current.route_key)))
+ .select_from(current)
+ .join(other, and_(other.route_key == current.route_key, other.dataset_id != current.dataset_id))
+ .join(other_dataset, other_dataset.id == other.dataset_id)
+ .where(
+ current.dataset_id == dataset_id,
+ current.route_key.is_not(None),
+ current.route_key != "",
+ other_dataset.kind == "gtfs",
+ other_dataset.is_active.is_(True),
+ )
+ )
+ or 0
+ )
+
+
+def _shared_canonical_stops(session: Session, dataset_id: int) -> int:
+ current = aliased(CanonicalStopLink)
+ other = aliased(CanonicalStopLink)
+ other_dataset = aliased(Dataset)
+ return int(
+ session.scalar(
+ select(func.count(func.distinct(current.canonical_stop_id)))
+ .select_from(current)
+ .join(other, and_(other.canonical_stop_id == current.canonical_stop_id, other.dataset_id != current.dataset_id))
+ .join(other_dataset, other_dataset.id == other.dataset_id)
+ .where(
+ current.dataset_id == dataset_id,
+ current.object_type == "gtfs_stop",
+ other.object_type == "gtfs_stop",
+ other_dataset.kind == "gtfs",
+ other_dataset.is_active.is_(True),
+ )
+ )
+ or 0
+ )
+
+
+def _count(session: Session, model: Any, dataset_id: int, *criteria: Any) -> int:
+ stmt = select(func.count()).select_from(model).where(model.dataset_id == dataset_id)
+ if criteria:
+ stmt = stmt.where(*criteria)
+ return int(session.scalar(stmt) or 0)
+
+
+def _metric(label: str, value: Any, tone: str = "info", description: str = "") -> dict[str, Any]:
+ return {"label": label, "value": value, "tone": tone, "description": description}
+
+
+def _issue(issue_id: str, severity: str, title: str, detail: str) -> dict[str, str]:
+ return {"id": issue_id, "severity": severity, "title": title, "detail": detail}
+
+
+def _qa_status(issues: list[dict[str, str]], dataset: Dataset | None) -> str:
+ if dataset is None or any(issue.get("severity") == "bad" for issue in issues):
+ return "blocked"
+ if any(issue.get("severity") == "warn" for issue in issues):
+ return "needs_review"
+ return "ready"
+
+
+def _feed_sections(feed: dict[str, Any]) -> list[dict[str, Any]]:
+ return [
+ {"id": "validation", "title": "GTFS Validation", "items": feed["validation"]["items"]},
+ {"id": "service", "title": "Service Horizon", "items": feed["service"]["items"]},
+ {"id": "overlap", "title": "Overlap and Deduplication", "items": feed["overlap"]["items"]},
+ {"id": "license", "title": "License", "items": [_metric("Redistribution", feed["license"]["redistribution_status"], feed["license"]["tone"]), _metric("License", feed["license"]["label"], feed["license"]["tone"])]},
+ ]
+
+
+def _gtfs_date(value: int | None) -> date | None:
+ if value is None:
+ return None
+ try:
+ return datetime.strptime(str(int(value)), "%Y%m%d").date()
+ except ValueError:
+ return None
+
+
+def _min_int(*values: int | None) -> int | None:
+ clean = [int(value) for value in values if value is not None]
+ return min(clean) if clean else None
+
+
+def _max_int(*values: int | None) -> int | None:
+ clean = [int(value) for value in values if value is not None]
+ return max(clean) if clean else None
+
+
+def _iso(value: datetime | None) -> str | None:
+ return None if value is None else value.isoformat()
diff --git a/app/itineraries.py b/app/itineraries.py
new file mode 100644
index 0000000..803eba0
--- /dev/null
+++ b/app/itineraries.py
@@ -0,0 +1,360 @@
+from __future__ import annotations
+
+import json
+from datetime import datetime, timezone
+from typing import Any
+
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
+from app.journey import duration_minutes_ceil, find_journeys, format_duration_label
+from app.models import Itinerary, ItineraryLeg, TravelRequest
+from app.routing import route_between_points
+
+
+def generate_itineraries(
+ db: Session,
+ *,
+ from_stop_id: str,
+ to_stop_id: str,
+ via_stop_id: str | None,
+ departure: str,
+ service_date: str | None,
+ max_transfers: int,
+ transfer_seconds: int,
+ limit: int,
+ source_ids: list[int] | None,
+ preferences: dict[str, Any] | None = None,
+) -> dict:
+ request = TravelRequest(
+ origin_stop_id=from_stop_id,
+ destination_stop_id=to_stop_id,
+ via_stop_id=via_stop_id or None,
+ departure_time=departure,
+ service_date=service_date or None,
+ max_transfers=max(0, max_transfers),
+ transfer_seconds=max(0, transfer_seconds),
+ source_filter=",".join(str(source_id) for source_id in source_ids or []) or None,
+ preferences_json=json.dumps(preferences or {}, separators=(",", ":")),
+ )
+ db.add(request)
+ db.flush()
+
+ journey_result = find_journeys(
+ db=db,
+ from_stop_id=from_stop_id,
+ to_stop_id=to_stop_id,
+ via_stop_id=via_stop_id,
+ departure=departure,
+ service_date=service_date,
+ max_transfers=max(0, max_transfers),
+ transfer_seconds=max(0, transfer_seconds),
+ limit=limit,
+ source_ids=source_ids,
+ )
+ itineraries: list[Itinerary] = []
+ for index, journey in enumerate(journey_result.get("journeys", []), start=1):
+ itinerary = _journey_itinerary(request.id, journey, index)
+ db.add(itinerary)
+ db.flush()
+ _add_journey_legs(db, itinerary.id, journey)
+ itineraries.append(itinerary)
+
+ car_itinerary = _car_itinerary(db, request.id, journey_result.get("from"), journey_result.get("to"))
+ if car_itinerary is not None:
+ db.add(car_itinerary)
+ db.flush()
+ _add_routing_leg(db, car_itinerary.id, car_itinerary)
+ itineraries.append(car_itinerary)
+
+ placeholders = _placeholder_itineraries(
+ request.id,
+ journey_result.get("from"),
+ journey_result.get("to"),
+ service_date=service_date,
+ include_car=car_itinerary is None,
+ )
+ for itinerary in placeholders:
+ db.add(itinerary)
+ db.flush()
+ itineraries.append(itinerary)
+
+ db.flush()
+ return {
+ "request": travel_request_payload(request),
+ "journey_context": {
+ "from": journey_result.get("from"),
+ "to": journey_result.get("to"),
+ "via": journey_result.get("via"),
+ "sources": journey_result.get("sources", []),
+ },
+ "itineraries": [itinerary_payload(db, itinerary) for itinerary in itineraries],
+ }
+
+
+def travel_request_payload(request: TravelRequest) -> dict[str, Any]:
+ return {
+ "id": request.id,
+ "origin_stop_id": request.origin_stop_id,
+ "destination_stop_id": request.destination_stop_id,
+ "via_stop_id": request.via_stop_id,
+ "departure_time": request.departure_time,
+ "service_date": request.service_date,
+ "max_transfers": request.max_transfers,
+ "transfer_seconds": request.transfer_seconds,
+ "source_filter": request.source_filter,
+ "preferences": _json_dict(request.preferences_json),
+ "created_at": request.created_at.isoformat() if request.created_at else None,
+ }
+
+
+def itinerary_payload(db: Session, itinerary: Itinerary) -> dict[str, Any]:
+ legs = db.scalars(
+ select(ItineraryLeg)
+ .where(ItineraryLeg.itinerary_id == itinerary.id)
+ .order_by(ItineraryLeg.sequence)
+ ).all()
+ return {
+ "id": itinerary.id,
+ "request_id": itinerary.request_id,
+ "title": itinerary.title,
+ "family": itinerary.family,
+ "status": itinerary.status,
+ "saved": itinerary.saved,
+ "summary": _json_dict(itinerary.summary_json),
+ "score": _json_dict(itinerary.score_json),
+ "payload": _json_dict(itinerary.payload_json),
+ "legs": [itinerary_leg_payload(leg) for leg in legs],
+ "created_at": itinerary.created_at.isoformat() if itinerary.created_at else None,
+ "updated_at": itinerary.updated_at.isoformat() if itinerary.updated_at else None,
+ }
+
+
+def itinerary_leg_payload(leg: ItineraryLeg) -> dict[str, Any]:
+ return {
+ "id": leg.id,
+ "itinerary_id": leg.itinerary_id,
+ "sequence": leg.sequence,
+ "mode": leg.mode,
+ "route_ref": leg.route_ref,
+ "route_name": leg.route_name,
+ "from_name": leg.from_name,
+ "to_name": leg.to_name,
+ "departure_time": leg.departure_time,
+ "arrival_time": leg.arrival_time,
+ "locked": leg.locked,
+ "payload": _json_dict(leg.payload_json),
+ }
+
+
+def set_itinerary_saved(db: Session, itinerary: Itinerary, saved: bool) -> dict[str, Any]:
+ itinerary.saved = saved
+ itinerary.status = "saved" if saved else "candidate"
+ itinerary.updated_at = datetime.now(timezone.utc)
+ db.flush()
+ return itinerary_payload(db, itinerary)
+
+
+def set_leg_locked(db: Session, leg: ItineraryLeg, locked: bool) -> dict[str, Any]:
+ leg.locked = locked
+ itinerary = db.get(Itinerary, leg.itinerary_id)
+ if itinerary is not None:
+ itinerary.updated_at = datetime.now(timezone.utc)
+ db.flush()
+ return itinerary_leg_payload(leg)
+
+
+def recent_itineraries(db: Session, *, saved_only: bool = False, limit: int = 30) -> list[dict[str, Any]]:
+ stmt = select(Itinerary).order_by(Itinerary.updated_at.desc(), Itinerary.id.desc())
+ if saved_only:
+ stmt = stmt.where(Itinerary.saved.is_(True))
+ rows = db.scalars(stmt.limit(max(1, min(limit, 100)))).all()
+ return [itinerary_payload(db, itinerary) for itinerary in rows]
+
+
+def _journey_itinerary(request_id: int, journey: dict, index: int) -> Itinerary:
+ score = _journey_score(journey)
+ summary = {
+ "departure_time": journey.get("departure_time"),
+ "arrival_time": journey.get("arrival_time"),
+ "duration_minutes": journey.get("duration_minutes"),
+ "duration_label": journey.get("duration_label"),
+ "transfers": journey.get("transfers"),
+ "leg_count": len(journey.get("legs", [])),
+ "route_refs": [leg.get("route_ref") or leg.get("route_id") for leg in journey.get("legs", [])],
+ }
+ return Itinerary(
+ request_id=request_id,
+ title=f"Public transport option {index}",
+ family="public_transport",
+ status="candidate",
+ saved=False,
+ summary_json=json.dumps(summary, separators=(",", ":")),
+ score_json=json.dumps(score, separators=(",", ":")),
+ payload_json=json.dumps({"journey": journey}, separators=(",", ":")),
+ )
+
+
+def _add_journey_legs(db: Session, itinerary_id: int, journey: dict) -> None:
+ for index, leg in enumerate(journey.get("legs", []), start=1):
+ db.add(
+ ItineraryLeg(
+ itinerary_id=itinerary_id,
+ sequence=index,
+ mode=leg.get("mode"),
+ route_ref=leg.get("route_ref"),
+ route_name=leg.get("route_name"),
+ from_name=(leg.get("from") or {}).get("name") or (leg.get("from") or {}).get("stop_id"),
+ to_name=(leg.get("to") or {}).get("name") or (leg.get("to") or {}).get("stop_id"),
+ departure_time=leg.get("departure_time"),
+ arrival_time=leg.get("arrival_time"),
+ locked=False,
+ payload_json=json.dumps({"journey_leg": leg}, separators=(",", ":")),
+ )
+ )
+
+
+def _car_itinerary(db: Session, request_id: int, from_stop: dict | None, to_stop: dict | None) -> Itinerary | None:
+ from_lon = _float_or_none((from_stop or {}).get("lon"))
+ from_lat = _float_or_none((from_stop or {}).get("lat"))
+ to_lon = _float_or_none((to_stop or {}).get("lon"))
+ to_lat = _float_or_none((to_stop or {}).get("lat"))
+ if None in {from_lon, from_lat, to_lon, to_lat}:
+ return None
+ try:
+ route = route_between_points(
+ db,
+ from_lon=from_lon,
+ from_lat=from_lat,
+ to_lon=to_lon,
+ to_lat=to_lat,
+ mode="drive",
+ max_visited=300_000,
+ )
+ except Exception: # noqa: BLE001 - car comparison is optional
+ return None
+ duration_seconds = _float_or_none(route.get("duration_seconds"))
+ duration_minutes = duration_minutes_ceil(duration_seconds)
+ distance_m = _float_or_none(route.get("distance_m"))
+ summary = {
+ "from": (from_stop or {}).get("name") or (from_stop or {}).get("stop_id") or "origin",
+ "to": (to_stop or {}).get("name") or (to_stop or {}).get("stop_id") or "destination",
+ "duration_minutes": duration_minutes,
+ "duration_label": format_duration_label(duration_seconds),
+ "distance_km": None if distance_m is None else round(distance_m / 1000, 1),
+ "transfers": 0,
+ "engine": route.get("engine"),
+ }
+ score = {
+ "duration_minutes": duration_minutes,
+ "transfers": 0,
+ "complexity": 1,
+ "emissions": "high",
+ "estimated_cost": None,
+ }
+ return Itinerary(
+ request_id=request_id,
+ title="Car only",
+ family="car",
+ status="candidate",
+ saved=False,
+ summary_json=json.dumps(summary, separators=(",", ":")),
+ score_json=json.dumps(score, separators=(",", ":")),
+ payload_json=json.dumps({"routing": route}, separators=(",", ":")),
+ )
+
+
+def _add_routing_leg(db: Session, itinerary_id: int, itinerary: Itinerary) -> None:
+ payload = _json_dict(itinerary.payload_json)
+ route = payload.get("routing") if isinstance(payload, dict) else None
+ if not isinstance(route, dict):
+ return
+ db.add(
+ ItineraryLeg(
+ itinerary_id=itinerary_id,
+ sequence=1,
+ mode=str(route.get("mode") or "drive"),
+ route_ref=None,
+ route_name="Road route",
+ from_name=str((route.get("start_node") or {}).get("osm_node_id") or "origin"),
+ to_name=str((route.get("target_node") or {}).get("osm_node_id") or "destination"),
+ departure_time=None,
+ arrival_time=None,
+ locked=False,
+ payload_json=json.dumps({"routing_leg": route}, separators=(",", ":")),
+ )
+ )
+
+
+def _placeholder_itineraries(
+ request_id: int,
+ from_stop: dict | None,
+ to_stop: dict | None,
+ *,
+ service_date: str | None,
+ include_car: bool = True,
+) -> list[Itinerary]:
+ from_name = (from_stop or {}).get("name") or (from_stop or {}).get("stop_id") or "origin"
+ to_name = (to_stop or {}).get("name") or (to_stop or {}).get("stop_id") or "destination"
+ placeholders = [
+ ("car_ferry", "Car + ferry", "Needs ferry-port candidate graph", {"complexity": 3, "emissions": "medium_high"}),
+ ("flight_access", "Flight + airport access", "Needs airport/flight schedule connector", {"complexity": 4, "emissions": "high"}),
+ ("rail_long_stay", "Rail with adjustable city stop", "Use via stop and leg locking to refine", {"complexity": 3, "emissions": "low"}),
+ ]
+ if include_car:
+ placeholders.insert(0, ("car", "Car only", "Needs road-routing connector", {"complexity": 1, "emissions": "high"}))
+ rows = []
+ for family, title, note, score in placeholders:
+ summary = {
+ "from": from_name,
+ "to": to_name,
+ "service_date": service_date,
+ "note": note,
+ "duration_minutes": None,
+ "transfers": None,
+ }
+ rows.append(
+ Itinerary(
+ request_id=request_id,
+ title=title,
+ family=family,
+ status="placeholder",
+ saved=False,
+ summary_json=json.dumps(summary, separators=(",", ":")),
+ score_json=json.dumps(score, separators=(",", ":")),
+ payload_json=json.dumps({"placeholder": True, "note": note}, separators=(",", ":")),
+ )
+ )
+ return rows
+
+
+def _float_or_none(value: object) -> float | None:
+ try:
+ return None if value is None else float(value)
+ except (TypeError, ValueError):
+ return None
+
+
+def _journey_score(journey: dict) -> dict[str, Any]:
+ modes = [leg.get("mode") for leg in journey.get("legs", [])]
+ duration = journey.get("duration_minutes")
+ transfers = int(journey.get("transfers") or 0)
+ railish = sum(1 for mode in modes if mode in {"train", "subway", "tram", "light_rail"})
+ busish = sum(1 for mode in modes if mode in {"bus", "coach", "trolleybus"})
+ emissions_hint = "low" if railish >= busish else "medium"
+ return {
+ "duration_minutes": duration,
+ "transfers": transfers,
+ "complexity": transfers + len(modes),
+ "emissions": emissions_hint,
+ "overnight": False,
+ "estimated_cost": None,
+ }
+
+
+def _json_dict(value: str | None) -> dict[str, Any]:
+ try:
+ data = json.loads(value or "{}")
+ except json.JSONDecodeError:
+ return {}
+ return data if isinstance(data, dict) else {}
diff --git a/app/jobs.py b/app/jobs.py
new file mode 100644
index 0000000..09cb1bb
--- /dev/null
+++ b/app/jobs.py
@@ -0,0 +1,1932 @@
+from __future__ import annotations
+
+import json
+import os
+import threading
+import time
+from contextlib import contextmanager
+from datetime import datetime, timedelta, timezone
+from typing import Any, Iterator
+from uuid import uuid4
+
+from sqlalchemy import func, select, text
+from sqlalchemy.orm import Session
+
+from app.config import settings
+from app.data_management import (
+ delete_dataset,
+ delete_source,
+ prune_inactive_datasets,
+ prune_unreferenced_cache_files,
+ unreferenced_cache_file_summary,
+)
+from app.db import SessionLocal, engine, init_db
+from app.db_lock import DatabaseWriteBusy, database_write_lock
+from app.gtfs_storage import missing_sidecar_paths as gtfs_missing_sidecar_paths
+from app.models import Dataset, Job, JobEvent, Source, SourceUpdateCheck
+from app.osm_storage import missing_sidecar_paths as osm_missing_sidecar_paths
+from app.pipeline.gtfs import backfill_gtfs_shapes
+from app.pipeline.matcher import run_route_matching
+from app.pipeline.osm_addresses import rebuild_address_index
+from app.pipeline.osm_labeling import relabel_osm_features
+from app.pipeline.osm_pbf import run_osm_pbf_source_staged
+from app.pipeline.route_layer import rebuild_route_layer
+from app.pipeline.run import run_source
+from app.pipeline.sample_data import clear_project_data, load_sample_project
+from app.source_catalog import import_ingestable_sources, import_source_catalog, source_catalog_summary
+
+
+ROUTE_MATCHING_JOB_KIND = "route_matching"
+ROUTE_LAYER_JOB_KIND = "route_layer_rebuild"
+ADDRESS_INDEX_JOB_KIND = "address_index_rebuild"
+OSM_RELABEL_JOB_KIND = "osm_relabel"
+SOURCE_IMPORT_JOB_KIND = "source_import"
+SOURCE_DELETE_JOB_KIND = "source_delete"
+DATASET_DELETE_JOB_KIND = "dataset_delete"
+MAINTENANCE_JOB_KIND = "maintenance"
+TERMINAL_JOB_STATUSES = {"completed", "failed", "cancelled"}
+ACTIVE_JOB_STATUSES = {"queued", "running", "paused"}
+LEASE_SECONDS = max(300, int(settings.queue_job_lease_seconds))
+HEARTBEAT_INTERVAL_SECONDS = 60
+
+
+class JobPaused(Exception):
+ pass
+
+
+class JobCancelled(Exception):
+ pass
+
+
+def create_route_layer_rebuild_job(session: Session, *, priority: int = 0) -> Job:
+ job = Job(
+ kind=ROUTE_LAYER_JOB_KIND,
+ status="queued",
+ description="Rebuild visual route layer from active GTFS and OSM datasets",
+ progress_current=0,
+ progress_total=4,
+ priority=int(priority),
+ )
+ session.add(job)
+ session.flush()
+ add_job_event(
+ session,
+ job,
+ event_type="queued",
+ message="Route-layer rebuild queued.",
+ progress_current=0,
+ progress_total=4,
+ )
+ return job
+
+
+def create_address_index_rebuild_job(session: Session, *, priority: int = 0) -> Job:
+ job = Job(
+ kind=ADDRESS_INDEX_JOB_KIND,
+ status="queued",
+ description="Rebuild OSM address index from the active OSM PBF dataset",
+ progress_current=0,
+ progress_total=4,
+ priority=int(priority),
+ )
+ session.add(job)
+ session.flush()
+ add_job_event(
+ session,
+ job,
+ event_type="queued",
+ message="OSM address index rebuild queued.",
+ progress_current=0,
+ progress_total=4,
+ )
+ return job
+
+
+def create_route_matching_job(session: Session, *, priority: int = 0) -> Job:
+ job = Job(
+ kind=ROUTE_MATCHING_JOB_KIND,
+ status="queued",
+ description="Match active GTFS routes against active OSM route features",
+ progress_current=0,
+ progress_total=0,
+ priority=int(priority),
+ )
+ session.add(job)
+ session.flush()
+ add_job_event(
+ session,
+ job,
+ event_type="queued",
+ message="Route matching queued.",
+ progress_current=0,
+ progress_total=0,
+ )
+ return job
+
+
+def create_osm_relabel_job(
+ session: Session,
+ *,
+ dataset_id: int | None = None,
+ build_route_layer: bool = True,
+ force: bool = False,
+ priority: int = 0,
+) -> Job:
+ description = "Relabel active OSM features"
+ if dataset_id is not None:
+ description = f"Relabel OSM features for dataset {dataset_id}"
+ job = Job(
+ kind=OSM_RELABEL_JOB_KIND,
+ status="queued",
+ description=description,
+ progress_current=0,
+ progress_total=2 if build_route_layer else 1,
+ priority=int(priority),
+ result_json=json.dumps(
+ {
+ "dataset_id": dataset_id,
+ "build_route_layer": build_route_layer,
+ "force": force,
+ },
+ separators=(",", ":"),
+ ),
+ )
+ session.add(job)
+ session.flush()
+ add_job_event(
+ session,
+ job,
+ event_type="queued",
+ message="OSM relabeling queued.",
+ progress_current=0,
+ progress_total=job.progress_total,
+ )
+ return job
+
+
+def start_route_layer_rebuild_worker(job_id: int) -> None:
+ _ = job_id
+
+
+def create_source_import_job(
+ session: Session,
+ source: Source,
+ *,
+ run_match: bool = True,
+ build_route_layer: bool = True,
+ priority: int = 0,
+ recovery_reason: str | None = None,
+) -> Job:
+ active_job = active_source_import_job(session, source.id)
+ if active_job is not None:
+ return active_job
+ description = f"Import source {source.id}: {source.name}"
+ if recovery_reason:
+ description = f"Recover source {source.id}: {source.name}"
+ result = {
+ "source_id": source.id,
+ "source_name": source.name,
+ "run_match": run_match,
+ "build_route_layer": build_route_layer,
+ "queued_pid": os.getpid(),
+ }
+ if recovery_reason:
+ result["recovery_reason"] = recovery_reason
+ job = Job(
+ kind=SOURCE_IMPORT_JOB_KIND,
+ status="queued",
+ description=description,
+ progress_current=0,
+ progress_total=4 if build_route_layer else 3,
+ priority=int(priority),
+ result_json=json.dumps(result, separators=(",", ":")),
+ )
+ session.add(job)
+ source.status = "queued"
+ source.last_error = None
+ session.flush()
+ add_job_event(
+ session,
+ job,
+ event_type="queued",
+ message=f"Recovery import queued for {source.name}: {recovery_reason}" if recovery_reason else f"Source import queued for {source.name}.",
+ progress_current=0,
+ progress_total=job.progress_total,
+ )
+ return job
+
+
+def queue_missing_gtfs_sidecar_recovery_jobs(session: Session, *, priority: int = 20) -> int:
+ queued = 0
+ seen_source_ids: set[int] = set()
+ datasets = session.scalars(
+ select(Dataset).where(Dataset.kind.in_(["gtfs", "osm_geojson"]), Dataset.is_active.is_(True))
+ ).all()
+ for dataset in datasets:
+ if dataset.kind == "gtfs":
+ storage_kind = "GTFS"
+ missing_paths = gtfs_missing_sidecar_paths(dataset)
+ elif dataset.kind == "osm_geojson":
+ storage_kind = "OSM"
+ missing_paths = osm_missing_sidecar_paths(dataset)
+ else:
+ continue
+ if not missing_paths:
+ continue
+ dataset.status = "missing_files"
+ source = session.get(Source, dataset.source_id)
+ if source is None or not source.enabled or source.id in seen_source_ids:
+ continue
+ seen_source_ids.add(source.id)
+ if (
+ active_source_import_job(session, source.id) is not None
+ or active_source_delete_job(session, source.id) is not None
+ or active_dataset_delete_job(session, dataset.id) is not None
+ ):
+ continue
+ reason = f"{storage_kind} sidecar missing for dataset #{dataset.id}: {', '.join(missing_paths)}"
+ create_source_import_job(
+ session,
+ source,
+ run_match=True,
+ build_route_layer=True,
+ priority=priority,
+ recovery_reason=reason,
+ )
+ queued += 1
+ if queued:
+ session.flush()
+ return queued
+
+
+def create_source_delete_job(session: Session, source: Source, *, priority: int = 50) -> Job:
+ active_job = active_source_delete_job(session, source.id)
+ if active_job is not None:
+ return active_job
+ job = Job(
+ kind=SOURCE_DELETE_JOB_KIND,
+ status="queued",
+ description=f"Delete source {source.id}: {source.name}",
+ progress_current=0,
+ progress_total=3,
+ priority=int(priority),
+ result_json=json.dumps(
+ {
+ "source_id": source.id,
+ "source_name": source.name,
+ "queued_pid": os.getpid(),
+ },
+ separators=(",", ":"),
+ ),
+ )
+ session.add(job)
+ source.status = "queued"
+ source.last_error = None
+ session.flush()
+ add_job_event(
+ session,
+ job,
+ event_type="queued",
+ message=f"Source deletion queued for {source.name}.",
+ progress_current=0,
+ progress_total=job.progress_total,
+ )
+ return job
+
+
+def create_dataset_delete_job(session: Session, dataset: Dataset, *, priority: int = 50) -> Job:
+ active_job = active_dataset_delete_job(session, dataset.id)
+ if active_job is not None:
+ return active_job
+ job = Job(
+ kind=DATASET_DELETE_JOB_KIND,
+ status="queued",
+ description=f"Delete dataset {dataset.id}: {dataset.kind}",
+ progress_current=0,
+ progress_total=3,
+ priority=int(priority),
+ result_json=json.dumps(
+ {
+ "dataset_id": dataset.id,
+ "dataset_kind": dataset.kind,
+ "dataset_status": dataset.status,
+ "source_id": dataset.source_id,
+ "queued_pid": os.getpid(),
+ },
+ separators=(",", ":"),
+ ),
+ )
+ session.add(job)
+ dataset.status = "queued"
+ session.flush()
+ add_job_event(
+ session,
+ job,
+ event_type="queued",
+ message=f"Dataset deletion queued for dataset #{dataset.id}.",
+ progress_current=0,
+ progress_total=job.progress_total,
+ )
+ return job
+
+
+def create_maintenance_job(
+ session: Session,
+ action: str,
+ payload: dict[str, Any] | None = None,
+ *,
+ priority: int = 0,
+) -> Job:
+ normalized_payload = _normalize_job_payload(payload)
+ active_job = active_maintenance_job(session, action, normalized_payload)
+ if active_job is not None:
+ return active_job
+ job = Job(
+ kind=MAINTENANCE_JOB_KIND,
+ status="queued",
+ description=_maintenance_description(action, normalized_payload),
+ progress_current=0,
+ progress_total=_maintenance_progress_total(action),
+ priority=int(priority),
+ result_json=json.dumps(
+ {
+ "action": action,
+ "payload": normalized_payload,
+ "queued_pid": os.getpid(),
+ },
+ separators=(",", ":"),
+ ),
+ )
+ session.add(job)
+ session.flush()
+ add_job_event(
+ session,
+ job,
+ event_type="queued",
+ message=f"{_maintenance_description(action, normalized_payload)} queued.",
+ progress_current=0,
+ progress_total=job.progress_total,
+ )
+ return job
+
+
+def active_source_import_job(session: Session, source_id: int) -> Job | None:
+ for job in session.scalars(
+ select(Job)
+ .where(Job.kind == SOURCE_IMPORT_JOB_KIND, Job.status.in_(ACTIVE_JOB_STATUSES))
+ .order_by(Job.created_at.desc(), Job.id.desc())
+ ).all():
+ if source_id_from_job(job) == source_id:
+ return job
+ return None
+
+
+def active_source_delete_job(session: Session, source_id: int) -> Job | None:
+ for job in session.scalars(
+ select(Job)
+ .where(Job.kind == SOURCE_DELETE_JOB_KIND, Job.status.in_(ACTIVE_JOB_STATUSES))
+ .order_by(Job.created_at.desc(), Job.id.desc())
+ ).all():
+ if source_id_from_job(job) == source_id:
+ return job
+ return None
+
+
+def active_dataset_delete_job(session: Session, dataset_id: int) -> Job | None:
+ return active_dataset_delete_jobs(session).get(dataset_id)
+
+
+def active_dataset_delete_jobs(session: Session) -> dict[int, Job]:
+ jobs_by_dataset: dict[int, Job] = {}
+ for job in session.scalars(
+ select(Job)
+ .where(Job.kind == DATASET_DELETE_JOB_KIND, Job.status.in_(ACTIVE_JOB_STATUSES))
+ .order_by(Job.created_at.desc(), Job.id.desc())
+ ).all():
+ dataset_id = dataset_id_from_job(job)
+ if dataset_id is not None and dataset_id not in jobs_by_dataset:
+ jobs_by_dataset[dataset_id] = job
+ return jobs_by_dataset
+
+
+def active_maintenance_job(session: Session, action: str, payload: dict[str, Any] | None = None) -> Job | None:
+ normalized_payload = _normalize_job_payload(payload)
+ for job in session.scalars(
+ select(Job)
+ .where(Job.kind == MAINTENANCE_JOB_KIND, Job.status.in_(ACTIVE_JOB_STATUSES))
+ .order_by(Job.created_at.desc(), Job.id.desc())
+ ).all():
+ options = _json_object(job.result_json)
+ if options.get("action") == action and _normalize_job_payload(options.get("payload")) == normalized_payload:
+ return job
+ return None
+
+
+def active_source_import_jobs(session: Session) -> dict[int, Job]:
+ jobs_by_source: dict[int, Job] = {}
+ for job in session.scalars(
+ select(Job)
+ .where(Job.kind == SOURCE_IMPORT_JOB_KIND, Job.status.in_(ACTIVE_JOB_STATUSES))
+ .order_by(Job.created_at.desc(), Job.id.desc())
+ ).all():
+ source_id = source_id_from_job(job)
+ if source_id is not None and source_id not in jobs_by_source:
+ jobs_by_source[source_id] = job
+ return jobs_by_source
+
+
+def active_source_workflow_jobs(session: Session) -> dict[int, Job]:
+ jobs_by_source: dict[int, Job] = {}
+ for job in session.scalars(
+ select(Job)
+ .where(
+ Job.kind.in_([SOURCE_IMPORT_JOB_KIND, SOURCE_DELETE_JOB_KIND, DATASET_DELETE_JOB_KIND]),
+ Job.status.in_(ACTIVE_JOB_STATUSES),
+ )
+ .order_by(Job.created_at.desc(), Job.id.desc())
+ ).all():
+ source_id = source_id_from_job(job)
+ if source_id is not None and source_id not in jobs_by_source:
+ jobs_by_source[source_id] = job
+ return jobs_by_source
+
+
+def active_address_index_rebuild_job(session: Session) -> Job | None:
+ return session.scalar(
+ select(Job)
+ .where(Job.kind == ADDRESS_INDEX_JOB_KIND, Job.status.in_(ACTIVE_JOB_STATUSES))
+ .order_by(Job.created_at.desc(), Job.id.desc())
+ .limit(1)
+ )
+
+
+def reconcile_interrupted_jobs(session: Session) -> int:
+ """Requeue interrupted jobs and repair stale active rows with terminal markers."""
+ recovered = 0
+ now = datetime.now(timezone.utc)
+ jobs = session.scalars(
+ select(Job).where(Job.status.in_(["queued", "running"]))
+ ).all()
+ for job in jobs:
+ if _reconcile_terminal_marker(session, job, now):
+ recovered += 1
+ continue
+ if job.status != "running":
+ continue
+ worker_pid = _worker_pid_from_job(job)
+ worker_alive = worker_pid is not None and _pid_running(worker_pid)
+ lease_expires_at = _as_utc(job.lease_expires_at)
+ lease_expired = lease_expires_at is None or lease_expires_at < now
+ if worker_alive and lease_expired:
+ _renew_expired_live_worker_lease(session, job, worker_pid, now)
+ continue
+ if worker_alive or ((not lease_expired) and worker_pid is None):
+ continue
+ reason = "worker_process_exited" if worker_pid is not None else "worker_lease_expired"
+ job.status = "queued"
+ job.requested_action = None
+ job.lease_owner = None
+ job.lease_expires_at = None
+ job.paused_at = None
+ job.updated_at = now
+ job.error = None
+ options = _json_object(job.result_json)
+ options.pop("worker_pid", None)
+ options.pop("worker_id", None)
+ if options:
+ job.result_json = json.dumps(options, separators=(",", ":"))
+ add_job_event(
+ session,
+ job,
+ event_type="lease_expired",
+ message="Worker is no longer active; job returned to the queue.",
+ progress_current=job.progress_current,
+ progress_total=job.progress_total,
+ metadata={"reason": reason, "worker_pid": worker_pid},
+ )
+ source = _job_source(session, job)
+ if source is not None:
+ source.status = "queued"
+ source.last_error = None
+ recovered += 1
+ if recovered:
+ session.flush()
+ reconcile_source_workflow_state(session)
+ return recovered
+
+
+def _reconcile_terminal_marker(session: Session, job: Job, now: datetime) -> bool:
+ terminal_status = _terminal_status_from_marker(session, job)
+ if terminal_status is None:
+ return False
+ previous_status = job.status
+ job.status = terminal_status
+ job.requested_action = None
+ job.lease_owner = None
+ job.lease_expires_at = None
+ job.paused_at = None
+ job.updated_at = now
+ if job.finished_at is None:
+ job.finished_at = now
+ if terminal_status == "completed":
+ job.error = None
+ if job.progress_total > 0:
+ job.progress_current = job.progress_total
+ _clear_job_control_request(job.id)
+ add_job_event(
+ session,
+ job,
+ event_type="terminal_reconciled",
+ message=f"Stale {previous_status} job had already reached {terminal_status}; kept it out of the queue.",
+ progress_current=job.progress_current,
+ progress_total=job.progress_total,
+ metadata={"previous_status": previous_status, "terminal_status": terminal_status},
+ )
+ source = _job_source(session, job)
+ if source is not None:
+ if terminal_status == "completed":
+ source.status = _source_status_without_active_job(session, source)
+ source.last_error = None
+ source.last_run_at = job.finished_at
+ elif terminal_status == "failed":
+ source.status = "error"
+ source.last_error = job.error
+ elif terminal_status == "cancelled":
+ source.status = _source_status_without_active_job(session, source)
+ dataset = _job_dataset(session, job)
+ if dataset is not None and terminal_status in {"completed", "cancelled"}:
+ dataset.status = str(_json_object(job.result_json).get("dataset_status") or "imported")
+ return True
+
+
+def _terminal_status_from_marker(session: Session, job: Job) -> str | None:
+ latest_terminal_status = _latest_terminal_event_status(session, job.id)
+ if latest_terminal_status is not None:
+ return latest_terminal_status
+ if job.finished_at is not None:
+ return "failed" if job.error else "completed"
+ if _job_has_completed_result_marker(job):
+ return "completed"
+ latest_event = session.scalar(
+ select(JobEvent).where(JobEvent.job_id == job.id).order_by(JobEvent.created_at.desc(), JobEvent.id.desc()).limit(1)
+ )
+ if latest_event is None:
+ return None
+ return _status_from_terminal_event(latest_event.event_type)
+
+
+def _latest_terminal_event_status(session: Session, job_id: int) -> str | None:
+ event = session.scalar(
+ select(JobEvent)
+ .where(JobEvent.job_id == job_id, JobEvent.event_type.in_(["completed", "failed", "cancelled"]))
+ .order_by(JobEvent.created_at.desc(), JobEvent.id.desc())
+ .limit(1)
+ )
+ if event is None:
+ return None
+ return _status_from_terminal_event(event.event_type)
+
+
+def _status_from_terminal_event(event_type: str) -> str | None:
+ if event_type == "completed":
+ return "completed"
+ if event_type == "failed":
+ return "failed"
+ if event_type == "cancelled":
+ return "cancelled"
+ return None
+
+
+def _job_has_completed_result_marker(job: Job) -> bool:
+ options = _json_object(job.result_json)
+ if job.kind == SOURCE_IMPORT_JOB_KIND:
+ if "dataset_id" not in options:
+ return False
+ if options.get("run_match") and "match_result" not in options:
+ return False
+ if options.get("build_route_layer") and "route_layer_result" not in options:
+ return False
+ return True
+ if job.kind == OSM_RELABEL_JOB_KIND:
+ if "relabel_result" not in options:
+ return False
+ if options.get("build_route_layer") and "route_layer_result" not in options:
+ return False
+ return True
+ return False
+
+
+def reconcile_source_workflow_state(session: Session) -> int:
+ active_jobs = active_source_workflow_jobs(session)
+ changed = 0
+ sources = session.scalars(select(Source).where(Source.status.in_(["queued", "running", "paused"]))).all()
+ for source in sources:
+ active_job = active_jobs.get(source.id)
+ if active_job is not None:
+ expected = active_job.status
+ if source.status != expected:
+ source.status = expected
+ changed += 1
+ continue
+ replacement = _source_status_without_active_job(session, source)
+ if source.status != replacement:
+ source.status = replacement
+ changed += 1
+ if replacement != "error" and source.last_error == "Job was interrupted before completion.":
+ source.last_error = None
+ changed += 1
+ if changed:
+ session.flush()
+ return changed
+
+
+def source_id_from_job(job: Job) -> int | None:
+ value = _json_object(job.result_json).get("source_id")
+ try:
+ return None if value is None else int(value)
+ except (TypeError, ValueError):
+ return None
+
+
+def dataset_id_from_job(job: Job) -> int | None:
+ value = _json_object(job.result_json).get("dataset_id")
+ try:
+ return None if value is None else int(value)
+ except (TypeError, ValueError):
+ return None
+
+
+def _worker_pid_from_job(job: Job) -> int | None:
+ value = _json_object(job.result_json).get("worker_pid")
+ try:
+ pid = int(value)
+ except (TypeError, ValueError):
+ return None
+ return pid if pid > 0 else None
+
+
+def _pid_running(pid: int) -> bool:
+ try:
+ os.kill(pid, 0)
+ except ProcessLookupError:
+ return False
+ except PermissionError:
+ return True
+ return True
+
+
+def _renew_expired_live_worker_lease(session: Session, job: Job, worker_pid: int, now: datetime) -> None:
+ job.lease_expires_at = now + timedelta(seconds=LEASE_SECONDS)
+ job.updated_at = now
+ source = _job_source(session, job)
+ if source is not None:
+ source.status = "running"
+ source.last_error = None
+ add_job_event(
+ session,
+ job,
+ event_type="lease_renewed",
+ message="Worker process is still alive; renewed expired lease.",
+ progress_current=job.progress_current,
+ progress_total=job.progress_total,
+ metadata={"worker_pid": worker_pid},
+ )
+
+
+def _source_status_without_active_job(session: Session, source: Source) -> str:
+ active_dataset = session.scalar(
+ select(Dataset)
+ .where(Dataset.source_id == source.id, Dataset.is_active.is_(True), Dataset.status == "imported")
+ .order_by(Dataset.created_at.desc(), Dataset.id.desc())
+ .limit(1)
+ )
+ if active_dataset is None:
+ return "error" if source.last_error else "new"
+ latest_check = session.scalar(
+ select(SourceUpdateCheck)
+ .where(SourceUpdateCheck.source_id == source.id)
+ .order_by(SourceUpdateCheck.checked_at.desc(), SourceUpdateCheck.id.desc())
+ .limit(1)
+ )
+ if latest_check is not None and latest_check.status == "checked":
+ return "update_available" if latest_check.update_available else "up_to_date"
+ return "ok"
+
+
+def start_source_import_worker(job_id: int) -> None:
+ _ = job_id
+
+
+def job_payload(job: Job) -> dict[str, Any]:
+ return {
+ "id": job.id,
+ "kind": job.kind,
+ "status": job.status,
+ "description": job.description,
+ "progress_current": job.progress_current,
+ "progress_total": job.progress_total,
+ "priority": job.priority,
+ "requested_action": job.requested_action,
+ "lease_owner": job.lease_owner,
+ "lease_expires_at": _iso(job.lease_expires_at),
+ "paused_at": _iso(job.paused_at),
+ "result": _json_object(job.result_json),
+ "error": job.error,
+ "dismissed_at": _iso(job.dismissed_at),
+ "created_at": _iso(job.created_at),
+ "started_at": _iso(job.started_at),
+ "updated_at": _iso(job.updated_at),
+ "finished_at": _iso(job.finished_at),
+ "terminal": job.status in TERMINAL_JOB_STATUSES,
+ }
+
+
+def job_event_payload(event: JobEvent) -> dict[str, Any]:
+ return {
+ "id": event.id,
+ "job_id": event.job_id,
+ "level": event.level,
+ "event_type": event.event_type,
+ "message": event.message,
+ "progress_current": event.progress_current,
+ "progress_total": event.progress_total,
+ "metadata": _json_object(event.metadata_json),
+ "created_at": _iso(event.created_at),
+ }
+
+
+def add_job_event(
+ session: Session,
+ job: Job,
+ *,
+ event_type: str,
+ message: str,
+ level: str = "info",
+ progress_current: int | None = None,
+ progress_total: int | None = None,
+ metadata: dict[str, Any] | None = None,
+) -> JobEvent:
+ event = JobEvent(
+ job_id=job.id,
+ level=level,
+ event_type=event_type,
+ message=message,
+ progress_current=progress_current,
+ progress_total=progress_total,
+ metadata_json=None if metadata is None else json.dumps(metadata, separators=(",", ":")),
+ )
+ session.add(event)
+ return event
+
+
+def latest_jobs(session: Session, *, limit: int = 20, kind: str | None = None, include_dismissed: bool = False) -> list[Job]:
+ stmt = select(Job).order_by(Job.created_at.desc(), Job.id.desc())
+ if kind:
+ stmt = stmt.where(Job.kind == kind)
+ if not include_dismissed:
+ stmt = stmt.where(Job.dismissed_at.is_(None))
+ return session.scalars(stmt.limit(max(1, min(limit, 100)))).all()
+
+
+def job_queue_revision(session: Session, *, include_dismissed: bool = False) -> dict[str, Any]:
+ filters = []
+ if not include_dismissed:
+ filters.append(Job.dismissed_at.is_(None))
+ job_stmt = select(
+ func.count(Job.id),
+ func.coalesce(func.max(Job.id), 0),
+ func.max(Job.created_at),
+ func.max(Job.updated_at),
+ func.max(Job.finished_at),
+ func.max(Job.dismissed_at),
+ func.count(Job.id).filter(Job.status.in_(ACTIVE_JOB_STATUSES)),
+ func.count(Job.id).filter(Job.status == "queued"),
+ func.count(Job.id).filter(Job.status == "running"),
+ func.count(Job.id).filter(Job.status == "paused"),
+ )
+ if filters:
+ job_stmt = job_stmt.where(*filters)
+ (
+ job_count,
+ latest_job_id,
+ latest_job_created_at,
+ latest_job_updated_at,
+ latest_job_finished_at,
+ latest_job_dismissed_at,
+ active_count,
+ queued_count,
+ running_count,
+ paused_count,
+ ) = session.execute(job_stmt).one()
+
+ event_stmt = select(func.coalesce(func.max(JobEvent.id), 0), func.max(JobEvent.created_at)).select_from(JobEvent)
+ if not include_dismissed:
+ event_stmt = event_stmt.join(Job, Job.id == JobEvent.job_id).where(Job.dismissed_at.is_(None))
+ latest_event_id, latest_event_created_at = session.execute(event_stmt).one()
+
+ revision_parts = [
+ int(job_count or 0),
+ int(latest_job_id or 0),
+ _revision_datetime(latest_job_created_at),
+ _revision_datetime(latest_job_updated_at),
+ _revision_datetime(latest_job_finished_at),
+ _revision_datetime(latest_job_dismissed_at),
+ int(latest_event_id or 0),
+ _revision_datetime(latest_event_created_at),
+ int(active_count or 0),
+ int(queued_count or 0),
+ int(running_count or 0),
+ int(paused_count or 0),
+ ]
+ return {
+ "revision": "|".join(str(part) for part in revision_parts),
+ "job_count": int(job_count or 0),
+ "latest_job_id": int(latest_job_id or 0),
+ "latest_event_id": int(latest_event_id or 0),
+ "active_count": int(active_count or 0),
+ "queued_count": int(queued_count or 0),
+ "running_count": int(running_count or 0),
+ "paused_count": int(paused_count or 0),
+ "latest_job_created_at": _iso(latest_job_created_at),
+ "latest_job_updated_at": _iso(latest_job_updated_at),
+ "latest_job_finished_at": _iso(latest_job_finished_at),
+ "latest_event_created_at": _iso(latest_event_created_at),
+ }
+
+
+def job_events(session: Session, job_id: int, *, limit: int = 100) -> list[JobEvent]:
+ return session.scalars(
+ select(JobEvent)
+ .where(JobEvent.job_id == job_id)
+ .order_by(JobEvent.created_at, JobEvent.id)
+ .limit(max(1, min(limit, 500)))
+ ).all()
+
+
+def request_job_control(job_id: int, action: str) -> dict[str, Any]:
+ if action not in {"pause", "cancel"}:
+ raise ValueError(f"unsupported job control action: {action}")
+ requested_at = datetime.now(timezone.utc)
+ payload = {
+ "job_id": int(job_id),
+ "requested_action": action,
+ "requested_at": requested_at.isoformat(),
+ "request_pid": os.getpid(),
+ }
+ _write_job_control_request(job_id, payload)
+ return {
+ "id": int(job_id),
+ "status": "running",
+ "requested_action": action,
+ "control_request_queued": True,
+ "terminal": False,
+ "updated_at": requested_at.isoformat(),
+ "result": {},
+ }
+
+
+def run_worker_loop(
+ *,
+ worker_id: str | None = None,
+ poll_interval: float = 2.0,
+ max_jobs: int | None = None,
+ once: bool = False,
+) -> dict[str, int | str]:
+ init_db()
+ worker = worker_id or f"worker-{os.getpid()}-{uuid4().hex[:8]}"
+ processed = 0
+ while True:
+ with SessionLocal() as session:
+ reconcile_interrupted_jobs(session)
+ session.commit()
+ try:
+ job_id = claim_next_job(worker)
+ except DatabaseWriteBusy:
+ if once:
+ return {"worker_id": worker, "processed": processed}
+ time.sleep(max(0.2, float(poll_interval)))
+ continue
+ if job_id is None:
+ if once:
+ return {"worker_id": worker, "processed": processed}
+ time.sleep(max(0.2, float(poll_interval)))
+ continue
+ run_claimed_job(job_id, worker)
+ processed += 1
+ if max_jobs is not None and processed >= max_jobs:
+ return {"worker_id": worker, "processed": processed}
+
+
+def run_worker_once(*, worker_id: str | None = None) -> dict[str, int | str]:
+ return run_worker_loop(worker_id=worker_id, once=True, max_jobs=1)
+
+
+def claim_next_job(worker_id: str, *, lease_seconds: int = LEASE_SECONDS) -> int | None:
+ with database_write_lock("job:claim", timeout=30):
+ with SessionLocal() as session:
+ reconcile_interrupted_jobs(session)
+ job = session.scalar(
+ select(Job)
+ .where(Job.status == "queued")
+ .order_by(Job.priority.desc(), Job.created_at, Job.id)
+ .limit(1)
+ )
+ if job is None:
+ session.commit()
+ return None
+ now = datetime.now(timezone.utc)
+ job.status = "running"
+ job.requested_action = None
+ job.lease_owner = worker_id
+ job.lease_expires_at = now + timedelta(seconds=lease_seconds)
+ job.paused_at = None
+ job.error = None
+ if job.started_at is None:
+ job.started_at = now
+ job.updated_at = now
+ add_job_event(
+ session,
+ job,
+ event_type="claimed",
+ message=f"Job claimed by {worker_id}.",
+ progress_current=job.progress_current,
+ progress_total=job.progress_total,
+ metadata={"worker_id": worker_id},
+ )
+ source = _job_source(session, job)
+ if source is not None:
+ source.status = "running"
+ source.last_error = None
+ dataset = _job_dataset(session, job)
+ if dataset is not None:
+ dataset.status = "running"
+ session.commit()
+ return int(job.id)
+
+
+def run_claimed_job(job_id: int, worker_id: str) -> None:
+ init_db()
+ try:
+ with _job_heartbeat_context(job_id, worker_id):
+ with SessionLocal() as session:
+ job = session.get(Job, job_id)
+ if job is None:
+ return
+ if job.lease_owner != worker_id:
+ raise RuntimeError(f"job #{job_id} is not leased by this worker")
+ if job.kind == ROUTE_MATCHING_JOB_KIND:
+ _run_route_matching_job(job_id, worker_id)
+ elif job.kind == ROUTE_LAYER_JOB_KIND:
+ _run_route_layer_rebuild_job(job_id, worker_id)
+ elif job.kind == ADDRESS_INDEX_JOB_KIND:
+ _run_address_index_rebuild_job(job_id, worker_id)
+ elif job.kind == OSM_RELABEL_JOB_KIND:
+ _run_osm_relabel_job(job_id, worker_id)
+ elif job.kind == SOURCE_IMPORT_JOB_KIND:
+ _run_source_import_job(job_id, worker_id)
+ elif job.kind == SOURCE_DELETE_JOB_KIND:
+ _run_source_delete_job(job_id, worker_id)
+ elif job.kind == DATASET_DELETE_JOB_KIND:
+ _run_dataset_delete_job(job_id, worker_id)
+ elif job.kind == MAINTENANCE_JOB_KIND:
+ _run_maintenance_job(job_id, worker_id)
+ else:
+ raise ValueError(f"unsupported job kind: {job.kind}")
+ except JobPaused:
+ return
+ except JobCancelled:
+ _mark_job_cancelled(job_id)
+ except Exception as exc: # noqa: BLE001 - surfaced through job status UI
+ _mark_job_failed(job_id, exc)
+
+
+@contextmanager
+def _job_heartbeat_context(job_id: int, worker_id: str) -> Iterator[None]:
+ stop_event = threading.Event()
+ interval = max(10.0, min(float(HEARTBEAT_INTERVAL_SECONDS), float(LEASE_SECONDS) / 3))
+ thread = threading.Thread(
+ target=_job_heartbeat_loop,
+ args=(job_id, worker_id, stop_event, interval),
+ name=f"job-heartbeat-{job_id}",
+ daemon=True,
+ )
+ thread.start()
+ try:
+ yield
+ finally:
+ stop_event.set()
+ thread.join(timeout=5)
+
+
+def _job_heartbeat_loop(job_id: int, worker_id: str, stop_event: threading.Event, interval: float) -> None:
+ while not stop_event.wait(interval):
+ try:
+ with SessionLocal() as session:
+ job = session.get(Job, job_id)
+ if job is None or job.status != "running" or job.lease_owner != worker_id:
+ continue
+ _heartbeat_job(job, worker_id)
+ session.commit()
+ except Exception:
+ # Best-effort liveness refresh. A normal progress callback or the
+ # next loop will renew the lease if this short write collides.
+ continue
+
+
+def pause_job(session: Session, job_id: int) -> Job:
+ job = _get_job_or_raise(session, job_id)
+ if job.status == "queued":
+ _mark_job_paused(session, job, "Job paused before it was claimed.")
+ elif job.status == "running":
+ job.requested_action = "pause"
+ job.updated_at = datetime.now(timezone.utc)
+ _write_job_control_request(
+ job.id,
+ {"job_id": job.id, "requested_action": "pause", "requested_at": job.updated_at.isoformat(), "request_pid": os.getpid()},
+ )
+ add_job_event(session, job, event_type="pause_requested", message="Pause requested.")
+ elif job.status != "paused":
+ raise ValueError(f"cannot pause job in status {job.status}")
+ session.flush()
+ return job
+
+
+def resume_job(session: Session, job_id: int) -> Job:
+ job = _get_job_or_raise(session, job_id)
+ if job.status != "paused":
+ raise ValueError(f"cannot resume job in status {job.status}")
+ job.status = "queued"
+ job.requested_action = None
+ job.lease_owner = None
+ job.lease_expires_at = None
+ job.paused_at = None
+ job.updated_at = datetime.now(timezone.utc)
+ _clear_job_control_request(job.id)
+ add_job_event(session, job, event_type="resumed", message="Job returned to the queue.")
+ source = _job_source(session, job)
+ if source is not None:
+ source.status = "queued"
+ source.last_error = None
+ dataset = _job_dataset(session, job)
+ if dataset is not None:
+ dataset.status = "queued"
+ session.flush()
+ return job
+
+
+def retry_job(session: Session, job_id: int) -> Job:
+ job = _get_job_or_raise(session, job_id)
+ if job.status not in TERMINAL_JOB_STATUSES:
+ raise ValueError(f"cannot retry job in status {job.status}")
+ now = datetime.now(timezone.utc)
+ job.status = "queued"
+ job.requested_action = None
+ job.lease_owner = None
+ job.lease_expires_at = None
+ job.paused_at = None
+ job.error = None
+ job.dismissed_at = None
+ job.started_at = None
+ job.finished_at = None
+ job.progress_current = 0
+ job.updated_at = now
+ options = _json_object(job.result_json)
+ options.pop("worker_pid", None)
+ options.pop("worker_id", None)
+ job.result_json = json.dumps(options, separators=(",", ":")) if options else None
+ _clear_job_control_request(job.id)
+ add_job_event(
+ session,
+ job,
+ event_type="retried",
+ message="Job returned to the queue for retry.",
+ progress_current=job.progress_current,
+ progress_total=job.progress_total,
+ )
+ source = _job_source(session, job)
+ if source is not None:
+ source.status = "queued"
+ source.last_error = None
+ dataset = _job_dataset(session, job)
+ if dataset is not None:
+ dataset.status = "queued"
+ session.flush()
+ return job
+
+
+def cancel_job(session: Session, job_id: int) -> Job:
+ job = _get_job_or_raise(session, job_id)
+ if job.status in TERMINAL_JOB_STATUSES:
+ return job
+ if job.status == "running":
+ job.requested_action = "cancel"
+ job.updated_at = datetime.now(timezone.utc)
+ _write_job_control_request(
+ job.id,
+ {"job_id": job.id, "requested_action": "cancel", "requested_at": job.updated_at.isoformat(), "request_pid": os.getpid()},
+ )
+ add_job_event(session, job, event_type="cancel_requested", message="Stop requested.")
+ else:
+ _finish_job_cancelled(session, job)
+ session.flush()
+ return job
+
+
+def dismiss_job(session: Session, job_id: int) -> Job:
+ job = _get_job_or_raise(session, job_id)
+ if job.status not in TERMINAL_JOB_STATUSES:
+ raise ValueError(f"cannot dismiss job in status {job.status}")
+ if job.dismissed_at is None:
+ now = datetime.now(timezone.utc)
+ job.dismissed_at = now
+ job.updated_at = now
+ add_job_event(session, job, event_type="dismissed", message="Job dismissed from the default jobs view.")
+ session.flush()
+ return job
+
+
+def dismiss_terminal_jobs(session: Session) -> int:
+ now = datetime.now(timezone.utc)
+ jobs = session.scalars(
+ select(Job)
+ .where(Job.status.in_(TERMINAL_JOB_STATUSES), Job.dismissed_at.is_(None))
+ .order_by(Job.created_at.desc(), Job.id.desc())
+ ).all()
+ for job in jobs:
+ job.dismissed_at = now
+ job.updated_at = now
+ add_job_event(session, job, event_type="dismissed", message="Job dismissed from the default jobs view.")
+ if jobs:
+ session.flush()
+ return len(jobs)
+
+
+def set_job_priority(session: Session, job_id: int, priority: int) -> Job:
+ job = _get_job_or_raise(session, job_id)
+ job.priority = int(priority)
+ job.updated_at = datetime.now(timezone.utc)
+ add_job_event(session, job, event_type="priority_changed", message=f"Priority changed to {job.priority}.", metadata={"priority": job.priority})
+ session.flush()
+ return job
+
+
+def _run_route_matching_job(job_id: int, worker_id: str) -> None:
+ init_db()
+ with database_write_lock(f"job:{ROUTE_MATCHING_JOB_KIND}:{job_id}", timeout=3600):
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ _job_running(session, job, worker_id, "started", "Route matching started.", 0)
+
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ _check_job_control(session, job)
+ progress_callback = _job_progress_callback(session, job, worker_id, update_job_progress=True)
+ result = run_route_matching(session, progress_callback=progress_callback)
+ job = _job_for_worker(session, job_id, worker_id)
+ _complete_job(session, job, "Route matching completed.", result)
+ session.commit()
+
+
+def _run_route_layer_rebuild_job(job_id: int, worker_id: str) -> None:
+ init_db()
+ with database_write_lock(f"job:{ROUTE_LAYER_JOB_KIND}:{job_id}", timeout=3600):
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ _job_running(session, job, worker_id, "started", "Route-layer rebuild started.", 1)
+
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ _job_running(session, job, worker_id, "rebuilding", "Extracting canonical stops and route patterns.", 2)
+
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ _check_job_control(session, job)
+ progress_callback = _job_progress_callback(session, job, worker_id)
+ result = rebuild_route_layer(session, progress_callback=progress_callback)
+ job = _job_for_worker(session, job_id, worker_id)
+ _complete_job(session, job, "Route-layer rebuild completed.", result)
+ session.commit()
+
+
+def _run_address_index_rebuild_job(job_id: int, worker_id: str) -> None:
+ init_db()
+ with database_write_lock(f"job:{ADDRESS_INDEX_JOB_KIND}:{job_id}", timeout=3600):
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ _job_running(session, job, worker_id, "started", "OSM address index rebuild started.", 1)
+
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ _job_running(session, job, worker_id, "rebuilding", "Extracting searchable addresses from OSM.", 2)
+
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ _check_job_control(session, job)
+ progress_callback = _job_progress_callback(session, job, worker_id)
+ result = rebuild_address_index(session, progress_callback=progress_callback)
+ job = _job_for_worker(session, job_id, worker_id)
+ _complete_job(session, job, "OSM address index rebuild completed.", result)
+ session.commit()
+
+
+def _run_osm_relabel_job(job_id: int, worker_id: str) -> None:
+ init_db()
+ with database_write_lock(f"job:{OSM_RELABEL_JOB_KIND}:{job_id}", timeout=3600):
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ options = _json_object(job.result_json)
+ _job_running(session, job, worker_id, "started", "OSM feature relabeling started.", 0)
+ progress_callback = _job_progress_callback(session, job, worker_id, update_job_progress=False)
+ relabel_result = relabel_osm_features(
+ session,
+ dataset_id=_optional_int(options.get("dataset_id")),
+ force=bool(options.get("force")),
+ progress_callback=progress_callback,
+ job_id=job.id,
+ )
+ job = _job_for_worker(session, job_id, worker_id)
+ options = _json_object(job.result_json)
+ options["relabel_result"] = relabel_result
+ job.result_json = json.dumps(options, separators=(",", ":"))
+ job.progress_current = 1
+ _heartbeat_job(job, worker_id)
+ add_job_event(
+ session,
+ job,
+ event_type="osm_relabel_completed",
+ message="OSM feature relabeling completed.",
+ progress_current=1,
+ progress_total=job.progress_total,
+ metadata=relabel_result,
+ )
+ _check_job_control(session, job)
+ session.commit()
+
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ options = _json_object(job.result_json)
+ if options.get("build_route_layer"):
+ _job_running(session, job, worker_id, "rebuilding_route_layer", "Rebuilding route layer after OSM relabeling.", 1)
+ progress_callback = _job_progress_callback(session, job, worker_id)
+ route_layer_result = rebuild_route_layer(session, progress_callback=progress_callback)
+ job = _job_for_worker(session, job_id, worker_id)
+ options = _json_object(job.result_json)
+ options["route_layer_result"] = route_layer_result
+ job.result_json = json.dumps(options, separators=(",", ":"))
+ _heartbeat_job(job, worker_id)
+ add_job_event(
+ session,
+ job,
+ event_type="route_layer_rebuilt",
+ message="Route layer rebuilt after OSM relabeling.",
+ progress_current=job.progress_total,
+ progress_total=job.progress_total,
+ metadata=route_layer_result,
+ )
+ _check_job_control(session, job)
+ session.commit()
+
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ _complete_job(session, job, "OSM relabel job completed.", _json_object(job.result_json))
+ session.commit()
+
+
+def _run_source_import_job(job_id: int, worker_id: str) -> None:
+ init_db()
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ options = _json_object(job.result_json)
+ source = session.get(Source, int(options.get("source_id") or 0))
+ if source is None:
+ raise ValueError("source not found for import job")
+ source.status = "running"
+ source.last_error = None
+ source.last_run_at = datetime.now(timezone.utc)
+ options["worker_pid"] = os.getpid()
+ options["worker_id"] = worker_id
+ job.result_json = json.dumps(options, separators=(",", ":"))
+ _job_running(session, job, worker_id, "started", f"Importing source {source.name}.", 1)
+ progress_callback = _job_progress_callback(session, job, worker_id)
+ if source.kind == "osm_pbf":
+ dataset = run_osm_pbf_source_staged(source.id, progress_callback=progress_callback)
+ else:
+ dataset = run_source(session, source, progress_callback=progress_callback)
+ result = {**options, "dataset_id": dataset.id, "dataset_kind": dataset.kind, "dataset_status": dataset.status}
+ job = _job_for_worker(session, job_id, worker_id)
+ job.result_json = json.dumps(result, separators=(",", ":"))
+ job.progress_current = 2
+ _heartbeat_job(job, worker_id)
+ add_job_event(
+ session,
+ job,
+ event_type="source_imported",
+ message=f"Imported dataset #{dataset.id} from {source.name}.",
+ progress_current=2,
+ progress_total=job.progress_total,
+ metadata=result,
+ )
+ _check_job_control(session, job)
+ session.commit()
+
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ options = _json_object(job.result_json)
+ if options.get("run_match"):
+ _job_running(session, job, worker_id, "matching", "Running route matcher after import.", 3)
+ progress_callback = _job_progress_callback(session, job, worker_id)
+ match_result = run_route_matching(session, progress_callback=progress_callback)
+ job = _job_for_worker(session, job_id, worker_id)
+ options = _json_object(job.result_json)
+ options["match_result"] = match_result
+ job.result_json = json.dumps(options, separators=(",", ":"))
+ _heartbeat_job(job, worker_id)
+ add_job_event(
+ session,
+ job,
+ event_type="matched",
+ message="Route matcher completed.",
+ progress_current=3,
+ progress_total=job.progress_total,
+ metadata=match_result,
+ )
+ _check_job_control(session, job)
+ session.commit()
+
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ options = _json_object(job.result_json)
+ if options.get("build_route_layer"):
+ _job_running(session, job, worker_id, "rebuilding_route_layer", "Rebuilding route layer after source import.", job.progress_total - 1)
+ progress_callback = _job_progress_callback(session, job, worker_id)
+ route_layer_result = rebuild_route_layer(session, progress_callback=progress_callback)
+ job = _job_for_worker(session, job_id, worker_id)
+ options = _json_object(job.result_json)
+ options["route_layer_result"] = route_layer_result
+ job.result_json = json.dumps(options, separators=(",", ":"))
+ _heartbeat_job(job, worker_id)
+ add_job_event(
+ session,
+ job,
+ event_type="route_layer_rebuilt",
+ message="Route layer rebuilt after import.",
+ progress_current=job.progress_total - 1,
+ progress_total=job.progress_total,
+ metadata=route_layer_result,
+ )
+ _check_job_control(session, job)
+ session.commit()
+
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ source = _job_source(session, job)
+ if source is not None:
+ source.status = "ok"
+ source.last_error = None
+ source.last_run_at = datetime.now(timezone.utc)
+ _complete_job(session, job, "Source import job completed.", _json_object(job.result_json))
+ session.commit()
+
+
+def _run_source_delete_job(job_id: int, worker_id: str) -> None:
+ init_db()
+ delete_result: dict[str, Any] = {}
+ with database_write_lock(f"job:{SOURCE_DELETE_JOB_KIND}:{job_id}", timeout=3600):
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ options = _json_object(job.result_json)
+ source_id = _optional_int(options.get("source_id"))
+ if source_id is None:
+ raise ValueError("source not found for delete job")
+ source = session.get(Source, source_id)
+ if source is None:
+ result = {**options, "delete_result": {"deleted": False, "reason": "source not found", "source_id": source_id}}
+ _complete_job(session, job, "Source delete job completed; source was already absent.", result)
+ session.commit()
+ return
+ source.status = "running"
+ source.last_error = None
+ options["worker_pid"] = os.getpid()
+ options["worker_id"] = worker_id
+ job.result_json = json.dumps(options, separators=(",", ":"))
+ _job_running(session, job, worker_id, "started", f"Deleting source {source.name}.", 1)
+
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ options = _json_object(job.result_json)
+ source_id = int(options["source_id"])
+ delete_result = delete_source(session, source_id)
+ job.progress_current = 2
+ _heartbeat_job(job, worker_id)
+ add_job_event(
+ session,
+ job,
+ event_type="source_deleted",
+ message="Source rows and datasets deleted.",
+ progress_current=2,
+ progress_total=job.progress_total,
+ metadata=delete_result,
+ )
+ _check_job_control(session, job)
+ session.commit()
+
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ _job_running(session, job, worker_id, "pruning_cache", "Pruning unreferenced cache files.", 3)
+ cache = prune_unreferenced_cache_files(session)
+ result = {**_json_object(job.result_json), "delete_result": delete_result, "cache_pruned": cache}
+ _complete_job(session, job, "Source delete job completed.", result)
+ session.commit()
+
+
+def _run_dataset_delete_job(job_id: int, worker_id: str) -> None:
+ init_db()
+ delete_result: dict[str, Any] = {}
+ with database_write_lock(f"job:{DATASET_DELETE_JOB_KIND}:{job_id}", timeout=3600):
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ options = _json_object(job.result_json)
+ dataset_id = _optional_int(options.get("dataset_id"))
+ if dataset_id is None:
+ raise ValueError("dataset not found for delete job")
+ dataset = session.get(Dataset, dataset_id)
+ if dataset is None:
+ result = {**options, "delete_result": {"deleted": False, "reason": "dataset not found", "dataset_id": dataset_id}}
+ _complete_job(session, job, "Dataset delete job completed; dataset was already absent.", result)
+ session.commit()
+ return
+ dataset.status = "running"
+ options["worker_pid"] = os.getpid()
+ options["worker_id"] = worker_id
+ job.result_json = json.dumps(options, separators=(",", ":"))
+ _job_running(session, job, worker_id, "started", f"Deleting dataset #{dataset.id}.", 1)
+
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ options = _json_object(job.result_json)
+ dataset_id = int(options["dataset_id"])
+ source_id = _optional_int(options.get("source_id"))
+ delete_result = delete_dataset(session, dataset_id)
+ if source_id is not None:
+ source = session.get(Source, source_id)
+ if source is not None:
+ source.status = _source_status_without_active_job(session, source)
+ job.progress_current = 2
+ _heartbeat_job(job, worker_id)
+ add_job_event(
+ session,
+ job,
+ event_type="dataset_deleted",
+ message="Dataset rows and files deleted.",
+ progress_current=2,
+ progress_total=job.progress_total,
+ metadata=delete_result,
+ )
+ _check_job_control(session, job)
+ session.commit()
+
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ _job_running(session, job, worker_id, "pruning_cache", "Pruning unreferenced cache files.", 3)
+ cache = prune_unreferenced_cache_files(session)
+ result = {**_json_object(job.result_json), "delete_result": delete_result, "cache_pruned": cache}
+ _complete_job(session, job, "Dataset delete job completed.", result)
+ session.commit()
+
+
+def _run_maintenance_job(job_id: int, worker_id: str) -> None:
+ init_db()
+ with database_write_lock(f"job:{MAINTENANCE_JOB_KIND}:{job_id}", timeout=3600):
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ options = _json_object(job.result_json)
+ action = str(options.get("action") or "")
+ payload = _normalize_job_payload(options.get("payload"))
+ if not action:
+ raise ValueError("maintenance action is missing")
+ options["worker_pid"] = os.getpid()
+ options["worker_id"] = worker_id
+ job.result_json = json.dumps(options, separators=(",", ":"))
+ _job_running(session, job, worker_id, "started", f"{_maintenance_description(action, payload)} started.", 1)
+
+ result = _run_maintenance_action(job_id, worker_id, action, payload)
+
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ _complete_job(session, job, f"{_maintenance_description(action, payload)} completed.", result)
+ session.commit()
+
+
+def _run_maintenance_action(job_id: int, worker_id: str, action: str, payload: dict[str, Any]) -> dict[str, Any]:
+ if action == "init-db":
+ init_db()
+ return {"action": action, "payload": payload, "result": {"status": "initialized"}}
+ if action == "vacuum-db":
+ with engine.connect().execution_options(isolation_level="AUTOCOMMIT") as connection:
+ connection.execute(text("VACUUM"))
+ connection.execute(text("PRAGMA wal_checkpoint(TRUNCATE)"))
+ return {"action": action, "payload": payload, "result": {"status": "vacuumed"}}
+
+ with SessionLocal() as session:
+ job = _job_for_worker(session, job_id, worker_id)
+ _check_job_control(session, job)
+ if action == "sample-reset":
+ result = load_sample_project(session, preserve_job_id=job_id)
+ elif action == "reset-db":
+ clear_project_data(session, preserve_job_id=job_id, preserve_catalog=False)
+ result = {"status": "reset", "preserved_job_id": job_id}
+ elif action == "backfill-gtfs-shapes":
+ result = backfill_gtfs_shapes(session, dataset_id=_optional_int(payload.get("dataset_id")))
+ elif action == "prune-cache":
+ dry_run = bool(payload.get("dry_run", True))
+ if dry_run:
+ result = {"dry_run": True, **unreferenced_cache_file_summary(session)}
+ else:
+ result = {"dry_run": False, **prune_unreferenced_cache_files(session)}
+ elif action == "prune-inactive-datasets":
+ result = prune_inactive_datasets(session, dry_run=bool(payload.get("dry_run", True)))
+ elif action == "source-catalog-import":
+ imported = import_source_catalog(session, payload.get("csv_path"), update_existing=bool(payload.get("update_existing", True)))
+ result = {**imported, "summary": source_catalog_summary(session)}
+ elif action == "source-catalog-import-ingestable":
+ imported = import_ingestable_sources(session, payload.get("csv_path"), update_existing=bool(payload.get("update_existing", True)))
+ result = {**imported, "summary": source_catalog_summary(session)}
+ else:
+ raise ValueError(f"unsupported maintenance action: {action}")
+ job = _job_for_worker(session, job_id, worker_id)
+ _check_job_control(session, job)
+ session.commit()
+ return {"action": action, "payload": payload, "result": result}
+
+
+def _job_for_worker(session: Session, job_id: int, worker_id: str) -> Job:
+ job = session.get(Job, job_id)
+ if job is None:
+ raise ValueError(f"job not found: {job_id}")
+ if job.lease_owner != worker_id:
+ raise RuntimeError(f"job #{job_id} is not leased by this worker")
+ if job.status != "running":
+ raise RuntimeError(f"job #{job_id} is not running")
+ _heartbeat_job(job, worker_id)
+ return job
+
+
+def _job_running(session: Session, job: Job, worker_id: str, event_type: str, message: str, progress_current: int) -> None:
+ now = datetime.now(timezone.utc)
+ job.status = "running"
+ if job.started_at is None:
+ job.started_at = now
+ _heartbeat_job(job, worker_id, now=now)
+ job.progress_current = progress_current
+ add_job_event(
+ session,
+ job,
+ event_type=event_type,
+ message=message,
+ progress_current=progress_current,
+ progress_total=job.progress_total,
+ )
+ _check_job_control(session, job)
+ session.commit()
+
+
+def _job_progress_callback(session: Session, job: Job, worker_id: str, *, update_job_progress: bool = False):
+ def _callback(
+ event_type: str,
+ message: str,
+ progress_current: int | None = None,
+ progress_total: int | None = None,
+ metadata: dict[str, Any] | None = None,
+ ) -> None:
+ current_job = session.get(Job, job.id)
+ if current_job is None:
+ return
+ session.refresh(current_job)
+ if current_job.status != "running" or current_job.lease_owner != worker_id:
+ raise RuntimeError(f"job #{job.id} is no longer running under worker {worker_id}")
+ _heartbeat_job(current_job, worker_id)
+ if update_job_progress and progress_current is not None:
+ current_job.progress_current = progress_current
+ if update_job_progress and progress_total is not None:
+ current_job.progress_total = progress_total
+ add_job_event(
+ session,
+ current_job,
+ event_type=event_type,
+ message=message,
+ progress_current=progress_current,
+ progress_total=progress_total,
+ metadata=metadata,
+ )
+ _check_job_control(session, current_job)
+ session.commit()
+
+ return _callback
+
+
+def _heartbeat_job(job: Job, worker_id: str, *, now: datetime | None = None) -> None:
+ current = now or datetime.now(timezone.utc)
+ job.lease_owner = worker_id
+ job.lease_expires_at = current + timedelta(seconds=LEASE_SECONDS)
+ job.updated_at = current
+
+
+def _check_job_control(session: Session, job: Job) -> None:
+ session.flush()
+ session.refresh(job)
+ requested_action = job.requested_action or _job_control_request_action(job)
+ if requested_action == "cancel":
+ _clear_job_control_request(job.id)
+ raise JobCancelled()
+ if requested_action == "pause":
+ _clear_job_control_request(job.id)
+ _mark_job_paused(session, job, "Job paused at a cooperative checkpoint.")
+ session.commit()
+ raise JobPaused()
+
+
+def _mark_job_paused(session: Session, job: Job, message: str) -> None:
+ now = datetime.now(timezone.utc)
+ job.status = "paused"
+ job.requested_action = None
+ job.lease_owner = None
+ job.lease_expires_at = None
+ job.paused_at = now
+ job.updated_at = now
+ add_job_event(
+ session,
+ job,
+ event_type="paused",
+ message=message,
+ progress_current=job.progress_current,
+ progress_total=job.progress_total,
+ )
+ source = _job_source(session, job)
+ if source is not None:
+ source.status = "paused"
+ source.last_error = None
+ dataset = _job_dataset(session, job)
+ if dataset is not None:
+ dataset.status = "paused"
+ _clear_job_control_request(job.id)
+
+
+def _complete_job(session: Session, job: Job, message: str, result: dict[str, Any]) -> None:
+ job.status = "completed"
+ job.requested_action = None
+ job.lease_owner = None
+ job.lease_expires_at = None
+ job.paused_at = None
+ job.progress_current = job.progress_total
+ job.result_json = json.dumps(result, separators=(",", ":"))
+ job.updated_at = datetime.now(timezone.utc)
+ job.finished_at = job.updated_at
+ _clear_job_control_request(job.id)
+ add_job_event(
+ session,
+ job,
+ event_type="completed",
+ message=message,
+ progress_current=job.progress_current,
+ progress_total=job.progress_total,
+ metadata=result,
+ )
+
+
+def _finish_job_cancelled(session: Session, job: Job) -> None:
+ now = datetime.now(timezone.utc)
+ job.status = "cancelled"
+ job.requested_action = None
+ job.lease_owner = None
+ job.lease_expires_at = None
+ job.paused_at = None
+ job.updated_at = now
+ job.finished_at = now
+ job.error = None
+ _clear_job_control_request(job.id)
+ add_job_event(
+ session,
+ job,
+ event_type="cancelled",
+ message="Job stopped.",
+ progress_current=job.progress_current,
+ progress_total=job.progress_total,
+ )
+ source = _job_source(session, job)
+ if source is not None:
+ replacement = _source_status_without_active_job(session, source)
+ source.status = "new" if replacement == "error" and source.last_error is None else replacement
+ dataset = _job_dataset(session, job)
+ if dataset is not None:
+ dataset.status = str(_json_object(job.result_json).get("dataset_status") or "imported")
+
+
+def _mark_job_cancelled(job_id: int) -> None:
+ with database_write_lock(f"job:{job_id}:cancelled", timeout=10):
+ with SessionLocal() as session:
+ job = session.get(Job, job_id)
+ if job is None:
+ return
+ _finish_job_cancelled(session, job)
+ session.commit()
+
+
+def _get_job_or_raise(session: Session, job_id: int) -> Job:
+ job = session.get(Job, job_id)
+ if job is None:
+ raise ValueError(f"job not found: {job_id}")
+ return job
+
+
+def _mark_job_failed(job_id: int, exc: Exception) -> None:
+ try:
+ with database_write_lock(f"job:{ROUTE_LAYER_JOB_KIND}:{job_id}:failed", timeout=10):
+ with SessionLocal() as session:
+ job = session.get(Job, job_id)
+ if job is None:
+ return
+ if job.status in TERMINAL_JOB_STATUSES and job.lease_owner is None:
+ return
+ job.status = "failed"
+ job.requested_action = None
+ job.lease_owner = None
+ job.lease_expires_at = None
+ job.paused_at = None
+ job.error = str(exc)
+ job.updated_at = datetime.now(timezone.utc)
+ job.finished_at = job.updated_at
+ source = _job_source(session, job)
+ if source is not None:
+ source.status = "error"
+ source.last_error = str(exc)
+ source.last_run_at = job.finished_at
+ dataset = _job_dataset(session, job)
+ if dataset is not None:
+ dataset.status = "error"
+ add_job_event(
+ session,
+ job,
+ level="error",
+ event_type="failed",
+ message=str(exc),
+ progress_current=job.progress_current,
+ progress_total=job.progress_total,
+ metadata={"exception_type": exc.__class__.__name__},
+ )
+ session.commit()
+ except DatabaseWriteBusy:
+ pass
+
+
+def _job_source(session: Session, job: Job) -> Source | None:
+ source_id = source_id_from_job(job)
+ if source_id is None:
+ return None
+ return session.get(Source, source_id)
+
+
+def _job_dataset(session: Session, job: Job) -> Dataset | None:
+ dataset_id = dataset_id_from_job(job)
+ if dataset_id is None:
+ return None
+ return session.get(Dataset, dataset_id)
+
+
+def _job_control_dir():
+ return settings.data_dir / "job-control"
+
+
+def _job_control_path(job_id: int):
+ return _job_control_dir() / f"job-{int(job_id)}.json"
+
+
+def _write_job_control_request(job_id: int, payload: dict[str, Any]) -> None:
+ directory = _job_control_dir()
+ directory.mkdir(parents=True, exist_ok=True)
+ path = _job_control_path(job_id)
+ tmp_path = directory / f".job-{int(job_id)}-{os.getpid()}-{threading.get_ident()}.tmp"
+ tmp_path.write_text(json.dumps(payload, separators=(",", ":")), encoding="utf-8")
+ os.replace(tmp_path, path)
+
+
+def _read_job_control_request(job_id: int) -> dict[str, Any]:
+ try:
+ data = json.loads(_job_control_path(job_id).read_text(encoding="utf-8"))
+ except (OSError, json.JSONDecodeError):
+ return {}
+ return data if isinstance(data, dict) else {}
+
+
+def _clear_job_control_request(job_id: int) -> None:
+ try:
+ _job_control_path(job_id).unlink()
+ except FileNotFoundError:
+ pass
+ except OSError:
+ pass
+
+
+def _job_control_request_action(job: Job) -> str | None:
+ payload = _read_job_control_request(job.id)
+ action = payload.get("requested_action")
+ if action not in {"pause", "cancel"}:
+ return None
+ requested_at = _datetime_from_iso(payload.get("requested_at"))
+ if requested_at is not None and job.created_at is not None and requested_at < _as_utc(job.created_at):
+ _clear_job_control_request(job.id)
+ return None
+ return str(action)
+
+
+def _json_object(value: str | None) -> dict[str, Any]:
+ if not value:
+ return {}
+ try:
+ data = json.loads(value)
+ except json.JSONDecodeError:
+ return {}
+ return data if isinstance(data, dict) else {}
+
+
+def _normalize_job_payload(value: object) -> dict[str, Any]:
+ if not isinstance(value, dict):
+ return {}
+ normalized: dict[str, Any] = {}
+ for key in sorted(value):
+ item = value[key]
+ if item is None:
+ continue
+ if isinstance(item, dict):
+ normalized[str(key)] = _normalize_job_payload(item)
+ elif isinstance(item, list):
+ normalized[str(key)] = item
+ elif isinstance(item, (str, int, float, bool)):
+ normalized[str(key)] = item
+ else:
+ normalized[str(key)] = str(item)
+ return normalized
+
+
+def _maintenance_description(action: str, payload: dict[str, Any] | None = None) -> str:
+ normalized = _normalize_job_payload(payload)
+ if action == "init-db":
+ return "Initialize database schema"
+ if action == "sample-reset":
+ return "Reset sample data"
+ if action == "reset-db":
+ return "Reset database contents"
+ if action == "backfill-gtfs-shapes":
+ dataset_id = normalized.get("dataset_id")
+ return f"Backfill GTFS shapes for dataset {dataset_id}" if dataset_id is not None else "Backfill GTFS shapes"
+ if action == "prune-cache":
+ return "Check unreferenced cache files" if normalized.get("dry_run", True) else "Prune unreferenced cache files"
+ if action == "prune-inactive-datasets":
+ return "Check inactive datasets" if normalized.get("dry_run", True) else "Prune inactive datasets"
+ if action == "vacuum-db":
+ return "Vacuum database"
+ if action == "source-catalog-import":
+ return "Import source catalog"
+ if action == "source-catalog-import-ingestable":
+ return "Import ingestable source seeds"
+ return f"Maintenance action {action}"
+
+
+def _maintenance_progress_total(action: str) -> int:
+ if action in {"sample-reset", "reset-db"}:
+ return 4
+ return 1
+
+
+def _optional_int(value: object) -> int | None:
+ if value is None:
+ return None
+ try:
+ return int(value)
+ except (TypeError, ValueError):
+ return None
+
+
+def _datetime_from_iso(value: object) -> datetime | None:
+ if not isinstance(value, str) or not value:
+ return None
+ try:
+ parsed = datetime.fromisoformat(value)
+ except ValueError:
+ return None
+ return _as_utc(parsed)
+
+
+def _iso(value: datetime | None) -> str | None:
+ return value.isoformat() if value else None
+
+
+def _revision_datetime(value: datetime | None) -> str:
+ return "" if value is None else value.isoformat()
+
+
+def _as_utc(value: datetime | None) -> datetime | None:
+ if value is None:
+ return None
+ if value.tzinfo is None:
+ return value.replace(tzinfo=timezone.utc)
+ return value.astimezone(timezone.utc)
diff --git a/app/journey.py b/app/journey.py
new file mode 100644
index 0000000..446fbe7
--- /dev/null
+++ b/app/journey.py
@@ -0,0 +1,5385 @@
+from __future__ import annotations
+
+import json
+import math
+import re
+import threading
+import time
+from dataclasses import dataclass
+from datetime import date, datetime
+from typing import Iterator, Optional
+
+from shapely.geometry import LineString, MultiLineString, Point, mapping, shape
+from shapely.ops import linemerge, substring
+from sqlalchemy import and_, bindparam, case, exists, func, or_, select, text
+from sqlalchemy.orm import Session, aliased
+
+from app.address_search import (
+ address_by_token,
+ address_point_by_token,
+ address_point_token,
+ address_token,
+ coordinate_token,
+ is_coordinate_token,
+ is_address_point_token,
+ is_address_token,
+ is_location_token,
+ parse_coordinate_token,
+)
+from app.config import settings
+from app.gtfs_storage import (
+ GTFS_STOP_TIME_COLUMNS,
+ SQLITE_IN_CHUNK_SIZE,
+ all_scheduled_stop_ids,
+ execute_sidecar_query,
+ has_scheduled_stop as storage_has_scheduled_stop,
+ scheduled_stop_ids as storage_scheduled_stop_ids,
+ stop_times_by_trip as storage_stop_times_by_trip,
+ stop_times_for_trip_range,
+ uses_sidecar_stop_times,
+)
+from app.models import (
+ CanonicalStop,
+ CanonicalStopLink,
+ Dataset,
+ GtfsCalendar,
+ GtfsCalendarDate,
+ GtfsRoute,
+ GtfsShape,
+ GtfsStop,
+ GtfsStopTime,
+ GtfsTrip,
+ OsmAddress,
+ OsmFeature,
+ RoutePattern,
+ Source,
+)
+from app.osm_storage import query_osm_features
+from app.pipeline.route_layer import (
+ canonical_stop_for_gtfs_stop,
+ logical_stop_group_id,
+ route_pattern_for_trip,
+)
+from app.routing import route_between_points, snap_point_to_routing_graph
+from app.serializers import feature_collection
+
+
+MAX_DIRECT_ROWS = 12000
+MAX_TRANSFER_BOARDINGS = 350
+MAX_TARGET_DESTINATION_ARRIVALS = 1400
+MAX_TARGET_SECOND_LEGS_PER_STOP = 48
+MAX_TARGET_TRANSFER_CANDIDATES = 4500
+MAX_BACKWARD_SECOND_LEG_OPTIONS = 160
+OSM_STOP_MATCH_RADIUS_DEG = 0.0012
+LEG_GEOMETRY_MAX_STOP_DISTANCE_DEG = 0.08
+MAX_STOP_SEARCH_ROWS = 700
+MAX_GROUP_STOP_IDS = 120
+MAX_ROUTER_BOARDING_CANDIDATES = 2200
+MAX_ROUTER_TRANSIT_LEGS = 6
+MAX_JOURNEY_DATASET_PAIRS = 40
+WALKING_TRANSFER_RADIUS_M = 450
+WALKING_TRANSFER_RADIUS_DEG = WALKING_TRANSFER_RADIUS_M / 111_320
+WALKING_TRANSFER_SPEED_MPS = 1.25
+MAX_WALKING_TRANSFER_SOURCE_STOPS = 80
+MAX_WALKING_TRANSFER_NEIGHBORS_PER_STOP = 8
+ACCESS_TRANSFER_MAX_SECONDS = 45 * 60
+MAX_ACCESS_TRANSFER_CANDIDATES = 4
+PUBLIC_TRANSPORT_WALK_OPTION_MAX_SECONDS = 45 * 60
+ADDRESS_ACCESS_RADIUS_M = 1800
+ADDRESS_ACCESS_MAX_SECONDS = 30 * 60
+ADDRESS_ACCESS_STOP_CANDIDATES = 4
+ADDRESS_ACCESS_MAX_PAIR_CANDIDATES = 8
+ADDRESS_ACCESS_MAX_DEEP_PAIR_CANDIDATES = 4
+ADDRESS_ACCESS_SHORT_DIRECT_WALK_SECONDS = 20 * 60
+ADDRESS_ACCESS_LONG_DISTANCE_HUB_THRESHOLD_M = 50_000
+ADDRESS_ACCESS_MAJOR_HUB_RADIUS_M = 12_000
+ADDRESS_ACCESS_MAJOR_HUB_CANDIDATES = 3
+ADDRESS_ACCESS_NORMAL_PRIORITY = 100
+ADDRESS_ACCESS_MAJOR_HUB_PRIORITY = 10
+WALK_GEOMETRY_CACHE_TTL_SECONDS = 10 * 60
+WALK_GEOMETRY_CACHE_MAX_ENTRIES = 1024
+LEG_GEOMETRY_CACHE_TTL_SECONDS = 10 * 60
+LEG_GEOMETRY_CACHE_MAX_ENTRIES = 2048
+STOP_GROUP_PREFIX = "group:"
+STOP_EXACT_PREFIX = "stop:"
+STOP_PLACE_PREFIX = "place:"
+_walk_geometry_cache_lock = threading.RLock()
+_walk_geometry_cache: dict[tuple[float, float, float, float], tuple[float, tuple[dict | None, float, float | None]]] = {}
+_leg_geometry_cache_lock = threading.RLock()
+_leg_geometry_cache: dict[tuple[object, ...], tuple[float, dict | None, str, int | None]] = {}
+
+
+@dataclass(frozen=True)
+class StopSummary:
+ id: int
+ dataset_id: int
+ stop_id: str
+ name: str | None
+ lat: float | None
+ lon: float | None
+
+
+@dataclass(frozen=True)
+class StopSelection:
+ display: StopSummary
+ stop_ids_by_dataset: dict[int, tuple[str, ...]]
+ canonical_stop_id: int | None = None
+
+ @property
+ def dataset_id(self) -> int:
+ return next(iter(self.stop_ids_by_dataset))
+
+ @property
+ def stop_ids(self) -> tuple[str, ...]:
+ return self.stop_ids_by_dataset[self.dataset_id]
+
+ @property
+ def dataset_ids(self) -> tuple[int, ...]:
+ return tuple(self.stop_ids_by_dataset)
+
+
+@dataclass(frozen=True)
+class _AccessStopCandidate:
+ token: str
+ selection: StopSelection
+ distance_m: float
+ priority: int = ADDRESS_ACCESS_NORMAL_PRIORITY
+
+
+def search_scheduled_stops(
+ db: Session,
+ query: str | None = None,
+ source_ids: list[int] | None = None,
+ limit: int = 25,
+ bbox: tuple[float, float, float, float] | None = None,
+) -> list[dict]:
+ """Return stops that have imported stop_times.
+
+ The importer may intentionally cap stop_times for large feeds. Searching only
+ scheduled stops prevents the UI from offering stops that cannot route yet.
+ """
+ active_dataset_ids = _active_gtfs_dataset_ids(db, source_ids=source_ids)
+ if not active_dataset_ids:
+ return []
+
+ stmt = (
+ select(GtfsStop, Source.id, Source.name)
+ .join(Dataset, Dataset.id == GtfsStop.dataset_id)
+ .join(Source, Source.id == Dataset.source_id)
+ .where(GtfsStop.dataset_id.in_(active_dataset_ids))
+ )
+ q = (query or "").strip()
+ if q:
+ pattern = f"%{q}%"
+ tokens = [token for token in re.split(r"[\s,;/]+", q) if token]
+ token_filters = [
+ or_(GtfsStop.name.ilike(f"%{token}%"), GtfsStop.stop_id.ilike(f"%{token}%"))
+ for token in tokens
+ ]
+ where_parts = [GtfsStop.name.ilike(pattern), GtfsStop.stop_id.ilike(pattern)]
+ if token_filters:
+ where_parts.append(and_(*token_filters))
+ stmt = stmt.where(or_(*where_parts))
+ rank = case(
+ (GtfsStop.name.ilike(q), 0),
+ (GtfsStop.name.ilike(f"{q}%"), 1),
+ (GtfsStop.name.ilike(pattern), 2),
+ (GtfsStop.stop_id.ilike(f"{q}%"), 3),
+ else_=4,
+ )
+ if bbox is not None:
+ stmt = stmt.order_by(rank, *_bbox_order_expressions(GtfsStop, bbox), GtfsStop.name, GtfsStop.id)
+ else:
+ stmt = stmt.order_by(rank, GtfsStop.name, GtfsStop.id)
+ else:
+ if bbox is not None:
+ stmt = stmt.order_by(*_bbox_order_expressions(GtfsStop, bbox), GtfsStop.name, GtfsStop.id)
+ else:
+ stmt = stmt.order_by(GtfsStop.name, GtfsStop.id)
+ stmt = stmt.limit(MAX_STOP_SEARCH_ROWS * (3 if bbox is not None else 1))
+
+ groups: dict[tuple[int, str], dict] = {}
+ for stop, source_id, source_name in db.execute(stmt).all():
+ group_id = logical_stop_group_id(stop)
+ key = (stop.dataset_id, group_id)
+ rank_value = _stop_match_rank(stop, q)
+ group = groups.setdefault(
+ key,
+ {
+ "dataset_id": stop.dataset_id,
+ "group_id": group_id,
+ "source_id": source_id,
+ "source_name": source_name,
+ "rank": rank_value,
+ "matches": [],
+ },
+ )
+ group["rank"] = min(int(group["rank"]), rank_value)
+ group["matches"].append(stop)
+
+ if not groups:
+ return []
+
+ parents = _parent_stops_for_groups(db, groups.keys())
+ scheduled = _scheduled_stops_for_groups(db, groups.keys())
+ results = []
+ for key, group in groups.items():
+ scheduled_stops = scheduled.get(key, [])
+ if not scheduled_stops:
+ continue
+ parent = parents.get(key)
+ display_stop = parent or _best_display_stop(str(group["group_id"]), group["matches"], scheduled_stops)
+ canonical = _canonical_stop_for_group(db, scheduled_stops)
+ name = canonical.name if canonical is not None else display_stop.name
+ display_parts = _city_stop_display_parts(
+ name,
+ display_stop.name,
+ parent.name if parent is not None else None,
+ *(stop.name for stop in scheduled_stops),
+ )
+ display_name = display_parts["display_name"]
+ result_id = (
+ _stop_place_token(canonical.id, display_stop.dataset_id)
+ if canonical is not None
+ else _stop_group_token(display_stop.dataset_id, str(group["group_id"]))
+ )
+ display_rank = _stop_match_rank(display_stop, q)
+ result_lat = canonical.lat if canonical is not None else display_stop.lat
+ result_lon = canonical.lon if canonical is not None else display_stop.lon
+ bbox_rank, bbox_distance_m = _bbox_rank(result_lat, result_lon, bbox)
+ results.append(
+ {
+ "id": result_id,
+ "canonical_stop_id": None if canonical is None else canonical.id,
+ "dataset_id": display_stop.dataset_id,
+ "stop_id": str(group["group_id"]),
+ "name": name,
+ "display_name": display_name,
+ "city": display_parts["city"],
+ "local_name": display_parts["local_name"],
+ "lat": result_lat,
+ "lon": result_lon,
+ "source_id": group["source_id"],
+ "source_name": group["source_name"],
+ "scheduled": True,
+ "grouped": True,
+ "grouped_stop_count": len(scheduled_stops),
+ "sample_stop_ids": [stop.stop_id for stop in scheduled_stops[:5]],
+ "_display_rank": display_rank,
+ "_match_rank": group["rank"],
+ "_bbox_rank": bbox_rank,
+ "_bbox_distance_m": bbox_distance_m,
+ "_importance_rank": _station_importance_rank(
+ display_name,
+ name,
+ display_stop.name,
+ parent.name if parent is not None else None,
+ *(stop.name for stop in scheduled_stops),
+ ),
+ }
+ )
+
+ results.sort(
+ key=lambda item: (
+ item["_bbox_rank"],
+ item["_importance_rank"],
+ item["_display_rank"],
+ item["_match_rank"],
+ item["_bbox_distance_m"],
+ -(int(item["grouped_stop_count"])),
+ item["name"] or "",
+ item["stop_id"],
+ )
+ )
+ if not source_ids or len(source_ids) > 1:
+ results = _merge_canonical_stop_results(results)
+ _enrich_canonical_stop_sources(db, results, active_dataset_ids)
+ results.sort(
+ key=lambda item: (
+ item["_bbox_rank"],
+ item["_importance_rank"],
+ item["_display_rank"],
+ item["_match_rank"],
+ item["_bbox_distance_m"],
+ -(int(item["grouped_stop_count"])),
+ item["name"] or "",
+ item["stop_id"],
+ )
+ )
+ selected = results[: max(1, min(limit, 100))]
+ for item in selected:
+ item.pop("_display_rank", None)
+ item.pop("_match_rank", None)
+ item.pop("_bbox_rank", None)
+ item.pop("_bbox_distance_m", None)
+ item.pop("_importance_rank", None)
+ return selected
+
+
+def nearest_scheduled_stops(
+ db: Session,
+ *,
+ lat: float,
+ lon: float,
+ source_ids: list[int] | None = None,
+ limit: int = 3,
+ radius_m: float = 900,
+) -> list[dict]:
+ active_dataset_ids = _active_gtfs_dataset_ids(db, source_ids=source_ids)
+ if not active_dataset_ids:
+ return []
+ selected_limit = max(1, min(int(limit), 25))
+ if settings.is_postgresql_database:
+ rows = _nearest_canonical_stop_rows_postgresql(
+ db,
+ lat=lat,
+ lon=lon,
+ dataset_ids=active_dataset_ids,
+ limit=selected_limit * 8,
+ radius_m=radius_m,
+ )
+ rows.extend(
+ _nearest_visual_stop_rows_postgresql(
+ db,
+ lat=lat,
+ lon=lon,
+ dataset_ids=active_dataset_ids,
+ limit=selected_limit * 8,
+ radius_m=radius_m,
+ )
+ )
+ rows.sort(key=lambda item: (float(item[2] or 0), int(item[0]), int(item[1])))
+ results: list[dict] = []
+ seen: set[int] = set()
+ for canonical_stop_id, preferred_dataset_id, distance_m in rows:
+ if int(canonical_stop_id) in seen:
+ continue
+ try:
+ selection = _selection_for_canonical_stop(
+ db,
+ int(canonical_stop_id),
+ dataset_ids=active_dataset_ids,
+ preferred_dataset_id=int(preferred_dataset_id),
+ )
+ except ValueError:
+ continue
+ seen.add(int(canonical_stop_id))
+ source = _source_payload_for_dataset_id(db, selection.dataset_id) or {}
+ payload = _stop_payload(selection.display)
+ payload.update(
+ {
+ "id": _stop_place_token(int(canonical_stop_id), selection.dataset_id),
+ "kind": "stop",
+ "canonical_stop_id": int(canonical_stop_id),
+ "display_name": selection.display.name,
+ "source_id": source.get("id"),
+ "source_name": source.get("name"),
+ "scheduled": True,
+ "grouped": True,
+ "grouped_stop_count": sum(len(stop_ids) for stop_ids in selection.stop_ids_by_dataset.values()),
+ "distance_m": float(distance_m or 0),
+ }
+ )
+ results.append(payload)
+ if len(results) >= selected_limit:
+ break
+ return results
+
+ radius_deg = float(radius_m) / 111_320
+ bbox = (float(lon) - radius_deg, float(lat) - radius_deg, float(lon) + radius_deg, float(lat) + radius_deg)
+ candidates = search_scheduled_stops(db, source_ids=source_ids, bbox=bbox, limit=selected_limit * 4)
+ for item in candidates:
+ if item.get("lat") is None or item.get("lon") is None:
+ item["distance_m"] = float("inf")
+ else:
+ item["distance_m"] = _distance_m(float(lat), float(lon), float(item["lat"]), float(item["lon"]))
+ item["kind"] = "stop"
+ candidates = [item for item in candidates if float(item.get("distance_m") or 0) <= radius_m]
+ candidates.sort(key=lambda item: (float(item.get("distance_m") or 0), item.get("display_name") or item.get("name") or ""))
+ return candidates[:selected_limit]
+
+
+def _nearest_canonical_stop_rows_postgresql(
+ db: Session,
+ *,
+ lat: float,
+ lon: float,
+ dataset_ids: list[int],
+ limit: int,
+ radius_m: float,
+) -> list[tuple[int, int, float]]:
+ radius_deg = float(radius_m) / 111_320
+ stmt = text(
+ """
+ WITH point AS (
+ SELECT ST_SetSRID(ST_MakePoint(:lon, :lat), 4326) AS geom
+ )
+ SELECT
+ canonical_stops.id AS canonical_stop_id,
+ canonical_stop_links.dataset_id AS dataset_id,
+ ST_DistanceSphere(canonical_stops.geom, point.geom) AS distance_m
+ FROM canonical_stops
+ JOIN canonical_stop_links
+ ON canonical_stop_links.canonical_stop_id = canonical_stops.id
+ AND canonical_stop_links.object_type = 'gtfs_stop'
+ JOIN datasets
+ ON datasets.id = canonical_stop_links.dataset_id
+ AND datasets.kind = 'gtfs'
+ AND datasets.is_active IS TRUE
+ CROSS JOIN point
+ WHERE canonical_stop_links.dataset_id IN :dataset_ids
+ AND canonical_stops.geom IS NOT NULL
+ AND canonical_stops.geom && ST_Expand(point.geom, :radius_deg)
+ AND ST_DWithin(canonical_stops.geom::geography, point.geom::geography, :radius_m)
+ GROUP BY canonical_stops.id, canonical_stop_links.dataset_id, canonical_stops.geom, point.geom
+ ORDER BY canonical_stops.geom <-> point.geom, canonical_stops.id
+ LIMIT :limit
+ """
+ ).bindparams(bindparam("dataset_ids", expanding=True))
+ rows = db.execute(
+ stmt,
+ {
+ "lat": float(lat),
+ "lon": float(lon),
+ "dataset_ids": tuple(dataset_ids),
+ "radius_deg": radius_deg,
+ "radius_m": float(radius_m),
+ "limit": max(1, int(limit)),
+ },
+ ).all()
+ return [(int(row[0]), int(row[1]), float(row[2] or 0)) for row in rows]
+
+
+def _nearest_visual_stop_rows_postgresql(
+ db: Session,
+ *,
+ lat: float,
+ lon: float,
+ dataset_ids: list[int],
+ limit: int,
+ radius_m: float,
+) -> list[tuple[int, int, float]]:
+ radius_deg = float(radius_m) / 111_320
+ stmt = text(
+ """
+ WITH point AS (
+ SELECT ST_SetSRID(ST_MakePoint(:lon, :lat), 4326) AS geom
+ ),
+ visual_hits AS (
+ SELECT
+ osm_link.canonical_stop_id AS canonical_stop_id,
+ gtfs_link.dataset_id AS dataset_id,
+ ST_DistanceSphere(osm_features.geom, point.geom) AS distance_m
+ FROM osm_features
+ JOIN canonical_stop_links AS osm_link
+ ON osm_link.object_type = 'osm_feature'
+ AND osm_link.object_id = osm_features.id
+ JOIN canonical_stop_links AS gtfs_link
+ ON gtfs_link.canonical_stop_id = osm_link.canonical_stop_id
+ AND gtfs_link.object_type = 'gtfs_stop'
+ JOIN datasets
+ ON datasets.id = gtfs_link.dataset_id
+ AND datasets.kind = 'gtfs'
+ AND datasets.is_active IS TRUE
+ CROSS JOIN point
+ WHERE gtfs_link.dataset_id IN :dataset_ids
+ AND osm_features.kind IN ('stop', 'station', 'terminal')
+ AND osm_features.geom IS NOT NULL
+ AND osm_features.geom && ST_Expand(point.geom, :radius_deg)
+ AND ST_DWithin(osm_features.geom::geography, point.geom::geography, :radius_m)
+ )
+ SELECT canonical_stop_id, dataset_id, MIN(distance_m) AS distance_m
+ FROM visual_hits
+ GROUP BY canonical_stop_id, dataset_id
+ ORDER BY MIN(distance_m), canonical_stop_id
+ LIMIT :limit
+ """
+ ).bindparams(bindparam("dataset_ids", expanding=True))
+ rows = db.execute(
+ stmt,
+ {
+ "lat": float(lat),
+ "lon": float(lon),
+ "dataset_ids": tuple(dataset_ids),
+ "radius_deg": radius_deg,
+ "radius_m": float(radius_m),
+ "limit": max(1, int(limit)),
+ },
+ ).all()
+ return [(int(row[0]), int(row[1]), float(row[2] or 0)) for row in rows]
+
+
+def _enrich_canonical_stop_sources(db: Session, results: list[dict], active_dataset_ids: list[int]) -> None:
+ canonical_stop_ids = sorted(
+ {
+ int(item["canonical_stop_id"])
+ for item in results
+ if item.get("canonical_stop_id") is not None
+ }
+ )
+ if not canonical_stop_ids or not active_dataset_ids:
+ for item in results:
+ item.setdefault("source_names", [item["source_name"]] if item.get("source_name") else [])
+ return
+
+ rows = db.execute(
+ select(CanonicalStopLink.canonical_stop_id, Source.id, Source.name, func.count(CanonicalStopLink.id))
+ .join(Dataset, Dataset.id == CanonicalStopLink.dataset_id)
+ .join(Source, Source.id == Dataset.source_id)
+ .where(
+ CanonicalStopLink.object_type == "gtfs_stop",
+ CanonicalStopLink.canonical_stop_id.in_(canonical_stop_ids),
+ CanonicalStopLink.dataset_id.in_(active_dataset_ids),
+ )
+ .group_by(CanonicalStopLink.canonical_stop_id, Source.id, Source.name)
+ .order_by(CanonicalStopLink.canonical_stop_id, Source.name, Source.id)
+ ).all()
+ summaries: dict[int, dict] = {}
+ for canonical_stop_id, source_id, source_name, linked_stop_count in rows:
+ summary = summaries.setdefault(int(canonical_stop_id), {"source_ids": [], "source_names": [], "linked_stop_count": 0})
+ summary["source_ids"].append(int(source_id))
+ summary["source_names"].append(str(source_name))
+ summary["linked_stop_count"] += int(linked_stop_count or 0)
+
+ for item in results:
+ canonical_stop_id = item.get("canonical_stop_id")
+ if canonical_stop_id is None:
+ item.setdefault("source_names", [item["source_name"]] if item.get("source_name") else [])
+ continue
+ summary = summaries.get(int(canonical_stop_id))
+ if not summary:
+ item.setdefault("source_names", [item["source_name"]] if item.get("source_name") else [])
+ continue
+ source_names = summary["source_names"]
+ item["source_names"] = source_names
+ item["source_name"] = ", ".join(source_names[:3])
+ if len(source_names) > 3:
+ item["source_name"] += f" +{len(source_names) - 3}"
+ item["source_id"] = summary["source_ids"][0] if len(summary["source_ids"]) == 1 else None
+ item["grouped_stop_count"] = max(int(item.get("grouped_stop_count") or 0), int(summary["linked_stop_count"]))
+
+
+def _merge_canonical_stop_results(results: list[dict]) -> list[dict]:
+ merged: dict[tuple[object, ...], dict] = {}
+ for item in results:
+ canonical_stop_id = item.get("canonical_stop_id")
+ key = (
+ ("canonical", canonical_stop_id)
+ if canonical_stop_id is not None
+ else ("group", item.get("dataset_id"), item.get("stop_id"))
+ )
+ current = merged.get(key)
+ if current is None:
+ copied = dict(item)
+ copied["source_names"] = [item["source_name"]] if item.get("source_name") else []
+ merged[key] = copied
+ continue
+ current["_display_rank"] = min(int(current["_display_rank"]), int(item["_display_rank"]))
+ current["_match_rank"] = min(int(current["_match_rank"]), int(item["_match_rank"]))
+ current["_bbox_rank"] = min(int(current.get("_bbox_rank", 2)), int(item.get("_bbox_rank", 2)))
+ current["_bbox_distance_m"] = min(
+ float(current.get("_bbox_distance_m", float("inf"))),
+ float(item.get("_bbox_distance_m", float("inf"))),
+ )
+ current["_importance_rank"] = min(
+ int(current.get("_importance_rank", 3)),
+ int(item.get("_importance_rank", 3)),
+ )
+ current["grouped_stop_count"] = int(current.get("grouped_stop_count") or 0) + int(item.get("grouped_stop_count") or 0)
+ current["sample_stop_ids"] = _merge_sample_stop_ids(current.get("sample_stop_ids", []), item.get("sample_stop_ids", []))
+ source_names = _merge_source_names(current.get("source_names", []), [item["source_name"]] if item.get("source_name") else [])
+ current["source_names"] = source_names
+ current["source_name"] = ", ".join(source_names[:3])
+ if len(source_names) > 3:
+ current["source_name"] += f" +{len(source_names) - 3}"
+ if len(source_names) > 1:
+ current["source_id"] = None
+
+ selected = list(merged.values())
+ selected.sort(
+ key=lambda item: (
+ item.get("_bbox_rank", 2),
+ item.get("_importance_rank", 3),
+ item["_display_rank"],
+ item["_match_rank"],
+ item.get("_bbox_distance_m", float("inf")),
+ -(int(item["grouped_stop_count"])),
+ item["name"] or "",
+ item["stop_id"],
+ )
+ )
+ return selected
+
+
+def _merge_sample_stop_ids(left: list[str], right: list[str]) -> list[str]:
+ merged = []
+ seen = set()
+ for stop_id in [*left, *right]:
+ if stop_id in seen:
+ continue
+ seen.add(stop_id)
+ merged.append(stop_id)
+ if len(merged) >= 8:
+ break
+ return merged
+
+
+def _merge_source_names(left: list[str], right: list[str]) -> list[str]:
+ names = []
+ seen = set()
+ for name in [*left, *right]:
+ if not name or name in seen:
+ continue
+ seen.add(name)
+ names.append(name)
+ return names
+
+
+def _city_stop_display_parts(primary: str | None, *candidates: str | None) -> dict[str, str | None]:
+ primary_name = _clean_stop_name(primary)
+ if not primary_name:
+ return {"display_name": None, "city": None, "local_name": None}
+ if "," in primary_name:
+ pairs = _candidate_city_stop_pairs(primary_name)
+ if pairs:
+ city, stop_name = pairs[0]
+ return {"display_name": f"{city}, {stop_name}", "city": city, "local_name": stop_name}
+ leading = _split_leading_city_stop_name(primary_name)
+ if leading is not None:
+ city, stop_name = leading
+ return {"display_name": f"{city}, {stop_name}", "city": city, "local_name": stop_name}
+ for candidate in candidates:
+ for city, stop_name in _candidate_city_stop_pairs(candidate):
+ local_name = _local_stop_name(primary_name, city, stop_name)
+ if stop_name and (_stop_names_match(primary_name, stop_name) or _stop_names_match(local_name, stop_name)):
+ return {"display_name": f"{city}, {local_name}", "city": city, "local_name": local_name}
+ return {"display_name": primary_name, "city": None, "local_name": primary_name}
+
+
+def _normalize_city_stop_name(value: str) -> str:
+ city, stop_name = _split_city_stop_name(value)
+ if city and stop_name:
+ return f"{city}, {stop_name}"
+ return _clean_stop_name(value) or value
+
+
+def _split_city_stop_name(value: str | None, primary_name: str | None = None) -> tuple[str | None, str | None]:
+ pairs = _candidate_city_stop_pairs(value, primary_name=primary_name)
+ if pairs:
+ return pairs[0]
+ name = _clean_stop_name(value)
+ return (None, name)
+
+
+def _candidate_city_stop_pairs(value: str | None, primary_name: str | None = None) -> list[tuple[str, str]]:
+ name = _clean_stop_name(value)
+ parts = _split_first_comma_outside_parentheses(name)
+ if parts is None:
+ return []
+ left, right = parts
+ left = _clean_stop_name(left)
+ right = _clean_stop_name(right)
+ if not left or not right:
+ return []
+ left_stop = _looks_like_stop_name(left)
+ right_stop = _looks_like_stop_name(right)
+ pairs: list[tuple[str, str]] = []
+ if primary_name:
+ left_matches = _stop_names_match(primary_name, left)
+ right_matches = _stop_names_match(primary_name, right)
+ if left_matches and not right_matches and not right_stop:
+ pairs.append((right, left))
+ if right_matches and not left_matches and not left_stop:
+ pairs.append((left, right))
+ if left_stop and not right_stop:
+ pairs.append((right, left))
+ elif right_stop and not left_stop:
+ pairs.append((left, right))
+ elif not left_stop and not right_stop:
+ pairs.append((left, right))
+ deduped: list[tuple[str, str]] = []
+ seen = set()
+ for pair in pairs:
+ if pair in seen:
+ continue
+ seen.add(pair)
+ deduped.append(pair)
+ return deduped
+
+
+def _split_first_comma_outside_parentheses(value: str | None) -> tuple[str, str] | None:
+ if not value:
+ return None
+ depth = 0
+ for index, char in enumerate(value):
+ if char == "(":
+ depth += 1
+ elif char == ")" and depth > 0:
+ depth -= 1
+ elif char == "," and depth == 0:
+ return value[:index], value[index + 1 :]
+ return None
+
+
+def _looks_like_stop_name(value: str) -> bool:
+ normalized = _normalize_stop_search(value)
+ if re.search(r"(^|[\s,(/-])hbf\.?($|[\s,)/-])", normalized):
+ return True
+ stop_tokens = (
+ "hauptbahnhof",
+ "bahnhof",
+ "station",
+ "central station",
+ "central train station",
+ "steig",
+ "tram",
+ "bus",
+ "zob",
+ "ostseite",
+ "westseite",
+ )
+ return any(token in normalized for token in stop_tokens)
+
+
+def _split_leading_city_stop_name(value: str) -> tuple[str, str] | None:
+ name = _clean_stop_name(value)
+ if not name:
+ return None
+ match = re.match(
+ r"^(.+?)\s+(central train station|central station|main station|hauptbahnhof(?:\s+.*)?|hbf\.?(?:\s+.*)?)$",
+ name,
+ flags=re.IGNORECASE,
+ )
+ if not match:
+ return None
+ city = _clean_stop_name(match.group(1))
+ stop_name = _clean_stop_name(match.group(2))
+ if not city or not stop_name or _looks_like_stop_name(city):
+ return None
+ return city, stop_name
+
+
+def _local_stop_name(primary_name: str, city: str, candidate_stop_name: str | None) -> str:
+ if _normalize_stop_search(primary_name).startswith(f"{_normalize_stop_search(city)} "):
+ remainder = primary_name[len(city) :].strip(" ,")
+ if remainder:
+ return remainder
+ if candidate_stop_name and _normalize_station_synonyms(primary_name) == _normalize_station_synonyms(candidate_stop_name):
+ return candidate_stop_name
+ return primary_name
+
+
+def _stop_names_match(left: str | None, right: str | None) -> bool:
+ left_norm = _normalize_stop_search(left or "")
+ right_norm = _normalize_stop_search(right or "")
+ if not left_norm or not right_norm:
+ return False
+ if left_norm == right_norm or left_norm in right_norm or right_norm in left_norm:
+ return True
+ return _normalize_station_synonyms(left_norm) == _normalize_station_synonyms(right_norm)
+
+
+def _normalize_station_synonyms(value: str) -> str:
+ normalized = _normalize_stop_search(value)
+ normalized = re.sub(r"\bcentral train station\b", "mainstation", normalized)
+ normalized = re.sub(r"\bcentral station\b", "mainstation", normalized)
+ normalized = re.sub(r"\bmain station\b", "mainstation", normalized)
+ normalized = re.sub(r"\bhauptbahnhof\b", "mainstation", normalized)
+ normalized = re.sub(r"(^|[\s,(/-])hbf($|[\s,)/-])", " mainstation ", normalized)
+ return re.sub(r"[^a-z0-9]+", "", normalized)
+
+
+def _clean_stop_name(value: str | None) -> str | None:
+ cleaned = re.sub(r"\s+", " ", str(value or "")).strip()
+ return cleaned or None
+
+
+def _stop_group_token(dataset_id: int, group_id: str) -> str:
+ return f"{STOP_GROUP_PREFIX}{dataset_id}:{group_id}"
+
+
+def _stop_place_token(canonical_stop_id: int, dataset_id: int) -> str:
+ return f"{STOP_PLACE_PREFIX}{canonical_stop_id}:{dataset_id}"
+
+
+def _canonical_stop_for_group(db: Session, stops: list[GtfsStop]) -> CanonicalStop | None:
+ stop_ids = [stop.id for stop in stops]
+ if not stop_ids:
+ return None
+ link = db.scalar(
+ select(CanonicalStopLink)
+ .where(CanonicalStopLink.object_type == "gtfs_stop", CanonicalStopLink.object_id.in_(stop_ids))
+ .order_by(CanonicalStopLink.role, CanonicalStopLink.id)
+ )
+ if link is None:
+ return None
+ return db.get(CanonicalStop, link.canonical_stop_id)
+
+
+def _stop_match_rank(stop: GtfsStop, query: str) -> int:
+ if not query:
+ return 4
+ needle = _normalize_stop_search(query)
+ name = _normalize_stop_search(stop.name or "")
+ stop_id = _normalize_stop_search(stop.stop_id)
+ if needle in {name, stop_id}:
+ return 0
+ if name.startswith(needle) or stop_id.startswith(needle):
+ return 1
+ if needle in name or needle in stop_id:
+ return 2
+ tokens = [token for token in re.split(r"[\s,;/]+", needle) if token]
+ haystack = f"{name} {stop_id}"
+ if tokens and all(token in haystack for token in tokens):
+ return 3
+ return 4
+
+
+def _bbox_order_expressions(model, bbox: tuple[float, float, float, float]):
+ min_lon, min_lat, max_lon, max_lat = bbox
+ center_lon = (min_lon + max_lon) / 2
+ center_lat = (min_lat + max_lat) / 2
+ missing = or_(model.lon.is_(None), model.lat.is_(None))
+ inside = and_(model.lon >= min_lon, model.lon <= max_lon, model.lat >= min_lat, model.lat <= max_lat)
+ bbox_rank = case((missing, 2), (inside, 0), else_=1)
+ lon_offset = func.coalesce(model.lon, center_lon) - center_lon
+ lat_offset = func.coalesce(model.lat, center_lat) - center_lat
+ return (bbox_rank, lon_offset * lon_offset + lat_offset * lat_offset)
+
+
+def _bbox_rank(
+ lat: float | None,
+ lon: float | None,
+ bbox: tuple[float, float, float, float] | None,
+) -> tuple[int, float]:
+ if bbox is None:
+ return (1, 0.0)
+ if lat is None or lon is None:
+ return (2, float("inf"))
+ min_lon, min_lat, max_lon, max_lat = bbox
+ if min_lon <= lon <= max_lon and min_lat <= lat <= max_lat:
+ return (0, 0.0)
+ clamped_lon = min(max(lon, min_lon), max_lon)
+ clamped_lat = min(max(lat, min_lat), max_lat)
+ return (1, _distance_m(lat, lon, clamped_lat, clamped_lon))
+
+
+def _station_importance_rank(*names: str | None) -> int:
+ normalized_names = [_normalize_stop_search(name or "") for name in names if name]
+ if any(
+ re.search(r"(^|[\\s,(/-])hbf($|[\\s,)/-])", name)
+ or "hauptbahnhof" in name
+ or "central station" in name
+ or "central train station" in name
+ for name in normalized_names
+ ):
+ return 0
+ if any(
+ re.search(r"(^|[\\s,(/-])bf($|[\\s,)/-])", name)
+ or "bahnhof" in name
+ or "station" in name
+ for name in normalized_names
+ ):
+ return 1
+ if any("zob" in name or "busbahnhof" in name for name in normalized_names):
+ return 2
+ return 3
+
+
+def _normalize_stop_search(value: str) -> str:
+ return re.sub(r"\s+", " ", value.casefold().strip())
+
+
+def _parent_stops_for_groups(db: Session, group_keys) -> dict[tuple[int, str], GtfsStop]:
+ requested = set(group_keys)
+ if not requested:
+ return {}
+ dataset_ids = {dataset_id for dataset_id, _ in requested}
+ group_ids = {group_id for _, group_id in requested}
+ rows = db.scalars(
+ select(GtfsStop).where(GtfsStop.dataset_id.in_(dataset_ids), GtfsStop.stop_id.in_(group_ids))
+ ).all()
+ return {
+ (stop.dataset_id, stop.stop_id): stop
+ for stop in rows
+ if (stop.dataset_id, stop.stop_id) in requested
+ }
+
+
+def _scheduled_stops_for_groups(db: Session, group_keys) -> dict[tuple[int, str], list[GtfsStop]]:
+ requested = set(group_keys)
+ if not requested:
+ return {}
+ dataset_ids = {dataset_id for dataset_id, _ in requested}
+ group_ids = {group_id for _, group_id in requested}
+ if settings.is_postgresql_database:
+ group_condition = or_(
+ GtfsStop.stop_id.in_(group_ids),
+ GtfsStop.parent_station.in_(group_ids),
+ func.split_part(GtfsStop.stop_id, "::", 1).in_(group_ids),
+ )
+ else:
+ inferred_child_filters = [GtfsStop.stop_id.ilike(f"{group_id}::%") for group_id in group_ids]
+ group_condition = or_(GtfsStop.stop_id.in_(group_ids), GtfsStop.parent_station.in_(group_ids), *inferred_child_filters)
+ rows = db.scalars(
+ select(GtfsStop)
+ .where(
+ GtfsStop.dataset_id.in_(dataset_ids),
+ group_condition,
+ *(_scheduled_gtfs_stop_condition() if settings.is_postgresql_database else ()),
+ )
+ .order_by(GtfsStop.name, GtfsStop.stop_id)
+ ).all()
+ scheduled_by_dataset = {} if settings.is_postgresql_database else {dataset_id: all_scheduled_stop_ids(db, dataset_id) for dataset_id in dataset_ids}
+ grouped: dict[tuple[int, str], list[GtfsStop]] = {}
+ for stop in rows:
+ if scheduled_by_dataset and stop.stop_id not in scheduled_by_dataset.get(stop.dataset_id, set()):
+ continue
+ group_id = logical_stop_group_id(stop)
+ key = (stop.dataset_id, group_id)
+ if key not in requested:
+ continue
+ bucket = grouped.setdefault(key, [])
+ if len(bucket) < MAX_GROUP_STOP_IDS:
+ bucket.append(stop)
+ return grouped
+
+
+def _scheduled_gtfs_stop_condition():
+ scheduled_exists = (
+ select(GtfsStopTime.id)
+ .where(
+ GtfsStopTime.dataset_id == GtfsStop.dataset_id,
+ GtfsStopTime.stop_id == GtfsStop.stop_id,
+ )
+ .limit(1)
+ .exists()
+ )
+ return (scheduled_exists,)
+
+
+def _best_display_stop(group_id: str, matches: list[GtfsStop], scheduled_stops: list[GtfsStop]) -> GtfsStop:
+ candidates = [*matches, *scheduled_stops]
+ return min(
+ candidates,
+ key=lambda stop: (
+ 0 if stop.stop_id == group_id and stop.parent_station is None else 1,
+ 0 if stop.parent_station == group_id else 1,
+ 0 if stop.parent_station is not None else 1,
+ 0 if stop.lat is not None and stop.lon is not None else 1,
+ stop.name or "",
+ stop.stop_id,
+ ),
+ )
+
+
+def _resolve_stop_selection(db: Session, value: int | str, source_ids: list[int] | None = None) -> StopSelection:
+ token = str(value).strip()
+ if is_location_token(token):
+ raise ValueError("selected location must be routed through location-aware search")
+ active_dataset_ids = _active_gtfs_dataset_ids(db, source_ids=source_ids)
+ if token.startswith(STOP_PLACE_PREFIX):
+ canonical_stop_id, dataset_id = _parse_stop_place_token(token)
+ return _selection_for_canonical_stop(
+ db,
+ canonical_stop_id,
+ dataset_ids=active_dataset_ids,
+ preferred_dataset_id=dataset_id,
+ )
+
+ if token.startswith(STOP_GROUP_PREFIX):
+ dataset_id, group_id = _parse_stop_group_token(token)
+ selection = _selection_for_group(db, dataset_id, group_id)
+ if selection.canonical_stop_id is not None:
+ return _selection_for_canonical_stop(
+ db,
+ selection.canonical_stop_id,
+ dataset_ids=active_dataset_ids,
+ preferred_dataset_id=dataset_id,
+ )
+ return selection
+
+ exact_external_stop_id = False
+ if token.startswith(STOP_EXACT_PREFIX):
+ token = token[len(STOP_EXACT_PREFIX) :]
+ exact_external_stop_id = True
+
+ stop = _active_stop_by_external_stop_id(db, token, active_dataset_ids) if token else None
+ if stop is None and not exact_external_stop_id and token.isdigit():
+ candidate = db.get(GtfsStop, int(token))
+ if candidate is not None and (not active_dataset_ids or candidate.dataset_id in active_dataset_ids):
+ stop = candidate
+ if stop is None:
+ raise ValueError("from_stop_id and to_stop_id must reference existing GTFS stops")
+
+ return _selection_for_stop(db, stop, active_dataset_ids)
+
+
+def resolve_location_summary(db: Session, value: int | str, source_ids: list[int] | None = None) -> StopSummary:
+ token = str(value).strip()
+ if is_coordinate_token(token):
+ lat, lon = parse_coordinate_token(token)
+ return _coordinate_summary(db, lat, lon)
+ if is_address_point_token(token):
+ address, lat, lon = address_point_by_token(db, token)
+ return _address_summary(address, db=db, lat=lat, lon=lon, street_level=True)
+ if is_address_token(token):
+ return _address_summary(address_by_token(db, token), db=db)
+ return _resolve_stop_selection(db, token, source_ids=source_ids).display
+
+
+def _address_summary(
+ address: OsmAddress,
+ *,
+ db: Session | None = None,
+ lat: float | None = None,
+ lon: float | None = None,
+ street_level: bool = False,
+) -> StopSummary:
+ resolved_lat = address.lat if lat is None else lat
+ resolved_lon = address.lon if lon is None else lon
+ snapped = _snap_walk_location(db, lat=resolved_lat, lon=resolved_lon)
+ if snapped is not None:
+ resolved_lat, resolved_lon = snapped
+ if street_level and resolved_lat is not None and resolved_lon is not None:
+ stop_id = address_point_token(address.id, float(resolved_lat), float(resolved_lon))
+ name = _street_address_name(address)
+ else:
+ stop_id = address_token(address.id)
+ name = address.display_name
+ return StopSummary(
+ id=address.id,
+ dataset_id=address.dataset_id,
+ stop_id=stop_id,
+ name=name,
+ lat=resolved_lat,
+ lon=resolved_lon,
+ )
+
+
+def _coordinate_summary(db: Session, lat: float, lon: float) -> StopSummary:
+ token = coordinate_token(lat, lon)
+ snapped = _snap_walk_location(db, lat=lat, lon=lon)
+ resolved_lat, resolved_lon = snapped if snapped is not None else (float(lat), float(lon))
+ return StopSummary(
+ id=0,
+ dataset_id=0,
+ stop_id=token,
+ name=f"Map point {lat:.5f}, {lon:.5f}",
+ lat=resolved_lat,
+ lon=resolved_lon,
+ )
+
+
+def _snap_walk_location(db: Session | None, *, lat: float | None, lon: float | None) -> tuple[float, float] | None:
+ if db is None or lat is None or lon is None:
+ return None
+ try:
+ snapped = snap_point_to_routing_graph(db, lon=float(lon), lat=float(lat), mode="walk", max_distance_m=250)
+ except Exception: # noqa: BLE001 - snapping must not break address/coordinate routing
+ return None
+ if snapped is None:
+ return None
+ return float(snapped["lat"]), float(snapped["lon"])
+
+
+def _street_address_name(address: OsmAddress) -> str:
+ local_name = address.street or address.place or address.name or address.display_name or "Address"
+ locality = " ".join(str(part) for part in [address.postcode, address.city] if part).strip()
+ return f"{local_name}, {locality}" if locality else str(local_name)
+
+
+def _active_stop_by_external_stop_id(db: Session, stop_id: str, active_dataset_ids: list[int]) -> GtfsStop | None:
+ stmt = (
+ select(GtfsStop)
+ .join(Dataset, Dataset.id == GtfsStop.dataset_id)
+ .where(Dataset.is_active.is_(True), Dataset.kind == "gtfs", GtfsStop.stop_id == stop_id)
+ .order_by(
+ GtfsStop.dataset_id,
+ case((GtfsStop.parent_station.is_(None), 0), else_=1),
+ GtfsStop.id,
+ )
+ )
+ if active_dataset_ids:
+ stmt = stmt.where(GtfsStop.dataset_id.in_(active_dataset_ids))
+ return db.scalar(stmt)
+
+
+def _selection_for_stop(db: Session, stop: GtfsStop, active_dataset_ids: list[int]) -> StopSelection:
+ if _has_scheduled_stop(db, stop):
+ canonical = canonical_stop_for_gtfs_stop(db, stop)
+ if canonical is not None:
+ return _selection_for_canonical_stop(
+ db,
+ canonical.id,
+ dataset_ids=active_dataset_ids,
+ preferred_dataset_id=stop.dataset_id,
+ )
+ return StopSelection(
+ display=_stop_summary(stop),
+ stop_ids_by_dataset={stop.dataset_id: (stop.stop_id,)},
+ canonical_stop_id=None if canonical is None else canonical.id,
+ )
+
+ selection = _selection_for_group(db, stop.dataset_id, stop.parent_station or stop.stop_id)
+ if selection.canonical_stop_id is not None:
+ return _selection_for_canonical_stop(
+ db,
+ selection.canonical_stop_id,
+ dataset_ids=active_dataset_ids,
+ preferred_dataset_id=stop.dataset_id,
+ )
+ return selection
+
+
+def _parse_stop_group_token(token: str) -> tuple[int, str]:
+ rest = token[len(STOP_GROUP_PREFIX) :]
+ if ":" not in rest:
+ raise ValueError("invalid grouped stop token")
+ dataset_text, group_id = rest.split(":", 1)
+ try:
+ dataset_id = int(dataset_text)
+ except ValueError as exc:
+ raise ValueError("invalid grouped stop token") from exc
+ if not group_id:
+ raise ValueError("invalid grouped stop token")
+ return dataset_id, group_id
+
+
+def _parse_stop_place_token(token: str) -> tuple[int, int]:
+ rest = token[len(STOP_PLACE_PREFIX) :]
+ if ":" not in rest:
+ raise ValueError("invalid canonical stop token")
+ canonical_text, dataset_text = rest.split(":", 1)
+ try:
+ canonical_stop_id = int(canonical_text)
+ dataset_id = int(dataset_text)
+ except ValueError as exc:
+ raise ValueError("invalid canonical stop token") from exc
+ return canonical_stop_id, dataset_id
+
+
+def _selection_for_canonical_stop(
+ db: Session,
+ canonical_stop_id: int,
+ dataset_ids: list[int] | None = None,
+ preferred_dataset_id: int | None = None,
+) -> StopSelection:
+ canonical = db.get(CanonicalStop, canonical_stop_id)
+ if canonical is None:
+ raise ValueError("selected stop place does not exist")
+ active_dataset_ids = _active_gtfs_dataset_ids(db) if dataset_ids is None else dataset_ids
+ stop_ids_by_dataset = _gtfs_stop_ids_for_canonical_stop_by_dataset(db, canonical_stop_id, active_dataset_ids)
+ scheduled_by_dataset: dict[int, tuple[str, ...]] = {}
+ for dataset_id in _preferred_dataset_order(stop_ids_by_dataset, preferred_dataset_id):
+ scheduled_stop_ids = _scheduled_stop_ids(db, dataset_id, stop_ids_by_dataset[dataset_id])
+ if scheduled_stop_ids:
+ scheduled_by_dataset[dataset_id] = scheduled_stop_ids
+ if not scheduled_by_dataset:
+ raise ValueError("selected stop place has no imported scheduled stop_times in the selected source scope")
+ display_dataset_id = preferred_dataset_id if preferred_dataset_id in scheduled_by_dataset else next(iter(scheduled_by_dataset))
+ return StopSelection(
+ display=StopSummary(
+ id=canonical.id,
+ dataset_id=display_dataset_id,
+ stop_id=f"canonical:{canonical.id}",
+ name=canonical.name,
+ lat=canonical.lat,
+ lon=canonical.lon,
+ ),
+ stop_ids_by_dataset=scheduled_by_dataset,
+ canonical_stop_id=canonical.id,
+ )
+
+
+def _selection_for_group(db: Session, dataset_id: int, group_id: str) -> StopSelection:
+ scheduled = _scheduled_stops_for_groups(db, [(dataset_id, group_id)]).get((dataset_id, group_id), [])
+ if not scheduled:
+ raise ValueError("selected stop group has no imported scheduled stop_times")
+ parent = _parent_stops_for_groups(db, [(dataset_id, group_id)]).get((dataset_id, group_id))
+ display = parent or _best_display_stop(group_id, [], scheduled)
+ canonical = _canonical_stop_for_group(db, scheduled)
+ return StopSelection(
+ display=_stop_summary(display),
+ stop_ids_by_dataset={dataset_id: tuple(stop.stop_id for stop in scheduled[:MAX_GROUP_STOP_IDS])},
+ canonical_stop_id=None if canonical is None else canonical.id,
+ )
+
+
+def _gtfs_stop_ids_for_canonical_stop_by_dataset(
+ db: Session, canonical_stop_id: int, dataset_ids: list[int]
+) -> dict[int, tuple[str, ...]]:
+ if not dataset_ids:
+ return {}
+ rows = db.execute(
+ select(CanonicalStopLink.dataset_id, CanonicalStopLink.external_id)
+ .where(
+ CanonicalStopLink.canonical_stop_id == canonical_stop_id,
+ CanonicalStopLink.object_type == "gtfs_stop",
+ CanonicalStopLink.dataset_id.in_(dataset_ids),
+ )
+ .order_by(CanonicalStopLink.dataset_id, CanonicalStopLink.role, CanonicalStopLink.external_id)
+ ).all()
+ grouped: dict[int, list[str]] = {}
+ for dataset_id, external_id in rows:
+ bucket = grouped.setdefault(int(dataset_id), [])
+ if len(bucket) < MAX_GROUP_STOP_IDS:
+ bucket.append(str(external_id))
+ return {dataset_id: tuple(stop_ids) for dataset_id, stop_ids in grouped.items()}
+
+
+def _preferred_dataset_order(stop_ids_by_dataset: dict[int, tuple[str, ...]], preferred_dataset_id: int | None) -> list[int]:
+ dataset_ids = sorted(stop_ids_by_dataset)
+ if preferred_dataset_id is None or preferred_dataset_id not in stop_ids_by_dataset:
+ return dataset_ids
+ return [preferred_dataset_id, *[dataset_id for dataset_id in dataset_ids if dataset_id != preferred_dataset_id]]
+
+
+def _scheduled_stop_ids(db: Session, dataset_id: int, stop_ids: tuple[str, ...]) -> tuple[str, ...]:
+ return storage_scheduled_stop_ids(db, dataset_id, stop_ids)[:MAX_GROUP_STOP_IDS]
+
+
+def _has_scheduled_stop(db: Session, stop: GtfsStop) -> bool:
+ return storage_has_scheduled_stop(db, stop.dataset_id, stop.stop_id)
+
+
+def find_journeys(
+ db: Session,
+ from_stop_id: int | str,
+ to_stop_id: int | str,
+ departure: str,
+ max_transfers: int = 0,
+ limit: int = 5,
+ transfer_seconds: int = 120,
+ source_ids: list[int] | None = None,
+ via_stop_id: int | str | None = None,
+ service_date: str | date | None = None,
+ _allow_access_transfer: bool = True,
+ _allow_address_access: bool = True,
+) -> dict:
+ if via_stop_id is not None and str(via_stop_id).strip():
+ return _find_journeys_via(
+ db=db,
+ from_stop_id=from_stop_id,
+ via_stop_id=via_stop_id,
+ to_stop_id=to_stop_id,
+ departure=departure,
+ max_transfers=max_transfers,
+ transfer_seconds=transfer_seconds,
+ limit=limit,
+ source_ids=source_ids,
+ service_date=service_date,
+ )
+
+ if _allow_address_access and (is_location_token(from_stop_id) or is_location_token(to_stop_id)):
+ return _find_journeys_with_address_access(
+ db=db,
+ from_stop_id=from_stop_id,
+ to_stop_id=to_stop_id,
+ departure=departure,
+ max_transfers=max_transfers,
+ transfer_seconds=transfer_seconds,
+ limit=limit,
+ source_ids=source_ids,
+ service_date=service_date,
+ )
+
+ from_selection = _resolve_stop_selection(db, from_stop_id, source_ids=source_ids)
+ to_selection = _resolve_stop_selection(db, to_stop_id, source_ids=source_ids)
+ departure_seconds = parse_gtfs_time(departure)
+ if departure_seconds is None:
+ raise ValueError("departure must be HH:MM or HH:MM:SS")
+ parsed_service_date = parse_service_date(service_date)
+
+ stop_cache: dict[tuple[int, str], StopSummary] = {}
+ for dataset_id, stop_ids in from_selection.stop_ids_by_dataset.items():
+ for stop_id in stop_ids:
+ stop_cache.setdefault((dataset_id, stop_id), _stop_summary_for_stop_id(db, dataset_id, stop_id))
+ for dataset_id, stop_ids in to_selection.stop_ids_by_dataset.items():
+ for stop_id in stop_ids:
+ stop_cache.setdefault((dataset_id, stop_id), _stop_summary_for_stop_id(db, dataset_id, stop_id))
+ osm_stop_cache: dict[tuple[int, str], dict] = {}
+ max_journeys = max(1, min(limit, 10))
+ common_dataset_ids = sorted(set(from_selection.stop_ids_by_dataset) & set(to_selection.stop_ids_by_dataset))
+ service_ids_by_dataset = _service_ids_by_dataset(db, sorted(set(from_selection.stop_ids_by_dataset) | set(to_selection.stop_ids_by_dataset)), parsed_service_date)
+ direct: list[dict] = []
+ for dataset_id in common_dataset_ids:
+ service_ids = service_ids_by_dataset.get(dataset_id)
+ if service_ids == set():
+ continue
+ direct.extend(
+ _find_direct_journeys(
+ db=db,
+ dataset_id=dataset_id,
+ service_ids=service_ids,
+ from_stop_ids=from_selection.stop_ids_by_dataset[dataset_id],
+ to_stop_ids=to_selection.stop_ids_by_dataset[dataset_id],
+ earliest_departure=departure_seconds,
+ limit=max_journeys,
+ stop_cache=stop_cache,
+ osm_stop_cache=osm_stop_cache,
+ )
+ )
+ direct = sorted(direct, key=_journey_sort_key)[:max_journeys]
+ if max_transfers > 0:
+ direct_arrival = direct[0]["arrival_seconds"] if direct else None
+ transfer_journeys: list[dict] = []
+ for first_dataset_id, second_dataset_id in _journey_dataset_pairs(from_selection, to_selection):
+ first_service_ids = service_ids_by_dataset.get(first_dataset_id)
+ second_service_ids = service_ids_by_dataset.get(second_dataset_id)
+ if first_service_ids == set() or second_service_ids == set():
+ continue
+ transfer_journeys.extend(
+ _find_one_transfer_journeys(
+ db=db,
+ first_dataset_id=first_dataset_id,
+ second_dataset_id=second_dataset_id,
+ first_service_ids=first_service_ids,
+ second_service_ids=second_service_ids,
+ from_stop_ids=from_selection.stop_ids_by_dataset[first_dataset_id],
+ to_stop_ids=to_selection.stop_ids_by_dataset[second_dataset_id],
+ origin_canonical_stop_id=from_selection.canonical_stop_id,
+ target_canonical_stop_id=to_selection.canonical_stop_id,
+ earliest_departure=departure_seconds,
+ latest_arrival=direct_arrival,
+ transfer_seconds=max(0, transfer_seconds),
+ limit=max_journeys,
+ stop_cache=stop_cache,
+ osm_stop_cache=osm_stop_cache,
+ )
+ )
+ transfer_journeys = sorted(
+ transfer_journeys,
+ key=_journey_sort_key,
+ )[: max_journeys * 3]
+ if max_transfers > 1:
+ best_known_arrival = min(
+ (
+ int(journey["arrival_seconds"])
+ for journey in [*direct, *transfer_journeys]
+ if journey.get("arrival_seconds") is not None
+ ),
+ default=None,
+ )
+ round_journeys: list[dict] = []
+ for dataset_id in common_dataset_ids:
+ service_ids = service_ids_by_dataset.get(dataset_id)
+ if service_ids == set():
+ continue
+ round_journeys.extend(
+ _find_round_journeys(
+ db=db,
+ dataset_id=dataset_id,
+ service_ids=service_ids,
+ from_selection=from_selection,
+ to_selection=to_selection,
+ earliest_departure=departure_seconds,
+ max_transfers=max(0, max_transfers),
+ transfer_seconds=max(0, transfer_seconds),
+ latest_arrival=best_known_arrival,
+ limit=max_journeys,
+ stop_cache=stop_cache,
+ osm_stop_cache=osm_stop_cache,
+ )
+ )
+ transfer_journeys = sorted(
+ [*transfer_journeys, *round_journeys],
+ key=_journey_sort_key,
+ )[: max_journeys * 3]
+ else:
+ transfer_journeys = []
+ walk_journey = _find_walk_only_journey(
+ db,
+ from_selection=from_selection,
+ to_selection=to_selection,
+ departure_seconds=departure_seconds,
+ )
+ walk_journeys = [] if walk_journey is None else [walk_journey]
+ journeys = _filter_reasonable_journeys([*walk_journeys, *transfer_journeys, *direct])
+
+ unique: dict[tuple[str, ...], dict] = {}
+ for journey in sorted(journeys, key=_journey_sort_key):
+ key = tuple(_journey_leg_signature(leg) for leg in journey["legs"])
+ unique.setdefault(key, journey)
+
+ selected = _select_diverse_journeys(unique.values(), limit=max_journeys)
+ if not selected and _allow_access_transfer and max_transfers > 0:
+ access_journeys = _find_access_transfer_journeys(
+ db=db,
+ from_selection=from_selection,
+ to_stop_id=to_stop_id,
+ earliest_departure=departure_seconds,
+ max_transfers=max_transfers,
+ transfer_seconds=max(0, transfer_seconds),
+ limit=max_journeys,
+ source_ids=source_ids,
+ service_date=parsed_service_date,
+ stop_cache=stop_cache,
+ osm_stop_cache=osm_stop_cache,
+ )
+ selected = list(
+ {
+ tuple(_journey_leg_signature(leg) for leg in journey["legs"]): journey
+ for journey in sorted(access_journeys, key=_journey_sort_key)
+ }.values()
+ )[:max_journeys]
+ selected_dataset_ids = sorted(
+ {
+ int(leg["dataset_id"])
+ for journey in selected
+ for leg in journey.get("legs", [])
+ if leg.get("dataset_id") is not None
+ }
+ )
+ searched_dataset_ids = sorted(set(from_selection.stop_ids_by_dataset) | set(to_selection.stop_ids_by_dataset))
+ source_payloads = _source_payloads_for_dataset_ids(db, selected_dataset_ids or searched_dataset_ids)
+ return {
+ "from": _stop_payload(from_selection.display),
+ "to": _stop_payload(to_selection.display),
+ "source": source_payloads[0] if len(source_payloads) == 1 else None,
+ "sources": source_payloads,
+ "dataset_id": selected_dataset_ids[0] if len(selected_dataset_ids) == 1 else None,
+ "dataset_ids": selected_dataset_ids or searched_dataset_ids,
+ "departure_time": format_gtfs_time(departure_seconds),
+ "departure_time_label": format_gtfs_time_label(departure_seconds),
+ "service_date": None if parsed_service_date is None else parsed_service_date.isoformat(),
+ "max_transfers": max(0, max_transfers),
+ "journeys": selected,
+ }
+
+
+def _find_journeys_with_address_access(
+ db: Session,
+ from_stop_id: int | str,
+ to_stop_id: int | str,
+ departure: str,
+ max_transfers: int,
+ transfer_seconds: int,
+ limit: int,
+ source_ids: list[int] | None,
+ service_date: str | date | None,
+) -> dict:
+ departure_seconds = parse_gtfs_time(departure)
+ if departure_seconds is None:
+ raise ValueError("departure must be HH:MM or HH:MM:SS")
+ parsed_service_date = parse_service_date(service_date)
+ active_dataset_ids = _active_gtfs_dataset_ids(db, source_ids=source_ids)
+ from_location = resolve_location_summary(db, from_stop_id, source_ids=source_ids)
+ to_location = resolve_location_summary(db, to_stop_id, source_ids=source_ids)
+ max_journeys = max(1, min(limit, 10))
+
+ journeys: list[dict] = []
+ direct_walk = _walk_only_journey_between_summaries(
+ db,
+ from_location=from_location,
+ to_location=to_location,
+ departure_seconds=departure_seconds,
+ dataset_id=(active_dataset_ids[0] if active_dataset_ids else from_location.dataset_id),
+ route_geometry=True,
+ )
+ if direct_walk is not None:
+ journeys.append(direct_walk)
+
+ origin_is_address = is_location_token(from_stop_id)
+ destination_is_address = is_location_token(to_stop_id)
+ short_direct_walk_only = (
+ direct_walk is not None
+ and origin_is_address != destination_is_address
+ and int(direct_walk.get("duration_seconds") or 0) <= ADDRESS_ACCESS_SHORT_DIRECT_WALK_SECONDS
+ )
+ access_distance_m = (
+ _distance_m(float(from_location.lat), float(from_location.lon), float(to_location.lat), float(to_location.lon))
+ if from_location.lat is not None
+ and from_location.lon is not None
+ and to_location.lat is not None
+ and to_location.lon is not None
+ else 0
+ )
+ include_major_hubs = (
+ origin_is_address
+ and destination_is_address
+ and access_distance_m >= ADDRESS_ACCESS_LONG_DISTANCE_HUB_THRESHOLD_M
+ )
+ origin_candidates = _location_stop_candidates(
+ db,
+ from_stop_id,
+ from_location,
+ active_dataset_ids,
+ source_ids=source_ids,
+ include_major_hubs=include_major_hubs,
+ )
+ destination_candidates = _location_stop_candidates(
+ db,
+ to_stop_id,
+ to_location,
+ active_dataset_ids,
+ source_ids=source_ids,
+ include_major_hubs=include_major_hubs,
+ )
+ if short_direct_walk_only:
+ origin_candidates = []
+ destination_candidates = []
+ candidate_pairs = []
+ else:
+ candidate_pairs = _address_access_candidate_pairs(
+ origin_candidates,
+ destination_candidates,
+ origin_is_address=origin_is_address,
+ destination_is_address=destination_is_address,
+ max_pairs=ADDRESS_ACCESS_MAX_DEEP_PAIR_CANDIDATES if max_transfers > 1 else ADDRESS_ACCESS_MAX_PAIR_CANDIDATES,
+ )
+ access_leg_cache: dict[str, dict | None] = {}
+ transit_departure_cache: dict[str, int | None] = {}
+ for origin, destination in candidate_pairs:
+ access_leg = access_leg_cache.get(origin.token)
+ transit_departure_seconds = transit_departure_cache.get(origin.token)
+ if origin.token not in transit_departure_cache:
+ access_leg = None
+ transit_departure_seconds = departure_seconds
+ if origin_is_address:
+ access_leg = _walk_leg_between_summaries(
+ db,
+ from_stop=from_location,
+ to_stop=origin.selection.display,
+ departure_seconds=departure_seconds,
+ dataset_id=origin.selection.dataset_id,
+ max_duration_seconds=ADDRESS_ACCESS_MAX_SECONDS,
+ route_geometry=True,
+ )
+ if access_leg is None:
+ transit_departure_seconds = None
+ else:
+ transit_departure_seconds = int(access_leg["arrival_seconds"])
+ access_leg_cache[origin.token] = access_leg
+ transit_departure_cache[origin.token] = transit_departure_seconds
+ if transit_departure_seconds is None:
+ continue
+ transit_departure = format_gtfs_time(transit_departure_seconds)
+ if transit_departure is None:
+ continue
+ try:
+ transit = find_journeys(
+ db=db,
+ from_stop_id=origin.token,
+ to_stop_id=destination.token,
+ departure=transit_departure,
+ max_transfers=max_transfers,
+ transfer_seconds=transfer_seconds,
+ limit=max(max_journeys, 6),
+ source_ids=source_ids,
+ service_date=parsed_service_date,
+ _allow_access_transfer=include_major_hubs,
+ _allow_address_access=False,
+ )
+ except ValueError:
+ continue
+ for transit_journey in transit.get("journeys", [])[: max_journeys * 2]:
+ egress_leg = None
+ if destination_is_address:
+ arrival_seconds = transit_journey.get("arrival_seconds")
+ if arrival_seconds is None:
+ continue
+ egress_leg = _walk_leg_between_summaries(
+ db,
+ from_stop=destination.selection.display,
+ to_stop=to_location,
+ departure_seconds=int(arrival_seconds),
+ dataset_id=destination.selection.dataset_id,
+ max_duration_seconds=ADDRESS_ACCESS_MAX_SECONDS,
+ route_geometry=True,
+ )
+ if egress_leg is None:
+ continue
+ combined = _compose_address_access_journey(
+ transit_journey,
+ access_leg=access_leg,
+ egress_leg=egress_leg,
+ )
+ if combined is not None:
+ journeys.append(combined)
+ if include_major_hubs and len(journeys) >= max_journeys:
+ break
+ if include_major_hubs and len(journeys) >= max_journeys:
+ break
+
+ unique: dict[tuple[str, ...], dict] = {}
+ for journey in sorted(_filter_reasonable_journeys(journeys), key=_journey_sort_key):
+ key = tuple(_journey_leg_signature(leg) for leg in journey["legs"])
+ unique.setdefault(key, journey)
+ selected = _select_diverse_journeys(unique.values(), limit=max_journeys)
+ selected_dataset_ids = sorted(
+ {
+ int(leg["dataset_id"])
+ for journey in selected
+ for leg in journey.get("legs", [])
+ if leg.get("dataset_id") is not None
+ }
+ )
+ searched_dataset_ids = sorted(active_dataset_ids)
+ source_payloads = _source_payloads_for_dataset_ids(db, selected_dataset_ids or searched_dataset_ids)
+ diagnostics = {
+ "address_access": {
+ "origin_candidates": len(origin_candidates),
+ "destination_candidates": len(destination_candidates),
+ "searched_pairs": len(candidate_pairs),
+ "max_pairs": ADDRESS_ACCESS_MAX_DEEP_PAIR_CANDIDATES if max_transfers > 1 else ADDRESS_ACCESS_MAX_PAIR_CANDIDATES,
+ "major_hubs": include_major_hubs,
+ }
+ }
+ return {
+ "from": _stop_payload(from_location),
+ "to": _stop_payload(to_location),
+ "source": source_payloads[0] if len(source_payloads) == 1 else None,
+ "sources": source_payloads,
+ "dataset_id": selected_dataset_ids[0] if len(selected_dataset_ids) == 1 else None,
+ "dataset_ids": selected_dataset_ids or searched_dataset_ids,
+ "departure_time": format_gtfs_time(departure_seconds),
+ "departure_time_label": format_gtfs_time_label(departure_seconds),
+ "service_date": None if parsed_service_date is None else parsed_service_date.isoformat(),
+ "max_transfers": max(0, max_transfers),
+ "diagnostics": diagnostics,
+ "journeys": selected,
+ }
+
+
+def _address_access_candidate_pairs(
+ origins: list[_AccessStopCandidate],
+ destinations: list[_AccessStopCandidate],
+ *,
+ origin_is_address: bool,
+ destination_is_address: bool,
+ max_pairs: int,
+) -> list[tuple[_AccessStopCandidate, _AccessStopCandidate]]:
+ pairs = [
+ (
+ (origin.distance_m if origin_is_address else 0) + (destination.distance_m if destination_is_address else 0),
+ origin,
+ destination,
+ )
+ for origin in origins
+ for destination in destinations
+ ]
+ pairs.sort(key=lambda item: (item[0], item[1].distance_m, item[2].distance_m, item[1].token, item[2].token))
+ if not origin_is_address or not destination_is_address:
+ return [(origin, destination) for _, origin, destination in pairs]
+
+ closest_count = max(2, max_pairs // 2)
+ selected: list[tuple[float, _AccessStopCandidate, _AccessStopCandidate]] = []
+ seen: set[tuple[str, str]] = set()
+
+ priority_pairs = sorted(
+ (
+ item
+ for item in pairs
+ if item[2].priority < ADDRESS_ACCESS_NORMAL_PRIORITY
+ ),
+ key=lambda item: (
+ item[2].priority,
+ item[1].distance_m,
+ item[2].distance_m,
+ item[0],
+ item[1].token,
+ item[2].token,
+ ),
+ )
+
+ def append_item(item: tuple[float, _AccessStopCandidate, _AccessStopCandidate]) -> bool:
+ _, origin, destination = item
+ key = (origin.token, destination.token)
+ if key in seen:
+ return False
+ seen.add(key)
+ selected.append(item)
+ return True
+
+ priority_budget = max(0, max_pairs - closest_count)
+ if priority_budget > 0:
+ for item in priority_pairs:
+ append_item(item)
+ if len(selected) >= priority_budget:
+ break
+
+ for item in pairs[:closest_count]:
+ append_item(item)
+
+ for item in priority_pairs:
+ append_item(item)
+ if len(selected) >= max_pairs:
+ break
+
+ for item in pairs:
+ append_item(item)
+ if len(selected) >= max_pairs:
+ break
+
+ return [(origin, destination) for _, origin, destination in selected[:max_pairs]]
+
+
+def _location_stop_candidates(
+ db: Session,
+ token: int | str,
+ location: StopSummary,
+ active_dataset_ids: list[int],
+ *,
+ source_ids: list[int] | None,
+ include_major_hubs: bool = False,
+) -> list[_AccessStopCandidate]:
+ if not is_location_token(token):
+ selection = _resolve_stop_selection(db, token, source_ids=source_ids)
+ if selection.canonical_stop_id is not None:
+ return [
+ _AccessStopCandidate(
+ token=_stop_place_token(selection.canonical_stop_id, selection.dataset_id),
+ selection=selection,
+ distance_m=0,
+ )
+ ]
+ return [_AccessStopCandidate(token=str(token), selection=selection, distance_m=0)]
+ if location.lon is None or location.lat is None or not active_dataset_ids:
+ return []
+ rows = (
+ _nearby_canonical_stops_postgresql(db, location, active_dataset_ids)
+ if settings.is_postgresql_database
+ else _nearby_canonical_stops_sqlite(db, location, active_dataset_ids)
+ )
+ candidates: list[_AccessStopCandidate] = []
+ seen: set[int] = set()
+ for canonical_stop_id, preferred_dataset_id, distance_m in rows:
+ if int(canonical_stop_id) in seen:
+ continue
+ seen.add(int(canonical_stop_id))
+ try:
+ selection = _selection_for_canonical_stop(
+ db,
+ int(canonical_stop_id),
+ dataset_ids=active_dataset_ids,
+ preferred_dataset_id=int(preferred_dataset_id),
+ )
+ except ValueError:
+ continue
+ candidates.append(
+ _AccessStopCandidate(
+ token=_stop_place_token(int(canonical_stop_id), selection.dataset_id),
+ selection=selection,
+ distance_m=float(distance_m or 0),
+ )
+ )
+ if len(candidates) >= ADDRESS_ACCESS_STOP_CANDIDATES:
+ break
+ if include_major_hubs:
+ candidates = _merge_access_stop_candidates(
+ candidates,
+ _location_major_hub_stop_candidates(db, token, location, active_dataset_ids),
+ )
+ return candidates
+
+
+def _merge_access_stop_candidates(
+ primary: list[_AccessStopCandidate],
+ extra: list[_AccessStopCandidate],
+) -> list[_AccessStopCandidate]:
+ merged = list(primary)
+ seen_tokens = {candidate.token for candidate in merged}
+ seen_canonical_ids = {
+ candidate.selection.canonical_stop_id
+ for candidate in merged
+ if candidate.selection.canonical_stop_id is not None
+ }
+ for candidate in extra:
+ canonical_stop_id = candidate.selection.canonical_stop_id
+ if candidate.token in seen_tokens or (canonical_stop_id is not None and canonical_stop_id in seen_canonical_ids):
+ continue
+ merged.append(candidate)
+ seen_tokens.add(candidate.token)
+ if canonical_stop_id is not None:
+ seen_canonical_ids.add(canonical_stop_id)
+ return merged
+
+
+def _location_major_hub_stop_candidates(
+ db: Session,
+ token: int | str,
+ location: StopSummary,
+ active_dataset_ids: list[int],
+) -> list[_AccessStopCandidate]:
+ if location.lon is None or location.lat is None or not active_dataset_ids:
+ return []
+ locality = _address_city_for_token(db, token)
+ rows = (
+ _major_hub_canonical_stops_postgresql(db, location, active_dataset_ids, locality=locality)
+ if settings.is_postgresql_database
+ else _major_hub_canonical_stops_sqlite(db, location, active_dataset_ids, locality=locality)
+ )
+ candidates: list[_AccessStopCandidate] = []
+ seen: set[int] = set()
+ for canonical_stop_id, preferred_dataset_id, distance_m in rows:
+ if int(canonical_stop_id) in seen:
+ continue
+ seen.add(int(canonical_stop_id))
+ try:
+ selection = _selection_for_canonical_stop(
+ db,
+ int(canonical_stop_id),
+ dataset_ids=active_dataset_ids,
+ preferred_dataset_id=int(preferred_dataset_id),
+ )
+ except ValueError:
+ continue
+ candidates.append(
+ _AccessStopCandidate(
+ token=_stop_place_token(int(canonical_stop_id), selection.dataset_id),
+ selection=selection,
+ distance_m=float(distance_m or 0),
+ priority=ADDRESS_ACCESS_MAJOR_HUB_PRIORITY,
+ )
+ )
+ if len(candidates) >= ADDRESS_ACCESS_MAJOR_HUB_CANDIDATES:
+ break
+ return candidates
+
+
+def _address_city_for_token(db: Session, token: int | str) -> str:
+ try:
+ if is_coordinate_token(token):
+ return ""
+ if is_address_point_token(token):
+ address, _, _ = address_point_by_token(db, token)
+ elif is_address_token(token):
+ address = address_by_token(db, token)
+ else:
+ return ""
+ except ValueError:
+ return ""
+ return _normalize_stop_search(address.city or "")
+
+
+def _is_major_station_name(value: str | None) -> bool:
+ normalized = _normalize_stop_search(value or "")
+ return (
+ bool(re.search(r"(^|[\s,(/-])hbf($|[\s,)/-])", normalized))
+ or "hauptbahnhof" in normalized
+ or "central station" in normalized
+ or "central train station" in normalized
+ )
+
+
+def _major_hub_canonical_stops_postgresql(
+ db: Session,
+ location: StopSummary,
+ active_dataset_ids: list[int],
+ *,
+ locality: str,
+) -> list[tuple[int, int, float]]:
+ radius_deg = ADDRESS_ACCESS_MAJOR_HUB_RADIUS_M / 111_320
+ stmt = text(
+ """
+ WITH point AS (
+ SELECT ST_SetSRID(ST_MakePoint(:lon, :lat), 4326) AS geom
+ ),
+ hub_rows AS (
+ SELECT
+ canonical_stops.id AS canonical_stop_id,
+ canonical_stop_links.dataset_id AS dataset_id,
+ ST_DistanceSphere(canonical_stops.geom, point.geom) AS distance_m,
+ MIN(
+ CASE
+ WHEN :locality = '' THEN 1
+ WHEN LOWER(COALESCE(canonical_stops.name, '')) LIKE :locality_pattern THEN 0
+ ELSE 1
+ END
+ ) AS locality_rank
+ FROM canonical_stops
+ JOIN canonical_stop_links
+ ON canonical_stop_links.canonical_stop_id = canonical_stops.id
+ AND canonical_stop_links.object_type = 'gtfs_stop'
+ JOIN datasets
+ ON datasets.id = canonical_stop_links.dataset_id
+ AND datasets.kind = 'gtfs'
+ AND datasets.is_active IS TRUE
+ CROSS JOIN point
+ WHERE canonical_stop_links.dataset_id IN :dataset_ids
+ AND canonical_stops.geom IS NOT NULL
+ AND canonical_stops.geom && ST_Expand(point.geom, :radius_deg)
+ AND ST_DWithin(canonical_stops.geom::geography, point.geom::geography, :radius_m)
+ AND (
+ LOWER(COALESCE(canonical_stops.name, '')) ~ '(^|[[:space:],(/-])hbf($|[[:space:],)/-])'
+ OR LOWER(COALESCE(canonical_stops.name, '')) LIKE '%hauptbahnhof%'
+ OR LOWER(COALESCE(canonical_stops.name, '')) LIKE '%central station%'
+ OR LOWER(COALESCE(canonical_stops.name, '')) LIKE '%central train station%'
+ )
+ GROUP BY canonical_stops.id, canonical_stop_links.dataset_id, canonical_stops.geom, point.geom
+ )
+ SELECT canonical_stop_id, dataset_id, distance_m
+ FROM hub_rows
+ ORDER BY locality_rank, distance_m, canonical_stop_id
+ LIMIT :limit
+ """
+ ).bindparams(bindparam("dataset_ids", expanding=True))
+ rows = db.execute(
+ stmt,
+ {
+ "lon": float(location.lon),
+ "lat": float(location.lat),
+ "dataset_ids": tuple(active_dataset_ids),
+ "radius_deg": radius_deg,
+ "radius_m": ADDRESS_ACCESS_MAJOR_HUB_RADIUS_M,
+ "locality": locality,
+ "locality_pattern": f"%{locality}%" if locality else "",
+ "limit": ADDRESS_ACCESS_MAJOR_HUB_CANDIDATES * 6,
+ },
+ ).all()
+ return [(int(row[0]), int(row[1]), float(row[2] or 0)) for row in rows]
+
+
+def _major_hub_canonical_stops_sqlite(
+ db: Session,
+ location: StopSummary,
+ active_dataset_ids: list[int],
+ *,
+ locality: str,
+) -> list[tuple[int, int, float]]:
+ lon = float(location.lon)
+ lat = float(location.lat)
+ radius_deg = ADDRESS_ACCESS_MAJOR_HUB_RADIUS_M / 111_320
+ rows = db.execute(
+ select(
+ CanonicalStop.id,
+ CanonicalStopLink.dataset_id,
+ CanonicalStop.name,
+ CanonicalStop.lat,
+ CanonicalStop.lon,
+ )
+ .join(CanonicalStopLink, CanonicalStopLink.canonical_stop_id == CanonicalStop.id)
+ .where(
+ CanonicalStopLink.object_type == "gtfs_stop",
+ CanonicalStopLink.dataset_id.in_(active_dataset_ids),
+ CanonicalStop.lat.is_not(None),
+ CanonicalStop.lon.is_not(None),
+ CanonicalStop.lat >= lat - radius_deg,
+ CanonicalStop.lat <= lat + radius_deg,
+ CanonicalStop.lon >= lon - radius_deg,
+ CanonicalStop.lon <= lon + radius_deg,
+ )
+ .limit(ADDRESS_ACCESS_MAJOR_HUB_CANDIDATES * 100)
+ ).all()
+ result: list[tuple[int, int, float, int]] = []
+ seen: set[int] = set()
+ for canonical_stop_id, dataset_id, canonical_name, stop_lat, stop_lon in rows:
+ if not _is_major_station_name(canonical_name):
+ continue
+ distance_m = _distance_m(lat, lon, float(stop_lat), float(stop_lon))
+ if distance_m > ADDRESS_ACCESS_MAJOR_HUB_RADIUS_M:
+ continue
+ locality_rank = (
+ 0
+ if locality
+ and locality in _normalize_stop_search(canonical_name or "")
+ else 1
+ )
+ if int(canonical_stop_id) in seen:
+ continue
+ seen.add(int(canonical_stop_id))
+ result.append((int(canonical_stop_id), int(dataset_id), distance_m, locality_rank))
+ result.sort(key=lambda item: (item[3], item[2], item[0]))
+ return [(canonical_stop_id, dataset_id, distance_m) for canonical_stop_id, dataset_id, distance_m, _ in result]
+
+
+def _nearby_canonical_stops_postgresql(
+ db: Session,
+ location: StopSummary,
+ active_dataset_ids: list[int],
+) -> list[tuple[int, int, float]]:
+ radius_deg = ADDRESS_ACCESS_RADIUS_M / 111_320
+ stmt = text(
+ """
+ WITH point AS (
+ SELECT ST_SetSRID(ST_MakePoint(:lon, :lat), 4326) AS geom
+ )
+ SELECT
+ canonical_stops.id AS canonical_stop_id,
+ canonical_stop_links.dataset_id AS dataset_id,
+ ST_DistanceSphere(canonical_stops.geom, point.geom) AS distance_m
+ FROM canonical_stops
+ JOIN canonical_stop_links
+ ON canonical_stop_links.canonical_stop_id = canonical_stops.id
+ AND canonical_stop_links.object_type = 'gtfs_stop'
+ JOIN datasets
+ ON datasets.id = canonical_stop_links.dataset_id
+ AND datasets.kind = 'gtfs'
+ AND datasets.is_active IS TRUE
+ CROSS JOIN point
+ WHERE canonical_stop_links.dataset_id IN :dataset_ids
+ AND canonical_stops.geom IS NOT NULL
+ AND canonical_stops.geom && ST_Expand(point.geom, :radius_deg)
+ AND ST_DWithin(canonical_stops.geom::geography, point.geom::geography, :radius_m)
+ GROUP BY canonical_stops.id, canonical_stop_links.dataset_id, canonical_stops.geom, point.geom
+ ORDER BY canonical_stops.geom <-> point.geom, canonical_stops.id
+ LIMIT :limit
+ """
+ ).bindparams(bindparam("dataset_ids", expanding=True))
+ rows = db.execute(
+ stmt,
+ {
+ "lon": float(location.lon),
+ "lat": float(location.lat),
+ "dataset_ids": tuple(active_dataset_ids),
+ "radius_deg": radius_deg,
+ "radius_m": ADDRESS_ACCESS_RADIUS_M,
+ "limit": ADDRESS_ACCESS_STOP_CANDIDATES * 8,
+ },
+ ).all()
+ return [(int(row[0]), int(row[1]), float(row[2] or 0)) for row in rows]
+
+
+def _nearby_canonical_stops_sqlite(
+ db: Session,
+ location: StopSummary,
+ active_dataset_ids: list[int],
+) -> list[tuple[int, int, float]]:
+ lon = float(location.lon)
+ lat = float(location.lat)
+ distance_expr = (CanonicalStop.lon - lon) * (CanonicalStop.lon - lon) + (CanonicalStop.lat - lat) * (CanonicalStop.lat - lat)
+ rows = db.execute(
+ select(CanonicalStop.id, CanonicalStopLink.dataset_id, CanonicalStop.lat, CanonicalStop.lon)
+ .join(CanonicalStopLink, CanonicalStopLink.canonical_stop_id == CanonicalStop.id)
+ .where(
+ CanonicalStopLink.object_type == "gtfs_stop",
+ CanonicalStopLink.dataset_id.in_(active_dataset_ids),
+ CanonicalStop.lat.is_not(None),
+ CanonicalStop.lon.is_not(None),
+ )
+ .order_by(distance_expr)
+ .limit(ADDRESS_ACCESS_STOP_CANDIDATES * 8)
+ ).all()
+ result = []
+ for canonical_stop_id, dataset_id, stop_lat, stop_lon in rows:
+ distance_m = _distance_m(lat, lon, float(stop_lat), float(stop_lon))
+ if distance_m <= ADDRESS_ACCESS_RADIUS_M:
+ result.append((int(canonical_stop_id), int(dataset_id), distance_m))
+ result.sort(key=lambda item: item[2])
+ return result
+
+
+def _walk_only_journey_between_summaries(
+ db: Session,
+ *,
+ from_location: StopSummary,
+ to_location: StopSummary,
+ departure_seconds: int,
+ dataset_id: int,
+ route_geometry: bool = True,
+) -> dict | None:
+ if from_location.lon is None or from_location.lat is None or to_location.lon is None or to_location.lat is None:
+ return None
+ direct_distance_m = _distance_m(float(from_location.lat), float(from_location.lon), float(to_location.lat), float(to_location.lon))
+ if direct_distance_m > PUBLIC_TRANSPORT_WALK_OPTION_MAX_SECONDS * 1.35:
+ return None
+ leg = _walk_leg_between_summaries(
+ db,
+ from_stop=from_location,
+ to_stop=to_location,
+ departure_seconds=departure_seconds,
+ dataset_id=dataset_id,
+ max_duration_seconds=PUBLIC_TRANSPORT_WALK_OPTION_MAX_SECONDS,
+ route_geometry=route_geometry,
+ )
+ if leg is None:
+ return None
+ leg["route_name"] = "Walk only"
+ return _journey_payload([leg])
+
+
+def _walk_leg_between_summaries(
+ db: Session,
+ *,
+ from_stop: StopSummary,
+ to_stop: StopSummary,
+ departure_seconds: int,
+ dataset_id: int,
+ max_duration_seconds: int,
+ route_geometry: bool = True,
+) -> dict | None:
+ if from_stop.lon is None or from_stop.lat is None or to_stop.lon is None or to_stop.lat is None:
+ return None
+ distance_m = _distance_m(float(from_stop.lat), float(from_stop.lon), float(to_stop.lat), float(to_stop.lon))
+ estimated_duration_seconds = int(math.ceil(distance_m / 1.35))
+ if estimated_duration_seconds > max_duration_seconds * 1.5:
+ return None
+ leg = _walk_leg_payload(
+ db,
+ _RouterWalkBacklink(
+ previous_label=_RouterLabel(canonical_stop_id=0, arrival_seconds=departure_seconds),
+ from_stop=from_stop,
+ to_stop=to_stop,
+ distance_m=distance_m,
+ departure_seconds=departure_seconds,
+ arrival_seconds=departure_seconds + estimated_duration_seconds,
+ ),
+ dataset_id,
+ route_geometry=route_geometry,
+ )
+ if int(leg.get("duration_seconds") or 0) > max_duration_seconds:
+ return None
+ return leg
+
+
+def _compose_address_access_journey(
+ journey: dict,
+ *,
+ access_leg: dict | None,
+ egress_leg: dict | None,
+) -> dict | None:
+ public_legs: list[dict] = []
+ features: list[dict] = []
+ leg_offset = 0
+ if access_leg is not None:
+ public_legs.append(_leg_public_payload(access_leg))
+ features.extend(_offset_feature_legs(_feature_items(_journey_payload([access_leg])), leg_offset))
+ leg_offset += 1
+ public_legs.extend(journey.get("legs") or [])
+ features.extend(_offset_feature_legs(_feature_items(journey), leg_offset))
+ leg_offset += len(journey.get("legs") or [])
+ if egress_leg is not None:
+ public_legs.append(_leg_public_payload(egress_leg))
+ features.extend(_offset_feature_legs(_feature_items(_journey_payload([egress_leg])), leg_offset))
+ if not public_legs:
+ return None
+ departure = access_leg["departure_seconds"] if access_leg is not None else journey.get("departure_seconds")
+ arrival = egress_leg["arrival_seconds"] if egress_leg is not None else journey.get("arrival_seconds")
+ if departure is None or arrival is None:
+ return None
+ transit_legs = [leg for leg in public_legs if leg.get("mode") != "walk"]
+ duration_seconds = max(0, int(arrival) - int(departure))
+ return {
+ "transfers": max(0, len(transit_legs) - 1),
+ "departure_seconds": int(departure),
+ "arrival_seconds": int(arrival),
+ "departure_time": format_gtfs_time(int(departure)),
+ "arrival_time": format_gtfs_time(int(arrival)),
+ "departure_time_label": format_gtfs_time_label(int(departure)),
+ "arrival_time_label": format_gtfs_time_label(int(arrival)),
+ "duration_seconds": duration_seconds,
+ "duration_minutes": duration_minutes_ceil(duration_seconds),
+ "duration_label": format_duration_label(duration_seconds),
+ "legs": public_legs,
+ "features": feature_collection(features),
+ }
+
+
+def _feature_items(payload: dict) -> list[dict]:
+ features = payload.get("features") if isinstance(payload, dict) else None
+ if isinstance(features, dict):
+ items = features.get("features")
+ else:
+ items = None
+ return list(items or []) if isinstance(items, list) else []
+
+
+def _offset_feature_legs(features: list[dict], offset: int) -> list[dict]:
+ if offset <= 0:
+ return json.loads(json.dumps(features))
+ copied = json.loads(json.dumps(features))
+ for feature in copied:
+ props = feature.get("properties") if isinstance(feature, dict) else None
+ if isinstance(props, dict) and isinstance(props.get("leg"), int):
+ props["leg"] = int(props["leg"]) + offset
+ return copied
+
+
+def _select_diverse_journeys(journeys, *, limit: int) -> list[dict]:
+ ranked = sorted((dict(journey) for journey in journeys), key=_journey_sort_key)
+ selected: list[dict] = []
+ seen_exact: set[str] = set()
+ seen_diversity: set[tuple[object, ...]] = set()
+ for journey in ranked:
+ exact = "||".join(_journey_leg_signature(leg) for leg in journey.get("legs") or [])
+ if exact in seen_exact:
+ continue
+ seen_exact.add(exact)
+ diversity_key = _journey_diversity_key(journey)
+ if diversity_key in seen_diversity and len(selected) >= max(3, limit // 2):
+ continue
+ seen_diversity.add(diversity_key)
+ selected.append(journey)
+ if len(selected) >= limit:
+ break
+ if len(selected) < min(limit, 3):
+ for journey in ranked:
+ exact = "||".join(_journey_leg_signature(leg) for leg in journey.get("legs") or [])
+ if exact in {"||".join(_journey_leg_signature(leg) for leg in existing.get("legs") or []) for existing in selected}:
+ continue
+ selected.append(journey)
+ if len(selected) >= min(limit, 3):
+ break
+ return _ensure_walk_only_option(selected, ranked, limit=limit)
+
+
+def _ensure_walk_only_option(selected: list[dict], ranked: list[dict], *, limit: int) -> list[dict]:
+ if any(_journey_is_walk_only(journey) for journey in selected):
+ return selected
+ walk = next((journey for journey in ranked if _journey_is_walk_only(journey)), None)
+ if walk is None:
+ return selected
+ if len(selected) < limit:
+ return [*selected, walk]
+ if selected:
+ selected[-1] = walk
+ return selected
+
+
+def _journey_is_walk_only(journey: dict) -> bool:
+ legs = journey.get("legs") or []
+ return bool(legs) and all(leg.get("mode") == "walk" for leg in legs)
+
+
+def _journey_diversity_key(journey: dict) -> tuple[object, ...]:
+ route_signature = tuple(
+ str(leg.get("route_ref") or leg.get("route_id") or leg.get("mode") or "")
+ for leg in journey.get("legs") or []
+ if leg.get("mode") != "walk"
+ )
+ departure = journey.get("departure_seconds")
+ time_band = None if departure is None else int(departure) // (30 * 60)
+ return (int(journey.get("transfers") or 0), route_signature, time_band)
+
+
+def _find_journeys_via(
+ db: Session,
+ from_stop_id: int | str,
+ via_stop_id: int | str,
+ to_stop_id: int | str,
+ departure: str,
+ max_transfers: int,
+ transfer_seconds: int,
+ limit: int,
+ source_ids: list[int] | None,
+ service_date: str | date | None,
+) -> dict:
+ max_journeys = max(1, min(limit, 10))
+ first_result = find_journeys(
+ db=db,
+ from_stop_id=from_stop_id,
+ to_stop_id=via_stop_id,
+ departure=departure,
+ max_transfers=max_transfers,
+ transfer_seconds=transfer_seconds,
+ limit=max_journeys,
+ source_ids=source_ids,
+ via_stop_id=None,
+ service_date=service_date,
+ )
+ combined = []
+ for first in first_result.get("journeys", [])[:max_journeys]:
+ first_arrival = first.get("arrival_seconds")
+ if first_arrival is None:
+ continue
+ onward_departure = format_gtfs_time(int(first_arrival) + max(0, transfer_seconds))
+ second_result = find_journeys(
+ db=db,
+ from_stop_id=via_stop_id,
+ to_stop_id=to_stop_id,
+ departure=onward_departure or departure,
+ max_transfers=max_transfers,
+ transfer_seconds=transfer_seconds,
+ limit=max_journeys,
+ source_ids=source_ids,
+ via_stop_id=None,
+ service_date=service_date,
+ )
+ for second in second_result.get("journeys", [])[:max_journeys]:
+ combined.append(_combine_via_journey(first, second))
+
+ unique: dict[tuple[str, ...], dict] = {}
+ for journey in sorted(combined, key=_journey_sort_key):
+ key = tuple(_journey_leg_signature(leg) for leg in journey["legs"])
+ unique.setdefault(key, journey)
+ selected = list(unique.values())[:max_journeys]
+ dataset_ids = sorted(
+ {
+ int(leg["dataset_id"])
+ for journey in selected
+ for leg in journey.get("legs", [])
+ if leg.get("dataset_id") is not None
+ }
+ )
+ searched_dataset_ids = sorted(set(first_result.get("dataset_ids") or []) | set(dataset_ids))
+ return {
+ "from": first_result.get("from"),
+ "to": selected[0]["legs"][-1]["to"] if selected else None,
+ "via": first_result.get("to"),
+ "source": None,
+ "sources": _source_payloads_for_dataset_ids(db, dataset_ids or searched_dataset_ids),
+ "dataset_id": dataset_ids[0] if len(dataset_ids) == 1 else None,
+ "dataset_ids": dataset_ids or searched_dataset_ids,
+ "departure_time": first_result.get("departure_time"),
+ "service_date": first_result.get("service_date"),
+ "max_transfers": max(0, max_transfers),
+ "via_transfer_seconds": max(0, transfer_seconds),
+ "journeys": selected,
+ }
+
+
+def _combine_via_journey(first: dict, second: dict) -> dict:
+ legs = [*first.get("legs", []), *second.get("legs", [])]
+ departure = first.get("departure_seconds")
+ arrival = second.get("arrival_seconds")
+ duration_seconds = None if departure is None or arrival is None else max(0, int(arrival) - int(departure))
+ features = _combine_via_features(first.get("features") or {}, second.get("features") or {}, first_leg_count=len(first.get("legs", [])))
+ return {
+ "transfers": max(0, len(legs) - 1),
+ "departure_seconds": departure,
+ "arrival_seconds": arrival,
+ "departure_time": format_gtfs_time(departure),
+ "arrival_time": format_gtfs_time(arrival),
+ "departure_time_label": format_gtfs_time_label(departure),
+ "arrival_time_label": format_gtfs_time_label(arrival),
+ "duration_seconds": duration_seconds,
+ "duration_minutes": duration_minutes_ceil(duration_seconds),
+ "duration_label": format_duration_label(duration_seconds),
+ "legs": legs,
+ "features": feature_collection(features),
+ "via_forced": True,
+ }
+
+
+def _combine_via_features(first_features: dict, second_features: dict, first_leg_count: int) -> list[dict]:
+ features = []
+ first_collection = first_features.get("features") if isinstance(first_features, dict) else []
+ second_collection = second_features.get("features") if isinstance(second_features, dict) else []
+ for feature in first_collection or []:
+ features.append(_copy_via_feature(feature, leg_offset=0, first_part=True))
+ for feature in second_collection or []:
+ features.append(_copy_via_feature(feature, leg_offset=first_leg_count, first_part=False))
+ return features
+
+
+def _copy_via_feature(feature: dict, *, leg_offset: int, first_part: bool) -> dict:
+ copied = json.loads(json.dumps(feature))
+ props = copied.setdefault("properties", {})
+ if isinstance(props.get("leg"), int):
+ props["leg"] = int(props["leg"]) + leg_offset
+ if props.get("feature_type") == "journey_stop":
+ if first_part and props.get("role") == "end":
+ props["role"] = "transfer"
+ elif not first_part and props.get("role") == "start":
+ props["role"] = "transfer"
+ return copied
+
+
+def _journey_dataset_pairs(from_selection: StopSelection, to_selection: StopSelection) -> list[tuple[int, int]]:
+ pairs = [
+ (first_dataset_id, second_dataset_id)
+ for first_dataset_id in from_selection.stop_ids_by_dataset
+ for second_dataset_id in to_selection.stop_ids_by_dataset
+ ]
+ pairs.sort(key=lambda item: (item[0] != item[1], item[0], item[1]))
+ return pairs[:MAX_JOURNEY_DATASET_PAIRS]
+
+
+def _source_payloads_for_dataset_ids(db: Session, dataset_ids: list[int]) -> list[dict]:
+ if not dataset_ids:
+ return []
+ rows = db.execute(
+ select(Dataset.id, Source.id, Source.name)
+ .join(Source, Source.id == Dataset.source_id)
+ .where(Dataset.id.in_(dataset_ids))
+ .order_by(Source.name, Source.id)
+ ).all()
+ payloads = []
+ seen = set()
+ for dataset_id, source_id, source_name in rows:
+ if source_id in seen:
+ continue
+ seen.add(source_id)
+ payloads.append({"id": source_id, "name": source_name, "dataset_id": dataset_id})
+ return payloads
+
+
+def _journey_sort_key(journey: dict) -> tuple[float, float, float, int, int]:
+ arrival = journey.get("arrival_seconds")
+ departure = journey.get("departure_seconds")
+ transfers = int(journey.get("transfers") or 0)
+ walking_seconds = sum(
+ float(leg.get("distance_m") or 0) / 1.35
+ for leg in journey.get("legs") or []
+ if leg.get("mode") == "walk"
+ )
+ recommended_arrival = None if arrival is None else float(arrival) + transfers * 600 + walking_seconds
+ transit_legs = sum(1 for leg in journey.get("legs") or [] if leg.get("mode") != "walk")
+ return (
+ float("inf") if recommended_arrival is None else recommended_arrival,
+ float("inf") if arrival is None else float(arrival),
+ float("inf") if departure is None else -float(departure),
+ transfers,
+ 1 if transit_legs == 0 else 0,
+ )
+
+
+def _filter_reasonable_journeys(journeys: list[dict]) -> list[dict]:
+ return [journey for journey in journeys if _journey_is_reasonable(journey)]
+
+
+def _journey_is_reasonable(journey: dict) -> bool:
+ path: list[int] = []
+ for leg in journey.get("legs") or []:
+ path.extend(_leg_endpoint_canonical_ids(leg))
+ collapsed: list[int] = []
+ for canonical_stop_id in path:
+ if not collapsed or collapsed[-1] != canonical_stop_id:
+ collapsed.append(canonical_stop_id)
+ seen: set[int] = set()
+ for canonical_stop_id in collapsed:
+ if canonical_stop_id in seen:
+ return False
+ seen.add(canonical_stop_id)
+ return True
+
+
+def _leg_endpoint_canonical_ids(leg: dict) -> tuple[int, ...]:
+ ids: list[int] = []
+ stops = leg.get("stops") or []
+ for stop in stops:
+ canonical_id = (stop.get("canonical_stop") or {}).get("id") or stop.get("canonical_stop_id")
+ if canonical_id is None:
+ continue
+ try:
+ ids.append(int(canonical_id))
+ except (TypeError, ValueError):
+ continue
+ return tuple(ids)
+
+
+def parse_service_date(value: str | date | None) -> date | None:
+ if value is None or value == "":
+ return None
+ if isinstance(value, date) and not isinstance(value, datetime):
+ return value
+ text = str(value).strip()
+ try:
+ return date.fromisoformat(text)
+ except ValueError as exc:
+ raise ValueError("service_date must be YYYY-MM-DD") from exc
+
+
+def _service_ids_by_dataset(db: Session, dataset_ids: list[int], service_date: date | None) -> dict[int, set[str] | None]:
+ if service_date is None or not dataset_ids:
+ return {dataset_id: None for dataset_id in dataset_ids}
+ return {dataset_id: _active_service_ids(db, dataset_id, service_date) for dataset_id in dataset_ids}
+
+
+def _active_service_ids(db: Session, dataset_id: int, service_date: date) -> set[str] | None:
+ has_calendar = bool(db.scalar(select(exists().where(GtfsCalendar.dataset_id == dataset_id))))
+ has_calendar_dates = bool(db.scalar(select(exists().where(GtfsCalendarDate.dataset_id == dataset_id))))
+ if not has_calendar and not has_calendar_dates:
+ return None
+
+ date_int = int(service_date.strftime("%Y%m%d"))
+ weekday_column = [
+ GtfsCalendar.monday,
+ GtfsCalendar.tuesday,
+ GtfsCalendar.wednesday,
+ GtfsCalendar.thursday,
+ GtfsCalendar.friday,
+ GtfsCalendar.saturday,
+ GtfsCalendar.sunday,
+ ][service_date.weekday()]
+ active = set(
+ db.scalars(
+ select(GtfsCalendar.service_id).where(
+ GtfsCalendar.dataset_id == dataset_id,
+ GtfsCalendar.start_date <= date_int,
+ GtfsCalendar.end_date >= date_int,
+ weekday_column.is_(True),
+ )
+ ).all()
+ )
+ exceptions = db.execute(
+ select(GtfsCalendarDate.service_id, GtfsCalendarDate.exception_type).where(
+ GtfsCalendarDate.dataset_id == dataset_id,
+ GtfsCalendarDate.date == date_int,
+ )
+ ).all()
+ for service_id, exception_type in exceptions:
+ if int(exception_type or 0) == 1:
+ active.add(str(service_id))
+ elif int(exception_type or 0) == 2:
+ active.discard(str(service_id))
+ return active
+
+
+def _where_trip_service_active(stmt, trip_model, service_ids: set[str] | None):
+ if service_ids is None:
+ return stmt
+ return stmt.where(trip_model.service_id.in_(service_ids))
+
+
+def _sidecar_service_filter(service_ids: set[str] | None, alias: str = "trips") -> tuple[str, list[object]]:
+ if service_ids is None:
+ return "", []
+ if not service_ids:
+ return " AND 0", []
+ service_list = sorted(str(service_id) for service_id in service_ids)
+ placeholders = ", ".join(["?"] * len(service_list))
+ return f" AND {alias}.service_id IN ({placeholders})", list(service_list)
+
+
+def _sidecar_stop_time_columns(alias: str, prefix: str) -> str:
+ return ", ".join(f"{alias}.{column} AS {prefix}_{column}" for column in GTFS_STOP_TIME_COLUMNS)
+
+
+def _sidecar_stop_time_from_row(dataset_id: int, row, prefix: str) -> GtfsStopTime:
+ return GtfsStopTime(
+ dataset_id=dataset_id,
+ trip_id=str(row[f"{prefix}_trip_id"]),
+ stop_id=str(row[f"{prefix}_stop_id"]),
+ stop_sequence=int(row[f"{prefix}_stop_sequence"]),
+ arrival_time=row[f"{prefix}_arrival_time"],
+ departure_time=row[f"{prefix}_departure_time"],
+ arrival_seconds=row[f"{prefix}_arrival_seconds"],
+ departure_seconds=row[f"{prefix}_departure_seconds"],
+ )
+
+
+def _trip_route_lookup(
+ db: Session,
+ dataset_id: int,
+ trip_ids: list[str],
+ service_ids: set[str] | None = None,
+) -> dict[str, tuple[GtfsTrip, GtfsRoute]]:
+ if service_ids == set() or not trip_ids:
+ return {}
+ service_filter = None if service_ids is None else {str(service_id) for service_id in service_ids}
+ lookup: dict[str, tuple[GtfsTrip, GtfsRoute]] = {}
+ for chunk in _chunks(sorted(set(trip_ids)), SQLITE_IN_CHUNK_SIZE):
+ stmt = (
+ select(GtfsTrip, GtfsRoute)
+ .join(GtfsRoute, and_(GtfsRoute.dataset_id == GtfsTrip.dataset_id, GtfsRoute.route_id == GtfsTrip.route_id))
+ .where(GtfsTrip.dataset_id == dataset_id, GtfsTrip.trip_id.in_(chunk))
+ )
+ for trip, route in db.execute(stmt).all():
+ if service_filter is not None and str(trip.service_id) not in service_filter:
+ continue
+ lookup.setdefault(trip.trip_id, (trip, route))
+ return lookup
+
+
+def _sidecar_direct_leg_rows(
+ db: Session,
+ dataset_id: int,
+ service_ids: set[str] | None,
+ from_stop_ids: tuple[str, ...],
+ to_stop_ids: tuple[str, ...],
+ earliest_departure: int,
+ limit: int,
+) -> list[tuple[GtfsStopTime, GtfsStopTime, GtfsTrip, GtfsRoute]]:
+ service_sql, service_params = _sidecar_service_filter(service_ids)
+ origin_columns = _sidecar_stop_time_columns("origin", "origin")
+ dest_columns = _sidecar_stop_time_columns("dest", "dest")
+ from_placeholders = ", ".join(["?"] * len(from_stop_ids))
+ to_placeholders = ", ".join(["?"] * len(to_stop_ids))
+ rows = execute_sidecar_query(
+ db,
+ dataset_id,
+ f"""
+ SELECT {origin_columns}, {dest_columns}, trips.trip_id AS lookup_trip_id
+ FROM gtfs_stop_times AS origin
+ JOIN gtfs_stop_times AS dest
+ ON dest.trip_id = origin.trip_id
+ AND dest.stop_sequence > origin.stop_sequence
+ JOIN gtfs_trips AS trips
+ ON trips.trip_id = origin.trip_id
+ WHERE origin.stop_id IN ({from_placeholders})
+ AND dest.stop_id IN ({to_placeholders})
+ AND (origin.departure_seconds IS NULL OR origin.departure_seconds >= ?)
+ {service_sql}
+ ORDER BY origin.departure_seconds, origin.departure_time, dest.arrival_seconds, dest.arrival_time, origin.trip_id
+ LIMIT ?
+ """,
+ [*from_stop_ids, *to_stop_ids, earliest_departure, *service_params, limit],
+ )
+ trip_lookup = _trip_route_lookup(db, dataset_id, [str(row["lookup_trip_id"]) for row in rows], service_ids)
+ results = []
+ for row in rows:
+ trip_route = trip_lookup.get(str(row["lookup_trip_id"]))
+ if trip_route is None:
+ continue
+ trip, route = trip_route
+ results.append(
+ (
+ _sidecar_stop_time_from_row(dataset_id, row, "origin"),
+ _sidecar_stop_time_from_row(dataset_id, row, "dest"),
+ trip,
+ route,
+ )
+ )
+ return results
+
+
+def _sidecar_latest_direct_leg_rows(
+ db: Session,
+ dataset_id: int,
+ service_ids: set[str] | None,
+ from_stop_ids: tuple[str, ...],
+ to_stop_ids: tuple[str, ...],
+ earliest_departure: int,
+ latest_arrival: int,
+ excluded_trip_id: str | None,
+) -> list[tuple[GtfsStopTime, GtfsStopTime, GtfsTrip, GtfsRoute]]:
+ service_sql, service_params = _sidecar_service_filter(service_ids)
+ excluded_sql = " AND origin.trip_id != ?" if excluded_trip_id else ""
+ origin_columns = _sidecar_stop_time_columns("origin", "origin")
+ dest_columns = _sidecar_stop_time_columns("dest", "dest")
+ from_placeholders = ", ".join(["?"] * len(from_stop_ids))
+ to_placeholders = ", ".join(["?"] * len(to_stop_ids))
+ params: list[object] = [*from_stop_ids, *to_stop_ids, earliest_departure, latest_arrival, *service_params]
+ if excluded_trip_id:
+ params.append(excluded_trip_id)
+ params.append(120)
+ rows = execute_sidecar_query(
+ db,
+ dataset_id,
+ f"""
+ SELECT {origin_columns}, {dest_columns}, trips.trip_id AS lookup_trip_id
+ FROM gtfs_stop_times AS origin
+ JOIN gtfs_stop_times AS dest
+ ON dest.trip_id = origin.trip_id
+ AND dest.stop_sequence > origin.stop_sequence
+ JOIN gtfs_trips AS trips
+ ON trips.trip_id = origin.trip_id
+ WHERE origin.stop_id IN ({from_placeholders})
+ AND dest.stop_id IN ({to_placeholders})
+ AND (origin.departure_seconds IS NULL OR origin.departure_seconds >= ?)
+ AND (dest.arrival_seconds IS NULL OR dest.arrival_seconds <= ?)
+ {service_sql}
+ {excluded_sql}
+ ORDER BY origin.departure_seconds DESC, origin.departure_time DESC, dest.arrival_seconds DESC, dest.arrival_time DESC, origin.trip_id
+ LIMIT ?
+ """,
+ params,
+ )
+ trip_lookup = _trip_route_lookup(db, dataset_id, [str(row["lookup_trip_id"]) for row in rows], service_ids)
+ results = []
+ for row in rows:
+ trip_route = trip_lookup.get(str(row["lookup_trip_id"]))
+ if trip_route is None:
+ continue
+ trip, route = trip_route
+ results.append(
+ (
+ _sidecar_stop_time_from_row(dataset_id, row, "origin"),
+ _sidecar_stop_time_from_row(dataset_id, row, "dest"),
+ trip,
+ route,
+ )
+ )
+ return results
+
+
+def _sidecar_destination_arrival_rows(
+ db: Session,
+ dataset_id: int,
+ service_ids: set[str] | None,
+ stop_ids: tuple[str, ...],
+ earliest_departure: int,
+ latest_arrival: int | None,
+) -> list[tuple[GtfsStopTime, GtfsTrip, GtfsRoute]]:
+ service_sql, service_params = _sidecar_service_filter(service_ids)
+ latest_sql = " AND (call.arrival_seconds IS NULL OR call.arrival_seconds <= ?)" if latest_arrival is not None else ""
+ call_columns = _sidecar_stop_time_columns("call", "call")
+ stop_placeholders = ", ".join(["?"] * len(stop_ids))
+ params: list[object] = [*stop_ids, earliest_departure]
+ if latest_arrival is not None:
+ params.append(latest_arrival)
+ params.extend(service_params)
+ params.append(MAX_TARGET_DESTINATION_ARRIVALS)
+ rows = execute_sidecar_query(
+ db,
+ dataset_id,
+ f"""
+ SELECT {call_columns}, trips.trip_id AS lookup_trip_id
+ FROM gtfs_stop_times AS call
+ JOIN gtfs_trips AS trips
+ ON trips.trip_id = call.trip_id
+ WHERE call.stop_id IN ({stop_placeholders})
+ AND (call.arrival_seconds IS NULL OR call.arrival_seconds >= ?)
+ {latest_sql}
+ {service_sql}
+ ORDER BY call.arrival_seconds, call.arrival_time, call.trip_id
+ LIMIT ?
+ """,
+ params,
+ )
+ trip_lookup = _trip_route_lookup(db, dataset_id, [str(row["lookup_trip_id"]) for row in rows], service_ids)
+ results = []
+ for row in rows:
+ trip_route = trip_lookup.get(str(row["lookup_trip_id"]))
+ if trip_route is None:
+ continue
+ trip, route = trip_route
+ results.append((_sidecar_stop_time_from_row(dataset_id, row, "call"), trip, route))
+ return results
+
+
+def _sidecar_boarding_rows(
+ db: Session,
+ dataset_id: int,
+ service_ids: set[str] | None,
+ stop_ids: tuple[str, ...],
+ earliest_departure: int,
+ limit: int,
+ latest_departure: int | None = None,
+) -> list[tuple[GtfsStopTime, GtfsTrip, GtfsRoute]]:
+ service_sql, service_params = _sidecar_service_filter(service_ids)
+ latest_sql = " AND (call.departure_seconds IS NULL OR call.departure_seconds < ?)" if latest_departure is not None else ""
+ call_columns = _sidecar_stop_time_columns("call", "call")
+ stop_placeholders = ", ".join(["?"] * len(stop_ids))
+ params: list[object] = [*stop_ids, earliest_departure]
+ if latest_departure is not None:
+ params.append(latest_departure)
+ params.extend(service_params)
+ params.append(limit)
+ rows = execute_sidecar_query(
+ db,
+ dataset_id,
+ f"""
+ SELECT {call_columns}, trips.trip_id AS lookup_trip_id
+ FROM gtfs_stop_times AS call
+ JOIN gtfs_trips AS trips
+ ON trips.trip_id = call.trip_id
+ WHERE call.stop_id IN ({stop_placeholders})
+ AND (call.departure_seconds IS NULL OR call.departure_seconds >= ?)
+ {latest_sql}
+ {service_sql}
+ ORDER BY call.departure_seconds, call.departure_time, call.trip_id
+ LIMIT ?
+ """,
+ params,
+ )
+ trip_lookup = _trip_route_lookup(db, dataset_id, [str(row["lookup_trip_id"]) for row in rows], service_ids)
+ results = []
+ for row in rows:
+ trip_route = trip_lookup.get(str(row["lookup_trip_id"]))
+ if trip_route is None:
+ continue
+ trip, route = trip_route
+ results.append((_sidecar_stop_time_from_row(dataset_id, row, "call"), trip, route))
+ return results
+
+
+def _chunks[T](items: list[T], size: int) -> Iterator[list[T]]:
+ for index in range(0, len(items), size):
+ yield items[index : index + size]
+
+
+def parse_gtfs_time(value: str | None) -> int | None:
+ if not value:
+ return None
+ parts = value.strip().split(":")
+ if len(parts) == 2:
+ parts.append("0")
+ if len(parts) != 3:
+ return None
+ try:
+ hours, minutes, seconds = [int(part) for part in parts]
+ except ValueError:
+ return None
+ if hours < 0 or minutes < 0 or minutes > 59 or seconds < 0 or seconds > 59:
+ return None
+ return hours * 3600 + minutes * 60 + seconds
+
+
+def format_gtfs_time(seconds: int | None) -> str | None:
+ if seconds is None:
+ return None
+ hours = seconds // 3600
+ minutes = (seconds % 3600) // 60
+ secs = seconds % 60
+ return f"{hours:02d}:{minutes:02d}:{secs:02d}"
+
+
+def format_gtfs_time_label(seconds: int | None) -> str | None:
+ if seconds is None:
+ return None
+ service_day = seconds // 86_400
+ seconds_in_day = seconds % 86_400
+ hours = seconds_in_day // 3600
+ minutes = (seconds_in_day % 3600) // 60
+ secs = seconds_in_day % 60
+ clock = f"{hours:02d}:{minutes:02d}" if secs == 0 else f"{hours:02d}:{minutes:02d}:{secs:02d}"
+ return clock if service_day == 0 else f"+{service_day}d {clock}"
+
+
+def duration_minutes_ceil(seconds: int | float | None) -> int | None:
+ if seconds is None:
+ return None
+ return max(0, int(math.ceil(float(seconds) / 60)))
+
+
+def format_duration_label(seconds: int | float | None) -> str | None:
+ minutes_total = duration_minutes_ceil(seconds)
+ if minutes_total is None:
+ return None
+ days = minutes_total // (24 * 60)
+ remaining = minutes_total % (24 * 60)
+ hours = remaining // 60
+ minutes = remaining % 60
+ if days:
+ return f"{days}d {hours:02d}:{minutes:02d}"
+ if hours:
+ return f"{hours}:{minutes:02d}"
+ return f"{minutes} min"
+
+
+@dataclass
+class _RouterLabel:
+ canonical_stop_id: int
+ arrival_seconds: int
+ previous: "_RouterLegBacklink | _RouterWalkBacklink | None" = None
+
+
+@dataclass(frozen=True)
+class _RouterLegBacklink:
+ previous_label: _RouterLabel
+ route: GtfsRoute
+ trip: GtfsTrip
+ origin: GtfsStopTime
+ dest: GtfsStopTime
+
+
+@dataclass(frozen=True)
+class _RouterWalkBacklink:
+ previous_label: _RouterLabel
+ from_stop: StopSummary
+ to_stop: StopSummary
+ distance_m: float
+ departure_seconds: int
+ arrival_seconds: int
+
+
+@dataclass(frozen=True)
+class _RouterBoarding:
+ canonical_stop_id: int
+ call: GtfsStopTime
+ trip: GtfsTrip
+ route: GtfsRoute
+ ready_seconds: int
+
+
+def _find_round_journeys(
+ db: Session,
+ dataset_id: int,
+ service_ids: set[str] | None,
+ from_selection: StopSelection,
+ to_selection: StopSelection,
+ earliest_departure: int,
+ max_transfers: int,
+ transfer_seconds: int,
+ latest_arrival: int | None,
+ limit: int,
+ stop_cache: dict[tuple[int, str], StopSummary],
+ osm_stop_cache: dict[tuple[int, str], dict],
+) -> list[dict]:
+ if from_selection.canonical_stop_id is None or to_selection.canonical_stop_id is None:
+ return []
+ origin_id = from_selection.canonical_stop_id
+ target_id = to_selection.canonical_stop_id
+ best: dict[int, _RouterLabel] = {origin_id: _RouterLabel(origin_id, earliest_departure)}
+ marked = {origin_id}
+ solutions: list[_RouterLabel] = []
+ max_legs = max(1, min(max_transfers + 1, MAX_ROUTER_TRANSIT_LEGS))
+
+ for round_index in range(max_legs):
+ if not marked:
+ break
+ boarding_labels = {
+ stop_id: label
+ for stop_id in marked
+ if (label := best.get(stop_id)) is not None
+ }
+ walking_labels = _walking_transfer_labels(
+ db,
+ dataset_id=dataset_id,
+ source_labels=boarding_labels,
+ latest_arrival=latest_arrival,
+ )
+ for stop_id, label in walking_labels.items():
+ current = best.get(stop_id)
+ accepted = current is None or label.arrival_seconds < current.arrival_seconds
+ if accepted:
+ best[stop_id] = label
+ boarding_labels[stop_id] = label
+ if stop_id == target_id:
+ solutions.append(label)
+ elif stop_id not in boarding_labels:
+ boarding_labels[stop_id] = current
+ board_ready = {
+ stop_id: ready_seconds
+ for stop_id, label in boarding_labels.items()
+ if (ready_seconds := label.arrival_seconds + (0 if label.previous is None else transfer_seconds)) is not None
+ and (latest_arrival is None or ready_seconds < latest_arrival)
+ }
+ if not board_ready:
+ break
+ boardings = _router_boardings_for_marked_stops(
+ db=db,
+ dataset_id=dataset_id,
+ service_ids=service_ids,
+ board_ready=board_ready,
+ latest_arrival=latest_arrival,
+ )
+ if not boardings:
+ break
+ next_marked: set[int] = set()
+ calls_by_trip = _stop_times_by_trip(db, dataset_id, sorted({boarding.trip.trip_id for boarding in boardings}))
+ stop_to_canonical = _canonical_ids_for_trip_calls(db, dataset_id, calls_by_trip)
+ for boarding in boardings:
+ previous_label = best.get(boarding.canonical_stop_id)
+ if previous_label is None:
+ continue
+ calls = calls_by_trip.get(boarding.trip.trip_id, [])
+ for call in calls:
+ if call.stop_sequence <= boarding.call.stop_sequence:
+ continue
+ canonical_stop_id = stop_to_canonical.get(call.stop_id)
+ if canonical_stop_id is None:
+ continue
+ arrival = _arrival_seconds(call)
+ if arrival is None or arrival < boarding.ready_seconds:
+ continue
+ if latest_arrival is not None and arrival >= latest_arrival:
+ continue
+ current = best.get(canonical_stop_id)
+ if current is not None and current.arrival_seconds <= arrival:
+ continue
+ label = _RouterLabel(
+ canonical_stop_id=canonical_stop_id,
+ arrival_seconds=arrival,
+ previous=_RouterLegBacklink(
+ previous_label=previous_label,
+ route=boarding.route,
+ trip=boarding.trip,
+ origin=boarding.call,
+ dest=call,
+ ),
+ )
+ best[canonical_stop_id] = label
+ next_marked.add(canonical_stop_id)
+ if canonical_stop_id == target_id:
+ solutions.append(label)
+ marked = next_marked
+ if len(solutions) >= limit and round_index > 0:
+ break
+
+ journeys = []
+ for label in sorted(solutions, key=lambda item: item.arrival_seconds)[: max(limit * 2, limit)]:
+ legs = _router_label_legs(db, dataset_id, label, stop_cache, osm_stop_cache)
+ if legs:
+ journeys.append(_journey_payload(legs))
+ return sorted(journeys, key=_journey_sort_key)[:limit]
+
+
+def _walking_transfer_labels(
+ db: Session,
+ dataset_id: int,
+ source_labels: dict[int, _RouterLabel],
+ latest_arrival: int | None,
+) -> dict[int, _RouterLabel]:
+ if not source_labels:
+ return {}
+ source_labels = dict(
+ sorted(source_labels.items(), key=lambda item: (item[1].arrival_seconds, item[0]))[
+ :MAX_WALKING_TRANSFER_SOURCE_STOPS
+ ]
+ )
+ nearby_rows = (
+ _walking_transfer_rows_postgres(db, dataset_id, tuple(source_labels))
+ if settings.is_postgresql_database
+ else _walking_transfer_rows_sqlite(db, dataset_id, tuple(source_labels))
+ )
+ labels: dict[int, _RouterLabel] = {}
+ stop_summaries = _canonical_stop_summaries(
+ db,
+ dataset_id,
+ {stop_id for row in nearby_rows for stop_id in (int(row[0]), int(row[1]))},
+ )
+ for source_id, target_id, distance_m in nearby_rows:
+ source_label = source_labels.get(source_id)
+ if source_label is None:
+ continue
+ from_stop = stop_summaries.get(int(source_id))
+ to_stop = stop_summaries.get(int(target_id))
+ if from_stop is None or to_stop is None:
+ continue
+ walk_seconds = _walking_transfer_seconds(distance_m)
+ arrival = source_label.arrival_seconds + walk_seconds
+ if latest_arrival is not None and arrival >= latest_arrival:
+ continue
+ current = labels.get(target_id)
+ if current is not None and current.arrival_seconds <= arrival:
+ continue
+ labels[target_id] = _RouterLabel(
+ canonical_stop_id=target_id,
+ arrival_seconds=arrival,
+ previous=_RouterWalkBacklink(
+ previous_label=source_label,
+ from_stop=from_stop,
+ to_stop=to_stop,
+ distance_m=float(distance_m or 0),
+ departure_seconds=source_label.arrival_seconds,
+ arrival_seconds=arrival,
+ ),
+ )
+ return labels
+
+
+def _canonical_stop_summaries(db: Session, dataset_id: int, canonical_stop_ids: set[int]) -> dict[int, StopSummary]:
+ if not canonical_stop_ids:
+ return {}
+ rows = db.scalars(select(CanonicalStop).where(CanonicalStop.id.in_(canonical_stop_ids))).all()
+ return {
+ canonical.id: StopSummary(
+ id=canonical.id,
+ dataset_id=dataset_id,
+ stop_id=f"canonical:{canonical.id}",
+ name=canonical.name,
+ lat=canonical.lat,
+ lon=canonical.lon,
+ )
+ for canonical in rows
+ }
+
+
+def _walking_transfer_rows_postgres(
+ db: Session,
+ dataset_id: int,
+ source_ids: tuple[int, ...],
+) -> list[tuple[int, int, float]]:
+ if not source_ids:
+ return []
+ stmt = text(
+ """
+ WITH nearby AS (
+ SELECT
+ src.id AS source_id,
+ dest.id AS target_id,
+ ST_DistanceSphere(src.geom, dest.geom) AS distance_m,
+ row_number() OVER (
+ PARTITION BY src.id
+ ORDER BY ST_DistanceSphere(src.geom, dest.geom), dest.id
+ ) AS rn
+ FROM canonical_stops AS src
+ JOIN canonical_stops AS dest
+ ON dest.id != src.id
+ AND src.geom IS NOT NULL
+ AND dest.geom IS NOT NULL
+ AND dest.geom && ST_Expand(src.geom, :radius_deg)
+ AND ST_DWithin(src.geom, dest.geom, :radius_deg)
+ WHERE src.id IN :source_ids
+ AND EXISTS (
+ SELECT 1
+ FROM canonical_stop_links AS link
+ WHERE link.canonical_stop_id = dest.id
+ AND link.dataset_id = :dataset_id
+ AND link.object_type = 'gtfs_stop'
+ )
+ )
+ SELECT source_id, target_id, distance_m
+ FROM nearby
+ WHERE rn <= :neighbor_limit
+ ORDER BY source_id, distance_m, target_id
+ """
+ ).bindparams(bindparam("source_ids", expanding=True))
+ rows = db.execute(
+ stmt,
+ {
+ "dataset_id": dataset_id,
+ "source_ids": source_ids,
+ "radius_deg": WALKING_TRANSFER_RADIUS_DEG,
+ "neighbor_limit": MAX_WALKING_TRANSFER_NEIGHBORS_PER_STOP,
+ },
+ ).all()
+ return [(int(source_id), int(target_id), float(distance_m or 0)) for source_id, target_id, distance_m in rows]
+
+
+def _walking_transfer_rows_sqlite(
+ db: Session,
+ dataset_id: int,
+ source_ids: tuple[int, ...],
+) -> list[tuple[int, int, float]]:
+ if not source_ids:
+ return []
+ source_rows = db.execute(
+ select(CanonicalStop.id, CanonicalStop.lat, CanonicalStop.lon).where(CanonicalStop.id.in_(source_ids))
+ ).all()
+ sources = {
+ int(stop_id): (float(lat), float(lon))
+ for stop_id, lat, lon in source_rows
+ if lat is not None and lon is not None
+ }
+ if not sources:
+ return []
+
+ lat_delta = WALKING_TRANSFER_RADIUS_M / 111_320
+ min_lat = min(lat for lat, _ in sources.values()) - lat_delta
+ max_lat = max(lat for lat, _ in sources.values()) + lat_delta
+ min_lon = min(lon for _, lon in sources.values()) - lat_delta
+ max_lon = max(lon for _, lon in sources.values()) + lat_delta
+ dest_rows = db.execute(
+ select(CanonicalStop.id, CanonicalStop.lat, CanonicalStop.lon)
+ .join(CanonicalStopLink, CanonicalStopLink.canonical_stop_id == CanonicalStop.id)
+ .where(
+ CanonicalStopLink.dataset_id == dataset_id,
+ CanonicalStopLink.object_type == "gtfs_stop",
+ CanonicalStop.lat >= min_lat,
+ CanonicalStop.lat <= max_lat,
+ CanonicalStop.lon >= min_lon,
+ CanonicalStop.lon <= max_lon,
+ )
+ .distinct()
+ ).all()
+ rows: list[tuple[int, int, float]] = []
+ for source_id, (source_lat, source_lon) in sources.items():
+ candidates = []
+ for target_id, target_lat, target_lon in dest_rows:
+ if int(target_id) == source_id or target_lat is None or target_lon is None:
+ continue
+ distance_m = _distance_m(source_lat, source_lon, float(target_lat), float(target_lon))
+ if distance_m <= WALKING_TRANSFER_RADIUS_M:
+ candidates.append((source_id, int(target_id), distance_m))
+ rows.extend(
+ sorted(candidates, key=lambda item: (item[2], item[1]))[:MAX_WALKING_TRANSFER_NEIGHBORS_PER_STOP]
+ )
+ return rows
+
+
+def _walking_transfer_seconds(distance_m: float) -> int:
+ return max(30, int(math.ceil(float(distance_m or 0) / WALKING_TRANSFER_SPEED_MPS)))
+
+
+def _distance_m(lat_a: float, lon_a: float, lat_b: float, lon_b: float) -> float:
+ mean_lat = math.radians((lat_a + lat_b) / 2)
+ meters_per_lon = 111_320 * math.cos(mean_lat)
+ dx = (lon_b - lon_a) * meters_per_lon
+ dy = (lat_b - lat_a) * 111_320
+ return math.hypot(dx, dy)
+
+
+def _router_boardings_for_marked_stops(
+ db: Session,
+ dataset_id: int,
+ service_ids: set[str] | None,
+ board_ready: dict[int, int],
+ latest_arrival: int | None = None,
+) -> list[_RouterBoarding]:
+ if not board_ready:
+ return []
+ stop_ids_by_canonical = _gtfs_stop_ids_for_canonical_ids(db, dataset_id, set(board_ready))
+ stop_to_canonical = {
+ stop_id: canonical_stop_id
+ for canonical_stop_id, stop_ids in stop_ids_by_canonical.items()
+ for stop_id in stop_ids
+ }
+ stop_ids = tuple(stop_to_canonical)
+ if not stop_ids:
+ return []
+ boardings: list[_RouterBoarding] = []
+ seen: set[str] = set()
+ earliest = min(board_ready.values())
+ for call, trip, route in _router_boarding_rows(db, dataset_id, service_ids, stop_ids, earliest, latest_arrival):
+ canonical_stop_id = stop_to_canonical.get(call.stop_id)
+ if canonical_stop_id is None:
+ continue
+ ready = board_ready.get(canonical_stop_id)
+ departure = _departure_seconds(call)
+ if ready is None or departure is None or departure < ready:
+ continue
+ if trip.trip_id in seen:
+ continue
+ seen.add(trip.trip_id)
+ boardings.append(
+ _RouterBoarding(
+ canonical_stop_id=canonical_stop_id,
+ call=call,
+ trip=trip,
+ route=route,
+ ready_seconds=ready,
+ )
+ )
+ if len(boardings) >= MAX_ROUTER_BOARDING_CANDIDATES:
+ break
+ return sorted(boardings, key=lambda item: (_departure_seconds(item.call) or 10**9, item.trip.trip_id))
+
+
+def _router_boarding_rows(
+ db: Session,
+ dataset_id: int,
+ service_ids: set[str] | None,
+ stop_ids: tuple[str, ...],
+ earliest: int,
+ latest_departure: int | None = None,
+) -> list[tuple[GtfsStopTime, GtfsTrip, GtfsRoute]]:
+ if service_ids == set():
+ return []
+ if uses_sidecar_stop_times(db, dataset_id):
+ return _sidecar_boarding_rows(
+ db=db,
+ dataset_id=dataset_id,
+ service_ids=service_ids,
+ stop_ids=stop_ids,
+ earliest_departure=earliest,
+ latest_departure=latest_departure,
+ limit=MAX_ROUTER_BOARDING_CANDIDATES * 2,
+ )
+ stmt = (
+ select(GtfsStopTime, GtfsTrip, GtfsRoute)
+ .join(GtfsTrip, and_(GtfsTrip.dataset_id == GtfsStopTime.dataset_id, GtfsTrip.trip_id == GtfsStopTime.trip_id))
+ .join(GtfsRoute, and_(GtfsRoute.dataset_id == GtfsTrip.dataset_id, GtfsRoute.route_id == GtfsTrip.route_id))
+ .where(
+ GtfsStopTime.dataset_id == dataset_id,
+ GtfsStopTime.stop_id.in_(stop_ids),
+ or_(GtfsStopTime.departure_seconds.is_(None), GtfsStopTime.departure_seconds >= earliest),
+ )
+ .order_by(GtfsStopTime.departure_seconds, GtfsStopTime.departure_time, GtfsStopTime.trip_id)
+ .limit(MAX_ROUTER_BOARDING_CANDIDATES * 2)
+ )
+ stmt = _where_trip_service_active(stmt, GtfsTrip, service_ids)
+ if latest_departure is not None:
+ stmt = stmt.where(or_(GtfsStopTime.departure_seconds.is_(None), GtfsStopTime.departure_seconds < latest_departure))
+ return db.execute(stmt).all()
+
+
+def _gtfs_stop_ids_for_canonical_ids(
+ db: Session,
+ dataset_id: int,
+ canonical_stop_ids: set[int],
+) -> dict[int, tuple[str, ...]]:
+ if not canonical_stop_ids:
+ return {}
+ rows = db.execute(
+ select(CanonicalStopLink.canonical_stop_id, CanonicalStopLink.external_id)
+ .where(
+ CanonicalStopLink.object_type == "gtfs_stop",
+ CanonicalStopLink.dataset_id == dataset_id,
+ CanonicalStopLink.canonical_stop_id.in_(canonical_stop_ids),
+ )
+ .order_by(CanonicalStopLink.canonical_stop_id, CanonicalStopLink.external_id)
+ ).all()
+ grouped: dict[int, list[str]] = {}
+ for canonical_stop_id, stop_id in rows:
+ grouped.setdefault(int(canonical_stop_id), []).append(str(stop_id))
+ return {canonical_stop_id: tuple(stop_ids) for canonical_stop_id, stop_ids in grouped.items()}
+
+
+def _canonical_ids_for_trip_calls(
+ db: Session,
+ dataset_id: int,
+ calls_by_trip: dict[str, list[GtfsStopTime]],
+) -> dict[str, int]:
+ stop_ids = sorted({call.stop_id for calls in calls_by_trip.values() for call in calls})
+ if not stop_ids:
+ return {}
+ rows = db.execute(
+ select(CanonicalStopLink.external_id, CanonicalStopLink.canonical_stop_id)
+ .where(
+ CanonicalStopLink.object_type == "gtfs_stop",
+ CanonicalStopLink.dataset_id == dataset_id,
+ CanonicalStopLink.external_id.in_(stop_ids),
+ )
+ ).all()
+ return {str(stop_id): int(canonical_stop_id) for stop_id, canonical_stop_id in rows}
+
+
+def _router_label_legs(
+ db: Session,
+ dataset_id: int,
+ label: _RouterLabel,
+ stop_cache: dict[tuple[int, str], StopSummary],
+ osm_stop_cache: dict[tuple[int, str], dict],
+) -> list[dict]:
+ backlinks: list[_RouterLegBacklink | _RouterWalkBacklink] = []
+ current = label
+ while current.previous is not None:
+ backlinks.append(current.previous)
+ current = current.previous.previous_label
+ backlinks.reverse()
+ legs = []
+ for backlink in backlinks:
+ if isinstance(backlink, _RouterWalkBacklink):
+ legs.append(_walk_leg_payload(db, backlink, dataset_id))
+ continue
+ legs.append(
+ _leg_payload(
+ db=db,
+ dataset_id=dataset_id,
+ route=backlink.route,
+ trip=backlink.trip,
+ origin=backlink.origin,
+ dest=backlink.dest,
+ stop_cache=stop_cache,
+ osm_stop_cache=osm_stop_cache,
+ )
+ )
+ return legs
+
+
+def _find_walk_only_journey(
+ db: Session,
+ *,
+ from_selection: StopSelection,
+ to_selection: StopSelection,
+ departure_seconds: int,
+) -> dict | None:
+ if from_selection.canonical_stop_id is None or to_selection.canonical_stop_id is None:
+ return None
+ if from_selection.canonical_stop_id == to_selection.canonical_stop_id:
+ return None
+ if (
+ from_selection.display.lon is None
+ or from_selection.display.lat is None
+ or to_selection.display.lon is None
+ or to_selection.display.lat is None
+ ):
+ return None
+ direct_distance_m = _distance_m(
+ float(from_selection.display.lat),
+ float(from_selection.display.lon),
+ float(to_selection.display.lat),
+ float(to_selection.display.lon),
+ )
+ if direct_distance_m > PUBLIC_TRANSPORT_WALK_OPTION_MAX_SECONDS * 1.35:
+ return None
+ try:
+ route = route_between_points(
+ db,
+ from_lon=float(from_selection.display.lon),
+ from_lat=float(from_selection.display.lat),
+ to_lon=float(to_selection.display.lon),
+ to_lat=float(to_selection.display.lat),
+ mode="walk",
+ max_visited=80_000,
+ )
+ except Exception: # noqa: BLE001 - walking comparison is optional
+ return None
+ duration_seconds = float(route.get("duration_seconds") or 0)
+ if duration_seconds <= 0 or duration_seconds > PUBLIC_TRANSPORT_WALK_OPTION_MAX_SECONDS:
+ return None
+ arrival_seconds = departure_seconds + int(math.ceil(duration_seconds))
+ leg = _walk_leg_payload(
+ db,
+ _RouterWalkBacklink(
+ previous_label=_RouterLabel(
+ canonical_stop_id=from_selection.canonical_stop_id,
+ arrival_seconds=departure_seconds,
+ ),
+ from_stop=from_selection.display,
+ to_stop=to_selection.display,
+ distance_m=float(route.get("distance_m") or 0),
+ departure_seconds=departure_seconds,
+ arrival_seconds=arrival_seconds,
+ ),
+ from_selection.display.dataset_id,
+ )
+ leg["route_name"] = "Walk only"
+ leg["duration_seconds"] = duration_seconds
+ return _journey_payload([leg])
+
+
+def _find_direct_journeys(
+ db: Session,
+ dataset_id: int,
+ service_ids: set[str] | None,
+ from_stop_ids: tuple[str, ...],
+ to_stop_ids: tuple[str, ...],
+ earliest_departure: int,
+ limit: int,
+ stop_cache: dict[tuple[int, str], StopSummary],
+ osm_stop_cache: dict[tuple[int, str], dict],
+) -> list[dict]:
+ candidates = [
+ _journey_payload([leg])
+ for leg in _find_direct_legs(
+ db,
+ dataset_id,
+ service_ids,
+ from_stop_ids,
+ to_stop_ids,
+ earliest_departure,
+ stop_cache,
+ osm_stop_cache,
+ max_legs=max(limit * 4, limit),
+ )
+ ]
+ return sorted(candidates, key=_journey_sort_key)[:limit]
+
+
+def _find_direct_legs(
+ db: Session,
+ dataset_id: int,
+ service_ids: set[str] | None,
+ from_stop_ids: tuple[str, ...],
+ to_stop_ids: tuple[str, ...],
+ earliest_departure: int,
+ stop_cache: dict[tuple[int, str], StopSummary],
+ osm_stop_cache: dict[tuple[int, str], dict],
+ max_legs: int = 20,
+) -> list[dict]:
+ if not from_stop_ids or not to_stop_ids:
+ return []
+ if service_ids == set():
+ return []
+ if uses_sidecar_stop_times(db, dataset_id):
+ rows = _sidecar_direct_leg_rows(
+ db=db,
+ dataset_id=dataset_id,
+ service_ids=service_ids,
+ from_stop_ids=from_stop_ids,
+ to_stop_ids=to_stop_ids,
+ earliest_departure=earliest_departure,
+ limit=MAX_DIRECT_ROWS,
+ )
+ candidates: list[dict] = []
+ seen: set[tuple[object, ...]] = set()
+ for origin, dest, trip, route in rows:
+ dep_seconds = _departure_seconds(origin)
+ arr_seconds = _arrival_seconds(dest)
+ if dep_seconds is None or arr_seconds is None:
+ continue
+ if dep_seconds < earliest_departure or arr_seconds < dep_seconds:
+ continue
+ key = (route.route_id, route.short_name, origin.stop_id, dest.stop_id, dep_seconds, arr_seconds)
+ if key in seen:
+ continue
+ seen.add(key)
+ candidates.append(_leg_payload(db, dataset_id, route, trip, origin, dest, stop_cache, osm_stop_cache))
+ if len(candidates) >= max(1, max_legs):
+ break
+ return sorted(candidates, key=lambda item: (item["arrival_seconds"], -(item["departure_seconds"] or -1)))
+
+ Origin = aliased(GtfsStopTime)
+ Dest = aliased(GtfsStopTime)
+ stmt = (
+ select(Origin, Dest, GtfsTrip, GtfsRoute)
+ .join(
+ Dest,
+ and_(
+ Dest.dataset_id == Origin.dataset_id,
+ Dest.trip_id == Origin.trip_id,
+ Dest.stop_sequence > Origin.stop_sequence,
+ ),
+ )
+ .join(GtfsTrip, and_(GtfsTrip.dataset_id == Origin.dataset_id, GtfsTrip.trip_id == Origin.trip_id))
+ .join(GtfsRoute, and_(GtfsRoute.dataset_id == GtfsTrip.dataset_id, GtfsRoute.route_id == GtfsTrip.route_id))
+ .where(Origin.dataset_id == dataset_id, Origin.stop_id.in_(from_stop_ids), Dest.stop_id.in_(to_stop_ids))
+ .where(or_(Origin.departure_seconds.is_(None), Origin.departure_seconds >= earliest_departure))
+ .order_by(Origin.departure_seconds, Origin.departure_time, Dest.arrival_seconds, Dest.arrival_time, Origin.trip_id)
+ .limit(MAX_DIRECT_ROWS)
+ )
+ stmt = _where_trip_service_active(stmt, GtfsTrip, service_ids)
+ candidates: list[dict] = []
+ seen: set[tuple[object, ...]] = set()
+ for origin, dest, trip, route in db.execute(stmt).all():
+ dep_seconds = _departure_seconds(origin)
+ arr_seconds = _arrival_seconds(dest)
+ if dep_seconds is None or arr_seconds is None:
+ continue
+ if dep_seconds < earliest_departure or arr_seconds < dep_seconds:
+ continue
+ key = (route.route_id, route.short_name, origin.stop_id, dest.stop_id, dep_seconds, arr_seconds)
+ if key in seen:
+ continue
+ seen.add(key)
+ leg = _leg_payload(db, dataset_id, route, trip, origin, dest, stop_cache, osm_stop_cache)
+ candidates.append(leg)
+ if len(candidates) >= max(1, max_legs):
+ break
+
+ return sorted(candidates, key=lambda item: (item["arrival_seconds"], -(item["departure_seconds"] or -1)))
+
+
+@dataclass(frozen=True)
+class _FirstLegOption:
+ departure_seconds: int
+ arrival_seconds: int
+ origin: GtfsStopTime
+ dest: GtfsStopTime
+ trip: GtfsTrip
+ route: GtfsRoute
+
+
+@dataclass(frozen=True)
+class _SecondLegOption:
+ canonical_stop_id: int
+ departure_seconds: int
+ arrival_seconds: int
+ origin: GtfsStopTime
+ dest: GtfsStopTime
+ trip: GtfsTrip
+ route: GtfsRoute
+
+
+@dataclass(frozen=True)
+class _OneTransferCandidate:
+ arrival_seconds: int
+ departure_seconds: int
+ first_route: GtfsRoute
+ first_trip: GtfsTrip
+ first_origin: GtfsStopTime
+ first_dest: GtfsStopTime
+ second: _SecondLegOption
+ final_walk: _RouterWalkBacklink | None = None
+
+
+@dataclass(frozen=True)
+class _AccessTransferCandidate:
+ canonical_stop_id: int
+ option: _FirstLegOption
+ rank: int
+
+
+def _find_one_transfer_journeys(
+ db: Session,
+ first_dataset_id: int,
+ second_dataset_id: int,
+ first_service_ids: set[str] | None,
+ second_service_ids: set[str] | None,
+ from_stop_ids: tuple[str, ...],
+ to_stop_ids: tuple[str, ...],
+ origin_canonical_stop_id: int | None,
+ target_canonical_stop_id: int | None,
+ earliest_departure: int,
+ latest_arrival: int | None,
+ transfer_seconds: int,
+ limit: int,
+ stop_cache: dict[tuple[int, str], StopSummary],
+ osm_stop_cache: dict[tuple[int, str], dict],
+) -> list[dict]:
+ if first_service_ids == set() or second_service_ids == set():
+ return []
+ if latest_arrival is not None and latest_arrival <= earliest_departure:
+ return []
+ destination_groups = _destination_stop_groups_with_final_walks(
+ db,
+ dataset_id=second_dataset_id,
+ to_stop_ids=to_stop_ids,
+ target_canonical_stop_id=target_canonical_stop_id,
+ )
+ second_legs: dict[int, list[_SecondLegOption]] = {}
+ final_walk_by_canonical: dict[int, _RouterWalkBacklink] = {}
+ for destination_stop_ids, group_walks in destination_groups:
+ group_second_legs = _targeted_second_leg_options(
+ db,
+ second_dataset_id,
+ second_service_ids,
+ destination_stop_ids,
+ earliest_departure,
+ latest_arrival,
+ )
+ for canonical_stop_id, options in group_second_legs.items():
+ second_legs.setdefault(canonical_stop_id, []).extend(options)
+ final_walk_by_canonical.update(group_walks)
+ if not second_legs:
+ return []
+ second_dest_canonical = _canonical_ids_for_stop_ids(
+ db,
+ second_dataset_id,
+ {option.dest.stop_id for options in second_legs.values() for option in options},
+ )
+ transfer_stop_ids_by_canonical = _gtfs_stop_ids_for_canonical_ids(db, first_dataset_id, set(second_legs))
+ candidates: list[_OneTransferCandidate] = []
+ seen: set[tuple[object, ...]] = set()
+
+ second_leg_options = sorted(
+ [
+ (canonical_stop_id, option)
+ for canonical_stop_id, options in second_legs.items()
+ for option in options
+ ],
+ key=lambda item: (item[1].arrival_seconds, -item[1].departure_seconds),
+ )
+ latest_first_arrival_limit = max(
+ (
+ option.departure_seconds - transfer_seconds
+ for _, option in second_leg_options
+ if option.departure_seconds - transfer_seconds >= earliest_departure
+ ),
+ default=earliest_departure,
+ )
+ first_options_by_canonical = _first_leg_options_to_transfer_stops(
+ db=db,
+ dataset_id=first_dataset_id,
+ service_ids=first_service_ids,
+ from_stop_ids=from_stop_ids,
+ transfer_stop_ids_by_canonical=transfer_stop_ids_by_canonical,
+ earliest_departure=earliest_departure,
+ latest_arrival=latest_first_arrival_limit,
+ )
+ searched_second_legs = 0
+ best_candidate_arrival: int | None = None
+ for canonical_stop_id, second in second_leg_options:
+ if searched_second_legs >= MAX_BACKWARD_SECOND_LEG_OPTIONS and candidates:
+ break
+ if best_candidate_arrival is not None and candidates and second.arrival_seconds > best_candidate_arrival:
+ break
+ searched_second_legs += 1
+ transfer_stop_ids = transfer_stop_ids_by_canonical.get(canonical_stop_id)
+ if not transfer_stop_ids:
+ continue
+ latest_first_arrival = second.departure_seconds - transfer_seconds
+ if latest_first_arrival < earliest_departure:
+ continue
+ excluded_trip_id = second.trip.trip_id if first_dataset_id == second_dataset_id else None
+ first = _best_first_leg_for_second(
+ first_options_by_canonical.get(canonical_stop_id, []),
+ latest_arrival=latest_first_arrival,
+ excluded_trip_id=excluded_trip_id,
+ )
+ if first is None:
+ continue
+ if origin_canonical_stop_id is not None and canonical_stop_id == origin_canonical_stop_id:
+ continue
+ final_walk_template = final_walk_by_canonical.get(second_dest_canonical.get(second.dest.stop_id))
+ final_walk = None
+ candidate_arrival = second.arrival_seconds
+ if final_walk_template is not None:
+ if origin_canonical_stop_id is not None and final_walk_template.from_stop.id == origin_canonical_stop_id:
+ continue
+ candidate_arrival = second.arrival_seconds + _walking_transfer_seconds(final_walk_template.distance_m)
+ if latest_arrival is not None and candidate_arrival >= latest_arrival:
+ continue
+ final_walk = _RouterWalkBacklink(
+ previous_label=final_walk_template.previous_label,
+ from_stop=final_walk_template.from_stop,
+ to_stop=final_walk_template.to_stop,
+ distance_m=final_walk_template.distance_m,
+ departure_seconds=second.arrival_seconds,
+ arrival_seconds=candidate_arrival,
+ )
+ key = (
+ first_dataset_id,
+ first.trip.trip_id,
+ first.origin.stop_sequence,
+ first.dest.stop_id,
+ second_dataset_id,
+ second.trip.trip_id,
+ second.origin.stop_sequence,
+ second.dest.stop_sequence,
+ None if final_walk is None else final_walk.to_stop.stop_id,
+ )
+ if key in seen:
+ continue
+ seen.add(key)
+ best_candidate_arrival = candidate_arrival if best_candidate_arrival is None else min(best_candidate_arrival, candidate_arrival)
+ candidates.append(
+ _OneTransferCandidate(
+ arrival_seconds=candidate_arrival,
+ departure_seconds=first.departure_seconds,
+ first_route=first.route,
+ first_trip=first.trip,
+ first_origin=first.origin,
+ first_dest=first.dest,
+ second=second,
+ final_walk=final_walk,
+ )
+ )
+ if len(candidates) >= MAX_TARGET_TRANSFER_CANDIDATES:
+ break
+
+ tightened_candidates = _latest_feeder_by_onward_leg(candidates)
+ journeys: list[dict] = []
+ for candidate in sorted(tightened_candidates, key=_one_transfer_candidate_sort_key)[
+ : max(limit * 4, limit)
+ ]:
+ first_leg = _leg_payload(
+ db,
+ first_dataset_id,
+ candidate.first_route,
+ candidate.first_trip,
+ candidate.first_origin,
+ candidate.first_dest,
+ stop_cache,
+ osm_stop_cache,
+ )
+ second_leg = _leg_payload(
+ db,
+ second_dataset_id,
+ candidate.second.route,
+ candidate.second.trip,
+ candidate.second.origin,
+ candidate.second.dest,
+ stop_cache,
+ osm_stop_cache,
+ )
+ legs = [first_leg, second_leg]
+ if candidate.final_walk is not None:
+ legs.append(_walk_leg_payload(db, candidate.final_walk, second_dataset_id))
+ journeys.append(_journey_payload(legs))
+
+ return sorted(journeys, key=_journey_sort_key)[:limit]
+
+
+def _destination_stop_groups_with_final_walks(
+ db: Session,
+ dataset_id: int,
+ to_stop_ids: tuple[str, ...],
+ target_canonical_stop_id: int | None,
+) -> list[tuple[tuple[str, ...], dict[int, _RouterWalkBacklink]]]:
+ if target_canonical_stop_id is None:
+ return [(to_stop_ids, {})]
+ target_summary = _canonical_stop_summaries(db, dataset_id, {target_canonical_stop_id}).get(target_canonical_stop_id)
+ if target_summary is None:
+ return [(to_stop_ids, {})]
+ nearby_rows = (
+ _walking_transfer_rows_postgres(db, dataset_id, (target_canonical_stop_id,))
+ if settings.is_postgresql_database
+ else _walking_transfer_rows_sqlite(db, dataset_id, (target_canonical_stop_id,))
+ )
+ nearby_ids = [int(target_id) for _, target_id, _ in nearby_rows]
+ if not nearby_ids:
+ return [(to_stop_ids, {})]
+ summaries = _canonical_stop_summaries(db, dataset_id, set(nearby_ids))
+ final_walk_by_canonical: dict[int, _RouterWalkBacklink] = {}
+ for _, nearby_id, distance_m in nearby_rows:
+ nearby_id = int(nearby_id)
+ from_summary = summaries.get(nearby_id)
+ if from_summary is None:
+ continue
+ final_walk_by_canonical[nearby_id] = _RouterWalkBacklink(
+ previous_label=_RouterLabel(nearby_id, 0),
+ from_stop=from_summary,
+ to_stop=target_summary,
+ distance_m=float(distance_m or 0),
+ departure_seconds=0,
+ arrival_seconds=0,
+ )
+ stop_ids_by_canonical = _gtfs_stop_ids_for_canonical_ids(db, dataset_id, set(final_walk_by_canonical))
+ groups: list[tuple[tuple[str, ...], dict[int, _RouterWalkBacklink]]] = [(to_stop_ids, {})]
+ for canonical_stop_id, stop_ids in stop_ids_by_canonical.items():
+ walk = final_walk_by_canonical.get(canonical_stop_id)
+ if not stop_ids or walk is None:
+ continue
+ groups.append((stop_ids[:MAX_GROUP_STOP_IDS], {canonical_stop_id: walk}))
+ return groups
+
+
+def _canonical_ids_for_stop_ids(db: Session, dataset_id: int, stop_ids: set[str]) -> dict[str, int]:
+ if not stop_ids:
+ return {}
+ rows = db.execute(
+ select(CanonicalStopLink.external_id, CanonicalStopLink.canonical_stop_id)
+ .where(
+ CanonicalStopLink.object_type == "gtfs_stop",
+ CanonicalStopLink.dataset_id == dataset_id,
+ CanonicalStopLink.external_id.in_(stop_ids),
+ )
+ ).all()
+ return {str(stop_id): int(canonical_stop_id) for stop_id, canonical_stop_id in rows}
+
+
+def _find_access_transfer_journeys(
+ db: Session,
+ from_selection: StopSelection,
+ to_stop_id: int | str,
+ earliest_departure: int,
+ max_transfers: int,
+ transfer_seconds: int,
+ limit: int,
+ source_ids: list[int] | None,
+ service_date: date | None,
+ stop_cache: dict[tuple[int, str], StopSummary],
+ osm_stop_cache: dict[tuple[int, str], dict],
+) -> list[dict]:
+ journeys: list[dict] = []
+ for dataset_id, from_stop_ids in from_selection.stop_ids_by_dataset.items():
+ service_ids = _service_ids_by_dataset(db, [dataset_id], service_date).get(dataset_id)
+ if service_ids == set():
+ continue
+ candidates = _access_transfer_candidates(
+ db=db,
+ dataset_id=dataset_id,
+ service_ids=service_ids,
+ from_selection=from_selection,
+ from_stop_ids=from_stop_ids,
+ earliest_departure=earliest_departure,
+ )
+ for candidate in candidates:
+ access_leg = _leg_payload(
+ db=db,
+ dataset_id=dataset_id,
+ route=candidate.option.route,
+ trip=candidate.option.trip,
+ origin=candidate.option.origin,
+ dest=candidate.option.dest,
+ stop_cache=stop_cache,
+ osm_stop_cache=osm_stop_cache,
+ )
+ onward_departure = format_gtfs_time(candidate.option.arrival_seconds + transfer_seconds)
+ if onward_departure is None:
+ continue
+ try:
+ onward = find_journeys(
+ db=db,
+ from_stop_id=_stop_place_token(candidate.canonical_stop_id, dataset_id),
+ to_stop_id=to_stop_id,
+ departure=onward_departure,
+ max_transfers=max(0, max_transfers - 1),
+ limit=limit,
+ transfer_seconds=transfer_seconds,
+ source_ids=source_ids,
+ service_date=service_date,
+ _allow_access_transfer=False,
+ )
+ except ValueError:
+ continue
+ for onward_journey in onward.get("journeys", [])[:limit]:
+ journeys.append(_prepend_access_leg_to_journey(access_leg, onward_journey))
+ if len(journeys) >= limit * 3:
+ break
+ if len(journeys) >= limit * 3:
+ break
+ return sorted(journeys, key=_journey_sort_key)[:limit]
+
+
+def _access_transfer_candidates(
+ db: Session,
+ dataset_id: int,
+ service_ids: set[str] | None,
+ from_selection: StopSelection,
+ from_stop_ids: tuple[str, ...],
+ earliest_departure: int,
+) -> list[_AccessTransferCandidate]:
+ boardings = _origin_boardings(
+ db=db,
+ dataset_id=dataset_id,
+ service_ids=service_ids,
+ stop_ids=from_stop_ids,
+ earliest_departure=earliest_departure,
+ latest_departure=earliest_departure + ACCESS_TRANSFER_MAX_SECONDS,
+ )
+ if not boardings:
+ return []
+ calls_by_trip = _stop_times_by_trip(db, dataset_id, sorted({boarding.trip.trip_id for boarding in boardings}))
+ stop_to_canonical = _canonical_ids_for_trip_calls(db, dataset_id, calls_by_trip)
+ canonical_ids = sorted(set(stop_to_canonical.values()))
+ canonical_names = {
+ int(canonical.id): canonical.name
+ for canonical in db.scalars(select(CanonicalStop).where(CanonicalStop.id.in_(canonical_ids))).all()
+ }
+ candidates: dict[int, _AccessTransferCandidate] = {}
+ for boarding in boardings:
+ departure = _departure_seconds(boarding.call)
+ if departure is None or departure < earliest_departure:
+ continue
+ for call in calls_by_trip.get(boarding.trip.trip_id, []):
+ if call.stop_sequence <= boarding.call.stop_sequence:
+ continue
+ arrival = _arrival_seconds(call)
+ if arrival is None or arrival < departure:
+ continue
+ if arrival - earliest_departure > ACCESS_TRANSFER_MAX_SECONDS:
+ break
+ canonical_stop_id = stop_to_canonical.get(call.stop_id)
+ if canonical_stop_id is None or canonical_stop_id == from_selection.canonical_stop_id:
+ continue
+ stop_name = _stop_name_for_stop_id(db, dataset_id, call.stop_id)
+ rank = _station_importance_rank(canonical_names.get(canonical_stop_id), stop_name)
+ if rank > 1:
+ continue
+ option = _FirstLegOption(
+ departure_seconds=departure,
+ arrival_seconds=arrival,
+ origin=boarding.call,
+ dest=call,
+ trip=boarding.trip,
+ route=boarding.route,
+ )
+ current = candidates.get(canonical_stop_id)
+ candidate = _AccessTransferCandidate(canonical_stop_id=canonical_stop_id, option=option, rank=rank)
+ if current is None or _access_transfer_sort_key(candidate) < _access_transfer_sort_key(current):
+ candidates[canonical_stop_id] = candidate
+ return sorted(candidates.values(), key=_access_transfer_sort_key)[:MAX_ACCESS_TRANSFER_CANDIDATES]
+
+
+def _access_transfer_sort_key(candidate: _AccessTransferCandidate) -> tuple[int, int, int, str]:
+ return (
+ candidate.rank,
+ candidate.option.arrival_seconds,
+ candidate.option.arrival_seconds - candidate.option.departure_seconds,
+ candidate.option.dest.stop_id,
+ )
+
+
+def _stop_name_for_stop_id(db: Session, dataset_id: int, stop_id: str) -> str | None:
+ stop = db.scalar(select(GtfsStop).where(GtfsStop.dataset_id == dataset_id, GtfsStop.stop_id == stop_id))
+ return None if stop is None else stop.name
+
+
+def _prepend_access_leg_to_journey(access_leg: dict, onward_journey: dict) -> dict:
+ access_payload = _journey_payload([access_leg])
+ access_features = access_payload.get("features") or {}
+ onward_features = onward_journey.get("features") or {}
+ features = _combine_via_features(access_features, onward_features, first_leg_count=1)
+ legs = [access_payload["legs"][0], *(onward_journey.get("legs") or [])]
+ departure = access_leg.get("departure_seconds")
+ arrival = onward_journey.get("arrival_seconds")
+ transit_legs = [leg for leg in legs if leg.get("mode") != "walk"]
+ duration_seconds = None if departure is None or arrival is None else max(0, int(arrival) - int(departure))
+ return {
+ "transfers": max(0, len(transit_legs) - 1),
+ "departure_seconds": departure,
+ "arrival_seconds": arrival,
+ "departure_time": format_gtfs_time(departure),
+ "arrival_time": format_gtfs_time(arrival),
+ "departure_time_label": format_gtfs_time_label(departure),
+ "arrival_time_label": format_gtfs_time_label(arrival),
+ "duration_seconds": duration_seconds,
+ "duration_minutes": duration_minutes_ceil(duration_seconds),
+ "duration_label": format_duration_label(duration_seconds),
+ "legs": legs,
+ "features": feature_collection(features),
+ "access_transfer_composed": True,
+ }
+
+
+def _first_leg_options_to_transfer_stops(
+ db: Session,
+ dataset_id: int,
+ service_ids: set[str] | None,
+ from_stop_ids: tuple[str, ...],
+ transfer_stop_ids_by_canonical: dict[int, tuple[str, ...]],
+ earliest_departure: int,
+ latest_arrival: int,
+) -> dict[int, list[_FirstLegOption]]:
+ if not transfer_stop_ids_by_canonical:
+ return {}
+ stop_to_canonical = {
+ stop_id: canonical_stop_id
+ for canonical_stop_id, stop_ids in transfer_stop_ids_by_canonical.items()
+ for stop_id in stop_ids
+ }
+ if not stop_to_canonical:
+ return {}
+ boardings = _origin_boardings(
+ db=db,
+ dataset_id=dataset_id,
+ service_ids=service_ids,
+ stop_ids=from_stop_ids,
+ earliest_departure=earliest_departure,
+ latest_departure=latest_arrival,
+ )
+ if not boardings:
+ return {}
+ calls_by_trip = _stop_times_by_trip(db, dataset_id, sorted({boarding.trip.trip_id for boarding in boardings}))
+ grouped: dict[int, list[_FirstLegOption]] = {}
+ seen: set[tuple[object, ...]] = set()
+ for boarding in boardings:
+ departure = _departure_seconds(boarding.call)
+ if departure is None or departure < earliest_departure:
+ continue
+ calls = calls_by_trip.get(boarding.trip.trip_id, [])
+ for call in calls:
+ if call.stop_sequence <= boarding.call.stop_sequence:
+ continue
+ canonical_stop_id = stop_to_canonical.get(call.stop_id)
+ if canonical_stop_id is None:
+ continue
+ arrival = _arrival_seconds(call)
+ if arrival is None or arrival < departure or arrival > latest_arrival:
+ continue
+ key = (canonical_stop_id, boarding.trip.trip_id, boarding.call.stop_sequence, call.stop_sequence)
+ if key in seen:
+ continue
+ seen.add(key)
+ grouped.setdefault(canonical_stop_id, []).append(
+ _FirstLegOption(
+ departure_seconds=departure,
+ arrival_seconds=arrival,
+ origin=boarding.call,
+ dest=call,
+ trip=boarding.trip,
+ route=boarding.route,
+ )
+ )
+ for canonical_stop_id, options in grouped.items():
+ grouped[canonical_stop_id] = sorted(
+ options,
+ key=lambda option: (option.departure_seconds, option.arrival_seconds),
+ reverse=True,
+ )[:MAX_TRANSFER_BOARDINGS]
+ return grouped
+
+
+def _best_first_leg_for_second(
+ options: list[_FirstLegOption],
+ latest_arrival: int,
+ excluded_trip_id: str | None,
+) -> _FirstLegOption | None:
+ for option in options:
+ if excluded_trip_id and option.trip.trip_id == excluded_trip_id:
+ continue
+ if option.arrival_seconds <= latest_arrival:
+ return option
+ return None
+
+
+def _latest_feeder_by_onward_leg(candidates: list[_OneTransferCandidate]) -> list[_OneTransferCandidate]:
+ latest_by_second_leg: dict[tuple[object, ...], _OneTransferCandidate] = {}
+ for candidate in candidates:
+ key = (
+ candidate.second.canonical_stop_id,
+ candidate.second.trip.dataset_id,
+ candidate.second.trip.trip_id,
+ candidate.second.origin.stop_sequence,
+ candidate.second.dest.stop_sequence,
+ candidate.second.departure_seconds,
+ candidate.second.arrival_seconds,
+ )
+ current = latest_by_second_leg.get(key)
+ if current is None or _one_transfer_feeder_rank(candidate) > _one_transfer_feeder_rank(current):
+ latest_by_second_leg[key] = candidate
+ return list(latest_by_second_leg.values())
+
+
+def _one_transfer_feeder_rank(candidate: _OneTransferCandidate) -> tuple[int, int]:
+ first_arrival = _arrival_seconds(candidate.first_dest) or -1
+ return (candidate.departure_seconds, first_arrival)
+
+
+def _one_transfer_candidate_sort_key(candidate: _OneTransferCandidate) -> tuple[float, float, int]:
+ return (
+ float(candidate.arrival_seconds),
+ -float(candidate.departure_seconds),
+ 1,
+ )
+
+
+def _latest_direct_leg_to_stops(
+ db: Session,
+ dataset_id: int,
+ service_ids: set[str] | None,
+ from_stop_ids: tuple[str, ...],
+ to_stop_ids: tuple[str, ...],
+ earliest_departure: int,
+ latest_arrival: int,
+ excluded_trip_id: str | None = None,
+) -> _FirstLegOption | None:
+ if not from_stop_ids or not to_stop_ids:
+ return None
+ if service_ids == set():
+ return None
+ if uses_sidecar_stop_times(db, dataset_id):
+ rows = _sidecar_latest_direct_leg_rows(
+ db=db,
+ dataset_id=dataset_id,
+ service_ids=service_ids,
+ from_stop_ids=from_stop_ids,
+ to_stop_ids=to_stop_ids,
+ earliest_departure=earliest_departure,
+ latest_arrival=latest_arrival,
+ excluded_trip_id=excluded_trip_id,
+ )
+ for origin, dest, trip, route in rows:
+ departure = _departure_seconds(origin)
+ arrival = _arrival_seconds(dest)
+ if departure is None or arrival is None:
+ continue
+ if departure < earliest_departure or arrival > latest_arrival or arrival < departure:
+ continue
+ return _FirstLegOption(
+ departure_seconds=departure,
+ arrival_seconds=arrival,
+ origin=origin,
+ dest=dest,
+ trip=trip,
+ route=route,
+ )
+ return None
+
+ Origin = aliased(GtfsStopTime)
+ Dest = aliased(GtfsStopTime)
+ stmt = (
+ select(Origin, Dest, GtfsTrip, GtfsRoute)
+ .join(
+ Dest,
+ and_(
+ Dest.dataset_id == Origin.dataset_id,
+ Dest.trip_id == Origin.trip_id,
+ Dest.stop_sequence > Origin.stop_sequence,
+ ),
+ )
+ .join(GtfsTrip, and_(GtfsTrip.dataset_id == Origin.dataset_id, GtfsTrip.trip_id == Origin.trip_id))
+ .join(GtfsRoute, and_(GtfsRoute.dataset_id == GtfsTrip.dataset_id, GtfsRoute.route_id == GtfsTrip.route_id))
+ .where(
+ Origin.dataset_id == dataset_id,
+ Origin.stop_id.in_(from_stop_ids),
+ Dest.stop_id.in_(to_stop_ids),
+ or_(Origin.departure_seconds.is_(None), Origin.departure_seconds >= earliest_departure),
+ or_(Dest.arrival_seconds.is_(None), Dest.arrival_seconds <= latest_arrival),
+ )
+ .order_by(
+ Origin.departure_seconds.desc(),
+ Origin.departure_time.desc(),
+ Dest.arrival_seconds.desc(),
+ Dest.arrival_time.desc(),
+ Origin.trip_id,
+ )
+ .limit(120)
+ )
+ stmt = _where_trip_service_active(stmt, GtfsTrip, service_ids)
+ if excluded_trip_id:
+ stmt = stmt.where(GtfsTrip.trip_id != excluded_trip_id)
+ for origin, dest, trip, route in db.execute(stmt).all():
+ departure = _departure_seconds(origin)
+ arrival = _arrival_seconds(dest)
+ if departure is None or arrival is None:
+ continue
+ if departure < earliest_departure or arrival > latest_arrival or arrival < departure:
+ continue
+ return _FirstLegOption(
+ departure_seconds=departure,
+ arrival_seconds=arrival,
+ origin=origin,
+ dest=dest,
+ trip=trip,
+ route=route,
+ )
+ return None
+
+
+def _targeted_second_leg_options(
+ db: Session,
+ dataset_id: int,
+ service_ids: set[str] | None,
+ to_stop_ids: tuple[str, ...],
+ earliest_departure: int,
+ latest_arrival: int | None,
+) -> dict[int, list[_SecondLegOption]]:
+ if not to_stop_ids:
+ return {}
+ if service_ids == set():
+ return {}
+ destination_rows = _destination_arrivals(db, dataset_id, service_ids, to_stop_ids, earliest_departure, latest_arrival)
+ if not destination_rows:
+ return {}
+ calls_by_trip = _stop_times_by_trip(db, dataset_id, sorted({trip.trip_id for _, trip, _ in destination_rows}))
+ stop_to_canonical = _canonical_ids_for_trip_calls(db, dataset_id, calls_by_trip)
+ grouped: dict[int, list[_SecondLegOption]] = {}
+ seen: set[tuple[object, ...]] = set()
+ to_stop_id_set = set(to_stop_ids)
+ for dest, trip, route in destination_rows:
+ dest_arrival = _arrival_seconds(dest)
+ if dest_arrival is None:
+ continue
+ for call in calls_by_trip.get(trip.trip_id, []):
+ if call.stop_sequence >= dest.stop_sequence:
+ break
+ if call.stop_id in to_stop_id_set:
+ continue
+ departure = _departure_seconds(call)
+ if departure is None or departure < earliest_departure or departure > dest_arrival:
+ continue
+ canonical_stop_id = stop_to_canonical.get(call.stop_id)
+ if canonical_stop_id is None:
+ continue
+ key = (canonical_stop_id, trip.trip_id, call.stop_sequence, dest.stop_sequence)
+ if key in seen:
+ continue
+ seen.add(key)
+ grouped.setdefault(canonical_stop_id, []).append(
+ _SecondLegOption(
+ canonical_stop_id=canonical_stop_id,
+ departure_seconds=departure,
+ arrival_seconds=dest_arrival,
+ origin=call,
+ dest=dest,
+ trip=trip,
+ route=route,
+ )
+ )
+
+ capped: dict[int, list[_SecondLegOption]] = {}
+ for canonical_stop_id, options in grouped.items():
+ selected = sorted(options, key=lambda item: (item.departure_seconds, item.arrival_seconds))[
+ :MAX_TARGET_SECOND_LEGS_PER_STOP
+ ]
+ if selected:
+ capped[canonical_stop_id] = selected
+ return capped
+
+
+def _destination_arrivals(
+ db: Session,
+ dataset_id: int,
+ service_ids: set[str] | None,
+ stop_ids: tuple[str, ...],
+ earliest_departure: int,
+ latest_arrival: int | None,
+) -> list[tuple[GtfsStopTime, GtfsTrip, GtfsRoute]]:
+ if service_ids == set():
+ return []
+ if uses_sidecar_stop_times(db, dataset_id):
+ rows = _sidecar_destination_arrival_rows(
+ db=db,
+ dataset_id=dataset_id,
+ service_ids=service_ids,
+ stop_ids=stop_ids,
+ earliest_departure=earliest_departure,
+ latest_arrival=latest_arrival,
+ )
+ selected = []
+ for stop_time, trip, route in rows:
+ arrival = _arrival_seconds(stop_time)
+ if arrival is None or arrival < earliest_departure:
+ continue
+ if latest_arrival is not None and arrival >= latest_arrival:
+ continue
+ selected.append((stop_time, trip, route))
+ return selected
+
+ stmt = (
+ select(GtfsStopTime, GtfsTrip, GtfsRoute)
+ .join(GtfsTrip, and_(GtfsTrip.dataset_id == GtfsStopTime.dataset_id, GtfsTrip.trip_id == GtfsStopTime.trip_id))
+ .join(GtfsRoute, and_(GtfsRoute.dataset_id == GtfsTrip.dataset_id, GtfsRoute.route_id == GtfsTrip.route_id))
+ .where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.stop_id.in_(stop_ids))
+ .where(or_(GtfsStopTime.arrival_seconds.is_(None), GtfsStopTime.arrival_seconds >= earliest_departure))
+ .order_by(GtfsStopTime.arrival_seconds, GtfsStopTime.arrival_time, GtfsStopTime.trip_id)
+ .limit(MAX_TARGET_DESTINATION_ARRIVALS)
+ )
+ stmt = _where_trip_service_active(stmt, GtfsTrip, service_ids)
+ if latest_arrival is not None:
+ stmt = stmt.where(or_(GtfsStopTime.arrival_seconds.is_(None), GtfsStopTime.arrival_seconds <= latest_arrival))
+ rows = []
+ for stop_time, trip, route in db.execute(stmt).all():
+ arrival = _arrival_seconds(stop_time)
+ if arrival is None or arrival < earliest_departure:
+ continue
+ if latest_arrival is not None and arrival >= latest_arrival:
+ continue
+ rows.append((stop_time, trip, route))
+ return rows
+
+
+@dataclass(frozen=True)
+class _Boarding:
+ call: GtfsStopTime
+ trip: GtfsTrip
+ route: GtfsRoute
+
+
+def _origin_boardings(
+ db: Session,
+ dataset_id: int,
+ service_ids: set[str] | None,
+ stop_ids: tuple[str, ...],
+ earliest_departure: int,
+ latest_departure: int | None = None,
+) -> list[_Boarding]:
+ if not stop_ids:
+ return []
+ if service_ids == set():
+ return []
+ if uses_sidecar_stop_times(db, dataset_id):
+ boardings: list[_Boarding] = []
+ for call, trip, route in _sidecar_boarding_rows(
+ db=db,
+ dataset_id=dataset_id,
+ service_ids=service_ids,
+ stop_ids=stop_ids,
+ earliest_departure=earliest_departure,
+ latest_departure=latest_departure,
+ limit=MAX_DIRECT_ROWS,
+ ):
+ departure = _departure_seconds(call)
+ if departure is None or departure < earliest_departure:
+ continue
+ if latest_departure is not None and departure >= latest_departure:
+ continue
+ boardings.append(_Boarding(call=call, trip=trip, route=route))
+ if len(boardings) >= MAX_TRANSFER_BOARDINGS:
+ break
+ return boardings
+
+ stmt = (
+ select(GtfsStopTime, GtfsTrip, GtfsRoute)
+ .join(GtfsTrip, and_(GtfsTrip.dataset_id == GtfsStopTime.dataset_id, GtfsTrip.trip_id == GtfsStopTime.trip_id))
+ .join(GtfsRoute, and_(GtfsRoute.dataset_id == GtfsTrip.dataset_id, GtfsRoute.route_id == GtfsTrip.route_id))
+ .where(GtfsStopTime.dataset_id == dataset_id, GtfsStopTime.stop_id.in_(stop_ids))
+ .where(or_(GtfsStopTime.departure_seconds.is_(None), GtfsStopTime.departure_seconds >= earliest_departure))
+ .order_by(GtfsStopTime.departure_seconds, GtfsStopTime.departure_time, GtfsStopTime.trip_id)
+ .limit(MAX_DIRECT_ROWS)
+ )
+ stmt = _where_trip_service_active(stmt, GtfsTrip, service_ids)
+ if latest_departure is not None:
+ stmt = stmt.where(or_(GtfsStopTime.departure_seconds.is_(None), GtfsStopTime.departure_seconds < latest_departure))
+ boardings: list[_Boarding] = []
+ for call, trip, route in db.execute(stmt).all():
+ departure = _departure_seconds(call)
+ if departure is None or departure < earliest_departure:
+ continue
+ if latest_departure is not None and departure >= latest_departure:
+ continue
+ boardings.append(_Boarding(call=call, trip=trip, route=route))
+ if len(boardings) >= MAX_TRANSFER_BOARDINGS:
+ break
+ return boardings
+
+
+def _stop_times_by_trip(db: Session, dataset_id: int, trip_ids: list[str]) -> dict[str, list[GtfsStopTime]]:
+ return storage_stop_times_by_trip(db, dataset_id, trip_ids)
+
+
+def _leg_payload(
+ db: Session,
+ dataset_id: int,
+ route: GtfsRoute,
+ trip: GtfsTrip,
+ origin: GtfsStopTime,
+ dest: GtfsStopTime,
+ stop_cache: dict[tuple[int, str], StopSummary],
+ osm_stop_cache: dict[tuple[int, str], dict],
+) -> dict:
+ from_stop = _stop_for_id(db, dataset_id, origin.stop_id, stop_cache)
+ to_stop = _stop_for_id(db, dataset_id, dest.stop_id, stop_cache)
+ departure_seconds = _departure_seconds(origin)
+ arrival_seconds = _arrival_seconds(dest)
+ linked_route_pattern = route_pattern_for_trip(db, route, trip)
+ stops = _leg_stop_payloads(
+ db=db,
+ dataset_id=dataset_id,
+ trip_id=trip.trip_id,
+ start_sequence=origin.stop_sequence,
+ end_sequence=dest.stop_sequence,
+ stop_cache=stop_cache,
+ osm_stop_cache=osm_stop_cache,
+ )
+ geometry, geometry_source, route_pattern = _leg_geometry(db, linked_route_pattern, route, trip, from_stop, to_stop, stops)
+ source = _source_payload_for_dataset_id(db, dataset_id)
+ stop_count = len(stops)
+ return {
+ "dataset_id": dataset_id,
+ "source_id": None if source is None else source["id"],
+ "source_name": None if source is None else source["name"],
+ "route_db_id": route.id,
+ "route_id": route.route_id,
+ "route_ref": route.short_name,
+ "route_name": route.long_name,
+ "mode": route.mode,
+ "operator": route.operator_name,
+ "trip_id": trip.trip_id,
+ "route_pattern_id": None if route_pattern is None else route_pattern.id,
+ "route_pattern_source": None if route_pattern is None else route_pattern.source_kind,
+ "route_pattern_status": None if route_pattern is None else route_pattern.status,
+ "from": _stop_payload(from_stop),
+ "to": _stop_payload(to_stop),
+ "departure_seconds": departure_seconds,
+ "arrival_seconds": arrival_seconds,
+ "departure_time": format_gtfs_time(departure_seconds),
+ "arrival_time": format_gtfs_time(arrival_seconds),
+ "departure_time_label": format_gtfs_time_label(departure_seconds),
+ "arrival_time_label": format_gtfs_time_label(arrival_seconds),
+ "stop_count": stop_count,
+ "intermediate_stop_count": max(0, stop_count - 2),
+ "geometry": geometry,
+ "geometry_source": geometry_source,
+ "stops": stops,
+ }
+
+
+def _journey_payload(legs: list[dict]) -> dict:
+ departure = legs[0]["departure_seconds"]
+ arrival = legs[-1]["arrival_seconds"]
+ duration_seconds = None if departure is None or arrival is None else max(0, int(arrival) - int(departure))
+ transit_legs = [leg for leg in legs if leg.get("mode") != "walk"]
+ features = []
+ for index, leg in enumerate(legs, start=1):
+ if leg["geometry"] is None:
+ continue
+ features.append(
+ {
+ "type": "Feature",
+ "geometry": leg["geometry"],
+ "properties": {
+ "leg": index,
+ "route_id": leg["route_id"],
+ "route_ref": leg["route_ref"],
+ "mode": leg["mode"],
+ "trip_id": leg["trip_id"],
+ "route_pattern_id": leg.get("route_pattern_id"),
+ "route_pattern_source": leg.get("route_pattern_source"),
+ "route_pattern_status": leg.get("route_pattern_status"),
+ "geometry_source": leg["geometry_source"],
+ },
+ }
+ )
+ features.extend(_journey_stop_features(legs))
+ return {
+ "transfers": max(0, len(transit_legs) - 1),
+ "departure_seconds": departure,
+ "arrival_seconds": arrival,
+ "departure_time": format_gtfs_time(departure),
+ "arrival_time": format_gtfs_time(arrival),
+ "departure_time_label": format_gtfs_time_label(departure),
+ "arrival_time_label": format_gtfs_time_label(arrival),
+ "duration_seconds": duration_seconds,
+ "duration_minutes": duration_minutes_ceil(duration_seconds),
+ "duration_label": format_duration_label(duration_seconds),
+ "legs": [_leg_public_payload(leg) for leg in legs],
+ "features": feature_collection(features),
+ }
+
+
+def _leg_public_payload(leg: dict) -> dict:
+ return {key: value for key, value in leg.items() if key not in {"geometry", "departure_seconds", "arrival_seconds"}}
+
+
+def _walk_leg_payload(db: Session, backlink: _RouterWalkBacklink, dataset_id: int, *, route_geometry: bool = True) -> dict:
+ geometry = None
+ geometry_source = "walking_transfer"
+ distance_m = round(float(backlink.distance_m or 0), 1)
+ duration_seconds = max(0, int(backlink.arrival_seconds) - int(backlink.departure_seconds))
+ arrival_seconds = backlink.arrival_seconds
+ if (
+ backlink.from_stop.lon is not None
+ and backlink.from_stop.lat is not None
+ and backlink.to_stop.lon is not None
+ and backlink.to_stop.lat is not None
+ ):
+ if route_geometry:
+ routed_geometry, routed_distance, routed_duration_seconds = _walk_geometry_from_routing(db, backlink.from_stop, backlink.to_stop)
+ else:
+ routed_geometry, routed_distance, routed_duration_seconds = None, 0.0, None
+ if routed_geometry is not None:
+ geometry = routed_geometry
+ geometry_source = "routing_layer:walk"
+ distance_m = routed_distance
+ if routed_duration_seconds is not None:
+ duration_seconds = max(0, int(math.ceil(routed_duration_seconds)))
+ arrival_seconds = backlink.departure_seconds + duration_seconds
+ if geometry is None:
+ geometry = {
+ "type": "LineString",
+ "coordinates": [
+ [backlink.from_stop.lon, backlink.from_stop.lat],
+ [backlink.to_stop.lon, backlink.to_stop.lat],
+ ],
+ }
+ return {
+ "dataset_id": dataset_id,
+ "source_id": None,
+ "source_name": None,
+ "route_db_id": None,
+ "route_id": "walk",
+ "route_ref": "Walk",
+ "route_name": "Walking transfer",
+ "mode": "walk",
+ "operator": None,
+ "trip_id": None,
+ "route_pattern_id": None,
+ "route_pattern_source": None,
+ "route_pattern_status": None,
+ "from": _stop_payload(backlink.from_stop),
+ "to": _stop_payload(backlink.to_stop),
+ "departure_seconds": backlink.departure_seconds,
+ "arrival_seconds": arrival_seconds,
+ "departure_time": format_gtfs_time(backlink.departure_seconds),
+ "arrival_time": format_gtfs_time(arrival_seconds),
+ "departure_time_label": format_gtfs_time_label(backlink.departure_seconds),
+ "arrival_time_label": format_gtfs_time_label(arrival_seconds),
+ "distance_m": distance_m,
+ "duration_seconds": duration_seconds,
+ "geometry": geometry,
+ "geometry_source": geometry_source,
+ "stops": [
+ _canonical_walk_stop_payload(backlink.from_stop, 1),
+ _canonical_walk_stop_payload(backlink.to_stop, 2),
+ ],
+ }
+
+
+def _walk_geometry_from_routing(db: Session, from_stop: StopSummary, to_stop: StopSummary) -> tuple[dict | None, float, float | None]:
+ if from_stop.lon is None or from_stop.lat is None or to_stop.lon is None or to_stop.lat is None:
+ return None, 0.0, None
+ cache_key = (
+ round(float(from_stop.lon), 6),
+ round(float(from_stop.lat), 6),
+ round(float(to_stop.lon), 6),
+ round(float(to_stop.lat), 6),
+ )
+ cached = _walk_geometry_cache_get(cache_key)
+ if cached is not None:
+ return cached
+ try:
+ route = route_between_points(
+ db,
+ from_lon=float(from_stop.lon),
+ from_lat=float(from_stop.lat),
+ to_lon=float(to_stop.lon),
+ to_lat=float(to_stop.lat),
+ mode="walk",
+ max_visited=5_000,
+ )
+ except Exception: # noqa: BLE001 - routing graph may be unavailable during import
+ return None, 0.0, None
+ features = (route.get("features") or {}).get("features") if isinstance(route, dict) else None
+ if not isinstance(features, list):
+ return None, 0.0, None
+ lines = [
+ feature.get("geometry")
+ for feature in features
+ if isinstance(feature, dict) and (feature.get("geometry") or {}).get("type") == "LineString"
+ ]
+ coordinates = [
+ geometry.get("coordinates")
+ for geometry in lines
+ if isinstance(geometry, dict) and len(geometry.get("coordinates") or []) >= 2
+ ]
+ if not coordinates:
+ return None, 0.0, None
+ geometry = coordinates[0] if len(coordinates) == 1 else None
+ duration_seconds = float(route.get("duration_seconds") or 0)
+ if geometry is not None:
+ result = ({"type": "LineString", "coordinates": geometry}, float(route.get("distance_m") or 0), duration_seconds)
+ else:
+ result = ({"type": "MultiLineString", "coordinates": coordinates}, float(route.get("distance_m") or 0), duration_seconds)
+ _walk_geometry_cache_put(cache_key, result)
+ return _copy_walk_geometry_cache_value(result)
+
+
+def _walk_geometry_cache_get(key: tuple[float, float, float, float]) -> tuple[dict | None, float, float | None] | None:
+ now = time.monotonic()
+ with _walk_geometry_cache_lock:
+ cached = _walk_geometry_cache.get(key)
+ if cached is None:
+ return None
+ expires_at, value = cached
+ if expires_at <= now:
+ _walk_geometry_cache.pop(key, None)
+ return None
+ return _copy_walk_geometry_cache_value(value)
+
+
+def _walk_geometry_cache_put(key: tuple[float, float, float, float], value: tuple[dict | None, float, float | None]) -> None:
+ with _walk_geometry_cache_lock:
+ _walk_geometry_cache[key] = (time.monotonic() + WALK_GEOMETRY_CACHE_TTL_SECONDS, _copy_walk_geometry_cache_value(value))
+ if len(_walk_geometry_cache) <= WALK_GEOMETRY_CACHE_MAX_ENTRIES:
+ return
+ oldest = sorted(_walk_geometry_cache.items(), key=lambda item: item[1][0])[
+ : len(_walk_geometry_cache) - WALK_GEOMETRY_CACHE_MAX_ENTRIES
+ ]
+ for old_key, _ in oldest:
+ _walk_geometry_cache.pop(old_key, None)
+
+
+def _copy_walk_geometry_cache_value(value: tuple[dict | None, float, float | None]) -> tuple[dict | None, float, float | None]:
+ geometry, distance_m, duration_seconds = value
+ copied_geometry = None if geometry is None else json.loads(json.dumps(geometry))
+ return copied_geometry, distance_m, duration_seconds
+
+
+def _canonical_walk_stop_payload(stop: StopSummary, sequence: int) -> dict:
+ payload = _stop_payload(stop)
+ payload["stop_sequence"] = sequence
+ is_external_location = is_location_token(stop.stop_id)
+ payload["visual_source"] = "address" if is_external_location else "canonical_stop"
+ payload["visual_lon"] = stop.lon
+ payload["visual_lat"] = stop.lat
+ payload["osm"] = None
+ payload["canonical_stop"] = None if is_external_location else {"id": stop.id, "name": stop.name}
+ return payload
+
+
+def _leg_geometry(
+ db: Session,
+ linked_route_pattern: RoutePattern | None,
+ route: GtfsRoute,
+ trip: GtfsTrip,
+ from_stop: StopSummary,
+ to_stop: StopSummary,
+ fallback_stops: list[dict],
+) -> tuple[dict | None, str, RoutePattern | None]:
+ cache_key = _leg_geometry_cache_key(route, trip, linked_route_pattern, from_stop, to_stop)
+ cached = _leg_geometry_cache_get(db, cache_key)
+ if cached is not None:
+ return cached
+
+ route_layer_candidates: list[tuple[str, str | None, RoutePattern | None]] = []
+ gtfs_shape_candidates: list[tuple[str, str | None, RoutePattern | None]] = []
+ legacy_candidates: list[tuple[str, str | None, RoutePattern | None]] = []
+ if linked_route_pattern is not None:
+ route_layer_candidates.append((f"route_layer:{linked_route_pattern.source_kind}", linked_route_pattern.geometry_geojson, linked_route_pattern))
+ route_layer_candidates.extend(_alternate_route_pattern_geometry_candidates(db, route, linked_route_pattern))
+ if trip.shape_id:
+ shape_row = db.scalar(
+ select(GtfsShape).where(
+ GtfsShape.dataset_id == trip.dataset_id,
+ GtfsShape.shape_id == trip.shape_id,
+ )
+ )
+ if shape_row is not None:
+ gtfs_shape_candidates.append(("gtfs_shape", shape_row.geometry_geojson, None))
+ legacy_candidates.append(("legacy_gtfs_route", route.geometry_geojson, None))
+
+ full_geometry_candidates = [*route_layer_candidates, *gtfs_shape_candidates]
+ usable_route_layer_candidates = _usable_geometry_candidates(route_layer_candidates)
+ for geometry_source, geometry_text, candidate_pattern in _usable_geometry_candidates(full_geometry_candidates):
+ geometry = _validated_leg_geometry(geometry_text, from_stop, to_stop)
+ if geometry is not None:
+ return _leg_geometry_cache_put(cache_key, geometry, geometry_source, candidate_pattern)
+
+ stop_coords = _stop_sequence_coords(fallback_stops, from_stop, to_stop)
+ for geometry_source, geometry_text, candidate_pattern in usable_route_layer_candidates:
+ stitched = _stitched_partial_geometry(geometry_text, stop_coords)
+ if stitched is not None:
+ return _leg_geometry_cache_put(cache_key, stitched, f"{geometry_source}:stitched", candidate_pattern)
+
+ for geometry_source, geometry_text, candidate_pattern in _usable_geometry_candidates(legacy_candidates):
+ geometry = _validated_leg_geometry(geometry_text, from_stop, to_stop)
+ if geometry is not None:
+ return _leg_geometry_cache_put(cache_key, geometry, geometry_source, candidate_pattern)
+
+ for geometry_source, geometry_text, candidate_pattern in _usable_geometry_candidates(gtfs_shape_candidates):
+ stitched = _stitched_partial_geometry(geometry_text, stop_coords)
+ if stitched is not None:
+ return _leg_geometry_cache_put(cache_key, stitched, f"{geometry_source}:stitched", candidate_pattern)
+
+ fallback_geometry, fallback_source = _stop_sequence_fallback_geometry(stop_coords)
+ if fallback_geometry is not None:
+ return _leg_geometry_cache_put(cache_key, fallback_geometry, fallback_source, None)
+ return _leg_geometry_cache_put(cache_key, None, "none", None)
+
+
+def _leg_geometry_cache_key(
+ route: GtfsRoute,
+ trip: GtfsTrip,
+ linked_route_pattern: RoutePattern | None,
+ from_stop: StopSummary,
+ to_stop: StopSummary,
+) -> tuple[object, ...]:
+ return (
+ route.dataset_id,
+ route.route_id,
+ route.id,
+ _geometry_text_fingerprint(route.geometry_geojson),
+ trip.shape_id or "",
+ None if linked_route_pattern is None else linked_route_pattern.id,
+ _geometry_text_fingerprint(None if linked_route_pattern is None else linked_route_pattern.geometry_geojson),
+ from_stop.id,
+ from_stop.stop_id,
+ to_stop.id,
+ to_stop.stop_id,
+ )
+
+
+def _geometry_text_fingerprint(value: str | None) -> tuple[int, str, str]:
+ if not value:
+ return (0, "", "")
+ text_value = str(value)
+ return (len(text_value), text_value[:96], text_value[-96:])
+
+
+def _leg_geometry_cache_get(
+ db: Session,
+ cache_key: tuple[object, ...],
+) -> tuple[dict | None, str, RoutePattern | None] | None:
+ now = time.monotonic()
+ with _leg_geometry_cache_lock:
+ cached = _leg_geometry_cache.get(cache_key)
+ if cached is None:
+ return None
+ expires_at, geometry, geometry_source, route_pattern_id = cached
+ if expires_at <= now:
+ _leg_geometry_cache.pop(cache_key, None)
+ return None
+ pattern = db.get(RoutePattern, route_pattern_id) if route_pattern_id is not None else None
+ return json.loads(json.dumps(geometry)) if geometry is not None else None, geometry_source, pattern
+
+
+def _leg_geometry_cache_put(
+ cache_key: tuple[object, ...],
+ geometry: dict | None,
+ geometry_source: str,
+ route_pattern: RoutePattern | None,
+) -> tuple[dict | None, str, RoutePattern | None]:
+ stored_geometry = json.loads(json.dumps(geometry)) if geometry is not None else None
+ with _leg_geometry_cache_lock:
+ _leg_geometry_cache[cache_key] = (
+ time.monotonic() + LEG_GEOMETRY_CACHE_TTL_SECONDS,
+ stored_geometry,
+ geometry_source,
+ None if route_pattern is None else int(route_pattern.id),
+ )
+ if len(_leg_geometry_cache) > LEG_GEOMETRY_CACHE_MAX_ENTRIES:
+ oldest_keys = sorted(
+ _leg_geometry_cache,
+ key=lambda key: _leg_geometry_cache[key][0],
+ )[: len(_leg_geometry_cache) - LEG_GEOMETRY_CACHE_MAX_ENTRIES]
+ for oldest_key in oldest_keys:
+ _leg_geometry_cache.pop(oldest_key, None)
+ return geometry, geometry_source, route_pattern
+
+
+def _usable_geometry_candidates(
+ candidates: list[tuple[str, str | None, RoutePattern | None]]
+) -> list[tuple[str, str, RoutePattern | None]]:
+ seen_geometry: set[str] = set()
+ usable: list[tuple[str, str, RoutePattern | None]] = []
+ for geometry_source, geometry_text, candidate_pattern in candidates:
+ if not geometry_text or geometry_text in seen_geometry:
+ continue
+ seen_geometry.add(geometry_text)
+ usable.append((geometry_source, geometry_text, candidate_pattern))
+ return usable
+
+
+def _alternate_route_pattern_geometry_candidates(
+ db: Session,
+ route: GtfsRoute,
+ linked_route_pattern: RoutePattern | None,
+) -> list[tuple[str, str | None, RoutePattern | None]]:
+ route_refs = [value for value in [route.short_name, route.route_id] if value]
+ if not route_refs:
+ return []
+ stmt = (
+ select(RoutePattern)
+ .where(RoutePattern.route_ref.in_(route_refs))
+ .order_by(
+ case((RoutePattern.source_kind == "osm", 0), else_=1),
+ RoutePattern.confidence.desc(),
+ RoutePattern.id,
+ )
+ .limit(40)
+ )
+ if route.mode:
+ stmt = stmt.where(or_(RoutePattern.mode == route.mode, RoutePattern.mode.is_(None)))
+ if linked_route_pattern is not None:
+ stmt = stmt.where(RoutePattern.id != linked_route_pattern.id)
+ return [
+ (f"route_layer:{pattern.source_kind}:alternate", pattern.geometry_geojson, pattern)
+ for pattern in db.scalars(stmt).all()
+ ]
+
+
+def _validated_leg_geometry(geometry_text: str, from_stop: StopSummary, to_stop: StopSummary) -> dict | None:
+ full_geometry = json.loads(geometry_text)
+ if from_stop.lon is None or from_stop.lat is None or to_stop.lon is None or to_stop.lat is None:
+ return full_geometry
+ try:
+ segment = _segment_between_stops(shape(full_geometry), from_stop, to_stop)
+ if segment is None or segment.is_empty or segment.length == 0:
+ return None
+ return mapping(segment)
+ except Exception: # noqa: BLE001 - route geometry clipping should not break journey search
+ return None
+
+
+def _stop_sequence_fallback_geometry(
+ coords: list[tuple[float, float]],
+) -> tuple[dict | None, str]:
+ if len(coords) < 2:
+ return None, "none"
+ source = "stop_sequence_fallback" if len(coords) > 2 else "stop_straight_line_fallback"
+ return mapping(LineString(coords)), source
+
+
+def _stop_sequence_coords(
+ stops: list[dict],
+ from_stop: StopSummary,
+ to_stop: StopSummary,
+) -> list[tuple[float, float]]:
+ coords: list[tuple[float, float]] = []
+ for stop in stops:
+ lon = _float_or_none(stop.get("visual_lon", stop.get("lon")))
+ lat = _float_or_none(stop.get("visual_lat", stop.get("lat")))
+ _append_coord(coords, lon, lat)
+
+ if not stops:
+ _append_coord(coords, from_stop.lon, from_stop.lat)
+ _append_coord(coords, to_stop.lon, to_stop.lat)
+ else:
+ if _stop_payload_coord(stops[0]) is None:
+ _prepend_coord(coords, from_stop.lon, from_stop.lat)
+ if _stop_payload_coord(stops[-1]) is None:
+ _append_coord(coords, to_stop.lon, to_stop.lat)
+ if len(coords) < 2:
+ _prepend_coord(coords, from_stop.lon, from_stop.lat)
+ _append_coord(coords, to_stop.lon, to_stop.lat)
+ return coords
+
+
+def _stitched_partial_geometry(geometry_text: str, stop_coords: list[tuple[float, float]]) -> dict | None:
+ if len(stop_coords) < 2:
+ return None
+ try:
+ geom = shape(json.loads(geometry_text))
+ except Exception: # noqa: BLE001 - invalid geometry should not break routing
+ return None
+ line = _stitchable_line_for_geometry(geom, stop_coords)
+ if line is None or line.length == 0:
+ return None
+ matches = _stop_projection_matches(line, stop_coords)
+ if not matches:
+ return None
+
+ first_match = matches[0]
+ last_match = matches[-1]
+ start_stop_index, start_measure, end_stop_index, end_measure = _partial_line_measure_range(line, stop_coords, matches)
+ if start_stop_index is None or end_stop_index is None or start_measure is None or end_measure is None:
+ return None
+ if abs(end_measure - start_measure) <= 1e-12:
+ return None
+
+ route_segment = substring(line, min(start_measure, end_measure), max(start_measure, end_measure))
+ if route_segment.is_empty or route_segment.length == 0 or not isinstance(route_segment, LineString):
+ return None
+ if start_measure > end_measure:
+ route_segment = LineString(list(route_segment.coords)[::-1])
+
+ coords: list[tuple[float, float]] = []
+ for coord in stop_coords[:start_stop_index]:
+ _append_coord(coords, coord[0], coord[1])
+ for coord in route_segment.coords:
+ _append_coord(coords, float(coord[0]), float(coord[1]))
+ for coord in stop_coords[end_stop_index + 1 :]:
+ _append_coord(coords, coord[0], coord[1])
+
+ if len(coords) < 2:
+ return None
+ if len(coords) == len(stop_coords) and all(_coords_equal(left, right) for left, right in zip(coords, stop_coords)):
+ return None
+ return mapping(LineString(coords))
+
+
+def _stitchable_line_for_geometry(geom, stop_coords: list[tuple[float, float]]) -> LineString | None:
+ if isinstance(geom, LineString):
+ return geom
+ if not isinstance(geom, MultiLineString):
+ return None
+ merged = linemerge(geom)
+ if isinstance(merged, LineString):
+ return merged
+ if not isinstance(merged, MultiLineString):
+ return None
+ stop_points = [Point(coord) for coord in stop_coords]
+
+ def score(line: LineString) -> tuple[int, float, float]:
+ distances = [line.distance(point) for point in stop_points]
+ near_count = sum(distance <= LEG_GEOMETRY_MAX_STOP_DISTANCE_DEG for distance in distances)
+ return (near_count, -sum(distances), line.length)
+
+ best = max(merged.geoms, key=score, default=None)
+ if best is None or score(best)[0] == 0:
+ return None
+ return best
+
+
+def _stop_projection_matches(line: LineString, stop_coords: list[tuple[float, float]]) -> list[tuple[int, float, float]]:
+ matches = []
+ for index, coord in enumerate(stop_coords):
+ point = Point(coord)
+ distance = line.distance(point)
+ if distance <= LEG_GEOMETRY_MAX_STOP_DISTANCE_DEG:
+ matches.append((index, line.project(point), distance))
+ return matches
+
+
+def _partial_line_measure_range(
+ line: LineString,
+ stop_coords: list[tuple[float, float]],
+ matches: list[tuple[int, float, float]],
+) -> tuple[int | None, float | None, int | None, float | None]:
+ first_match = matches[0]
+ last_match = matches[-1]
+ direction = _projection_direction(matches)
+
+ start_index = first_match[0]
+ end_index = last_match[0]
+ start_measure = first_match[1]
+ end_measure = last_match[1]
+
+ if direction is None:
+ only_index, only_measure, _ = first_match
+ if only_index < len(stop_coords) - 1:
+ endpoint = _line_endpoint_toward(line, only_measure, stop_coords[only_index + 1])
+ if endpoint is None:
+ return None, None, None, None
+ start_index = only_index
+ end_index = only_index
+ start_measure = only_measure
+ end_measure = endpoint
+ elif only_index > 0:
+ endpoint = _line_endpoint_toward(line, only_measure, stop_coords[only_index - 1])
+ if endpoint is None:
+ return None, None, None, None
+ start_index = only_index
+ end_index = only_index
+ start_measure = endpoint
+ end_measure = only_measure
+ else:
+ return None, None, None, None
+ elif direction > 0:
+ if start_index > 0:
+ endpoint = _line_endpoint_toward(line, start_measure, stop_coords[start_index - 1], preferred="before")
+ if endpoint is not None:
+ start_measure = endpoint
+ if end_index < len(stop_coords) - 1:
+ endpoint = _line_endpoint_toward(line, end_measure, stop_coords[end_index + 1], preferred="after")
+ if endpoint is not None:
+ end_measure = endpoint
+ else:
+ if end_index < len(stop_coords) - 1:
+ endpoint = _line_endpoint_toward(line, end_measure, stop_coords[end_index + 1], preferred="before")
+ if endpoint is not None:
+ end_measure = endpoint
+ if start_index > 0:
+ endpoint = _line_endpoint_toward(line, start_measure, stop_coords[start_index - 1], preferred="after")
+ if endpoint is not None:
+ start_measure = endpoint
+
+ return start_index, start_measure, end_index, end_measure
+
+
+def _projection_direction(matches: list[tuple[int, float, float]]) -> int | None:
+ if len(matches) < 2:
+ return None
+ first = matches[0][1]
+ last = matches[-1][1]
+ if abs(last - first) <= 1e-12:
+ return None
+ return 1 if last > first else -1
+
+
+def _line_endpoint_toward(
+ line: LineString,
+ from_measure: float,
+ target_coord: tuple[float, float],
+ preferred: str | None = None,
+) -> float | None:
+ target = Point(target_coord)
+ candidates = []
+ if preferred in {None, "before"} and from_measure > 1e-12:
+ candidates.append(0.0)
+ if preferred in {None, "after"} and from_measure < line.length - 1e-12:
+ candidates.append(float(line.length))
+ if not candidates:
+ return None
+ projected_point = line.interpolate(from_measure)
+ projected_distance = projected_point.distance(target)
+ endpoint = min(candidates, key=lambda measure: line.interpolate(measure).distance(target))
+ if line.interpolate(endpoint).distance(target) >= projected_distance:
+ return None
+ return endpoint
+
+
+def _coords_equal(left: tuple[float, float], right: tuple[float, float]) -> bool:
+ return abs(left[0] - right[0]) < 1e-12 and abs(left[1] - right[1]) < 1e-12
+
+
+def _append_coord(coords: list[tuple[float, float]], lon: float | None, lat: float | None) -> None:
+ if lon is None or lat is None:
+ return
+ coord = (float(lon), float(lat))
+ if coords and abs(coords[-1][0] - coord[0]) < 1e-12 and abs(coords[-1][1] - coord[1]) < 1e-12:
+ return
+ coords.append(coord)
+
+
+def _prepend_coord(coords: list[tuple[float, float]], lon: float | None, lat: float | None) -> None:
+ if lon is None or lat is None:
+ return
+ coord = (float(lon), float(lat))
+ if coords and abs(coords[0][0] - coord[0]) < 1e-12 and abs(coords[0][1] - coord[1]) < 1e-12:
+ return
+ coords.insert(0, coord)
+
+
+def _stop_payload_coord(stop: dict) -> tuple[float, float] | None:
+ lon = _float_or_none(stop.get("visual_lon", stop.get("lon")))
+ lat = _float_or_none(stop.get("visual_lat", stop.get("lat")))
+ if lon is None or lat is None:
+ return None
+ return (lon, lat)
+
+
+def _float_or_none(value) -> float | None:
+ if value is None:
+ return None
+ try:
+ return float(value)
+ except (TypeError, ValueError):
+ return None
+
+
+def _segment_between_stops(geom, from_stop: StopSummary, to_stop: StopSummary) -> LineString | None:
+ start_point = Point(from_stop.lon, from_stop.lat)
+ end_point = Point(to_stop.lon, to_stop.lat)
+ if geom.distance(start_point) > LEG_GEOMETRY_MAX_STOP_DISTANCE_DEG:
+ return None
+ if geom.distance(end_point) > LEG_GEOMETRY_MAX_STOP_DISTANCE_DEG:
+ return None
+ if isinstance(geom, LineString):
+ return _substring_for_points(geom, start_point, end_point)
+ if isinstance(geom, MultiLineString):
+ merged = linemerge(geom)
+ if isinstance(merged, LineString):
+ return _substring_for_points(merged, start_point, end_point)
+ if isinstance(merged, MultiLineString):
+ path = _network_path_for_points(merged, start_point, end_point)
+ if path is not None:
+ return path
+ line = _best_line_for_points(merged, start_point, end_point)
+ if line is not None:
+ return _substring_for_points(line, start_point, end_point)
+ return None
+
+
+def _substring_for_points(line: LineString, start_point: Point, end_point: Point) -> LineString | None:
+ if line.length == 0:
+ return None
+ start = line.project(start_point)
+ end = line.project(end_point)
+ if abs(start - end) <= 1e-12:
+ return None
+ segment = substring(line, min(start, end), max(start, end))
+ if segment.is_empty or segment.length == 0:
+ return None
+ if start > end and isinstance(segment, LineString):
+ segment = LineString(list(segment.coords)[::-1])
+ return segment if isinstance(segment, LineString) else None
+
+
+def _network_path_for_points(geom: MultiLineString, start_point: Point, end_point: Point) -> LineString | None:
+ nodes: dict[tuple[float, float], tuple[float, float]] = {}
+ graph: dict[tuple[float, float], list[tuple[tuple[float, float], float]]] = {}
+
+ def key(coord) -> tuple[float, float]:
+ return (round(float(coord[0]), 6), round(float(coord[1]), 6))
+
+ def add_node(coord) -> tuple[float, float]:
+ node = key(coord)
+ nodes.setdefault(node, (float(coord[0]), float(coord[1])))
+ graph.setdefault(node, [])
+ return node
+
+ for line in geom.geoms:
+ coords = list(line.coords)
+ for left, right in zip(coords, coords[1:]):
+ left_key = add_node(left)
+ right_key = add_node(right)
+ weight = Point(nodes[left_key]).distance(Point(nodes[right_key]))
+ if weight == 0:
+ continue
+ graph[left_key].append((right_key, weight))
+ graph[right_key].append((left_key, weight))
+ if not nodes:
+ return None
+
+ start_key = _nearest_graph_node(nodes, start_point)
+ end_key = _nearest_graph_node(nodes, end_point)
+ if start_key is None or end_key is None:
+ return None
+ path_keys = _shortest_path(graph, start_key, end_key)
+ if not path_keys:
+ return None
+ coords = [(start_point.x, start_point.y)]
+ coords.extend(nodes[node] for node in path_keys)
+ coords.append((end_point.x, end_point.y))
+ deduped = []
+ for coord in coords:
+ if not deduped or Point(deduped[-1]).distance(Point(coord)) > 1e-10:
+ deduped.append(coord)
+ if len(deduped) < 2:
+ return None
+ return LineString(deduped)
+
+
+def _nearest_graph_node(nodes: dict[tuple[float, float], tuple[float, float]], point: Point) -> tuple[float, float] | None:
+ if not nodes:
+ return None
+ return min(nodes, key=lambda node: Point(nodes[node]).distance(point))
+
+
+def _shortest_path(
+ graph: dict[tuple[float, float], list[tuple[tuple[float, float], float]]],
+ start: tuple[float, float],
+ end: tuple[float, float],
+) -> list[tuple[float, float]] | None:
+ unvisited = {start}
+ distances = {start: 0.0}
+ previous: dict[tuple[float, float], tuple[float, float]] = {}
+ visited: set[tuple[float, float]] = set()
+ while unvisited:
+ current = min(unvisited, key=lambda node: distances.get(node, float("inf")))
+ unvisited.remove(current)
+ if current == end:
+ break
+ visited.add(current)
+ for neighbor, weight in graph.get(current, []):
+ if neighbor in visited:
+ continue
+ candidate = distances[current] + weight
+ if candidate < distances.get(neighbor, float("inf")):
+ distances[neighbor] = candidate
+ previous[neighbor] = current
+ unvisited.add(neighbor)
+ if end not in distances:
+ return None
+ path = [end]
+ while path[-1] != start:
+ parent = previous.get(path[-1])
+ if parent is None:
+ return None
+ path.append(parent)
+ path.reverse()
+ return path
+
+
+def _best_line_for_points(geom: MultiLineString, start: Point, end: Point) -> LineString | None:
+ return min(geom.geoms, key=lambda line: line.distance(start) + line.distance(end), default=None)
+
+
+def _leg_stop_payloads(
+ db: Session,
+ dataset_id: int,
+ trip_id: str,
+ start_sequence: int,
+ end_sequence: int,
+ stop_cache: dict[tuple[int, str], StopSummary],
+ osm_stop_cache: dict[tuple[int, str], dict],
+) -> list[dict]:
+ rows = stop_times_for_trip_range(db, dataset_id, trip_id, start_sequence, end_sequence)
+ stops = []
+ for row in rows:
+ stop = _stop_for_id(db, dataset_id, row.stop_id, stop_cache)
+ stops.append(_visual_stop_payload(db, stop, row.stop_sequence, osm_stop_cache))
+ return stops
+
+
+def _visual_stop_payload(db: Session, stop: StopSummary, stop_sequence: int, osm_stop_cache: dict[tuple[int, str], dict]) -> dict:
+ payload = _stop_payload(stop)
+ payload["stop_sequence"] = stop_sequence
+ payload["visual_source"] = "gtfs"
+ payload["visual_lon"] = stop.lon
+ payload["visual_lat"] = stop.lat
+ payload["osm"] = None
+ payload["canonical_stop"] = None
+ canonical = _canonical_visual_stop(db, stop)
+ if canonical is not None:
+ payload["visual_source"] = "canonical_stop"
+ payload["visual_lon"] = canonical["lon"]
+ payload["visual_lat"] = canonical["lat"]
+ payload["canonical_stop"] = {
+ "id": canonical["id"],
+ "name": canonical["name"],
+ }
+ if canonical["name"]:
+ payload["name"] = canonical["name"]
+ return payload
+ cache_key = (stop.dataset_id, stop.stop_id)
+ if cache_key not in osm_stop_cache:
+ osm_stop_cache[cache_key] = _nearest_osm_stop(db, stop) or {}
+ osm = osm_stop_cache[cache_key]
+ if osm:
+ payload["visual_source"] = "osm"
+ payload["visual_lon"] = osm["lon"]
+ payload["visual_lat"] = osm["lat"]
+ payload["osm"] = {
+ "id": osm["id"],
+ "dataset_id": osm["dataset_id"],
+ "osm_type": osm["osm_type"],
+ "osm_id": osm["osm_id"],
+ "name": osm["name"],
+ "distance_m": osm["distance_m"],
+ }
+ return payload
+
+
+def _canonical_visual_stop(db: Session, stop: StopSummary) -> dict | None:
+ if not stop.id:
+ return None
+ link = db.scalar(
+ select(CanonicalStopLink)
+ .where(CanonicalStopLink.object_type == "gtfs_stop", CanonicalStopLink.object_id == stop.id)
+ .order_by(CanonicalStopLink.id)
+ )
+ if link is None:
+ return None
+ canonical = db.get(CanonicalStop, link.canonical_stop_id)
+ if canonical is None or canonical.lon is None or canonical.lat is None:
+ return None
+ return {
+ "id": canonical.id,
+ "name": canonical.name,
+ "lon": canonical.lon,
+ "lat": canonical.lat,
+ }
+
+
+def _nearest_osm_stop(db: Session, stop: StopSummary) -> dict | None:
+ if stop.lon is None or stop.lat is None:
+ return None
+ active_osm_dataset_ids = [
+ row[0]
+ for row in db.execute(select(Dataset.id).where(Dataset.is_active.is_(True), Dataset.kind == "osm_geojson")).all()
+ ]
+ if not active_osm_dataset_ids:
+ return None
+ min_lon = stop.lon - OSM_STOP_MATCH_RADIUS_DEG
+ max_lon = stop.lon + OSM_STOP_MATCH_RADIUS_DEG
+ min_lat = stop.lat - OSM_STOP_MATCH_RADIUS_DEG
+ max_lat = stop.lat + OSM_STOP_MATCH_RADIUS_DEG
+ candidates = query_osm_features(
+ db,
+ active_osm_dataset_ids,
+ kinds=["stop", "station", "terminal"],
+ bbox=(min_lon, min_lat, max_lon, max_lat),
+ limit=80,
+ )
+ point = Point(stop.lon, stop.lat)
+ best = None
+ for candidate in candidates:
+ if not candidate.geometry_geojson:
+ continue
+ try:
+ geom = shape(json.loads(candidate.geometry_geojson))
+ except Exception: # noqa: BLE001 - ignore malformed feature geometry in visual stop matching
+ continue
+ representative = geom if isinstance(geom, Point) else geom.representative_point()
+ distance_deg = representative.distance(point)
+ if best is None or distance_deg < best["distance_deg"]:
+ best = {
+ "id": candidate.id,
+ "dataset_id": candidate.dataset_id,
+ "osm_type": candidate.osm_type,
+ "osm_id": candidate.osm_id,
+ "name": candidate.name,
+ "lon": representative.x,
+ "lat": representative.y,
+ "distance_deg": distance_deg,
+ "distance_m": round(distance_deg * 111_320, 1),
+ }
+ if best is None or best["distance_deg"] > OSM_STOP_MATCH_RADIUS_DEG:
+ return None
+ best.pop("distance_deg", None)
+ return best
+
+
+def _journey_stop_features(legs: list[dict]) -> list[dict]:
+ features_by_key: dict[str, dict] = {}
+ for leg_index, leg in enumerate(legs, start=1):
+ stops = leg.get("stops", [])
+ for stop_index, stop in enumerate(stops):
+ lon = stop.get("visual_lon")
+ lat = stop.get("visual_lat")
+ if lon is None or lat is None:
+ continue
+ role = "passed"
+ if leg_index == 1 and stop_index == 0:
+ role = "start"
+ elif leg_index == len(legs) and stop_index == len(stops) - 1:
+ role = "end"
+ elif (stop_index == len(stops) - 1 and leg_index < len(legs)) or (stop_index == 0 and leg_index > 1):
+ role = "transfer"
+ key = f"{stop['dataset_id']}:{stop['stop_id']}:{round(float(lon), 6)}:{round(float(lat), 6)}"
+ current = features_by_key.get(key)
+ if current is not None and _stop_role_rank(current["properties"]["role"]) >= _stop_role_rank(role):
+ continue
+ features_by_key[key] = {
+ "type": "Feature",
+ "geometry": {"type": "Point", "coordinates": [lon, lat]},
+ "properties": {
+ "feature_type": "journey_stop",
+ "role": role,
+ "leg": leg_index,
+ "route_ref": leg.get("route_ref"),
+ "mode": leg.get("mode"),
+ "stop_id": stop.get("stop_id"),
+ "name": stop.get("name"),
+ "visual_source": stop.get("visual_source"),
+ "canonical_stop_id": (stop.get("canonical_stop") or {}).get("id"),
+ "osm_id": (stop.get("osm") or {}).get("osm_id"),
+ },
+ }
+ return list(features_by_key.values())
+
+
+def _stop_role_rank(role: str) -> int:
+ return {"passed": 0, "transfer": 1, "start": 2, "end": 2}.get(role, 0)
+
+
+def _arrival_seconds(stop_time: GtfsStopTime) -> int | None:
+ return stop_time.arrival_seconds if stop_time.arrival_seconds is not None else parse_gtfs_time(stop_time.arrival_time or stop_time.departure_time)
+
+
+def _departure_seconds(stop_time: GtfsStopTime) -> int | None:
+ return stop_time.departure_seconds if stop_time.departure_seconds is not None else parse_gtfs_time(stop_time.departure_time or stop_time.arrival_time)
+
+
+def _stop_for_id(db: Session, dataset_id: int, stop_id: str, stop_cache: dict[tuple[int, str], StopSummary]) -> StopSummary:
+ key = (dataset_id, stop_id)
+ if key in stop_cache:
+ return stop_cache[key]
+ summary = _stop_summary_for_stop_id(db, dataset_id, stop_id)
+ stop_cache[key] = summary
+ return summary
+
+
+def _source_payload_for_dataset_id(db: Session, dataset_id: int) -> dict | None:
+ row = db.execute(
+ select(Source.id, Source.name)
+ .join(Dataset, Dataset.source_id == Source.id)
+ .where(Dataset.id == dataset_id)
+ ).first()
+ if row is None:
+ return None
+ source_id, source_name = row
+ return {"id": source_id, "name": source_name, "dataset_id": dataset_id}
+
+
+def _stop_summary_for_stop_id(db: Session, dataset_id: int, stop_id: str) -> StopSummary:
+ stop = db.scalar(select(GtfsStop).where(GtfsStop.dataset_id == dataset_id, GtfsStop.stop_id == stop_id))
+ if stop is None:
+ return StopSummary(id=0, dataset_id=dataset_id, stop_id=stop_id, name=stop_id, lat=None, lon=None)
+ return _stop_summary(stop)
+
+
+def _stop_summary(stop: GtfsStop) -> StopSummary:
+ return StopSummary(
+ id=stop.id,
+ dataset_id=stop.dataset_id,
+ stop_id=stop.stop_id,
+ name=stop.name,
+ lat=stop.lat,
+ lon=stop.lon,
+ )
+
+
+def _stop_payload(stop: StopSummary) -> dict:
+ return {
+ "id": stop.id,
+ "dataset_id": stop.dataset_id,
+ "stop_id": stop.stop_id,
+ "name": stop.name,
+ "lat": stop.lat,
+ "lon": stop.lon,
+ }
+
+
+def _active_gtfs_dataset_ids(db: Session, source_ids: Optional[list[int]] = None) -> list[int]:
+ stmt = select(Dataset.id).where(Dataset.is_active.is_(True), Dataset.kind == "gtfs")
+ if source_ids:
+ stmt = stmt.where(Dataset.source_id.in_(source_ids))
+ return [row[0] for row in db.execute(stmt).all()]
+
+
+def _journey_leg_signature(leg: dict) -> str:
+ return "|".join(
+ str(part or "")
+ for part in [
+ leg.get("dataset_id"),
+ leg.get("route_id"),
+ leg.get("route_ref"),
+ leg.get("from", {}).get("name") or leg.get("from", {}).get("stop_id"),
+ leg.get("to", {}).get("name") or leg.get("to", {}).get("stop_id"),
+ leg.get("departure_seconds") or leg.get("departure_time"),
+ leg.get("arrival_seconds") or leg.get("arrival_time"),
+ ]
+ )
diff --git a/app/journey_search.py b/app/journey_search.py
new file mode 100644
index 0000000..8dee844
--- /dev/null
+++ b/app/journey_search.py
@@ -0,0 +1,717 @@
+from __future__ import annotations
+
+import json
+import hashlib
+import threading
+import time
+import uuid
+from concurrent.futures import ThreadPoolExecutor
+from dataclasses import dataclass, field
+from datetime import date, datetime, timedelta, timezone
+from typing import Any
+
+from sqlalchemy import select
+
+from app.address_search import is_location_token
+from app.db import SessionLocal
+from app.journey import find_journeys, parse_service_date, resolve_location_summary
+from app.models import JourneySearchCache
+from app.routing import direct_route_between_points, route_between_points
+
+
+MAX_PROGRESSIVE_TRANSFERS = 5
+TRANSIT_STAGE_CACHE_TTL_SECONDS = 5 * 60
+TRANSIT_STAGE_CACHE_MAX_ENTRIES = 256
+PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS = 10 * 60
+PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES = 128
+JOURNEY_SEARCH_CACHE_VERSION = "journey-search-v7"
+_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="journey-search")
+_lock = threading.RLock()
+_searches: dict[str, "_SearchState"] = {}
+_progressive_search_inflight: dict[tuple[object, ...], str] = {}
+_transit_stage_cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]] = {}
+_progressive_search_cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]] = {}
+
+
+@dataclass
+class _SearchState:
+ id: str
+ request: dict[str, Any]
+ cache_key: tuple[object, ...] | None = None
+ status: str = "queued"
+ message: str = "Queued."
+ stage: str = "queued"
+ journeys: list[dict] = field(default_factory=list)
+ routing: dict[str, Any] | None = None
+ context: dict[str, Any] = field(default_factory=dict)
+ error: str | None = None
+ created_at: float = field(default_factory=time.time)
+ updated_at: float = field(default_factory=time.time)
+ complete: bool = False
+ cancelled: bool = False
+
+
+def start_journey_search(request: dict[str, Any]) -> dict[str, Any]:
+ key = _progressive_cache_key(request)
+ cached = _progressive_cache_get(key)
+ search_id = uuid.uuid4().hex
+ state = _SearchState(id=search_id, request=dict(request), cache_key=key)
+ if cached is not None:
+ _apply_cached_payload(state, cached)
+ with _lock:
+ _prune_old_searches()
+ if cached is None:
+ existing_search_id = _progressive_search_inflight.get(key)
+ existing_state = None if existing_search_id is None else _searches.get(existing_search_id)
+ if existing_state is not None and not existing_state.complete and not existing_state.cancelled:
+ return _payload(existing_state)
+ _progressive_search_inflight[key] = search_id
+ _searches[search_id] = state
+ if cached is None:
+ _executor.submit(_run_search, search_id)
+ return journey_search_payload(search_id)
+
+
+def journey_search_payload(search_id: str) -> dict[str, Any]:
+ with _lock:
+ state = _searches.get(search_id)
+ if state is None:
+ raise KeyError(search_id)
+ return _payload(state)
+
+
+def cancel_journey_search(search_id: str) -> dict[str, Any]:
+ with _lock:
+ state = _searches.get(search_id)
+ if state is None:
+ raise KeyError(search_id)
+ state.cancelled = True
+ if not state.complete:
+ state.status = "cancelled"
+ state.message = "Search cancelled."
+ state.complete = True
+ state.updated_at = time.time()
+ _clear_inflight_search_locked(state)
+ return _payload(state)
+
+
+def _run_search(search_id: str) -> None:
+ with _lock:
+ state = _searches.get(search_id)
+ if state is None or state.cancelled:
+ return
+ state.status = "running"
+ state.stage = "starting"
+ state.message = "Starting search..."
+ state.updated_at = time.time()
+ request = dict(state.request)
+ try:
+ mode = str(request.get("mode") or "transit")
+ if mode in {"walk", "drive", "car"}:
+ _run_point_route_search(search_id, "drive" if mode == "car" else mode, request)
+ else:
+ _run_transit_search(search_id, request)
+ except Exception as exc: # noqa: BLE001 - report progressive-search failure to client
+ _publish_error(search_id, str(exc))
+
+
+def _run_transit_search(search_id: str, request: dict[str, Any]) -> None:
+ direct_only = bool(request.get("direct_only"))
+ limit = max(3, min(int(request.get("limit") or 5), 10))
+ transfer_seconds = max(0, min(int(request.get("transfer_seconds") or 120), 3600))
+ source_ids = _csv_ints(request.get("source_id"))
+ service_date = request.get("service_date") or None
+ stages = [0] if direct_only else list(range(0, MAX_PROGRESSIVE_TRANSFERS + 1))
+ address_search = is_location_token(request.get("from_stop_id")) or is_location_token(request.get("to_stop_id"))
+ stage_limit = limit if address_search else max(limit, 10)
+ merged: dict[str, dict] = {}
+ context: dict[str, Any] = {}
+ diagnostics: dict[str, Any] = {"stages": []}
+ best_count = 0
+ stale_stages = 0
+ for transfers in stages:
+ if _is_cancelled(search_id):
+ return
+ label = "direct" if transfers == 0 else f"up to {transfers} transfer{'s' if transfers != 1 else ''}"
+ _publish_status(search_id, "running", f"Searching {label}...", f"transfers_{transfers}")
+ stage_started_at = time.monotonic()
+ with SessionLocal() as db:
+ result = _cached_find_journeys(
+ db,
+ from_stop_id=str(request.get("from_stop_id") or ""),
+ to_stop_id=str(request.get("to_stop_id") or ""),
+ via_stop_id=request.get("via_stop_id") or None,
+ source_ids=source_ids,
+ departure=str(request.get("departure") or "08:00"),
+ service_date=service_date,
+ max_transfers=transfers,
+ transfer_seconds=transfer_seconds,
+ limit=stage_limit,
+ )
+ cache_status = str(result.pop("_cache_status", "miss"))
+ elapsed_ms = int((time.monotonic() - stage_started_at) * 1000)
+ stage_diagnostics = {
+ "transfers": transfers,
+ "cache": cache_status,
+ "elapsed_ms": elapsed_ms,
+ "journeys": len(result.get("journeys") or []),
+ }
+ result_diagnostics = result.get("diagnostics")
+ if isinstance(result_diagnostics, dict):
+ stage_diagnostics["details"] = result_diagnostics
+ diagnostics["stages"].append(stage_diagnostics)
+ context = _context_from_result(result)
+ context["diagnostics"] = diagnostics
+ before = len(merged)
+ for journey in result.get("journeys") or []:
+ merged.setdefault(_journey_key(journey), journey)
+ ranked = _select_diverse_journeys(_rank_journeys(merged.values(), str(request.get("ranking") or "recommended")), limit=limit)
+ _publish_results(
+ search_id,
+ journeys=ranked,
+ context=context,
+ status="running",
+ stage=f"transfers_{transfers}",
+ message=f"Found {len(ranked)} option{'s' if len(ranked) != 1 else ''}; still searching..." if not direct_only else "Direct search complete.",
+ )
+ if len(merged) <= before and ranked:
+ stale_stages += 1
+ else:
+ stale_stages = 0
+ best_count = max(best_count, len(ranked))
+ if ranked and stale_stages >= 2 and transfers >= 2:
+ break
+ if _major_hub_address_stage_is_complete(result_diagnostics, ranked, transfers=transfers, limit=limit):
+ break
+ complete_message = (
+ f"Search complete. Found {best_count} option{'s' if best_count != 1 else ''}."
+ if best_count
+ else "Search complete. No route found in the imported timetable."
+ )
+ _publish_complete(search_id, message=complete_message)
+ payload = journey_search_payload(search_id)
+ if payload.get("status") == "complete" and not payload.get("error"):
+ _progressive_cache_put(_progressive_cache_key(request), payload)
+
+
+def _major_hub_address_stage_is_complete(
+ diagnostics: dict[str, Any] | None,
+ ranked: list[dict],
+ *,
+ transfers: int,
+ limit: int,
+) -> bool:
+ if transfers < 1 or not ranked or not isinstance(diagnostics, dict):
+ return False
+ address_access = diagnostics.get("address_access")
+ if not isinstance(address_access, dict) or not address_access.get("major_hubs"):
+ return False
+ return len(ranked) >= min(3, limit)
+
+
+def _run_point_route_search(search_id: str, mode: str, request: dict[str, Any]) -> None:
+ _publish_status(search_id, "running", f"Searching {mode} route...", mode)
+ with SessionLocal() as db:
+ from_location = resolve_location_summary(db, str(request.get("from_stop_id") or ""), source_ids=_csv_ints(request.get("source_id")))
+ to_location = resolve_location_summary(db, str(request.get("to_stop_id") or ""), source_ids=_csv_ints(request.get("source_id")))
+ if from_location.lon is None or from_location.lat is None:
+ raise ValueError("Selected start has no coordinates.")
+ if to_location.lon is None or to_location.lat is None:
+ raise ValueError("Selected destination has no coordinates.")
+ try:
+ route = route_between_points(
+ db,
+ from_lon=float(from_location.lon),
+ from_lat=float(from_location.lat),
+ to_lon=float(to_location.lon),
+ to_lat=float(to_location.lat),
+ mode=mode,
+ max_visited=300_000,
+ )
+ message = f"{mode.title()} route found."
+ except Exception as exc: # noqa: BLE001 - point routing should still return an approximate connector
+ route = direct_route_between_points(
+ db,
+ from_lon=float(from_location.lon),
+ from_lat=float(from_location.lat),
+ to_lon=float(to_location.lon),
+ to_lat=float(to_location.lat),
+ mode=mode,
+ reason=str(exc),
+ )
+ message = f"{mode.title()} route approximated."
+ context = {
+ "from": _stop_payload(from_location),
+ "to": _stop_payload(to_location),
+ "mode": mode,
+ }
+ _publish_routing(search_id, route, context=context, message=message)
+ _publish_complete(search_id, message=f"{mode.title()} route complete.")
+ payload = journey_search_payload(search_id)
+ if payload.get("status") == "complete" and not payload.get("error"):
+ _progressive_cache_put(_progressive_cache_key(request), payload)
+
+
+def _cached_find_journeys(
+ db,
+ *,
+ from_stop_id: str,
+ to_stop_id: str,
+ via_stop_id: object,
+ source_ids: list[int] | None,
+ departure: str,
+ service_date: object,
+ max_transfers: int,
+ transfer_seconds: int,
+ limit: int,
+) -> dict[str, Any]:
+ key = (
+ from_stop_id,
+ to_stop_id,
+ str(via_stop_id or ""),
+ tuple(sorted(int(source_id) for source_id in source_ids or [])),
+ departure,
+ str(service_date or ""),
+ int(max_transfers),
+ int(transfer_seconds),
+ int(limit),
+ )
+ now = time.monotonic()
+ with _lock:
+ cached = _transit_stage_cache.get(key)
+ if cached is not None:
+ expires_at, payload = cached
+ if expires_at > now:
+ return _with_cache_status(payload, "memory")
+ _transit_stage_cache.pop(key, None)
+ durable = _durable_cache_get("transit_stage", key)
+ if durable is not None:
+ with _lock:
+ _transit_stage_cache[key] = (now + TRANSIT_STAGE_CACHE_TTL_SECONDS, json.loads(json.dumps(durable)))
+ _prune_timed_cache(_transit_stage_cache, TRANSIT_STAGE_CACHE_MAX_ENTRIES)
+ return _with_cache_status(durable, "persistent")
+ result = find_journeys(
+ db=db,
+ from_stop_id=from_stop_id,
+ to_stop_id=to_stop_id,
+ via_stop_id=via_stop_id,
+ source_ids=source_ids,
+ departure=departure,
+ service_date=service_date,
+ max_transfers=max_transfers,
+ transfer_seconds=transfer_seconds,
+ limit=limit,
+ )
+ stored_result = json.loads(json.dumps(result))
+ with _lock:
+ _transit_stage_cache[key] = (now + TRANSIT_STAGE_CACHE_TTL_SECONDS, stored_result)
+ _prune_timed_cache(_transit_stage_cache, TRANSIT_STAGE_CACHE_MAX_ENTRIES)
+ _durable_cache_put("transit_stage", key, stored_result, ttl_seconds=TRANSIT_STAGE_CACHE_TTL_SECONDS)
+ return _with_cache_status(result, "miss")
+
+
+def _with_cache_status(payload: dict[str, Any], cache_status: str) -> dict[str, Any]:
+ copied = json.loads(json.dumps(payload))
+ copied["_cache_status"] = cache_status
+ return copied
+
+
+def _prune_timed_cache(cache: dict[tuple[object, ...], tuple[float, dict[str, Any]]], max_entries: int) -> None:
+ if len(cache) <= max_entries:
+ return
+ oldest = sorted(cache.items(), key=lambda item: item[1][0])[: len(cache) - max_entries]
+ for old_key, _ in oldest:
+ cache.pop(old_key, None)
+
+
+def _durable_cache_get(cache_type: str, key: tuple[object, ...]) -> dict[str, Any] | None:
+ storage_key = _durable_cache_key(cache_type, key)
+ now = datetime.now(timezone.utc)
+ try:
+ with SessionLocal() as session:
+ row = session.scalar(select(JourneySearchCache).where(JourneySearchCache.cache_key == storage_key))
+ if row is None:
+ return None
+ expires_at = _as_utc(row.expires_at)
+ if expires_at is None or expires_at <= now:
+ session.delete(row)
+ session.commit()
+ return None
+ return json.loads(row.payload_json)
+ except Exception: # noqa: BLE001 - cache misses must not break journey search
+ return None
+
+
+def _durable_cache_put(cache_type: str, key: tuple[object, ...], payload: dict[str, Any], *, ttl_seconds: int) -> None:
+ storage_key = _durable_cache_key(cache_type, key)
+ now = datetime.now(timezone.utc)
+ expires_at = now + timedelta(seconds=max(1, int(ttl_seconds)))
+ try:
+ with SessionLocal() as session:
+ row = session.scalar(select(JourneySearchCache).where(JourneySearchCache.cache_key == storage_key))
+ if row is None:
+ row = JourneySearchCache(
+ cache_key=storage_key,
+ cache_type=cache_type,
+ payload_json=json.dumps(payload, separators=(",", ":")),
+ created_at=now,
+ updated_at=now,
+ expires_at=expires_at,
+ )
+ session.add(row)
+ else:
+ row.cache_type = cache_type
+ row.payload_json = json.dumps(payload, separators=(",", ":"))
+ row.updated_at = now
+ row.expires_at = expires_at
+ session.commit()
+ except Exception: # noqa: BLE001 - cache writes are best-effort
+ return
+
+
+def _durable_cache_key(cache_type: str, key: tuple[object, ...]) -> str:
+ raw = json.dumps(
+ {
+ "version": JOURNEY_SEARCH_CACHE_VERSION,
+ "cache_type": cache_type,
+ "key": _json_safe(key),
+ },
+ sort_keys=True,
+ separators=(",", ":"),
+ )
+ return hashlib.sha256(raw.encode("utf-8")).hexdigest()
+
+
+def _json_safe(value: object) -> object:
+ if isinstance(value, tuple):
+ return [_json_safe(item) for item in value]
+ if isinstance(value, list):
+ return [_json_safe(item) for item in value]
+ if isinstance(value, dict):
+ return {str(key): _json_safe(item) for key, item in sorted(value.items(), key=lambda item: str(item[0]))}
+ if isinstance(value, (str, int, float, bool)) or value is None:
+ return value
+ if isinstance(value, (date, datetime)):
+ return value.isoformat()
+ return str(value)
+
+
+def _as_utc(value: datetime | None) -> datetime | None:
+ if value is None:
+ return None
+ if value.tzinfo is None:
+ return value.replace(tzinfo=timezone.utc)
+ return value.astimezone(timezone.utc)
+
+
+def _progressive_cache_key(request: dict[str, Any]) -> tuple[object, ...]:
+ source_ids = _csv_ints(request.get("source_id"))
+ return (
+ str(request.get("mode") or "transit"),
+ str(request.get("from_stop_id") or ""),
+ str(request.get("to_stop_id") or ""),
+ str(request.get("via_stop_id") or ""),
+ tuple(sorted(int(source_id) for source_id in source_ids or [])),
+ str(request.get("departure") or "08:00"),
+ str(request.get("service_date") or ""),
+ bool(request.get("direct_only")),
+ str(request.get("ranking") or "recommended"),
+ int(request.get("transfer_seconds") or 120),
+ max(3, min(int(request.get("limit") or 5), 10)),
+ )
+
+
+def _progressive_cache_get(key: tuple[object, ...]) -> dict[str, Any] | None:
+ now = time.monotonic()
+ with _lock:
+ cached = _progressive_search_cache.get(key)
+ if cached is not None:
+ expires_at, payload = cached
+ if expires_at > now:
+ copied = json.loads(json.dumps(payload))
+ copied["cache_status"] = "memory"
+ return copied
+ _progressive_search_cache.pop(key, None)
+ durable = _durable_cache_get("progressive", key)
+ if durable is None:
+ return None
+ with _lock:
+ _progressive_search_cache[key] = (now + PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS, json.loads(json.dumps(durable)))
+ _prune_timed_cache(_progressive_search_cache, PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES)
+ copied = json.loads(json.dumps(durable))
+ copied["cache_status"] = "persistent"
+ return copied
+
+
+def _progressive_cache_put(key: tuple[object, ...], payload: dict[str, Any]) -> None:
+ stored_payload = json.loads(json.dumps(payload))
+ stored_payload.pop("cache_status", None)
+ with _lock:
+ _progressive_search_cache[key] = (time.monotonic() + PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS, stored_payload)
+ _prune_timed_cache(_progressive_search_cache, PROGRESSIVE_SEARCH_CACHE_MAX_ENTRIES)
+ _durable_cache_put("progressive", key, stored_payload, ttl_seconds=PROGRESSIVE_SEARCH_CACHE_TTL_SECONDS)
+
+
+def _apply_cached_payload(state: _SearchState, payload: dict[str, Any]) -> None:
+ state.status = str(payload.get("status") or "complete")
+ state.message = "Cached result."
+ state.stage = str(payload.get("stage") or "cached")
+ state.journeys = json.loads(json.dumps(payload.get("journeys") or []))
+ state.routing = json.loads(json.dumps(payload.get("routing"))) if payload.get("routing") is not None else None
+ state.context = {
+ key: value
+ for key, value in payload.items()
+ if key not in {"search_id", "status", "stage", "message", "complete", "error", "journeys", "routing", "created_at", "updated_at"}
+ }
+ state.error = None
+ state.complete = True
+ state.updated_at = time.time()
+
+
+def _publish_status(search_id: str, status: str, message: str, stage: str) -> None:
+ with _lock:
+ state = _searches.get(search_id)
+ if state is None or state.cancelled:
+ return
+ state.status = status
+ state.message = message
+ state.stage = stage
+ state.updated_at = time.time()
+
+
+def _publish_results(search_id: str, *, journeys: list[dict], context: dict[str, Any], status: str, stage: str, message: str) -> None:
+ with _lock:
+ state = _searches.get(search_id)
+ if state is None or state.cancelled:
+ return
+ state.status = status
+ state.stage = stage
+ state.message = message
+ state.journeys = list(journeys)
+ state.context = dict(context)
+ state.updated_at = time.time()
+
+
+def _publish_routing(search_id: str, routing: dict[str, Any], *, context: dict[str, Any], message: str) -> None:
+ with _lock:
+ state = _searches.get(search_id)
+ if state is None or state.cancelled:
+ return
+ state.status = "running"
+ state.stage = str(routing.get("mode") or "route")
+ state.message = message
+ state.routing = routing
+ state.context = dict(context)
+ state.updated_at = time.time()
+
+
+def _publish_complete(search_id: str, *, message: str) -> None:
+ with _lock:
+ state = _searches.get(search_id)
+ if state is None or state.cancelled:
+ return
+ state.status = "complete"
+ state.message = message
+ state.complete = True
+ state.updated_at = time.time()
+ _clear_inflight_search_locked(state)
+
+
+def _publish_error(search_id: str, message: str) -> None:
+ with _lock:
+ state = _searches.get(search_id)
+ if state is None:
+ return
+ state.status = "error"
+ state.stage = "error"
+ state.message = message
+ state.error = message
+ state.complete = True
+ state.updated_at = time.time()
+ _clear_inflight_search_locked(state)
+
+
+def _clear_inflight_search_locked(state: _SearchState) -> None:
+ if state.cache_key is not None and _progressive_search_inflight.get(state.cache_key) == state.id:
+ _progressive_search_inflight.pop(state.cache_key, None)
+
+
+def _is_cancelled(search_id: str) -> bool:
+ with _lock:
+ state = _searches.get(search_id)
+ return state is None or state.cancelled
+
+
+def _payload(state: _SearchState) -> dict[str, Any]:
+ return {
+ "search_id": state.id,
+ "status": state.status,
+ "stage": state.stage,
+ "message": state.message,
+ "complete": state.complete,
+ "error": state.error,
+ "journeys": json.loads(json.dumps(state.journeys)),
+ "routing": json.loads(json.dumps(state.routing)) if state.routing is not None else None,
+ "created_at": state.created_at,
+ "updated_at": state.updated_at,
+ **json.loads(json.dumps(state.context)),
+ }
+
+
+def _context_from_result(result: dict[str, Any]) -> dict[str, Any]:
+ return {
+ key: value
+ for key, value in result.items()
+ if key not in {"journeys"} and key not in {"error"}
+ }
+
+
+def _journey_key(journey: dict[str, Any]) -> str:
+ parts = []
+ for leg in journey.get("legs") or []:
+ parts.append(
+ "|".join(
+ str(part or "")
+ for part in [
+ leg.get("dataset_id"),
+ leg.get("mode"),
+ leg.get("route_id"),
+ leg.get("trip_id"),
+ (leg.get("from") or {}).get("stop_id") or (leg.get("from") or {}).get("name"),
+ (leg.get("to") or {}).get("stop_id") or (leg.get("to") or {}).get("name"),
+ leg.get("departure_time"),
+ leg.get("arrival_time"),
+ ]
+ )
+ )
+ return "||".join(parts)
+
+
+def _rank_journeys(journeys, ranking: str) -> list[dict]:
+ def key(journey: dict[str, Any]) -> tuple[float, float, int, float]:
+ departure = journey.get("departure_seconds")
+ arrival = journey.get("arrival_seconds")
+ duration = journey.get("duration_minutes")
+ transfers = int(journey.get("transfers") or 0)
+ walking = sum(float(leg.get("distance_m") or 0) for leg in journey.get("legs") or [] if leg.get("mode") == "walk")
+ walking_seconds = walking / 1.35
+ if ranking == "duration":
+ return (
+ float("inf") if duration is None else float(duration),
+ float("inf") if arrival is None else float(arrival),
+ transfers,
+ walking,
+ )
+ if ranking == "fewest_transfers":
+ return (
+ transfers,
+ float("inf") if arrival is None else float(arrival),
+ float("inf") if duration is None else float(duration),
+ walking,
+ )
+ if ranking == "earliest_arrival":
+ return (
+ float("inf") if arrival is None else float(arrival),
+ float("inf") if duration is None else float(duration),
+ transfers,
+ walking,
+ )
+ return (
+ float("inf") if arrival is None else float(arrival) + transfers * 600 + walking_seconds,
+ float("inf") if arrival is None else float(arrival),
+ transfers,
+ walking,
+ )
+
+ return sorted((dict(journey) for journey in journeys), key=key)
+
+
+def _select_diverse_journeys(journeys: list[dict], *, limit: int) -> list[dict]:
+ selected: list[dict] = []
+ selected_exact: set[str] = set()
+ selected_diversity: set[tuple[object, ...]] = set()
+ for journey in journeys:
+ exact_key = _journey_key(journey)
+ if exact_key in selected_exact:
+ continue
+ diversity_key = _journey_diversity_key(journey)
+ if diversity_key in selected_diversity and len(selected) >= 3:
+ continue
+ selected.append(journey)
+ selected_exact.add(exact_key)
+ selected_diversity.add(diversity_key)
+ if len(selected) >= limit:
+ return selected
+ if len(selected) >= min(3, limit):
+ return selected
+ for journey in journeys:
+ exact_key = _journey_key(journey)
+ if exact_key in selected_exact:
+ continue
+ selected.append(journey)
+ selected_exact.add(exact_key)
+ if len(selected) >= min(3, limit):
+ break
+ return _ensure_walk_only_option(selected, journeys, limit=limit)
+
+
+def _ensure_walk_only_option(selected: list[dict], ranked: list[dict], *, limit: int) -> list[dict]:
+ if any(_journey_is_walk_only(journey) for journey in selected):
+ return selected
+ walk = next((journey for journey in ranked if _journey_is_walk_only(journey)), None)
+ if walk is None:
+ return selected
+ if len(selected) < limit:
+ return [*selected, walk]
+ if selected:
+ selected[-1] = walk
+ return selected
+
+
+def _journey_is_walk_only(journey: dict) -> bool:
+ legs = journey.get("legs") or []
+ return bool(legs) and all(leg.get("mode") == "walk" for leg in legs)
+
+
+def _journey_diversity_key(journey: dict[str, Any]) -> tuple[object, ...]:
+ route_signature = tuple(
+ str(leg.get("route_ref") or leg.get("route_id") or leg.get("mode") or "")
+ for leg in journey.get("legs") or []
+ if leg.get("mode") != "walk"
+ )
+ departure = journey.get("departure_seconds")
+ time_band = None if departure is None else int(departure) // (30 * 60)
+ return (int(journey.get("transfers") or 0), route_signature, time_band)
+
+
+def _csv_ints(value: object) -> list[int] | None:
+ if value is None:
+ return None
+ items = [item.strip() for item in str(value).split(",") if item.strip()]
+ if not items:
+ return None
+ return [int(item) for item in items]
+
+
+def _stop_payload(stop) -> dict[str, Any]:
+ return {
+ "id": stop.id,
+ "dataset_id": stop.dataset_id,
+ "stop_id": stop.stop_id,
+ "name": stop.name,
+ "lat": stop.lat,
+ "lon": stop.lon,
+ }
+
+
+def _prune_old_searches() -> None:
+ now = time.time()
+ stale = [
+ search_id
+ for search_id, state in _searches.items()
+ if now - state.updated_at > 15 * 60 or (state.complete and now - state.updated_at > 3 * 60)
+ ]
+ for search_id in stale:
+ state = _searches.pop(search_id, None)
+ if state is not None:
+ _clear_inflight_search_locked(state)
diff --git a/app/main.py b/app/main.py
new file mode 100644
index 0000000..8a1eb93
--- /dev/null
+++ b/app/main.py
@@ -0,0 +1,2653 @@
+from __future__ import annotations
+
+import json
+from contextlib import asynccontextmanager
+from datetime import datetime, timezone
+from functools import wraps
+from pathlib import Path
+from typing import Optional
+from urllib.parse import urlparse
+
+from fastapi import Depends, FastAPI, HTTPException, Request, Response
+from fastapi.responses import HTMLResponse, JSONResponse
+from fastapi.staticfiles import StaticFiles
+from fastapi.templating import Jinja2Templates
+from pydantic import BaseModel, Field
+from sqlalchemy import and_, exists, func, not_, or_, select, text
+from sqlalchemy.exc import OperationalError
+from sqlalchemy.orm import Session, joinedload
+
+from app.address_search import address_at_point, search_addresses
+from app.config import settings
+from app.data_management import dataset_row_counts, source_row_counts
+from app.dataset_search import search_datasets
+from app.db import engine, get_db, init_db
+from app.db_lock import DatabaseWriteBusy, database_write_lock, database_write_status
+from app.geofabrik import create_geofabrik_source, geofabrik_catalog
+from app.gtfs_storage import scheduled_stop_ids
+from app.harmonization import GTFS_QA_NOTE_PREFIX, gtfs_harmonization_feed_detail, gtfs_harmonization_inventory
+from app.itineraries import generate_itineraries, itinerary_payload, recent_itineraries, set_itinerary_saved, set_leg_locked
+from app.journey import find_journeys, nearest_scheduled_stops, search_scheduled_stops
+from app.journey_search import cancel_journey_search, journey_search_payload, start_journey_search
+from app.jobs import (
+ active_address_index_rebuild_job,
+ active_source_import_job,
+ active_source_workflow_jobs,
+ active_dataset_delete_jobs,
+ cancel_job,
+ create_dataset_delete_job,
+ create_address_index_rebuild_job,
+ create_maintenance_job,
+ create_osm_relabel_job,
+ create_route_matching_job,
+ create_source_delete_job,
+ create_source_import_job,
+ create_route_layer_rebuild_job,
+ dismiss_job,
+ dismiss_terminal_jobs,
+ job_event_payload,
+ job_events,
+ job_payload,
+ latest_jobs,
+ job_queue_revision,
+ pause_job,
+ queue_missing_gtfs_sidecar_recovery_jobs,
+ reconcile_interrupted_jobs,
+ reconcile_source_workflow_state,
+ request_job_control,
+ resume_job,
+ retry_job,
+ set_job_priority,
+)
+from app.models import (
+ CanonicalStop,
+ CanonicalStopLink,
+ Dataset,
+ GtfsRoute,
+ GtfsStop,
+ Itinerary,
+ ItineraryLeg,
+ Job,
+ MatchRule,
+ OsmFeature,
+ PipelineRun,
+ RouteMatch,
+ RoutePattern,
+ Source,
+ SourceCatalogEntry,
+ SourceUpdateCheck,
+)
+from app.osm_storage import (
+ ensure_main_osm_feature,
+ osm_feature_count,
+ osm_feature_public_id,
+ query_osm_features,
+ resolve_osm_feature,
+)
+from app.pipeline.matcher import candidate_osm_routes_for_route, run_route_matching, score_route_pair
+from app.pipeline.osm_addresses import address_index_status
+from app.qa import qa_summary
+from app.pipeline.route_layer import rebuild_route_layer
+from app.pipeline.sample_data import load_sample_project
+from app.pipeline.state import pipeline_run_payload
+from app.routing import route_between_points, routing_status
+from app.serializers import feature_collection, gtfs_route_feature, gtfs_stop_feature, match_row, osm_feature_feature, route_pattern_feature
+from app.spatial import using_postgresql
+from app.source_updates import (
+ check_source_for_update,
+ latest_source_update_check,
+ update_check_payload,
+)
+from app.source_catalog import (
+ catalog_entry_payload,
+ import_ingestable_sources,
+ import_source_catalog,
+ linked_source_counts,
+ source_catalog_rows,
+ source_catalog_summary,
+)
+from app.worker_supervisor import queue_worker_status, start_queue_workers, stop_queue_workers
+
+@asynccontextmanager
+async def lifespan(_app: FastAPI):
+ init_db()
+ with Session(engine) as session:
+ reconcile_interrupted_jobs(session)
+ queue_missing_gtfs_sidecar_recovery_jobs(session)
+ session.commit()
+ start_queue_workers()
+ try:
+ yield
+ finally:
+ stop_queue_workers()
+
+
+app = FastAPI(title="Mobility Workbench", version="0.1.0", lifespan=lifespan)
+app.mount("/static", StaticFiles(directory=Path(__file__).parent / "static"), name="static")
+templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
+MAP_FEATURE_LIMIT = 5000
+MAP_FEATURE_LIMIT_MAX = 20000
+
+
+class SourceCreate(BaseModel):
+ name: str
+ kind: str
+ url: str
+ country: Optional[str] = None
+ license: Optional[str] = None
+ priority: Optional[str] = None
+ mode_scope: Optional[str] = None
+ source_basis: Optional[str] = None
+ notes: Optional[str] = None
+ catalog_entry_id: Optional[int] = None
+
+
+class GtfsFeedReviewUpdate(BaseModel):
+ license: Optional[str] = None
+ review_status: Optional[str] = None
+ review_note: Optional[str] = None
+ enabled: Optional[bool] = None
+
+
+class CatalogImportRequest(BaseModel):
+ csv_path: Optional[str] = None
+ update_existing: bool = True
+
+
+class RuleCreate(BaseModel):
+ rule_type: str
+ selector: dict
+ action: dict
+ note: Optional[str] = None
+
+
+class CanonicalStopGtfsLinkRequest(BaseModel):
+ gtfs_stop_id: Optional[int] = None
+ dataset_id: Optional[int] = None
+ stop_id: Optional[str] = None
+ note: Optional[str] = None
+
+
+class ItineraryGenerateRequest(BaseModel):
+ from_stop_id: str
+ to_stop_id: str
+ via_stop_id: Optional[str] = None
+ source_id: Optional[str] = None
+ departure: str = "08:00"
+ service_date: Optional[str] = None
+ max_transfers: int = 1
+ transfer_seconds: int = 120
+ limit: int = 5
+ preferences: dict = Field(default_factory=dict)
+
+
+class JourneySearchRequest(BaseModel):
+ from_stop_id: str
+ to_stop_id: str
+ via_stop_id: Optional[str] = None
+ source_id: Optional[str] = None
+ departure: str = "08:00"
+ service_date: Optional[str] = None
+ mode: str = "transit"
+ direct_only: bool = False
+ ranking: str = "recommended"
+ transfer_seconds: int = 120
+ limit: int = 5
+
+
+class ItinerarySaveRequest(BaseModel):
+ saved: bool = True
+
+
+class ItineraryLegLockRequest(BaseModel):
+ locked: bool = True
+
+
+class GeofabrikSourceRequest(BaseModel):
+ geofabrik_id: str
+ import_updates: bool = False
+ run_import: bool = False
+ run_match: bool = True
+ build_route_layer: bool = True
+
+
+class AdminActionRequest(BaseModel):
+ dry_run: bool = True
+ confirm: Optional[str] = None
+ dataset_id: Optional[int] = None
+
+
+class JobPriorityRequest(BaseModel):
+ priority: int
+
+
+def write_endpoint(operation: str):
+ def decorator(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ try:
+ with database_write_lock(operation):
+ return fn(*args, **kwargs)
+ except DatabaseWriteBusy as exc:
+ raise HTTPException(status_code=409, detail=str(exc), headers={"Retry-After": "5"}) from exc
+
+ return wrapper
+
+ return decorator
+
+
+@app.exception_handler(OperationalError)
+async def database_operational_error_handler(_request: Request, exc: OperationalError) -> JSONResponse:
+ if _is_database_locked_error(exc):
+ return JSONResponse(
+ status_code=409,
+ content={"detail": "Database is busy with another operation. Try again when the current write finishes."},
+ headers={"Retry-After": "5"},
+ )
+ return JSONResponse(status_code=500, content={"detail": "Database operation failed."})
+
+
+@app.get("/", response_class=HTMLResponse)
+def index(request: Request) -> HTMLResponse:
+ return templates.TemplateResponse(request=request, name="index.html")
+
+
+@app.get("/api/sources")
+def list_sources(db: Session = Depends(get_db)) -> list[dict]:
+ reconcile_interrupted_jobs(db)
+ queue_missing_gtfs_sidecar_recovery_jobs(db)
+ db.commit()
+ sources = db.scalars(select(Source).order_by(Source.id)).all()
+ latest_checks = {s.id: latest_source_update_check(db, s.id) for s in sources}
+ active_jobs = active_source_workflow_jobs(db)
+ active_dataset_jobs = active_dataset_delete_jobs(db)
+ return [_source_response(db, s, latest_checks.get(s.id), active_jobs.get(s.id), active_dataset_jobs) for s in sources]
+
+
+@app.post("/api/sources")
+@write_endpoint("create source")
+def create_source(payload: SourceCreate, db: Session = Depends(get_db)) -> dict:
+ if payload.kind not in {"gtfs", "osm_geojson", "osm_pbf", "osm_diff"}:
+ raise HTTPException(status_code=400, detail="kind must be 'gtfs', 'osm_geojson', 'osm_pbf', or 'osm_diff'")
+ catalog_entry = db.get(SourceCatalogEntry, payload.catalog_entry_id) if payload.catalog_entry_id else None
+ if payload.catalog_entry_id and catalog_entry is None:
+ raise HTTPException(status_code=404, detail="catalog entry not found")
+ source = Source(
+ catalog_entry_id=payload.catalog_entry_id,
+ name=_truncate(payload.name, 255) or payload.name,
+ kind=payload.kind,
+ url=payload.url,
+ country=payload.country or _catalog_country(catalog_entry),
+ license=_truncate(payload.license, 255) or _truncate(catalog_entry.access_license_notes if catalog_entry else None, 255),
+ priority=payload.priority or (catalog_entry.priority if catalog_entry else None),
+ mode_scope=payload.mode_scope or (catalog_entry.mode_scope if catalog_entry else None),
+ source_basis=payload.source_basis or (catalog_entry.source_category if catalog_entry else None),
+ notes=payload.notes or _catalog_notes(catalog_entry),
+ )
+ db.add(source)
+ db.commit()
+ db.refresh(source)
+ return {"id": source.id, "status": source.status}
+
+
+@app.post("/api/sources/{source_id}/run")
+@write_endpoint("run source import")
+def run_source(source_id: int, db: Session = Depends(get_db)) -> dict:
+ source = db.get(Source, source_id)
+ if source is None:
+ raise HTTPException(status_code=404, detail="source not found")
+ job, started = _queue_source_import_job(db, source, run_match=False, build_route_layer=False)
+ return {"source_id": source.id, "job_id": job.id, "status": "queued" if started else "already_running", "job": job_payload(job)}
+
+
+@app.post("/api/sources/{source_id}/update")
+@write_endpoint("update source")
+def update_source(source_id: int, force: bool = False, db: Session = Depends(get_db)) -> dict:
+ source = db.get(Source, source_id)
+ if source is None:
+ raise HTTPException(status_code=404, detail="source not found")
+ active_job = active_source_import_job(db, source.id)
+ if active_job is not None:
+ return {"source_id": source.id, "status": "already_running", "job": job_payload(active_job)}
+ check = check_source_for_update(db, source)
+ if check.status != "checked":
+ db.commit()
+ raise HTTPException(status_code=502, detail=check.reason or "update check failed")
+ if not force and not check.update_available:
+ db.commit()
+ return {
+ "source_id": source.id,
+ "status": "skipped",
+ "reason": check.reason,
+ "check": update_check_payload(check),
+ }
+ job, started = _queue_source_import_job(db, source, run_match=True, build_route_layer=True)
+ return {"source_id": source.id, "status": "queued" if started else "already_running", "job": job_payload(job), "check": update_check_payload(check)}
+
+
+@app.post("/api/sources/{source_id}/check-update")
+@write_endpoint("check source update")
+def check_source_update(source_id: int, db: Session = Depends(get_db)) -> dict:
+ source = db.get(Source, source_id)
+ if source is None:
+ raise HTTPException(status_code=404, detail="source not found")
+ check = check_source_for_update(db, source)
+ db.commit()
+ db.refresh(check)
+ return update_check_payload(check) or {}
+
+
+@app.get("/api/sources/{source_id}/update-checks")
+def list_source_update_checks(source_id: int, limit: int = 20, db: Session = Depends(get_db)) -> dict:
+ source = db.get(Source, source_id)
+ if source is None:
+ raise HTTPException(status_code=404, detail="source not found")
+ checks = db.scalars(
+ select(SourceUpdateCheck)
+ .where(SourceUpdateCheck.source_id == source.id)
+ .order_by(SourceUpdateCheck.checked_at.desc(), SourceUpdateCheck.id.desc())
+ .limit(max(1, min(limit, 100)))
+ ).all()
+ return {"source_id": source.id, "checks": [update_check_payload(check) for check in checks]}
+
+
+@app.get("/api/database/write-status")
+def database_write_status_endpoint() -> dict:
+ state = database_write_status()
+ return {
+ "locked": state.locked,
+ "operation": state.operation,
+ "pid": state.pid,
+ "elapsed_seconds": None if state.elapsed_seconds is None else round(state.elapsed_seconds, 1),
+ }
+
+
+ADMIN_JOB_ACTIONS = {
+ "init-db",
+ "reset-db",
+ "backfill-gtfs-shapes",
+ "prune-cache",
+ "prune-inactive-datasets",
+ "vacuum-db",
+}
+
+
+@app.post("/api/admin/init-db")
+@write_endpoint("admin init database")
+def admin_init_database(priority: int = 0, db: Session = Depends(get_db)) -> dict:
+ return _queue_admin_maintenance_job(db, "init-db", None, priority=priority)
+
+
+@app.post("/api/admin/reset-db")
+@write_endpoint("admin reset database")
+def admin_reset_database(payload: AdminActionRequest, priority: int = 0, db: Session = Depends(get_db)) -> dict:
+ return _queue_admin_maintenance_job(db, "reset-db", payload, priority=priority)
+
+
+@app.post("/api/admin/backfill-gtfs-shapes")
+@write_endpoint("admin backfill GTFS shapes")
+def admin_backfill_gtfs_shapes(payload: AdminActionRequest | None = None, priority: int = 0, db: Session = Depends(get_db)) -> dict:
+ return _queue_admin_maintenance_job(db, "backfill-gtfs-shapes", payload, priority=priority)
+
+
+@app.post("/api/admin/prune-cache")
+@write_endpoint("admin prune cache")
+def admin_prune_cache(payload: AdminActionRequest, priority: int = 0, db: Session = Depends(get_db)) -> dict:
+ return _queue_admin_maintenance_job(db, "prune-cache", payload, priority=priority)
+
+
+@app.post("/api/admin/prune-inactive-datasets")
+@write_endpoint("admin prune inactive datasets")
+def admin_prune_inactive_datasets(payload: AdminActionRequest, priority: int = 0, db: Session = Depends(get_db)) -> dict:
+ return _queue_admin_maintenance_job(db, "prune-inactive-datasets", payload, priority=priority)
+
+
+@app.post("/api/admin/vacuum-db")
+@write_endpoint("admin vacuum database")
+def admin_vacuum_database(payload: AdminActionRequest, priority: int = 0, db: Session = Depends(get_db)) -> dict:
+ return _queue_admin_maintenance_job(db, "vacuum-db", payload, priority=priority)
+
+
+@app.post("/api/jobs/admin/{action}")
+@write_endpoint("queue admin maintenance")
+def queue_admin_maintenance(action: str, payload: AdminActionRequest | None = None, priority: int = 0, db: Session = Depends(get_db)) -> dict:
+ return _queue_admin_maintenance_job(db, action, payload, priority=priority)
+
+
+@app.delete("/api/sources/{source_id}")
+@write_endpoint("queue source delete")
+def remove_source(source_id: int, db: Session = Depends(get_db)) -> dict:
+ source = db.get(Source, source_id)
+ if source is None:
+ raise HTTPException(status_code=404, detail="source not found")
+ job = create_source_delete_job(db, source)
+ db.commit()
+ db.refresh(job)
+ return job_payload(job)
+
+
+@app.get("/api/source-catalog")
+def list_source_catalog(
+ q: Optional[str] = None,
+ country: Optional[str] = None,
+ priority: Optional[str] = None,
+ status: Optional[str] = None,
+ limit: int = 100,
+ db: Session = Depends(get_db),
+) -> dict:
+ entries = source_catalog_rows(db, q=q, country=country, priority=priority, status=status, limit=limit)
+ source_counts = linked_source_counts(db, entries)
+ return {
+ "summary": source_catalog_summary(db),
+ "entries": [
+ catalog_entry_payload(entry, linked_source_count=source_counts.get(entry.id, 0))
+ for entry in entries
+ ],
+ }
+
+
+@app.post("/api/source-catalog/import")
+@write_endpoint("import source catalog")
+def import_source_catalog_endpoint(payload: CatalogImportRequest | None = None, db: Session = Depends(get_db)) -> dict:
+ result = import_source_catalog(db, None if payload is None else payload.csv_path, update_existing=payload.update_existing if payload else True)
+ db.commit()
+ return {**result, "summary": source_catalog_summary(db)}
+
+
+@app.post("/api/source-catalog/import-ingestable")
+@write_endpoint("import ingestable sources")
+def import_ingestable_sources_endpoint(payload: CatalogImportRequest | None = None, db: Session = Depends(get_db)) -> dict:
+ result = import_ingestable_sources(db, None if payload is None else payload.csv_path, update_existing=payload.update_existing if payload else True)
+ db.commit()
+ return {**result, "summary": source_catalog_summary(db)}
+
+
+@app.post("/api/jobs/source-catalog/import")
+@write_endpoint("queue source catalog import")
+def queue_source_catalog_import(payload: CatalogImportRequest | None = None, priority: int = 0, db: Session = Depends(get_db)) -> dict:
+ request_payload = payload or CatalogImportRequest()
+ job = create_maintenance_job(db, "source-catalog-import", _request_model_payload(request_payload), priority=priority)
+ db.commit()
+ db.refresh(job)
+ return job_payload(job)
+
+
+@app.post("/api/jobs/source-catalog/import-ingestable")
+@write_endpoint("queue ingestable source import")
+def queue_ingestable_source_import(payload: CatalogImportRequest | None = None, priority: int = 0, db: Session = Depends(get_db)) -> dict:
+ request_payload = payload or CatalogImportRequest()
+ job = create_maintenance_job(db, "source-catalog-import-ingestable", _request_model_payload(request_payload), priority=priority)
+ db.commit()
+ db.refresh(job)
+ return job_payload(job)
+
+
+@app.get("/api/geofabrik/catalog")
+def list_geofabrik_catalog(q: Optional[str] = None, limit: int = 80) -> dict:
+ try:
+ return {"entries": geofabrik_catalog(q=q, limit=limit)}
+ except Exception as exc: # noqa: BLE001 - remote catalog errors should be visible in UI
+ raise HTTPException(status_code=502, detail=f"Geofabrik catalog fetch failed: {exc}") from exc
+
+
+@app.post("/api/geofabrik/sources")
+@write_endpoint("create Geofabrik source")
+def create_geofabrik_source_endpoint(payload: GeofabrikSourceRequest, db: Session = Depends(get_db)) -> dict:
+ try:
+ source = create_geofabrik_source(db, payload.geofabrik_id, import_updates=payload.import_updates)
+ except ValueError as exc:
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
+ response = {"source": _source_response(db, source), "job": None}
+ if payload.run_import:
+ job, _started = _queue_source_import_job(db, source, run_match=payload.run_match, build_route_layer=payload.build_route_layer)
+ response["job"] = job_payload(job)
+ return response
+ db.commit()
+ db.refresh(source)
+ response["source"] = _source_response(db, source)
+ return response
+
+
+@app.post("/api/jobs/sources/{source_id}/import")
+@write_endpoint("queue source import")
+def queue_source_import(
+ source_id: int,
+ run_match: bool = True,
+ build_route_layer: bool = True,
+ priority: int = 0,
+ db: Session = Depends(get_db),
+) -> dict:
+ source = db.get(Source, source_id)
+ if source is None:
+ raise HTTPException(status_code=404, detail="source not found")
+ job, _started = _queue_source_import_job(db, source, run_match=run_match, build_route_layer=build_route_layer, priority=priority)
+ return job_payload(job)
+
+
+@app.delete("/api/datasets/{dataset_id}")
+@write_endpoint("delete dataset")
+def remove_dataset(dataset_id: int, db: Session = Depends(get_db)) -> dict:
+ dataset = db.get(Dataset, dataset_id)
+ if dataset is None:
+ raise HTTPException(status_code=404, detail="dataset not found")
+ job = create_dataset_delete_job(db, dataset)
+ db.commit()
+ db.refresh(job)
+ return job_payload(job)
+
+
+@app.get("/api/datasets/search")
+def dataset_search(q: str = "", active_only: bool = False, limit: int = 80, db: Session = Depends(get_db)) -> dict:
+ return search_datasets(db, q, active_only=active_only, limit=limit)
+
+
+@app.get("/api/datasets/search/feature.geojson")
+def dataset_search_feature(type: str, id: str, db: Session = Depends(get_db)) -> JSONResponse:
+ if type == "gtfs_route":
+ route = db.get(GtfsRoute, _path_int(id, "GTFS route id"))
+ if route is None:
+ raise HTTPException(status_code=404, detail="GTFS route not found")
+ feature = gtfs_route_feature(route, {"search_result_type": type})
+ elif type == "osm_route":
+ osm_feature = resolve_osm_feature(db, id)
+ if osm_feature is None or osm_feature.kind != "route":
+ raise HTTPException(status_code=404, detail="OSM route not found")
+ feature = osm_feature_feature(osm_feature, {"search_result_type": type})
+ elif type == "route_pattern":
+ pattern = db.get(RoutePattern, _path_int(id, "route pattern id"))
+ if pattern is None:
+ raise HTTPException(status_code=404, detail="route pattern not found")
+ feature = route_pattern_feature(pattern, {"search_result_type": type})
+ else:
+ raise HTTPException(status_code=400, detail="type must be gtfs_route, osm_route, or route_pattern")
+ return JSONResponse(feature_collection([] if feature is None else [feature]))
+
+
+@app.post("/api/sample/reset")
+@write_endpoint("reset sample data")
+def reset_and_load_sample(db: Session = Depends(get_db)) -> dict:
+ result = load_sample_project(db)
+ db.commit()
+ return result
+
+
+@app.post("/api/jobs/sample-reset")
+@write_endpoint("queue sample reset")
+def queue_sample_reset(priority: int = 0, db: Session = Depends(get_db)) -> dict:
+ job = create_maintenance_job(db, "sample-reset", {}, priority=priority)
+ db.commit()
+ db.refresh(job)
+ return job_payload(job)
+
+
+@app.post("/api/match/run")
+@write_endpoint("run route matching")
+def run_matching(db: Session = Depends(get_db)) -> dict:
+ result = run_route_matching(db)
+ db.commit()
+ return result
+
+
+@app.post("/api/route-layer/build")
+@write_endpoint("build route layer")
+def build_route_layer(db: Session = Depends(get_db)) -> dict:
+ result = rebuild_route_layer(db)
+ db.commit()
+ return result
+
+
+@app.post("/api/jobs/route-layer-build")
+@write_endpoint("queue route layer build")
+def queue_route_layer_build(priority: int = 0, db: Session = Depends(get_db)) -> dict:
+ job = create_route_layer_rebuild_job(db, priority=priority)
+ db.commit()
+ db.refresh(job)
+ return job_payload(job)
+
+
+@app.post("/api/jobs/address-index-build")
+@write_endpoint("queue OSM address index build")
+def queue_address_index_build(priority: int = 0, db: Session = Depends(get_db)) -> dict:
+ job = create_address_index_rebuild_job(db, priority=priority)
+ db.commit()
+ db.refresh(job)
+ return job_payload(job)
+
+
+@app.post("/api/jobs/match-run")
+@write_endpoint("queue route matching")
+def queue_route_matching(priority: int = 0, db: Session = Depends(get_db)) -> dict:
+ job = create_route_matching_job(db, priority=priority)
+ db.commit()
+ db.refresh(job)
+ return job_payload(job)
+
+
+@app.post("/api/jobs/osm-relabel")
+@write_endpoint("queue OSM relabeling")
+def queue_osm_relabel(
+ dataset_id: Optional[int] = None,
+ build_route_layer: bool = True,
+ force: bool = False,
+ priority: int = 0,
+ db: Session = Depends(get_db),
+) -> dict:
+ job = create_osm_relabel_job(db, dataset_id=dataset_id, build_route_layer=build_route_layer, force=force, priority=priority)
+ db.commit()
+ db.refresh(job)
+ return job_payload(job)
+
+
+@app.get("/api/jobs")
+def list_jobs(
+ response: Response,
+ kind: Optional[str] = None,
+ limit: int = 20,
+ include_dismissed: bool = False,
+ db: Session = Depends(get_db),
+) -> dict:
+ reconcile_interrupted_jobs(db)
+ db.commit()
+ workers = queue_worker_status()
+ revision = _job_queue_revision_payload(db, workers=workers, include_dismissed=include_dismissed)
+ _set_etag(response, revision["revision"])
+ return {
+ "jobs": [job_payload(job) for job in latest_jobs(db, limit=limit, kind=kind, include_dismissed=include_dismissed)],
+ "workers": workers,
+ "revision": revision["revision"],
+ }
+
+
+@app.get("/api/jobs/revision")
+def get_jobs_revision(
+ response: Response,
+ since: Optional[str] = None,
+ include_dismissed: bool = False,
+ db: Session = Depends(get_db),
+) -> dict:
+ reconcile_interrupted_jobs(db)
+ db.commit()
+ revision = _job_queue_revision_payload(db, include_dismissed=include_dismissed)
+ _set_etag(response, revision["revision"])
+ return {**revision, "changed": since != revision["revision"]}
+
+
+@app.get("/api/jobs/{job_id}")
+def get_job(job_id: int, db: Session = Depends(get_db)) -> dict:
+ job = db.get(Job, job_id)
+ if job is None:
+ raise HTTPException(status_code=404, detail="job not found")
+ return job_payload(job)
+
+
+@app.get("/api/jobs/{job_id}/events")
+def get_job_events(job_id: int, limit: int = 100, db: Session = Depends(get_db)) -> dict:
+ job = db.get(Job, job_id)
+ if job is None:
+ raise HTTPException(status_code=404, detail="job not found")
+ return {"job": job_payload(job), "events": [job_event_payload(event) for event in job_events(db, job_id, limit=limit)]}
+
+
+@app.get("/api/pipeline-runs")
+def list_pipeline_runs(
+ stage: Optional[str] = None,
+ source_id: Optional[int] = None,
+ dataset_id: Optional[int] = None,
+ limit: int = 50,
+ db: Session = Depends(get_db),
+) -> dict:
+ stmt = select(PipelineRun).order_by(PipelineRun.started_at.desc(), PipelineRun.id.desc())
+ if stage:
+ stmt = stmt.where(PipelineRun.stage == stage)
+ if source_id is not None:
+ stmt = stmt.where(PipelineRun.source_id == source_id)
+ if dataset_id is not None:
+ stmt = stmt.where(PipelineRun.dataset_id == dataset_id)
+ runs = db.scalars(stmt.limit(max(1, min(int(limit), 200)))).all()
+ return {"runs": [pipeline_run_payload(run) for run in runs]}
+
+
+@app.post("/api/jobs/{job_id}/pause")
+def pause_job_endpoint(job_id: int, db: Session = Depends(get_db)) -> dict:
+ return _job_control_response(db, job_id, "pause")
+
+
+@app.post("/api/jobs/{job_id}/resume")
+@write_endpoint("resume job")
+def resume_job_endpoint(job_id: int, db: Session = Depends(get_db)) -> dict:
+ try:
+ job = resume_job(db, job_id)
+ except ValueError as exc:
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
+ db.commit()
+ db.refresh(job)
+ return job_payload(job)
+
+
+@app.post("/api/jobs/{job_id}/retry")
+@write_endpoint("retry job")
+def retry_job_endpoint(job_id: int, db: Session = Depends(get_db)) -> dict:
+ try:
+ job = retry_job(db, job_id)
+ except ValueError as exc:
+ detail = str(exc)
+ status_code = 404 if detail.startswith("job not found") else 400
+ raise HTTPException(status_code=status_code, detail=detail) from exc
+ db.commit()
+ db.refresh(job)
+ return job_payload(job)
+
+
+@app.post("/api/jobs/{job_id}/cancel")
+def cancel_job_endpoint(job_id: int, db: Session = Depends(get_db)) -> dict:
+ return _job_control_response(db, job_id, "cancel")
+
+
+@app.post("/api/jobs/{job_id}/stop")
+def stop_job_endpoint(job_id: int, db: Session = Depends(get_db)) -> dict:
+ return _job_control_response(db, job_id, "cancel")
+
+
+def _job_control_response(db: Session, job_id: int, action: str) -> dict:
+ try:
+ job = pause_job(db, job_id) if action == "pause" else cancel_job(db, job_id)
+ db.commit()
+ db.refresh(job)
+ return job_payload(job)
+ except OperationalError as exc:
+ db.rollback()
+ if not _is_database_locked_error(exc):
+ raise
+ return request_job_control(job_id, action)
+ except ValueError as exc:
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
+
+
+@app.post("/api/jobs/{job_id}/priority")
+@write_endpoint("set job priority")
+def set_job_priority_endpoint(job_id: int, payload: JobPriorityRequest, db: Session = Depends(get_db)) -> dict:
+ try:
+ job = set_job_priority(db, job_id, payload.priority)
+ except ValueError as exc:
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
+ db.commit()
+ db.refresh(job)
+ return job_payload(job)
+
+
+@app.post("/api/jobs/{job_id}/dismiss")
+@write_endpoint("dismiss job")
+def dismiss_job_endpoint(job_id: int, db: Session = Depends(get_db)) -> dict:
+ try:
+ job = dismiss_job(db, job_id)
+ except ValueError as exc:
+ detail = str(exc)
+ status_code = 404 if detail.startswith("job not found") else 400
+ raise HTTPException(status_code=status_code, detail=detail) from exc
+ db.commit()
+ db.refresh(job)
+ return job_payload(job)
+
+
+@app.post("/api/jobs/dismiss-terminal")
+@write_endpoint("dismiss terminal jobs")
+def dismiss_terminal_jobs_endpoint(db: Session = Depends(get_db)) -> dict:
+ count = dismiss_terminal_jobs(db)
+ db.commit()
+ return {"dismissed": count}
+
+
+@app.get("/api/stats")
+def stats(db: Session = Depends(get_db)) -> dict:
+ queue_missing_gtfs_sidecar_recovery_jobs(db)
+ db.commit()
+ active_dataset_ids = [row[0] for row in db.execute(select(Dataset.id).where(Dataset.is_active.is_(True))).all()]
+ def count_active(model, *where):
+ if not active_dataset_ids:
+ return 0
+ return db.scalar(select(func.count()).select_from(model).where(model.dataset_id.in_(active_dataset_ids), *where)) or 0
+
+ active_gtfs_dataset_ids = [
+ row[0]
+ for row in db.execute(select(Dataset.id).where(Dataset.is_active.is_(True), Dataset.kind == "gtfs")).all()
+ ]
+ active_osm_dataset_ids = [
+ row[0]
+ for row in db.execute(select(Dataset.id).where(Dataset.is_active.is_(True), Dataset.kind == "osm_geojson")).all()
+ ]
+ match_counts = {}
+ if active_gtfs_dataset_ids:
+ match_counts = {
+ status: count
+ for status, count in db.execute(
+ select(RouteMatch.status, func.count())
+ .join(RouteMatch.gtfs_route)
+ .where(GtfsRoute.dataset_id.in_(active_gtfs_dataset_ids))
+ .group_by(RouteMatch.status)
+ ).all()
+ }
+ matched_total = (match_counts.get("matched", 0) or 0) + (match_counts.get("accepted", 0) or 0)
+ review_total = sum(match_counts.values())
+ scope_summary = _match_scope_summary(db, active_gtfs_dataset_ids)
+ scoped_review_total = sum(scope_summary.get("in_osm_scope", {}).values())
+ scoped_matched_total = (
+ scope_summary.get("in_osm_scope", {}).get("matched", 0)
+ + scope_summary.get("in_osm_scope", {}).get("accepted", 0)
+ )
+ return {
+ "sources": db.scalar(select(func.count()).select_from(Source)) or 0,
+ "source_catalog": source_catalog_summary(db),
+ "active_datasets": len(active_dataset_ids),
+ "gtfs_routes": count_active(GtfsRoute),
+ "gtfs_stops": count_active(GtfsStop),
+ "osm_routes": sum(osm_feature_count(db, dataset_id, kind="route") for dataset_id in active_osm_dataset_ids),
+ "osm_stops_terminals": sum(osm_feature_count(db, dataset_id, kind=["stop", "station", "terminal"]) for dataset_id in active_osm_dataset_ids),
+ "route_patterns": db.scalar(select(func.count()).select_from(RoutePattern)) or 0,
+ "matches": match_counts,
+ "match_summary": {
+ "reviewed_routes": review_total,
+ "matched_or_accepted": matched_total,
+ "probable": match_counts.get("probable", 0) or 0,
+ "weak": match_counts.get("weak", 0) or 0,
+ "missing": match_counts.get("missing", 0) or 0,
+ "coverage_percent": round((matched_total / review_total) * 100, 1) if review_total else 0,
+ "in_scope_reviewed_routes": scoped_review_total,
+ "in_scope_matched_or_accepted": scoped_matched_total,
+ "in_scope_coverage_percent": round((scoped_matched_total / scoped_review_total) * 100, 1) if scoped_review_total else 0,
+ },
+ "match_scope_summary": scope_summary,
+ "by_source": _source_stats(db),
+ }
+
+
+@app.get("/api/qa/summary")
+def qa_summary_endpoint(db: Session = Depends(get_db)) -> dict:
+ return qa_summary(db)
+
+
+@app.get("/api/harmonization/gtfs/inventory")
+def gtfs_harmonization_inventory_endpoint(db: Session = Depends(get_db)) -> dict:
+ return gtfs_harmonization_inventory(db)
+
+
+@app.get("/api/harmonization/gtfs/sources/{source_id}")
+def gtfs_harmonization_feed_detail_endpoint(source_id: int, db: Session = Depends(get_db)) -> dict:
+ detail = gtfs_harmonization_feed_detail(db, source_id)
+ if detail is None:
+ raise HTTPException(status_code=404, detail="GTFS source not found")
+ return detail
+
+
+@app.patch("/api/harmonization/gtfs/sources/{source_id}/review")
+@write_endpoint("update GTFS feed QA review")
+def update_gtfs_harmonization_feed_review(source_id: int, payload: GtfsFeedReviewUpdate, db: Session = Depends(get_db)) -> dict:
+ source = db.get(Source, source_id)
+ if source is None or source.kind != "gtfs":
+ raise HTTPException(status_code=404, detail="GTFS source not found")
+ if payload.review_status is not None and payload.review_status not in {"unreviewed", "approved", "needs_review", "blocked", "rejected"}:
+ raise HTTPException(status_code=400, detail="review_status must be unreviewed, approved, needs_review, blocked, or rejected")
+ if payload.license is not None:
+ source.license = _truncate(payload.license, 255)
+ if payload.enabled is not None:
+ source.enabled = bool(payload.enabled)
+ if payload.review_status is not None or payload.review_note is not None:
+ source.notes = _upsert_gtfs_qa_note(
+ source.notes,
+ status=payload.review_status or "unreviewed",
+ note=payload.review_note or "",
+ )
+ db.commit()
+ detail = gtfs_harmonization_feed_detail(db, source_id)
+ if detail is None:
+ raise HTTPException(status_code=404, detail="GTFS source not found")
+ return detail
+
+
+@app.get("/api/matches")
+def list_matches(
+ status: Optional[str] = None,
+ source_id: Optional[str] = None,
+ limit: int = 250,
+ db: Session = Depends(get_db),
+) -> list[dict]:
+ source_ids = _csv_ints(source_id, "source_id")
+ gtfs_dataset_ids = _active_dataset_ids(db, source_ids=source_ids, dataset_kinds=["gtfs"])
+ if not gtfs_dataset_ids:
+ return []
+ stmt = select(RouteMatch).options(joinedload(RouteMatch.gtfs_route), joinedload(RouteMatch.osm_feature)).order_by(
+ RouteMatch.confidence.desc(), RouteMatch.id
+ ).join(RouteMatch.gtfs_route).where(GtfsRoute.dataset_id.in_(gtfs_dataset_ids))
+ if status:
+ stmt = stmt.where(RouteMatch.status == status)
+ matches = db.scalars(stmt.limit(min(limit, 1000))).all()
+ return [match_row(m) for m in matches]
+
+
+@app.post("/api/matches/{match_id}/accept")
+@write_endpoint("accept route match")
+def accept_match(match_id: int, db: Session = Depends(get_db)) -> dict:
+ match = db.get(RouteMatch, match_id)
+ if match is None:
+ raise HTTPException(status_code=404, detail="match not found")
+ match.status = "accepted"
+ match.rule_source = "manual"
+ match.updated_at = datetime.now(timezone.utc)
+ _persist_match_rule(db, match, "accept_match")
+ db.commit()
+ return {"id": match.id, "status": match.status}
+
+
+@app.post("/api/matches/{match_id}/reject")
+@write_endpoint("reject route match")
+def reject_match(match_id: int, db: Session = Depends(get_db)) -> dict:
+ match = db.get(RouteMatch, match_id)
+ if match is None:
+ raise HTTPException(status_code=404, detail="match not found")
+ match.status = "rejected"
+ match.rule_source = "manual"
+ match.updated_at = datetime.now(timezone.utc)
+ _persist_match_rule(db, match, "reject_match")
+ db.commit()
+ return {"id": match.id, "status": match.status}
+
+
+@app.get("/api/matches/{match_id}/candidates")
+def match_candidates(match_id: int, limit: int = 25, db: Session = Depends(get_db)) -> dict:
+ match = db.get(RouteMatch, match_id)
+ if match is None:
+ raise HTTPException(status_code=404, detail="match not found")
+ route = db.get(GtfsRoute, match.gtfs_route_id)
+ if route is None:
+ raise HTTPException(status_code=404, detail="GTFS route for match not found")
+ osm_dataset_ids = _active_dataset_ids(db, dataset_kinds=["osm_geojson"])
+ if not osm_dataset_ids:
+ return {
+ "match_id": match.id,
+ "route": _gtfs_route_summary(route),
+ "candidates": [],
+ "preview": _match_candidate_preview(route, match, []),
+ }
+
+ features = candidate_osm_routes_for_route(db, route, osm_dataset_ids)
+ candidate_rows = []
+ for feature in features:
+ score, reasons = score_route_pair(route, feature)
+ candidate_rows.append(
+ (
+ float(score),
+ feature.ref or "",
+ feature.name or "",
+ feature,
+ reasons,
+ )
+ )
+ candidate_rows.sort(key=lambda row: (row[0], row[1], row[2]), reverse=True)
+ selected_rows = candidate_rows[: max(1, min(limit, 100))]
+ candidates = [
+ {
+ "score": round(score, 2),
+ "status": _status_from_candidate_score(score),
+ "osm": _osm_route_summary(feature),
+ "reasons": reasons,
+ "current_match": feature.id == match.osm_feature_id,
+ }
+ for score, _, _, feature, reasons in selected_rows
+ ]
+ return {
+ "match_id": match.id,
+ "current_status": match.status,
+ "current_confidence": match.confidence,
+ "route": _gtfs_route_summary(route),
+ "candidates": candidates,
+ "preview": _match_candidate_preview(
+ route,
+ match,
+ [(feature, score, reasons) for score, _, _, feature, reasons in selected_rows],
+ ),
+ }
+
+
+@app.post("/api/matches/{match_id}/candidates/{osm_feature_id}/accept")
+@write_endpoint("accept route match candidate")
+def accept_match_candidate(match_id: int, osm_feature_id: str, db: Session = Depends(get_db)) -> dict:
+ match = db.get(RouteMatch, match_id)
+ if match is None:
+ raise HTTPException(status_code=404, detail="match not found")
+ route = db.get(GtfsRoute, match.gtfs_route_id)
+ feature = resolve_osm_feature(db, osm_feature_id)
+ if route is None:
+ raise HTTPException(status_code=404, detail="GTFS route for match not found")
+ if feature is None or feature.kind != "route":
+ raise HTTPException(status_code=404, detail="OSM route candidate not found")
+
+ feature = ensure_main_osm_feature(db, feature)
+ score, reasons = score_route_pair(route, feature)
+ match.osm_feature = feature
+ match.osm_feature_id = feature.id
+ match.confidence = float(score)
+ match.status = "accepted"
+ match.rule_source = "manual"
+ match.reasons_json = json.dumps(reasons, separators=(",", ":"))
+ match.updated_at = datetime.now(timezone.utc)
+ _persist_match_rule(db, match, "accept_match")
+ db.commit()
+ db.refresh(match)
+ return {"id": match.id, "status": match.status, "confidence": match.confidence, "match": match_row(match)}
+
+
+@app.post("/api/rules")
+@write_endpoint("create match rule")
+def create_rule(payload: RuleCreate, db: Session = Depends(get_db)) -> dict:
+ rule = MatchRule(
+ rule_type=payload.rule_type,
+ selector_json=json.dumps(payload.selector, separators=(",", ":")),
+ action_json=json.dumps(payload.action, separators=(",", ":")),
+ note=payload.note,
+ )
+ db.add(rule)
+ db.commit()
+ db.refresh(rule)
+ return {"id": rule.id, "rule_type": rule.rule_type}
+
+
+@app.get("/api/rules")
+def list_rules(db: Session = Depends(get_db)) -> list[dict]:
+ rules = db.scalars(select(MatchRule).order_by(MatchRule.id.desc())).all()
+ return [
+ {
+ "id": r.id,
+ "rule_type": r.rule_type,
+ "selector": json.loads(r.selector_json),
+ "action": json.loads(r.action_json),
+ "note": r.note,
+ "active": r.active,
+ "created_at": r.created_at.isoformat() if r.created_at else None,
+ }
+ for r in rules
+ ]
+
+
+@app.get("/api/canonical-stops/{canonical_stop_id}")
+def canonical_stop_detail(canonical_stop_id: int, db: Session = Depends(get_db)) -> dict:
+ canonical = db.get(CanonicalStop, canonical_stop_id)
+ if canonical is None:
+ raise HTTPException(status_code=404, detail="canonical stop not found")
+ return _canonical_stop_detail(db, canonical)
+
+
+@app.get("/api/canonical-stops/{canonical_stop_id}/gtfs-candidates")
+def canonical_stop_gtfs_candidates(
+ canonical_stop_id: int,
+ q: Optional[str] = None,
+ source_id: Optional[str] = None,
+ limit: int = 30,
+ db: Session = Depends(get_db),
+) -> dict:
+ canonical = db.get(CanonicalStop, canonical_stop_id)
+ if canonical is None:
+ raise HTTPException(status_code=404, detail="canonical stop not found")
+ candidates = _gtfs_stop_candidates_for_canonical_stop(
+ db,
+ canonical,
+ q=q,
+ source_ids=_csv_ints(source_id, "source_id"),
+ limit=limit,
+ )
+ return {"canonical_stop": _canonical_stop_summary(canonical), "candidates": candidates}
+
+
+@app.post("/api/canonical-stops/{canonical_stop_id}/link-gtfs-stop")
+@write_endpoint("link GTFS stop to canonical stop")
+def link_gtfs_stop_to_canonical_stop(
+ canonical_stop_id: int,
+ payload: CanonicalStopGtfsLinkRequest,
+ db: Session = Depends(get_db),
+) -> dict:
+ canonical = db.get(CanonicalStop, canonical_stop_id)
+ if canonical is None:
+ raise HTTPException(status_code=404, detail="canonical stop not found")
+ stop = _resolve_gtfs_stop_link_payload(db, payload)
+ link = db.scalar(
+ select(CanonicalStopLink).where(
+ CanonicalStopLink.object_type == "gtfs_stop",
+ CanonicalStopLink.dataset_id == stop.dataset_id,
+ CanonicalStopLink.object_id == stop.id,
+ )
+ )
+ role = "parent" if stop.parent_station is None else "platform"
+ if link is None:
+ link = CanonicalStopLink(
+ canonical_stop_id=canonical.id,
+ layer="timetable",
+ object_type="gtfs_stop",
+ dataset_id=stop.dataset_id,
+ object_id=stop.id,
+ external_id=stop.stop_id,
+ role=role,
+ confidence=1.0,
+ distance_m=None,
+ metadata_json=json.dumps({"manual_rule": "link_canonical_stop"}, separators=(",", ":")),
+ )
+ db.add(link)
+ else:
+ link.canonical_stop_id = canonical.id
+ link.role = role
+ link.confidence = 1.0
+ link.metadata_json = json.dumps({"manual_rule": "link_canonical_stop"}, separators=(",", ":"))
+ _persist_canonical_stop_rule(db, "link_canonical_stop", stop, canonical, payload.note)
+ db.commit()
+ db.refresh(canonical)
+ return _canonical_stop_detail(db, canonical)
+
+
+@app.post("/api/canonical-stop-links/{link_id}/unlink")
+@write_endpoint("unlink GTFS stop from canonical stop")
+def unlink_gtfs_stop_from_canonical_stop(link_id: int, db: Session = Depends(get_db)) -> dict:
+ link = db.get(CanonicalStopLink, link_id)
+ if link is None:
+ raise HTTPException(status_code=404, detail="canonical stop link not found")
+ if link.object_type != "gtfs_stop":
+ raise HTTPException(status_code=400, detail="only GTFS stop links can be unlinked here")
+ stop = db.get(GtfsStop, link.object_id)
+ if stop is None:
+ raise HTTPException(status_code=404, detail="linked GTFS stop not found")
+ standalone = _manual_standalone_canonical_stop(db, stop)
+ link.canonical_stop_id = standalone.id
+ link.confidence = 1.0
+ link.metadata_json = json.dumps({"manual_rule": "unlink_canonical_stop"}, separators=(",", ":"))
+ _persist_canonical_stop_rule(db, "unlink_canonical_stop", stop, standalone, None)
+ db.commit()
+ db.refresh(standalone)
+ return {"canonical_stop": _canonical_stop_summary(standalone), "moved_link_id": link.id}
+
+
+@app.get("/api/map/osm_routes.geojson")
+def map_osm_routes(
+ mode: Optional[str] = None,
+ route_scope: Optional[str] = None,
+ source_id: Optional[str] = None,
+ dataset_id: Optional[str] = None,
+ bbox: Optional[str] = None,
+ zoom: Optional[int] = None,
+ limit: int = MAP_FEATURE_LIMIT,
+ db: Session = Depends(get_db),
+) -> JSONResponse:
+ return _osm_features_response(db=db, kind="route", mode=mode, route_scope=route_scope, source_id=source_id, dataset_id=dataset_id, bbox=bbox, limit=limit)
+
+
+@app.get("/api/map/osm_stops.geojson")
+def map_osm_stops(
+ geometry: Optional[str] = None,
+ source_id: Optional[str] = None,
+ dataset_id: Optional[str] = None,
+ bbox: Optional[str] = None,
+ zoom: Optional[int] = None,
+ limit: int = MAP_FEATURE_LIMIT,
+ db: Session = Depends(get_db),
+) -> JSONResponse:
+ return _osm_features_response(
+ db=db,
+ kind="stop,station,terminal",
+ geometry=geometry,
+ source_id=source_id,
+ dataset_id=dataset_id,
+ bbox=bbox,
+ limit=limit,
+ )
+
+
+@app.get("/api/map/osm_features.geojson")
+def map_osm_features(
+ kind: Optional[str] = None,
+ mode: Optional[str] = None,
+ route_scope: Optional[str] = None,
+ geometry: Optional[str] = None,
+ source_id: Optional[str] = None,
+ dataset_id: Optional[str] = None,
+ bbox: Optional[str] = None,
+ zoom: Optional[int] = None,
+ limit: int = MAP_FEATURE_LIMIT,
+ db: Session = Depends(get_db),
+) -> JSONResponse:
+ return _osm_features_response(
+ db=db,
+ kind=kind,
+ mode=mode,
+ route_scope=route_scope,
+ geometry=geometry,
+ source_id=source_id,
+ dataset_id=dataset_id,
+ bbox=bbox,
+ limit=limit,
+ )
+
+
+@app.get("/api/map/gtfs_routes.geojson")
+def map_gtfs_routes(
+ mode: Optional[str] = None,
+ route_scope: Optional[str] = None,
+ source_id: Optional[str] = None,
+ dataset_id: Optional[str] = None,
+ bbox: Optional[str] = None,
+ zoom: Optional[int] = None,
+ limit: int = MAP_FEATURE_LIMIT,
+ db: Session = Depends(get_db),
+) -> JSONResponse:
+ active_dataset_ids = _active_dataset_ids(
+ db,
+ source_ids=_csv_ints(source_id, "source_id"),
+ dataset_ids=_csv_ints(dataset_id, "dataset_id"),
+ dataset_kinds=["gtfs"],
+ )
+ if not active_dataset_ids:
+ return JSONResponse(feature_collection([]))
+ stmt = select(GtfsRoute).where(GtfsRoute.dataset_id.in_(active_dataset_ids))
+ modes = _csv_values(mode)
+ if modes:
+ stmt = stmt.where(GtfsRoute.mode.in_(modes))
+ route_scopes = _csv_values(route_scope)
+ if route_scopes:
+ stmt = stmt.where(_gtfs_route_scope_condition(route_scopes))
+ parsed_bbox = _parse_bbox(bbox)
+ if parsed_bbox:
+ stmt = _where_bbox_overlaps(stmt, GtfsRoute, parsed_bbox)
+ rows = db.scalars(stmt.order_by(GtfsRoute.mode, GtfsRoute.short_name, GtfsRoute.id).limit(_clamp_limit(limit))).all()
+ return JSONResponse(feature_collection(f for r in rows if (f := gtfs_route_feature(r)) is not None))
+
+
+@app.get("/api/map/route_patterns.geojson")
+def map_route_patterns(
+ mode: Optional[str] = None,
+ route_scope: Optional[str] = None,
+ source_kind: Optional[str] = None,
+ status: Optional[str] = None,
+ bbox: Optional[str] = None,
+ zoom: Optional[int] = None,
+ limit: int = MAP_FEATURE_LIMIT,
+ db: Session = Depends(get_db),
+) -> JSONResponse:
+ stmt = select(RoutePattern)
+ modes = _csv_values(mode)
+ if modes:
+ stmt = stmt.where(RoutePattern.mode.in_(modes))
+ route_scopes = _csv_values(route_scope)
+ if route_scopes:
+ stmt = stmt.where(_route_pattern_scope_condition(route_scopes))
+ source_kinds = _csv_values(source_kind)
+ if source_kinds:
+ stmt = stmt.where(RoutePattern.source_kind.in_(source_kinds))
+ statuses = _csv_values(status)
+ if statuses:
+ stmt = stmt.where(RoutePattern.status.in_(statuses))
+ parsed_bbox = _parse_bbox(bbox)
+ if parsed_bbox:
+ stmt = _where_bbox_overlaps(stmt, RoutePattern, parsed_bbox)
+ rows = db.scalars(
+ stmt.order_by(RoutePattern.mode, RoutePattern.route_ref, RoutePattern.source_kind, RoutePattern.id).limit(_clamp_limit(limit))
+ ).all()
+ return JSONResponse(feature_collection(f for r in rows if (f := route_pattern_feature(r)) is not None))
+
+
+@app.get("/api/map/gtfs_stops.geojson")
+def map_gtfs_stops(
+ source_id: Optional[str] = None,
+ dataset_id: Optional[str] = None,
+ bbox: Optional[str] = None,
+ zoom: Optional[int] = None,
+ limit: int = MAP_FEATURE_LIMIT,
+ db: Session = Depends(get_db),
+) -> JSONResponse:
+ active_dataset_ids = _active_dataset_ids(
+ db,
+ source_ids=_csv_ints(source_id, "source_id"),
+ dataset_ids=_csv_ints(dataset_id, "dataset_id"),
+ dataset_kinds=["gtfs"],
+ )
+ if not active_dataset_ids:
+ return JSONResponse(feature_collection([]))
+ stmt = select(GtfsStop).where(GtfsStop.dataset_id.in_(active_dataset_ids))
+ parsed_bbox = _parse_bbox(bbox)
+ if parsed_bbox:
+ stmt = _where_point_bbox(stmt, GtfsStop, parsed_bbox)
+ rows = db.scalars(stmt.order_by(GtfsStop.name, GtfsStop.id).limit(_clamp_limit(limit))).all()
+ return JSONResponse(feature_collection(f for r in rows if (f := gtfs_stop_feature(r)) is not None))
+
+
+@app.get("/api/map/matched_gtfs_routes.geojson")
+def map_matched_gtfs_routes(
+ status: Optional[str] = None,
+ source_id: Optional[str] = None,
+ dataset_id: Optional[str] = None,
+ bbox: Optional[str] = None,
+ zoom: Optional[int] = None,
+ limit: int = MAP_FEATURE_LIMIT,
+ db: Session = Depends(get_db),
+) -> JSONResponse:
+ stmt = select(RouteMatch).options(joinedload(RouteMatch.gtfs_route), joinedload(RouteMatch.osm_feature))
+ if status:
+ stmt = stmt.where(RouteMatch.status == status)
+ gtfs_dataset_ids = _active_dataset_ids(
+ db,
+ source_ids=_csv_ints(source_id, "source_id"),
+ dataset_ids=_csv_ints(dataset_id, "dataset_id"),
+ dataset_kinds=["gtfs"],
+ )
+ if not gtfs_dataset_ids:
+ return JSONResponse(feature_collection([]))
+ stmt = stmt.join(RouteMatch.gtfs_route).where(GtfsRoute.dataset_id.in_(gtfs_dataset_ids))
+ parsed_bbox = _parse_bbox(bbox)
+ if parsed_bbox:
+ stmt = _where_bbox_overlaps(stmt, GtfsRoute, parsed_bbox)
+ matches = db.scalars(stmt.order_by(RouteMatch.confidence.desc(), RouteMatch.id).limit(_clamp_limit(limit))).all()
+ features = []
+ for match in matches:
+ route = match.gtfs_route
+ extra = {
+ "match_id": match.id,
+ "match_status": match.status,
+ "confidence": match.confidence,
+ "gtfs_route_id": route.route_id,
+ "gtfs_ref": route.short_name,
+ "gtfs_mode": route.mode,
+ "visual_source": "osm" if match.osm_feature else "gtfs",
+ }
+ if match.osm_feature:
+ feature = osm_feature_feature(match.osm_feature, extra)
+ else:
+ feature = gtfs_route_feature(route, extra)
+ if feature:
+ features.append(feature)
+ return JSONResponse(feature_collection(features))
+
+
+@app.get("/api/journey/stops")
+def journey_stops(
+ q: Optional[str] = None,
+ source_id: Optional[str] = None,
+ bbox: Optional[str] = None,
+ limit: int = 25,
+ db: Session = Depends(get_db),
+) -> dict:
+ if settings.is_postgresql_database:
+ db.execute(text("SET LOCAL statement_timeout = '4000ms'"))
+ parsed_bbox = _parse_bbox(bbox)
+ selected_limit = max(1, min(limit, 100))
+ stop_limit = selected_limit if source_id else max(1, int(selected_limit * 0.7))
+ address_limit = 0 if source_id else max(1, selected_limit - stop_limit)
+ stops = []
+ try:
+ stops = search_scheduled_stops(
+ db=db,
+ query=q,
+ source_ids=_csv_ints(source_id, "source_id"),
+ bbox=parsed_bbox,
+ limit=stop_limit,
+ )
+ except OperationalError as exc:
+ if not _is_statement_timeout_error(exc):
+ raise
+ db.rollback()
+ return {"stops": [], "timed_out": True, "message": "Search timed out; keep typing or try a more specific query."}
+ if address_limit:
+ try:
+ stops.extend(search_addresses(db=db, query=q, bbox=parsed_bbox, limit=address_limit))
+ except OperationalError as exc:
+ if not _is_statement_timeout_error(exc):
+ raise
+ db.rollback()
+ return {
+ "stops": stops[:selected_limit],
+ "timed_out": True,
+ "message": "Address search timed out; showing stop results.",
+ }
+ return {
+ "stops": stops[:selected_limit]
+ }
+
+
+@app.get("/api/journey/nearest-location")
+def journey_nearest_location(
+ lat: float,
+ lon: float,
+ source_id: Optional[str] = None,
+ limit: int = 4,
+ stop_radius_m: float = 35,
+ db: Session = Depends(get_db),
+) -> dict:
+ if lat < -90 or lat > 90 or lon < -180 or lon > 180:
+ raise HTTPException(status_code=400, detail="lat/lon out of range")
+ stops = nearest_scheduled_stops(
+ db=db,
+ lat=lat,
+ lon=lon,
+ source_ids=_csv_ints(source_id, "source_id"),
+ limit=1,
+ radius_m=max(5, min(float(stop_radius_m), 120)),
+ )
+ location = stops[0] if stops else None
+ if location is not None:
+ return {
+ "lat": lat,
+ "lon": lon,
+ "location": location,
+ "locations": [location],
+ "selection_kind": "stop",
+ }
+ address_job = active_address_index_rebuild_job(db)
+ if address_job is not None:
+ return {
+ "lat": lat,
+ "lon": lon,
+ "location": None,
+ "locations": [],
+ "selection_kind": "coordinate",
+ "address_lookup_skipped": True,
+ "message": f"Address index rebuild job #{address_job.id} is {address_job.status}; using coordinates.",
+ }
+ address = address_at_point(db=db, lat=lat, lon=lon)
+ if address is not None:
+ return {
+ "lat": lat,
+ "lon": lon,
+ "location": address,
+ "locations": [address],
+ "selection_kind": "address",
+ }
+ return {
+ "lat": lat,
+ "lon": lon,
+ "location": None,
+ "locations": [],
+ "selection_kind": "coordinate",
+ }
+
+
+@app.get("/api/addresses/status")
+def addresses_status(db: Session = Depends(get_db)) -> dict:
+ return address_index_status(db)
+
+
+@app.get("/api/addresses/search")
+def addresses_search(
+ q: Optional[str] = None,
+ bbox: Optional[str] = None,
+ limit: int = 25,
+ db: Session = Depends(get_db),
+) -> dict:
+ return {"addresses": search_addresses(db=db, query=q, bbox=_parse_bbox(bbox), limit=limit)}
+
+
+@app.get("/api/journey/search")
+def journey_search(
+ from_stop_id: str,
+ to_stop_id: str,
+ via_stop_id: Optional[str] = None,
+ source_id: Optional[str] = None,
+ departure: str = "08:00",
+ service_date: Optional[str] = None,
+ max_transfers: int = 0,
+ transfer_seconds: int = 120,
+ limit: int = 5,
+ db: Session = Depends(get_db),
+) -> dict:
+ try:
+ return find_journeys(
+ db=db,
+ from_stop_id=from_stop_id,
+ to_stop_id=to_stop_id,
+ departure=departure,
+ max_transfers=max(0, min(max_transfers, 5)),
+ transfer_seconds=max(0, min(transfer_seconds, 3600)),
+ limit=limit,
+ source_ids=_csv_ints(source_id, "source_id"),
+ via_stop_id=via_stop_id,
+ service_date=service_date,
+ )
+ except ValueError as exc:
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
+
+
+@app.post("/api/journey/searches")
+def start_progressive_journey_search(payload: JourneySearchRequest) -> dict:
+ mode = payload.mode if payload.mode in {"transit", "walk", "drive", "car"} else "transit"
+ ranking = payload.ranking if payload.ranking in {"recommended", "earliest_arrival", "duration", "fewest_transfers"} else "recommended"
+ try:
+ return start_journey_search(
+ {
+ "from_stop_id": payload.from_stop_id,
+ "to_stop_id": payload.to_stop_id,
+ "via_stop_id": payload.via_stop_id,
+ "source_id": payload.source_id,
+ "departure": payload.departure,
+ "service_date": payload.service_date,
+ "mode": mode,
+ "direct_only": bool(payload.direct_only),
+ "ranking": ranking,
+ "transfer_seconds": max(0, min(payload.transfer_seconds, 3600)),
+ "limit": max(1, min(payload.limit, 10)),
+ }
+ )
+ except ValueError as exc:
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
+
+
+@app.get("/api/journey/searches/{search_id}")
+def progressive_journey_search_status(search_id: str) -> dict:
+ try:
+ return journey_search_payload(search_id)
+ except KeyError as exc:
+ raise HTTPException(status_code=404, detail="journey search not found") from exc
+
+
+@app.delete("/api/journey/searches/{search_id}")
+def cancel_progressive_journey_search(search_id: str) -> dict:
+ try:
+ return cancel_journey_search(search_id)
+ except KeyError as exc:
+ raise HTTPException(status_code=404, detail="journey search not found") from exc
+
+
+@app.post("/api/itineraries/generate")
+@write_endpoint("generate itineraries")
+def generate_itineraries_endpoint(payload: ItineraryGenerateRequest, db: Session = Depends(get_db)) -> dict:
+ try:
+ result = generate_itineraries(
+ db,
+ from_stop_id=payload.from_stop_id,
+ to_stop_id=payload.to_stop_id,
+ via_stop_id=payload.via_stop_id,
+ departure=payload.departure,
+ service_date=payload.service_date,
+ max_transfers=max(0, min(payload.max_transfers, 5)),
+ transfer_seconds=max(0, min(payload.transfer_seconds, 3600)),
+ limit=max(1, min(payload.limit, 10)),
+ source_ids=_csv_ints(payload.source_id, "source_id"),
+ preferences=payload.preferences,
+ )
+ db.commit()
+ return result
+ except ValueError as exc:
+ db.rollback()
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
+
+
+@app.get("/api/itineraries")
+def list_itineraries(saved_only: bool = False, limit: int = 30, db: Session = Depends(get_db)) -> dict:
+ return {"itineraries": recent_itineraries(db, saved_only=saved_only, limit=limit)}
+
+
+@app.post("/api/itineraries/{itinerary_id}/save")
+@write_endpoint("save itinerary")
+def save_itinerary(itinerary_id: int, payload: ItinerarySaveRequest, db: Session = Depends(get_db)) -> dict:
+ itinerary = db.get(Itinerary, itinerary_id)
+ if itinerary is None:
+ raise HTTPException(status_code=404, detail="itinerary not found")
+ result = set_itinerary_saved(db, itinerary, payload.saved)
+ db.commit()
+ return result
+
+
+@app.get("/api/routing/status")
+def routing_status_endpoint(db: Session = Depends(get_db)) -> dict:
+ return routing_status(db)
+
+
+@app.get("/api/routing/route")
+def routing_route_endpoint(
+ from_lon: float,
+ from_lat: float,
+ to_lon: float,
+ to_lat: float,
+ mode: str = "walk",
+ dataset_id: Optional[int] = None,
+ max_visited: int = 160_000,
+ db: Session = Depends(get_db),
+) -> dict:
+ try:
+ return route_between_points(
+ db,
+ from_lon=from_lon,
+ from_lat=from_lat,
+ to_lon=to_lon,
+ to_lat=to_lat,
+ mode=mode,
+ dataset_id=dataset_id,
+ max_visited=max(1_000, min(max_visited, 1_000_000)),
+ )
+ except ValueError as exc:
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
+
+
+@app.post("/api/itinerary-legs/{leg_id}/lock")
+@write_endpoint("lock itinerary leg")
+def lock_itinerary_leg(leg_id: int, payload: ItineraryLegLockRequest, db: Session = Depends(get_db)) -> dict:
+ leg = db.get(ItineraryLeg, leg_id)
+ if leg is None:
+ raise HTTPException(status_code=404, detail="itinerary leg not found")
+ result = set_leg_locked(db, leg, payload.locked)
+ db.commit()
+ return result
+
+
+def _canonical_stop_summary(canonical: CanonicalStop) -> dict:
+ return {
+ "id": canonical.id,
+ "stop_key": canonical.stop_key,
+ "name": canonical.name,
+ "normalized_name": canonical.normalized_name,
+ "lat": canonical.lat,
+ "lon": canonical.lon,
+ "mode": canonical.mode,
+ "metadata": _json_object(canonical.metadata_json),
+ "created_at": canonical.created_at.isoformat() if canonical.created_at else None,
+ }
+
+
+def _canonical_stop_detail(db: Session, canonical: CanonicalStop) -> dict:
+ rows = db.execute(
+ select(CanonicalStopLink, Dataset, Source)
+ .join(Dataset, Dataset.id == CanonicalStopLink.dataset_id)
+ .join(Source, Source.id == Dataset.source_id)
+ .where(CanonicalStopLink.canonical_stop_id == canonical.id)
+ .order_by(CanonicalStopLink.layer, Source.name, Dataset.id, CanonicalStopLink.role, CanonicalStopLink.external_id)
+ ).all()
+ gtfs_links = []
+ osm_links = []
+ for link, dataset, source in rows:
+ if link.object_type == "gtfs_stop":
+ stop = db.get(GtfsStop, link.object_id)
+ gtfs_links.append(_canonical_gtfs_stop_link_payload(link, stop, dataset, source))
+ elif link.object_type == "osm_feature":
+ feature = db.get(OsmFeature, link.object_id)
+ osm_links.append(_canonical_osm_feature_link_payload(link, feature, dataset, source))
+ return {
+ "canonical_stop": _canonical_stop_summary(canonical),
+ "gtfs_stops": gtfs_links,
+ "osm_features": osm_links,
+ "rules": _canonical_stop_rule_payloads(db, canonical, gtfs_links),
+ }
+
+
+def _canonical_gtfs_stop_link_payload(link: CanonicalStopLink, stop: GtfsStop | None, dataset: Dataset, source: Source) -> dict:
+ return {
+ "link_id": link.id,
+ "layer": link.layer,
+ "object_type": link.object_type,
+ "dataset_id": link.dataset_id,
+ "source_id": source.id,
+ "source_name": source.name,
+ "source_kind": source.kind,
+ "dataset_active": dataset.is_active,
+ "external_id": link.external_id,
+ "role": link.role,
+ "confidence": link.confidence,
+ "distance_m": link.distance_m,
+ "metadata": _json_object(link.metadata_json),
+ "stop": None
+ if stop is None
+ else {
+ "id": stop.id,
+ "dataset_id": stop.dataset_id,
+ "stop_id": stop.stop_id,
+ "name": stop.name,
+ "lat": stop.lat,
+ "lon": stop.lon,
+ "parent_station": stop.parent_station,
+ },
+ }
+
+
+def _canonical_osm_feature_link_payload(
+ link: CanonicalStopLink,
+ feature: OsmFeature | None,
+ dataset: Dataset,
+ source: Source,
+) -> dict:
+ return {
+ "link_id": link.id,
+ "layer": link.layer,
+ "object_type": link.object_type,
+ "dataset_id": link.dataset_id,
+ "source_id": source.id,
+ "source_name": source.name,
+ "source_kind": source.kind,
+ "dataset_active": dataset.is_active,
+ "external_id": link.external_id,
+ "role": link.role,
+ "confidence": link.confidence,
+ "distance_m": link.distance_m,
+ "metadata": _json_object(link.metadata_json),
+ "feature": None
+ if feature is None
+ else {
+ "id": feature.id,
+ "dataset_id": feature.dataset_id,
+ "osm_type": feature.osm_type,
+ "osm_id": feature.osm_id,
+ "kind": feature.kind,
+ "mode": feature.mode,
+ "name": feature.name,
+ "ref": feature.ref,
+ "operator": feature.operator,
+ "network": feature.network,
+ "geometry": {
+ "present": bool(feature.geometry_geojson),
+ "bbox": [feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat],
+ },
+ },
+ }
+
+
+def _canonical_stop_rule_payloads(db: Session, canonical: CanonicalStop, gtfs_links: list[dict]) -> list[dict]:
+ rule_rows = db.scalars(
+ select(MatchRule)
+ .where(MatchRule.rule_type.in_(["link_canonical_stop", "unlink_canonical_stop"]))
+ .order_by(MatchRule.id.desc())
+ .limit(300)
+ ).all()
+ linked_keys = {
+ (int(link["source_id"]), str(link["external_id"]))
+ for link in gtfs_links
+ if link.get("source_id") is not None and link.get("external_id")
+ }
+ payloads = []
+ for rule in rule_rows:
+ selector = _json_object(rule.selector_json)
+ action = _json_object(rule.action_json)
+ selector_key = (
+ _optional_int(selector.get("source_id") or _json_object_value(selector.get("gtfs_stop"), "source_id")),
+ str(selector.get("external_id") or _json_object_value(selector.get("gtfs_stop"), "external_id") or ""),
+ )
+ if action.get("target_stop_key") != canonical.stop_key and selector_key not in linked_keys:
+ continue
+ payloads.append(
+ {
+ "id": rule.id,
+ "rule_type": rule.rule_type,
+ "selector": selector,
+ "action": action,
+ "note": rule.note,
+ "active": rule.active,
+ "created_at": rule.created_at.isoformat() if rule.created_at else None,
+ }
+ )
+ return payloads[:50]
+
+
+def _gtfs_stop_candidates_for_canonical_stop(
+ db: Session,
+ canonical: CanonicalStop,
+ *,
+ q: str | None,
+ source_ids: list[int] | None,
+ limit: int,
+) -> list[dict]:
+ active_dataset_ids = _active_dataset_ids(db, source_ids=source_ids, dataset_kinds=["gtfs"])
+ if not active_dataset_ids:
+ return []
+ stmt = (
+ select(GtfsStop, Dataset, Source)
+ .join(Dataset, Dataset.id == GtfsStop.dataset_id)
+ .join(Source, Source.id == Dataset.source_id)
+ .where(GtfsStop.dataset_id.in_(active_dataset_ids))
+ )
+ query = (q or "").strip()
+ if query:
+ pattern = f"%{query}%"
+ tokens = [token for token in query.replace(",", " ").replace("/", " ").split() if token]
+ token_filters = [
+ or_(GtfsStop.name.ilike(f"%{token}%"), GtfsStop.stop_id.ilike(f"%{token}%"))
+ for token in tokens
+ ]
+ where_parts = [GtfsStop.name.ilike(pattern), GtfsStop.stop_id.ilike(pattern)]
+ if token_filters:
+ where_parts.append(and_(*token_filters))
+ stmt = stmt.where(or_(*where_parts))
+ elif canonical.lon is not None and canonical.lat is not None:
+ radius = 0.015
+ stmt = stmt.where(
+ GtfsStop.lon >= canonical.lon - radius,
+ GtfsStop.lon <= canonical.lon + radius,
+ GtfsStop.lat >= canonical.lat - radius,
+ GtfsStop.lat <= canonical.lat + radius,
+ )
+ else:
+ return []
+ rows = db.execute(stmt.order_by(Source.name, GtfsStop.name, GtfsStop.stop_id).limit(max(1, min(limit, 100)) * 4)).all()
+ if not rows:
+ return []
+
+ stop_ids = [stop.id for stop, _, _ in rows]
+ link_by_stop_id = {
+ link.object_id: link
+ for link in db.scalars(
+ select(CanonicalStopLink).where(
+ CanonicalStopLink.object_type == "gtfs_stop",
+ CanonicalStopLink.object_id.in_(stop_ids),
+ )
+ ).all()
+ }
+ stop_ids_by_dataset: dict[int, set[str]] = {}
+ for stop, _, _ in rows:
+ stop_ids_by_dataset.setdefault(stop.dataset_id, set()).add(stop.stop_id)
+ scheduled_by_dataset = {
+ dataset_id: set(scheduled_stop_ids(db, dataset_id, sorted(stop_ids)))
+ for dataset_id, stop_ids in stop_ids_by_dataset.items()
+ }
+ scheduled_keys = {
+ (dataset_id, stop_id)
+ for dataset_id, stop_ids in scheduled_by_dataset.items()
+ for stop_id in stop_ids
+ if stop_id in stop_ids_by_dataset.get(dataset_id, set())
+ }
+ candidates = []
+ for stop, dataset, source in rows:
+ link = link_by_stop_id.get(stop.id)
+ distance_m = _distance_m(canonical.lon, canonical.lat, stop.lon, stop.lat)
+ candidates.append(
+ {
+ "id": stop.id,
+ "dataset_id": stop.dataset_id,
+ "source_id": source.id,
+ "source_name": source.name,
+ "dataset_active": dataset.is_active,
+ "stop_id": stop.stop_id,
+ "name": stop.name,
+ "lat": stop.lat,
+ "lon": stop.lon,
+ "parent_station": stop.parent_station,
+ "scheduled": (stop.dataset_id, stop.stop_id) in scheduled_keys,
+ "distance_m": distance_m,
+ "current_canonical_stop_id": None if link is None else link.canonical_stop_id,
+ "current_link_id": None if link is None else link.id,
+ }
+ )
+ candidates.sort(
+ key=lambda item: (
+ 0 if item["current_canonical_stop_id"] == canonical.id else 1,
+ 0 if item["scheduled"] else 1,
+ float("inf") if item["distance_m"] is None else item["distance_m"],
+ item["source_name"] or "",
+ item["name"] or "",
+ item["stop_id"],
+ )
+ )
+ return candidates[: max(1, min(limit, 100))]
+
+
+def _resolve_gtfs_stop_link_payload(db: Session, payload: CanonicalStopGtfsLinkRequest) -> GtfsStop:
+ stop = db.get(GtfsStop, payload.gtfs_stop_id) if payload.gtfs_stop_id is not None else None
+ if stop is None and payload.dataset_id is not None and payload.stop_id:
+ stop = db.scalar(
+ select(GtfsStop).where(
+ GtfsStop.dataset_id == payload.dataset_id,
+ GtfsStop.stop_id == payload.stop_id,
+ )
+ )
+ if stop is None:
+ raise HTTPException(status_code=404, detail="GTFS stop not found")
+ return stop
+
+
+def _manual_standalone_canonical_stop(db: Session, stop: GtfsStop) -> CanonicalStop:
+ stop_key = f"manual:gtfs_stop:{stop.dataset_id}:{stop.stop_id}"
+ canonical = db.scalar(select(CanonicalStop).where(CanonicalStop.stop_key == stop_key))
+ if canonical is not None:
+ return canonical
+ name = stop.name or stop.stop_id
+ canonical = CanonicalStop(
+ stop_key=stop_key,
+ name=name,
+ normalized_name=" ".join(name.casefold().split()),
+ lat=stop.lat,
+ lon=stop.lon,
+ metadata_json=json.dumps({"source": "manual_unlink", "dataset_id": stop.dataset_id, "stop_id": stop.stop_id}, separators=(",", ":")),
+ )
+ db.add(canonical)
+ db.flush()
+ return canonical
+
+
+def _persist_canonical_stop_rule(
+ db: Session,
+ rule_type: str,
+ stop: GtfsStop,
+ canonical: CanonicalStop,
+ note: str | None,
+) -> None:
+ dataset = db.get(Dataset, stop.dataset_id)
+ selector = {
+ "object_type": "gtfs_stop",
+ "source_id": None if dataset is None else dataset.source_id,
+ "dataset_id": stop.dataset_id,
+ "external_id": stop.stop_id,
+ }
+ action = {
+ "target_stop_key": canonical.stop_key,
+ "target_name": canonical.name,
+ "target_lat": canonical.lat,
+ "target_lon": canonical.lon,
+ "target_mode": canonical.mode,
+ "target_gtfs_stops": _target_gtfs_stop_rule_refs(db, canonical.id),
+ }
+ db.add(
+ MatchRule(
+ rule_type=rule_type,
+ selector_json=json.dumps(selector, separators=(",", ":")),
+ action_json=json.dumps(action, separators=(",", ":")),
+ note=note or f"Created from canonical stop {canonical.id}",
+ )
+ )
+
+
+def _target_gtfs_stop_rule_refs(db: Session, canonical_stop_id: int) -> list[dict]:
+ rows = db.execute(
+ select(CanonicalStopLink, Dataset)
+ .join(Dataset, Dataset.id == CanonicalStopLink.dataset_id)
+ .where(
+ CanonicalStopLink.canonical_stop_id == canonical_stop_id,
+ CanonicalStopLink.object_type == "gtfs_stop",
+ )
+ .order_by(Dataset.is_active.desc(), Dataset.source_id, CanonicalStopLink.role, CanonicalStopLink.external_id)
+ .limit(50)
+ ).all()
+ return [
+ {
+ "source_id": dataset.source_id,
+ "dataset_id": link.dataset_id,
+ "external_id": link.external_id,
+ }
+ for link, dataset in rows
+ ]
+
+
+def _distance_m(left_lon: float | None, left_lat: float | None, right_lon: float | None, right_lat: float | None) -> float | None:
+ if None in {left_lon, left_lat, right_lon, right_lat}:
+ return None
+ return round((((float(left_lon) - float(right_lon)) ** 2 + (float(left_lat) - float(right_lat)) ** 2) ** 0.5) * 111_320, 1)
+
+
+def _optional_int(value: object) -> int | None:
+ try:
+ return None if value is None else int(value)
+ except (TypeError, ValueError):
+ return None
+
+
+def _json_object_value(value: object, key: str) -> object:
+ return value.get(key) if isinstance(value, dict) else None
+
+
+def _job_queue_revision_payload(
+ db: Session,
+ *,
+ workers: list[dict] | None = None,
+ include_dismissed: bool = False,
+) -> dict:
+ queue_revision = job_queue_revision(db, include_dismissed=include_dismissed)
+ worker_status = workers if workers is not None else queue_worker_status()
+ worker_revision = _worker_revision(worker_status)
+ return {
+ **queue_revision,
+ "job_revision": queue_revision["revision"],
+ "worker_revision": worker_revision,
+ "revision": f"{queue_revision['revision']}|workers:{worker_revision}",
+ "workers": worker_status,
+ }
+
+
+def _worker_revision(workers: list[dict]) -> str:
+ if not workers:
+ return "none"
+ parts = []
+ for worker in sorted(workers, key=lambda item: str(item.get("worker_id") or "")):
+ parts.append(
+ ":".join(
+ [
+ str(worker.get("worker_id") or ""),
+ "1" if worker.get("running") else "0",
+ str(worker.get("pid") or 0),
+ str(worker.get("log_file") or ""),
+ ]
+ )
+ )
+ return "|".join(parts)
+
+
+def _set_etag(response: Response, revision: str) -> None:
+ response.headers["ETag"] = f'W/"{revision}"'
+
+
+def _upsert_gtfs_qa_note(notes: str | None, *, status: str, note: str) -> str | None:
+ status_text = (status or "unreviewed").strip() or "unreviewed"
+ note_text = " ".join((note or "").strip().split())
+ updated_at = datetime.now(timezone.utc).isoformat()
+ marker = f"{GTFS_QA_NOTE_PREFIX} status={status_text}; updated_at={updated_at}"
+ if note_text:
+ marker = f"{marker}; note={note_text}"
+ preserved = [
+ line
+ for line in str(notes or "").splitlines()
+ if line.strip() and not line.startswith(GTFS_QA_NOTE_PREFIX)
+ ]
+ if status_text != "unreviewed" or note_text:
+ preserved.insert(0, marker)
+ return "\n".join(preserved) or None
+
+
+def _source_response(
+ db: Session,
+ source: Source,
+ update_check: SourceUpdateCheck | None = None,
+ active_job: Job | None = None,
+ active_dataset_jobs: dict[int, Job] | None = None,
+) -> dict:
+ is_online = urlparse(source.url).scheme in {"http", "https"}
+ dataset_jobs = active_dataset_jobs or {}
+ return {
+ "id": source.id,
+ "name": source.name,
+ "kind": source.kind,
+ "url": source.url,
+ "is_online": is_online,
+ "country": source.country,
+ "license": source.license,
+ "priority": source.priority,
+ "mode_scope": source.mode_scope,
+ "source_basis": source.source_basis,
+ "notes": source.notes,
+ "catalog_entry_id": source.catalog_entry_id,
+ "enabled": source.enabled,
+ "status": source.status,
+ "last_error": source.last_error,
+ "last_run_at": source.last_run_at.isoformat() if source.last_run_at else None,
+ "active_job": None if active_job is None else job_payload(active_job),
+ "latest_update_check": update_check_payload(update_check),
+ "stats": source_row_counts(db, source),
+ "datasets": [
+ {
+ "id": dataset.id,
+ "kind": dataset.kind,
+ "is_active": dataset.is_active,
+ "status": dataset.status,
+ "local_path": dataset.local_path,
+ "sha256": dataset.sha256,
+ "created_at": dataset.created_at.isoformat() if dataset.created_at else None,
+ "metadata": _json_object(dataset.metadata_json),
+ "stats": dataset_row_counts(db, dataset.id, dataset.kind),
+ "active_job": None if dataset.id not in dataset_jobs else job_payload(dataset_jobs[dataset.id]),
+ }
+ for dataset in source.datasets
+ ],
+ }
+
+
+def _queue_source_import_job(db: Session, source: Source, *, run_match: bool, build_route_layer: bool, priority: int = 0) -> tuple[Job, bool]:
+ active_job = active_source_import_job(db, source.id)
+ if active_job is not None:
+ db.commit()
+ return active_job, False
+ job = create_source_import_job(db, source, run_match=run_match, build_route_layer=build_route_layer, priority=priority)
+ db.commit()
+ db.refresh(job)
+ return job, True
+
+
+def _queue_admin_maintenance_job(
+ db: Session,
+ action: str,
+ payload: AdminActionRequest | None,
+ *,
+ priority: int = 0,
+) -> dict:
+ if action not in ADMIN_JOB_ACTIONS:
+ raise HTTPException(status_code=404, detail="admin action not found")
+ request_payload = payload or AdminActionRequest()
+ if action == "reset-db" and request_payload.confirm != "RESET":
+ raise HTTPException(status_code=400, detail="confirmation text RESET is required")
+ if action == "vacuum-db" and request_payload.confirm != "VACUUM":
+ raise HTTPException(status_code=400, detail="confirmation text VACUUM is required")
+ if action in {"prune-cache", "prune-inactive-datasets"} and not request_payload.dry_run and request_payload.confirm != "PRUNE":
+ raise HTTPException(status_code=400, detail="confirmation text PRUNE is required")
+
+ job_payload_data = _request_model_payload(request_payload)
+ job_payload_data.pop("confirm", None)
+ if action in {"init-db", "vacuum-db", "reset-db"}:
+ job_payload_data.pop("dry_run", None)
+ job = create_maintenance_job(db, action, job_payload_data, priority=priority)
+ db.commit()
+ db.refresh(job)
+ return job_payload(job)
+
+
+def _request_model_payload(model: BaseModel) -> dict:
+ if hasattr(model, "model_dump"):
+ return model.model_dump(exclude_none=True)
+ return model.dict(exclude_none=True)
+
+
+def _active_dataset_ids(
+ db: Session,
+ source_ids: list[int] | None = None,
+ dataset_ids: list[int] | None = None,
+ dataset_kinds: list[str] | None = None,
+) -> list[int]:
+ stmt = select(Dataset.id).where(Dataset.is_active.is_(True))
+ if source_ids:
+ stmt = stmt.where(Dataset.source_id.in_(source_ids))
+ if dataset_ids:
+ stmt = stmt.where(Dataset.id.in_(dataset_ids))
+ if dataset_kinds:
+ stmt = stmt.where(Dataset.kind.in_(dataset_kinds))
+ return [row[0] for row in db.execute(stmt).all()]
+
+
+def _gtfs_route_scope_condition(route_scopes: list[str]):
+ fallback_builder = _route_scope_fallback_condition(
+ mode_column=GtfsRoute.mode,
+ ref_column=GtfsRoute.short_name,
+ name_column=GtfsRoute.long_name,
+ )
+ fallback = fallback_builder(route_scopes)
+ stored = GtfsRoute.route_scope.in_(route_scopes)
+ if "local" in route_scopes:
+ non_local_bus_fallback = fallback_builder(["long_distance", "regional"])
+ stored = and_(stored, not_(and_(GtfsRoute.mode.in_(["bus", "trolleybus"]), non_local_bus_fallback)))
+ return or_(stored, fallback)
+
+
+def _route_pattern_scope_condition(route_scopes: list[str]):
+ fallback_builder = _route_scope_fallback_condition(
+ mode_column=RoutePattern.mode,
+ ref_column=RoutePattern.route_ref,
+ name_column=RoutePattern.route_name,
+ )
+ fallback = fallback_builder(route_scopes)
+ stored = RoutePattern.route_scope.in_(route_scopes)
+ if "local" in route_scopes:
+ non_local_bus_fallback = fallback_builder(["long_distance", "regional"])
+ stored = and_(stored, not_(and_(RoutePattern.mode.in_(["bus", "trolleybus"]), non_local_bus_fallback)))
+ return or_(stored, fallback)
+
+
+def _route_scope_fallback_condition(*, mode_column, ref_column, name_column):
+ def condition(route_scopes: list[str]):
+ ref = func.upper(func.coalesce(ref_column, ""))
+ name = func.upper(func.coalesce(name_column, ""))
+ train_long_distance = and_(
+ mode_column == "train",
+ or_(
+ ref.like("ICE%"),
+ ref.like("IC%"),
+ ref.like("EC%"),
+ ref.like("ECE%"),
+ ref.like("EN%"),
+ ref.like("NJ%"),
+ ref.like("RJ%"),
+ ref.like("RJX%"),
+ ref.like("TGV%"),
+ ref.like("THA%"),
+ ref.like("FLX%"),
+ name.like("%INTERCITY%"),
+ name.like("%EUROCITY%"),
+ name.like("%NIGHTJET%"),
+ name.like("%FLIXTRAIN%"),
+ ),
+ )
+ bus_long_distance = and_(
+ mode_column.in_(["bus", "trolleybus"]),
+ or_(
+ ref.like("FLX%"),
+ name.like("%FLIXBUS%"),
+ name.like("%EUROLINES%"),
+ name.like("%INTERCITYBUS%"),
+ name.like("%IC BUS%"),
+ name.like("%FERNBUS%"),
+ name.like("%LONG DISTANCE%"),
+ ),
+ )
+ long_distance = or_(mode_column == "coach", train_long_distance, bus_long_distance)
+ bus_regional = and_(
+ mode_column.in_(["bus", "trolleybus"]),
+ not_(bus_long_distance),
+ or_(
+ name.like("%REGIONALBUS%"),
+ name.like("%REGIOBUS%"),
+ name.like("%REGIONAL BUS%"),
+ name.like("%REGIONALVERKEHR%"),
+ ),
+ )
+ local = or_(
+ mode_column.in_(["tram", "light_rail", "subway", "ferry", "funicular", "aerialway", "monorail"]),
+ and_(mode_column.in_(["bus", "trolleybus"]), not_(or_(bus_long_distance, bus_regional))),
+ and_(mode_column == "train", or_(ref.like("S%"), name.like("%S-BAHN%"))),
+ )
+ train_regional = and_(
+ mode_column == "train",
+ not_(train_long_distance),
+ or_(
+ ref.like("IRE%"),
+ ref.like("RE%"),
+ ref.like("RB%"),
+ ref.like("RER%"),
+ ref.like("TER%"),
+ ref.like("REX%"),
+ ref.like("MEX%"),
+ ref.like("ALX%"),
+ ref.like("WFB%"),
+ ref.like("R%"),
+ name.like("%REGIONAL%"),
+ name.like("%REGIO%"),
+ ),
+ )
+ regional = or_(train_regional, bus_regional)
+ conditions = []
+ if "long_distance" in route_scopes:
+ conditions.append(long_distance)
+ if "regional" in route_scopes:
+ conditions.append(regional)
+ if "local" in route_scopes:
+ conditions.append(local)
+ if "unknown" in route_scopes:
+ conditions.append(and_(mode_column == "train", not_(or_(long_distance, regional, local))))
+ return or_(*conditions) if conditions else mode_column.is_(None)
+
+ return condition
+
+
+def _osm_features_response(
+ db: Session,
+ kind: Optional[str],
+ mode: Optional[str] = None,
+ route_scope: Optional[str] = None,
+ geometry: Optional[str] = None,
+ source_id: Optional[str] = None,
+ dataset_id: Optional[str] = None,
+ bbox: Optional[str] = None,
+ limit: int = MAP_FEATURE_LIMIT,
+) -> JSONResponse:
+ active_dataset_ids = _active_dataset_ids(
+ db,
+ source_ids=_csv_ints(source_id, "source_id"),
+ dataset_ids=_csv_ints(dataset_id, "dataset_id"),
+ dataset_kinds=["osm_geojson"],
+ )
+ if not active_dataset_ids:
+ return JSONResponse(feature_collection([]))
+
+ kinds = _csv_values(kind)
+ modes = _csv_values(mode)
+ route_scopes = _csv_values(route_scope)
+ parsed_bbox = _parse_bbox(bbox)
+ rows = query_osm_features(
+ db,
+ active_dataset_ids,
+ kinds=kinds or None,
+ modes=modes or None,
+ route_scopes=route_scopes or None,
+ bbox=parsed_bbox,
+ limit=_clamp_limit(limit),
+ )
+ if not source_id and not dataset_id:
+ rows = _dedupe_osm_feature_rows(rows)
+ features = []
+ for row in rows:
+ feature = osm_feature_feature(row)
+ if feature is not None and _geometry_matches(feature["geometry"], geometry):
+ features.append(feature)
+ return JSONResponse(feature_collection(features))
+
+
+def _dedupe_osm_feature_rows(rows: list[OsmFeature]) -> list[OsmFeature]:
+ selected: dict[tuple[str, str, str], OsmFeature] = {}
+ for row in rows:
+ key = (row.kind, row.osm_type, row.osm_id)
+ current = selected.get(key)
+ if current is None or _osm_feature_preference(row) < _osm_feature_preference(current):
+ selected[key] = row
+ return list(selected.values())
+
+
+def _osm_feature_preference(row: OsmFeature) -> tuple[int, int]:
+ span = None
+ if None not in {row.min_lon, row.min_lat, row.max_lon, row.max_lat}:
+ span = abs(float(row.max_lon) - float(row.min_lon)) + abs(float(row.max_lat) - float(row.min_lat))
+ return (0 if span is not None else 1, -int((span or 0) * 1_000_000), row.dataset_id)
+
+
+def _csv_values(value: Optional[str]) -> list[str]:
+ if not value:
+ return []
+ return [part.strip() for part in value.split(",") if part.strip()]
+
+
+def _csv_ints(value: Optional[str], name: str) -> list[int]:
+ if not value:
+ return []
+ values = []
+ for part in value.split(","):
+ part = part.strip()
+ if not part:
+ continue
+ try:
+ values.append(int(part))
+ except ValueError as exc:
+ raise HTTPException(status_code=400, detail=f"{name} values must be integers") from exc
+ return values
+
+
+def _path_int(value: str, name: str) -> int:
+ try:
+ return int(value)
+ except (TypeError, ValueError) as exc:
+ raise HTTPException(status_code=400, detail=f"{name} must be an integer") from exc
+
+
+def _json_object(value: str | None) -> dict:
+ try:
+ data = json.loads(value or "{}")
+ except json.JSONDecodeError:
+ return {}
+ return data if isinstance(data, dict) else {}
+
+
+def _source_stats(db: Session) -> list[dict]:
+ active_datasets = db.scalars(select(Dataset).where(Dataset.is_active.is_(True))).all()
+ rows = []
+ for dataset in active_datasets:
+ source = db.get(Source, dataset.source_id)
+ if source is None:
+ continue
+ item = {
+ "source_id": source.id,
+ "source_name": source.name,
+ "source_kind": source.kind,
+ "dataset_id": dataset.id,
+ "dataset_kind": dataset.kind,
+ "routes": 0,
+ "stops": 0,
+ "trips": 0,
+ "stop_times": 0,
+ "features": 0,
+ "match_counts": {},
+ }
+ dataset_stats = dataset_row_counts(db, dataset.id, dataset.kind)
+ if dataset.kind == "gtfs":
+ item["routes"] = dataset_stats.get("routes", 0)
+ item["stops"] = dataset_stats.get("stops", 0)
+ item["trips"] = dataset_stats.get("trips", 0)
+ item["stop_times"] = dataset_stats.get("stop_times", 0)
+ item["match_counts"] = dataset_stats.get("match_counts", {})
+ elif dataset.kind == "osm_geojson":
+ item["routes"] = dataset_stats.get("routes", 0)
+ item["stops"] = dataset_stats.get("stops", 0)
+ item["features"] = dataset_stats.get("features", 0)
+ rows.append(item)
+ return rows
+
+
+def _match_scope_summary(db: Session, active_gtfs_dataset_ids: list[int]) -> dict[str, dict[str, int]]:
+ if not active_gtfs_dataset_ids:
+ return {}
+ rows = db.execute(
+ select(RouteMatch.status, RouteMatch.reasons_json)
+ .join(RouteMatch.gtfs_route)
+ .where(GtfsRoute.dataset_id.in_(active_gtfs_dataset_ids))
+ ).all()
+ summary: dict[str, dict[str, int]] = {}
+ for status, reasons_json in rows:
+ try:
+ reasons = json.loads(reasons_json or "{}")
+ except json.JSONDecodeError:
+ reasons = {}
+ scope = str(reasons.get("scope") or "unknown_scope")
+ scope_counts = summary.setdefault(scope, {})
+ scope_counts[str(status)] = scope_counts.get(str(status), 0) + 1
+ return summary
+
+
+def _parse_bbox(value: Optional[str]) -> tuple[float, float, float, float] | None:
+ if not value:
+ return None
+ parts = value.split(",")
+ if len(parts) != 4:
+ raise HTTPException(status_code=400, detail="bbox must be min_lon,min_lat,max_lon,max_lat")
+ try:
+ min_lon, min_lat, max_lon, max_lat = [float(part) for part in parts]
+ except ValueError as exc:
+ raise HTTPException(status_code=400, detail="bbox values must be numbers") from exc
+ if min_lon > max_lon or min_lat > max_lat:
+ raise HTTPException(status_code=400, detail="bbox minimums must be less than maximums")
+ return min_lon, min_lat, max_lon, max_lat
+
+
+def _where_bbox_overlaps(stmt, model, bbox: tuple[float, float, float, float]):
+ min_lon, min_lat, max_lon, max_lat = bbox
+ table_name = getattr(model, "__tablename__", "")
+ if using_postgresql() and table_name in {"gtfs_routes", "route_patterns"}:
+ return stmt.where(
+ text(
+ f"""
+ (
+ {table_name}.geom && ST_MakeEnvelope(:bbox_min_lon, :bbox_min_lat, :bbox_max_lon, :bbox_max_lat, 4326)
+ OR (
+ {table_name}.geom IS NULL
+ AND {table_name}.min_lon <= :bbox_max_lon
+ AND {table_name}.max_lon >= :bbox_min_lon
+ AND {table_name}.min_lat <= :bbox_max_lat
+ AND {table_name}.max_lat >= :bbox_min_lat
+ )
+ )
+ """
+ )
+ ).params(
+ bbox_min_lon=min_lon,
+ bbox_min_lat=min_lat,
+ bbox_max_lon=max_lon,
+ bbox_max_lat=max_lat,
+ )
+ return stmt.where(model.min_lon <= max_lon, model.max_lon >= min_lon, model.min_lat <= max_lat, model.max_lat >= min_lat)
+
+
+def _where_point_bbox(stmt, model, bbox: tuple[float, float, float, float]):
+ min_lon, min_lat, max_lon, max_lat = bbox
+ table_name = getattr(model, "__tablename__", "")
+ if using_postgresql() and table_name in {"gtfs_stops", "canonical_stops"}:
+ return stmt.where(
+ text(
+ f"""
+ (
+ {table_name}.geom && ST_MakeEnvelope(:bbox_min_lon, :bbox_min_lat, :bbox_max_lon, :bbox_max_lat, 4326)
+ OR (
+ {table_name}.geom IS NULL
+ AND {table_name}.lon >= :bbox_min_lon
+ AND {table_name}.lon <= :bbox_max_lon
+ AND {table_name}.lat >= :bbox_min_lat
+ AND {table_name}.lat <= :bbox_max_lat
+ )
+ )
+ """
+ )
+ ).params(
+ bbox_min_lon=min_lon,
+ bbox_min_lat=min_lat,
+ bbox_max_lon=max_lon,
+ bbox_max_lat=max_lat,
+ )
+ return stmt.where(model.lon >= min_lon, model.lon <= max_lon, model.lat >= min_lat, model.lat <= max_lat)
+
+
+def _geometry_matches(geometry: dict, requested: Optional[str]) -> bool:
+ if not requested:
+ return True
+ geometry_type = geometry.get("type")
+ if requested == "point":
+ return geometry_type == "Point" or geometry_type == "MultiPoint"
+ if requested == "line":
+ return geometry_type == "LineString" or geometry_type == "MultiLineString"
+ if requested == "polygon":
+ return geometry_type == "Polygon" or geometry_type == "MultiPolygon"
+ if requested == "nonpoint":
+ return geometry_type not in {"Point", "MultiPoint"}
+ raise HTTPException(status_code=400, detail="geometry must be point, line, polygon, or nonpoint")
+
+
+def _clamp_limit(value: int) -> int:
+ return max(1, min(value, MAP_FEATURE_LIMIT_MAX))
+
+
+def _persist_match_rule(db: Session, match: RouteMatch, rule_type: str) -> None:
+ route = db.get(GtfsRoute, match.gtfs_route_id)
+ if route is None:
+ return
+ route_dataset = db.get(Dataset, route.dataset_id)
+ feature = db.get(OsmFeature, match.osm_feature_id) if match.osm_feature_id else None
+ feature_dataset = db.get(Dataset, feature.dataset_id) if feature is not None else None
+ selector = {
+ "gtfs": {
+ "source_id": None if route_dataset is None else route_dataset.source_id,
+ "dataset_id": route.dataset_id,
+ "route_id": route.route_id,
+ "route_key": route.route_key,
+ "ref": route.short_name,
+ "mode": route.mode,
+ },
+ "gtfs_route_id": match.gtfs_route_id,
+ }
+ action = {"status": match.status}
+ if feature is not None:
+ selector["osm_feature_id"] = match.osm_feature_id
+ action["osm"] = {
+ "source_id": None if feature_dataset is None else feature_dataset.source_id,
+ "dataset_id": feature.dataset_id,
+ "osm_type": feature.osm_type,
+ "osm_id": feature.osm_id,
+ "route_key": feature.route_key,
+ "ref": feature.ref,
+ "mode": feature.mode,
+ }
+ db.add(
+ MatchRule(
+ rule_type=rule_type,
+ selector_json=json.dumps(selector, separators=(",", ":")),
+ action_json=json.dumps(action, separators=(",", ":")),
+ note=f"Created from match {match.id}",
+ )
+ )
+
+
+def _status_from_candidate_score(score: float) -> str:
+ if score >= 85:
+ return "matched"
+ if score >= 65:
+ return "probable"
+ if score >= 40:
+ return "weak"
+ return "below_threshold"
+
+
+def _catalog_country(entry: SourceCatalogEntry | None) -> str | None:
+ if entry is None or not entry.country_code:
+ return None
+ country = entry.country_code.strip()
+ if len(country) == 2 and country.isalpha():
+ return country.upper()
+ return None
+
+
+def _catalog_notes(entry: SourceCatalogEntry | None) -> str | None:
+ if entry is None:
+ return None
+ parts = [
+ entry.next_pipeline_action,
+ entry.coverage_notes,
+ entry.geometry_notes,
+ ]
+ return _truncate(" ".join(part for part in parts if part), 2000)
+
+
+def _truncate(value: str | None, length: int) -> str | None:
+ if not value:
+ return None
+ return value[:length]
+
+
+def _is_database_locked_error(exc: OperationalError) -> bool:
+ text = " ".join(str(part).lower() for part in [exc, getattr(exc, "orig", "")])
+ return "database is locked" in text or "database table is locked" in text or "database is busy" in text
+
+
+def _is_statement_timeout_error(exc: OperationalError) -> bool:
+ text = " ".join(str(part).lower() for part in [exc, getattr(exc, "orig", "")])
+ return "statement timeout" in text or "canceling statement due to statement timeout" in text
+
+
+def _gtfs_route_summary(route: GtfsRoute) -> dict:
+ return {
+ "id": route.id,
+ "dataset_id": route.dataset_id,
+ "route_id": route.route_id,
+ "ref": route.short_name,
+ "name": route.long_name,
+ "mode": route.mode,
+ "operator": route.operator_name,
+ "geometry": {
+ "present": bool(route.geometry_geojson),
+ "bbox": [route.min_lon, route.min_lat, route.max_lon, route.max_lat],
+ },
+ }
+
+
+def _osm_route_summary(feature: OsmFeature) -> dict:
+ return {
+ "id": osm_feature_public_id(feature),
+ "row_id": feature.id,
+ "dataset_id": feature.dataset_id,
+ "osm_type": feature.osm_type,
+ "osm_id": feature.osm_id,
+ "ref": feature.ref,
+ "name": feature.name,
+ "mode": feature.mode,
+ "operator": feature.operator,
+ "network": feature.network,
+ "geometry": {
+ "present": bool(feature.geometry_geojson),
+ "bbox": [feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat],
+ },
+ }
+
+
+def _match_candidate_preview(
+ route: GtfsRoute,
+ match: RouteMatch,
+ candidate_rows: list[tuple[OsmFeature, float, dict[str, object]]],
+) -> dict:
+ features: list[dict] = []
+ route_feature = gtfs_route_feature(
+ route,
+ {
+ "preview_role": "gtfs_route",
+ "match_id": match.id,
+ "match_status": match.status,
+ "label": route.short_name or route.route_id,
+ },
+ )
+ if route_feature is not None:
+ features.append(route_feature)
+
+ for rank, (feature, score, _reasons) in enumerate(candidate_rows, start=1):
+ candidate_feature = osm_feature_feature(
+ feature,
+ {
+ "preview_role": "candidate",
+ "match_id": match.id,
+ "candidate_rank": rank,
+ "candidate_score": round(float(score), 2),
+ "candidate_status": _status_from_candidate_score(score),
+ "current_match": feature.id == match.osm_feature_id,
+ "label": feature.ref or feature.name or feature.osm_id,
+ },
+ )
+ if candidate_feature is not None:
+ features.append(candidate_feature)
+ return feature_collection(features)
diff --git a/app/models.py b/app/models.py
new file mode 100644
index 0000000..823374d
--- /dev/null
+++ b/app/models.py
@@ -0,0 +1,612 @@
+from __future__ import annotations
+
+from datetime import datetime, timezone
+from typing import Optional
+
+from sqlalchemy import BigInteger, Boolean, DateTime, Float, ForeignKey, Integer, String, Text, UniqueConstraint
+from sqlalchemy.orm import Mapped, mapped_column, relationship
+
+from app.db import Base
+
+
+def now_utc() -> datetime:
+ return datetime.now(timezone.utc)
+
+
+class Source(Base):
+ __tablename__ = "sources"
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ catalog_entry_id: Mapped[Optional[int]] = mapped_column(ForeignKey("source_catalog_entries.id"), nullable=True, index=True)
+ name: Mapped[str] = mapped_column(String(255), nullable=False)
+ kind: Mapped[str] = mapped_column(String(64), nullable=False) # gtfs, osm_geojson, osm_pbf, osm_diff
+ url: Mapped[str] = mapped_column(Text, nullable=False)
+ country: Mapped[Optional[str]] = mapped_column(String(8), nullable=True)
+ license: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
+ priority: Mapped[Optional[str]] = mapped_column(String(16), nullable=True, index=True)
+ mode_scope: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ source_basis: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ enabled: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
+ status: Mapped[str] = mapped_column(String(64), default="new", nullable=False)
+ last_error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ last_run_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
+
+ catalog_entry: Mapped[Optional["SourceCatalogEntry"]] = relationship()
+ datasets: Mapped[list["Dataset"]] = relationship(back_populates="source", cascade="all, delete-orphan")
+ update_checks: Mapped[list["SourceUpdateCheck"]] = relationship(back_populates="source", cascade="all, delete-orphan")
+
+
+class SourceCatalogEntry(Base):
+ __tablename__ = "source_catalog_entries"
+ __table_args__ = (UniqueConstraint("catalog_key", name="uq_source_catalog_entry_key"),)
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ catalog_key: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
+ geography: Mapped[Optional[str]] = mapped_column(String(128), nullable=True, index=True)
+ country_code: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
+ mode_scope: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ source_name: Mapped[str] = mapped_column(Text, nullable=False)
+ source_category: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ formats_apis: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ availability: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ coverage_notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ geometry_notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ disruptions_closures: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ operator_list_use: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ access_license_notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ priority: Mapped[Optional[str]] = mapped_column(String(32), nullable=True, index=True)
+ source_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ evidence_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ next_pipeline_action: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ status: Mapped[str] = mapped_column(String(64), default="backlog", nullable=False, index=True)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
+ updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
+
+
+class Dataset(Base):
+ __tablename__ = "datasets"
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ source_id: Mapped[int] = mapped_column(ForeignKey("sources.id"), nullable=False, index=True)
+ kind: Mapped[str] = mapped_column(String(64), nullable=False)
+ local_path: Mapped[str] = mapped_column(Text, nullable=False)
+ sha256: Mapped[str] = mapped_column(String(64), nullable=False)
+ is_active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
+ status: Mapped[str] = mapped_column(String(64), default="imported", nullable=False)
+ metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
+
+ source: Mapped[Source] = relationship(back_populates="datasets")
+
+
+class SourceUpdateCheck(Base):
+ __tablename__ = "source_update_checks"
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ source_id: Mapped[int] = mapped_column(ForeignKey("sources.id"), nullable=False, index=True)
+ checked_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
+ status: Mapped[str] = mapped_column(String(64), nullable=False, default="checked", index=True)
+ update_available: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
+ reason: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ remote_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ etag: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ last_modified: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ content_length: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
+ content_type: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ local_mtime: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
+ local_size: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
+ local_sha256: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
+ active_dataset_id: Mapped[Optional[int]] = mapped_column(ForeignKey("datasets.id"), nullable=True, index=True)
+ active_dataset_sha256: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
+ metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+
+ source: Mapped[Source] = relationship(back_populates="update_checks")
+ active_dataset: Mapped[Optional[Dataset]] = relationship()
+
+
+class OsmDiffState(Base):
+ __tablename__ = "osm_diff_states"
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ source_id: Mapped[int] = mapped_column(ForeignKey("sources.id"), nullable=False, index=True)
+ raw_dataset_id: Mapped[Optional[int]] = mapped_column(ForeignKey("datasets.id"), nullable=True, index=True)
+ updates_url: Mapped[str] = mapped_column(Text, nullable=False)
+ sequence_number: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
+ timestamp: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
+ status: Mapped[str] = mapped_column(String(64), nullable=False, default="active", index=True)
+ metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
+ updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
+
+ source: Mapped[Source] = relationship()
+ raw_dataset: Mapped[Optional[Dataset]] = relationship()
+
+
+class Job(Base):
+ __tablename__ = "jobs"
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ kind: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
+ status: Mapped[str] = mapped_column(String(64), nullable=False, default="queued", index=True)
+ description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ progress_current: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
+ progress_total: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
+ priority: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
+ requested_action: Mapped[Optional[str]] = mapped_column(String(32), nullable=True, index=True)
+ lease_owner: Mapped[Optional[str]] = mapped_column(String(255), nullable=True, index=True)
+ lease_expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True, index=True)
+ paused_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
+ result_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ dismissed_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True, index=True)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
+ started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
+ updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
+ finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
+
+ events: Mapped[list["JobEvent"]] = relationship(back_populates="job", cascade="all, delete-orphan")
+
+
+class JobEvent(Base):
+ __tablename__ = "job_events"
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ job_id: Mapped[int] = mapped_column(ForeignKey("jobs.id"), nullable=False, index=True)
+ level: Mapped[str] = mapped_column(String(32), nullable=False, default="info", index=True)
+ event_type: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
+ message: Mapped[str] = mapped_column(Text, nullable=False)
+ progress_current: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
+ progress_total: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
+ metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
+
+ job: Mapped[Job] = relationship(back_populates="events")
+
+
+class PipelineRun(Base):
+ __tablename__ = "pipeline_runs"
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ stage: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
+ version: Mapped[str] = mapped_column(String(128), nullable=False, index=True)
+ dependency_hash: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
+ status: Mapped[str] = mapped_column(String(64), nullable=False, default="running", index=True)
+ source_id: Mapped[Optional[int]] = mapped_column(ForeignKey("sources.id"), nullable=True, index=True)
+ dataset_id: Mapped[Optional[int]] = mapped_column(ForeignKey("datasets.id"), nullable=True, index=True)
+ job_id: Mapped[Optional[int]] = mapped_column(ForeignKey("jobs.id"), nullable=True, index=True)
+ input_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ output_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ error: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ started_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
+ updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
+ finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
+
+ source: Mapped[Optional[Source]] = relationship()
+ dataset: Mapped[Optional[Dataset]] = relationship()
+ job: Mapped[Optional[Job]] = relationship()
+
+
+class GtfsAgency(Base):
+ __tablename__ = "gtfs_agencies"
+ __table_args__ = (UniqueConstraint("dataset_id", "agency_id", name="uq_gtfs_agency_dataset_id"),)
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
+ agency_id: Mapped[str] = mapped_column(String(255), nullable=False)
+ name: Mapped[str] = mapped_column(Text, nullable=False)
+ url: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ timezone: Mapped[Optional[str]] = mapped_column(String(128), nullable=True)
+
+
+class GtfsStop(Base):
+ __tablename__ = "gtfs_stops"
+ __table_args__ = (UniqueConstraint("dataset_id", "stop_id", name="uq_gtfs_stop_dataset_id"),)
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
+ stop_id: Mapped[str] = mapped_column(String(255), nullable=False)
+ name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ parent_station: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
+
+
+class GtfsRoute(Base):
+ __tablename__ = "gtfs_routes"
+ __table_args__ = (UniqueConstraint("dataset_id", "route_id", name="uq_gtfs_route_dataset_id"),)
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
+ route_id: Mapped[str] = mapped_column(String(255), nullable=False)
+ agency_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
+ short_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ long_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ route_type: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
+ mode: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
+ route_scope: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
+ operator_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ geometry_geojson: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ min_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ min_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ max_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ max_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ route_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True, index=True)
+ operator_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True, index=True)
+
+
+class GtfsTrip(Base):
+ __tablename__ = "gtfs_trips"
+ __table_args__ = (UniqueConstraint("dataset_id", "trip_id", name="uq_gtfs_trip_dataset_id"),)
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
+ route_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
+ trip_id: Mapped[str] = mapped_column(String(255), nullable=False)
+ service_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
+ shape_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
+
+
+class GtfsCalendar(Base):
+ __tablename__ = "gtfs_calendars"
+ __table_args__ = (UniqueConstraint("dataset_id", "service_id", name="uq_gtfs_calendar_dataset_service"),)
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
+ service_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
+ monday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
+ tuesday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
+ wednesday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
+ thursday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
+ friday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
+ saturday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
+ sunday: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
+ start_date: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
+ end_date: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
+
+
+class GtfsCalendarDate(Base):
+ __tablename__ = "gtfs_calendar_dates"
+ __table_args__ = (UniqueConstraint("dataset_id", "service_id", "date", name="uq_gtfs_calendar_date_dataset_service_date"),)
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
+ service_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
+ date: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
+ exception_type: Mapped[int] = mapped_column(Integer, nullable=False)
+
+
+class GtfsShape(Base):
+ __tablename__ = "gtfs_shapes"
+ __table_args__ = (UniqueConstraint("dataset_id", "shape_id", name="uq_gtfs_shape_dataset_id"),)
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
+ shape_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
+ geometry_geojson: Mapped[str] = mapped_column(Text, nullable=False)
+ min_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ min_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ max_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ max_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+
+
+class GtfsStopTime(Base):
+ __tablename__ = "gtfs_stop_times"
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
+ trip_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
+ stop_id: Mapped[str] = mapped_column(String(255), nullable=False)
+ stop_sequence: Mapped[int] = mapped_column(Integer, nullable=False)
+ arrival_time: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
+ departure_time: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
+ arrival_seconds: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, index=True)
+ departure_seconds: Mapped[Optional[int]] = mapped_column(Integer, nullable=True, index=True)
+
+
+class CanonicalStop(Base):
+ __tablename__ = "canonical_stops"
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ stop_key: Mapped[str] = mapped_column(String(255), nullable=False, unique=True, index=True)
+ name: Mapped[str] = mapped_column(Text, nullable=False)
+ normalized_name: Mapped[str] = mapped_column(Text, nullable=False, index=True)
+ lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ mode: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
+ metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
+
+
+class CanonicalStopLink(Base):
+ __tablename__ = "canonical_stop_links"
+ __table_args__ = (
+ UniqueConstraint("object_type", "dataset_id", "object_id", name="uq_canonical_stop_link_object"),
+ )
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ canonical_stop_id: Mapped[int] = mapped_column(ForeignKey("canonical_stops.id"), nullable=False, index=True)
+ layer: Mapped[str] = mapped_column(String(64), nullable=False, index=True) # timetable, visual
+ object_type: Mapped[str] = mapped_column(String(64), nullable=False, index=True) # gtfs_stop, osm_feature
+ dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
+ object_id: Mapped[int] = mapped_column(Integer, nullable=False, index=True)
+ external_id: Mapped[str] = mapped_column(Text, nullable=False)
+ role: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
+ confidence: Mapped[float] = mapped_column(Float, nullable=False, default=1.0)
+ distance_m: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+
+ canonical_stop: Mapped[CanonicalStop] = relationship()
+
+
+class RoutePattern(Base):
+ __tablename__ = "route_patterns"
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ pattern_key: Mapped[str] = mapped_column(String(255), nullable=False, unique=True, index=True)
+ route_ref: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ route_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ mode: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
+ route_scope: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
+ operator_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ source_kind: Mapped[str] = mapped_column(String(64), nullable=False, index=True) # osm, gtfs_proposed
+ status: Mapped[str] = mapped_column(String(64), nullable=False, default="active", index=True)
+ osm_feature_id: Mapped[Optional[int]] = mapped_column(ForeignKey("osm_features.id"), nullable=True, index=True)
+ gtfs_route_id: Mapped[Optional[int]] = mapped_column(ForeignKey("gtfs_routes.id"), nullable=True, index=True)
+ gtfs_shape_id: Mapped[Optional[str]] = mapped_column(String(255), nullable=True, index=True)
+ geometry_geojson: Mapped[str] = mapped_column(Text, nullable=False)
+ min_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ min_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ max_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ max_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ confidence: Mapped[float] = mapped_column(Float, nullable=False, default=1.0)
+ metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
+
+ osm_feature: Mapped[Optional["OsmFeature"]] = relationship()
+ gtfs_route: Mapped[Optional[GtfsRoute]] = relationship()
+
+
+class RoutePatternStop(Base):
+ __tablename__ = "route_pattern_stops"
+ __table_args__ = (UniqueConstraint("route_pattern_id", "sequence", name="uq_route_pattern_stop_sequence"),)
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ route_pattern_id: Mapped[int] = mapped_column(ForeignKey("route_patterns.id"), nullable=False, index=True)
+ canonical_stop_id: Mapped[int] = mapped_column(ForeignKey("canonical_stops.id"), nullable=False, index=True)
+ sequence: Mapped[int] = mapped_column(Integer, nullable=False)
+ distance_along: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ source_kind: Mapped[str] = mapped_column(String(64), nullable=False, default="timetable")
+ confidence: Mapped[float] = mapped_column(Float, nullable=False, default=1.0)
+
+ route_pattern: Mapped[RoutePattern] = relationship()
+ canonical_stop: Mapped[CanonicalStop] = relationship()
+
+
+class GtfsRoutePatternLink(Base):
+ __tablename__ = "gtfs_route_pattern_links"
+ __table_args__ = (UniqueConstraint("dataset_id", "route_id", "shape_id", name="uq_gtfs_route_pattern_shape"),)
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
+ gtfs_route_id: Mapped[int] = mapped_column(ForeignKey("gtfs_routes.id"), nullable=False, index=True)
+ route_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
+ shape_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
+ route_pattern_id: Mapped[int] = mapped_column(ForeignKey("route_patterns.id"), nullable=False, index=True)
+ confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0)
+ status: Mapped[str] = mapped_column(String(64), nullable=False, default="linked", index=True)
+ source_kind: Mapped[str] = mapped_column(String(64), nullable=False)
+ reasons_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+
+ gtfs_route: Mapped[GtfsRoute] = relationship()
+ route_pattern: Mapped[RoutePattern] = relationship()
+
+
+class GtfsTripRoutePatternLink(Base):
+ __tablename__ = "gtfs_trip_route_pattern_links"
+ __table_args__ = (UniqueConstraint("dataset_id", "trip_id", name="uq_gtfs_trip_route_pattern"),)
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
+ trip_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
+ route_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
+ shape_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
+ route_pattern_id: Mapped[int] = mapped_column(ForeignKey("route_patterns.id"), nullable=False, index=True)
+ source_kind: Mapped[str] = mapped_column(String(64), nullable=False)
+ confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0)
+ status: Mapped[str] = mapped_column(String(64), nullable=False, default="linked", index=True)
+
+ route_pattern: Mapped[RoutePattern] = relationship()
+
+
+class OsmFeature(Base):
+ __tablename__ = "osm_features"
+ __table_args__ = (UniqueConstraint("dataset_id", "osm_type", "osm_id", name="uq_osm_feature_dataset_type_id"),)
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
+ osm_type: Mapped[str] = mapped_column(String(32), nullable=False)
+ osm_id: Mapped[str] = mapped_column(String(64), nullable=False)
+ kind: Mapped[str] = mapped_column(String(64), nullable=False, index=True) # route, stop, terminal, station, infra
+ mode: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
+ route_scope: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
+ name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ ref: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ operator: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ network: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ geometry_geojson: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ min_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ min_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ max_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ max_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ tags_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ route_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True, index=True)
+ operator_key: Mapped[Optional[str]] = mapped_column(Text, nullable=True, index=True)
+
+
+class OsmAddress(Base):
+ __tablename__ = "osm_addresses"
+ __table_args__ = (UniqueConstraint("dataset_id", "osm_type", "osm_id", name="uq_osm_address_dataset_type_id"),)
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
+ osm_type: Mapped[str] = mapped_column(String(32), nullable=False)
+ osm_id: Mapped[str] = mapped_column(String(64), nullable=False)
+ housenumber: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ street: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ place: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ postcode: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ city: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ country: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ unit: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ display_name: Mapped[str] = mapped_column(Text, nullable=False)
+ search_text: Mapped[str] = mapped_column(Text, nullable=False)
+ lat: Mapped[float] = mapped_column(Float, nullable=False)
+ lon: Mapped[float] = mapped_column(Float, nullable=False)
+ min_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ min_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ max_lon: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ max_lat: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ geometry_geojson: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ tags_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
+
+
+class RoutingNode(Base):
+ __tablename__ = "routing_nodes"
+ __table_args__ = (UniqueConstraint("dataset_id", "osm_node_id", name="uq_routing_node_dataset_osm"),)
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
+ osm_node_id: Mapped[int] = mapped_column(BigInteger, nullable=False, index=True)
+ lat: Mapped[float] = mapped_column(Float, nullable=False)
+ lon: Mapped[float] = mapped_column(Float, nullable=False)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
+
+
+class RoutingEdge(Base):
+ __tablename__ = "routing_edges"
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ dataset_id: Mapped[int] = mapped_column(ForeignKey("datasets.id"), nullable=False, index=True)
+ osm_way_id: Mapped[int] = mapped_column(BigInteger, nullable=False, index=True)
+ source_osm_node_id: Mapped[int] = mapped_column(BigInteger, nullable=False, index=True)
+ target_osm_node_id: Mapped[int] = mapped_column(BigInteger, nullable=False, index=True)
+ source_lat: Mapped[float] = mapped_column(Float, nullable=False)
+ source_lon: Mapped[float] = mapped_column(Float, nullable=False)
+ target_lat: Mapped[float] = mapped_column(Float, nullable=False)
+ target_lon: Mapped[float] = mapped_column(Float, nullable=False)
+ highway: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
+ name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ length_m: Mapped[float] = mapped_column(Float, nullable=False)
+ walk_cost_s: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ reverse_walk_cost_s: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ drive_cost_s: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ reverse_drive_cost_s: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
+ geometry_geojson: Mapped[str] = mapped_column(Text, nullable=False)
+ min_lon: Mapped[float] = mapped_column(Float, nullable=False)
+ min_lat: Mapped[float] = mapped_column(Float, nullable=False)
+ max_lon: Mapped[float] = mapped_column(Float, nullable=False)
+ max_lat: Mapped[float] = mapped_column(Float, nullable=False)
+ tags_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
+
+
+class RouteMatch(Base):
+ __tablename__ = "route_matches"
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ gtfs_route_id: Mapped[int] = mapped_column(ForeignKey("gtfs_routes.id"), nullable=False, index=True)
+ osm_feature_id: Mapped[Optional[int]] = mapped_column(ForeignKey("osm_features.id"), nullable=True, index=True)
+ confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0)
+ status: Mapped[str] = mapped_column(String(64), nullable=False) # matched, probable, weak, missing, accepted, rejected
+ rule_source: Mapped[str] = mapped_column(String(64), default="auto", nullable=False) # auto, manual
+ reasons_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
+ updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
+
+ gtfs_route: Mapped[GtfsRoute] = relationship()
+ osm_feature: Mapped[Optional[OsmFeature]] = relationship()
+
+
+class MatchRule(Base):
+ __tablename__ = "match_rules"
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ rule_type: Mapped[str] = mapped_column(String(64), nullable=False) # accept_match, reject_match, alias, force_operator
+ selector_json: Mapped[str] = mapped_column(Text, nullable=False)
+ action_json: Mapped[str] = mapped_column(Text, nullable=False)
+ note: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ active: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
+
+
+class JourneySearchCache(Base):
+ __tablename__ = "journey_search_cache"
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ cache_key: Mapped[str] = mapped_column(String(128), nullable=False, unique=True, index=True)
+ cache_type: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
+ payload_json: Mapped[str] = mapped_column(Text, nullable=False)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
+ updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
+ expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, index=True)
+
+
+class TravelRequest(Base):
+ __tablename__ = "travel_requests"
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ origin_stop_id: Mapped[str] = mapped_column(Text, nullable=False)
+ destination_stop_id: Mapped[str] = mapped_column(Text, nullable=False)
+ via_stop_id: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ departure_time: Mapped[str] = mapped_column(String(32), nullable=False)
+ service_date: Mapped[Optional[str]] = mapped_column(String(10), nullable=True, index=True)
+ max_transfers: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
+ transfer_seconds: Mapped[int] = mapped_column(Integer, nullable=False, default=120)
+ source_filter: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ preferences_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
+
+ itineraries: Mapped[list["Itinerary"]] = relationship(back_populates="request", cascade="all, delete-orphan")
+
+
+class Itinerary(Base):
+ __tablename__ = "itineraries"
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ request_id: Mapped[int] = mapped_column(ForeignKey("travel_requests.id"), nullable=False, index=True)
+ title: Mapped[str] = mapped_column(Text, nullable=False)
+ family: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
+ status: Mapped[str] = mapped_column(String(64), nullable=False, default="candidate", index=True)
+ saved: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, index=True)
+ summary_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ score_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ payload_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False, index=True)
+ updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=now_utc, nullable=False)
+
+ request: Mapped[TravelRequest] = relationship(back_populates="itineraries")
+ legs: Mapped[list["ItineraryLeg"]] = relationship(back_populates="itinerary", cascade="all, delete-orphan")
+
+
+class ItineraryLeg(Base):
+ __tablename__ = "itinerary_legs"
+ __table_args__ = (UniqueConstraint("itinerary_id", "sequence", name="uq_itinerary_leg_sequence"),)
+
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ itinerary_id: Mapped[int] = mapped_column(ForeignKey("itineraries.id"), nullable=False, index=True)
+ sequence: Mapped[int] = mapped_column(Integer, nullable=False)
+ mode: Mapped[Optional[str]] = mapped_column(String(64), nullable=True, index=True)
+ route_ref: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ route_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ from_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ to_name: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+ departure_time: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
+ arrival_time: Mapped[Optional[str]] = mapped_column(String(32), nullable=True)
+ locked: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, index=True)
+ payload_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
+
+ itinerary: Mapped[Itinerary] = relationship(back_populates="legs")
diff --git a/app/osm_classification.py b/app/osm_classification.py
new file mode 100644
index 0000000..7569c8f
--- /dev/null
+++ b/app/osm_classification.py
@@ -0,0 +1,111 @@
+from __future__ import annotations
+
+import json
+import re
+from typing import Mapping
+
+
+LOCAL_SCOPE = "local"
+REGIONAL_SCOPE = "regional"
+LONG_DISTANCE_SCOPE = "long_distance"
+UNKNOWN_SCOPE = "unknown"
+OSM_ROUTE_SCOPE_CLASSIFIER_VERSION = "route_scope_v2"
+
+BUS_MODES = {"bus", "trolleybus"}
+LOCAL_MODES = {"tram", "light_rail", "subway", "ferry", "funicular", "aerialway", "monorail"}
+LONG_DISTANCE_MODES = {"coach"}
+
+LONG_DISTANCE_SERVICE_VALUES = {
+ "high_speed",
+ "long_distance",
+ "intercity",
+ "international",
+ "night",
+ "sleeper",
+}
+REGIONAL_SERVICE_VALUES = {"regional", "interurban", "commuter", "branch", "suburban"}
+LOCAL_SERVICE_VALUES = {"local", "urban", "city", "subway", "tram", "light_rail", "s-bahn", "sbahn"}
+
+LONG_DISTANCE_PREFIX_RE = re.compile(r"^(ICE|IC|EC|ECE|EN|NJ|RJ|RJX|TGV|THA|EST|FLX|WB)\b|^(ICE|IC|EC|ECE|EN|NJ|RJ|RJX|TGV|THA|EST|FLX|WB)\d")
+REGIONAL_PREFIX_RE = re.compile(r"^(IRE|RE|RB|RER|TER|REX|MEX|ALX|WFB|R)\b|^(IRE|RE|RB|RER|TER|REX|MEX|ALX|WFB|R)\d")
+LOCAL_TRAIN_PREFIX_RE = re.compile(r"^(S|S-BAHN)\b|^S\d")
+
+
+def infer_osm_route_scope(
+ *,
+ mode: str | None,
+ ref: str | None = None,
+ name: str | None = None,
+ network: str | None = None,
+ tags: Mapping[str, object] | str | None = None,
+) -> str | None:
+ """Classify a public-transport route into a display scope.
+
+ OSM tagging varies by country and operator, so this intentionally combines
+ explicit service tags with conservative reference-prefix heuristics.
+ """
+ normalized_mode = (mode or "").strip().lower()
+ tags_dict = _tags_dict(tags)
+ values = {
+ str(tags_dict.get(key) or "").strip().lower()
+ for key in ("service", "train", "bus", "passenger", "network:type", "route_scope")
+ if tags_dict.get(key)
+ }
+ if values & LONG_DISTANCE_SERVICE_VALUES:
+ return LONG_DISTANCE_SCOPE
+ if values & LOCAL_SERVICE_VALUES:
+ return LOCAL_SCOPE
+ if values & REGIONAL_SERVICE_VALUES:
+ return REGIONAL_SCOPE
+ if normalized_mode in LOCAL_MODES:
+ return LOCAL_SCOPE
+ if normalized_mode in LONG_DISTANCE_MODES:
+ return LONG_DISTANCE_SCOPE
+
+ text = _classification_text(ref, name, network, tags_dict)
+ if normalized_mode in BUS_MODES:
+ if any(marker in text for marker in ("FLIXBUS", "EUROLINES", "INTERCITYBUS", "IC BUS", "LONG DISTANCE", "FERNBUS")):
+ return LONG_DISTANCE_SCOPE
+ if any(marker in text for marker in ("REGIONALBUS", "REGIOBUS", "REGIONAL BUS", "REGIONALVERKEHR", "REGIONAL VERKEHR")):
+ return REGIONAL_SCOPE
+ return LOCAL_SCOPE
+
+ if normalized_mode == "train":
+ if LONG_DISTANCE_PREFIX_RE.search(text) or any(marker in text for marker in ("INTERCITY", "EUROCITY", "NIGHTJET", "FLIXTRAIN")):
+ return LONG_DISTANCE_SCOPE
+ if LOCAL_TRAIN_PREFIX_RE.search(text) or "S-BAHN" in text or "SBahn".upper() in text:
+ return LOCAL_SCOPE
+ if REGIONAL_PREFIX_RE.search(text) or any(marker in text for marker in ("REGIONAL", "REGIO", "REGIONALBAHN", "REGIONALEXPRESS")):
+ return REGIONAL_SCOPE
+ return UNKNOWN_SCOPE
+
+ return None
+
+
+def infer_osm_route_scope_from_tags(mode: str | None, ref: str | None, name: str | None, network: str | None, tags_json: str | None) -> str | None:
+ return infer_osm_route_scope(mode=mode, ref=ref, name=name, network=network, tags=tags_json)
+
+
+def _tags_dict(tags: Mapping[str, object] | str | None) -> dict[str, object]:
+ if isinstance(tags, str):
+ try:
+ data = json.loads(tags or "{}")
+ except json.JSONDecodeError:
+ return {}
+ return data if isinstance(data, dict) else {}
+ if isinstance(tags, Mapping):
+ return dict(tags)
+ return {}
+
+
+def _classification_text(ref: str | None, name: str | None, network: str | None, tags: Mapping[str, object]) -> str:
+ parts = [
+ ref or "",
+ name or "",
+ network or "",
+ str(tags.get("ref") or ""),
+ str(tags.get("name") or ""),
+ str(tags.get("network") or ""),
+ str(tags.get("network:short") or ""),
+ ]
+ return " ".join(parts).strip().upper().replace("_", " ")
diff --git a/app/osm_storage.py b/app/osm_storage.py
new file mode 100644
index 0000000..a6c9928
--- /dev/null
+++ b/app/osm_storage.py
@@ -0,0 +1,981 @@
+from __future__ import annotations
+
+import json
+import sqlite3
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Iterator, Sequence
+
+from sqlalchemy import and_, func, insert, not_, or_, select, text
+from sqlalchemy.dialects.postgresql import insert as postgresql_insert
+from sqlalchemy.orm import Session
+
+from app.config import settings
+from app.models import Dataset, OsmFeature
+from app.spatial import refresh_postgis_geometries
+
+
+OSM_STORAGE_METADATA_KEY = "osm_storage"
+OSM_STORAGE_MAIN = "main"
+OSM_STORAGE_SIDECAR_FEATURES = "sidecar_features"
+SQLITE_IN_CHUNK_SIZE = 800
+OSM_SIDECAR_ROUTE_SCOPE_INDEXES = ["ix_osm_sidecar_scope_bbox"]
+OSM_FEATURE_COLUMNS = [
+ "dataset_id",
+ "osm_type",
+ "osm_id",
+ "kind",
+ "mode",
+ "route_scope",
+ "name",
+ "ref",
+ "operator",
+ "network",
+ "geometry_geojson",
+ "min_lon",
+ "min_lat",
+ "max_lon",
+ "max_lat",
+ "tags_json",
+ "route_key",
+ "operator_key",
+]
+
+
+def effective_osm_feature_storage(value: str | None = None) -> str:
+ configured = str(value or settings.osm_feature_storage or OSM_STORAGE_SIDECAR_FEATURES).strip().lower()
+ if configured in {OSM_STORAGE_MAIN, "main", "main_db", "postgres", "postgresql"}:
+ return OSM_STORAGE_MAIN
+ if settings.is_postgresql_database and not settings.postgres_use_sidecars:
+ return OSM_STORAGE_MAIN
+ return OSM_STORAGE_SIDECAR_FEATURES
+
+
+class MissingOsmSidecar(FileNotFoundError):
+ pass
+
+
+def dataset_metadata(dataset: Dataset) -> dict:
+ try:
+ metadata = json.loads(dataset.metadata_json or "{}")
+ except json.JSONDecodeError:
+ return {}
+ return metadata if isinstance(metadata, dict) else {}
+
+
+def features_are_sidecar(dataset: Dataset | None) -> bool:
+ if dataset is None:
+ return False
+ storage = dataset_metadata(dataset).get(OSM_STORAGE_METADATA_KEY)
+ if not isinstance(storage, dict):
+ return False
+ tables = storage.get("tables")
+ if isinstance(tables, dict):
+ return tables.get("osm_features") == "sidecar"
+ return storage.get("mode") == OSM_STORAGE_SIDECAR_FEATURES
+
+
+def sidecar_path(dataset: Dataset | None) -> Path | None:
+ if dataset is None:
+ return None
+ storage = dataset_metadata(dataset).get(OSM_STORAGE_METADATA_KEY)
+ if not isinstance(storage, dict):
+ return None
+ value = storage.get("sidecar_path")
+ if not value:
+ return None
+ return Path(str(value))
+
+
+def dataset_sidecar_paths(dataset: Dataset) -> list[Path]:
+ path = sidecar_path(dataset)
+ return [] if path is None else [path]
+
+
+def missing_sidecar_paths(dataset: Dataset | None) -> list[str]:
+ if not features_are_sidecar(dataset):
+ return []
+ path = sidecar_path(dataset)
+ if path is None or path.exists():
+ return []
+ return [str(path)]
+
+
+@contextmanager
+def sidecar_connection(dataset: Dataset) -> Iterator[sqlite3.Connection]:
+ path = sidecar_path(dataset)
+ if path is None:
+ raise MissingOsmSidecar(f"dataset #{dataset.id} does not reference an OSM sidecar")
+ if not path.exists():
+ raise MissingOsmSidecar(f"OSM sidecar does not exist: {path}")
+ connection = sqlite3.connect(f"file:{path}?mode=ro", uri=True)
+ connection.row_factory = sqlite3.Row
+ try:
+ yield connection
+ finally:
+ connection.close()
+
+
+@contextmanager
+def writable_sidecar_connection(dataset: Dataset) -> Iterator[sqlite3.Connection]:
+ path = sidecar_path(dataset)
+ if path is None:
+ raise MissingOsmSidecar(f"dataset #{dataset.id} does not reference an OSM sidecar")
+ if not path.exists():
+ raise MissingOsmSidecar(f"OSM sidecar does not exist: {path}")
+ connection = sqlite3.connect(path)
+ connection.row_factory = sqlite3.Row
+ try:
+ connection.execute(f"PRAGMA busy_timeout={int(settings.sqlite_busy_timeout_ms)}")
+ connection.execute("PRAGMA synchronous=NORMAL")
+ yield connection
+ finally:
+ connection.close()
+
+
+def create_osm_sidecar(dataset: Dataset, rows: Sequence[dict[str, object]], *, source_hash: str | None = None) -> dict:
+ path = _new_sidecar_path(dataset, source_hash or dataset.sha256)
+ path.parent.mkdir(parents=True, exist_ok=True)
+ if path.exists():
+ path.unlink()
+ connection = sqlite3.connect(path)
+ try:
+ connection.execute("PRAGMA journal_mode=OFF")
+ connection.execute("PRAGMA synchronous=OFF")
+ _create_schema(connection)
+ deduped_rows, duplicate_count = dedupe_osm_feature_rows(rows)
+ inserted = 0
+ counts = {"route": 0, "stop": 0, "station": 0, "terminal": 0, "infra": 0, "feature": 0}
+ insert_sql = f"""
+ INSERT INTO osm_features
+ ({", ".join(["id", *OSM_FEATURE_COLUMNS])})
+ VALUES
+ ({", ".join(["?"] * (len(OSM_FEATURE_COLUMNS) + 1))})
+ """
+ batch = []
+ for index, row in enumerate(deduped_rows, start=1):
+ kind = str(row.get("kind") or "feature")
+ counts[kind] = counts.get(kind, 0) + 1
+ batch.append((index, *[row.get(column) for column in OSM_FEATURE_COLUMNS]))
+ if len(batch) >= 5000:
+ connection.executemany(insert_sql, batch)
+ inserted += len(batch)
+ batch.clear()
+ if batch:
+ connection.executemany(insert_sql, batch)
+ inserted += len(batch)
+ connection.commit()
+ _create_indexes(connection)
+ connection.commit()
+ finally:
+ connection.close()
+ return {
+ "mode": OSM_STORAGE_SIDECAR_FEATURES,
+ "tables": {"osm_features": "sidecar"},
+ "sidecar_path": str(path),
+ "features": inserted,
+ "duplicate_features_skipped": duplicate_count,
+ "counts": counts,
+ }
+
+
+def ensure_osm_sidecar_schema(connection: sqlite3.Connection) -> None:
+ columns = _sidecar_columns(connection)
+ if "route_scope" not in columns:
+ connection.execute("ALTER TABLE osm_features ADD COLUMN route_scope TEXT")
+ connection.commit()
+
+
+def drop_osm_sidecar_route_scope_indexes(connection: sqlite3.Connection) -> None:
+ for index_name in OSM_SIDECAR_ROUTE_SCOPE_INDEXES:
+ connection.execute(f"DROP INDEX IF EXISTS {index_name}")
+
+
+def rebuild_osm_sidecar_indexes(connection: sqlite3.Connection) -> None:
+ _create_indexes(connection)
+
+
+def osm_feature_count(session: Session, dataset_id: int, *, kind: str | Sequence[str] | None = None) -> int:
+ dataset = session.get(Dataset, dataset_id)
+ if features_are_sidecar(dataset):
+ kinds = _as_list(kind)
+ sql = "SELECT COUNT(*) FROM osm_features"
+ params: list[object] = []
+ if kinds:
+ placeholders = ", ".join(["?"] * len(kinds))
+ sql += f" WHERE kind IN ({placeholders})"
+ params.extend(kinds)
+ try:
+ with sidecar_connection(dataset) as connection:
+ return int(connection.execute(sql, params).fetchone()[0] or 0)
+ except MissingOsmSidecar:
+ return 0
+ stmt = select(func.count()).select_from(OsmFeature).where(OsmFeature.dataset_id == dataset_id)
+ kinds = _as_list(kind)
+ if kinds:
+ stmt = stmt.where(OsmFeature.kind.in_(kinds))
+ return int(session.scalar(stmt) or 0)
+
+
+def osm_feature_bbox(
+ session: Session,
+ dataset_ids: Sequence[int],
+ *,
+ kinds: Sequence[str] | None = None,
+) -> tuple[float | None, float | None, float | None, float | None]:
+ if not dataset_ids:
+ return (None, None, None, None)
+ datasets = {
+ dataset.id: dataset
+ for dataset in session.scalars(select(Dataset).where(Dataset.id.in_([int(value) for value in dataset_ids]))).all()
+ }
+ boxes: list[tuple[float, float, float, float]] = []
+ main_dataset_ids = [dataset_id for dataset_id, dataset in datasets.items() if not features_are_sidecar(dataset)]
+ if main_dataset_ids:
+ stmt = select(func.min(OsmFeature.min_lon), func.min(OsmFeature.min_lat), func.max(OsmFeature.max_lon), func.max(OsmFeature.max_lat)).where(
+ OsmFeature.dataset_id.in_(main_dataset_ids)
+ )
+ if kinds:
+ stmt = stmt.where(OsmFeature.kind.in_(list(kinds)))
+ row = session.execute(stmt).one()
+ if None not in row:
+ boxes.append((float(row[0]), float(row[1]), float(row[2]), float(row[3])))
+ for dataset in datasets.values():
+ if not features_are_sidecar(dataset):
+ continue
+ where = []
+ params: list[object] = []
+ if kinds:
+ placeholders = ", ".join(["?"] * len(kinds))
+ where.append(f"kind IN ({placeholders})")
+ params.extend(list(kinds))
+ sql = "SELECT MIN(min_lon), MIN(min_lat), MAX(max_lon), MAX(max_lat) FROM osm_features"
+ if where:
+ sql += " WHERE " + " AND ".join(where)
+ try:
+ with sidecar_connection(dataset) as connection:
+ row = connection.execute(sql, params).fetchone()
+ if row is not None and None not in row:
+ boxes.append((float(row[0]), float(row[1]), float(row[2]), float(row[3])))
+ except MissingOsmSidecar:
+ continue
+ if not boxes:
+ return (None, None, None, None)
+ return (
+ min(box[0] for box in boxes),
+ min(box[1] for box in boxes),
+ max(box[2] for box in boxes),
+ max(box[3] for box in boxes),
+ )
+
+
+def query_osm_features(
+ session: Session,
+ dataset_ids: Sequence[int],
+ *,
+ kinds: Sequence[str] | None = None,
+ modes: Sequence[str] | None = None,
+ bbox: tuple[float, float, float, float] | None = None,
+ geometry_required: bool | None = None,
+ search: str | None = None,
+ route_key: str | None = None,
+ route_scopes: Sequence[str] | None = None,
+ ref: str | None = None,
+ osm_type: str | None = None,
+ osm_id: str | None = None,
+ limit: int | None = None,
+ offset: int | None = None,
+ prefer_materialized_ids: bool = True,
+) -> list[OsmFeature]:
+ if not dataset_ids:
+ return []
+ datasets = {
+ dataset.id: dataset
+ for dataset in session.scalars(select(Dataset).where(Dataset.id.in_([int(value) for value in dataset_ids]))).all()
+ }
+ materialized_ids = _materialized_ids_by_identity(session, list(datasets)) if prefer_materialized_ids else {}
+ rows: list[OsmFeature] = []
+ main_dataset_ids = [dataset_id for dataset_id, dataset in datasets.items() if not features_are_sidecar(dataset)]
+ if main_dataset_ids:
+ stmt = select(OsmFeature).where(OsmFeature.dataset_id.in_(main_dataset_ids))
+ stmt = _apply_main_filters(
+ stmt,
+ kinds=kinds,
+ modes=modes,
+ bbox=bbox,
+ geometry_required=geometry_required,
+ search=search,
+ route_key=route_key,
+ route_scopes=route_scopes,
+ ref=ref,
+ osm_type=osm_type,
+ osm_id=osm_id,
+ )
+ if offset:
+ stmt = stmt.offset(max(0, int(offset)))
+ rows.extend(
+ session.scalars(
+ stmt.order_by(OsmFeature.kind, OsmFeature.mode, OsmFeature.ref, OsmFeature.name, OsmFeature.id).limit(limit)
+ ).all()
+ )
+ for dataset_id, dataset in datasets.items():
+ if not features_are_sidecar(dataset):
+ continue
+ rows.extend(
+ _query_sidecar_features(
+ dataset,
+ kinds=kinds,
+ modes=modes,
+ bbox=bbox,
+ geometry_required=geometry_required,
+ search=search,
+ route_key=route_key,
+ route_scopes=route_scopes,
+ ref=ref,
+ osm_type=osm_type,
+ osm_id=osm_id,
+ limit=limit,
+ offset=offset,
+ materialized_ids=materialized_ids,
+ )
+ )
+ rows.sort(key=lambda row: (row.kind or "", row.mode or "", row.ref or "", row.name or "", int(row.id or 0)))
+ if limit is not None:
+ return rows[: max(1, int(limit))]
+ return rows
+
+
+def get_osm_feature(session: Session, feature_id: int) -> OsmFeature | None:
+ return session.get(OsmFeature, feature_id)
+
+
+def osm_feature_identity_key(feature: OsmFeature) -> str:
+ return f"{feature.dataset_id}|{feature.osm_type}|{feature.osm_id}"
+
+
+def osm_feature_public_id(feature: OsmFeature) -> int | str | None:
+ if getattr(feature, "_osm_sidecar_source", False):
+ return osm_feature_identity_key(feature)
+ return feature.id
+
+
+def resolve_osm_feature(session: Session, value: int | str) -> OsmFeature | None:
+ int_value = _safe_int(value)
+ if int_value is not None:
+ feature = session.get(OsmFeature, int_value)
+ if feature is not None:
+ return feature
+ parsed = parse_osm_feature_identity_key(str(value))
+ if parsed is None:
+ return None
+ dataset_id, osm_type, osm_id = parsed
+ existing = session.scalar(
+ select(OsmFeature).where(
+ OsmFeature.dataset_id == dataset_id,
+ OsmFeature.osm_type == osm_type,
+ OsmFeature.osm_id == osm_id,
+ )
+ )
+ if existing is not None:
+ return existing
+ dataset = session.get(Dataset, dataset_id)
+ if not features_are_sidecar(dataset):
+ return None
+ try:
+ with sidecar_connection(dataset) as connection:
+ select_columns = ", ".join(_sidecar_select_columns(_sidecar_columns(connection)))
+ row = connection.execute(
+ f"""
+ SELECT id, {select_columns}
+ FROM osm_features
+ WHERE dataset_id = ?
+ AND osm_type = ?
+ AND osm_id = ?
+ """,
+ (dataset_id, osm_type, osm_id),
+ ).fetchone()
+ except MissingOsmSidecar:
+ return None
+ if row is None:
+ return None
+ return _feature_from_row(row, {})
+
+
+def parse_osm_feature_identity_key(value: str) -> tuple[int, str, str] | None:
+ parts = value.split("|", 2)
+ if len(parts) != 3:
+ return None
+ dataset_id = _safe_int(parts[0])
+ if dataset_id is None:
+ return None
+ osm_type = parts[1].strip()
+ osm_id = parts[2].strip()
+ if not osm_type or not osm_id:
+ return None
+ return dataset_id, osm_type, osm_id
+
+
+def ensure_main_osm_feature(session: Session, feature: OsmFeature) -> OsmFeature:
+ existing = session.scalar(
+ select(OsmFeature).where(
+ OsmFeature.dataset_id == feature.dataset_id,
+ OsmFeature.osm_type == feature.osm_type,
+ OsmFeature.osm_id == feature.osm_id,
+ )
+ )
+ if existing is not None:
+ return existing
+ values = dict(
+ dataset_id=feature.dataset_id,
+ osm_type=feature.osm_type,
+ osm_id=feature.osm_id,
+ kind=feature.kind,
+ mode=feature.mode,
+ route_scope=feature.route_scope,
+ name=feature.name,
+ ref=feature.ref,
+ operator=feature.operator,
+ network=feature.network,
+ geometry_geojson=feature.geometry_geojson,
+ min_lon=feature.min_lon,
+ min_lat=feature.min_lat,
+ max_lon=feature.max_lon,
+ max_lat=feature.max_lat,
+ tags_json=feature.tags_json,
+ route_key=feature.route_key,
+ operator_key=feature.operator_key,
+ )
+ if settings.is_postgresql_database:
+ session.execute(
+ postgresql_insert(OsmFeature)
+ .values(**values)
+ .on_conflict_do_nothing(index_elements=["dataset_id", "osm_type", "osm_id"])
+ )
+ else:
+ session.execute(insert(OsmFeature).values(**values).prefix_with("OR IGNORE"))
+ session.flush()
+ refresh_postgis_geometries(session, dataset_id=feature.dataset_id, tables=["osm_features"])
+ existing = session.scalar(
+ select(OsmFeature).where(
+ OsmFeature.dataset_id == feature.dataset_id,
+ OsmFeature.osm_type == feature.osm_type,
+ OsmFeature.osm_id == feature.osm_id,
+ )
+ )
+ if existing is None:
+ raise RuntimeError(f"Could not materialize OSM feature {feature.dataset_id}:{feature.osm_type}:{feature.osm_id}")
+ return existing
+
+
+def materialize_osm_features(session: Session, features: Sequence[OsmFeature]) -> list[OsmFeature]:
+ return [ensure_main_osm_feature(session, feature) for feature in features]
+
+
+def _new_sidecar_path(dataset: Dataset, source_hash: str | None) -> Path:
+ suffix = (source_hash or dataset.sha256 or str(dataset.id))[:12]
+ return settings.data_dir / "sidecars" / f"source_{dataset.source_id}" / f"osm_dataset_{dataset.id}_{suffix}.sqlite"
+
+
+def dedupe_osm_feature_rows(rows: Sequence[dict[str, object]]) -> tuple[list[dict[str, object]], int]:
+ selected: dict[tuple[int, str, str], dict[str, object]] = {}
+ for row in rows:
+ key = (int(row["dataset_id"]), str(row["osm_type"]), str(row["osm_id"]))
+ current = selected.get(key)
+ if current is None or _feature_row_preference(row) < _feature_row_preference(current):
+ selected[key] = dict(row)
+ return list(selected.values()), max(0, len(rows) - len(selected))
+
+
+def _feature_row_preference(row: dict[str, object]) -> tuple[int, int, int]:
+ kind_rank = {
+ "route": 0,
+ "station": 1,
+ "terminal": 2,
+ "stop": 3,
+ "infra": 4,
+ "feature": 5,
+ }.get(str(row.get("kind") or "feature"), 6)
+ has_geometry = 0 if row.get("geometry_geojson") else 1
+ geometry_size = -len(str(row.get("geometry_geojson") or ""))
+ return (kind_rank, has_geometry, geometry_size)
+
+
+def _create_schema(connection: sqlite3.Connection) -> None:
+ connection.execute(
+ """
+ CREATE TABLE osm_features (
+ id INTEGER PRIMARY KEY,
+ dataset_id INTEGER NOT NULL,
+ osm_type TEXT NOT NULL,
+ osm_id TEXT NOT NULL,
+ kind TEXT NOT NULL,
+ mode TEXT,
+ route_scope TEXT,
+ name TEXT,
+ ref TEXT,
+ operator TEXT,
+ network TEXT,
+ geometry_geojson TEXT,
+ min_lon REAL,
+ min_lat REAL,
+ max_lon REAL,
+ max_lat REAL,
+ tags_json TEXT,
+ route_key TEXT,
+ operator_key TEXT,
+ UNIQUE(dataset_id, osm_type, osm_id)
+ )
+ """
+ )
+
+
+def _create_indexes(connection: sqlite3.Connection) -> None:
+ statements = [
+ "CREATE INDEX IF NOT EXISTS ix_osm_sidecar_kind_mode_bbox ON osm_features (kind, mode, min_lon, max_lon, min_lat, max_lat)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_sidecar_scope_bbox ON osm_features (kind, mode, route_scope, min_lon, max_lon, min_lat, max_lat)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_sidecar_route_key ON osm_features (route_key)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_sidecar_ref ON osm_features (ref)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_sidecar_identity ON osm_features (dataset_id, osm_type, osm_id)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_sidecar_kind_ref_mode ON osm_features (kind, ref, mode)",
+ ]
+ for statement in statements:
+ connection.execute(statement)
+
+
+def _apply_main_filters(stmt, *, kinds, modes, bbox, geometry_required, search, route_key, route_scopes, ref, osm_type, osm_id):
+ if kinds:
+ stmt = stmt.where(OsmFeature.kind.in_(list(kinds)))
+ if modes:
+ stmt = stmt.where(OsmFeature.mode.in_(list(modes)))
+ if route_scopes:
+ stmt = stmt.where(_main_route_scope_condition([str(scope) for scope in route_scopes]))
+ if bbox:
+ min_lon, min_lat, max_lon, max_lat = bbox
+ if settings.is_postgresql_database:
+ stmt = stmt.where(
+ text(
+ """
+ (
+ osm_features.geom && ST_MakeEnvelope(:bbox_min_lon, :bbox_min_lat, :bbox_max_lon, :bbox_max_lat, 4326)
+ OR (
+ osm_features.geom IS NULL
+ AND osm_features.min_lon <= :bbox_max_lon
+ AND osm_features.max_lon >= :bbox_min_lon
+ AND osm_features.min_lat <= :bbox_max_lat
+ AND osm_features.max_lat >= :bbox_min_lat
+ )
+ )
+ """
+ )
+ ).params(
+ bbox_min_lon=min_lon,
+ bbox_min_lat=min_lat,
+ bbox_max_lon=max_lon,
+ bbox_max_lat=max_lat,
+ )
+ else:
+ stmt = stmt.where(OsmFeature.min_lon <= max_lon, OsmFeature.max_lon >= min_lon, OsmFeature.min_lat <= max_lat, OsmFeature.max_lat >= min_lat)
+ if geometry_required is True:
+ stmt = stmt.where(OsmFeature.geometry_geojson.is_not(None))
+ elif geometry_required is False:
+ stmt = stmt.where(OsmFeature.geometry_geojson.is_(None))
+ if search:
+ if settings.is_postgresql_database:
+ stmt = stmt.where(
+ text(
+ """
+ (
+ LOWER(COALESCE(osm_features.ref, '')) LIKE :search_pattern
+ OR LOWER(COALESCE(osm_features.name, '')) LIKE :search_pattern
+ OR LOWER(COALESCE(osm_features.tags_json, '')) LIKE :search_pattern
+ )
+ """
+ )
+ ).params(search_pattern=f"%{search.lower()}%")
+ else:
+ pattern = f"%{search}%"
+ stmt = stmt.where(
+ (OsmFeature.ref.ilike(pattern))
+ | (OsmFeature.name.ilike(pattern))
+ | (OsmFeature.tags_json.ilike(pattern))
+ )
+ if route_key:
+ stmt = stmt.where(OsmFeature.route_key == route_key)
+ if ref:
+ stmt = stmt.where(OsmFeature.ref == ref)
+ if osm_type:
+ stmt = stmt.where(OsmFeature.osm_type == osm_type)
+ if osm_id:
+ stmt = stmt.where(OsmFeature.osm_id == osm_id)
+ return stmt
+
+
+def _main_route_scope_condition(route_scopes: list[str]):
+ fallback = _main_route_scope_fallback_condition(route_scopes)
+ stored = OsmFeature.route_scope.in_(route_scopes)
+ if "local" in route_scopes:
+ non_local_bus_fallback = _main_route_scope_fallback_condition(["long_distance", "regional"])
+ stored = and_(stored, not_(and_(OsmFeature.mode.in_(["bus", "trolleybus"]), non_local_bus_fallback)))
+ return or_(stored, fallback)
+
+
+def _main_route_scope_fallback_condition(route_scopes: list[str]):
+ ref = func.upper(func.coalesce(OsmFeature.ref, ""))
+ name = func.upper(func.coalesce(OsmFeature.name, ""))
+ network = func.upper(func.coalesce(OsmFeature.network, ""))
+ tags = func.lower(func.coalesce(OsmFeature.tags_json, ""))
+ train_long_distance = and_(
+ OsmFeature.mode == "train",
+ or_(
+ ref.like("ICE%"),
+ ref.like("IC%"),
+ ref.like("EC%"),
+ ref.like("ECE%"),
+ ref.like("EN%"),
+ ref.like("NJ%"),
+ ref.like("RJ%"),
+ ref.like("RJX%"),
+ ref.like("TGV%"),
+ ref.like("THA%"),
+ ref.like("FLX%"),
+ name.like("%INTERCITY%"),
+ name.like("%EUROCITY%"),
+ name.like("%NIGHTJET%"),
+ name.like("%FLIXTRAIN%"),
+ tags.like('%"service":"long_distance"%'),
+ tags.like('%"train":"long_distance"%'),
+ tags.like('%"train":"high_speed"%'),
+ tags.like('%"train":"intercity"%'),
+ ),
+ )
+ bus_long_distance = and_(
+ OsmFeature.mode.in_(["bus", "trolleybus"]),
+ or_(
+ name.like("%FLIXBUS%"),
+ network.like("%FLIXBUS%"),
+ name.like("%EUROLINES%"),
+ network.like("%EUROLINES%"),
+ name.like("%INTERCITYBUS%"),
+ name.like("%IC BUS%"),
+ name.like("%FERNBUS%"),
+ tags.like('%"service":"long_distance"%'),
+ tags.like('%"bus":"long_distance"%'),
+ tags.like('%"bus":"intercity"%'),
+ tags.like('%"network:type":"long_distance"%'),
+ ),
+ )
+ long_distance = or_(OsmFeature.mode == "coach", train_long_distance, bus_long_distance)
+ bus_regional = and_(
+ OsmFeature.mode.in_(["bus", "trolleybus"]),
+ not_(bus_long_distance),
+ or_(
+ name.like("%REGIONALBUS%"),
+ name.like("%REGIOBUS%"),
+ name.like("%REGIONAL BUS%"),
+ name.like("%REGIONALVERKEHR%"),
+ network.like("%REGIONALBUS%"),
+ network.like("%REGIOBUS%"),
+ network.like("%REGIONALVERKEHR%"),
+ tags.like('%"service":"regional"%'),
+ tags.like('%"bus":"regional"%'),
+ tags.like('%"bus":"interurban"%'),
+ tags.like('%"network:type":"regional"%'),
+ ),
+ )
+ local = or_(
+ OsmFeature.mode.in_(["tram", "light_rail", "subway", "ferry", "funicular", "aerialway", "monorail"]),
+ and_(OsmFeature.mode.in_(["bus", "trolleybus"]), not_(or_(bus_long_distance, bus_regional))),
+ and_(
+ OsmFeature.mode == "train",
+ or_(ref.like("S%"), name.like("%S-BAHN%"), network.like("%S-BAHN%"), tags.like('%"train":"commuter"%')),
+ ),
+ )
+ train_regional = and_(
+ OsmFeature.mode == "train",
+ not_(train_long_distance),
+ or_(
+ ref.like("IRE%"),
+ ref.like("RE%"),
+ ref.like("RB%"),
+ ref.like("RER%"),
+ ref.like("TER%"),
+ ref.like("REX%"),
+ ref.like("MEX%"),
+ ref.like("ALX%"),
+ ref.like("WFB%"),
+ ref.like("R%"),
+ name.like("%REGIONAL%"),
+ name.like("%REGIO%"),
+ tags.like('%"service":"regional"%'),
+ tags.like('%"train":"regional"%'),
+ ),
+ )
+ regional = or_(train_regional, bus_regional)
+ conditions = []
+ if "long_distance" in route_scopes:
+ conditions.append(long_distance)
+ if "regional" in route_scopes:
+ conditions.append(regional)
+ if "local" in route_scopes:
+ conditions.append(local)
+ if "unknown" in route_scopes:
+ conditions.append(and_(OsmFeature.mode == "train", not_(or_(long_distance, regional, local))))
+ return or_(*conditions) if conditions else OsmFeature.route_scope.is_(None)
+
+
+def _query_sidecar_features(
+ dataset: Dataset,
+ *,
+ kinds: Sequence[str] | None,
+ modes: Sequence[str] | None,
+ bbox: tuple[float, float, float, float] | None,
+ geometry_required: bool | None,
+ search: str | None,
+ route_key: str | None,
+ route_scopes: Sequence[str] | None,
+ ref: str | None,
+ osm_type: str | None,
+ osm_id: str | None,
+ limit: int | None,
+ offset: int | None,
+ materialized_ids: dict[tuple[int, str, str], int],
+) -> list[OsmFeature]:
+ where = []
+ params: list[object] = []
+ try:
+ with sidecar_connection(dataset) as connection:
+ available_columns = _sidecar_columns(connection)
+ if kinds:
+ placeholders = ", ".join(["?"] * len(kinds))
+ where.append(f"kind IN ({placeholders})")
+ params.extend(list(kinds))
+ if modes:
+ placeholders = ", ".join(["?"] * len(modes))
+ where.append(f"mode IN ({placeholders})")
+ params.extend(list(modes))
+ if bbox:
+ min_lon, min_lat, max_lon, max_lat = bbox
+ where.extend(["min_lon <= ?", "max_lon >= ?", "min_lat <= ?", "max_lat >= ?"])
+ params.extend([max_lon, min_lon, max_lat, min_lat])
+ if geometry_required is True:
+ where.append("geometry_geojson IS NOT NULL")
+ elif geometry_required is False:
+ where.append("geometry_geojson IS NULL")
+ if search:
+ where.append("(LOWER(COALESCE(ref, '')) LIKE ? OR LOWER(COALESCE(name, '')) LIKE ? OR LOWER(COALESCE(tags_json, '')) LIKE ?)")
+ pattern = f"%{search.lower()}%"
+ params.extend([pattern, pattern, pattern])
+ if route_key:
+ where.append("route_key = ?")
+ params.append(route_key)
+ if route_scopes:
+ condition, condition_params = _sidecar_route_scope_condition([str(scope) for scope in route_scopes], has_route_scope="route_scope" in available_columns)
+ where.append(condition)
+ params.extend(condition_params)
+ if ref:
+ where.append("ref = ?")
+ params.append(ref)
+ if osm_type:
+ where.append("osm_type = ?")
+ params.append(osm_type)
+ if osm_id:
+ where.append("osm_id = ?")
+ params.append(osm_id)
+ select_columns = ", ".join(_sidecar_select_columns(available_columns))
+ sql = f"SELECT id, {select_columns} FROM osm_features"
+ if where:
+ sql += " WHERE " + " AND ".join(where)
+ sql += " ORDER BY kind, mode, ref, name, id"
+ if limit is not None:
+ sql += " LIMIT ?"
+ params.append(max(1, int(limit)))
+ if offset:
+ if limit is None:
+ sql += " LIMIT -1"
+ sql += " OFFSET ?"
+ params.append(max(0, int(offset)))
+ return [_feature_from_row(row, materialized_ids) for row in connection.execute(sql, params).fetchall()]
+ except MissingOsmSidecar:
+ return []
+
+
+def _sidecar_columns(connection: sqlite3.Connection) -> set[str]:
+ return {str(row["name"]) for row in connection.execute("PRAGMA table_info(osm_features)").fetchall()}
+
+
+def _sidecar_select_columns(available_columns: set[str]) -> list[str]:
+ return [column if column in available_columns else f"NULL AS {column}" for column in OSM_FEATURE_COLUMNS]
+
+
+def _sidecar_route_scope_condition(route_scopes: list[str], *, has_route_scope: bool) -> tuple[str, list[object]]:
+ fallback_sql, fallback_params = _sidecar_route_scope_fallback_condition(route_scopes)
+ if has_route_scope:
+ placeholders = ", ".join(["?"] * len(route_scopes))
+ stored_sql = f"route_scope IN ({placeholders})"
+ params: list[object] = [*route_scopes]
+ if "local" in route_scopes:
+ non_local_sql, non_local_params = _sidecar_route_scope_fallback_condition(["long_distance", "regional"])
+ stored_sql = f"({stored_sql} AND NOT (mode IN ('bus', 'trolleybus') AND {non_local_sql}))"
+ params.extend(non_local_params)
+ return f"({stored_sql} OR {fallback_sql})", [*params, *fallback_params]
+ return fallback_sql, fallback_params
+
+
+def _sidecar_route_scope_fallback_condition(route_scopes: list[str]) -> tuple[str, list[object]]:
+ train_long_distance = """(
+ mode = 'train'
+ AND (
+ UPPER(COALESCE(ref, '')) LIKE 'ICE%'
+ OR UPPER(COALESCE(ref, '')) LIKE 'IC%'
+ OR UPPER(COALESCE(ref, '')) LIKE 'EC%'
+ OR UPPER(COALESCE(ref, '')) LIKE 'ECE%'
+ OR UPPER(COALESCE(ref, '')) LIKE 'EN%'
+ OR UPPER(COALESCE(ref, '')) LIKE 'NJ%'
+ OR UPPER(COALESCE(ref, '')) LIKE 'RJ%'
+ OR UPPER(COALESCE(ref, '')) LIKE 'RJX%'
+ OR UPPER(COALESCE(ref, '')) LIKE 'TGV%'
+ OR UPPER(COALESCE(ref, '')) LIKE 'THA%'
+ OR UPPER(COALESCE(ref, '')) LIKE 'FLX%'
+ OR UPPER(COALESCE(name, '')) LIKE '%INTERCITY%'
+ OR UPPER(COALESCE(name, '')) LIKE '%EUROCITY%'
+ OR UPPER(COALESCE(name, '')) LIKE '%NIGHTJET%'
+ OR UPPER(COALESCE(name, '')) LIKE '%FLIXTRAIN%'
+ OR LOWER(COALESCE(tags_json, '')) LIKE '%"service":"long_distance"%'
+ OR LOWER(COALESCE(tags_json, '')) LIKE '%"train":"long_distance"%'
+ OR LOWER(COALESCE(tags_json, '')) LIKE '%"train":"high_speed"%'
+ OR LOWER(COALESCE(tags_json, '')) LIKE '%"train":"intercity"%'
+ )
+ )"""
+ bus_long_distance = """(
+ mode IN ('bus', 'trolleybus')
+ AND (
+ UPPER(COALESCE(name, '')) LIKE '%FLIXBUS%'
+ OR UPPER(COALESCE(network, '')) LIKE '%FLIXBUS%'
+ OR UPPER(COALESCE(name, '')) LIKE '%EUROLINES%'
+ OR UPPER(COALESCE(network, '')) LIKE '%EUROLINES%'
+ OR UPPER(COALESCE(name, '')) LIKE '%INTERCITYBUS%'
+ OR UPPER(COALESCE(name, '')) LIKE '%IC BUS%'
+ OR UPPER(COALESCE(name, '')) LIKE '%FERNBUS%'
+ OR LOWER(COALESCE(tags_json, '')) LIKE '%"service":"long_distance"%'
+ OR LOWER(COALESCE(tags_json, '')) LIKE '%"bus":"long_distance"%'
+ OR LOWER(COALESCE(tags_json, '')) LIKE '%"bus":"intercity"%'
+ OR LOWER(COALESCE(tags_json, '')) LIKE '%"network:type":"long_distance"%'
+ )
+ )"""
+ long_distance = f"(mode = 'coach' OR {train_long_distance} OR {bus_long_distance})"
+ bus_regional = f"""(
+ mode IN ('bus', 'trolleybus')
+ AND NOT {bus_long_distance}
+ AND (
+ UPPER(COALESCE(name, '')) LIKE '%REGIONALBUS%'
+ OR UPPER(COALESCE(name, '')) LIKE '%REGIOBUS%'
+ OR UPPER(COALESCE(name, '')) LIKE '%REGIONAL BUS%'
+ OR UPPER(COALESCE(name, '')) LIKE '%REGIONALVERKEHR%'
+ OR UPPER(COALESCE(network, '')) LIKE '%REGIONALBUS%'
+ OR UPPER(COALESCE(network, '')) LIKE '%REGIOBUS%'
+ OR UPPER(COALESCE(network, '')) LIKE '%REGIONALVERKEHR%'
+ OR LOWER(COALESCE(tags_json, '')) LIKE '%"service":"regional"%'
+ OR LOWER(COALESCE(tags_json, '')) LIKE '%"bus":"regional"%'
+ OR LOWER(COALESCE(tags_json, '')) LIKE '%"bus":"interurban"%'
+ OR LOWER(COALESCE(tags_json, '')) LIKE '%"network:type":"regional"%'
+ )
+ )"""
+ train_regional = f"""(
+ mode = 'train'
+ AND NOT {train_long_distance}
+ AND (
+ UPPER(COALESCE(ref, '')) LIKE 'IRE%'
+ OR UPPER(COALESCE(ref, '')) LIKE 'RE%'
+ OR UPPER(COALESCE(ref, '')) LIKE 'RB%'
+ OR UPPER(COALESCE(ref, '')) LIKE 'RER%'
+ OR UPPER(COALESCE(ref, '')) LIKE 'TER%'
+ OR UPPER(COALESCE(ref, '')) LIKE 'REX%'
+ OR UPPER(COALESCE(ref, '')) LIKE 'MEX%'
+ OR UPPER(COALESCE(ref, '')) LIKE 'ALX%'
+ OR UPPER(COALESCE(ref, '')) LIKE 'WFB%'
+ OR UPPER(COALESCE(ref, '')) LIKE 'R%'
+ OR UPPER(COALESCE(name, '')) LIKE '%REGIONAL%'
+ OR UPPER(COALESCE(name, '')) LIKE '%REGIO%'
+ OR LOWER(COALESCE(tags_json, '')) LIKE '%"service":"regional"%'
+ OR LOWER(COALESCE(tags_json, '')) LIKE '%"train":"regional"%'
+ )
+ )"""
+ regional = f"({train_regional} OR {bus_regional})"
+ local = f"""(
+ mode IN ('tram', 'light_rail', 'subway', 'ferry', 'funicular', 'aerialway', 'monorail')
+ OR (mode IN ('bus', 'trolleybus') AND NOT ({bus_long_distance} OR {bus_regional}))
+ OR (
+ mode = 'train'
+ AND (
+ UPPER(COALESCE(ref, '')) LIKE 'S%'
+ OR UPPER(COALESCE(name, '')) LIKE '%S-BAHN%'
+ OR UPPER(COALESCE(network, '')) LIKE '%S-BAHN%'
+ OR LOWER(COALESCE(tags_json, '')) LIKE '%"train":"commuter"%'
+ )
+ )
+ )"""
+ parts = []
+ if "long_distance" in route_scopes:
+ parts.append(long_distance)
+ if "regional" in route_scopes:
+ parts.append(regional)
+ if "local" in route_scopes:
+ parts.append(local)
+ if "unknown" in route_scopes:
+ parts.append(f"(mode = 'train' AND NOT ({long_distance} OR {regional} OR {local}))")
+ return "(" + " OR ".join(parts or ["0"]) + ")", []
+
+
+def _feature_from_row(row: sqlite3.Row, materialized_ids: dict[tuple[int, str, str], int]) -> OsmFeature:
+ dataset_id = int(row["dataset_id"])
+ osm_type = str(row["osm_type"])
+ osm_id = str(row["osm_id"])
+ feature_id = materialized_ids.get((dataset_id, osm_type, osm_id), int(row["id"]))
+ feature = OsmFeature(
+ id=feature_id,
+ dataset_id=dataset_id,
+ osm_type=osm_type,
+ osm_id=osm_id,
+ kind=str(row["kind"]),
+ mode=row["mode"],
+ route_scope=row["route_scope"],
+ name=row["name"],
+ ref=row["ref"],
+ operator=row["operator"],
+ network=row["network"],
+ geometry_geojson=row["geometry_geojson"],
+ min_lon=row["min_lon"],
+ min_lat=row["min_lat"],
+ max_lon=row["max_lon"],
+ max_lat=row["max_lat"],
+ tags_json=row["tags_json"],
+ route_key=row["route_key"],
+ operator_key=row["operator_key"],
+ )
+ setattr(feature, "_osm_sidecar_source", True)
+ setattr(feature, "_osm_sidecar_row_id", int(row["id"]))
+ return feature
+
+
+def _materialized_ids_by_identity(session: Session, dataset_ids: Sequence[int]) -> dict[tuple[int, str, str], int]:
+ if not dataset_ids:
+ return {}
+ rows = session.execute(
+ select(OsmFeature.dataset_id, OsmFeature.osm_type, OsmFeature.osm_id, OsmFeature.id).where(OsmFeature.dataset_id.in_(dataset_ids))
+ ).all()
+ return {(int(dataset_id), str(osm_type), str(osm_id)): int(feature_id) for dataset_id, osm_type, osm_id, feature_id in rows}
+
+
+def _as_list(value: str | Sequence[str] | None) -> list[str]:
+ if value is None:
+ return []
+ if isinstance(value, str):
+ return [value]
+ return [str(item) for item in value]
+
+
+def _safe_int(value: object) -> int | None:
+ try:
+ return int(value) # type: ignore[arg-type]
+ except (TypeError, ValueError):
+ return None
diff --git a/app/performance.py b/app/performance.py
new file mode 100644
index 0000000..5a125d0
--- /dev/null
+++ b/app/performance.py
@@ -0,0 +1,61 @@
+from __future__ import annotations
+
+import json
+import time
+from contextlib import contextmanager
+from datetime import datetime, timezone
+from pathlib import Path
+from typing import Iterator
+
+from app.config import settings
+
+
+@contextmanager
+def measure_pipeline_phase(
+ phase: str,
+ *,
+ source_id: int | None = None,
+ dataset_id: int | None = None,
+ metadata: dict[str, object] | None = None,
+) -> Iterator[dict[str, object]]:
+ start = time.perf_counter()
+ payload: dict[str, object] = dict(metadata or {})
+ try:
+ yield payload
+ finally:
+ duration = round(time.perf_counter() - start, 3)
+ payload["duration_seconds"] = duration
+ record_pipeline_metric(
+ phase,
+ source_id=source_id,
+ dataset_id=dataset_id,
+ duration_seconds=duration,
+ metadata=payload,
+ )
+
+
+def record_pipeline_metric(
+ phase: str,
+ *,
+ source_id: int | None = None,
+ dataset_id: int | None = None,
+ duration_seconds: float | None = None,
+ metadata: dict[str, object] | None = None,
+) -> None:
+ path = _metric_path()
+ path.parent.mkdir(parents=True, exist_ok=True)
+ row = {
+ "timestamp": datetime.now(timezone.utc).isoformat(),
+ "phase": phase,
+ "source_id": source_id,
+ "dataset_id": dataset_id,
+ "duration_seconds": duration_seconds,
+ "metadata": metadata or {},
+ }
+ with path.open("a", encoding="utf-8") as handle:
+ handle.write(json.dumps(row, separators=(",", ":"), default=str))
+ handle.write("\n")
+
+
+def _metric_path() -> Path:
+ return settings.data_dir / "metrics" / "pipeline_metrics.jsonl"
diff --git a/app/pipeline/download.py b/app/pipeline/download.py
new file mode 100644
index 0000000..30cbf35
--- /dev/null
+++ b/app/pipeline/download.py
@@ -0,0 +1,111 @@
+from __future__ import annotations
+
+import shutil
+import time
+from pathlib import Path
+from urllib.parse import urlparse
+
+import requests
+
+from app.config import settings
+from app.models import Source
+from app.pipeline.utils import sha256_file
+
+
+def materialize_source(source: Source) -> Path:
+ """Download/copy a source into the local cache and return the file path.
+
+ Files are stored by content hash per source. Re-running an unchanged source
+ reuses the existing cached file instead of creating another timestamped copy.
+ """
+ source_dir = settings.data_dir / "sources" / f"source_{source.id}"
+ source_dir.mkdir(parents=True, exist_ok=True)
+ suffix = _guess_suffix(source.url, source.kind)
+
+ parsed = urlparse(source.url)
+ if parsed.scheme in {"http", "https"}:
+ temp_path = _download_temp_path(source_dir, suffix)
+ existing_size = temp_path.stat().st_size if temp_path.exists() else 0
+ headers = {"Range": f"bytes={existing_size}-"} if existing_size > 0 else None
+ with requests.get(source.url, stream=True, timeout=120, headers=headers) as r:
+ r.raise_for_status()
+ mode = "ab" if existing_size > 0 and r.status_code == 206 else "wb"
+ with temp_path.open(mode) as f:
+ for chunk in r.iter_content(chunk_size=1024 * 1024):
+ if chunk:
+ f.write(chunk)
+ return _store_or_reuse_cached_file(source_dir=source_dir, source_path=temp_path, suffix=suffix, move=True)
+
+ if parsed.scheme == "file":
+ source_path = Path(parsed.path)
+ else:
+ source_path = Path(source.url)
+
+ if not source_path.exists():
+ raise FileNotFoundError(f"Source file does not exist: {source.url}")
+ if _is_relative_to(source_path.resolve(), source_dir.resolve()):
+ return source_path
+ return _store_or_reuse_cached_file(source_dir=source_dir, source_path=source_path, suffix=suffix, move=False)
+
+
+def _download_temp_path(source_dir: Path, suffix: str) -> Path:
+ candidates = sorted(
+ source_dir.glob(f"*.download{suffix}"),
+ key=lambda path: path.stat().st_mtime if path.exists() else 0,
+ reverse=True,
+ )
+ if candidates:
+ return candidates[0]
+ return source_dir / f"{int(time.time())}.download{suffix}"
+
+
+def _guess_suffix(url: str, kind: str) -> str:
+ path = urlparse(url).path or url
+ lower = path.lower()
+ for suffix in (".zip", ".geojson", ".json", ".osm.pbf", ".pbf", ".osm", ".osm.xml", ".osc.gz", ".osc", ".csv"):
+ if lower.endswith(suffix):
+ return suffix
+ if kind == "gtfs":
+ return ".zip"
+ if kind == "osm_geojson":
+ return ".geojson"
+ return ".dat"
+
+
+def _store_or_reuse_cached_file(source_dir: Path, source_path: Path, suffix: str, move: bool) -> Path:
+ source_hash = sha256_file(source_path)
+ target = source_dir / f"{source_hash[:16]}{suffix}"
+
+ if target.exists() and sha256_file(target) == source_hash:
+ if move and source_path != target:
+ source_path.unlink(missing_ok=True)
+ return target
+
+ existing = _find_existing_cached_file(source_dir, source_hash, suffix, exclude=source_path)
+ if existing is not None:
+ if move and source_path != existing:
+ source_path.unlink(missing_ok=True)
+ return existing
+
+ if move:
+ source_path.replace(target)
+ else:
+ shutil.copyfile(source_path, target)
+ return target
+
+
+def _find_existing_cached_file(source_dir: Path, source_hash: str, suffix: str, exclude: Path | None = None) -> Path | None:
+ for candidate in sorted(source_dir.glob(f"*{suffix}")):
+ if exclude is not None and candidate.resolve() == exclude.resolve():
+ continue
+ if candidate.is_file() and sha256_file(candidate) == source_hash:
+ return candidate
+ return None
+
+
+def _is_relative_to(path: Path, parent: Path) -> bool:
+ try:
+ path.relative_to(parent)
+ return True
+ except ValueError:
+ return False
diff --git a/app/pipeline/gtfs.py b/app/pipeline/gtfs.py
new file mode 100644
index 0000000..09f5979
--- /dev/null
+++ b/app/pipeline/gtfs.py
@@ -0,0 +1,1327 @@
+from __future__ import annotations
+
+import csv
+import io
+import json
+import sqlite3
+import zipfile
+from collections import defaultdict
+from collections.abc import Callable
+from pathlib import Path
+from typing import Any, Iterator, Optional
+
+from shapely.geometry import LineString
+from sqlalchemy import func, select, text
+from sqlalchemy.orm import Session
+
+from app.config import settings
+from app.gtfs_storage import GTFS_STORAGE_MAIN, GTFS_STORAGE_METADATA_KEY, GTFS_STORAGE_SIDECAR_STOP_TIMES, effective_gtfs_timetable_storage
+from app.models import (
+ Dataset,
+ GtfsAgency,
+ GtfsCalendar,
+ GtfsCalendarDate,
+ GtfsRoute,
+ GtfsShape,
+ GtfsStop,
+ GtfsStopTime,
+ GtfsTrip,
+ Source,
+)
+from app.osm_classification import infer_osm_route_scope
+from app.performance import measure_pipeline_phase
+from app.pipeline.download import materialize_source
+from app.pipeline.utils import first_nonempty, geometry_json_and_bbox, norm_ref, norm_text, sha256_file
+from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries
+
+
+GTFS_MODE = {
+ 0: "tram",
+ 1: "subway",
+ 2: "train",
+ 3: "bus",
+ 4: "ferry",
+ 5: "cable_tram",
+ 6: "aerialway",
+ 7: "funicular",
+ 11: "trolleybus",
+ 12: "monorail",
+}
+
+GTFS_EXTENDED_MODE_RANGES = [
+ (100, 199, "train"),
+ (400, 499, "subway"),
+ (700, 799, "bus"),
+ (900, 999, "tram"),
+ (1000, 1099, "ferry"),
+ (1100, 1199, "aerialway"),
+ (1200, 1299, "funicular"),
+ (1300, 1399, "aerialway"),
+ (1400, 1499, "monorail"),
+ (1500, 1599, "trolleybus"),
+]
+
+GTFS_IMPORTER_VERSION = "gtfs_import_v6_sidecar_stop_times"
+
+REQUIRED_FILES = {"agency.txt", "stops.txt", "routes.txt", "trips.txt", "stop_times.txt"}
+GTFS_STAGE_BATCH_SIZE = 50_000
+ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, Any] | None], None]
+
+
+def run_gtfs_source(session: Session, source: Source, progress_callback: ProgressCallback | None = None) -> Dataset:
+ local_path = materialize_source(source)
+ source_hash = sha256_file(local_path)
+ existing = session.scalar(
+ select(Dataset)
+ .where(
+ Dataset.source_id == source.id,
+ Dataset.kind == "gtfs",
+ Dataset.sha256 == source_hash,
+ Dataset.is_active.is_(True),
+ Dataset.status == "imported",
+ )
+ .order_by(Dataset.id.desc())
+ )
+ if existing is not None and _dataset_importer_version(existing) == GTFS_IMPORTER_VERSION:
+ return existing
+ return import_gtfs_zip(session=session, source=source, zip_path=local_path, source_hash=source_hash, progress_callback=progress_callback)
+
+
+def import_gtfs_zip(
+ session: Session,
+ source: Source,
+ zip_path: Path,
+ source_hash: str | None = None,
+ progress_callback: ProgressCallback | None = None,
+) -> Dataset:
+ if not zipfile.is_zipfile(zip_path):
+ raise ValueError(f"GTFS source is not a zip file: {zip_path}")
+
+ dataset = Dataset(
+ source_id=source.id,
+ kind="gtfs",
+ local_path=str(zip_path),
+ sha256=source_hash or sha256_file(zip_path),
+ is_active=False,
+ status="staging",
+ )
+ session.add(dataset)
+ session.flush()
+ session.commit()
+
+ stage_path = _gtfs_stage_path(source, dataset, zip_path)
+ _emit_progress(progress_callback, "gtfs_staging_started", f"Staging GTFS zip {zip_path.name}.", 0, None, {"stage_path": str(stage_path)})
+ try:
+ with measure_pipeline_phase("gtfs_staging", source_id=source.id, dataset_id=dataset.id, metadata={"zip_path": str(zip_path), "stage_path": str(stage_path)}) as metric:
+ stage_summary = _stage_gtfs_zip(zip_path, stage_path, progress_callback=progress_callback)
+ metric.update(stage_summary)
+ activation_path = _prepare_gtfs_activation_path(source, dataset, stage_path, stage_summary)
+ _emit_progress(progress_callback, "gtfs_activation_started", "Activating staged GTFS dataset.", None, None, {"stage_path": str(activation_path)})
+ with measure_pipeline_phase("gtfs_activation", source_id=source.id, dataset_id=dataset.id, metadata={"stage_path": str(activation_path)}) as metric:
+ _activate_staged_gtfs(session, source, dataset, activation_path, stage_summary, progress_callback=progress_callback)
+ metric.update(stage_summary)
+ except BaseException:
+ session.rollback()
+ failed = session.get(Dataset, dataset.id)
+ if failed is not None:
+ failed.status = "failed"
+ failed.is_active = False
+ session.commit()
+ raise
+
+ source.status = "ok"
+ source.last_error = None
+ session.flush()
+ _emit_progress(progress_callback, "gtfs_activation_completed", f"Activated GTFS dataset #{dataset.id}.", None, None, {"dataset_id": dataset.id})
+ return dataset
+
+
+def backfill_gtfs_shapes(session: Session, dataset_id: int | None = None) -> dict:
+ stmt = select(Dataset).where(Dataset.kind == "gtfs")
+ if dataset_id is not None:
+ stmt = stmt.where(Dataset.id == dataset_id)
+ else:
+ stmt = stmt.where(Dataset.is_active.is_(True))
+ datasets = session.scalars(stmt.order_by(Dataset.id)).all()
+ results = []
+ for dataset in datasets:
+ existing = session.scalar(select(func.count()).select_from(GtfsShape).where(GtfsShape.dataset_id == dataset.id)) or 0
+ if existing:
+ results.append({"dataset_id": dataset.id, "status": "skipped", "shapes": existing})
+ continue
+ zip_path = Path(dataset.local_path)
+ if not zip_path.exists() or not zipfile.is_zipfile(zip_path):
+ results.append({"dataset_id": dataset.id, "status": "missing_zip", "path": str(zip_path)})
+ continue
+ with zipfile.ZipFile(zip_path) as zf:
+ names = {Path(name).name: name for name in zf.namelist() if not name.endswith("/")}
+ if "shapes.txt" not in names:
+ results.append({"dataset_id": dataset.id, "status": "no_shapes_txt", "shapes": 0})
+ continue
+ shapes_by_id = _read_shapes(zf, names)
+ imported = _import_shapes(session, dataset.id, shapes_by_id)
+ _record_importer_metadata(dataset, shapes_count=imported)
+ session.flush()
+ results.append({"dataset_id": dataset.id, "status": "imported", "shapes": imported})
+ return {"datasets": results}
+
+
+def _gtfs_stage_path(source: Source, dataset: Dataset, zip_path: Path) -> Path:
+ source_hash = dataset.sha256 or sha256_file(zip_path)
+ return settings.data_dir / "staging" / f"source_{source.id}" / f"gtfs_dataset_{dataset.id}_{source_hash[:12]}.sqlite"
+
+
+def _gtfs_sidecar_path(source: Source, dataset: Dataset) -> Path:
+ source_hash = dataset.sha256 or "unknown"
+ return settings.data_dir / "sidecars" / f"source_{source.id}" / f"gtfs_dataset_{dataset.id}_{source_hash[:12]}.sqlite"
+
+
+def _gtfs_timetable_storage_mode() -> str:
+ return effective_gtfs_timetable_storage()
+
+
+def _prepare_gtfs_activation_path(source: Source, dataset: Dataset, stage_path: Path, summary: dict[str, Any]) -> Path:
+ storage_mode = _gtfs_timetable_storage_mode()
+ if storage_mode == GTFS_STORAGE_SIDECAR_STOP_TIMES:
+ sidecar_path = _gtfs_sidecar_path(source, dataset)
+ sidecar_path.parent.mkdir(parents=True, exist_ok=True)
+ if sidecar_path.exists():
+ sidecar_path.unlink()
+ stage_path.replace(sidecar_path)
+ summary["stage_path"] = str(sidecar_path)
+ summary["staging"] = "sqlite_promoted_to_sidecar"
+ summary[GTFS_STORAGE_METADATA_KEY] = {
+ "mode": GTFS_STORAGE_SIDECAR_STOP_TIMES,
+ "sidecar_path": str(sidecar_path),
+ "tables": {
+ "gtfs_stop_times": "sidecar",
+ "gtfs_agencies": "main",
+ "gtfs_stops": "main",
+ "gtfs_routes": "main",
+ "gtfs_trips": "main",
+ "gtfs_calendars": "main",
+ "gtfs_calendar_dates": "main",
+ "gtfs_shapes": "main",
+ },
+ }
+ return sidecar_path
+
+ summary[GTFS_STORAGE_METADATA_KEY] = {
+ "mode": GTFS_STORAGE_MAIN,
+ "tables": {
+ "gtfs_stop_times": "main",
+ "gtfs_agencies": "main",
+ "gtfs_stops": "main",
+ "gtfs_routes": "main",
+ "gtfs_trips": "main",
+ "gtfs_calendars": "main",
+ "gtfs_calendar_dates": "main",
+ "gtfs_shapes": "main",
+ },
+ }
+ return stage_path
+
+
+def _stage_gtfs_zip(zip_path: Path, stage_path: Path, progress_callback: ProgressCallback | None = None) -> dict[str, Any]:
+ if stage_path.exists():
+ stage_path.unlink()
+ stage_path.parent.mkdir(parents=True, exist_ok=True)
+ connection = sqlite3.connect(stage_path)
+ try:
+ _configure_stage_connection(connection)
+ _create_gtfs_stage_schema(connection)
+ with zipfile.ZipFile(zip_path) as zf:
+ names = {Path(name).name: name for name in zf.namelist() if not name.endswith("/")}
+ missing = sorted(REQUIRED_FILES - set(names.keys()))
+ agency_names = _stage_agencies(connection, zf, names, progress_callback)
+ calendars_count = _stage_calendars(connection, zf, names, progress_callback)
+ calendar_dates_count = _stage_calendar_dates(connection, zf, names, progress_callback)
+ stops_by_id, stops_count = _stage_stops(connection, zf, names, progress_callback)
+ trips_by_route, first_shape_by_route, first_trip_by_route, trips_count = _stage_trips(connection, zf, names, progress_callback)
+ shapes_by_id = _read_shapes_with_progress(zf, names, progress_callback)
+ shapes_count = _stage_shapes(connection, shapes_by_id, progress_callback)
+ stopseq_by_trip, stop_times_seen, stop_times_imported = _stage_stop_times(
+ connection,
+ zf,
+ names,
+ first_trip_ids=set(first_trip_by_route.values()),
+ progress_callback=progress_callback,
+ )
+ routes_count = _stage_routes(
+ connection=connection,
+ routes_raw=list(_read_gtfs_csv(zf, names, "routes.txt")),
+ agency_names=agency_names,
+ stops_by_id=stops_by_id,
+ trips_by_route=trips_by_route,
+ first_shape_by_route=first_shape_by_route,
+ first_trip_by_route=first_trip_by_route,
+ shapes_by_id=shapes_by_id,
+ stopseq_by_trip=stopseq_by_trip,
+ progress_callback=progress_callback,
+ )
+ _create_gtfs_stage_indexes(connection, progress_callback)
+ connection.commit()
+ summary = {
+ "importer": GTFS_IMPORTER_VERSION,
+ "stage_path": str(stage_path),
+ "missing_required_files": missing,
+ "agencies": agency_names and len(agency_names) or 0,
+ "stops": stops_count,
+ "routes": routes_count,
+ "trips": trips_count,
+ "calendars": calendars_count,
+ "calendar_dates": calendar_dates_count,
+ "shapes": shapes_count,
+ "stop_times_seen": stop_times_seen,
+ "stop_times_imported": stop_times_imported,
+ "stop_times_import_limit": settings.gtfs_stop_times_import_limit,
+ "staging": "sqlite",
+ }
+ _emit_progress(progress_callback, "gtfs_staging_completed", "GTFS staging completed.", None, None, summary)
+ return summary
+ finally:
+ connection.close()
+
+
+def _configure_stage_connection(connection: sqlite3.Connection) -> None:
+ connection.execute("PRAGMA journal_mode=OFF")
+ connection.execute("PRAGMA synchronous=OFF")
+ connection.execute("PRAGMA temp_store=MEMORY")
+ connection.execute("PRAGMA locking_mode=EXCLUSIVE")
+
+
+def _create_gtfs_stage_schema(connection: sqlite3.Connection) -> None:
+ connection.executescript(
+ """
+ CREATE TABLE gtfs_agencies (
+ agency_id TEXT NOT NULL,
+ name TEXT NOT NULL,
+ url TEXT,
+ timezone TEXT
+ );
+ CREATE TABLE gtfs_stops (
+ stop_id TEXT NOT NULL,
+ name TEXT,
+ lat REAL,
+ lon REAL,
+ parent_station TEXT
+ );
+ CREATE TABLE gtfs_routes (
+ route_id TEXT NOT NULL,
+ agency_id TEXT,
+ short_name TEXT,
+ long_name TEXT,
+ route_type INTEGER,
+ mode TEXT,
+ route_scope TEXT,
+ operator_name TEXT,
+ geometry_geojson TEXT,
+ min_lon REAL,
+ min_lat REAL,
+ max_lon REAL,
+ max_lat REAL,
+ route_key TEXT,
+ operator_key TEXT
+ );
+ CREATE TABLE gtfs_trips (
+ route_id TEXT NOT NULL,
+ trip_id TEXT NOT NULL,
+ service_id TEXT,
+ shape_id TEXT
+ );
+ CREATE TABLE gtfs_calendars (
+ service_id TEXT NOT NULL,
+ monday INTEGER NOT NULL,
+ tuesday INTEGER NOT NULL,
+ wednesday INTEGER NOT NULL,
+ thursday INTEGER NOT NULL,
+ friday INTEGER NOT NULL,
+ saturday INTEGER NOT NULL,
+ sunday INTEGER NOT NULL,
+ start_date INTEGER NOT NULL,
+ end_date INTEGER NOT NULL
+ );
+ CREATE TABLE gtfs_calendar_dates (
+ service_id TEXT NOT NULL,
+ date INTEGER NOT NULL,
+ exception_type INTEGER NOT NULL
+ );
+ CREATE TABLE gtfs_shapes (
+ shape_id TEXT NOT NULL,
+ geometry_geojson TEXT NOT NULL,
+ min_lon REAL,
+ min_lat REAL,
+ max_lon REAL,
+ max_lat REAL
+ );
+ CREATE TABLE gtfs_stop_times (
+ trip_id TEXT NOT NULL,
+ stop_id TEXT NOT NULL,
+ stop_sequence INTEGER NOT NULL,
+ arrival_time TEXT,
+ departure_time TEXT,
+ arrival_seconds INTEGER,
+ departure_seconds INTEGER
+ );
+ """
+ )
+
+
+def _create_gtfs_stage_indexes(connection: sqlite3.Connection, progress_callback: ProgressCallback | None = None) -> None:
+ _emit_progress(progress_callback, "gtfs_stage_indexes_started", "Building GTFS stage indexes.", None, None, None)
+ for statement in [
+ "CREATE INDEX IF NOT EXISTS ix_stage_gtfs_stop_times_stop_depart_trip ON gtfs_stop_times (stop_id, departure_seconds, trip_id, stop_sequence)",
+ "CREATE INDEX IF NOT EXISTS ix_stage_gtfs_stop_times_stop_arrive_trip ON gtfs_stop_times (stop_id, arrival_seconds, trip_id, stop_sequence)",
+ "CREATE INDEX IF NOT EXISTS ix_stage_gtfs_stop_times_trip_seq ON gtfs_stop_times (trip_id, stop_sequence)",
+ "CREATE INDEX IF NOT EXISTS ix_stage_gtfs_stop_times_trip_stop_seq ON gtfs_stop_times (trip_id, stop_id, stop_sequence)",
+ "CREATE INDEX IF NOT EXISTS ix_stage_gtfs_trips_trip ON gtfs_trips (trip_id)",
+ "CREATE INDEX IF NOT EXISTS ix_stage_gtfs_trips_service_trip ON gtfs_trips (service_id, trip_id)",
+ "CREATE INDEX IF NOT EXISTS ix_stage_gtfs_trips_route_service ON gtfs_trips (route_id, service_id)",
+ ]:
+ connection.execute(statement)
+ _emit_progress(progress_callback, "gtfs_stage_indexes_completed", "Built GTFS stage indexes.", None, None, None)
+
+
+def _activate_staged_gtfs(
+ session: Session,
+ source: Source,
+ dataset: Dataset,
+ stage_path: Path,
+ summary: dict[str, Any],
+ progress_callback: ProgressCallback | None = None,
+) -> None:
+ if not stage_path.exists():
+ raise FileNotFoundError(f"GTFS staging database is missing: {stage_path}")
+ dataset = session.get(Dataset, dataset.id) or dataset
+ source = session.get(Source, source.id) or source
+ replaced_datasets = [existing for existing in list(source.datasets) if existing.id != dataset.id and existing.kind == "gtfs"]
+ for existing in source.datasets:
+ if existing.id != dataset.id:
+ existing.is_active = False
+ copy_stop_times = _copy_stop_times_to_main(summary)
+ heavy_index_drop = copy_stop_times and _should_drop_indexes_for_activation(stage_path)
+ if heavy_index_drop:
+ _emit_progress(progress_callback, "gtfs_activation_indexes_dropped", "Dropping heavy GTFS lookup indexes before bulk activation.", None, None, None)
+ _drop_gtfs_bulk_indexes(session.connection())
+ try:
+ if replaced_datasets:
+ _emit_progress(
+ progress_callback,
+ "gtfs_activation_pruning_replaced",
+ f"Pruning {len(replaced_datasets)} replaced GTFS dataset(s) before activation.",
+ None,
+ None,
+ {"dataset_ids": [dataset.id for dataset in replaced_datasets]},
+ )
+ from app.data_management import _delete_dataset_files, _delete_dataset_rows, _detach_update_checks_for_dataset
+
+ for old_dataset in replaced_datasets:
+ _detach_update_checks_for_dataset(session, old_dataset.id)
+ _delete_dataset_rows(session, old_dataset)
+ _delete_dataset_files(old_dataset)
+ session.delete(old_dataset)
+ with sqlite3.connect(stage_path) as stage_connection:
+ _copy_stage_table(
+ session,
+ stage_connection,
+ dataset.id,
+ "gtfs_agencies",
+ ["agency_id", "name", "url", "timezone"],
+ progress_callback,
+ )
+ _copy_stage_table(
+ session,
+ stage_connection,
+ dataset.id,
+ "gtfs_stops",
+ ["stop_id", "name", "lat", "lon", "parent_station"],
+ progress_callback,
+ )
+ _copy_stage_table(
+ session,
+ stage_connection,
+ dataset.id,
+ "gtfs_calendars",
+ ["service_id", "monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday", "start_date", "end_date"],
+ progress_callback,
+ )
+ _copy_stage_table(
+ session,
+ stage_connection,
+ dataset.id,
+ "gtfs_calendar_dates",
+ ["service_id", "date", "exception_type"],
+ progress_callback,
+ )
+ _copy_stage_table(
+ session,
+ stage_connection,
+ dataset.id,
+ "gtfs_trips",
+ ["route_id", "trip_id", "service_id", "shape_id"],
+ progress_callback,
+ )
+ _copy_stage_table(
+ session,
+ stage_connection,
+ dataset.id,
+ "gtfs_shapes",
+ ["shape_id", "geometry_geojson", "min_lon", "min_lat", "max_lon", "max_lat"],
+ progress_callback,
+ )
+ if copy_stop_times:
+ _copy_stage_table(
+ session,
+ stage_connection,
+ dataset.id,
+ "gtfs_stop_times",
+ ["trip_id", "stop_id", "stop_sequence", "arrival_time", "departure_time", "arrival_seconds", "departure_seconds"],
+ progress_callback,
+ )
+ else:
+ _emit_progress(
+ progress_callback,
+ "gtfs_activation_sidecar_stop_times",
+ "Kept gtfs_stop_times in sidecar storage.",
+ None,
+ None,
+ {"table": "gtfs_stop_times", "sidecar_path": str(stage_path)},
+ )
+ _copy_stage_table(
+ session,
+ stage_connection,
+ dataset.id,
+ "gtfs_routes",
+ [
+ "route_id",
+ "agency_id",
+ "short_name",
+ "long_name",
+ "route_type",
+ "mode",
+ "route_scope",
+ "operator_name",
+ "geometry_geojson",
+ "min_lon",
+ "min_lat",
+ "max_lon",
+ "max_lat",
+ "route_key",
+ "operator_key",
+ ],
+ progress_callback,
+ )
+ finally:
+ if heavy_index_drop:
+ _emit_progress(progress_callback, "gtfs_activation_indexes_rebuilding", "Rebuilding GTFS lookup indexes after bulk activation.", None, None, None)
+ _create_gtfs_bulk_indexes(session.connection())
+ dataset.status = "imported"
+ dataset.is_active = True
+ dataset.metadata_json = json.dumps(summary, indent=2)
+ source.status = "ok"
+ source.last_error = None
+ session.flush()
+ refresh_postgis_geometries(session, dataset_id=dataset.id, tables=["gtfs_stops", "gtfs_routes", "gtfs_shapes"])
+ analyze_postgresql_tables(session, ["gtfs_stops", "gtfs_routes", "gtfs_shapes", "gtfs_trips", "gtfs_stop_times"])
+ if copy_stop_times and not settings.gtfs_keep_activation_stage:
+ try:
+ stage_path.unlink()
+ except FileNotFoundError:
+ pass
+
+
+def _copy_stop_times_to_main(summary: dict[str, Any]) -> bool:
+ storage = summary.get(GTFS_STORAGE_METADATA_KEY)
+ if not isinstance(storage, dict):
+ return True
+ tables = storage.get("tables")
+ if isinstance(tables, dict):
+ return tables.get("gtfs_stop_times") != "sidecar"
+ return storage.get("mode") != GTFS_STORAGE_SIDECAR_STOP_TIMES
+
+
+def _copy_stage_table(
+ session: Session,
+ stage_connection: sqlite3.Connection,
+ dataset_id: int,
+ table: str,
+ columns: list[str],
+ progress_callback: ProgressCallback | None,
+) -> None:
+ column_sql = ", ".join(columns)
+ placeholders = ", ".join([":dataset_id", *[f":{column}" for column in columns]])
+ insert_sql = f"INSERT INTO {table} (dataset_id, {column_sql}) VALUES ({placeholders})"
+ cursor = stage_connection.execute(f"SELECT {column_sql} FROM {table}")
+ copied = 0
+ while True:
+ rows = cursor.fetchmany(GTFS_STAGE_BATCH_SIZE)
+ if not rows:
+ break
+ payload = [
+ {"dataset_id": dataset_id, **{column: row[index] for index, column in enumerate(columns)}}
+ for row in rows
+ ]
+ session.execute(text(insert_sql), payload)
+ copied += len(rows)
+ _emit_progress(
+ progress_callback,
+ "gtfs_activation_chunk",
+ f"Activated {table} chunk.",
+ copied,
+ None,
+ {"table": table, "rows": copied},
+ )
+
+
+def _should_drop_indexes_for_activation(stage_path: Path) -> bool:
+ if settings.is_postgresql_database:
+ return False
+ try:
+ with sqlite3.connect(stage_path) as connection:
+ stop_times = connection.execute("SELECT COUNT(*) FROM gtfs_stop_times").fetchone()[0]
+ trips = connection.execute("SELECT COUNT(*) FROM gtfs_trips").fetchone()[0]
+ except sqlite3.Error:
+ return False
+ return int(stop_times or 0) >= 250_000 or int(trips or 0) >= 100_000
+
+
+def _drop_gtfs_bulk_indexes(connection) -> None:
+ for index_name in [
+ "ix_gtfs_stop_times_stop",
+ "ix_gtfs_stop_times_stop_depart_trip",
+ "ix_gtfs_stop_times_stop_arrival",
+ "ix_gtfs_stop_times_stop_arrive_trip",
+ "ix_gtfs_stop_times_trip_seq",
+ "ix_gtfs_stop_times_trip_stop_seq",
+ "ix_gtfs_trips_dataset_trip",
+ "ix_gtfs_trips_dataset_route",
+ "ix_gtfs_trips_dataset_service",
+ "ix_gtfs_trips_dataset_route_service",
+ "ix_gtfs_routes_dataset_route",
+ "ix_gtfs_shapes_dataset_shape",
+ "ix_gtfs_calendars_dataset_service_dates",
+ "ix_gtfs_calendar_dates_dataset_date",
+ ]:
+ connection.exec_driver_sql(f"DROP INDEX IF EXISTS {index_name}")
+
+
+def _create_gtfs_bulk_indexes(connection) -> None:
+ for statement in [
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop ON gtfs_stop_times (dataset_id, stop_id, departure_seconds, trip_id, stop_sequence)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop_depart_trip ON gtfs_stop_times (dataset_id, stop_id, departure_seconds, trip_id)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop_arrival ON gtfs_stop_times (dataset_id, stop_id, arrival_seconds, trip_id, stop_sequence)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_stop_arrive_trip ON gtfs_stop_times (dataset_id, stop_id, arrival_seconds, trip_id)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_trip_seq ON gtfs_stop_times (dataset_id, trip_id, stop_sequence)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_stop_times_trip_stop_seq ON gtfs_stop_times (dataset_id, trip_id, stop_id, stop_sequence)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_trip ON gtfs_trips (dataset_id, trip_id)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_route ON gtfs_trips (dataset_id, route_id)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_service ON gtfs_trips (dataset_id, service_id, trip_id)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_trips_dataset_route_service ON gtfs_trips (dataset_id, route_id, service_id)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_routes_dataset_route ON gtfs_routes (dataset_id, route_id)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_shapes_dataset_shape ON gtfs_shapes (dataset_id, shape_id)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_calendars_dataset_service_dates ON gtfs_calendars (dataset_id, service_id, start_date, end_date)",
+ "CREATE INDEX IF NOT EXISTS ix_gtfs_calendar_dates_dataset_date ON gtfs_calendar_dates (dataset_id, date, service_id, exception_type)",
+ ]:
+ connection.exec_driver_sql(statement)
+
+
+def _stage_agencies(
+ connection: sqlite3.Connection,
+ zf: zipfile.ZipFile,
+ names: dict[str, str],
+ progress_callback: ProgressCallback | None,
+) -> dict[str, str]:
+ _emit_progress(progress_callback, "gtfs_file_started", "Reading agency.txt.", None, None, {"file": "agency.txt"})
+ agency_names: dict[str, str] = {}
+ rows = []
+ for idx, row in enumerate(_read_gtfs_csv(zf, names, "agency.txt")):
+ agency_id = first_nonempty(row.get("agency_id"), f"agency_{idx}")
+ name = first_nonempty(row.get("agency_name"), agency_id)
+ agency_names[agency_id] = name
+ rows.append((agency_id, name, row.get("agency_url") or None, row.get("agency_timezone") or None))
+ connection.executemany("INSERT INTO gtfs_agencies (agency_id, name, url, timezone) VALUES (?, ?, ?, ?)", rows)
+ _emit_progress(progress_callback, "gtfs_file_completed", "Imported agency.txt.", len(rows), None, {"file": "agency.txt", "rows": len(rows)})
+ return agency_names
+
+
+def _stage_calendars(
+ connection: sqlite3.Connection,
+ zf: zipfile.ZipFile,
+ names: dict[str, str],
+ progress_callback: ProgressCallback | None,
+) -> int:
+ _emit_progress(progress_callback, "gtfs_file_started", "Reading calendar.txt.", None, None, {"file": "calendar.txt"})
+ rows = []
+ for row in _read_gtfs_csv(zf, names, "calendar.txt"):
+ service_id = row.get("service_id") or ""
+ start_date = _int_or_none(row.get("start_date"))
+ end_date = _int_or_none(row.get("end_date"))
+ if not service_id or start_date is None or end_date is None:
+ continue
+ rows.append(
+ (
+ service_id,
+ int(_bool_flag(row.get("monday"))),
+ int(_bool_flag(row.get("tuesday"))),
+ int(_bool_flag(row.get("wednesday"))),
+ int(_bool_flag(row.get("thursday"))),
+ int(_bool_flag(row.get("friday"))),
+ int(_bool_flag(row.get("saturday"))),
+ int(_bool_flag(row.get("sunday"))),
+ start_date,
+ end_date,
+ )
+ )
+ connection.executemany(
+ """
+ INSERT INTO gtfs_calendars
+ (service_id, monday, tuesday, wednesday, thursday, friday, saturday, sunday, start_date, end_date)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+ """,
+ rows,
+ )
+ _emit_progress(progress_callback, "gtfs_file_completed", "Imported calendar.txt.", len(rows), None, {"file": "calendar.txt", "rows": len(rows)})
+ return len(rows)
+
+
+def _stage_calendar_dates(
+ connection: sqlite3.Connection,
+ zf: zipfile.ZipFile,
+ names: dict[str, str],
+ progress_callback: ProgressCallback | None,
+) -> int:
+ return _stage_chunked_rows(
+ connection=connection,
+ zf=zf,
+ names=names,
+ basename="calendar_dates.txt",
+ insert_sql="INSERT INTO gtfs_calendar_dates (service_id, date, exception_type) VALUES (?, ?, ?)",
+ row_factory=lambda row: (
+ row.get("service_id") or "",
+ _int_or_none(row.get("date")),
+ _int_or_none(row.get("exception_type")),
+ ),
+ validator=lambda row: bool(row[0]) and row[1] is not None and row[2] is not None,
+ progress_callback=progress_callback,
+ )
+
+
+def _stage_stops(
+ connection: sqlite3.Connection,
+ zf: zipfile.ZipFile,
+ names: dict[str, str],
+ progress_callback: ProgressCallback | None,
+) -> tuple[dict[str, tuple[float, float, str]], int]:
+ _emit_progress(progress_callback, "gtfs_file_started", "Reading stops.txt.", None, None, {"file": "stops.txt"})
+ stops_by_id: dict[str, tuple[float, float, str]] = {}
+ rows = []
+ for row in _read_gtfs_csv(zf, names, "stops.txt"):
+ stop_id = row.get("stop_id", "")
+ if not stop_id:
+ continue
+ lat = _float_or_none(row.get("stop_lat"))
+ lon = _float_or_none(row.get("stop_lon"))
+ name = row.get("stop_name") or None
+ if lat is not None and lon is not None:
+ stops_by_id[stop_id] = (lon, lat, name or stop_id)
+ rows.append((stop_id, name, lat, lon, row.get("parent_station") or None))
+ connection.executemany("INSERT INTO gtfs_stops (stop_id, name, lat, lon, parent_station) VALUES (?, ?, ?, ?, ?)", rows)
+ _emit_progress(progress_callback, "gtfs_file_completed", "Imported stops.txt.", len(rows), None, {"file": "stops.txt", "rows": len(rows)})
+ return stops_by_id, len(rows)
+
+
+def _stage_trips(
+ connection: sqlite3.Connection,
+ zf: zipfile.ZipFile,
+ names: dict[str, str],
+ progress_callback: ProgressCallback | None,
+) -> tuple[dict[str, list[str]], dict[str, str], dict[str, str], int]:
+ _emit_progress(progress_callback, "gtfs_file_started", "Reading trips.txt.", None, None, {"file": "trips.txt"})
+ trips_by_route: dict[str, list[str]] = defaultdict(list)
+ first_shape_by_route: dict[str, str] = {}
+ first_trip_by_route: dict[str, str] = {}
+ rows = []
+ imported = 0
+ for row in _read_gtfs_csv(zf, names, "trips.txt"):
+ route_id = row.get("route_id", "")
+ trip_id = row.get("trip_id", "")
+ if not route_id or not trip_id:
+ continue
+ trips_by_route[route_id].append(trip_id)
+ first_trip_by_route.setdefault(route_id, trip_id)
+ shape_id = row.get("shape_id") or ""
+ if shape_id:
+ first_shape_by_route.setdefault(route_id, shape_id)
+ rows.append((route_id, trip_id, row.get("service_id") or None, shape_id or None))
+ imported += 1
+ if len(rows) >= GTFS_STAGE_BATCH_SIZE:
+ connection.executemany("INSERT INTO gtfs_trips (route_id, trip_id, service_id, shape_id) VALUES (?, ?, ?, ?)", rows)
+ rows.clear()
+ _emit_progress(progress_callback, "gtfs_file_chunk", "Imported trips.txt chunk.", imported, None, {"file": "trips.txt", "rows": imported})
+ if rows:
+ connection.executemany("INSERT INTO gtfs_trips (route_id, trip_id, service_id, shape_id) VALUES (?, ?, ?, ?)", rows)
+ _emit_progress(progress_callback, "gtfs_file_completed", "Imported trips.txt.", imported, None, {"file": "trips.txt", "rows": imported})
+ return dict(trips_by_route), first_shape_by_route, first_trip_by_route, imported
+
+
+def _read_shapes_with_progress(
+ zf: zipfile.ZipFile,
+ names: dict[str, str],
+ progress_callback: ProgressCallback | None,
+) -> dict[str, list[tuple[float, float]]]:
+ _emit_progress(progress_callback, "gtfs_file_started", "Reading shapes.txt.", None, None, {"file": "shapes.txt"})
+ shapes = _read_shapes(zf, names)
+ _emit_progress(progress_callback, "gtfs_file_completed", "Read shapes.txt.", len(shapes), None, {"file": "shapes.txt", "shapes": len(shapes)})
+ return shapes
+
+
+def _stage_shapes(
+ connection: sqlite3.Connection,
+ shapes_by_id: dict[str, list[tuple[float, float]]],
+ progress_callback: ProgressCallback | None,
+) -> int:
+ rows = []
+ imported = 0
+ for shape_id, coords in shapes_by_id.items():
+ if len(coords) < 2:
+ continue
+ geometry_text, bbox = geometry_json_and_bbox(LineString(coords))
+ if geometry_text is None:
+ continue
+ rows.append((shape_id, geometry_text, bbox[0], bbox[1], bbox[2], bbox[3]))
+ imported += 1
+ if len(rows) >= 5000:
+ connection.executemany(
+ "INSERT INTO gtfs_shapes (shape_id, geometry_geojson, min_lon, min_lat, max_lon, max_lat) VALUES (?, ?, ?, ?, ?, ?)",
+ rows,
+ )
+ rows.clear()
+ _emit_progress(progress_callback, "gtfs_file_chunk", "Imported shapes chunk.", imported, None, {"file": "shapes.txt", "rows": imported})
+ if rows:
+ connection.executemany(
+ "INSERT INTO gtfs_shapes (shape_id, geometry_geojson, min_lon, min_lat, max_lon, max_lat) VALUES (?, ?, ?, ?, ?, ?)",
+ rows,
+ )
+ _emit_progress(progress_callback, "gtfs_file_completed", "Imported shapes.", imported, None, {"file": "shapes.txt", "rows": imported})
+ return imported
+
+
+def _stage_stop_times(
+ connection: sqlite3.Connection,
+ zf: zipfile.ZipFile,
+ names: dict[str, str],
+ first_trip_ids: set[str],
+ progress_callback: ProgressCallback | None,
+) -> tuple[dict[str, list[str]], int, int]:
+ _emit_progress(progress_callback, "gtfs_file_started", "Reading stop_times.txt.", None, None, {"file": "stop_times.txt"})
+ stopseq_by_trip: dict[str, list[tuple[int, str]]] = defaultdict(list)
+ rows = []
+ count = 0
+ imported = 0
+ limit = settings.gtfs_stop_times_import_limit
+ for row in _read_gtfs_csv(zf, names, "stop_times.txt"):
+ count += 1
+ trip_id = row.get("trip_id", "")
+ stop_id = row.get("stop_id", "")
+ seq = _int_or_none(row.get("stop_sequence"))
+ if not trip_id or not stop_id or seq is None:
+ continue
+ if trip_id in first_trip_ids:
+ stopseq_by_trip[trip_id].append((seq, stop_id))
+ if limit <= 0 or imported < limit:
+ arrival_time = row.get("arrival_time") or None
+ departure_time = row.get("departure_time") or None
+ rows.append((trip_id, stop_id, seq, arrival_time, departure_time, _time_seconds(arrival_time), _time_seconds(departure_time)))
+ imported += 1
+ if len(rows) >= GTFS_STAGE_BATCH_SIZE:
+ connection.executemany(
+ """
+ INSERT INTO gtfs_stop_times
+ (trip_id, stop_id, stop_sequence, arrival_time, departure_time, arrival_seconds, departure_seconds)
+ VALUES (?, ?, ?, ?, ?, ?, ?)
+ """,
+ rows,
+ )
+ rows.clear()
+ _emit_progress(progress_callback, "gtfs_file_chunk", "Imported stop_times.txt chunk.", imported, None, {"file": "stop_times.txt", "rows": imported, "seen": count})
+ if rows:
+ connection.executemany(
+ """
+ INSERT INTO gtfs_stop_times
+ (trip_id, stop_id, stop_sequence, arrival_time, departure_time, arrival_seconds, departure_seconds)
+ VALUES (?, ?, ?, ?, ?, ?, ?)
+ """,
+ rows,
+ )
+ _emit_progress(progress_callback, "gtfs_file_completed", "Imported stop_times.txt.", imported, None, {"file": "stop_times.txt", "rows": imported, "seen": count})
+ return {trip: [stop for _, stop in sorted(seq)] for trip, seq in stopseq_by_trip.items()}, count, imported
+
+
+def _stage_routes(
+ connection: sqlite3.Connection,
+ routes_raw: list[dict[str, str]],
+ agency_names: dict[str, str],
+ stops_by_id: dict[str, tuple[float, float, str]],
+ trips_by_route: dict[str, list[str]],
+ first_shape_by_route: dict[str, str],
+ first_trip_by_route: dict[str, str],
+ shapes_by_id: dict[str, list[tuple[float, float]]],
+ stopseq_by_trip: dict[str, list[str]],
+ progress_callback: ProgressCallback | None,
+) -> int:
+ _emit_progress(progress_callback, "gtfs_file_started", "Reading routes.txt.", None, None, {"file": "routes.txt"})
+ rows = []
+ for row in routes_raw:
+ route_id = row.get("route_id", "")
+ if not route_id:
+ continue
+ route_type = _int_or_none(row.get("route_type"))
+ mode = _gtfs_mode(route_type)
+ agency_id = row.get("agency_id") or None
+ operator = agency_names.get(agency_id or "", agency_id or "")
+ short_name = row.get("route_short_name") or None
+ long_name = row.get("route_long_name") or None
+ route_scope = infer_osm_route_scope(mode=mode, ref=short_name, name=long_name, network=operator)
+ geometry = _route_geometry(route_id, first_shape_by_route, first_trip_by_route, shapes_by_id, stopseq_by_trip, stops_by_id)
+ geometry_text, bbox = geometry_json_and_bbox(geometry) if geometry is not None else (None, (None, None, None, None))
+ rows.append(
+ (
+ route_id,
+ agency_id,
+ short_name,
+ long_name,
+ route_type,
+ mode,
+ route_scope,
+ operator or None,
+ geometry_text,
+ bbox[0],
+ bbox[1],
+ bbox[2],
+ bbox[3],
+ norm_ref(short_name) or norm_text(long_name) or norm_ref(route_id),
+ norm_text(operator),
+ )
+ )
+ connection.executemany(
+ """
+ INSERT INTO gtfs_routes
+ (route_id, agency_id, short_name, long_name, route_type, mode, route_scope, operator_name, geometry_geojson, min_lon, min_lat, max_lon, max_lat, route_key, operator_key)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+ """,
+ rows,
+ )
+ _emit_progress(progress_callback, "gtfs_file_completed", "Imported routes.txt.", len(rows), None, {"file": "routes.txt", "rows": len(rows)})
+ return len(rows)
+
+
+def _stage_chunked_rows(
+ connection: sqlite3.Connection,
+ zf: zipfile.ZipFile,
+ names: dict[str, str],
+ basename: str,
+ insert_sql: str,
+ row_factory,
+ validator,
+ progress_callback: ProgressCallback | None,
+) -> int:
+ _emit_progress(progress_callback, "gtfs_file_started", f"Reading {basename}.", None, None, {"file": basename})
+ rows = []
+ imported = 0
+ for raw in _read_gtfs_csv(zf, names, basename):
+ row = row_factory(raw)
+ if not validator(row):
+ continue
+ rows.append(row)
+ imported += 1
+ if len(rows) >= GTFS_STAGE_BATCH_SIZE:
+ connection.executemany(insert_sql, rows)
+ rows.clear()
+ _emit_progress(progress_callback, "gtfs_file_chunk", f"Imported {basename} chunk.", imported, None, {"file": basename, "rows": imported})
+ if rows:
+ connection.executemany(insert_sql, rows)
+ _emit_progress(progress_callback, "gtfs_file_completed", f"Imported {basename}.", imported, None, {"file": basename, "rows": imported})
+ return imported
+
+
+def _emit_progress(
+ progress_callback: ProgressCallback | None,
+ event_type: str,
+ message: str,
+ progress_current: int | None = None,
+ progress_total: int | None = None,
+ metadata: dict[str, Any] | None = None,
+) -> None:
+ if progress_callback is not None:
+ progress_callback(event_type, message, progress_current, progress_total, metadata)
+
+
+def _read_gtfs_csv(zf: zipfile.ZipFile, names: dict[str, str], basename: str) -> Iterator[dict[str, str]]:
+ if basename not in names:
+ return iter(())
+
+ def _iter() -> Iterator[dict[str, str]]:
+ with zf.open(names[basename], "r") as raw:
+ text = io.TextIOWrapper(raw, encoding="utf-8-sig", newline="")
+ reader = csv.DictReader(text)
+ for row in reader:
+ yield {str(k).strip(): (v or "").strip() for k, v in row.items() if k is not None}
+
+ return _iter()
+
+
+def _record_importer_metadata(dataset: Dataset, shapes_count: int | None = None) -> None:
+ metadata = {}
+ if dataset.metadata_json:
+ try:
+ metadata = json.loads(dataset.metadata_json)
+ except json.JSONDecodeError:
+ metadata = {}
+ metadata["importer"] = GTFS_IMPORTER_VERSION
+ if shapes_count is not None:
+ metadata["shapes"] = shapes_count
+ dataset.metadata_json = json.dumps(metadata, indent=2)
+
+
+def _import_agencies(session: Session, dataset_id: int, rows: list[dict[str, str]]) -> dict[str, str]:
+ agency_names: dict[str, str] = {}
+ objects: list[GtfsAgency] = []
+ for idx, row in enumerate(rows):
+ agency_id = first_nonempty(row.get("agency_id"), f"agency_{idx}")
+ name = first_nonempty(row.get("agency_name"), agency_id)
+ agency_names[agency_id] = name
+ objects.append(
+ GtfsAgency(
+ dataset_id=dataset_id,
+ agency_id=agency_id,
+ name=name,
+ url=row.get("agency_url") or None,
+ timezone=row.get("agency_timezone") or None,
+ )
+ )
+ if objects:
+ session.bulk_save_objects(objects)
+ return agency_names
+
+
+def _import_calendars(session: Session, dataset_id: int, rows: list[dict[str, str]]) -> int:
+ objects: list[GtfsCalendar] = []
+ for row in rows:
+ service_id = row.get("service_id") or ""
+ start_date = _int_or_none(row.get("start_date"))
+ end_date = _int_or_none(row.get("end_date"))
+ if not service_id or start_date is None or end_date is None:
+ continue
+ objects.append(
+ GtfsCalendar(
+ dataset_id=dataset_id,
+ service_id=service_id,
+ monday=_bool_flag(row.get("monday")),
+ tuesday=_bool_flag(row.get("tuesday")),
+ wednesday=_bool_flag(row.get("wednesday")),
+ thursday=_bool_flag(row.get("thursday")),
+ friday=_bool_flag(row.get("friday")),
+ saturday=_bool_flag(row.get("saturday")),
+ sunday=_bool_flag(row.get("sunday")),
+ start_date=start_date,
+ end_date=end_date,
+ )
+ )
+ if objects:
+ session.bulk_save_objects(objects)
+ return len(objects)
+
+
+def _import_calendar_dates(session: Session, dataset_id: int, rows: list[dict[str, str]]) -> int:
+ objects: list[GtfsCalendarDate] = []
+ for row in rows:
+ service_id = row.get("service_id") or ""
+ date = _int_or_none(row.get("date"))
+ exception_type = _int_or_none(row.get("exception_type"))
+ if not service_id or date is None or exception_type is None:
+ continue
+ objects.append(
+ GtfsCalendarDate(
+ dataset_id=dataset_id,
+ service_id=service_id,
+ date=date,
+ exception_type=exception_type,
+ )
+ )
+ for batch_start in range(0, len(objects), 5000):
+ session.bulk_save_objects(objects[batch_start : batch_start + 5000])
+ return len(objects)
+
+
+def _import_stops(session: Session, dataset_id: int, rows: list[dict[str, str]]) -> dict[str, tuple[float, float, str]]:
+ stops_by_id: dict[str, tuple[float, float, str]] = {}
+ objects: list[GtfsStop] = []
+ for row in rows:
+ stop_id = row.get("stop_id", "")
+ if not stop_id:
+ continue
+ lat = _float_or_none(row.get("stop_lat"))
+ lon = _float_or_none(row.get("stop_lon"))
+ name = row.get("stop_name") or None
+ if lat is not None and lon is not None:
+ stops_by_id[stop_id] = (lon, lat, name or stop_id)
+ objects.append(
+ GtfsStop(
+ dataset_id=dataset_id,
+ stop_id=stop_id,
+ name=name,
+ lat=lat,
+ lon=lon,
+ parent_station=row.get("parent_station") or None,
+ )
+ )
+ if objects:
+ session.bulk_save_objects(objects)
+ return stops_by_id
+
+
+def _import_trips(
+ session: Session, dataset_id: int, rows: list[dict[str, str]]
+) -> tuple[dict[str, list[str]], dict[str, str], dict[str, str]]:
+ trips_by_route: dict[str, list[str]] = defaultdict(list)
+ first_shape_by_route: dict[str, str] = {}
+ first_trip_by_route: dict[str, str] = {}
+ objects: list[GtfsTrip] = []
+ for row in rows:
+ route_id = row.get("route_id", "")
+ trip_id = row.get("trip_id", "")
+ if not route_id or not trip_id:
+ continue
+ trips_by_route[route_id].append(trip_id)
+ first_trip_by_route.setdefault(route_id, trip_id)
+ shape_id = row.get("shape_id") or ""
+ if shape_id:
+ first_shape_by_route.setdefault(route_id, shape_id)
+ objects.append(
+ GtfsTrip(
+ dataset_id=dataset_id,
+ route_id=route_id,
+ trip_id=trip_id,
+ service_id=row.get("service_id") or None,
+ shape_id=shape_id or None,
+ )
+ )
+ for batch_start in range(0, len(objects), 5000):
+ session.bulk_save_objects(objects[batch_start : batch_start + 5000])
+ return dict(trips_by_route), first_shape_by_route, first_trip_by_route
+
+
+def _read_shapes(zf: zipfile.ZipFile, names: dict[str, str]) -> dict[str, list[tuple[float, float]]]:
+ by_shape: dict[str, list[tuple[int, float, float]]] = defaultdict(list)
+ for row in _read_gtfs_csv(zf, names, "shapes.txt"):
+ shape_id = row.get("shape_id", "")
+ lat = _float_or_none(row.get("shape_pt_lat"))
+ lon = _float_or_none(row.get("shape_pt_lon"))
+ seq = _int_or_none(row.get("shape_pt_sequence"))
+ if shape_id and lat is not None and lon is not None:
+ by_shape[shape_id].append((seq if seq is not None else 0, lon, lat))
+ return {shape_id: [(lon, lat) for _, lon, lat in sorted(points)] for shape_id, points in by_shape.items()}
+
+
+def _import_shapes(session: Session, dataset_id: int, shapes_by_id: dict[str, list[tuple[float, float]]]) -> int:
+ objects: list[GtfsShape] = []
+ imported = 0
+ for shape_id, coords in shapes_by_id.items():
+ if len(coords) < 2:
+ continue
+ geometry_text, bbox = geometry_json_and_bbox(LineString(coords))
+ if geometry_text is None:
+ continue
+ objects.append(
+ GtfsShape(
+ dataset_id=dataset_id,
+ shape_id=shape_id,
+ geometry_geojson=geometry_text,
+ min_lon=bbox[0],
+ min_lat=bbox[1],
+ max_lon=bbox[2],
+ max_lat=bbox[3],
+ )
+ )
+ imported += 1
+ if len(objects) >= 1000:
+ session.bulk_save_objects(objects)
+ objects.clear()
+ if objects:
+ session.bulk_save_objects(objects)
+ return imported
+
+
+def _import_stop_times(
+ session: Session,
+ dataset_id: int,
+ zf: zipfile.ZipFile,
+ names: dict[str, str],
+ first_trip_ids: set[str],
+) -> tuple[dict[str, list[str]], int, int]:
+ stopseq_by_trip: dict[str, list[tuple[int, str]]] = defaultdict(list)
+ objects: list[GtfsStopTime] = []
+ count = 0
+ imported = 0
+ limit = settings.gtfs_stop_times_import_limit
+ for row in _read_gtfs_csv(zf, names, "stop_times.txt"):
+ count += 1
+ trip_id = row.get("trip_id", "")
+ stop_id = row.get("stop_id", "")
+ seq = _int_or_none(row.get("stop_sequence"))
+ if not trip_id or not stop_id or seq is None:
+ continue
+ if trip_id in first_trip_ids:
+ stopseq_by_trip[trip_id].append((seq, stop_id))
+ if limit <= 0 or imported < limit:
+ arrival_time = row.get("arrival_time") or None
+ departure_time = row.get("departure_time") or None
+ objects.append(
+ GtfsStopTime(
+ dataset_id=dataset_id,
+ trip_id=trip_id,
+ stop_id=stop_id,
+ stop_sequence=seq,
+ arrival_time=arrival_time,
+ departure_time=departure_time,
+ arrival_seconds=_time_seconds(arrival_time),
+ departure_seconds=_time_seconds(departure_time),
+ )
+ )
+ imported += 1
+ if len(objects) >= 5000:
+ session.bulk_save_objects(objects)
+ objects.clear()
+ if objects:
+ session.bulk_save_objects(objects)
+ return {trip: [stop for _, stop in sorted(seq)] for trip, seq in stopseq_by_trip.items()}, count, imported
+
+
+def _import_routes(
+ session: Session,
+ dataset_id: int,
+ routes_raw: list[dict[str, str]],
+ agency_names: dict[str, str],
+ stops_by_id: dict[str, tuple[float, float, str]],
+ trips_by_route: dict[str, list[str]],
+ first_shape_by_route: dict[str, str],
+ first_trip_by_route: dict[str, str],
+ shapes_by_id: dict[str, list[tuple[float, float]]],
+ stopseq_by_trip: dict[str, list[str]],
+) -> int:
+ objects: list[GtfsRoute] = []
+ for row in routes_raw:
+ route_id = row.get("route_id", "")
+ if not route_id:
+ continue
+ route_type = _int_or_none(row.get("route_type"))
+ mode = _gtfs_mode(route_type)
+ agency_id = row.get("agency_id") or None
+ operator = agency_names.get(agency_id or "", agency_id or "")
+ short_name = row.get("route_short_name") or None
+ long_name = row.get("route_long_name") or None
+ route_scope = infer_osm_route_scope(mode=mode, ref=short_name, name=long_name, network=operator)
+ geometry = _route_geometry(route_id, first_shape_by_route, first_trip_by_route, shapes_by_id, stopseq_by_trip, stops_by_id)
+ geometry_text, bbox = geometry_json_and_bbox(geometry) if geometry is not None else (None, (None, None, None, None))
+ route_key = norm_ref(short_name) or norm_text(long_name) or norm_ref(route_id)
+ objects.append(
+ GtfsRoute(
+ dataset_id=dataset_id,
+ route_id=route_id,
+ agency_id=agency_id,
+ short_name=short_name,
+ long_name=long_name,
+ route_type=route_type,
+ mode=mode,
+ route_scope=route_scope,
+ operator_name=operator or None,
+ geometry_geojson=geometry_text,
+ min_lon=bbox[0],
+ min_lat=bbox[1],
+ max_lon=bbox[2],
+ max_lat=bbox[3],
+ route_key=route_key,
+ operator_key=norm_text(operator),
+ )
+ )
+ if objects:
+ session.bulk_save_objects(objects)
+ return len(objects)
+
+
+def _route_geometry(
+ route_id: str,
+ first_shape_by_route: dict[str, str],
+ first_trip_by_route: dict[str, str],
+ shapes_by_id: dict[str, list[tuple[float, float]]],
+ stopseq_by_trip: dict[str, list[str]],
+ stops_by_id: dict[str, tuple[float, float, str]],
+) -> Optional[LineString]:
+ shape_id = first_shape_by_route.get(route_id)
+ coords = shapes_by_id.get(shape_id or "", [])
+ if len(coords) >= 2:
+ return LineString(coords)
+
+ trip_id = first_trip_by_route.get(route_id)
+ stop_ids = stopseq_by_trip.get(trip_id or "", [])
+ fallback = [(stops_by_id[sid][0], stops_by_id[sid][1]) for sid in stop_ids if sid in stops_by_id]
+ if len(fallback) >= 2:
+ return LineString(fallback)
+ return None
+
+
+def _float_or_none(value: object) -> Optional[float]:
+ try:
+ if value is None or str(value).strip() == "":
+ return None
+ return float(str(value))
+ except ValueError:
+ return None
+
+
+def _int_or_none(value: object) -> Optional[int]:
+ try:
+ if value is None or str(value).strip() == "":
+ return None
+ return int(float(str(value)))
+ except ValueError:
+ return None
+
+
+def _bool_flag(value: object) -> bool:
+ return str(value or "").strip() in {"1", "true", "True", "TRUE", "yes"}
+
+
+def _time_seconds(value: str | None) -> Optional[int]:
+ if not value:
+ return None
+ parts = value.strip().split(":")
+ if len(parts) == 2:
+ parts.append("0")
+ if len(parts) != 3:
+ return None
+ try:
+ hours, minutes, seconds = [int(part) for part in parts]
+ except ValueError:
+ return None
+ if hours < 0 or minutes < 0 or minutes > 59 or seconds < 0 or seconds > 59:
+ return None
+ return hours * 3600 + minutes * 60 + seconds
+
+
+def _gtfs_mode(route_type: Optional[int]) -> str:
+ if route_type is None:
+ return "unknown"
+ if route_type in GTFS_MODE:
+ return GTFS_MODE[route_type]
+ for start, end, mode in GTFS_EXTENDED_MODE_RANGES:
+ if start <= route_type <= end:
+ return mode
+ return "unknown"
+
+
+def _dataset_importer_version(dataset: Dataset) -> str:
+ try:
+ return str(json.loads(dataset.metadata_json or "{}").get("importer") or "")
+ except json.JSONDecodeError:
+ return ""
diff --git a/app/pipeline/matcher.py b/app/pipeline/matcher.py
new file mode 100644
index 0000000..b9d1fe3
--- /dev/null
+++ b/app/pipeline/matcher.py
@@ -0,0 +1,995 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from datetime import datetime, timezone
+import json
+from typing import Callable, Optional
+
+from shapely.geometry import LineString, MultiLineString, Point, shape
+from sqlalchemy import delete, select
+from sqlalchemy.orm import Session
+
+from app.config import settings
+from app.models import Dataset, GtfsRoute, MatchRule, OsmFeature, RouteMatch
+from app.osm_storage import ensure_main_osm_feature, osm_feature_bbox, query_osm_features
+from app.pipeline.state import STAGE_MATCH_ROUTES, dependency_hash, finish_pipeline_run, start_pipeline_run
+from app.pipeline.utils import approx_bbox_center_distance_deg, bbox_overlap, norm_ref, norm_text
+
+MODE_GROUPS = {
+ "train": {"train", "rail", "railway"},
+ "subway": {"subway", "metro"},
+ "tram": {"tram", "light_rail"},
+ "light_rail": {"light_rail", "tram"},
+ "bus": {"bus", "coach", "trolleybus"},
+ "coach": {"coach", "bus"},
+ "trolleybus": {"trolleybus", "bus"},
+ "ferry": {"ferry"},
+ "funicular": {"funicular"},
+ "aerialway": {"aerialway", "cable_car"},
+ "monorail": {"monorail"},
+}
+MAX_FALLBACK_CANDIDATES_WITH_REF = 40
+MAX_FALLBACK_CANDIDATES_WITHOUT_REF = 80
+MAX_EXACT_REF_CANDIDATES = 120
+OSM_SCOPE_NEAR_DISTANCE_DEG = 0.15
+GEOMETRY_PROXIMITY_DEG = 0.0035
+GEOMETRY_SAMPLE_POINTS = 24
+MATCHER_VERSION = "matcher_v4_scope_spatial_manual_rules"
+ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
+
+
+@dataclass(frozen=True)
+class _ManualMatchRule:
+ id: int
+ rule_type: str
+ route_selector: dict[str, object]
+ osm_selector: dict[str, object] | None
+ status: str
+
+
+@dataclass(frozen=True)
+class _OsmRouteIndex:
+ all_routes: list[OsmFeature]
+ by_ref: dict[str, list[OsmFeature]]
+ by_route_key: dict[str, list[OsmFeature]]
+ by_mode: dict[str, list[OsmFeature]]
+
+
+@dataclass(frozen=True)
+class _GeometryProfile:
+ geom: object
+ lines: list[LineString]
+ length: float
+ sample_points: list[Point]
+
+
+@dataclass(frozen=True)
+class _RouteMatchPayload:
+ gtfs_route_id: int
+ osm_feature_id: int | None
+ confidence: float
+ status: str
+ rule_source: str
+ reasons_json: str | None
+
+
+def run_route_matching(
+ session: Session,
+ *,
+ progress_callback: ProgressCallback | None = None,
+ batch_size: int | None = None,
+) -> dict[str, object]:
+ """Match active GTFS routes against active OSM route features."""
+ active_datasets = session.execute(
+ select(Dataset.id, Dataset.kind, Dataset.source_id).where(Dataset.is_active.is_(True))
+ ).all()
+ if not active_datasets:
+ return {"routes": 0, "matches": 0, "missing": 0}
+ dataset_source_ids = {int(dataset_id): int(source_id) for dataset_id, _, source_id in active_datasets}
+ gtfs_dataset_ids = [int(dataset_id) for dataset_id, kind, _ in active_datasets if kind == "gtfs"]
+ osm_dataset_ids = [int(dataset_id) for dataset_id, kind, _ in active_datasets if kind == "osm_geojson"]
+ if not gtfs_dataset_ids:
+ return {"routes": 0, "matches": 0, "missing": 0}
+
+ route_row_ids = session.scalars(
+ select(GtfsRoute.id)
+ .where(GtfsRoute.dataset_id.in_(gtfs_dataset_ids))
+ .order_by(GtfsRoute.dataset_id, GtfsRoute.route_id, GtfsRoute.id)
+ ).all()
+ # Reconcile current match rows from auto scoring plus durable manual rules.
+ total_routes = len(route_row_ids)
+ if total_routes == 0:
+ return {"routes": 0, "matches": 0, "missing": 0}
+
+ dependency = _route_matching_dependency(session, active_datasets)
+ run = start_pipeline_run(
+ session,
+ stage=STAGE_MATCH_ROUTES,
+ version=MATCHER_VERSION,
+ dependency_hash_value=dependency_hash(dependency),
+ inputs=dependency,
+ )
+ session.commit()
+ effective_batch_size = max(1, int(batch_size or settings.route_matching_batch_size))
+ _emit_progress(
+ progress_callback,
+ "route_matching_started",
+ f"Matching {total_routes} GTFS routes in batches of {effective_batch_size}.",
+ 0,
+ total_routes,
+ {"gtfs_datasets": gtfs_dataset_ids, "osm_datasets": osm_dataset_ids, "batch_size": effective_batch_size},
+ )
+ manual_rules = _manual_match_rules(session)
+ osm_scope_bbox = osm_feature_bbox(session, osm_dataset_ids, kinds=["route"])
+ counts = {"routes": total_routes, "matches": 0, "missing": 0, "manual": 0, "created": 0, "updated": 0, "unchanged": 0}
+ scoped_counts = {"in_osm_scope": 0, "near_osm_scope": 0, "outside_osm_scope": 0, "unknown_scope": 0}
+ processed = 0
+ for chunk in _chunks_int(route_row_ids, effective_batch_size):
+ routes = session.scalars(
+ select(GtfsRoute)
+ .where(GtfsRoute.id.in_(chunk))
+ .order_by(GtfsRoute.dataset_id, GtfsRoute.route_id, GtfsRoute.id)
+ ).all()
+ batch_counts = _match_route_batch(
+ session=session,
+ routes=routes,
+ osm_dataset_ids=osm_dataset_ids,
+ dataset_source_ids=dataset_source_ids,
+ manual_rules=manual_rules,
+ osm_scope_bbox=osm_scope_bbox,
+ scoped_counts=scoped_counts,
+ )
+ counts["matches"] += batch_counts["matches"]
+ counts["missing"] += batch_counts["missing"]
+ counts["manual"] += batch_counts["manual"]
+ counts["created"] += batch_counts["created"]
+ counts["updated"] += batch_counts["updated"]
+ counts["unchanged"] += batch_counts["unchanged"]
+ processed += len(routes)
+ session.commit()
+ _emit_progress(
+ progress_callback,
+ "route_matching_batch",
+ f"Matched {processed}/{total_routes} GTFS routes.",
+ processed,
+ total_routes,
+ {
+ "processed": processed,
+ "matches": counts["matches"],
+ "missing": counts["missing"],
+ "manual": counts["manual"],
+ "created": counts["created"],
+ "updated": counts["updated"],
+ "unchanged": counts["unchanged"],
+ "scope": dict(scoped_counts),
+ },
+ )
+ result = {**counts, "scope": scoped_counts}
+ finish_pipeline_run(session, run, outputs=result)
+ session.commit()
+ _emit_progress(
+ progress_callback,
+ "route_matching_completed",
+ "Route matching completed.",
+ total_routes,
+ total_routes,
+ result,
+ )
+ return result
+
+
+def _route_matching_dependency(session: Session, active_datasets) -> dict[str, object]:
+ datasets = [
+ {"id": int(dataset_id), "kind": str(kind), "source_id": int(source_id), "sha256": _dataset_sha(session, int(dataset_id))}
+ for dataset_id, kind, source_id in active_datasets
+ ]
+ rules = [
+ {
+ "id": int(rule.id),
+ "type": rule.rule_type,
+ "active": bool(rule.active),
+ "selector": rule.selector_json,
+ "action": rule.action_json,
+ }
+ for rule in session.scalars(select(MatchRule).order_by(MatchRule.id)).all()
+ ]
+ return {"version": MATCHER_VERSION, "active_datasets": datasets, "manual_rules": rules}
+
+
+def _dataset_sha(session: Session, dataset_id: int) -> str | None:
+ dataset = session.get(Dataset, dataset_id)
+ return None if dataset is None else dataset.sha256
+
+
+def _match_route_batch(
+ *,
+ session: Session,
+ routes: list[GtfsRoute],
+ osm_dataset_ids: list[int],
+ dataset_source_ids: dict[int, int],
+ manual_rules: list[_ManualMatchRule],
+ osm_scope_bbox: tuple[float | None, float | None, float | None, float | None],
+ scoped_counts: dict[str, int],
+) -> dict[str, int]:
+ matches = 0
+ missing = 0
+ manual = 0
+ payloads: list[_RouteMatchPayload] = []
+ for route in routes:
+ scope = route_match_scope(route, osm_scope_bbox)
+ scoped_counts[scope] = scoped_counts.get(scope, 0) + 1
+ route_source_id = dataset_source_ids.get(route.dataset_id)
+ accepted_rule = _accepted_rule_for_route(manual_rules, route, route_source_id)
+ if accepted_rule is not None:
+ accepted_feature = _feature_for_rule_from_storage(session, osm_dataset_ids, dataset_source_ids, accepted_rule)
+ if accepted_feature is not None:
+ accepted_feature = ensure_main_osm_feature(session, accepted_feature)
+ payloads.append(
+ _RouteMatchPayload(
+ gtfs_route_id=route.id,
+ osm_feature_id=accepted_feature.id,
+ confidence=100.0,
+ status="accepted",
+ rule_source="manual",
+ reasons_json=json.dumps(
+ {"manual_rule_id": accepted_rule.id, "manual": "accepted_match", "scope": scope},
+ separators=(",", ":"),
+ ),
+ )
+ )
+ matches += 1
+ manual += 1
+ continue
+
+ if scope == "outside_osm_scope":
+ missing += 1
+ payloads.append(
+ _RouteMatchPayload(
+ gtfs_route_id=route.id,
+ osm_feature_id=None,
+ confidence=0.0,
+ status="missing",
+ rule_source="auto",
+ reasons_json=json.dumps(
+ {
+ "reason": "outside loaded OSM route scope",
+ "scope": scope,
+ },
+ separators=(",", ":"),
+ ),
+ )
+ )
+ continue
+
+ best_feature: Optional[OsmFeature] = None
+ best_score = 0.0
+ best_reasons: dict[str, object] = {}
+ route_geometry_profile = _geometry_profile(route.geometry_geojson)
+ for feature in candidate_osm_routes_for_route(session, route, osm_dataset_ids):
+ if _is_rejected_pair(manual_rules, route, route_source_id, feature, dataset_source_ids.get(feature.dataset_id)):
+ continue
+ feature_geometry_profile = _geometry_profile(feature.geometry_geojson)
+ score, reasons = score_route_pair(
+ route,
+ feature,
+ route_geometry_profile=route_geometry_profile,
+ feature_geometry_profile=feature_geometry_profile,
+ )
+ if score > best_score:
+ best_score = score
+ best_feature = feature
+ best_reasons = reasons
+ status = _status_from_score(best_score)
+ if best_feature is None or status == "missing":
+ missing += 1
+ best_feature_id = None
+ best_reasons = {
+ "reason": "no OSM candidate above threshold",
+ "scope": scope,
+ "best_score_below_threshold": round(float(best_score), 2) if best_score else 0,
+ "best_reasons": best_reasons,
+ }
+ best_score = 0
+ else:
+ matches += 1
+ best_feature = ensure_main_osm_feature(session, best_feature)
+ best_feature_id = best_feature.id
+ best_reasons["scope"] = scope
+ payloads.append(
+ _RouteMatchPayload(
+ gtfs_route_id=route.id,
+ osm_feature_id=best_feature_id,
+ confidence=round(float(best_score), 2),
+ status=status,
+ rule_source="auto",
+ reasons_json=json.dumps(best_reasons, separators=(",", ":")),
+ )
+ )
+ changes = _apply_route_match_payloads(session, payloads)
+ session.flush()
+ return {"matches": matches, "missing": missing, "manual": manual, **changes}
+
+
+def _apply_route_match_payloads(session: Session, payloads: list[_RouteMatchPayload]) -> dict[str, int]:
+ if not payloads:
+ return {"created": 0, "updated": 0, "unchanged": 0}
+ route_ids = [payload.gtfs_route_id for payload in payloads]
+ existing_rows = session.scalars(
+ select(RouteMatch).where(RouteMatch.gtfs_route_id.in_(route_ids)).order_by(RouteMatch.gtfs_route_id, RouteMatch.id)
+ ).all()
+ existing_by_route: dict[int, list[RouteMatch]] = {}
+ for row in existing_rows:
+ existing_by_route.setdefault(row.gtfs_route_id, []).append(row)
+
+ created = 0
+ updated = 0
+ unchanged = 0
+ duplicate_ids: list[int] = []
+ now = datetime.now(timezone.utc)
+ for payload in payloads:
+ existing = existing_by_route.get(payload.gtfs_route_id, [])
+ current = _preferred_existing_match(existing)
+ if current is None:
+ session.add(
+ RouteMatch(
+ gtfs_route_id=payload.gtfs_route_id,
+ osm_feature_id=payload.osm_feature_id,
+ confidence=payload.confidence,
+ status=payload.status,
+ rule_source=payload.rule_source,
+ reasons_json=payload.reasons_json,
+ )
+ )
+ created += 1
+ continue
+
+ duplicate_ids.extend(row.id for row in existing if row.id != current.id)
+ if _route_match_payload_equal(current, payload):
+ unchanged += 1
+ continue
+ current.osm_feature_id = payload.osm_feature_id
+ current.confidence = payload.confidence
+ current.status = payload.status
+ current.rule_source = payload.rule_source
+ current.reasons_json = payload.reasons_json
+ current.updated_at = now
+ updated += 1
+
+ for chunk in _chunks_int(duplicate_ids, 1000):
+ session.execute(delete(RouteMatch).where(RouteMatch.id.in_(chunk)))
+ return {"created": created, "updated": updated, "unchanged": unchanged}
+
+
+def _preferred_existing_match(rows: list[RouteMatch]) -> RouteMatch | None:
+ if not rows:
+ return None
+ return next((row for row in rows if row.rule_source == "manual"), rows[0])
+
+
+def _route_match_payload_equal(row: RouteMatch, payload: _RouteMatchPayload) -> bool:
+ return (
+ row.osm_feature_id == payload.osm_feature_id
+ and round(float(row.confidence or 0), 2) == round(float(payload.confidence or 0), 2)
+ and row.status == payload.status
+ and row.rule_source == payload.rule_source
+ and (row.reasons_json or None) == (payload.reasons_json or None)
+ )
+
+
+def _build_osm_route_index(osm_routes: list[OsmFeature]) -> _OsmRouteIndex:
+ by_ref: dict[str, list[OsmFeature]] = {}
+ by_route_key: dict[str, list[OsmFeature]] = {}
+ by_mode: dict[str, list[OsmFeature]] = {}
+ for feature in osm_routes:
+ ref = norm_ref(feature.ref or "")
+ if ref:
+ by_ref.setdefault(ref, []).append(feature)
+ if feature.route_key:
+ by_route_key.setdefault(feature.route_key, []).append(feature)
+ if feature.mode:
+ by_mode.setdefault(feature.mode, []).append(feature)
+ return _OsmRouteIndex(all_routes=osm_routes, by_ref=by_ref, by_route_key=by_route_key, by_mode=by_mode)
+
+
+def _candidate_osm_routes(route: GtfsRoute, index: _OsmRouteIndex) -> list[OsmFeature]:
+ selected: list[OsmFeature] = []
+ seen: set[int] = set()
+
+ def add(features: list[OsmFeature], *, require_compatible_mode: bool = True) -> None:
+ for feature in features:
+ if feature.id in seen:
+ continue
+ if require_compatible_mode and not _mode_compatible(route.mode or "", feature.mode or ""):
+ continue
+ seen.add(feature.id)
+ selected.append(feature)
+
+ route_ref = norm_ref(route.short_name or route.route_id)
+ if route_ref:
+ add(index.by_ref.get(route_ref, []))
+ if route.route_key:
+ add(index.by_route_key.get(route.route_key, []))
+ if selected:
+ return _spatially_ranked_candidates(route, selected, MAX_EXACT_REF_CANDIDATES)
+
+ compatible_modes = MODE_GROUPS.get(route.mode or "", {route.mode or ""})
+ mode_candidates: list[OsmFeature] = []
+ for mode in compatible_modes:
+ if mode:
+ mode_candidates.extend(index.by_mode.get(mode, []))
+ if not mode_candidates:
+ mode_candidates = index.all_routes
+
+ gtfs_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
+ near_candidates: list[tuple[float, OsmFeature]] = []
+ for feature in mode_candidates:
+ osm_bbox = (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
+ distance = approx_bbox_center_distance_deg(gtfs_bbox, osm_bbox)
+ if bbox_overlap(gtfs_bbox, osm_bbox):
+ near_candidates.append((0.0, feature))
+ elif distance is not None and distance < 0.12:
+ near_candidates.append((distance, feature))
+ fallback_limit = MAX_FALLBACK_CANDIDATES_WITH_REF if route_ref else MAX_FALLBACK_CANDIDATES_WITHOUT_REF
+ fallback = [feature for _, feature in sorted(near_candidates, key=lambda item: item[0])[:fallback_limit]]
+ if not fallback:
+ fallback = mode_candidates[:fallback_limit]
+ add(fallback)
+ return _spatially_ranked_candidates(route, selected, fallback_limit)
+
+
+def candidate_osm_routes_for_route(session: Session, route: GtfsRoute, osm_dataset_ids: list[int]) -> list[OsmFeature]:
+ if not osm_dataset_ids:
+ return []
+ selected: list[OsmFeature] = []
+ seen: set[tuple[int, str, str]] = set()
+
+ def add(features: list[OsmFeature], *, require_compatible_mode: bool = True) -> None:
+ for feature in features:
+ key = (feature.dataset_id, feature.osm_type, feature.osm_id)
+ if key in seen:
+ continue
+ if require_compatible_mode and not _mode_compatible(route.mode or "", feature.mode or ""):
+ continue
+ seen.add(key)
+ selected.append(feature)
+
+ route_ref = norm_ref(route.short_name or route.route_id)
+ route_keys = [key for key in [route.route_key, route_ref] if key]
+ for route_key in dict.fromkeys(route_keys):
+ add(
+ query_osm_features(
+ session,
+ osm_dataset_ids,
+ kinds=["route"],
+ route_key=route_key,
+ )
+ )
+ if selected:
+ return _spatially_ranked_candidates(route, selected, MAX_EXACT_REF_CANDIDATES)
+
+ gtfs_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
+ compatible_modes = sorted(MODE_GROUPS.get(route.mode or "", {route.mode or ""}) - {""})
+ if not any(value is None for value in gtfs_bbox):
+ bbox = _expanded_bbox(gtfs_bbox, 0.10)
+ add(
+ query_osm_features(
+ session,
+ osm_dataset_ids,
+ kinds=["route"],
+ modes=compatible_modes or None,
+ bbox=bbox,
+ limit=MAX_FALLBACK_CANDIDATES_WITHOUT_REF * 4,
+ ),
+ require_compatible_mode=False,
+ )
+ if not selected:
+ add(
+ query_osm_features(
+ session,
+ osm_dataset_ids,
+ kinds=["route"],
+ modes=compatible_modes or None,
+ limit=MAX_FALLBACK_CANDIDATES_WITHOUT_REF,
+ ),
+ require_compatible_mode=False,
+ )
+ fallback_limit = MAX_FALLBACK_CANDIDATES_WITH_REF if route_ref else MAX_FALLBACK_CANDIDATES_WITHOUT_REF
+ return _spatially_ranked_candidates(route, selected, fallback_limit)
+
+
+def score_route_pair(
+ route: GtfsRoute,
+ feature: OsmFeature,
+ route_geometry_profile: _GeometryProfile | None = None,
+ feature_geometry_profile: _GeometryProfile | None = None,
+) -> tuple[float, dict[str, object]]:
+ score = 0.0
+ reasons: dict[str, object] = {}
+
+ gtfs_mode = route.mode or ""
+ osm_mode = feature.mode or ""
+ if _mode_compatible(gtfs_mode, osm_mode):
+ score += 25
+ reasons["mode"] = "compatible"
+ elif gtfs_mode and osm_mode:
+ reasons["mode"] = f"mismatch: {gtfs_mode} != {osm_mode}"
+ return 0.0, reasons
+
+ gtfs_ref = norm_ref(route.short_name or route.route_id)
+ osm_ref = norm_ref(feature.ref or "")
+ if gtfs_ref and osm_ref:
+ if gtfs_ref == osm_ref:
+ score += 25
+ reasons["ref"] = "exact"
+ elif gtfs_ref in osm_ref or osm_ref in gtfs_ref:
+ score += 15
+ reasons["ref"] = "partial"
+
+ gtfs_name = norm_text(" ".join(v for v in [route.long_name, route.short_name, route.route_id] if v))
+ osm_name = norm_text(" ".join(v for v in [feature.name, feature.ref] if v))
+ name_similarity = _ratio(gtfs_name, osm_name)
+ score += 20 * name_similarity
+ reasons["name_similarity"] = round(name_similarity, 3)
+
+ gtfs_operator = norm_text(route.operator_name or "")
+ osm_operator = norm_text(" ".join(v for v in [feature.operator, feature.network] if v))
+ operator_similarity = _ratio(gtfs_operator, osm_operator) if gtfs_operator and osm_operator else 0
+ score += 15 * operator_similarity
+ reasons["operator_similarity"] = round(operator_similarity, 3)
+
+ gtfs_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
+ osm_bbox = (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
+ center_distance = None
+ if bbox_overlap(gtfs_bbox, osm_bbox):
+ score += 14
+ reasons["bbox"] = "overlap"
+ if gtfs_ref and osm_ref and gtfs_ref == osm_ref and _mode_compatible(gtfs_mode, osm_mode):
+ score += 8
+ reasons["line_identity"] = "exact_ref_mode_bbox_overlap"
+ else:
+ center_distance = approx_bbox_center_distance_deg(gtfs_bbox, osm_bbox)
+ if center_distance is not None:
+ if center_distance < 0.01:
+ score += 12
+ elif center_distance < 0.03:
+ score += 8
+ elif center_distance < 0.08:
+ score += 4
+ elif gtfs_ref and osm_ref and gtfs_ref == osm_ref and center_distance > OSM_SCOPE_NEAR_DISTANCE_DEG:
+ score -= 8
+ reasons["spatial_penalty"] = "exact_ref_far_bbox_center"
+ reasons["bbox_center_distance_deg"] = round(center_distance, 5)
+
+ geometry_metrics = (
+ _geometry_match_metrics_from_profiles(route_geometry_profile, feature_geometry_profile)
+ if route_geometry_profile is not None and feature_geometry_profile is not None
+ else _geometry_match_metrics(route.geometry_geojson, feature.geometry_geojson)
+ )
+ if geometry_metrics is not None:
+ reasons["geometry"] = geometry_metrics
+ geometry_score = 34 * float(geometry_metrics["gtfs_on_osm_ratio"]) + 8 * float(geometry_metrics["osm_on_gtfs_ratio"])
+ if float(geometry_metrics["endpoint_distance_deg"]) < GEOMETRY_PROXIMITY_DEG * 2:
+ geometry_score += 6
+ if float(geometry_metrics["length_ratio"]) < 0.35 or float(geometry_metrics["length_ratio"]) > 2.8:
+ geometry_score -= 8
+ reasons["geometry_length"] = "implausible_ratio"
+ score += max(0.0, min(42.0, geometry_score))
+
+ # Extra small boost for same normalized route key.
+ if route.route_key and feature.route_key and route.route_key == feature.route_key:
+ score += 5
+ reasons["route_key"] = "same"
+
+ if gtfs_ref and osm_ref and gtfs_ref == osm_ref and _mode_compatible(gtfs_mode, osm_mode):
+ if bbox_overlap(gtfs_bbox, osm_bbox):
+ score = max(score, 88.0)
+ reasons["strong_identity"] = "exact_ref_mode_bbox_overlap"
+ elif center_distance is not None and center_distance < 0.02:
+ score = max(score, 82.0)
+ reasons["strong_identity"] = "exact_ref_mode_near_bbox_center"
+
+ if route.route_key and feature.route_key and route.route_key == feature.route_key and _mode_compatible(gtfs_mode, osm_mode):
+ if bbox_overlap(gtfs_bbox, osm_bbox):
+ score = max(score, 86.0)
+ reasons.setdefault("strong_identity", "same_route_key_mode_bbox_overlap")
+
+ if geometry_metrics is not None:
+ gtfs_on_osm = float(geometry_metrics["gtfs_on_osm_ratio"])
+ endpoint_distance = float(geometry_metrics["endpoint_distance_deg"])
+ if gtfs_on_osm >= 0.82 and endpoint_distance < GEOMETRY_PROXIMITY_DEG * 3 and _mode_compatible(gtfs_mode, osm_mode):
+ if gtfs_ref and osm_ref and gtfs_ref == osm_ref:
+ score = max(score, 90.0)
+ reasons["strong_identity"] = "exact_ref_mode_geometry_overlap"
+ elif gtfs_ref and osm_ref and (gtfs_ref in osm_ref or osm_ref in gtfs_ref):
+ score = max(score, 82.0)
+ reasons["strong_identity"] = "partial_ref_mode_geometry_overlap"
+
+ if (
+ gtfs_ref
+ and osm_ref
+ and gtfs_ref == osm_ref
+ and center_distance is not None
+ and center_distance > OSM_SCOPE_NEAR_DISTANCE_DEG
+ and not bbox_overlap(gtfs_bbox, osm_bbox)
+ and (
+ geometry_metrics is None
+ or float(geometry_metrics.get("gtfs_on_osm_ratio", 0.0)) < 0.25
+ )
+ ):
+ score = min(score, 58.0)
+ reasons["spatial_cap"] = "exact_ref_far_without_geometry_overlap"
+
+ return min(score, 100.0), reasons
+
+
+def route_match_scope(route: GtfsRoute, osm_scope_bbox: tuple[float | None, float | None, float | None, float | None]) -> str:
+ route_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
+ if any(value is None for value in route_bbox) or any(value is None for value in osm_scope_bbox):
+ return "unknown_scope"
+ if bbox_overlap(route_bbox, osm_scope_bbox):
+ return "in_osm_scope"
+ distance = approx_bbox_center_distance_deg(route_bbox, osm_scope_bbox)
+ if distance is not None and distance < OSM_SCOPE_NEAR_DISTANCE_DEG:
+ return "near_osm_scope"
+ return "outside_osm_scope"
+
+
+def _combined_bbox(features: list[OsmFeature]) -> tuple[float | None, float | None, float | None, float | None]:
+ boxes = [
+ (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
+ for feature in features
+ if None not in (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
+ ]
+ if not boxes:
+ return (None, None, None, None)
+ return (
+ min(float(box[0]) for box in boxes if box[0] is not None),
+ min(float(box[1]) for box in boxes if box[1] is not None),
+ max(float(box[2]) for box in boxes if box[2] is not None),
+ max(float(box[3]) for box in boxes if box[3] is not None),
+ )
+
+
+def _spatially_ranked_candidates(route: GtfsRoute, candidates: list[OsmFeature], limit: int) -> list[OsmFeature]:
+ return [
+ feature
+ for _, feature in sorted(
+ ((_spatial_rank(route, feature), feature) for feature in candidates),
+ key=lambda item: item[0],
+ )[: max(1, limit)]
+ ]
+
+
+def _spatial_rank(route: GtfsRoute, feature: OsmFeature) -> tuple[int, float, str]:
+ route_bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
+ feature_bbox = (feature.min_lon, feature.min_lat, feature.max_lon, feature.max_lat)
+ distance = approx_bbox_center_distance_deg(route_bbox, feature_bbox)
+ if bbox_overlap(route_bbox, feature_bbox):
+ bucket = 0
+ elif distance is not None and distance < OSM_SCOPE_NEAR_DISTANCE_DEG:
+ bucket = 1
+ elif distance is not None:
+ bucket = 2
+ else:
+ bucket = 3
+ return (bucket, distance if distance is not None else 999.0, feature.osm_id)
+
+
+def _expanded_bbox(
+ bbox: tuple[float | None, float | None, float | None, float | None],
+ padding: float,
+) -> tuple[float, float, float, float] | None:
+ min_lon, min_lat, max_lon, max_lat = bbox
+ if None in (min_lon, min_lat, max_lon, max_lat):
+ return None
+ return (float(min_lon) - padding, float(min_lat) - padding, float(max_lon) + padding, float(max_lat) + padding)
+
+
+def _chunks_int(values: list[int], size: int) -> list[list[int]]:
+ return [values[start : start + size] for start in range(0, len(values), max(1, size))]
+
+
+def _emit_progress(
+ progress_callback: ProgressCallback | None,
+ event_type: str,
+ message: str,
+ progress_current: int | None,
+ progress_total: int | None,
+ metadata: dict[str, object] | None = None,
+) -> None:
+ if progress_callback is not None:
+ progress_callback(event_type, message, progress_current, progress_total, metadata)
+
+
+def _geometry_match_metrics(route_geometry: str | None, feature_geometry: str | None) -> dict[str, float] | None:
+ route_profile = _geometry_profile(route_geometry)
+ feature_profile = _geometry_profile(feature_geometry)
+ return _geometry_match_metrics_from_profiles(route_profile, feature_profile)
+
+
+def _geometry_profile(geometry_text: str | None) -> _GeometryProfile | None:
+ if not geometry_text:
+ return None
+ try:
+ geom = shape(json.loads(geometry_text))
+ except Exception: # noqa: BLE001 - malformed geometry should not break matching
+ return None
+ lines = _iter_lines(geom)
+ if not lines:
+ return None
+ length = sum(line.length for line in lines)
+ if length == 0:
+ return None
+ sample_points = _sample_line_points(lines, GEOMETRY_SAMPLE_POINTS)
+ if not sample_points:
+ return None
+ return _GeometryProfile(geom=geom, lines=lines, length=length, sample_points=sample_points)
+
+
+def _geometry_match_metrics_from_profiles(
+ route_profile: _GeometryProfile | None, feature_profile: _GeometryProfile | None
+) -> dict[str, float] | None:
+ if route_profile is None or feature_profile is None:
+ return None
+ gtfs_on_osm = _near_point_ratio(route_profile.sample_points, feature_profile.geom, GEOMETRY_PROXIMITY_DEG)
+ osm_on_gtfs = _near_point_ratio(feature_profile.sample_points, route_profile.geom, GEOMETRY_PROXIMITY_DEG)
+ endpoint_distance = _endpoint_distance(route_profile.lines, feature_profile.geom)
+ length_ratio = route_profile.length / feature_profile.length if feature_profile.length else 0.0
+ return {
+ "gtfs_on_osm_ratio": round(gtfs_on_osm, 3),
+ "osm_on_gtfs_ratio": round(osm_on_gtfs, 3),
+ "endpoint_distance_deg": round(endpoint_distance, 6),
+ "length_ratio": round(length_ratio, 3),
+ }
+
+
+def _iter_lines(geom) -> list[LineString]:
+ if isinstance(geom, LineString):
+ return [geom]
+ if isinstance(geom, MultiLineString):
+ return [line for line in geom.geoms if isinstance(line, LineString) and line.length > 0]
+ return []
+
+
+def _sample_line_points(lines: list[LineString], count: int) -> list[Point]:
+ total_length = sum(line.length for line in lines)
+ if total_length == 0:
+ return []
+ points = []
+ for index in range(count):
+ target = total_length * (index / max(1, count - 1))
+ traversed = 0.0
+ for line in lines:
+ next_traversed = traversed + line.length
+ if target <= next_traversed or line is lines[-1]:
+ points.append(line.interpolate(max(0.0, min(line.length, target - traversed))))
+ break
+ traversed = next_traversed
+ return points
+
+
+def _near_point_ratio(points: list[Point], geom, max_distance: float) -> float:
+ if not points:
+ return 0.0
+ near = sum(1 for point in points if geom.distance(point) <= max_distance)
+ return near / len(points)
+
+
+def _endpoint_distance(gtfs_lines: list[LineString], osm_geom) -> float:
+ longest = max(gtfs_lines, key=lambda line: line.length)
+ coords = list(longest.coords)
+ if len(coords) < 2:
+ return 999.0
+ return osm_geom.distance(Point(coords[0])) + osm_geom.distance(Point(coords[-1]))
+
+
+def _manual_match_rules(session: Session) -> list[_ManualMatchRule]:
+ rules = session.scalars(
+ select(MatchRule)
+ .where(MatchRule.active.is_(True), MatchRule.rule_type.in_(["accept_match", "reject_match"]))
+ .order_by(MatchRule.id.desc())
+ ).all()
+ parsed: list[_ManualMatchRule] = []
+ for rule in rules:
+ try:
+ selector = json.loads(rule.selector_json or "{}")
+ action = json.loads(rule.action_json or "{}")
+ except json.JSONDecodeError:
+ continue
+ route_selector = selector.get("gtfs") if isinstance(selector.get("gtfs"), dict) else selector
+ osm_selector = action.get("osm") if isinstance(action.get("osm"), dict) else selector.get("osm")
+ if not isinstance(osm_selector, dict) and selector.get("osm_feature_id") is not None:
+ osm_selector = {"osm_feature_id": selector.get("osm_feature_id")}
+ status = str(action.get("status") or ("accepted" if rule.rule_type == "accept_match" else "rejected"))
+ parsed.append(
+ _ManualMatchRule(
+ id=rule.id,
+ rule_type=rule.rule_type,
+ route_selector=route_selector,
+ osm_selector=osm_selector if isinstance(osm_selector, dict) else None,
+ status=status,
+ )
+ )
+ return parsed
+
+
+def _accepted_rule_for_route(
+ rules: list[_ManualMatchRule], route: GtfsRoute, route_source_id: int | None
+) -> _ManualMatchRule | None:
+ for rule in rules:
+ if rule.rule_type != "accept_match":
+ continue
+ if rule.status != "accepted":
+ continue
+ if _route_matches_selector(route, route_source_id, rule.route_selector):
+ return rule
+ return None
+
+
+def _feature_for_rule(
+ features: list[OsmFeature], dataset_source_ids: dict[int, int], rule: _ManualMatchRule
+) -> OsmFeature | None:
+ if not rule.osm_selector:
+ return None
+ for feature in features:
+ if _feature_matches_selector(feature, dataset_source_ids.get(feature.dataset_id), rule.osm_selector):
+ return feature
+ return None
+
+
+def _feature_for_rule_from_storage(
+ session: Session,
+ osm_dataset_ids: list[int],
+ dataset_source_ids: dict[int, int],
+ rule: _ManualMatchRule,
+) -> OsmFeature | None:
+ if not rule.osm_selector:
+ return None
+ selector = rule.osm_selector
+ legacy_id = _safe_int(selector.get("osm_feature_id"))
+ if legacy_id is not None:
+ feature = session.get(OsmFeature, legacy_id)
+ if feature is not None and _feature_matches_selector(feature, dataset_source_ids.get(feature.dataset_id), selector):
+ return feature
+ scoped_dataset_ids = list(osm_dataset_ids)
+ expected_source = selector.get("source_id")
+ if expected_source is not None:
+ expected_source_id = _safe_int(expected_source)
+ if expected_source_id is not None:
+ scoped_dataset_ids = [
+ dataset_id
+ for dataset_id in scoped_dataset_ids
+ if dataset_source_ids.get(dataset_id) == expected_source_id
+ ]
+ dataset_id = _safe_int(selector.get("dataset_id"))
+ if dataset_id is not None:
+ scoped_dataset_ids = [value for value in scoped_dataset_ids if value == dataset_id]
+ if not scoped_dataset_ids:
+ return None
+
+ features: list[OsmFeature] = []
+ osm_type = selector.get("osm_type")
+ osm_id = selector.get("osm_id")
+ if osm_type and osm_id:
+ features = query_osm_features(
+ session,
+ scoped_dataset_ids,
+ kinds=["route"],
+ osm_type=str(osm_type),
+ osm_id=str(osm_id),
+ limit=10,
+ )
+ if not features:
+ route_key = selector.get("route_key")
+ if route_key:
+ features = query_osm_features(session, scoped_dataset_ids, kinds=["route"], route_key=str(route_key))
+ if not features:
+ ref = norm_ref(selector.get("ref"))
+ if ref:
+ features = query_osm_features(session, scoped_dataset_ids, kinds=["route"], route_key=ref)
+ for feature in features:
+ if _feature_matches_selector(feature, dataset_source_ids.get(feature.dataset_id), selector):
+ return feature
+ return None
+
+
+def _is_rejected_pair(
+ rules: list[_ManualMatchRule],
+ route: GtfsRoute,
+ route_source_id: int | None,
+ feature: OsmFeature,
+ feature_source_id: int | None,
+) -> bool:
+ for rule in rules:
+ if rule.rule_type != "reject_match":
+ continue
+ if not _route_matches_selector(route, route_source_id, rule.route_selector):
+ continue
+ if rule.osm_selector and _feature_matches_selector(feature, feature_source_id, rule.osm_selector):
+ return True
+ return False
+
+
+def _route_matches_selector(route: GtfsRoute, source_id: int | None, selector: dict[str, object]) -> bool:
+ legacy_id = selector.get("gtfs_route_id")
+ if legacy_id is not None and _safe_int(legacy_id) == route.id:
+ return True
+ expected_source = selector.get("source_id")
+ if expected_source is not None and source_id is not None and _safe_int(expected_source) != source_id:
+ return False
+ route_id = selector.get("route_id")
+ if route_id and str(route_id) == route.route_id:
+ return True
+ route_key = selector.get("route_key")
+ if route_key and route.route_key and str(route_key) == route.route_key:
+ return True
+ ref = norm_ref(selector.get("ref"))
+ mode = selector.get("mode")
+ if ref and ref == norm_ref(route.short_name or route.route_id):
+ return not mode or _mode_compatible(str(mode), route.mode or "")
+ return False
+
+
+def _feature_matches_selector(feature: OsmFeature, source_id: int | None, selector: dict[str, object]) -> bool:
+ legacy_id = selector.get("osm_feature_id")
+ if legacy_id is not None and _safe_int(legacy_id) == feature.id:
+ return True
+ expected_source = selector.get("source_id")
+ if expected_source is not None and source_id is not None and _safe_int(expected_source) != source_id:
+ return False
+ osm_type = selector.get("osm_type")
+ osm_id = selector.get("osm_id")
+ if osm_type and osm_id and str(osm_type) == feature.osm_type and str(osm_id) == feature.osm_id:
+ return True
+ route_key = selector.get("route_key")
+ if route_key and feature.route_key and str(route_key) == feature.route_key:
+ return True
+ ref = norm_ref(selector.get("ref"))
+ mode = selector.get("mode")
+ if ref and ref == norm_ref(feature.ref or ""):
+ return not mode or _mode_compatible(str(mode), feature.mode or "")
+ return False
+
+
+def _safe_int(value: object) -> int | None:
+ try:
+ return int(value) # type: ignore[arg-type]
+ except (TypeError, ValueError):
+ return None
+
+
+def _mode_compatible(gtfs_mode: str, osm_mode: str) -> bool:
+ if not gtfs_mode or not osm_mode:
+ return True
+ if gtfs_mode == osm_mode:
+ return True
+ return osm_mode in MODE_GROUPS.get(gtfs_mode, {gtfs_mode}) or gtfs_mode in MODE_GROUPS.get(osm_mode, {osm_mode})
+
+
+def _ratio(a: str, b: str) -> float:
+ if not a or not b:
+ return 0.0
+ if a == b:
+ return 1.0
+ token_ratio = _token_similarity(a, b)
+ if a in b or b in a:
+ token_ratio = max(token_ratio, 0.82)
+ return token_ratio
+
+
+def _token_similarity(a: str, b: str) -> float:
+ left = set(a.split())
+ right = set(b.split())
+ if not left or not right:
+ return 0.0
+ return len(left & right) / len(left | right)
+
+
+def _status_from_score(score: float) -> str:
+ if score >= 85:
+ return "matched"
+ if score >= 65:
+ return "probable"
+ if score >= 40:
+ return "weak"
+ return "missing"
diff --git a/app/pipeline/osm_addresses.py b/app/pipeline/osm_addresses.py
new file mode 100644
index 0000000..8b0f0b2
--- /dev/null
+++ b/app/pipeline/osm_addresses.py
@@ -0,0 +1,508 @@
+from __future__ import annotations
+
+import json
+import math
+import re
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Callable
+
+import osmium
+from sqlalchemy import delete, func, select, text
+from sqlalchemy.orm import Session
+
+from app.config import settings
+from app.models import Dataset, OsmAddress
+from app.pipeline.routing_layer import active_routing_dataset
+from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries
+
+
+ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
+ADDRESS_INDEX_VERSION = "osm_addresses_v2_nodes_ways_area_geometry"
+ADDRESS_TAGS = {
+ "addr:housenumber",
+ "addr:housename",
+ "addr:street",
+ "addr:place",
+ "addr:postcode",
+ "addr:city",
+ "addr:country",
+ "addr:unit",
+ "addr:suburb",
+ "addr:district",
+ "addr:municipality",
+ "entrance",
+ "name",
+}
+
+
+@dataclass
+class AddressIndexResult:
+ dataset_id: int
+ input_path: str
+ addresses: int
+ node_addresses: int
+ way_addresses: int
+ skipped: int
+ version: str = ADDRESS_INDEX_VERSION
+
+ def as_dict(self) -> dict[str, object]:
+ return {
+ "version": self.version,
+ "dataset_id": self.dataset_id,
+ "input_path": self.input_path,
+ "addresses": self.addresses,
+ "node_addresses": self.node_addresses,
+ "way_addresses": self.way_addresses,
+ "skipped": self.skipped,
+ }
+
+
+def rebuild_address_index(
+ session: Session,
+ *,
+ dataset_id: int | None = None,
+ input_path: str | Path | None = None,
+ reset: bool = True,
+ batch_size: int = 20_000,
+ progress_callback: ProgressCallback | None = None,
+) -> dict[str, object]:
+ dataset = session.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(session)
+ if dataset is None:
+ raise ValueError("No OSM PBF dataset is available for address indexing.")
+ path = Path(input_path or dataset.local_path)
+ if not path.exists():
+ raise FileNotFoundError(f"Address index PBF does not exist: {path}")
+
+ if reset:
+ _emit(progress_callback, "address_index_clear_started", "Clearing existing OSM address index.", None, None, {"dataset_id": dataset.id})
+ _clear_address_rows(session, dataset_id=int(dataset.id))
+ session.commit()
+
+ if settings.is_postgresql_database:
+ _emit(progress_callback, "address_index_indexes_dropped", "Dropping address lookup indexes before bulk import.", None, None, {"dataset_id": dataset.id})
+ _drop_address_indexes(session)
+ session.commit()
+
+ _emit(progress_callback, "address_index_import_started", "Importing OSM address nodes and ways.", None, None, {"dataset_id": dataset.id, "path": str(path)})
+ handler = _AddressHandler(
+ session=session,
+ dataset_id=dataset.id,
+ batch_size=batch_size,
+ progress_callback=progress_callback,
+ )
+ if hasattr(osmium, "FileProcessor"):
+ _apply_address_file_processor(handler, path)
+ else:
+ handler.apply_file(str(path), locations=True)
+ handler.flush()
+
+ return finalize_address_index(
+ session,
+ dataset_id=dataset.id,
+ input_path=path,
+ node_addresses=handler.node_address_count,
+ way_addresses=handler.way_address_count,
+ skipped=handler.skipped_count,
+ progress_callback=progress_callback,
+ )
+
+
+def finalize_address_index(
+ session: Session,
+ *,
+ dataset_id: int,
+ input_path: str | Path,
+ node_addresses: int = 0,
+ way_addresses: int = 0,
+ skipped: int = 0,
+ progress_callback: ProgressCallback | None = None,
+) -> dict[str, object]:
+ dataset = session.get(Dataset, dataset_id)
+ if dataset is None:
+ raise ValueError("Address index dataset does not exist.")
+ if settings.is_postgresql_database:
+ _emit(progress_callback, "address_index_geometry_started", "Refreshing address point geometries.", None, None, {"dataset_id": dataset.id})
+ refresh_postgis_geometries(session, dataset_id=dataset.id, tables=["osm_addresses"], only_missing=False)
+ session.commit()
+ _emit(progress_callback, "address_index_indexes_started", "Rebuilding address lookup indexes.", None, None, {"dataset_id": dataset.id})
+ _create_address_indexes(session)
+ session.commit()
+ analyze_postgresql_tables(session, ["osm_addresses"])
+ address_count = int(session.scalar(select(func.count()).select_from(OsmAddress).where(OsmAddress.dataset_id == dataset.id)) or 0)
+ metadata = _metadata(dataset)
+ metadata["address_index"] = {
+ "version": ADDRESS_INDEX_VERSION,
+ "addresses": address_count,
+ "node_addresses": int(node_addresses),
+ "way_addresses": int(way_addresses),
+ "skipped": int(skipped),
+ "input_path": str(input_path),
+ }
+ dataset.metadata_json = json.dumps(metadata, indent=2)
+ session.commit()
+ result = AddressIndexResult(
+ dataset_id=dataset.id,
+ input_path=str(input_path),
+ addresses=address_count,
+ node_addresses=node_addresses,
+ way_addresses=way_addresses,
+ skipped=skipped,
+ ).as_dict()
+ _emit(progress_callback, "address_index_import_completed", "OSM address index import completed.", address_count, address_count, result)
+ return result
+
+
+def _clear_address_rows(session: Session, *, dataset_id: int) -> None:
+ if settings.is_postgresql_database:
+ other_dataset_count = int(
+ session.scalar(
+ select(func.count(func.distinct(OsmAddress.dataset_id))).where(OsmAddress.dataset_id != int(dataset_id))
+ )
+ or 0
+ )
+ if other_dataset_count == 0:
+ session.execute(text("TRUNCATE TABLE osm_addresses RESTART IDENTITY"))
+ return
+ session.execute(delete(OsmAddress).where(OsmAddress.dataset_id == int(dataset_id)))
+
+
+def address_index_status(session: Session) -> dict[str, object]:
+ dataset = active_routing_dataset(session)
+ dataset_id = None if dataset is None else int(dataset.id)
+ address_count = 0
+ metadata: dict[str, object] = {}
+ if dataset is not None:
+ metadata = _metadata(dataset).get("address_index") or {}
+ if isinstance(metadata, dict):
+ try:
+ address_count = int(metadata.get("addresses") or 0)
+ except (TypeError, ValueError):
+ address_count = 0
+ if not address_count:
+ address_count = int(session.scalar(select(func.count()).select_from(OsmAddress).where(OsmAddress.dataset_id == dataset.id)) or 0)
+ installed_version = metadata.get("version") if isinstance(metadata, dict) else None
+ return {
+ "dataset_id": dataset_id,
+ "addresses": address_count,
+ "available": address_count > 0,
+ "version": installed_version,
+ "current_version": ADDRESS_INDEX_VERSION,
+ "stale": bool(address_count and installed_version != ADDRESS_INDEX_VERSION),
+ "input_path": metadata.get("input_path") if isinstance(metadata, dict) else None,
+ }
+
+
+class _AddressHandler(osmium.SimpleHandler):
+ def __init__(
+ self,
+ *,
+ session: Session,
+ dataset_id: int,
+ batch_size: int,
+ progress_callback: ProgressCallback | None,
+ ) -> None:
+ super().__init__()
+ self.session = session
+ self.dataset_id = int(dataset_id)
+ self.batch_size = max(1_000, int(batch_size))
+ self.progress_callback = progress_callback
+ self.rows: list[dict[str, object]] = []
+ self.address_count = 0
+ self.node_address_count = 0
+ self.way_address_count = 0
+ self.skipped_count = 0
+ self.processed_count = 0
+
+ def node(self, node) -> None:
+ self.process_node(node)
+
+ def way(self, way) -> None:
+ self.process_way(way)
+
+ def process_object(self, obj) -> None:
+ if hasattr(obj, "nodes"):
+ self.process_way(obj)
+ elif hasattr(obj, "location"):
+ self.process_node(obj)
+
+ def process_node(self, node) -> None:
+ self.processed_count += 1
+ tags = {tag.k: tag.v for tag in node.tags}
+ if not _has_address(tags):
+ return
+ if not node.location.valid():
+ self.skipped_count += 1
+ return
+ row = _address_row(
+ dataset_id=self.dataset_id,
+ osm_type="node",
+ osm_id=str(node.id),
+ tags=tags,
+ lon=float(node.location.lon),
+ lat=float(node.location.lat),
+ bounds=(float(node.location.lon), float(node.location.lat), float(node.location.lon), float(node.location.lat)),
+ geometry_geojson=None,
+ )
+ if row is None:
+ self.skipped_count += 1
+ return
+ self.rows.append(row)
+ self.node_address_count += 1
+ self._after_address()
+
+ def process_way(self, way) -> None:
+ self.processed_count += 1
+ tags = {tag.k: tag.v for tag in way.tags}
+ if not _has_address(tags):
+ return
+ coords = [
+ (float(node.location.lon), float(node.location.lat))
+ for node in way.nodes
+ if node.location.valid()
+ ]
+ if not coords:
+ self.skipped_count += 1
+ return
+ lon, lat = _centroid(coords)
+ min_lon = min(coord[0] for coord in coords)
+ max_lon = max(coord[0] for coord in coords)
+ min_lat = min(coord[1] for coord in coords)
+ max_lat = max(coord[1] for coord in coords)
+ row = _address_row(
+ dataset_id=self.dataset_id,
+ osm_type="way",
+ osm_id=str(way.id),
+ tags=tags,
+ lon=lon,
+ lat=lat,
+ bounds=(min_lon, min_lat, max_lon, max_lat),
+ geometry_geojson=_address_area_geometry_geojson(coords, closed=_way_is_closed(way)),
+ )
+ if row is None:
+ self.skipped_count += 1
+ return
+ self.rows.append(row)
+ self.way_address_count += 1
+ self._after_address()
+
+ def _after_address(self) -> None:
+ self.address_count += 1
+ if len(self.rows) >= self.batch_size:
+ self.flush()
+ if self.address_count % 50_000 == 0:
+ _emit(
+ self.progress_callback,
+ "address_index_import_batch",
+ f"Imported {self.address_count:,} OSM addresses.",
+ self.address_count,
+ None,
+ {"processed": self.processed_count, "skipped": self.skipped_count},
+ )
+
+ def flush(self) -> None:
+ if not self.rows:
+ return
+ self.session.bulk_insert_mappings(OsmAddress, self.rows)
+ self.session.commit()
+ self.rows = []
+
+
+def _apply_address_file_processor(handler: _AddressHandler, path: Path) -> None:
+ processor = (
+ osmium.FileProcessor(str(path), osmium.osm.NODE | osmium.osm.WAY)
+ .with_locations()
+ .with_filter(osmium.filter.KeyFilter("addr:housenumber", "addr:housename"))
+ )
+ for obj in processor:
+ handler.process_object(obj)
+
+
+def _has_address(tags: dict[str, str]) -> bool:
+ housenumber = _clean(tags.get("addr:housenumber") or tags.get("addr:housename"))
+ if not housenumber:
+ return False
+ return any(_clean(tags.get(key)) for key in ("addr:street", "addr:place", "addr:city", "addr:postcode"))
+
+
+def _address_row(
+ *,
+ dataset_id: int,
+ osm_type: str,
+ osm_id: str,
+ tags: dict[str, str],
+ lon: float,
+ lat: float,
+ bounds: tuple[float, float, float, float],
+ geometry_geojson: str | None = None,
+) -> dict[str, object] | None:
+ housenumber = _clean(tags.get("addr:housenumber") or tags.get("addr:housename"))
+ street = _clean(tags.get("addr:street"))
+ place = _clean(tags.get("addr:place"))
+ postcode = _clean(tags.get("addr:postcode"))
+ city = _clean(tags.get("addr:city") or tags.get("addr:municipality"))
+ country = _clean(tags.get("addr:country"))
+ unit = _clean(tags.get("addr:unit"))
+ name = _clean(tags.get("name"))
+ display_name = _display_name(housenumber=housenumber, street=street, place=place, postcode=postcode, city=city, name=name)
+ if not display_name:
+ return None
+ search_text = _search_text(display_name, housenumber, street, place, postcode, city, country, unit, name)
+ selected_tags = {key: tags[key] for key in sorted(ADDRESS_TAGS) if key in tags}
+ min_lon, min_lat, max_lon, max_lat = bounds
+ return {
+ "dataset_id": dataset_id,
+ "osm_type": osm_type,
+ "osm_id": osm_id,
+ "housenumber": housenumber,
+ "street": street,
+ "place": place,
+ "postcode": postcode,
+ "city": city,
+ "country": country,
+ "unit": unit,
+ "name": name,
+ "display_name": display_name,
+ "search_text": search_text,
+ "lon": lon,
+ "lat": lat,
+ "min_lon": min_lon,
+ "min_lat": min_lat,
+ "max_lon": max_lon,
+ "max_lat": max_lat,
+ "geometry_geojson": geometry_geojson,
+ "tags_json": json.dumps(selected_tags, separators=(",", ":")) if selected_tags else None,
+ }
+
+
+def _address_area_geometry_geojson(coords: list[tuple[float, float]], *, closed: bool | None = None) -> str | None:
+ if closed is False:
+ return None
+ if len(coords) < 3:
+ return None
+ ring_coords = list(coords)
+ first = ring_coords[0]
+ last = ring_coords[-1]
+ already_closed = abs(first[0] - last[0]) <= 1e-12 and abs(first[1] - last[1]) <= 1e-12
+ if not already_closed:
+ if closed is not True:
+ return None
+ ring_coords.append(first)
+ if len(ring_coords) < 4:
+ return None
+ ring = [[float(lon), float(lat)] for lon, lat in ring_coords]
+ if len({(round(lon, 12), round(lat, 12)) for lon, lat in ring_coords[:-1]}) < 3:
+ return None
+ return json.dumps({"type": "Polygon", "coordinates": [ring]}, separators=(",", ":"))
+
+
+def _way_is_closed(way) -> bool:
+ try:
+ nodes = way.nodes
+ return len(nodes) >= 3 and nodes[0].ref == nodes[-1].ref
+ except (AttributeError, IndexError, TypeError):
+ return False
+
+
+def _display_name(
+ *,
+ housenumber: str | None,
+ street: str | None,
+ place: str | None,
+ postcode: str | None,
+ city: str | None,
+ name: str | None,
+) -> str | None:
+ road = street or place or name
+ if road and housenumber:
+ first = f"{road} {housenumber}"
+ else:
+ first = road or housenumber
+ locality = " ".join(part for part in [postcode, city] if part)
+ if first and locality:
+ return f"{first}, {locality}"
+ return first or locality
+
+
+def _search_text(*parts: str | None) -> str:
+ return re.sub(r"\s+", " ", " ".join(part.casefold() for part in parts if part)).strip()
+
+
+def _clean(value: object) -> str | None:
+ cleaned = re.sub(r"\s+", " ", str(value or "")).strip()
+ return cleaned or None
+
+
+def _centroid(coords: list[tuple[float, float]]) -> tuple[float, float]:
+ if len(coords) >= 4 and coords[0] == coords[-1]:
+ area = 0.0
+ cx = 0.0
+ cy = 0.0
+ for (x1, y1), (x2, y2) in zip(coords, coords[1:]):
+ cross = x1 * y2 - x2 * y1
+ area += cross
+ cx += (x1 + x2) * cross
+ cy += (y1 + y2) * cross
+ if abs(area) > 1e-18:
+ factor = 1 / (3 * area)
+ return cx * factor, cy * factor
+ return (
+ math.fsum(coord[0] for coord in coords) / len(coords),
+ math.fsum(coord[1] for coord in coords) / len(coords),
+ )
+
+
+def _drop_address_indexes(session: Session) -> None:
+ for name in [
+ "ix_osm_addresses_dataset_city_street",
+ "ix_osm_addresses_dataset_postcode",
+ "ix_osm_addresses_bbox",
+ "ix_osm_addresses_geom_gist",
+ "ix_osm_addresses_area_geom_gist",
+ "ix_osm_addresses_search_trgm",
+ "ix_osm_addresses_display_trgm",
+ "ix_osm_addresses_street_key_house",
+ "ix_osm_addresses_street_key_trgm",
+ ]:
+ session.execute(text(f"DROP INDEX IF EXISTS {name}"))
+
+
+def _create_address_indexes(session: Session) -> None:
+ statements = [
+ "CREATE INDEX IF NOT EXISTS ix_osm_addresses_dataset_city_street ON osm_addresses (dataset_id, city, street, housenumber)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_addresses_dataset_postcode ON osm_addresses (dataset_id, postcode)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_addresses_bbox ON osm_addresses (dataset_id, min_lon, max_lon, min_lat, max_lat)",
+ ]
+ if settings.is_postgresql_database:
+ statements.extend(
+ [
+ "CREATE INDEX IF NOT EXISTS ix_osm_addresses_geom_gist ON osm_addresses USING GIST (geom)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_addresses_area_geom_gist ON osm_addresses USING GIST (area_geom)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_addresses_search_trgm ON osm_addresses USING GIN (LOWER(COALESCE(search_text, '')) gin_trgm_ops)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_addresses_display_trgm ON osm_addresses USING GIN (LOWER(COALESCE(display_name, '')) gin_trgm_ops)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_addresses_street_key_house ON osm_addresses (dataset_id, REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss'), housenumber)",
+ "CREATE INDEX IF NOT EXISTS ix_osm_addresses_street_key_trgm ON osm_addresses USING GIN (REPLACE(LOWER(COALESCE(NULLIF(street, ''), NULLIF(place, ''), '')), 'ß', 'ss') gin_trgm_ops)",
+ ]
+ )
+ for statement in statements:
+ session.execute(text(statement))
+
+
+def _metadata(dataset: Dataset) -> dict[str, object]:
+ try:
+ value = json.loads(dataset.metadata_json or "{}")
+ except json.JSONDecodeError:
+ return {}
+ return value if isinstance(value, dict) else {}
+
+
+def _emit(
+ progress_callback: ProgressCallback | None,
+ event_type: str,
+ message: str,
+ progress_current: int | None,
+ progress_total: int | None,
+ metadata: dict[str, object] | None = None,
+) -> None:
+ if progress_callback is not None:
+ progress_callback(event_type, message, progress_current, progress_total, metadata)
diff --git a/app/pipeline/osm_diff.py b/app/pipeline/osm_diff.py
new file mode 100644
index 0000000..f101b56
--- /dev/null
+++ b/app/pipeline/osm_diff.py
@@ -0,0 +1,100 @@
+from __future__ import annotations
+
+import json
+from urllib.parse import urlparse
+
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
+from app.config import settings
+from app.models import Dataset, Source
+from app.pipeline.download import materialize_source
+from app.pipeline.osm_pbf import _raw_format
+from app.pipeline.osm_replication import fetch_replication_state
+from app.pipeline.utils import sha256_file
+
+
+def run_osm_diff_source(session: Session, source: Source) -> Dataset:
+ """Commit an OSM change file as a raw update artifact.
+
+ Applying the diff to an authoritative OSM base extract is a separate step;
+ this importer deliberately records the file without treating it as a
+ complete visual route layer.
+ """
+ if _looks_like_update_directory(source.url):
+ return _commit_update_directory_state(session, source)
+
+ raw_path = materialize_source(source)
+ raw_hash = sha256_file(raw_path)
+ existing = session.scalar(
+ select(Dataset)
+ .where(Dataset.source_id == source.id, Dataset.kind == "osm_diff_raw", Dataset.sha256 == raw_hash)
+ .order_by(Dataset.id.desc())
+ )
+ if existing is not None:
+ return existing
+
+ dataset = Dataset(
+ source_id=source.id,
+ kind="osm_diff_raw",
+ local_path=str(raw_path),
+ sha256=raw_hash,
+ is_active=False,
+ status="committed",
+ metadata_json=json.dumps(
+ {
+ "stage": "raw_osm_diff",
+ "raw_format": _raw_format(raw_path),
+ "source_url": source.url,
+ },
+ indent=2,
+ ),
+ )
+ session.add(dataset)
+ session.flush()
+ return dataset
+
+
+def _commit_update_directory_state(session: Session, source: Source) -> Dataset:
+ state = fetch_replication_state(source.url, timeout=settings.osm_diff_state_timeout_seconds)
+ source_dir = settings.data_dir / "sources" / f"source_{source.id}"
+ source_dir.mkdir(parents=True, exist_ok=True)
+ state_path = source_dir / f"state_{state.sequence_number}.txt"
+ state_path.write_text(
+ "\n".join(f"{key}={value}" for key, value in sorted(state.raw.items())) + "\n",
+ encoding="utf-8",
+ )
+ state_hash = sha256_file(state_path)
+ existing = session.scalar(
+ select(Dataset)
+ .where(Dataset.source_id == source.id, Dataset.kind == "osm_diff_state", Dataset.sha256 == state_hash)
+ .order_by(Dataset.id.desc())
+ )
+ if existing is not None:
+ return existing
+ dataset = Dataset(
+ source_id=source.id,
+ kind="osm_diff_state",
+ local_path=str(state_path),
+ sha256=state_hash,
+ is_active=False,
+ status="committed",
+ metadata_json=json.dumps(
+ {
+ "stage": "osm_diff_state",
+ "updates_url": source.url,
+ "sequence_number": state.sequence_number,
+ "timestamp": state.timestamp,
+ "state": state.raw,
+ },
+ indent=2,
+ ),
+ )
+ session.add(dataset)
+ session.flush()
+ return dataset
+
+
+def _looks_like_update_directory(url: str) -> bool:
+ lower_path = urlparse(url).path.lower()
+ return lower_path.endswith("-updates") or lower_path.endswith("-updates/")
diff --git a/app/pipeline/osm_geojson.py b/app/pipeline/osm_geojson.py
new file mode 100644
index 0000000..8aa50a8
--- /dev/null
+++ b/app/pipeline/osm_geojson.py
@@ -0,0 +1,248 @@
+from __future__ import annotations
+
+import json
+from pathlib import Path
+from typing import Any
+
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
+from app.config import settings
+from app.models import Dataset, OsmFeature, Source
+from app.osm_classification import infer_osm_route_scope
+from app.osm_storage import (
+ OSM_STORAGE_METADATA_KEY,
+ OSM_STORAGE_MAIN,
+ OSM_STORAGE_SIDECAR_FEATURES,
+ create_osm_sidecar,
+ dedupe_osm_feature_rows,
+ effective_osm_feature_storage,
+)
+from app.pipeline.download import materialize_source
+from app.pipeline.utils import first_nonempty, geometry_json_and_bbox, norm_ref, norm_text, sha256_file
+from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries
+
+ROUTE_MODES = {
+ "train",
+ "railway",
+ "light_rail",
+ "subway",
+ "tram",
+ "bus",
+ "trolleybus",
+ "coach",
+ "ferry",
+ "monorail",
+ "funicular",
+ "aerialway",
+}
+
+
+def run_osm_geojson_source(session: Session, source: Source) -> Dataset:
+ local_path = materialize_source(source)
+ source_hash = sha256_file(local_path)
+ existing = session.scalar(
+ select(Dataset)
+ .where(
+ Dataset.source_id == source.id,
+ Dataset.kind == "osm_geojson",
+ Dataset.sha256 == source_hash,
+ Dataset.is_active.is_(True),
+ Dataset.status == "imported",
+ )
+ .order_by(Dataset.id.desc())
+ )
+ if existing is not None:
+ return existing
+ return import_osm_geojson(session=session, source=source, path=local_path, source_hash=source_hash)
+
+
+def import_osm_geojson(
+ session: Session,
+ source: Source,
+ path: Path,
+ source_hash: str | None = None,
+ *,
+ storage_mode: str | None = None,
+) -> Dataset:
+ for dataset in source.datasets:
+ dataset.is_active = False
+
+ dataset = Dataset(
+ source_id=source.id,
+ kind="osm_geojson",
+ local_path=str(path),
+ sha256=source_hash or sha256_file(path),
+ is_active=True,
+ status="importing",
+ )
+ session.add(dataset)
+ session.flush()
+
+ source_hash = source_hash or sha256_file(path)
+ dataset.metadata_json = json.dumps(
+ prepare_osm_geojson_storage(
+ session=session,
+ dataset=dataset,
+ path=path,
+ source_hash=source_hash,
+ storage_mode=storage_mode,
+ ),
+ indent=2,
+ )
+
+ dataset.status = "imported"
+ source.status = "ok"
+ source.last_error = None
+ session.flush()
+ return dataset
+
+
+def prepare_osm_geojson_storage(
+ *,
+ session: Session,
+ dataset: Dataset,
+ path: Path,
+ source_hash: str | None = None,
+ storage_mode: str | None = None,
+) -> dict[str, object]:
+ data = json.loads(path.read_text(encoding="utf-8"))
+ features = _as_features(data)
+ feature_rows = [_feature_row(dataset.id, idx, feature) for idx, feature in enumerate(features)]
+ storage = effective_osm_feature_storage(storage_mode)
+ if storage not in {OSM_STORAGE_MAIN, OSM_STORAGE_SIDECAR_FEATURES}:
+ raise ValueError(f"Unsupported OSM feature storage mode: {storage}")
+ if storage == OSM_STORAGE_SIDECAR_FEATURES:
+ return {
+ "features": len(feature_rows),
+ OSM_STORAGE_METADATA_KEY: create_osm_sidecar(dataset, feature_rows, source_hash=source_hash or dataset.sha256),
+ }
+ _insert_main_features(session, feature_rows)
+ session.flush()
+ refresh_postgis_geometries(session, dataset_id=dataset.id, tables=["osm_features"])
+ analyze_postgresql_tables(session, ["osm_features"])
+ return {"features": len(feature_rows), OSM_STORAGE_METADATA_KEY: {"mode": OSM_STORAGE_MAIN}}
+
+
+def _insert_main_features(session: Session, feature_rows: list[dict[str, object]]) -> None:
+ objects: list[OsmFeature] = []
+ deduped_rows, _duplicate_count = dedupe_osm_feature_rows(feature_rows)
+ for row in deduped_rows:
+ objects.append(
+ OsmFeature(
+ dataset_id=row["dataset_id"],
+ osm_type=row["osm_type"],
+ osm_id=row["osm_id"],
+ kind=row["kind"],
+ mode=row["mode"],
+ route_scope=row["route_scope"],
+ name=row["name"],
+ ref=row["ref"],
+ operator=row["operator"],
+ network=row["network"],
+ geometry_geojson=row["geometry_geojson"],
+ min_lon=row["min_lon"],
+ min_lat=row["min_lat"],
+ max_lon=row["max_lon"],
+ max_lat=row["max_lat"],
+ tags_json=row["tags_json"],
+ route_key=row["route_key"],
+ operator_key=row["operator_key"],
+ )
+ )
+ if len(objects) >= 5000:
+ session.bulk_save_objects(objects)
+ objects.clear()
+ if objects:
+ session.bulk_save_objects(objects)
+
+
+def _feature_row(dataset_id: int, idx: int, feature: dict[str, Any]) -> dict[str, object]:
+ props = feature.get("properties") or {}
+ geometry = feature.get("geometry")
+ geometry_text, bbox = geometry_json_and_bbox(geometry)
+ osm_type = str(first_nonempty(props.get("osm_type"), props.get("@type"), props.get("type"), "feature"))
+ osm_id = str(first_nonempty(props.get("osm_id"), props.get("@id"), props.get("id"), f"feature_{idx}"))
+ mode = _infer_mode(props)
+ kind = _infer_kind(props, mode)
+ name = first_nonempty(props.get("name"), props.get("official_name")) or None
+ ref = first_nonempty(props.get("ref"), props.get("route_ref"), props.get("line")) or None
+ operator = first_nonempty(props.get("operator"), props.get("agency"), props.get("brand")) or None
+ network = first_nonempty(props.get("network"), props.get("network:short")) or None
+ route_scope = infer_osm_route_scope(mode=mode, ref=ref, name=name, network=network, tags=props)
+ route_key = norm_ref(ref) or norm_text(name) or norm_ref(osm_id)
+ operator_key = norm_text(operator or network or "")
+ return {
+ "dataset_id": dataset_id,
+ "osm_type": osm_type,
+ "osm_id": osm_id,
+ "kind": kind,
+ "mode": mode,
+ "route_scope": route_scope,
+ "name": name,
+ "ref": ref,
+ "operator": operator,
+ "network": network,
+ "geometry_geojson": geometry_text,
+ "min_lon": bbox[0],
+ "min_lat": bbox[1],
+ "max_lon": bbox[2],
+ "max_lat": bbox[3],
+ "tags_json": json.dumps(props, separators=(",", ":")),
+ "route_key": route_key,
+ "operator_key": operator_key,
+ }
+
+
+def _as_features(data: Any) -> list[dict[str, Any]]:
+ if isinstance(data, dict) and data.get("type") == "FeatureCollection":
+ return [f for f in data.get("features", []) if isinstance(f, dict)]
+ if isinstance(data, dict) and data.get("type") == "Feature":
+ return [data]
+ if isinstance(data, list):
+ return [f for f in data if isinstance(f, dict)]
+ raise ValueError("OSM source must be GeoJSON FeatureCollection, Feature, or list of Features")
+
+
+def _infer_mode(props: dict[str, Any]) -> str | None:
+ for key in ("mode", "route", "route_master"):
+ value = str(props.get(key) or "").strip()
+ if value in ROUTE_MODES:
+ return "train" if value == "railway" else value
+ railway = str(props.get("railway") or "").strip()
+ if railway in {"station", "halt"}:
+ return "train"
+ if railway == "tram_stop":
+ return "tram"
+ if railway == "subway_entrance":
+ return "subway"
+ if str(props.get("highway") or "") == "bus_stop" or str(props.get("amenity") or "") == "bus_station":
+ return "bus"
+ if str(props.get("amenity") or "") == "ferry_terminal":
+ return "ferry"
+ if str(props.get("aerialway") or "") == "station":
+ return "aerialway"
+ return None
+
+
+def _infer_kind(props: dict[str, Any], mode: str | None) -> str:
+ explicit_kind = str(props.get("kind") or "").strip()
+ if explicit_kind in {"route", "stop", "station", "terminal", "infra", "feature"}:
+ return explicit_kind
+ if str(props.get("type") or "") in {"route", "route_master"} or str(props.get("route") or "") in ROUTE_MODES:
+ return "route"
+ if str(props.get("amenity") or "") == "ferry_terminal":
+ return "terminal"
+ if str(props.get("amenity") or "") == "bus_station":
+ return "terminal"
+ if str(props.get("railway") or "") in {"station", "halt"}:
+ return "station"
+ if str(props.get("aerialway") or "") == "station":
+ return "station"
+ if str(props.get("public_transport") or "") in {"platform", "stop_position", "station"}:
+ return "stop"
+ if str(props.get("highway") or "") == "bus_stop":
+ return "stop"
+ if mode:
+ return "infra"
+ return "feature"
diff --git a/app/pipeline/osm_labeling.py b/app/pipeline/osm_labeling.py
new file mode 100644
index 0000000..cc7f23e
--- /dev/null
+++ b/app/pipeline/osm_labeling.py
@@ -0,0 +1,456 @@
+from __future__ import annotations
+
+from datetime import datetime, timezone
+import json
+from pathlib import Path
+import sqlite3
+from typing import Callable
+
+from sqlalchemy import func, select, text
+from sqlalchemy.orm import Session
+
+from app.models import Dataset, OsmFeature
+from app.osm_classification import OSM_ROUTE_SCOPE_CLASSIFIER_VERSION, infer_osm_route_scope_from_tags
+from app.osm_storage import (
+ dataset_metadata,
+ drop_osm_sidecar_route_scope_indexes,
+ ensure_osm_sidecar_schema,
+ features_are_sidecar,
+ rebuild_osm_sidecar_indexes,
+ sidecar_path,
+ writable_sidecar_connection,
+)
+from app.pipeline.state import (
+ STAGE_BUILD_INDEXES,
+ STAGE_LABEL_FEATURES,
+ dependency_hash,
+ finish_pipeline_run,
+ latest_completed_run,
+ start_pipeline_run,
+)
+
+
+OSM_LABEL_FEATURES_VERSION = OSM_ROUTE_SCOPE_CLASSIFIER_VERSION
+MAIN_ROUTE_SCOPE_INDEX = "ix_osm_features_scope_bbox"
+MAIN_INDEX_REBUILD_THRESHOLD = 10_000
+SIDECAR_INDEX_REBUILD_THRESHOLD = 10_000
+ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
+
+
+def relabel_osm_features(
+ session: Session,
+ *,
+ dataset_id: int | None = None,
+ chunk_size: int = 5000,
+ force: bool = False,
+ rebuild_indexes: bool = True,
+ progress_callback: ProgressCallback | None = None,
+ job_id: int | None = None,
+) -> dict[str, object]:
+ datasets = _target_datasets(session, dataset_id)
+ result: dict[str, object] = {
+ "version": OSM_LABEL_FEATURES_VERSION,
+ "datasets": len(datasets),
+ "processed": 0,
+ "changed": 0,
+ "skipped": 0,
+ "missing": 0,
+ "index_rebuilds": 0,
+ "dataset_results": [],
+ }
+ _emit_progress(
+ progress_callback,
+ "osm_labeling_started",
+ f"Relabeling {len(datasets)} OSM dataset(s).",
+ 0,
+ len(datasets),
+ {"dataset_id": dataset_id, "force": force, "version": OSM_LABEL_FEATURES_VERSION},
+ )
+ for index, dataset in enumerate(datasets, start=1):
+ dataset_result = relabel_osm_dataset(
+ session,
+ dataset,
+ chunk_size=chunk_size,
+ force=force,
+ rebuild_indexes=rebuild_indexes,
+ progress_callback=progress_callback,
+ job_id=job_id,
+ )
+ result["processed"] = int(result["processed"]) + int(dataset_result.get("processed", 0) or 0)
+ result["changed"] = int(result["changed"]) + int(dataset_result.get("changed", 0) or 0)
+ result["skipped"] = int(result["skipped"]) + (1 if dataset_result.get("status") == "skipped" else 0)
+ result["missing"] = int(result["missing"]) + (1 if dataset_result.get("status") == "missing_sidecar" else 0)
+ result["index_rebuilds"] = int(result["index_rebuilds"]) + int(dataset_result.get("index_rebuilds", 0) or 0)
+ result["dataset_results"].append(dataset_result) # type: ignore[union-attr]
+ _emit_progress(
+ progress_callback,
+ "osm_labeling_dataset_completed",
+ f"Relabeled {index}/{len(datasets)} OSM dataset(s).",
+ index,
+ len(datasets),
+ dataset_result,
+ )
+ _emit_progress(progress_callback, "osm_labeling_completed", "OSM feature relabeling completed.", len(datasets), len(datasets), result)
+ return result
+
+
+def relabel_osm_dataset(
+ session: Session,
+ dataset: Dataset,
+ *,
+ chunk_size: int = 5000,
+ force: bool = False,
+ rebuild_indexes: bool = True,
+ progress_callback: ProgressCallback | None = None,
+ job_id: int | None = None,
+) -> dict[str, object]:
+ dependency = _label_dependency(dataset)
+ dependency_hash_value = dependency_hash(dependency)
+ if not force and _dataset_label_is_current(session, dataset, dependency_hash_value):
+ return {
+ "dataset_id": dataset.id,
+ "source_id": dataset.source_id,
+ "status": "skipped",
+ "reason": "label_features dependency is current",
+ "dependency_hash": dependency_hash_value,
+ "version": OSM_LABEL_FEATURES_VERSION,
+ "processed": 0,
+ "changed": 0,
+ "index_rebuilds": 0,
+ }
+
+ run = start_pipeline_run(
+ session,
+ stage=STAGE_LABEL_FEATURES,
+ version=OSM_LABEL_FEATURES_VERSION,
+ dependency_hash_value=dependency_hash_value,
+ source_id=dataset.source_id,
+ dataset_id=dataset.id,
+ job_id=job_id,
+ inputs=dependency,
+ )
+ session.commit()
+ try:
+ if features_are_sidecar(dataset):
+ counts = _relabel_sidecar_dataset(dataset, chunk_size=chunk_size, rebuild_indexes=rebuild_indexes, progress_callback=progress_callback)
+ else:
+ counts = _relabel_main_dataset(session, dataset, chunk_size=chunk_size, rebuild_indexes=rebuild_indexes, progress_callback=progress_callback)
+ output = {
+ "dataset_id": dataset.id,
+ "source_id": dataset.source_id,
+ "status": "completed",
+ "dependency_hash": dependency_hash_value,
+ "version": OSM_LABEL_FEATURES_VERSION,
+ **counts,
+ }
+ _stamp_dataset_metadata(session, dataset, dependency_hash_value, output)
+ finish_pipeline_run(session, run, outputs=output)
+ session.commit()
+ return output
+ except FileNotFoundError as exc:
+ output = {
+ "dataset_id": dataset.id,
+ "source_id": dataset.source_id,
+ "status": "missing_sidecar",
+ "dependency_hash": dependency_hash_value,
+ "version": OSM_LABEL_FEATURES_VERSION,
+ "processed": 0,
+ "changed": 0,
+ "index_rebuilds": 0,
+ "error": str(exc),
+ }
+ finish_pipeline_run(session, run, status="failed", outputs=output, error=str(exc))
+ session.commit()
+ return output
+ except Exception as exc:
+ finish_pipeline_run(session, run, status="failed", error=str(exc))
+ session.commit()
+ raise
+
+
+def _target_datasets(session: Session, dataset_id: int | None) -> list[Dataset]:
+ stmt = select(Dataset).where(Dataset.kind == "osm_geojson", Dataset.status == "imported")
+ if dataset_id is None:
+ stmt = stmt.where(Dataset.is_active.is_(True))
+ else:
+ stmt = stmt.where(Dataset.id == dataset_id)
+ return session.scalars(stmt.order_by(Dataset.source_id, Dataset.id)).all()
+
+
+def _dataset_label_is_current(session: Session, dataset: Dataset, dependency_hash_value: str) -> bool:
+ metadata = dataset_metadata(dataset)
+ label_info = metadata.get("label_features")
+ metadata_current = (
+ isinstance(label_info, dict)
+ and label_info.get("version") == OSM_LABEL_FEATURES_VERSION
+ and label_info.get("dependency_hash") == dependency_hash_value
+ )
+ if not metadata_current:
+ return False
+ return (
+ latest_completed_run(
+ session,
+ stage=STAGE_LABEL_FEATURES,
+ version=OSM_LABEL_FEATURES_VERSION,
+ dependency_hash_value=dependency_hash_value,
+ source_id=dataset.source_id,
+ dataset_id=dataset.id,
+ )
+ is not None
+ )
+
+
+def _relabel_sidecar_dataset(
+ dataset: Dataset,
+ *,
+ chunk_size: int,
+ rebuild_indexes: bool,
+ progress_callback: ProgressCallback | None,
+) -> dict[str, int | str]:
+ path = sidecar_path(dataset)
+ if path is None or not path.exists():
+ raise FileNotFoundError(f"OSM sidecar does not exist: {path}")
+ with writable_sidecar_connection(dataset) as connection:
+ ensure_osm_sidecar_schema(connection)
+ total = int(connection.execute("SELECT COUNT(*) FROM osm_features").fetchone()[0] or 0)
+ should_rebuild_index = rebuild_indexes and total >= SIDECAR_INDEX_REBUILD_THRESHOLD
+ if should_rebuild_index:
+ drop_osm_sidecar_route_scope_indexes(connection)
+ connection.commit()
+ processed = 0
+ changed = 0
+ last_id = 0
+ try:
+ while True:
+ rows = connection.execute(
+ """
+ SELECT id, mode, ref, name, network, tags_json, route_scope
+ FROM osm_features
+ WHERE id > ?
+ ORDER BY id
+ LIMIT ?
+ """,
+ (last_id, max(1, int(chunk_size))),
+ ).fetchall()
+ if not rows:
+ break
+ updates: list[tuple[str | None, int]] = []
+ for row in rows:
+ last_id = int(row["id"])
+ new_scope = _classified_scope(row["mode"], row["ref"], row["name"], row["network"], row["tags_json"])
+ if _normalize_scope(row["route_scope"]) != new_scope:
+ updates.append((new_scope, last_id))
+ if updates:
+ connection.executemany("UPDATE osm_features SET route_scope = ? WHERE id = ?", updates)
+ processed += len(rows)
+ changed += len(updates)
+ connection.commit()
+ _emit_progress(
+ progress_callback,
+ "osm_labeling_batch",
+ f"Relabeled {processed}/{total} OSM sidecar features.",
+ processed,
+ total,
+ {"dataset_id": dataset.id, "changed": changed, "storage": "sidecar"},
+ )
+ finally:
+ index_rebuilds = 0
+ if should_rebuild_index:
+ rebuild_osm_sidecar_indexes(connection)
+ connection.commit()
+ index_rebuilds = 1
+ _record_sidecar_index_build(connection, dataset, path)
+ _record_sidecar_label(connection, dataset, processed=processed, changed=changed)
+ connection.commit()
+ return {"storage": "sidecar", "processed": processed, "changed": changed, "index_rebuilds": index_rebuilds}
+
+
+def _relabel_main_dataset(
+ session: Session,
+ dataset: Dataset,
+ *,
+ chunk_size: int,
+ rebuild_indexes: bool,
+ progress_callback: ProgressCallback | None,
+) -> dict[str, int | str]:
+ total = int(session.scalar(select(func.count()).select_from(OsmFeature).where(OsmFeature.dataset_id == dataset.id)) or 0)
+ should_rebuild_index = rebuild_indexes and total >= MAIN_INDEX_REBUILD_THRESHOLD
+ index_rebuilds = 0
+ if should_rebuild_index:
+ session.execute(text(f"DROP INDEX IF EXISTS {MAIN_ROUTE_SCOPE_INDEX}"))
+ session.commit()
+ processed = 0
+ changed = 0
+ last_id = 0
+ try:
+ while True:
+ rows = session.scalars(
+ select(OsmFeature)
+ .where(OsmFeature.dataset_id == dataset.id, OsmFeature.id > last_id)
+ .order_by(OsmFeature.id)
+ .limit(max(1, int(chunk_size)))
+ ).all()
+ if not rows:
+ break
+ updates: list[dict[str, object]] = []
+ for feature in rows:
+ last_id = int(feature.id)
+ new_scope = _classified_scope(feature.mode, feature.ref, feature.name, feature.network, feature.tags_json)
+ if _normalize_scope(feature.route_scope) != new_scope:
+ updates.append({"id": feature.id, "route_scope": new_scope})
+ if updates:
+ session.bulk_update_mappings(OsmFeature, updates)
+ processed += len(rows)
+ changed += len(updates)
+ session.commit()
+ _emit_progress(
+ progress_callback,
+ "osm_labeling_batch",
+ f"Relabeled {processed}/{total} main-table OSM features.",
+ processed,
+ total,
+ {"dataset_id": dataset.id, "changed": changed, "storage": "main"},
+ )
+ finally:
+ if should_rebuild_index:
+ session.execute(
+ text(
+ "CREATE INDEX IF NOT EXISTS ix_osm_features_scope_bbox "
+ "ON osm_features (dataset_id, kind, mode, route_scope, min_lon, max_lon, min_lat, max_lat)"
+ )
+ )
+ session.commit()
+ index_rebuilds = 1
+ _record_main_index_build(session, dataset)
+ return {"storage": "main", "processed": processed, "changed": changed, "index_rebuilds": index_rebuilds}
+
+
+def _classified_scope(mode: object, ref: object, name: object, network: object, tags_json: object) -> str | None:
+ return _normalize_scope(
+ infer_osm_route_scope_from_tags(
+ None if mode is None else str(mode),
+ None if ref is None else str(ref),
+ None if name is None else str(name),
+ None if network is None else str(network),
+ None if tags_json is None else str(tags_json),
+ )
+ )
+
+
+def _normalize_scope(value: object) -> str | None:
+ text_value = str(value or "").strip()
+ return text_value or None
+
+
+def _label_dependency(dataset: Dataset) -> dict[str, object]:
+ metadata = dataset_metadata(dataset)
+ storage = metadata.get("osm_storage") if isinstance(metadata, dict) else None
+ path = sidecar_path(dataset)
+ path_fingerprint: dict[str, object] | None = None
+ if path is not None:
+ resolved = Path(path)
+ if resolved.exists():
+ path_fingerprint = {"path": str(resolved), "exists": True}
+ else:
+ path_fingerprint = {"path": str(resolved), "missing": True}
+ return {
+ "dataset_id": dataset.id,
+ "source_id": dataset.source_id,
+ "kind": dataset.kind,
+ "dataset_sha256": dataset.sha256,
+ "storage": storage,
+ "sidecar": path_fingerprint,
+ "classifier_version": OSM_LABEL_FEATURES_VERSION,
+ }
+
+
+def _stamp_dataset_metadata(session: Session, dataset: Dataset, dependency_hash_value: str, output: dict[str, object]) -> None:
+ refreshed = session.get(Dataset, dataset.id)
+ if refreshed is None:
+ return
+ metadata = dataset_metadata(refreshed)
+ metadata["label_features"] = {
+ "stage": STAGE_LABEL_FEATURES,
+ "version": OSM_LABEL_FEATURES_VERSION,
+ "dependency_hash": dependency_hash_value,
+ "labeled_at": datetime.now(timezone.utc).isoformat(),
+ "processed": output.get("processed", 0),
+ "changed": output.get("changed", 0),
+ "storage": output.get("storage"),
+ }
+ refreshed.metadata_json = json.dumps(metadata, indent=2)
+ session.flush()
+
+
+def _record_sidecar_label(connection: sqlite3.Connection, dataset: Dataset, *, processed: int, changed: int) -> None:
+ connection.execute("CREATE TABLE IF NOT EXISTS pipeline_metadata (key TEXT PRIMARY KEY, value TEXT NOT NULL)")
+ connection.execute(
+ "INSERT OR REPLACE INTO pipeline_metadata (key, value) VALUES (?, ?)",
+ (
+ "label_features",
+ json.dumps(
+ {
+ "stage": STAGE_LABEL_FEATURES,
+ "version": OSM_LABEL_FEATURES_VERSION,
+ "dataset_id": dataset.id,
+ "processed": processed,
+ "changed": changed,
+ "updated_at": datetime.now(timezone.utc).isoformat(),
+ },
+ sort_keys=True,
+ separators=(",", ":"),
+ ),
+ ),
+ )
+
+
+def _record_sidecar_index_build(connection: sqlite3.Connection, dataset: Dataset, path: Path) -> None:
+ connection.execute("CREATE TABLE IF NOT EXISTS pipeline_metadata (key TEXT PRIMARY KEY, value TEXT NOT NULL)")
+ connection.execute(
+ "INSERT OR REPLACE INTO pipeline_metadata (key, value) VALUES (?, ?)",
+ (
+ "build_indexes:route_scope",
+ json.dumps(
+ {
+ "stage": STAGE_BUILD_INDEXES,
+ "version": "osm_sidecar_indexes_v1",
+ "dataset_id": dataset.id,
+ "path": str(path),
+ "updated_at": datetime.now(timezone.utc).isoformat(),
+ },
+ sort_keys=True,
+ separators=(",", ":"),
+ ),
+ ),
+ )
+
+
+def _record_main_index_build(session: Session, dataset: Dataset) -> None:
+ dependency = {
+ "dataset_id": dataset.id,
+ "index": MAIN_ROUTE_SCOPE_INDEX,
+ "version": "osm_main_indexes_v1",
+ }
+ run = start_pipeline_run(
+ session,
+ stage=STAGE_BUILD_INDEXES,
+ version="osm_main_indexes_v1",
+ dependency_hash_value=dependency_hash(dependency),
+ source_id=dataset.source_id,
+ dataset_id=dataset.id,
+ inputs=dependency,
+ )
+ finish_pipeline_run(session, run, outputs={"index": MAIN_ROUTE_SCOPE_INDEX})
+ session.commit()
+
+
+def _emit_progress(
+ callback: ProgressCallback | None,
+ event_type: str,
+ message: str,
+ current: int | None,
+ total: int | None,
+ metadata: dict[str, object] | None,
+) -> None:
+ if callback is not None:
+ callback(event_type, message, current, total, metadata)
diff --git a/app/pipeline/osm_pbf.py b/app/pipeline/osm_pbf.py
new file mode 100644
index 0000000..0b67b35
--- /dev/null
+++ b/app/pipeline/osm_pbf.py
@@ -0,0 +1,1581 @@
+from __future__ import annotations
+
+import json
+import shutil
+import subprocess
+from dataclasses import dataclass
+from datetime import datetime, timezone
+from pathlib import Path
+from typing import Any
+
+import osmium
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
+from app.config import settings
+from app.db import SessionLocal
+from app.db_lock import database_write_lock
+from app.models import Dataset, OsmDiffState, Source
+from app.osm_storage import OSM_STORAGE_MAIN, OSM_STORAGE_SIDECAR_FEATURES, effective_osm_feature_storage
+from app.performance import measure_pipeline_phase
+from app.pipeline.download import materialize_source
+from app.pipeline.osm_geojson import import_osm_geojson, prepare_osm_geojson_storage
+from app.pipeline.osm_replication import ReplicationState, apply_osm_changes, download_diff, fetch_replication_state
+from app.pipeline.state import (
+ STAGE_ACQUIRE_RAW,
+ STAGE_BUILD_INDEXES,
+ STAGE_EXTRACT_GEOMETRY,
+ STAGE_FILTER_TRANSPORT,
+ dependency_hash,
+ finish_pipeline_run,
+ start_pipeline_run,
+)
+from app.pipeline.utils import sha256_file
+
+ROUTE_MODES = {
+ "train",
+ "railway",
+ "light_rail",
+ "subway",
+ "tram",
+ "bus",
+ "trolleybus",
+ "coach",
+ "ferry",
+ "monorail",
+ "funicular",
+ "aerialway",
+}
+
+RAILWAY_MODE_BY_TAG = {
+ "rail": "train",
+ "light_rail": "light_rail",
+ "subway": "subway",
+ "tram": "tram",
+ "monorail": "monorail",
+ "funicular": "funicular",
+}
+
+EXTRACTOR_VERSION = "osmium_transport_geojson_v2_ordered_relation_members"
+TRANSPORT_FILTER_VERSION = "osmium_transport_filter_v1"
+RAW_ACQUIRE_VERSION = "osm_raw_acquire_v1"
+OSM_SIDECAR_INDEX_VERSION = "osm_sidecar_indexes_v1"
+
+
+@dataclass(frozen=True)
+class _SourceRef:
+ id: int
+ name: str
+ kind: str
+ url: str
+ country: str | None = None
+ license: str | None = None
+ notes: str | None = None
+
+
+@dataclass(frozen=True)
+class _DatasetRef:
+ id: int
+ source_id: int
+ kind: str
+ local_path: str
+ sha256: str
+ status: str
+ metadata: dict[str, Any]
+
+
+@dataclass(frozen=True)
+class _PreparedRawFile:
+ path: Path
+ sha256: str
+ metadata: dict[str, Any]
+ replication_state: ReplicationState | None = None
+ diff_state_metadata: dict[str, Any] | None = None
+
+
+@dataclass(frozen=True)
+class _PreparedTransportFile:
+ path: Path
+ sha256: str
+ metadata: dict[str, Any]
+ reused: bool
+
+
+@dataclass(frozen=True)
+class _PreparedExtract:
+ path: Path
+ sha256: str
+ summary: dict[str, Any]
+ reused: bool
+
+
+def run_osm_pbf_source(session: Session, source: Source, progress_callback=None) -> Dataset:
+ raw_dataset = _prepare_raw_osm_dataset(session, source, progress_callback=progress_callback)
+ input_dataset = raw_dataset
+ input_path = Path(raw_dataset.local_path)
+
+ if _should_prefilter(input_path):
+ input_dataset = _prepare_transport_pbf(session, source, raw_dataset, input_path)
+ input_path = Path(input_dataset.local_path)
+
+ existing_derived = _find_existing_derived(session, source, input_dataset)
+ if existing_derived is not None:
+ return existing_derived
+
+ output_dir = settings.data_dir / "derived" / f"source_{source.id}" / f"extract_dataset_{input_dataset.id}"
+ output_path = output_dir / "transport.geojson"
+ extract_summary = extract_osm_transport_geojson(input_path, output_path)
+
+ input_dataset.status = "extracted"
+ _update_dataset_metadata(input_dataset, extractor=EXTRACTOR_VERSION, extract_summary=extract_summary)
+ if input_dataset.id != raw_dataset.id:
+ raw_dataset.status = "filtered"
+ _update_dataset_metadata(raw_dataset, filtered_dataset_id=input_dataset.id)
+ session.flush()
+
+ derived_dataset = import_osm_geojson(session=session, source=source, path=output_path)
+ derived_metadata = json.loads(derived_dataset.metadata_json or "{}")
+ derived_metadata.update(
+ {
+ "stage": "derived_osm_transport_geojson",
+ "derived_from_dataset_id": input_dataset.id,
+ "raw_dataset_id": raw_dataset.id,
+ "filtered_dataset_id": input_dataset.id if input_dataset.id != raw_dataset.id else None,
+ "extractor": EXTRACTOR_VERSION,
+ "extract_summary": extract_summary,
+ }
+ )
+ derived_dataset.metadata_json = json.dumps(derived_metadata, indent=2)
+ session.flush()
+ return derived_dataset
+
+
+def run_osm_pbf_source_staged(source_id: int, progress_callback=None) -> Dataset:
+ """Run large OSM PBF imports with only short DB write-lock sections.
+
+ The expensive file work is deterministic and resumable from cached files:
+ raw source materialization, optional osmium transport filtering, GeoJSON
+ extraction, and sidecar creation all happen outside the global SQLite write
+ lock. Dataset rows are reserved/activated in short transactions.
+ """
+ source_ref = _load_source_ref(source_id)
+ _mark_source_running(source_ref.id)
+ _emit_progress(progress_callback, "osm_staged_import_started", f"Preparing staged OSM import for {source_ref.name}.", 0, 7, {"source_id": source_ref.id})
+
+ prepared_raw = _prepare_raw_file_staged(source_ref, progress_callback=progress_callback)
+ raw_dataset = _reserve_raw_dataset(source_ref, prepared_raw)
+ _emit_progress(
+ progress_callback,
+ "osm_raw_dataset_reserved",
+ f"Reserved raw OSM dataset #{raw_dataset.id}.",
+ 2,
+ 7,
+ {"dataset_id": raw_dataset.id, "path": raw_dataset.local_path, "sha256": raw_dataset.sha256},
+ )
+
+ input_dataset = raw_dataset
+ input_path = Path(raw_dataset.local_path)
+ filtered_dataset: _DatasetRef | None = None
+ if _should_prefilter(input_path):
+ prepared_transport = _prepare_transport_file_staged(source_ref, raw_dataset, input_path, progress_callback=progress_callback)
+ filtered_dataset = _reserve_transport_dataset(source_ref, raw_dataset, prepared_transport)
+ input_dataset = filtered_dataset
+ input_path = Path(filtered_dataset.local_path)
+ _emit_progress(
+ progress_callback,
+ "osm_transport_dataset_reserved",
+ f"Reserved filtered OSM transport dataset #{filtered_dataset.id}.",
+ 3,
+ 7,
+ {"dataset_id": filtered_dataset.id, "path": filtered_dataset.local_path, "sha256": filtered_dataset.sha256, "reused": prepared_transport.reused},
+ )
+
+ existing = _existing_active_derived_ref(source_ref.id, input_dataset.id)
+ if existing is not None:
+ _activate_existing_derived(source_ref.id, existing.id)
+ _emit_progress(progress_callback, "osm_staged_import_reused", f"Reused active OSM transport dataset #{existing.id}.", 7, 7, {"dataset_id": existing.id})
+ return _load_dataset(existing.id)
+
+ extract = _extract_transport_geojson_staged(source_ref, input_dataset, input_path, progress_callback=progress_callback)
+ derived_dataset = _reserve_derived_dataset(
+ source_ref=source_ref,
+ raw_dataset=raw_dataset,
+ input_dataset=input_dataset,
+ filtered_dataset=filtered_dataset,
+ extract=extract,
+ )
+ _emit_progress(
+ progress_callback,
+ "osm_derived_dataset_reserved",
+ f"Reserved derived OSM dataset #{derived_dataset.id}.",
+ 5,
+ 7,
+ {"dataset_id": derived_dataset.id, "path": derived_dataset.local_path, "sha256": derived_dataset.sha256, "extract_reused": extract.reused},
+ )
+
+ sidecar_metadata = _prepare_derived_storage_staged(derived_dataset, extract, progress_callback=progress_callback)
+ activated_id = _activate_staged_osm_import(
+ source_ref=source_ref,
+ raw_dataset=raw_dataset,
+ filtered_dataset=filtered_dataset,
+ input_dataset=input_dataset,
+ derived_dataset=derived_dataset,
+ extract=extract,
+ sidecar_metadata=sidecar_metadata,
+ )
+ _emit_progress(progress_callback, "osm_staged_import_completed", f"Activated OSM dataset #{activated_id}.", 7, 7, {"dataset_id": activated_id})
+ return _load_dataset(activated_id)
+
+
+def _load_source_ref(source_id: int) -> _SourceRef:
+ with SessionLocal() as session:
+ source = session.get(Source, source_id)
+ if source is None:
+ raise ValueError(f"source not found: {source_id}")
+ if source.kind != "osm_pbf":
+ raise ValueError(f"staged OSM import requires source kind osm_pbf, got {source.kind}")
+ return _SourceRef(
+ id=source.id,
+ name=source.name,
+ kind=source.kind,
+ url=source.url,
+ country=source.country,
+ license=source.license,
+ notes=source.notes,
+ )
+
+
+def _load_dataset(dataset_id: int) -> Dataset:
+ with SessionLocal() as session:
+ dataset = session.get(Dataset, dataset_id)
+ if dataset is None:
+ raise ValueError(f"dataset not found after staged import: {dataset_id}")
+ return dataset
+
+
+def _mark_source_running(source_id: int) -> None:
+ with database_write_lock(f"osm_staged_import:{source_id}:start", timeout=30):
+ with SessionLocal() as session:
+ source = session.get(Source, source_id)
+ if source is None:
+ raise ValueError(f"source not found: {source_id}")
+ source.status = "running"
+ source.last_error = None
+ source.last_run_at = datetime.now(timezone.utc)
+ session.commit()
+
+
+def _prepare_raw_file_staged(source: _SourceRef, progress_callback=None) -> _PreparedRawFile:
+ diff_raw = _try_prepare_raw_file_from_diffs_staged(source, progress_callback=progress_callback)
+ if diff_raw is not None:
+ return diff_raw
+
+ _emit_progress(progress_callback, "osm_full_snapshot_started", f"Downloading/copying full OSM snapshot for {source.name}.", 1, 7, {"source_id": source.id})
+ with measure_pipeline_phase("osm_full_snapshot", source_id=source.id, metadata={"url": source.url}) as metric:
+ raw_path = materialize_source(source) # type: ignore[arg-type]
+ raw_hash = sha256_file(raw_path)
+ metric.update({"path": str(raw_path), "sha256": raw_hash, "bytes": raw_path.stat().st_size if raw_path.exists() else None})
+ metadata = {
+ "stage": "raw_osm",
+ "raw_format": _raw_format(raw_path),
+ "source_url": source.url,
+ "import_mode": "staged_short_lock",
+ }
+ replication_state = _fetch_current_replication_state_for_snapshot(source, progress_callback=progress_callback)
+ if replication_state is not None:
+ metadata["replication_state"] = {
+ "updates_url": _source_updates_url(source), # type: ignore[arg-type]
+ "sequence_number": replication_state.sequence_number,
+ "timestamp": replication_state.timestamp,
+ }
+ _emit_progress(progress_callback, "osm_full_snapshot_completed", "Prepared raw OSM snapshot file.", 1, 7, {"path": str(raw_path), "sha256": raw_hash})
+ return _PreparedRawFile(path=raw_path, sha256=raw_hash, metadata=metadata, replication_state=replication_state, diff_state_metadata={"source": "full_snapshot"} if replication_state is not None else None)
+
+
+def _try_prepare_raw_file_from_diffs_staged(source: _SourceRef, progress_callback=None) -> _PreparedRawFile | None:
+ updates_url = _source_updates_url(source) # type: ignore[arg-type]
+ if not updates_url:
+ return None
+ with SessionLocal() as session:
+ current_state = _latest_diff_state(session, source.id)
+ if current_state is None or current_state.raw_dataset_id is None:
+ _emit_progress(progress_callback, "osm_diff_fallback", "No local OSM replication state yet; using full snapshot.", None, None, {"updates_url": updates_url})
+ return None
+ base_dataset = session.get(Dataset, current_state.raw_dataset_id)
+ if base_dataset is None or not Path(base_dataset.local_path).exists():
+ _emit_progress(progress_callback, "osm_diff_fallback", "Local raw OSM base is missing; using full snapshot.", None, None, {"updates_url": updates_url})
+ return None
+ base_ref = _dataset_ref(base_dataset)
+ local_sequence = current_state.sequence_number
+
+ try:
+ remote_state = fetch_replication_state(updates_url, timeout=settings.osm_diff_state_timeout_seconds)
+ except Exception as exc: # noqa: BLE001 - correctness fallback
+ _emit_progress(progress_callback, "osm_diff_fallback", f"Could not read OSM replication state; using full snapshot: {exc}", None, None, {"updates_url": updates_url})
+ return None
+
+ if remote_state.sequence_number <= local_sequence:
+ _emit_progress(
+ progress_callback,
+ "osm_diff_up_to_date",
+ "Local raw OSM extract is already at the latest known replication sequence.",
+ remote_state.sequence_number,
+ remote_state.sequence_number,
+ {"updates_url": updates_url, "sequence_number": remote_state.sequence_number},
+ )
+ return _PreparedRawFile(
+ path=Path(base_ref.local_path),
+ sha256=base_ref.sha256,
+ metadata=base_ref.metadata,
+ replication_state=remote_state,
+ diff_state_metadata={"source": "existing_raw_dataset", "raw_dataset_id": base_ref.id},
+ )
+
+ gap = remote_state.sequence_number - local_sequence
+ if gap > settings.osm_diff_max_sequence_gap:
+ _emit_progress(
+ progress_callback,
+ "osm_diff_fallback",
+ "OSM replication gap is too large; using full snapshot.",
+ local_sequence,
+ remote_state.sequence_number,
+ {"gap": gap, "max_gap": settings.osm_diff_max_sequence_gap, "updates_url": updates_url},
+ )
+ return None
+
+ host_tool = _host_tool_path()
+ if not host_tool.exists():
+ _emit_progress(progress_callback, "osm_diff_fallback", "host_tool.sh is missing; using full snapshot.", None, None, {"host_tool": str(host_tool)})
+ return None
+
+ try:
+ return _apply_diff_range_files_staged(
+ source=source,
+ base_dataset=base_ref,
+ updates_url=updates_url,
+ local_sequence=local_sequence,
+ remote_state=remote_state,
+ host_tool=host_tool,
+ progress_callback=progress_callback,
+ )
+ except Exception as exc: # noqa: BLE001 - fall back to full snapshot rather than risk a bad base
+ _emit_progress(progress_callback, "osm_diff_fallback", f"OSM diff application failed; using full snapshot: {exc}", None, None, {"updates_url": updates_url})
+ return None
+
+
+def _apply_diff_range_files_staged(
+ *,
+ source: _SourceRef,
+ base_dataset: _DatasetRef,
+ updates_url: str,
+ local_sequence: int,
+ remote_state: ReplicationState,
+ host_tool: Path,
+ progress_callback=None,
+) -> _PreparedRawFile:
+ update_root = settings.data_dir / "sources" / f"source_{source.id}" / "updates"
+ work_root = settings.data_dir / "sources" / f"source_{source.id}" / "diff_work"
+ work_root.mkdir(parents=True, exist_ok=True)
+ current_path = Path(base_dataset.local_path)
+ batch_size = max(1, int(settings.osm_diff_apply_batch_size))
+ sequences = list(range(local_sequence + 1, remote_state.sequence_number + 1))
+ applied_sequences: list[int] = []
+ _emit_progress(
+ progress_callback,
+ "osm_diff_started",
+ f"Applying {len(sequences)} OSM replication diffs.",
+ local_sequence,
+ remote_state.sequence_number,
+ {"updates_url": updates_url, "from_sequence": local_sequence + 1, "to_sequence": remote_state.sequence_number},
+ )
+ with measure_pipeline_phase("osm_diff_apply", source_id=source.id, metadata={"from_sequence": local_sequence + 1, "to_sequence": remote_state.sequence_number}) as metric:
+ for batch_start in range(0, len(sequences), batch_size):
+ batch = sequences[batch_start : batch_start + batch_size]
+ diff_paths = []
+ for sequence in batch:
+ diff_path = download_diff(updates_url, sequence, update_root)
+ diff_paths.append(diff_path)
+ _emit_progress(
+ progress_callback,
+ "osm_diff_downloaded",
+ f"Downloaded OSM diff sequence {sequence}.",
+ sequence,
+ remote_state.sequence_number,
+ {"path": str(diff_path), "sequence_number": sequence},
+ )
+ temp_output = work_root / f"source_{source.id}_{batch[0]}_{batch[-1]}.tmp.osm.pbf"
+ completed = apply_osm_changes(current_path, diff_paths, temp_output, host_tool)
+ current_path = _store_updated_raw_pbf(source, temp_output) # type: ignore[arg-type]
+ applied_sequences.extend(batch)
+ _emit_progress(
+ progress_callback,
+ "osm_diff_applied",
+ f"Applied OSM diff sequences {batch[0]}-{batch[-1]}.",
+ batch[-1],
+ remote_state.sequence_number,
+ {
+ "output_path": str(current_path),
+ "stdout": completed.stdout.strip(),
+ "stderr": completed.stderr.strip(),
+ "batch_start": batch[0],
+ "batch_end": batch[-1],
+ },
+ )
+ raw_hash = sha256_file(current_path)
+ metric.update({"applied_sequences": applied_sequences, "path": str(current_path), "sha256": raw_hash, "bytes": current_path.stat().st_size if current_path.exists() else None})
+ metadata = {
+ "stage": "raw_osm",
+ "raw_format": _raw_format(current_path),
+ "source_url": source.url,
+ "import_mode": "staged_short_lock",
+ "replication_state": {
+ "updates_url": updates_url,
+ "sequence_number": remote_state.sequence_number,
+ "timestamp": remote_state.timestamp,
+ },
+ "diff_update": {
+ "base_dataset_id": base_dataset.id,
+ "base_sequence_number": local_sequence,
+ "applied_sequences": applied_sequences,
+ },
+ }
+ return _PreparedRawFile(
+ path=current_path,
+ sha256=raw_hash,
+ metadata=metadata,
+ replication_state=remote_state,
+ diff_state_metadata={"base_dataset_id": base_dataset.id, "applied_sequences": applied_sequences},
+ )
+
+
+def _fetch_current_replication_state_for_snapshot(source: _SourceRef, progress_callback=None) -> ReplicationState | None:
+ updates_url = _source_updates_url(source) # type: ignore[arg-type]
+ if not updates_url:
+ return None
+ try:
+ return fetch_replication_state(updates_url, timeout=settings.osm_diff_state_timeout_seconds)
+ except Exception as exc: # noqa: BLE001 - full snapshot is still usable without diff state
+ _emit_progress(progress_callback, "osm_diff_state_unavailable", f"Could not record OSM replication state: {exc}", None, None, {"updates_url": updates_url})
+ return None
+
+
+def _reserve_raw_dataset(source_ref: _SourceRef, prepared: _PreparedRawFile) -> _DatasetRef:
+ with database_write_lock(f"osm_staged_import:{source_ref.id}:reserve_raw", timeout=60):
+ with SessionLocal() as session:
+ source = session.get(Source, source_ref.id)
+ if source is None:
+ raise ValueError(f"source not found: {source_ref.id}")
+ dataset = _find_raw_dataset(session, source, prepared.sha256)
+ if dataset is None:
+ dataset = Dataset(
+ source_id=source.id,
+ kind="osm_pbf_raw",
+ local_path=str(prepared.path),
+ sha256=prepared.sha256,
+ is_active=False,
+ status="committed",
+ metadata_json=json.dumps(prepared.metadata, indent=2),
+ )
+ session.add(dataset)
+ session.flush()
+ else:
+ dataset.local_path = str(prepared.path)
+ dataset.status = "committed"
+ dataset.metadata_json = json.dumps({**_metadata(dataset), **prepared.metadata}, indent=2)
+ if prepared.replication_state is not None:
+ _record_diff_state(
+ session,
+ source=source,
+ raw_dataset=dataset,
+ updates_url=str(prepared.metadata.get("replication_state", {}).get("updates_url") or _source_updates_url(source) or ""),
+ state=prepared.replication_state,
+ metadata=prepared.diff_state_metadata,
+ )
+ _record_pipeline_stage(
+ session,
+ stage=STAGE_ACQUIRE_RAW,
+ version=RAW_ACQUIRE_VERSION,
+ source_id=source.id,
+ dataset=dataset,
+ inputs={
+ "source_url": source.url,
+ "source_kind": source.kind,
+ "remote": prepared.metadata.get("replication_state") or prepared.metadata.get("source_url"),
+ },
+ outputs={
+ "path": str(prepared.path),
+ "sha256": prepared.sha256,
+ "raw_format": prepared.metadata.get("raw_format"),
+ "diff_update": prepared.metadata.get("diff_update"),
+ },
+ )
+ source.status = "running"
+ source.last_error = None
+ session.commit()
+ return _dataset_ref(dataset)
+
+
+def _prepare_transport_file_staged(source: _SourceRef, raw_dataset: _DatasetRef, raw_path: Path, progress_callback=None) -> _PreparedTransportFile:
+ output_path = _transport_filter_path_for_raw_id(source.id, raw_dataset.id, raw_path)
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+ metadata_path = output_path.with_suffix(output_path.suffix + ".metadata.json")
+ existing_metadata = _read_json_file(metadata_path)
+ if output_path.exists() and existing_metadata.get("input_sha256") == raw_dataset.sha256 and existing_metadata.get("filter") == TRANSPORT_FILTER_VERSION:
+ filtered_hash = sha256_file(output_path)
+ _emit_progress(progress_callback, "osm_transport_filter_reused", "Reusing existing filtered OSM transport extract.", 3, 7, {"path": str(output_path), "sha256": filtered_hash})
+ return _PreparedTransportFile(path=output_path, sha256=filtered_hash, metadata=existing_metadata, reused=True)
+
+ script_path = _prefilter_script_path()
+ if not script_path.exists():
+ raise FileNotFoundError(f"OSM transport filter script not found: {script_path}")
+ _emit_progress(progress_callback, "osm_transport_filter_started", "Filtering OSM PBF to public-transport objects.", 2, 7, {"input_path": str(raw_path), "output_path": str(output_path)})
+ with measure_pipeline_phase("osm_transport_filter", source_id=source.id, dataset_id=raw_dataset.id, metadata={"input_path": str(raw_path), "output_path": str(output_path)}) as metric:
+ command = [str(script_path), str(raw_path), str(output_path)]
+ try:
+ completed = subprocess.run(command, check=True, capture_output=True, text=True)
+ except subprocess.CalledProcessError as exc:
+ stderr = (exc.stderr or "").strip()
+ stdout = (exc.stdout or "").strip()
+ details = stderr or stdout or f"exit code {exc.returncode}"
+ raise RuntimeError(f"OSM transport filter failed for {raw_path}: {details}") from exc
+ filtered_hash = sha256_file(output_path)
+ metric.update({"sha256": filtered_hash, "bytes": output_path.stat().st_size if output_path.exists() else None})
+ metadata = {
+ "stage": "filtered_osm_transport_pbf",
+ "raw_format": _raw_format(output_path),
+ "derived_from_dataset_id": raw_dataset.id,
+ "source_url": source.url,
+ "filter": TRANSPORT_FILTER_VERSION,
+ "filter_script": str(script_path),
+ "input_path": str(raw_path),
+ "input_sha256": raw_dataset.sha256,
+ "output_path": str(output_path),
+ "stdout": completed.stdout.strip(),
+ "stderr": completed.stderr.strip(),
+ "import_mode": "staged_short_lock",
+ }
+ metadata_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8")
+ _emit_progress(progress_callback, "osm_transport_filter_completed", "Filtered OSM transport extract.", 3, 7, {"path": str(output_path), "sha256": filtered_hash})
+ return _PreparedTransportFile(path=output_path, sha256=filtered_hash, metadata=metadata, reused=False)
+
+
+def _reserve_transport_dataset(source_ref: _SourceRef, raw_dataset: _DatasetRef, prepared: _PreparedTransportFile) -> _DatasetRef:
+ with database_write_lock(f"osm_staged_import:{source_ref.id}:reserve_transport", timeout=60):
+ with SessionLocal() as session:
+ source = session.get(Source, source_ref.id)
+ raw = session.get(Dataset, raw_dataset.id)
+ if source is None or raw is None:
+ raise ValueError("source or raw dataset disappeared during staged import")
+ dataset = _find_transport_dataset_by_raw_id(session, source.id, raw_dataset.id)
+ if dataset is None:
+ dataset = Dataset(
+ source_id=source.id,
+ kind="osm_pbf_transport",
+ local_path=str(prepared.path),
+ sha256=prepared.sha256,
+ is_active=False,
+ status="filtered",
+ metadata_json=json.dumps(prepared.metadata, indent=2),
+ )
+ session.add(dataset)
+ session.flush()
+ else:
+ dataset.local_path = str(prepared.path)
+ dataset.sha256 = prepared.sha256
+ dataset.status = "filtered"
+ dataset.metadata_json = json.dumps(prepared.metadata, indent=2)
+ raw.status = "filtered"
+ raw.metadata_json = json.dumps({**_metadata(raw), "filtered_dataset_id": dataset.id}, indent=2)
+ _record_pipeline_stage(
+ session,
+ stage=STAGE_FILTER_TRANSPORT,
+ version=TRANSPORT_FILTER_VERSION,
+ source_id=source.id,
+ dataset=dataset,
+ inputs={
+ "raw_dataset_id": raw_dataset.id,
+ "raw_sha256": raw_dataset.sha256,
+ "filter_script": prepared.metadata.get("filter_script"),
+ },
+ outputs={"path": str(prepared.path), "sha256": prepared.sha256, "reused": prepared.reused},
+ )
+ session.commit()
+ return _dataset_ref(dataset)
+
+
+def _extract_transport_geojson_staged(source: _SourceRef, input_dataset: _DatasetRef, input_path: Path, progress_callback=None) -> _PreparedExtract:
+ output_dir = settings.data_dir / "derived" / f"source_{source.id}" / f"extract_dataset_{input_dataset.id}"
+ output_path = output_dir / "transport.geojson"
+ summary_path = output_path.with_suffix(".summary.json")
+ existing_summary = _read_json_file(summary_path)
+ if output_path.exists() and existing_summary.get("input_sha256") == input_dataset.sha256 and existing_summary.get("extractor") == EXTRACTOR_VERSION:
+ output_hash = sha256_file(output_path)
+ _emit_progress(progress_callback, "osm_extract_reused", "Reusing existing extracted OSM transport GeoJSON.", 4, 7, {"path": str(output_path), "sha256": output_hash})
+ return _PreparedExtract(path=output_path, sha256=output_hash, summary=existing_summary["extract_summary"], reused=True)
+
+ _emit_progress(progress_callback, "osm_extract_started", "Extracting route, stop, and infrastructure geometry from OSM.", 4, 7, {"input_path": str(input_path), "output_path": str(output_path)})
+ with measure_pipeline_phase("osm_transport_extract", source_id=source.id, dataset_id=input_dataset.id, metadata={"input_path": str(input_path), "output_path": str(output_path)}) as metric:
+ extract_summary = extract_osm_transport_geojson(input_path, output_path)
+ output_hash = sha256_file(output_path)
+ metric.update({**extract_summary, "sha256": output_hash, "bytes": output_path.stat().st_size if output_path.exists() else None})
+ summary = {
+ "input_dataset_id": input_dataset.id,
+ "input_sha256": input_dataset.sha256,
+ "extractor": EXTRACTOR_VERSION,
+ "extract_summary": extract_summary,
+ }
+ summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
+ _emit_progress(progress_callback, "osm_extract_completed", "Extracted OSM transport GeoJSON.", 4, 7, {"path": str(output_path), "sha256": output_hash, **extract_summary})
+ return _PreparedExtract(path=output_path, sha256=output_hash, summary=extract_summary, reused=False)
+
+
+def _existing_active_derived_ref(source_id: int, input_dataset_id: int) -> _DatasetRef | None:
+ with SessionLocal() as session:
+ source = session.get(Source, source_id)
+ if source is None:
+ return None
+ dataset = _find_existing_derived(session, source, Dataset(id=input_dataset_id))
+ if dataset is None:
+ return None
+ return _dataset_ref(dataset)
+
+
+def _activate_existing_derived(source_id: int, derived_dataset_id: int) -> None:
+ with database_write_lock(f"osm_staged_import:{source_id}:reuse_existing", timeout=60):
+ with SessionLocal() as session:
+ source = session.get(Source, source_id)
+ dataset = session.get(Dataset, derived_dataset_id)
+ if source is None or dataset is None:
+ return
+ for existing in source.datasets:
+ existing.is_active = existing.id == dataset.id
+ source.status = "ok"
+ source.last_error = None
+ source.last_run_at = datetime.now(timezone.utc)
+ session.commit()
+
+
+def _reserve_derived_dataset(
+ *,
+ source_ref: _SourceRef,
+ raw_dataset: _DatasetRef,
+ input_dataset: _DatasetRef,
+ filtered_dataset: _DatasetRef | None,
+ extract: _PreparedExtract,
+) -> _DatasetRef:
+ metadata = {
+ "stage": "derived_osm_transport_geojson",
+ "derived_from_dataset_id": input_dataset.id,
+ "raw_dataset_id": raw_dataset.id,
+ "filtered_dataset_id": None if filtered_dataset is None else filtered_dataset.id,
+ "extractor": EXTRACTOR_VERSION,
+ "extract_summary": extract.summary,
+ "import_mode": "staged_short_lock",
+ "sidecar_status": "pending",
+ }
+ with database_write_lock(f"osm_staged_import:{source_ref.id}:reserve_derived", timeout=60):
+ with SessionLocal() as session:
+ source = session.get(Source, source_ref.id)
+ if source is None:
+ raise ValueError(f"source not found: {source_ref.id}")
+ dataset = _find_staged_derived_dataset(session, source.id, input_dataset.id, extract.sha256)
+ if dataset is None:
+ dataset = Dataset(
+ source_id=source.id,
+ kind="osm_geojson",
+ local_path=str(extract.path),
+ sha256=extract.sha256,
+ is_active=False,
+ status="sidecar_staging",
+ metadata_json=json.dumps(metadata, indent=2),
+ )
+ session.add(dataset)
+ session.flush()
+ else:
+ dataset.local_path = str(extract.path)
+ dataset.sha256 = extract.sha256
+ dataset.status = "sidecar_staging"
+ dataset.metadata_json = json.dumps({**_metadata(dataset), **metadata}, indent=2)
+ _record_pipeline_stage(
+ session,
+ stage=STAGE_EXTRACT_GEOMETRY,
+ version=EXTRACTOR_VERSION,
+ source_id=source.id,
+ dataset=dataset,
+ inputs={
+ "input_dataset_id": input_dataset.id,
+ "input_sha256": input_dataset.sha256,
+ "extractor": EXTRACTOR_VERSION,
+ },
+ outputs={"path": str(extract.path), "sha256": extract.sha256, "summary": extract.summary, "reused": extract.reused},
+ )
+ session.commit()
+ return _dataset_ref(dataset)
+
+
+def _prepare_derived_storage_staged(derived_dataset: _DatasetRef, extract: _PreparedExtract, progress_callback=None) -> dict[str, object]:
+ storage = derived_dataset.metadata.get("osm_storage")
+ if isinstance(storage, dict):
+ if storage.get("mode") == OSM_STORAGE_MAIN and derived_dataset.metadata.get("storage_status") == "ready":
+ _emit_progress(progress_callback, "osm_storage_reused", "Reusing existing OSM main-table storage.", 6, 7, {"dataset_id": derived_dataset.id})
+ return derived_dataset.metadata
+ sidecar = storage.get("sidecar_path")
+ if sidecar and Path(str(sidecar)).exists() and derived_dataset.metadata.get("sidecar_status") == "ready":
+ _emit_progress(progress_callback, "osm_sidecar_reused", "Reusing existing OSM feature sidecar.", 6, 7, {"dataset_id": derived_dataset.id, "sidecar_path": str(sidecar)})
+ return derived_dataset.metadata
+
+ storage_mode = effective_osm_feature_storage()
+ storage_label = "main-table OSM feature storage" if storage_mode == OSM_STORAGE_MAIN else "OSM feature sidecar"
+ started_event = "osm_storage_started" if storage_mode == OSM_STORAGE_MAIN else "osm_sidecar_started"
+ completed_event = "osm_storage_completed" if storage_mode == OSM_STORAGE_MAIN else "osm_sidecar_completed"
+ _emit_progress(progress_callback, started_event, f"Building {storage_label}.", 5, 7, {"dataset_id": derived_dataset.id, "path": str(extract.path), "storage_mode": storage_mode})
+ transient_dataset = Dataset(
+ id=derived_dataset.id,
+ source_id=derived_dataset.source_id,
+ kind=derived_dataset.kind,
+ local_path=derived_dataset.local_path,
+ sha256=derived_dataset.sha256,
+ is_active=False,
+ status=derived_dataset.status,
+ metadata_json=json.dumps(derived_dataset.metadata, indent=2),
+ )
+ with measure_pipeline_phase("osm_sidecar_build", source_id=derived_dataset.source_id, dataset_id=derived_dataset.id, metadata={"path": str(extract.path)}) as metric:
+ with SessionLocal() as session:
+ sidecar_metadata = prepare_osm_geojson_storage(
+ session=session,
+ dataset=transient_dataset,
+ path=extract.path,
+ source_hash=derived_dataset.sha256,
+ storage_mode=storage_mode,
+ )
+ session.commit()
+ metric.update(sidecar_metadata)
+ metadata = {**derived_dataset.metadata, **sidecar_metadata, "sidecar_status": "ready" if storage_mode == OSM_STORAGE_SIDECAR_FEATURES else "not_used", "storage_status": "ready"}
+ _emit_progress(progress_callback, completed_event, f"Built {storage_label}.", 6, 7, {"dataset_id": derived_dataset.id, **sidecar_metadata})
+ return metadata
+
+
+def _activate_staged_osm_import(
+ *,
+ source_ref: _SourceRef,
+ raw_dataset: _DatasetRef,
+ filtered_dataset: _DatasetRef | None,
+ input_dataset: _DatasetRef,
+ derived_dataset: _DatasetRef,
+ extract: _PreparedExtract,
+ sidecar_metadata: dict[str, object],
+) -> int:
+ metadata = {
+ **sidecar_metadata,
+ "stage": "derived_osm_transport_geojson",
+ "derived_from_dataset_id": input_dataset.id,
+ "raw_dataset_id": raw_dataset.id,
+ "filtered_dataset_id": None if filtered_dataset is None else filtered_dataset.id,
+ "extractor": EXTRACTOR_VERSION,
+ "extract_summary": extract.summary,
+ "import_mode": "staged_short_lock",
+ "sidecar_status": "ready",
+ }
+ with database_write_lock(f"osm_staged_import:{source_ref.id}:activate", timeout=60):
+ with SessionLocal() as session:
+ source = session.get(Source, source_ref.id)
+ raw = session.get(Dataset, raw_dataset.id)
+ filtered = session.get(Dataset, filtered_dataset.id) if filtered_dataset is not None else None
+ derived = session.get(Dataset, derived_dataset.id)
+ if source is None or raw is None or derived is None:
+ raise ValueError("staged OSM activation lost source or dataset rows")
+ for dataset in source.datasets:
+ dataset.is_active = False
+ raw.status = "filtered" if filtered is not None else "extracted"
+ raw.is_active = False
+ raw.metadata_json = json.dumps({**_metadata(raw), "extractor": EXTRACTOR_VERSION, "extract_summary": extract.summary}, indent=2)
+ if filtered is not None:
+ filtered.status = "extracted"
+ filtered.is_active = False
+ filtered.metadata_json = json.dumps({**_metadata(filtered), "extractor": EXTRACTOR_VERSION, "extract_summary": extract.summary}, indent=2)
+ derived.status = "imported"
+ derived.is_active = True
+ derived.local_path = str(extract.path)
+ derived.sha256 = extract.sha256
+ derived.metadata_json = json.dumps(metadata, indent=2)
+ _record_pipeline_stage(
+ session,
+ stage=STAGE_BUILD_INDEXES,
+ version=OSM_SIDECAR_INDEX_VERSION,
+ source_id=source.id,
+ dataset=derived,
+ inputs={
+ "dataset_id": derived.id,
+ "dataset_sha256": derived.sha256,
+ "sidecar_schema": "osm_features_v1",
+ "indexed_columns": ["kind", "mode", "route_scope", "bbox", "route_key", "ref", "identity"],
+ },
+ outputs=sidecar_metadata.get("osm_storage") if isinstance(sidecar_metadata.get("osm_storage"), dict) else sidecar_metadata,
+ )
+ source.status = "ok"
+ source.last_error = None
+ source.last_run_at = datetime.now(timezone.utc)
+ session.commit()
+ return derived.id
+
+
+def _find_transport_dataset_by_raw_id(session: Session, source_id: int, raw_dataset_id: int) -> Dataset | None:
+ datasets = session.scalars(
+ select(Dataset)
+ .where(Dataset.source_id == source_id, Dataset.kind == "osm_pbf_transport")
+ .order_by(Dataset.id.desc())
+ ).all()
+ for dataset in datasets:
+ metadata = _metadata(dataset)
+ if metadata.get("derived_from_dataset_id") == raw_dataset_id and metadata.get("filter") == TRANSPORT_FILTER_VERSION:
+ return dataset
+ return None
+
+
+def _find_staged_derived_dataset(session: Session, source_id: int, input_dataset_id: int, extract_hash: str) -> Dataset | None:
+ datasets = session.scalars(
+ select(Dataset)
+ .where(
+ Dataset.source_id == source_id,
+ Dataset.kind == "osm_geojson",
+ Dataset.status.in_(["sidecar_staging", "importing"]),
+ Dataset.sha256 == extract_hash,
+ )
+ .order_by(Dataset.id.desc())
+ ).all()
+ for dataset in datasets:
+ metadata = _metadata(dataset)
+ if metadata.get("derived_from_dataset_id") == input_dataset_id and metadata.get("extractor") == EXTRACTOR_VERSION:
+ return dataset
+ return None
+
+
+def _dataset_ref(dataset: Dataset) -> _DatasetRef:
+ return _DatasetRef(
+ id=int(dataset.id),
+ source_id=int(dataset.source_id),
+ kind=dataset.kind,
+ local_path=dataset.local_path,
+ sha256=dataset.sha256,
+ status=dataset.status,
+ metadata=_metadata(dataset),
+ )
+
+
+def _transport_filter_path_for_raw_id(source_id: int, raw_dataset_id: int, raw_path: Path) -> Path:
+ raw_format = _raw_format(raw_path)
+ suffix = ".osm.pbf" if raw_format == "osm_pbf" else ".osm"
+ return settings.data_dir / "derived" / f"source_{source_id}" / f"raw_dataset_{raw_dataset_id}" / f"transport{suffix}"
+
+
+def _read_json_file(path: Path) -> dict[str, Any]:
+ if not path.exists():
+ return {}
+ try:
+ data = json.loads(path.read_text(encoding="utf-8"))
+ except (OSError, json.JSONDecodeError):
+ return {}
+ return data if isinstance(data, dict) else {}
+
+
+def _prepare_raw_osm_dataset(session: Session, source: Source, progress_callback=None) -> Dataset:
+ diff_dataset = _try_prepare_raw_from_diffs(session, source, progress_callback=progress_callback)
+ if diff_dataset is not None:
+ return diff_dataset
+
+ _emit_progress(progress_callback, "osm_full_snapshot_started", f"Downloading/copying full OSM snapshot for {source.name}.", None, None, {"source_id": source.id})
+ raw_path = materialize_source(source)
+ raw_hash = sha256_file(raw_path)
+ raw_dataset = _find_raw_dataset(session, source, raw_hash) or _commit_raw_dataset(session, source, raw_path, raw_hash)
+ _record_current_replication_state_for_snapshot(session, source, raw_dataset, progress_callback=progress_callback)
+ _emit_progress(progress_callback, "osm_full_snapshot_completed", f"Prepared raw OSM dataset #{raw_dataset.id}.", None, None, {"dataset_id": raw_dataset.id})
+ return raw_dataset
+
+
+def extract_osm_transport_geojson(input_path: Path, output_path: Path) -> dict[str, Any]:
+ scan = _TransportScanHandler()
+ scan.apply_file(str(input_path))
+
+ geometry = _TransportGeometryHandler(scan.route_relations, scan.route_way_ids)
+ geometry.apply_file(str(input_path), locations=True)
+
+ features = geometry.features()
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+ output_path.write_text(json.dumps({"type": "FeatureCollection", "features": features}), encoding="utf-8")
+
+ route_features = sum(1 for feature in features if feature["properties"].get("type") == "route")
+ infra_features = sum(1 for feature in features if feature["properties"].get("kind") == "infra")
+ stop_features = len(features) - route_features - infra_features
+ return {
+ "input_path": str(input_path),
+ "output_path": str(output_path),
+ "route_relations_seen": len(scan.route_relations),
+ "route_relation_member_ways": len(scan.route_way_ids),
+ "features": len(features),
+ "route_features": route_features,
+ "infrastructure_features": infra_features,
+ "stop_station_features": stop_features,
+ "route_relations_without_geometry": geometry.route_relations_without_geometry,
+ }
+
+
+def _commit_raw_dataset(session: Session, source: Source, path: Path, source_hash: str) -> Dataset:
+ for dataset in source.datasets:
+ dataset.is_active = False
+
+ dataset = Dataset(
+ source_id=source.id,
+ kind="osm_pbf_raw",
+ local_path=str(path),
+ sha256=source_hash,
+ is_active=False,
+ status="committed",
+ metadata_json=json.dumps(
+ {
+ "stage": "raw_osm",
+ "raw_format": _raw_format(path),
+ "source_url": source.url,
+ },
+ indent=2,
+ ),
+ )
+ session.add(dataset)
+ session.flush()
+ return dataset
+
+
+def _try_prepare_raw_from_diffs(session: Session, source: Source, progress_callback=None) -> Dataset | None:
+ updates_url = _source_updates_url(source)
+ if not updates_url:
+ return None
+
+ current_state = _latest_diff_state(session, source.id)
+ if current_state is None or current_state.raw_dataset_id is None:
+ _emit_progress(progress_callback, "osm_diff_fallback", "No local OSM replication state yet; using full snapshot.", None, None, {"updates_url": updates_url})
+ return None
+ raw_dataset = session.get(Dataset, current_state.raw_dataset_id)
+ if raw_dataset is None or not Path(raw_dataset.local_path).exists():
+ _emit_progress(progress_callback, "osm_diff_fallback", "Local raw OSM base is missing; using full snapshot.", None, None, {"updates_url": updates_url})
+ return None
+
+ try:
+ remote_state = fetch_replication_state(updates_url, timeout=settings.osm_diff_state_timeout_seconds)
+ except Exception as exc: # noqa: BLE001 - correctness fallback
+ _emit_progress(progress_callback, "osm_diff_fallback", f"Could not read OSM replication state; using full snapshot: {exc}", None, None, {"updates_url": updates_url})
+ return None
+
+ if remote_state.sequence_number <= current_state.sequence_number:
+ _emit_progress(
+ progress_callback,
+ "osm_diff_up_to_date",
+ "Local raw OSM extract is already at the latest known replication sequence.",
+ remote_state.sequence_number,
+ remote_state.sequence_number,
+ {"updates_url": updates_url, "sequence_number": remote_state.sequence_number},
+ )
+ return raw_dataset
+
+ gap = remote_state.sequence_number - current_state.sequence_number
+ if gap > settings.osm_diff_max_sequence_gap:
+ _emit_progress(
+ progress_callback,
+ "osm_diff_fallback",
+ "OSM replication gap is too large; using full snapshot.",
+ current_state.sequence_number,
+ remote_state.sequence_number,
+ {"gap": gap, "max_gap": settings.osm_diff_max_sequence_gap, "updates_url": updates_url},
+ )
+ return None
+
+ host_tool = _host_tool_path()
+ if not host_tool.exists():
+ _emit_progress(progress_callback, "osm_diff_fallback", "host_tool.sh is missing; using full snapshot.", None, None, {"host_tool": str(host_tool)})
+ return None
+
+ try:
+ return _apply_diff_range(
+ session=session,
+ source=source,
+ base_dataset=raw_dataset,
+ updates_url=updates_url,
+ local_sequence=current_state.sequence_number,
+ remote_state=remote_state,
+ host_tool=host_tool,
+ progress_callback=progress_callback,
+ )
+ except Exception as exc: # noqa: BLE001 - fall back to full snapshot rather than risk a bad base
+ _emit_progress(progress_callback, "osm_diff_fallback", f"OSM diff application failed; using full snapshot: {exc}", None, None, {"updates_url": updates_url})
+ return None
+
+
+def _apply_diff_range(
+ session: Session,
+ source: Source,
+ base_dataset: Dataset,
+ updates_url: str,
+ local_sequence: int,
+ remote_state: ReplicationState,
+ host_tool: Path,
+ progress_callback=None,
+) -> Dataset:
+ update_root = settings.data_dir / "sources" / f"source_{source.id}" / "updates"
+ work_root = settings.data_dir / "sources" / f"source_{source.id}" / "diff_work"
+ work_root.mkdir(parents=True, exist_ok=True)
+ current_path = Path(base_dataset.local_path)
+ batch_size = max(1, int(settings.osm_diff_apply_batch_size))
+ sequences = list(range(local_sequence + 1, remote_state.sequence_number + 1))
+ applied_sequences: list[int] = []
+ _emit_progress(
+ progress_callback,
+ "osm_diff_started",
+ f"Applying {len(sequences)} OSM replication diffs.",
+ local_sequence,
+ remote_state.sequence_number,
+ {"updates_url": updates_url, "from_sequence": local_sequence + 1, "to_sequence": remote_state.sequence_number},
+ )
+ for batch_start in range(0, len(sequences), batch_size):
+ batch = sequences[batch_start : batch_start + batch_size]
+ diff_paths = []
+ for sequence in batch:
+ diff_path = download_diff(updates_url, sequence, update_root)
+ diff_paths.append(diff_path)
+ _emit_progress(
+ progress_callback,
+ "osm_diff_downloaded",
+ f"Downloaded OSM diff sequence {sequence}.",
+ sequence,
+ remote_state.sequence_number,
+ {"path": str(diff_path), "sequence_number": sequence},
+ )
+ temp_output = work_root / f"source_{source.id}_{batch[0]}_{batch[-1]}.tmp.osm.pbf"
+ completed = apply_osm_changes(current_path, diff_paths, temp_output, host_tool)
+ current_path = _store_updated_raw_pbf(source, temp_output)
+ applied_sequences.extend(batch)
+ _emit_progress(
+ progress_callback,
+ "osm_diff_applied",
+ f"Applied OSM diff sequences {batch[0]}-{batch[-1]}.",
+ batch[-1],
+ remote_state.sequence_number,
+ {
+ "output_path": str(current_path),
+ "stdout": completed.stdout.strip(),
+ "stderr": completed.stderr.strip(),
+ "batch_start": batch[0],
+ "batch_end": batch[-1],
+ },
+ )
+ raw_hash = sha256_file(current_path)
+ dataset = _find_raw_dataset(session, source, raw_hash) or _commit_raw_dataset(session, source, current_path, raw_hash)
+ _update_dataset_metadata(
+ dataset,
+ replication_state={
+ "updates_url": updates_url,
+ "sequence_number": remote_state.sequence_number,
+ "timestamp": remote_state.timestamp,
+ },
+ diff_update={
+ "base_dataset_id": base_dataset.id,
+ "base_sequence_number": local_sequence,
+ "applied_sequences": applied_sequences,
+ },
+ )
+ _record_diff_state(
+ session,
+ source=source,
+ raw_dataset=dataset,
+ updates_url=updates_url,
+ state=remote_state,
+ metadata={"base_dataset_id": base_dataset.id, "applied_sequences": applied_sequences},
+ )
+ return dataset
+
+
+def _record_current_replication_state_for_snapshot(session: Session, source: Source, raw_dataset: Dataset, progress_callback=None) -> None:
+ updates_url = _source_updates_url(source)
+ if not updates_url:
+ return
+ try:
+ state = fetch_replication_state(updates_url, timeout=settings.osm_diff_state_timeout_seconds)
+ except Exception as exc: # noqa: BLE001 - full snapshot is still usable without diff state
+ _emit_progress(progress_callback, "osm_diff_state_unavailable", f"Could not record OSM replication state: {exc}", None, None, {"updates_url": updates_url})
+ return
+ _update_dataset_metadata(
+ raw_dataset,
+ replication_state={
+ "updates_url": updates_url,
+ "sequence_number": state.sequence_number,
+ "timestamp": state.timestamp,
+ },
+ )
+ _record_diff_state(
+ session,
+ source=source,
+ raw_dataset=raw_dataset,
+ updates_url=updates_url,
+ state=state,
+ metadata={"source": "full_snapshot"},
+ )
+
+
+def _record_diff_state(
+ session: Session,
+ source: Source,
+ raw_dataset: Dataset,
+ updates_url: str,
+ state: ReplicationState,
+ metadata: dict[str, Any] | None = None,
+) -> OsmDiffState:
+ for existing in session.scalars(select(OsmDiffState).where(OsmDiffState.source_id == source.id, OsmDiffState.status == "active")).all():
+ existing.status = "superseded"
+ row = OsmDiffState(
+ source_id=source.id,
+ raw_dataset_id=raw_dataset.id,
+ updates_url=updates_url,
+ sequence_number=state.sequence_number,
+ timestamp=state.timestamp,
+ status="active",
+ metadata_json=json.dumps({"state": state.raw, **(metadata or {})}, separators=(",", ":")),
+ )
+ session.add(row)
+ session.flush()
+ return row
+
+
+def _latest_diff_state(session: Session, source_id: int) -> OsmDiffState | None:
+ return session.scalar(
+ select(OsmDiffState)
+ .where(OsmDiffState.source_id == source_id, OsmDiffState.status == "active")
+ .order_by(OsmDiffState.sequence_number.desc(), OsmDiffState.id.desc())
+ )
+
+
+def _store_updated_raw_pbf(source: Source, temp_path: Path) -> Path:
+ source_dir = settings.data_dir / "sources" / f"source_{source.id}"
+ source_dir.mkdir(parents=True, exist_ok=True)
+ raw_hash = sha256_file(temp_path)
+ target = source_dir / f"{raw_hash[:16]}.osm.pbf"
+ if target.exists() and sha256_file(target) == raw_hash:
+ temp_path.unlink(missing_ok=True)
+ return target
+ shutil.move(str(temp_path), str(target))
+ return target
+
+
+def _source_updates_url(source: Source) -> str | None:
+ notes = source.notes or ""
+ for part in notes.split(";"):
+ if "=" not in part:
+ continue
+ key, value = part.strip().split("=", 1)
+ if key.strip() == "updates_url" and value.strip():
+ return value.strip()
+ if source.kind == "osm_diff" and source.url:
+ return source.url
+ return None
+
+
+def _host_tool_path() -> Path:
+ return Path(__file__).resolve().parents[2] / "scripts" / "host_tool.sh"
+
+
+def _find_raw_dataset(session: Session, source: Source, raw_hash: str) -> Dataset | None:
+ return session.scalar(
+ select(Dataset)
+ .where(
+ Dataset.source_id == source.id,
+ Dataset.kind == "osm_pbf_raw",
+ Dataset.sha256 == raw_hash,
+ )
+ .order_by(Dataset.id.desc())
+ )
+
+
+def _prepare_transport_pbf(session: Session, source: Source, raw_dataset: Dataset, raw_path: Path) -> Dataset:
+ existing = _find_transport_dataset(session, source, raw_dataset)
+ if existing is not None and Path(existing.local_path).exists():
+ return existing
+
+ output_path = _transport_filter_path(source, raw_dataset, raw_path)
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+ script_path = _prefilter_script_path()
+ if not script_path.exists():
+ raise FileNotFoundError(f"OSM transport filter script not found: {script_path}")
+ command = [str(script_path), str(raw_path), str(output_path)]
+ try:
+ completed = subprocess.run(command, check=True, capture_output=True, text=True)
+ except subprocess.CalledProcessError as exc:
+ stderr = (exc.stderr or "").strip()
+ stdout = (exc.stdout or "").strip()
+ details = stderr or stdout or f"exit code {exc.returncode}"
+ raise RuntimeError(f"OSM transport filter failed for {raw_path}: {details}") from exc
+ filtered_hash = sha256_file(output_path)
+
+ metadata = {
+ "stage": "filtered_osm_transport_pbf",
+ "raw_format": _raw_format(output_path),
+ "derived_from_dataset_id": raw_dataset.id,
+ "source_url": source.url,
+ "filter": TRANSPORT_FILTER_VERSION,
+ "filter_script": str(script_path),
+ "input_path": str(raw_path),
+ "input_sha256": raw_dataset.sha256,
+ "output_path": str(output_path),
+ "stdout": completed.stdout.strip(),
+ "stderr": completed.stderr.strip(),
+ }
+ if existing is None:
+ dataset = Dataset(
+ source_id=source.id,
+ kind="osm_pbf_transport",
+ local_path=str(output_path),
+ sha256=filtered_hash,
+ is_active=False,
+ status="filtered",
+ metadata_json=json.dumps(metadata, indent=2),
+ )
+ session.add(dataset)
+ else:
+ dataset = existing
+ dataset.local_path = str(output_path)
+ dataset.sha256 = filtered_hash
+ dataset.status = "filtered"
+ dataset.metadata_json = json.dumps(metadata, indent=2)
+ raw_dataset.status = "filtered"
+ session.flush()
+ return dataset
+
+
+def _find_transport_dataset(session: Session, source: Source, raw_dataset: Dataset) -> Dataset | None:
+ datasets = session.scalars(
+ select(Dataset)
+ .where(Dataset.source_id == source.id, Dataset.kind == "osm_pbf_transport")
+ .order_by(Dataset.id.desc())
+ ).all()
+ for dataset in datasets:
+ metadata = _metadata(dataset)
+ if (
+ metadata.get("derived_from_dataset_id") == raw_dataset.id
+ and metadata.get("filter") == TRANSPORT_FILTER_VERSION
+ ):
+ return dataset
+ return None
+
+
+def _find_existing_derived(session: Session, source: Source, input_dataset: Dataset) -> Dataset | None:
+ derived_datasets = session.scalars(
+ select(Dataset)
+ .where(
+ Dataset.source_id == source.id,
+ Dataset.kind == "osm_geojson",
+ Dataset.status == "imported",
+ Dataset.is_active.is_(True),
+ )
+ .order_by(Dataset.id.desc())
+ ).all()
+ for derived_dataset in derived_datasets:
+ metadata = _metadata(derived_dataset)
+ if (
+ metadata.get("derived_from_dataset_id") == input_dataset.id
+ and metadata.get("extractor") == EXTRACTOR_VERSION
+ ):
+ return derived_dataset
+ return None
+
+
+def _metadata(dataset: Dataset) -> dict[str, Any]:
+ try:
+ return json.loads(dataset.metadata_json or "{}")
+ except json.JSONDecodeError:
+ return {}
+
+
+def _update_dataset_metadata(dataset: Dataset, **values: Any) -> None:
+ metadata = _metadata(dataset)
+ metadata.update(values)
+ dataset.metadata_json = json.dumps(metadata, indent=2)
+
+
+def _emit_progress(progress_callback, event_type: str, message: str, progress_current=None, progress_total=None, metadata: dict[str, Any] | None = None) -> None:
+ if progress_callback is not None:
+ progress_callback(event_type, message, progress_current, progress_total, metadata)
+
+
+def _should_prefilter(path: Path) -> bool:
+ if not settings.osm_pbf_prefilter_enabled:
+ return False
+ return _raw_format(path) in _prefilter_formats()
+
+
+def _prefilter_formats() -> set[str]:
+ return {
+ value.strip()
+ for value in str(settings.osm_pbf_prefilter_formats or "").split(",")
+ if value.strip()
+ }
+
+
+def _prefilter_script_path() -> Path:
+ path = settings.osm_pbf_prefilter_script
+ if path.is_absolute():
+ return path
+ return Path.cwd() / path
+
+
+def _transport_filter_path(source: Source, raw_dataset: Dataset, raw_path: Path) -> Path:
+ raw_format = _raw_format(raw_path)
+ suffix = ".osm.pbf" if raw_format == "osm_pbf" else ".osm"
+ return settings.data_dir / "derived" / f"source_{source.id}" / f"raw_dataset_{raw_dataset.id}" / f"transport{suffix}"
+
+
+class _TransportScanHandler(osmium.SimpleHandler):
+ def __init__(self) -> None:
+ super().__init__()
+ self.route_relations: dict[int, dict[str, Any]] = {}
+ self.route_way_ids: set[int] = set()
+
+ def relation(self, relation: osmium.osm.Relation) -> None:
+ tags = _tags_dict(relation.tags)
+ mode = _route_mode(tags)
+ if tags.get("type") != "route" or mode is None:
+ return
+
+ way_refs = [member.ref for member in relation.members if member.type == "w"]
+ if not way_refs:
+ return
+
+ self.route_relations[relation.id] = {
+ "tags": tags,
+ "way_refs": way_refs,
+ }
+ self.route_way_ids.update(way_refs)
+
+
+class _TransportGeometryHandler(osmium.SimpleHandler):
+ def __init__(self, route_relations: dict[int, dict[str, Any]], route_way_ids: set[int]) -> None:
+ super().__init__()
+ self.route_relations = route_relations
+ self.route_way_ids = route_way_ids
+ self.route_way_lines: dict[int, list[list[float]]] = {}
+ self.infrastructure_features: list[dict[str, Any]] = []
+ self.stop_features: list[dict[str, Any]] = []
+ self.route_relations_without_geometry = 0
+
+ def node(self, node: osmium.osm.Node) -> None:
+ tags = _tags_dict(node.tags)
+ if not _is_stop_or_station(tags):
+ return
+ coords = _node_coords(node)
+ if coords is None:
+ return
+ props = {
+ **tags,
+ "osm_type": "node",
+ "osm_id": str(node.id),
+ }
+ self.stop_features.append({"type": "Feature", "geometry": {"type": "Point", "coordinates": coords}, "properties": props})
+
+ def way(self, way: osmium.osm.Way) -> None:
+ tags = _tags_dict(way.tags)
+ coords = _way_coords(way)
+
+ if coords is not None and way.id in self.route_way_ids:
+ self.route_way_lines[way.id] = coords
+
+ if coords is not None and _is_transport_infrastructure(tags):
+ props = {
+ **tags,
+ "osm_type": "way",
+ "osm_id": str(way.id),
+ "kind": "infra",
+ }
+ mode = _infrastructure_mode(tags)
+ if mode:
+ props.setdefault("mode", mode)
+ self.infrastructure_features.append(
+ {"type": "Feature", "geometry": {"type": "LineString", "coordinates": coords}, "properties": props}
+ )
+
+ if _is_stop_or_station(tags):
+ feature = _way_area_or_line_feature(way, tags, coords)
+ if feature is not None:
+ self.stop_features.append(feature)
+
+ def features(self) -> list[dict[str, Any]]:
+ route_features = []
+ for relation_id, route in self.route_relations.items():
+ lines = [line for way_ref in route["way_refs"] if (line := self.route_way_lines.get(way_ref))]
+ if not lines:
+ self.route_relations_without_geometry += 1
+ continue
+
+ geometry: dict[str, Any]
+ ordered_lines = _ordered_route_lines(route["way_refs"], self.route_way_lines)
+ if len(ordered_lines) == 1:
+ geometry = {"type": "LineString", "coordinates": ordered_lines[0]}
+ else:
+ geometry = {"type": "MultiLineString", "coordinates": ordered_lines}
+
+ props = {
+ **route["tags"],
+ "osm_type": "relation",
+ "osm_id": str(relation_id),
+ "member_way_count": len(route["way_refs"]),
+ "geometry_source": "ordered_route_relation_member_ways",
+ "geometry_part_count": len(ordered_lines),
+ }
+ route_features.append({"type": "Feature", "geometry": geometry, "properties": props})
+ return route_features + self.infrastructure_features + self.stop_features
+
+
+def _ordered_route_lines(way_refs: list[int], route_way_lines: dict[int, list[list[float]]]) -> list[list[list[float]]]:
+ parts: list[list[list[float]]] = []
+ for way_ref in way_refs:
+ line = route_way_lines.get(way_ref)
+ if not line:
+ continue
+ coords = [list(coord) for coord in line]
+ if len(coords) < 2:
+ continue
+ if not parts:
+ parts.append(coords)
+ continue
+ if _append_connected(parts[-1], coords):
+ continue
+ attached = False
+ for part in reversed(parts[:-1]):
+ if _append_connected(part, coords):
+ attached = True
+ break
+ if not attached:
+ parts.append(coords)
+ return parts
+
+
+def _append_connected(part: list[list[float]], coords: list[list[float]]) -> bool:
+ if _same_coord(part[-1], coords[0]):
+ part.extend(coords[1:])
+ return True
+ if _same_coord(part[-1], coords[-1]):
+ part.extend(reversed(coords[:-1]))
+ return True
+ if _same_coord(part[0], coords[-1]):
+ part[:0] = coords[:-1]
+ return True
+ if _same_coord(part[0], coords[0]):
+ part[:0] = list(reversed(coords[1:]))
+ return True
+ return False
+
+
+def _same_coord(left: list[float], right: list[float]) -> bool:
+ return len(left) >= 2 and len(right) >= 2 and abs(left[0] - right[0]) < 1e-9 and abs(left[1] - right[1]) < 1e-9
+
+
+def _tags_dict(tags: osmium.osm.TagList) -> dict[str, str]:
+ return {tag.k: tag.v for tag in tags}
+
+
+def _route_mode(tags: dict[str, str]) -> str | None:
+ value = tags.get("route")
+ if value in ROUTE_MODES:
+ return "train" if value == "railway" else value
+ return None
+
+
+def _is_transport_infrastructure(tags: dict[str, str]) -> bool:
+ return _infrastructure_mode(tags) is not None
+
+
+def _infrastructure_mode(tags: dict[str, str]) -> str | None:
+ railway = tags.get("railway")
+ if railway in RAILWAY_MODE_BY_TAG:
+ return RAILWAY_MODE_BY_TAG[railway]
+ if tags.get("route") == "ferry":
+ return "ferry"
+ aerialway = tags.get("aerialway")
+ if aerialway and aerialway != "station":
+ return "aerialway"
+ return None
+
+
+def _is_stop_or_station(tags: dict[str, str]) -> bool:
+ if tags.get("public_transport") in {"platform", "stop_position", "station"}:
+ return True
+ if tags.get("railway") in {"station", "halt", "tram_stop", "subway_entrance", "platform"}:
+ return True
+ if tags.get("highway") == "bus_stop":
+ return True
+ if tags.get("amenity") in {"bus_station", "ferry_terminal"}:
+ return True
+ if tags.get("aerialway") == "station":
+ return True
+ return False
+
+
+def _node_coords(node: osmium.osm.Node) -> list[float] | None:
+ try:
+ if not node.location.valid():
+ return None
+ return [float(node.location.lon), float(node.location.lat)]
+ except Exception:
+ return None
+
+
+def _way_coords(way: osmium.osm.Way) -> list[list[float]] | None:
+ coords = []
+ try:
+ for node in way.nodes:
+ if not node.location.valid():
+ return None
+ coords.append([float(node.location.lon), float(node.location.lat)])
+ except Exception:
+ return None
+ return coords if len(coords) >= 2 else None
+
+
+def _way_area_or_line_feature(way: osmium.osm.Way, tags: dict[str, str], coords: list[list[float]] | None) -> dict[str, Any] | None:
+ if coords is None:
+ return None
+ props = {
+ **tags,
+ "osm_type": "way",
+ "osm_id": str(way.id),
+ }
+ if len(coords) >= 4 and coords[0] == coords[-1]:
+ return {"type": "Feature", "geometry": {"type": "Polygon", "coordinates": [coords]}, "properties": props}
+ return {"type": "Feature", "geometry": {"type": "LineString", "coordinates": coords}, "properties": props}
+
+
+def _record_pipeline_stage(
+ session: Session,
+ *,
+ stage: str,
+ version: str,
+ source_id: int,
+ dataset: Dataset,
+ inputs: dict[str, Any],
+ outputs: dict[str, Any] | None,
+) -> None:
+ dependency_hash_value = dependency_hash(inputs)
+ run = start_pipeline_run(
+ session,
+ stage=stage,
+ version=version,
+ dependency_hash_value=dependency_hash_value,
+ source_id=source_id,
+ dataset_id=dataset.id,
+ inputs=inputs,
+ )
+ finish_pipeline_run(session, run, outputs=outputs or {})
+
+
+def _raw_format(path: Path) -> str:
+ name = path.name.lower()
+ if name.endswith(".osm.pbf") or name.endswith(".pbf"):
+ return "osm_pbf"
+ if name.endswith(".osm") or name.endswith(".osm.xml") or name.endswith(".xml"):
+ return "osm_xml"
+ if name.endswith(".osc") or name.endswith(".osc.gz"):
+ return "osm_change"
+ return "osm"
diff --git a/app/pipeline/osm_replication.py b/app/pipeline/osm_replication.py
new file mode 100644
index 0000000..c7ef43e
--- /dev/null
+++ b/app/pipeline/osm_replication.py
@@ -0,0 +1,105 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from pathlib import Path
+import subprocess
+from urllib.parse import urljoin, urlparse
+
+import requests
+
+
+@dataclass(frozen=True)
+class ReplicationState:
+ sequence_number: int
+ timestamp: str | None
+ raw: dict[str, str]
+
+
+def fetch_replication_state(updates_url: str, *, timeout: float = 30) -> ReplicationState:
+ state_url = _state_url(updates_url)
+ response = requests.get(state_url, timeout=timeout)
+ response.raise_for_status()
+ return parse_replication_state_text(response.text)
+
+
+def parse_replication_state_text(text: str) -> ReplicationState:
+ values: dict[str, str] = {}
+ for line in text.splitlines():
+ line = line.strip()
+ if not line or line.startswith("#") or "=" not in line:
+ continue
+ key, value = line.split("=", 1)
+ values[key.strip()] = _unescape_state_value(value.strip())
+ sequence = values.get("sequenceNumber")
+ if sequence is None:
+ raise ValueError("replication state is missing sequenceNumber")
+ try:
+ sequence_number = int(sequence)
+ except ValueError as exc:
+ raise ValueError(f"invalid replication sequenceNumber: {sequence}") from exc
+ return ReplicationState(
+ sequence_number=sequence_number,
+ timestamp=values.get("timestamp"),
+ raw=values,
+ )
+
+
+def diff_url_for_sequence(updates_url: str, sequence_number: int) -> str:
+ padded = str(sequence_number).zfill(max(9, ((len(str(sequence_number)) + 2) // 3) * 3))
+ parts = [padded[index : index + 3] for index in range(0, len(padded), 3)]
+ return urljoin(_directory_url(updates_url), "/".join(parts) + ".osc.gz")
+
+
+def download_diff(updates_url: str, sequence_number: int, output_dir: Path, *, timeout: float = 120) -> Path:
+ url = diff_url_for_sequence(updates_url, sequence_number)
+ parsed_path = Path(urlparse(url).path)
+ output_path = output_dir / parsed_path.name
+ nested = output_dir / parsed_path.parent.name / output_path.name
+ if output_path.exists():
+ return output_path
+ if nested.exists():
+ return nested
+ output_dir.mkdir(parents=True, exist_ok=True)
+ temp_path = output_dir / f"{sequence_number}.download"
+ with requests.get(url, stream=True, timeout=timeout) as response:
+ response.raise_for_status()
+ with temp_path.open("wb") as handle:
+ for chunk in response.iter_content(chunk_size=1024 * 1024):
+ if chunk:
+ handle.write(chunk)
+ temp_path.replace(output_path)
+ return output_path
+
+
+def apply_osm_changes(base_path: Path, diff_paths: list[Path], output_path: Path, host_tool_path: Path) -> subprocess.CompletedProcess[str]:
+ if not diff_paths:
+ raise ValueError("no OSM change files supplied")
+ output_path.parent.mkdir(parents=True, exist_ok=True)
+ command = [
+ str(host_tool_path),
+ "osmium",
+ "apply-changes",
+ "--output",
+ str(output_path),
+ "--overwrite",
+ str(base_path),
+ *[str(path) for path in diff_paths],
+ ]
+ return subprocess.run(command, check=True, capture_output=True, text=True)
+
+
+def _state_url(updates_url: str) -> str:
+ return urljoin(_directory_url(updates_url), "state.txt")
+
+
+def _directory_url(url: str) -> str:
+ return url if url.endswith("/") else f"{url}/"
+
+
+def _unescape_state_value(value: str) -> str:
+ return (
+ value.replace("\\:", ":")
+ .replace("\\=", "=")
+ .replace("\\ ", " ")
+ .replace("\\\\", "\\")
+ )
diff --git a/app/pipeline/route_layer.py b/app/pipeline/route_layer.py
new file mode 100644
index 0000000..7e3b3d9
--- /dev/null
+++ b/app/pipeline/route_layer.py
@@ -0,0 +1,1903 @@
+from __future__ import annotations
+
+import json
+from dataclasses import dataclass
+from typing import Callable, Iterable
+
+from shapely.geometry import LineString, MultiLineString, Point, shape
+from shapely.ops import linemerge
+from sqlalchemy import and_, delete, func, or_, select, text
+from sqlalchemy.orm import Session
+
+from app.config import settings
+from app.gtfs_storage import all_scheduled_stop_ids, stop_times_by_trip as storage_stop_times_by_trip
+from app.models import (
+ CanonicalStop,
+ CanonicalStopLink,
+ Dataset,
+ GtfsRoute,
+ GtfsRoutePatternLink,
+ GtfsShape,
+ GtfsStop,
+ GtfsStopTime,
+ GtfsTrip,
+ GtfsTripRoutePatternLink,
+ MatchRule,
+ OsmFeature,
+ RouteMatch,
+ RoutePattern,
+ RoutePatternStop,
+)
+from app.osm_classification import infer_osm_route_scope_from_tags
+from app.osm_storage import ensure_main_osm_feature, features_are_sidecar, osm_feature_count, query_osm_features
+from app.pipeline.matcher import MODE_GROUPS
+from app.pipeline.state import STAGE_BUILD_ROUTE_LAYER, dependency_hash, finish_pipeline_run, start_pipeline_run
+from app.pipeline.utils import bbox_overlap, geometry_json_and_bbox, norm_ref, norm_text
+from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries, using_postgresql
+
+
+ROUTE_LAYER_VERSION = "route_layer_v3_stop_alias_matching"
+GTFS_ROUTE_PATTERN_NULL_SHAPE = "__route__"
+OSM_STOP_LINK_RADIUS_DEG = 0.0018
+OSM_STOP_NAME_LINK_RADIUS_DEG = 0.0032
+GTFS_STOP_EXACT_NAME_LINK_RADIUS_DEG = 0.006
+GTFS_STOP_NAME_LINK_RADIUS_DEG = 0.0032
+GTFS_STOP_PARTIAL_NAME_LINK_RADIUS_DEG = 0.0014
+OSM_ROUTE_MIN_SCORE = 62.0
+ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
+STOP_MATCH_NOISE_TOKENS = {
+ "s",
+ "u",
+ "bhf",
+ "station",
+ "train",
+ "flixtrain",
+ "flixbus",
+}
+
+
+@dataclass(frozen=True)
+class _GtfsPatternSeed:
+ route: GtfsRoute
+ shape_id: str | None
+ trip_id: str | None
+ geometry_text: str | None
+ geometry_source: str
+ bbox: tuple[float | None, float | None, float | None, float | None]
+ start_point: Point | None
+ end_point: Point | None
+ center_point: Point | None
+
+
+@dataclass(frozen=True)
+class _OsmRouteCandidate:
+ feature: OsmFeature
+ geom: object
+ geometry_text: str
+ bbox: tuple[float | None, float | None, float | None, float | None]
+ ref_key: str
+ mode: str | None
+
+
+@dataclass(frozen=True)
+class _OsmRouteCandidateIndex:
+ by_ref_mode: dict[tuple[str, str], list[_OsmRouteCandidate]]
+ by_id: dict[int, _OsmRouteCandidate]
+
+
+@dataclass(frozen=True)
+class _RouteLayerOverrides:
+ accepted_by_gtfs_route_id: dict[int, int]
+ rejected_by_gtfs_route_id: dict[int, set[int]]
+
+
+@dataclass(frozen=True)
+class _CanonicalStopLinkOverrides:
+ link_by_stop: dict[tuple[int, str], dict[str, object]]
+ unlink_by_stop: dict[tuple[int, str], dict[str, object]]
+
+
+@dataclass(frozen=True)
+class _PatternBuildItem:
+ seed: _GtfsPatternSeed
+ pattern: RoutePattern
+ confidence: float
+ source_kind: str
+ status: str
+ reasons: dict[str, object]
+
+
+def rebuild_route_layer(
+ session: Session,
+ *,
+ progress_callback: ProgressCallback | None = None,
+ commit_between_steps: bool = True,
+) -> dict[str, object]:
+ """Rebuild the visual route layer from active GTFS and OSM datasets."""
+ dependency = _route_layer_dependency(session)
+ run = start_pipeline_run(
+ session,
+ stage=STAGE_BUILD_ROUTE_LAYER,
+ version=ROUTE_LAYER_VERSION,
+ dependency_hash_value=dependency_hash(dependency),
+ inputs=dependency,
+ )
+ _commit_or_flush(session, commit_between_steps)
+ _emit_progress(progress_callback, "route_layer_started", "Rebuilding visual route layer.", 0, 4, {"version": ROUTE_LAYER_VERSION})
+ _clear_route_layer(session, preserve_route_patterns=True)
+ _commit_or_flush(session, commit_between_steps)
+ _emit_progress(progress_callback, "route_layer_cleared", "Cleared derived route-layer link tables.", 1, 4, None)
+ canonical_result = _build_canonical_stops(session)
+ _commit_or_flush(session, commit_between_steps)
+ _emit_progress(progress_callback, "route_layer_canonical_stops", "Built canonical GTFS stops.", 2, 4, canonical_result)
+ osm_link_result = _link_osm_stops(session, progress_callback=progress_callback, commit_batches=commit_between_steps)
+ _commit_or_flush(session, commit_between_steps)
+ _emit_progress(progress_callback, "route_layer_osm_stop_links", "Linked OSM visual stops to canonical stops.", 3, 4, osm_link_result)
+ pattern_result = _build_route_patterns(session, progress_callback=progress_callback)
+ _commit_or_flush(session, commit_between_steps)
+ result = {
+ "version": ROUTE_LAYER_VERSION,
+ "canonical_stops": canonical_result["canonical_stops"],
+ "canonical_stop_links": canonical_result["canonical_stop_links"] + osm_link_result["canonical_stop_links"],
+ "route_patterns": pattern_result["route_patterns"],
+ "route_patterns_created": pattern_result.get("route_patterns_created", 0),
+ "route_patterns_updated": pattern_result.get("route_patterns_updated", 0),
+ "route_patterns_reused": pattern_result.get("route_patterns_reused", 0),
+ "route_patterns_removed": pattern_result.get("route_patterns_removed", 0),
+ "route_pattern_links": pattern_result["route_pattern_links"],
+ "trip_pattern_links": pattern_result["trip_pattern_links"],
+ "route_pattern_stops": pattern_result["route_pattern_stops"],
+ "gtfs_proposed_patterns": pattern_result["gtfs_proposed_patterns"],
+ }
+ finish_pipeline_run(session, run, outputs=result)
+ _commit_or_flush(session, commit_between_steps)
+ _emit_progress(progress_callback, "route_layer_completed", "Visual route layer rebuilt.", 4, 4, result)
+ return result
+
+
+def _route_layer_dependency(session: Session) -> dict[str, object]:
+ active_datasets = [
+ {"id": int(dataset.id), "source_id": int(dataset.source_id), "kind": dataset.kind, "sha256": dataset.sha256, "metadata": dataset.metadata_json}
+ for dataset in session.scalars(select(Dataset).where(Dataset.is_active.is_(True)).order_by(Dataset.kind, Dataset.id)).all()
+ ]
+ match_rows = session.execute(select(RouteMatch.id, RouteMatch.gtfs_route_id, RouteMatch.osm_feature_id, RouteMatch.status, RouteMatch.updated_at).order_by(RouteMatch.id)).all()
+ match_signature = dependency_hash(
+ [
+ [
+ int(row.id),
+ int(row.gtfs_route_id),
+ None if row.osm_feature_id is None else int(row.osm_feature_id),
+ row.status,
+ row.updated_at.isoformat() if row.updated_at else None,
+ ]
+ for row in match_rows
+ ]
+ )
+ return {
+ "version": ROUTE_LAYER_VERSION,
+ "active_datasets": active_datasets,
+ "route_matches": {"count": len(match_rows), "signature": match_signature},
+ }
+
+
+def logical_stop_group_id(stop: GtfsStop) -> str:
+ if stop.parent_station:
+ return stop.parent_station
+ if "::" in stop.stop_id:
+ return stop.stop_id.split("::", 1)[0]
+ return stop.stop_id
+
+
+def route_pattern_for_trip(session: Session, route: GtfsRoute, trip: GtfsTrip) -> RoutePattern | None:
+ trip_link = session.scalar(
+ select(GtfsTripRoutePatternLink)
+ .where(
+ GtfsTripRoutePatternLink.dataset_id == trip.dataset_id,
+ GtfsTripRoutePatternLink.trip_id == trip.trip_id,
+ )
+ .order_by(GtfsTripRoutePatternLink.confidence.desc(), GtfsTripRoutePatternLink.id)
+ )
+ if trip_link is not None:
+ return session.get(RoutePattern, trip_link.route_pattern_id)
+
+ shape_key = trip.shape_id or GTFS_ROUTE_PATTERN_NULL_SHAPE
+ link = session.scalar(
+ select(GtfsRoutePatternLink)
+ .where(
+ GtfsRoutePatternLink.dataset_id == route.dataset_id,
+ GtfsRoutePatternLink.route_id == route.route_id,
+ GtfsRoutePatternLink.shape_id == shape_key,
+ )
+ .order_by(GtfsRoutePatternLink.confidence.desc(), GtfsRoutePatternLink.id)
+ )
+ if link is None:
+ return None
+ return session.get(RoutePattern, link.route_pattern_id)
+
+
+def canonical_stop_for_gtfs_stop(session: Session, stop: GtfsStop) -> CanonicalStop | None:
+ link = session.scalar(
+ select(CanonicalStopLink).where(
+ CanonicalStopLink.object_type == "gtfs_stop",
+ CanonicalStopLink.dataset_id == stop.dataset_id,
+ CanonicalStopLink.object_id == stop.id,
+ )
+ )
+ if link is None:
+ return None
+ return session.get(CanonicalStop, link.canonical_stop_id)
+
+
+def gtfs_stop_ids_for_canonical_stop(session: Session, canonical_stop_id: int, dataset_id: int) -> tuple[str, ...]:
+ rows = session.scalars(
+ select(CanonicalStopLink.external_id)
+ .where(
+ CanonicalStopLink.canonical_stop_id == canonical_stop_id,
+ CanonicalStopLink.object_type == "gtfs_stop",
+ CanonicalStopLink.dataset_id == dataset_id,
+ )
+ .order_by(CanonicalStopLink.role, CanonicalStopLink.external_id)
+ ).all()
+ return tuple(str(row) for row in rows)
+
+
+def _clear_route_layer(session: Session, *, preserve_route_patterns: bool = False) -> None:
+ models = [GtfsTripRoutePatternLink, GtfsRoutePatternLink, RoutePatternStop]
+ if not preserve_route_patterns:
+ models.append(RoutePattern)
+ models.extend([CanonicalStopLink, CanonicalStop])
+ for model in models:
+ session.execute(delete(model))
+ session.flush()
+
+
+def _commit_or_flush(session: Session, should_commit: bool) -> None:
+ if should_commit:
+ session.commit()
+ else:
+ session.flush()
+
+
+def _emit_progress(
+ progress_callback: ProgressCallback | None,
+ event_type: str,
+ message: str,
+ progress_current: int | None,
+ progress_total: int | None,
+ metadata: dict[str, object] | None = None,
+) -> None:
+ if progress_callback is not None:
+ progress_callback(event_type, message, progress_current, progress_total, metadata)
+
+
+def _build_canonical_stops(session: Session) -> dict[str, int]:
+ active_gtfs_dataset_ids = _active_dataset_ids(session, "gtfs")
+ if not active_gtfs_dataset_ids:
+ return {"canonical_stops": 0, "canonical_stop_links": 0}
+ overrides = _canonical_stop_link_overrides(session)
+ source_id_by_dataset = {
+ int(dataset_id): int(source_id)
+ for dataset_id, source_id in session.execute(
+ select(Dataset.id, Dataset.source_id).where(Dataset.id.in_(active_gtfs_dataset_ids))
+ ).all()
+ }
+
+ stops = _scheduled_gtfs_stops(session, active_gtfs_dataset_ids)
+ groups: dict[tuple[int, str], list[GtfsStop]] = {}
+ for stop in stops:
+ groups.setdefault((stop.dataset_id, logical_stop_group_id(stop)), []).append(stop)
+
+ canonical_by_group: dict[tuple[int, str], CanonicalStop] = {}
+ link_quality_by_group: dict[tuple[int, str], tuple[float, float | None]] = {}
+ canonical_grid: dict[tuple[int, int], list[CanonicalStop]] = {}
+ for (dataset_id, group_id), group_stops in groups.items():
+ display = _best_display_stop(group_id, group_stops)
+ canonical, distance_m, confidence = _nearest_gtfs_canonical_from_grid(canonical_grid, display)
+ if canonical is None:
+ canonical = CanonicalStop(
+ stop_key=f"gtfs:{dataset_id}:{group_id}",
+ name=display.name or group_id,
+ normalized_name=norm_text(display.name or group_id),
+ lat=display.lat,
+ lon=display.lon,
+ metadata_json=json.dumps({"dataset_id": dataset_id, "group_id": group_id}, separators=(",", ":")),
+ )
+ _add_canonical_to_gtfs_grid(canonical_grid, canonical)
+ confidence = 1.0
+ distance_m = None
+ else:
+ _maybe_update_canonical_stop_display(canonical, display)
+ canonical_by_group[(dataset_id, group_id)] = canonical
+ link_quality_by_group[(dataset_id, group_id)] = (confidence, distance_m)
+ unique_canonicals = list(dict.fromkeys(canonical_by_group.values()))
+ session.add_all(unique_canonicals)
+ session.flush()
+ canonical_by_stop_key = {canonical.stop_key: canonical for canonical in unique_canonicals}
+ canonical_by_gtfs_stop = {
+ (stop.dataset_id, stop.stop_id): canonical_by_group[(dataset_id, group_id)]
+ for (dataset_id, group_id), group_stops in groups.items()
+ for stop in group_stops
+ }
+
+ link_objects: list[CanonicalStopLink] = []
+ for (dataset_id, group_id), group_stops in groups.items():
+ group_canonical = canonical_by_group[(dataset_id, group_id)]
+ confidence, distance_m = link_quality_by_group[(dataset_id, group_id)]
+ for stop in group_stops:
+ canonical = _canonical_for_gtfs_stop_link(
+ session=session,
+ stop=stop,
+ group_canonical=group_canonical,
+ overrides=overrides,
+ canonical_by_stop_key=canonical_by_stop_key,
+ canonical_by_gtfs_stop=canonical_by_gtfs_stop,
+ source_id_by_dataset=source_id_by_dataset,
+ )
+ role = "parent" if stop.stop_id == group_id and stop.parent_station is None else "platform"
+ metadata = None
+ if (stop.dataset_id, stop.stop_id) in overrides.link_by_stop:
+ metadata = json.dumps({"manual_rule": "link_canonical_stop"}, separators=(",", ":"))
+ elif (stop.dataset_id, stop.stop_id) in overrides.unlink_by_stop:
+ metadata = json.dumps({"manual_rule": "unlink_canonical_stop"}, separators=(",", ":"))
+ link_objects.append(
+ CanonicalStopLink(
+ canonical_stop_id=canonical.id,
+ layer="timetable",
+ object_type="gtfs_stop",
+ dataset_id=stop.dataset_id,
+ object_id=stop.id,
+ external_id=stop.stop_id,
+ role=role,
+ confidence=confidence,
+ distance_m=distance_m,
+ metadata_json=metadata,
+ )
+ )
+ if len(link_objects) >= 5000:
+ session.bulk_save_objects(link_objects)
+ link_objects.clear()
+ if link_objects:
+ session.bulk_save_objects(link_objects)
+ session.flush()
+ refresh_postgis_geometries(session, tables=["canonical_stops"])
+ analyze_postgresql_tables(session, ["canonical_stops", "canonical_stop_links"])
+ linked_stop_count = sum(len(group_stops) for group_stops in groups.values())
+ return {"canonical_stops": len(canonical_by_stop_key), "canonical_stop_links": linked_stop_count}
+
+
+def _scheduled_gtfs_stops(session: Session, active_gtfs_dataset_ids: list[int]) -> list[GtfsStop]:
+ if using_postgresql():
+ scheduled_exists = (
+ select(GtfsStopTime.id)
+ .where(
+ GtfsStopTime.dataset_id == GtfsStop.dataset_id,
+ GtfsStopTime.stop_id == GtfsStop.stop_id,
+ )
+ .limit(1)
+ .exists()
+ )
+ return session.scalars(
+ select(GtfsStop)
+ .where(GtfsStop.dataset_id.in_(active_gtfs_dataset_ids), scheduled_exists)
+ .order_by(GtfsStop.dataset_id, GtfsStop.name, GtfsStop.stop_id)
+ ).all()
+
+ scheduled_by_dataset = {
+ dataset_id: all_scheduled_stop_ids(session, dataset_id)
+ for dataset_id in active_gtfs_dataset_ids
+ }
+ stops = session.scalars(
+ select(GtfsStop)
+ .where(GtfsStop.dataset_id.in_(active_gtfs_dataset_ids))
+ .order_by(GtfsStop.dataset_id, GtfsStop.name, GtfsStop.stop_id)
+ ).all()
+ return [
+ stop
+ for stop in stops
+ if stop.stop_id in scheduled_by_dataset.get(stop.dataset_id, set())
+ ]
+
+
+def _canonical_for_gtfs_stop_link(
+ *,
+ session: Session,
+ stop: GtfsStop,
+ group_canonical: CanonicalStop,
+ overrides: _CanonicalStopLinkOverrides,
+ canonical_by_stop_key: dict[str, CanonicalStop],
+ canonical_by_gtfs_stop: dict[tuple[int, str], CanonicalStop],
+ source_id_by_dataset: dict[int, int],
+) -> CanonicalStop:
+ key = (stop.dataset_id, stop.stop_id)
+ if key in overrides.unlink_by_stop:
+ action = overrides.unlink_by_stop[key]
+ stop_key = str(action.get("target_stop_key") or f"manual:gtfs_stop:{stop.dataset_id}:{stop.stop_id}")
+ return _manual_canonical_stop(
+ session=session,
+ stop=stop,
+ stop_key=stop_key,
+ action=action,
+ canonical_by_stop_key=canonical_by_stop_key,
+ metadata_type="manual_unlink",
+ )
+
+ if key in overrides.link_by_stop:
+ action = overrides.link_by_stop[key]
+ target = _canonical_from_target_gtfs_refs(action, canonical_by_gtfs_stop, source_id_by_dataset)
+ if target is not None:
+ return target
+ target_stop_key = str(action.get("target_stop_key") or group_canonical.stop_key)
+ canonical = canonical_by_stop_key.get(target_stop_key)
+ if canonical is not None:
+ return canonical
+ return _manual_canonical_stop(
+ session=session,
+ stop=stop,
+ stop_key=target_stop_key,
+ action=action,
+ canonical_by_stop_key=canonical_by_stop_key,
+ metadata_type="manual_link_target",
+ )
+
+ return group_canonical
+
+
+def _canonical_from_target_gtfs_refs(
+ action: dict[str, object],
+ canonical_by_gtfs_stop: dict[tuple[int, str], CanonicalStop],
+ source_id_by_dataset: dict[int, int],
+) -> CanonicalStop | None:
+ refs = action.get("target_gtfs_stops")
+ if not isinstance(refs, list):
+ return None
+ for ref in refs:
+ if not isinstance(ref, dict):
+ continue
+ external_id = ref.get("external_id") or ref.get("stop_id")
+ if not external_id:
+ continue
+ source_id = ref.get("source_id")
+ for (dataset_id, stop_id), canonical in canonical_by_gtfs_stop.items():
+ if stop_id != str(external_id):
+ continue
+ if source_id is not None:
+ try:
+ if source_id_by_dataset.get(dataset_id) != int(source_id):
+ continue
+ except (TypeError, ValueError):
+ continue
+ return canonical
+ return None
+
+
+def _manual_canonical_stop(
+ *,
+ session: Session,
+ stop: GtfsStop,
+ stop_key: str,
+ action: dict[str, object],
+ canonical_by_stop_key: dict[str, CanonicalStop],
+ metadata_type: str,
+) -> CanonicalStop:
+ canonical = canonical_by_stop_key.get(stop_key)
+ if canonical is not None:
+ return canonical
+ name = str(action.get("target_name") or stop.name or stop.stop_id)
+ canonical = CanonicalStop(
+ stop_key=stop_key,
+ name=name,
+ normalized_name=norm_text(name),
+ lat=_float_or_default(action.get("target_lat"), stop.lat),
+ lon=_float_or_default(action.get("target_lon"), stop.lon),
+ mode=str(action.get("target_mode") or "") or None,
+ metadata_json=json.dumps(
+ {
+ "source": metadata_type,
+ "dataset_id": stop.dataset_id,
+ "stop_id": stop.stop_id,
+ },
+ separators=(",", ":"),
+ ),
+ )
+ session.add(canonical)
+ session.flush()
+ canonical_by_stop_key[stop_key] = canonical
+ return canonical
+
+
+def _canonical_stop_link_overrides(session: Session) -> _CanonicalStopLinkOverrides:
+ active_dataset_ids_by_source: dict[int, list[int]] = {}
+ for source_id, dataset_id in session.execute(
+ select(Dataset.source_id, Dataset.id).where(Dataset.is_active.is_(True), Dataset.kind == "gtfs")
+ ).all():
+ active_dataset_ids_by_source.setdefault(int(source_id), []).append(int(dataset_id))
+ rules = session.scalars(
+ select(MatchRule)
+ .where(
+ MatchRule.active.is_(True),
+ MatchRule.rule_type.in_(["link_canonical_stop", "unlink_canonical_stop"]),
+ )
+ .order_by(MatchRule.id)
+ ).all()
+ link_by_stop: dict[tuple[int, str], dict[str, object]] = {}
+ unlink_by_stop: dict[tuple[int, str], dict[str, object]] = {}
+ for rule in rules:
+ selector = _json_dict(rule.selector_json)
+ action = _json_dict(rule.action_json)
+ keys = _gtfs_stop_rule_keys(selector, active_dataset_ids_by_source)
+ if not keys:
+ continue
+ for key in keys:
+ if rule.rule_type == "link_canonical_stop":
+ link_by_stop[key] = action
+ unlink_by_stop.pop(key, None)
+ elif rule.rule_type == "unlink_canonical_stop":
+ unlink_by_stop[key] = action
+ link_by_stop.pop(key, None)
+ return _CanonicalStopLinkOverrides(link_by_stop=link_by_stop, unlink_by_stop=unlink_by_stop)
+
+
+def _gtfs_stop_rule_keys(
+ selector: dict[str, object],
+ active_dataset_ids_by_source: dict[int, list[int]],
+) -> list[tuple[int, str]]:
+ if selector.get("object_type") not in {None, "gtfs_stop"}:
+ return []
+ nested = selector.get("gtfs_stop")
+ nested_selector = nested if isinstance(nested, dict) else {}
+ dataset_id = selector.get("dataset_id", nested_selector.get("dataset_id"))
+ source_id = selector.get("source_id", nested_selector.get("source_id"))
+ external_id = selector.get("external_id", nested_selector.get("external_id", nested_selector.get("stop_id")))
+ if external_id is None:
+ return []
+ keys: list[tuple[int, str]] = []
+ try:
+ if dataset_id is not None:
+ keys.append((int(dataset_id), str(external_id)))
+ if source_id is not None:
+ keys.extend((active_dataset_id, str(external_id)) for active_dataset_id in active_dataset_ids_by_source.get(int(source_id), []))
+ except (TypeError, ValueError):
+ return []
+ return list(dict.fromkeys(keys))
+
+
+def _json_dict(value: str | None) -> dict[str, object]:
+ try:
+ data = json.loads(value or "{}")
+ except json.JSONDecodeError:
+ return {}
+ return data if isinstance(data, dict) else {}
+
+
+def _float_or_default(value: object, default: float | None) -> float | None:
+ if value is None:
+ return default
+ try:
+ return float(value)
+ except (TypeError, ValueError):
+ return default
+
+
+def _nearest_gtfs_canonical_from_grid(
+ grid: dict[tuple[int, int], list[CanonicalStop]], display: GtfsStop
+) -> tuple[CanonicalStop | None, float | None, float]:
+ if display.lon is None or display.lat is None:
+ return None, None, 0.0
+ normalized_name = norm_text(display.name or display.stop_id)
+ cell_x, cell_y = _gtfs_grid_cell(display.lon, display.lat)
+ candidates = [
+ stop
+ for dx in (-1, 0, 1)
+ for dy in (-1, 0, 1)
+ for stop in grid.get((cell_x + dx, cell_y + dy), [])
+ ]
+ best = None
+ best_score = -1.0
+ for candidate in candidates:
+ if candidate.lon is None or candidate.lat is None:
+ continue
+ distance_deg = Point(candidate.lon, candidate.lat).distance(Point(display.lon, display.lat))
+ name_overlap = _stop_name_similarity(normalized_name, candidate.normalized_name)
+ exact_name = bool(_stop_match_key(normalized_name) and _stop_match_key(normalized_name) == _stop_match_key(candidate.normalized_name))
+ if exact_name:
+ max_radius = GTFS_STOP_EXACT_NAME_LINK_RADIUS_DEG
+ elif name_overlap >= 0.5:
+ max_radius = GTFS_STOP_NAME_LINK_RADIUS_DEG
+ elif name_overlap >= 0.25:
+ max_radius = GTFS_STOP_PARTIAL_NAME_LINK_RADIUS_DEG
+ else:
+ continue
+ if distance_deg > max_radius:
+ continue
+ distance_score = max(0.0, 1.0 - (distance_deg / max_radius))
+ score = distance_score * 0.62 + name_overlap * 0.38
+ if score > best_score:
+ best = (candidate, round(distance_deg * 111_320, 1), round(score, 3))
+ best_score = score
+ if best is None:
+ return None, None, 0.0
+ return best
+
+
+def _stop_name_similarity(left: str, right: str) -> float:
+ left_tokens = _stop_match_tokens(left)
+ right_tokens = _stop_match_tokens(right)
+ if not left_tokens or not right_tokens:
+ return 0.0
+ if left_tokens == right_tokens:
+ return 1.0
+ return len(left_tokens & right_tokens) / len(left_tokens | right_tokens)
+
+
+def _stop_match_key(value: str) -> str:
+ return " ".join(sorted(_stop_match_tokens(value)))
+
+
+def _stop_match_tokens(value: str) -> set[str]:
+ tokens = set(norm_text(value).split())
+ if not tokens:
+ return set()
+
+ is_main_station = (
+ "hauptbahnhof" in tokens
+ or "hbf" in tokens
+ or ({"central", "station"} <= tokens and "bus" not in tokens)
+ or ({"main", "station"} <= tokens and "bus" not in tokens)
+ )
+ cleaned = {token for token in tokens if token not in STOP_MATCH_NOISE_TOKENS}
+ if is_main_station:
+ cleaned.difference_update({"hauptbahnhof", "hbf", "central", "main", "station"})
+ cleaned.add("mainstation")
+ return cleaned
+
+
+def _maybe_update_canonical_stop_display(canonical: CanonicalStop, display: GtfsStop) -> None:
+ name = display.name or display.stop_id
+ if _stop_display_name_quality(name) <= _stop_display_name_quality(canonical.name):
+ return
+ canonical.name = name
+ canonical.normalized_name = norm_text(name)
+
+
+def _stop_display_name_quality(name: str | None) -> int:
+ normalized = norm_text(name or "")
+ if not normalized:
+ return 0
+ tokens = set(normalized.split())
+ score = 100
+ if {"flixtrain", "flixbus"} & tokens:
+ score -= 35
+ if "central" in tokens and "station" in tokens:
+ score -= 5
+ if "hauptbahnhof" in tokens or "hbf" in tokens:
+ score += 8
+ if "berlin" in tokens:
+ score += 1
+ return score
+
+
+def _add_canonical_to_gtfs_grid(grid: dict[tuple[int, int], list[CanonicalStop]], canonical: CanonicalStop) -> None:
+ if canonical.lon is None or canonical.lat is None:
+ return
+ grid.setdefault(_gtfs_grid_cell(canonical.lon, canonical.lat), []).append(canonical)
+
+
+def _gtfs_grid_cell(lon: float, lat: float) -> tuple[int, int]:
+ return int(lon / GTFS_STOP_EXACT_NAME_LINK_RADIUS_DEG), int(lat / GTFS_STOP_EXACT_NAME_LINK_RADIUS_DEG)
+
+
+def _link_osm_stops(
+ session: Session,
+ *,
+ progress_callback: ProgressCallback | None = None,
+ commit_batches: bool = False,
+) -> dict[str, int]:
+ active_osm_dataset_ids = _active_dataset_ids(session, "osm_geojson")
+ if not active_osm_dataset_ids:
+ return {"canonical_stop_links": 0}
+ sidecar_dataset_ids = {
+ dataset.id
+ for dataset in session.scalars(select(Dataset).where(Dataset.id.in_(active_osm_dataset_ids))).all()
+ if features_are_sidecar(dataset)
+ }
+ if using_postgresql() and not sidecar_dataset_ids and not settings.osm_sidecar_create_visual_only_stops:
+ return _link_osm_stops_postgis(
+ session,
+ active_osm_dataset_ids,
+ progress_callback=progress_callback,
+ commit_batches=commit_batches,
+ )
+ canonical_grid = _canonical_stop_grid(session)
+ link_objects: list[CanonicalStopLink] = []
+ visual_only: list[tuple[OsmFeature, CanonicalStop, Point]] = []
+ link_count = 0
+ total_features = sum(osm_feature_count(session, dataset_id, kind=["stop", "station", "terminal"]) for dataset_id in active_osm_dataset_ids)
+ processed = 0
+ batch_size = max(100, int(settings.route_layer_osm_stop_batch_size))
+
+ def flush_links() -> None:
+ nonlocal link_count
+ if visual_only:
+ session.add_all([canonical for _, canonical, _ in visual_only])
+ session.flush()
+ for feature, canonical, _ in visual_only:
+ link_objects.append(
+ CanonicalStopLink(
+ canonical_stop_id=canonical.id,
+ layer="visual",
+ object_type="osm_feature",
+ dataset_id=feature.dataset_id,
+ object_id=feature.id,
+ external_id=f"{feature.osm_type}:{feature.osm_id}",
+ role=feature.kind,
+ confidence=1.0,
+ distance_m=None,
+ )
+ )
+ visual_only.clear()
+ if not link_objects:
+ return
+ for chunk in _chunks_objects(link_objects, 5000):
+ session.bulk_save_objects(chunk)
+ link_count += len(link_objects)
+ link_objects.clear()
+ _commit_or_flush(session, commit_batches)
+
+ for dataset_id in active_osm_dataset_ids:
+ offset = 0
+ while True:
+ features = query_osm_features(
+ session,
+ [dataset_id],
+ kinds=["stop", "station", "terminal"],
+ geometry_required=True,
+ limit=batch_size,
+ offset=offset,
+ )
+ if not features:
+ break
+ for feature in features:
+ point = _representative_point(feature.geometry_geojson)
+ if point is None:
+ continue
+ canonical, distance_m, confidence = _nearest_canonical_stop_from_grid(canonical_grid, feature, point)
+ if canonical is None:
+ if feature.dataset_id in sidecar_dataset_ids and not settings.osm_sidecar_create_visual_only_stops:
+ continue
+ feature = ensure_main_osm_feature(session, feature)
+ canonical = CanonicalStop(
+ stop_key=f"osm:{feature.dataset_id}:{feature.id}",
+ name=feature.name or feature.ref or f"OSM {feature.osm_type} {feature.osm_id}",
+ normalized_name=norm_text(feature.name or feature.ref or feature.osm_id),
+ lat=point.y,
+ lon=point.x,
+ mode=feature.mode,
+ metadata_json=json.dumps({"osm_feature_id": feature.id}, separators=(",", ":")),
+ )
+ visual_only.append((feature, canonical, point))
+ continue
+ feature = ensure_main_osm_feature(session, feature)
+ link_objects.append(
+ CanonicalStopLink(
+ canonical_stop_id=canonical.id,
+ layer="visual",
+ object_type="osm_feature",
+ dataset_id=feature.dataset_id,
+ object_id=feature.id,
+ external_id=f"{feature.osm_type}:{feature.osm_id}",
+ role=feature.kind,
+ confidence=confidence,
+ distance_m=distance_m,
+ )
+ )
+ processed += len(features)
+ offset += len(features)
+ flush_links()
+ _emit_progress(
+ progress_callback,
+ "route_layer_osm_stop_batch",
+ f"Linked OSM stops for dataset #{dataset_id}.",
+ processed,
+ total_features or None,
+ {"dataset_id": dataset_id, "processed": processed, "links": link_count},
+ )
+ if len(features) < batch_size:
+ break
+ flush_links()
+ session.flush()
+ return {"canonical_stop_links": link_count}
+
+
+def _link_osm_stops_postgis(
+ session: Session,
+ active_osm_dataset_ids: list[int],
+ *,
+ progress_callback: ProgressCallback | None,
+ commit_batches: bool,
+) -> dict[str, int]:
+ refresh_postgis_geometries(session, tables=["canonical_stops", "osm_features"])
+ dataset_sql = ", ".join(str(int(dataset_id)) for dataset_id in active_osm_dataset_ids)
+ total_features = sum(osm_feature_count(session, dataset_id, kind=["stop", "station", "terminal"]) for dataset_id in active_osm_dataset_ids)
+ _emit_progress(
+ progress_callback,
+ "route_layer_osm_stop_postgis_started",
+ "Linking OSM stops with PostGIS spatial join.",
+ 0,
+ total_features or None,
+ {"datasets": active_osm_dataset_ids},
+ )
+ params = {
+ "base_radius_deg": OSM_STOP_LINK_RADIUS_DEG,
+ "name_radius_deg": OSM_STOP_NAME_LINK_RADIUS_DEG,
+ "name_threshold": 0.25,
+ }
+ session.execute(
+ text(
+ f"""
+ WITH ranked AS (
+ SELECT
+ o.dataset_id,
+ o.id AS osm_feature_id,
+ o.osm_type,
+ o.osm_id,
+ o.kind,
+ c.id AS canonical_stop_id,
+ ST_Distance(o.geom, c.geom) AS distance_deg,
+ ST_Distance(o.geom::geography, c.geom::geography) AS distance_m,
+ GREATEST(
+ similarity(LOWER(COALESCE(o.name, '')), LOWER(COALESCE(c.normalized_name, ''))),
+ similarity(LOWER(COALESCE(o.ref, '')), LOWER(COALESCE(c.normalized_name, '')))
+ ) AS name_score,
+ ROW_NUMBER() OVER (
+ PARTITION BY o.dataset_id, o.id
+ ORDER BY
+ (ST_Distance(o.geom, c.geom) * 111320.0)
+ - (
+ GREATEST(
+ similarity(LOWER(COALESCE(o.name, '')), LOWER(COALESCE(c.normalized_name, ''))),
+ similarity(LOWER(COALESCE(o.ref, '')), LOWER(COALESCE(c.normalized_name, '')))
+ ) * 120.0
+ ),
+ c.id
+ ) AS rn
+ FROM osm_features AS o
+ JOIN LATERAL (
+ SELECT candidate.*
+ FROM canonical_stops AS candidate
+ WHERE candidate.geom IS NOT NULL
+ AND candidate.geom && ST_Expand(o.geom, :name_radius_deg)
+ AND ST_DWithin(candidate.geom, o.geom, :name_radius_deg)
+ ORDER BY o.geom <-> candidate.geom
+ LIMIT 12
+ ) AS c ON TRUE
+ WHERE o.dataset_id IN ({dataset_sql})
+ AND o.kind IN ('stop', 'station', 'terminal')
+ AND o.geom IS NOT NULL
+ )
+ INSERT INTO canonical_stop_links
+ (canonical_stop_id, layer, object_type, dataset_id, object_id, external_id, role, confidence, distance_m)
+ SELECT
+ canonical_stop_id,
+ 'visual',
+ 'osm_feature',
+ dataset_id,
+ osm_feature_id,
+ osm_type || ':' || osm_id,
+ kind,
+ ROUND(
+ LEAST(
+ 1.0::double precision,
+ GREATEST(
+ 0.0::double precision,
+ (
+ 1.0
+ - distance_deg
+ / CASE WHEN name_score >= :name_threshold THEN :name_radius_deg ELSE :base_radius_deg END
+ ) * 0.6
+ + name_score * 0.4
+ )
+ )::numeric,
+ 3
+ )::double precision,
+ ROUND(distance_m::numeric, 1)::double precision
+ FROM ranked
+ WHERE rn = 1
+ AND (
+ distance_deg <= :base_radius_deg
+ OR (name_score >= :name_threshold AND distance_deg <= :name_radius_deg)
+ )
+ ON CONFLICT ON CONSTRAINT uq_canonical_stop_link_object DO NOTHING
+ """
+ ),
+ params,
+ )
+ _commit_or_flush(session, commit_batches)
+ link_count = int(
+ session.scalar(
+ text(
+ f"""
+ SELECT COUNT(*)
+ FROM canonical_stop_links
+ WHERE layer = 'visual'
+ AND object_type = 'osm_feature'
+ AND dataset_id IN ({dataset_sql})
+ """
+ )
+ )
+ or 0
+ )
+ analyze_postgresql_tables(session, ["canonical_stop_links"])
+ _emit_progress(
+ progress_callback,
+ "route_layer_osm_stop_postgis_completed",
+ "Linked OSM stops with PostGIS spatial join.",
+ total_features,
+ total_features or None,
+ {"datasets": active_osm_dataset_ids, "links": link_count},
+ )
+ return {"canonical_stop_links": link_count}
+
+
+def _build_route_patterns(
+ session: Session,
+ *,
+ progress_callback: ProgressCallback | None = None,
+) -> dict[str, int]:
+ osm_candidates = _osm_route_candidates(session, progress_callback=progress_callback)
+ overrides = _route_layer_overrides(session)
+ seeds = _gtfs_pattern_seeds(session)
+ _emit_progress(
+ progress_callback,
+ "route_layer_pattern_seeds",
+ f"Loaded {len(seeds)} GTFS route-pattern seeds.",
+ 0,
+ len(seeds),
+ {"seeds": len(seeds)},
+ )
+ link_count = 0
+ stop_count = 0
+ proposed_count = 0
+ existing_patterns_by_key = {
+ pattern.pattern_key: pattern
+ for pattern in session.scalars(select(RoutePattern).order_by(RoutePattern.id)).all()
+ }
+ patterns_by_key: dict[str, RoutePattern] = {}
+ pattern_usage: dict[str, int] = {}
+ pattern_confidence_by_key: dict[str, float] = {}
+ created_pattern_count = 0
+ updated_pattern_keys: set[str] = set()
+ pending: list[_PatternBuildItem] = []
+ for index, seed in enumerate(seeds, start=1):
+ if not seed.geometry_text:
+ continue
+ shape_key = seed.shape_id or GTFS_ROUTE_PATTERN_NULL_SHAPE
+ chosen, score, reasons = _choose_osm_candidate(seed, osm_candidates, overrides)
+ if chosen is not None:
+ chosen_feature = ensure_main_osm_feature(session, chosen.feature)
+ pattern_key = _osm_pattern_key(chosen_feature)
+ source_kind = "osm"
+ status = "active"
+ confidence = score
+ pattern = patterns_by_key.get(pattern_key) or existing_patterns_by_key.get(pattern_key)
+ if pattern is None:
+ bbox = chosen.bbox
+ pattern = RoutePattern(
+ pattern_key=pattern_key,
+ route_ref=chosen_feature.ref or seed.route.short_name or seed.route.route_id,
+ route_name=chosen_feature.name or seed.route.long_name,
+ mode=chosen_feature.mode or seed.route.mode,
+ route_scope=chosen_feature.route_scope
+ or infer_osm_route_scope_from_tags(
+ chosen_feature.mode,
+ chosen_feature.ref,
+ chosen_feature.name,
+ chosen_feature.network,
+ chosen_feature.tags_json,
+ ),
+ operator_name=chosen_feature.operator or seed.route.operator_name,
+ source_kind=source_kind,
+ status=status,
+ osm_feature_id=chosen_feature.id,
+ gtfs_route_id=seed.route.id,
+ gtfs_shape_id=None,
+ geometry_geojson=chosen.geometry_text,
+ min_lon=bbox[0],
+ min_lat=bbox[1],
+ max_lon=bbox[2],
+ max_lat=bbox[3],
+ confidence=confidence,
+ metadata_json=json.dumps(
+ {
+ "version": ROUTE_LAYER_VERSION,
+ "visual_source": "osm_feature",
+ "osm_feature_id": chosen_feature.id,
+ "osm_type": chosen_feature.osm_type,
+ "osm_id": chosen_feature.osm_id,
+ },
+ separators=(",", ":"),
+ ),
+ )
+ session.add(pattern)
+ created_pattern_count += 1
+ else:
+ bbox = chosen.bbox
+ changed = _update_route_pattern(
+ pattern,
+ route_ref=chosen_feature.ref or seed.route.short_name or seed.route.route_id,
+ route_name=chosen_feature.name or seed.route.long_name,
+ mode=chosen_feature.mode or seed.route.mode,
+ route_scope=chosen_feature.route_scope
+ or infer_osm_route_scope_from_tags(
+ chosen_feature.mode,
+ chosen_feature.ref,
+ chosen_feature.name,
+ chosen_feature.network,
+ chosen_feature.tags_json,
+ ),
+ operator_name=chosen_feature.operator or seed.route.operator_name,
+ source_kind=source_kind,
+ status=status,
+ osm_feature_id=chosen_feature.id,
+ gtfs_route_id=seed.route.id,
+ gtfs_shape_id=None,
+ geometry_geojson=chosen.geometry_text,
+ min_lon=bbox[0],
+ min_lat=bbox[1],
+ max_lon=bbox[2],
+ max_lat=bbox[3],
+ metadata_json=json.dumps(
+ {
+ "version": ROUTE_LAYER_VERSION,
+ "visual_source": "osm_feature",
+ "osm_feature_id": chosen_feature.id,
+ "osm_type": chosen_feature.osm_type,
+ "osm_id": chosen_feature.osm_id,
+ },
+ separators=(",", ":"),
+ ),
+ )
+ if changed:
+ updated_pattern_keys.add(pattern_key)
+ patterns_by_key[pattern_key] = pattern
+ next_confidence = max(pattern_confidence_by_key.get(pattern_key, confidence), confidence)
+ pattern_confidence_by_key[pattern_key] = next_confidence
+ if pattern_key in existing_patterns_by_key and float(pattern.confidence or 0) != float(next_confidence):
+ updated_pattern_keys.add(pattern_key)
+ pattern.confidence = next_confidence
+ link_reasons = _link_reasons(seed, chosen, reasons)
+ else:
+ pattern_key = f"gtfs:{seed.route.dataset_id}:{seed.route.route_id}:{shape_key}"
+ source_kind = "gtfs_proposed"
+ status = "needs_visual_review"
+ confidence = 0.0
+ proposed_count += 1
+ metadata_json = json.dumps(
+ {
+ "version": ROUTE_LAYER_VERSION,
+ "visual_source": "gtfs_shape",
+ "gtfs_geometry_source": seed.geometry_source,
+ "match_reasons": reasons,
+ },
+ separators=(",", ":"),
+ )
+ pattern = patterns_by_key.get(pattern_key) or existing_patterns_by_key.get(pattern_key)
+ if pattern is None:
+ pattern = RoutePattern(
+ pattern_key=pattern_key,
+ route_ref=seed.route.short_name or seed.route.route_id,
+ route_name=seed.route.long_name,
+ mode=seed.route.mode,
+ route_scope=seed.route.route_scope,
+ operator_name=seed.route.operator_name,
+ source_kind=source_kind,
+ status=status,
+ osm_feature_id=None,
+ gtfs_route_id=seed.route.id,
+ gtfs_shape_id=seed.shape_id,
+ geometry_geojson=seed.geometry_text,
+ min_lon=seed.bbox[0],
+ min_lat=seed.bbox[1],
+ max_lon=seed.bbox[2],
+ max_lat=seed.bbox[3],
+ confidence=confidence,
+ metadata_json=metadata_json,
+ )
+ session.add(pattern)
+ created_pattern_count += 1
+ else:
+ changed = _update_route_pattern(
+ pattern,
+ route_ref=seed.route.short_name or seed.route.route_id,
+ route_name=seed.route.long_name,
+ mode=seed.route.mode,
+ route_scope=seed.route.route_scope,
+ operator_name=seed.route.operator_name,
+ source_kind=source_kind,
+ status=status,
+ osm_feature_id=None,
+ gtfs_route_id=seed.route.id,
+ gtfs_shape_id=seed.shape_id,
+ geometry_geojson=seed.geometry_text,
+ min_lon=seed.bbox[0],
+ min_lat=seed.bbox[1],
+ max_lon=seed.bbox[2],
+ max_lat=seed.bbox[3],
+ metadata_json=metadata_json,
+ )
+ if changed:
+ updated_pattern_keys.add(pattern_key)
+ patterns_by_key[pattern_key] = pattern
+ pattern_confidence_by_key[pattern_key] = confidence
+ if pattern_key in existing_patterns_by_key and float(pattern.confidence or 0) != float(confidence):
+ updated_pattern_keys.add(pattern_key)
+ pattern.confidence = confidence
+ link_reasons = reasons
+ pattern_usage[pattern_key] = pattern_usage.get(pattern_key, 0) + 1
+ pending.append(
+ _PatternBuildItem(
+ seed=seed,
+ pattern=pattern,
+ confidence=confidence,
+ source_kind=source_kind,
+ status=status,
+ reasons=link_reasons,
+ )
+ )
+ if index % 500 == 0:
+ session.flush()
+ _emit_progress(
+ progress_callback,
+ "route_layer_pattern_batch",
+ f"Built {index}/{len(seeds)} route-pattern candidates.",
+ index,
+ len(seeds),
+ {"patterns": len(patterns_by_key), "links_pending": len(pending), "gtfs_proposed_patterns": proposed_count},
+ )
+ session.flush()
+ obsolete_pattern_ids = [
+ pattern.id
+ for pattern_key, pattern in existing_patterns_by_key.items()
+ if pattern_key not in patterns_by_key and pattern.id is not None
+ ]
+ for chunk in _chunks_objects(obsolete_pattern_ids, 1000):
+ session.execute(delete(RoutePattern).where(RoutePattern.id.in_(chunk)))
+ if obsolete_pattern_ids:
+ session.flush()
+ refresh_postgis_geometries(session, tables=["route_patterns"])
+ analyze_postgresql_tables(session, ["route_patterns"])
+ _emit_progress(
+ progress_callback,
+ "route_layer_patterns_materialized",
+ "Materialized route-pattern rows.",
+ len(seeds),
+ len(seeds),
+ {
+ "route_patterns": len(patterns_by_key),
+ "route_patterns_created": created_pattern_count,
+ "route_patterns_updated": len(updated_pattern_keys),
+ "route_patterns_reused": max(0, len(patterns_by_key) - created_pattern_count - len(updated_pattern_keys)),
+ "route_patterns_removed": len(obsolete_pattern_ids),
+ "gtfs_proposed_patterns": proposed_count,
+ },
+ )
+
+ for pattern_key, count in pattern_usage.items():
+ _update_pattern_metadata(patterns_by_key[pattern_key], linked_gtfs_patterns=count)
+
+ link_objects: list[GtfsRoutePatternLink] = []
+ for item in pending:
+ seed = item.seed
+ shape_key = seed.shape_id or GTFS_ROUTE_PATTERN_NULL_SHAPE
+ link_objects.append(
+ GtfsRoutePatternLink(
+ dataset_id=seed.route.dataset_id,
+ gtfs_route_id=seed.route.id,
+ route_id=seed.route.route_id,
+ shape_id=shape_key,
+ route_pattern_id=item.pattern.id,
+ confidence=item.confidence,
+ status=item.status,
+ source_kind=item.source_kind,
+ reasons_json=json.dumps(item.reasons, separators=(",", ":")),
+ )
+ )
+ link_count += 1
+ for chunk in _chunks_objects(link_objects, 5000):
+ session.bulk_save_objects(chunk)
+ _emit_progress(
+ progress_callback,
+ "route_layer_pattern_links",
+ "Stored GTFS route-pattern links.",
+ link_count,
+ link_count,
+ {"route_pattern_links": link_count},
+ )
+
+ stop_times_by_trip = _representative_stop_times(session, pending)
+ canonical_lookup = _canonical_link_lookup(session, stop_times_by_trip)
+ stop_objects: list[RoutePatternStop] = []
+ representative_stop_items: dict[int, _PatternBuildItem] = {}
+ for item in pending:
+ if item.pattern.id is not None:
+ representative_stop_items.setdefault(item.pattern.id, item)
+ for item in representative_stop_items.values():
+ seed = item.seed
+ objects = _route_pattern_stop_objects(
+ pattern=item.pattern,
+ dataset_id=seed.route.dataset_id,
+ trip_id=seed.trip_id,
+ rows=stop_times_by_trip.get((seed.route.dataset_id, seed.trip_id or ""), []),
+ canonical_lookup=canonical_lookup,
+ )
+ stop_objects.extend(objects)
+ stop_count += len(objects)
+ if len(stop_objects) >= 10000:
+ session.bulk_save_objects(stop_objects)
+ stop_objects.clear()
+ _emit_progress(
+ progress_callback,
+ "route_layer_pattern_stop_batch",
+ "Stored route-pattern stop links.",
+ stop_count,
+ None,
+ {"route_pattern_stops": stop_count},
+ )
+ if stop_objects:
+ session.bulk_save_objects(stop_objects)
+ trip_link_count = _build_trip_route_pattern_links(session)
+ session.flush()
+ result = {
+ "route_patterns": len(patterns_by_key),
+ "route_patterns_created": created_pattern_count,
+ "route_patterns_updated": len(updated_pattern_keys),
+ "route_patterns_reused": max(0, len(patterns_by_key) - created_pattern_count - len(updated_pattern_keys)),
+ "route_patterns_removed": len(obsolete_pattern_ids),
+ "route_pattern_links": link_count,
+ "trip_pattern_links": trip_link_count,
+ "route_pattern_stops": stop_count,
+ "gtfs_proposed_patterns": proposed_count,
+ }
+ _emit_progress(
+ progress_callback,
+ "route_layer_patterns_completed",
+ "Route-pattern build completed.",
+ len(seeds),
+ len(seeds),
+ result,
+ )
+ return result
+
+
+def _update_route_pattern(pattern: RoutePattern, **fields) -> bool:
+ changed = False
+ for key, value in fields.items():
+ if key == "metadata_json":
+ value = _route_pattern_metadata_with_existing_derived_values(pattern.metadata_json, value)
+ if getattr(pattern, key) == value:
+ continue
+ setattr(pattern, key, value)
+ changed = True
+ return changed
+
+
+def _route_pattern_metadata_with_existing_derived_values(existing_json: str | None, next_json: str | None) -> str | None:
+ if not next_json:
+ return next_json
+ try:
+ existing = json.loads(existing_json or "{}")
+ next_metadata = json.loads(next_json)
+ except json.JSONDecodeError:
+ return next_json
+ if "linked_gtfs_patterns" in existing:
+ next_metadata["linked_gtfs_patterns"] = existing["linked_gtfs_patterns"]
+ return json.dumps(next_metadata, separators=(",", ":"))
+
+
+def _build_trip_route_pattern_links(session: Session) -> int:
+ session.flush()
+ session.execute(delete(GtfsTripRoutePatternLink))
+ result = session.execute(
+ text(
+ """
+ INSERT INTO gtfs_trip_route_pattern_links
+ (dataset_id, trip_id, route_id, shape_id, route_pattern_id, source_kind, confidence, status)
+ SELECT
+ trips.dataset_id,
+ trips.trip_id,
+ trips.route_id,
+ COALESCE(trips.shape_id, :null_shape) AS shape_id,
+ links.route_pattern_id,
+ links.source_kind,
+ links.confidence,
+ links.status
+ FROM gtfs_trips AS trips
+ JOIN gtfs_route_pattern_links AS links
+ ON links.dataset_id = trips.dataset_id
+ AND links.route_id = trips.route_id
+ AND links.shape_id = COALESCE(trips.shape_id, :null_shape)
+ """
+ ),
+ {"null_shape": GTFS_ROUTE_PATTERN_NULL_SHAPE},
+ )
+ return int(result.rowcount or 0)
+
+
+def _active_dataset_ids(session: Session, kind: str) -> list[int]:
+ return [
+ row[0]
+ for row in session.execute(select(Dataset.id).where(Dataset.is_active.is_(True), Dataset.kind == kind)).all()
+ ]
+
+
+def _best_display_stop(group_id: str, stops: list[GtfsStop]) -> GtfsStop:
+ return min(
+ stops,
+ key=lambda stop: (
+ 0 if stop.stop_id == group_id and stop.parent_station is None else 1,
+ 0 if stop.parent_station == group_id else 1,
+ 0 if stop.parent_station is not None else 1,
+ 0 if stop.lat is not None and stop.lon is not None else 1,
+ stop.name or "",
+ stop.stop_id,
+ ),
+ )
+
+
+def _canonical_stop_grid(session: Session) -> dict[tuple[int, int], list[CanonicalStop]]:
+ stops = session.scalars(select(CanonicalStop).where(CanonicalStop.lon.is_not(None), CanonicalStop.lat.is_not(None))).all()
+ grid: dict[tuple[int, int], list[CanonicalStop]] = {}
+ for stop in stops:
+ grid.setdefault(_grid_cell(stop.lon, stop.lat), []).append(stop)
+ return grid
+
+
+def _nearest_canonical_stop_from_grid(
+ grid: dict[tuple[int, int], list[CanonicalStop]], feature: OsmFeature, point: Point
+) -> tuple[CanonicalStop | None, float | None, float]:
+ cell_x, cell_y = _grid_cell(point.x, point.y)
+ candidates = [
+ stop
+ for dx in (-1, 0, 1)
+ for dy in (-1, 0, 1)
+ for stop in grid.get((cell_x + dx, cell_y + dy), [])
+ ]
+ best = None
+ best_score = -1.0
+ feature_name = norm_text(feature.name or feature.ref or "")
+ for candidate in candidates:
+ if candidate.lon is None or candidate.lat is None:
+ continue
+ distance_deg = Point(candidate.lon, candidate.lat).distance(point)
+ distance_m = distance_deg * 111_320
+ name_overlap = _name_overlap(feature_name, candidate.normalized_name)
+ max_radius = OSM_STOP_NAME_LINK_RADIUS_DEG if name_overlap >= 0.25 else OSM_STOP_LINK_RADIUS_DEG
+ if distance_deg > max_radius:
+ continue
+ distance_score = max(0.0, 1.0 - (distance_deg / max_radius))
+ score = distance_score * 0.6 + name_overlap * 0.4
+ if score > best_score:
+ best = (candidate, round(distance_m, 1), round(score, 3))
+ best_score = score
+ if best is None:
+ return None, None, 0.0
+ return best
+
+
+def _grid_cell(lon: float, lat: float) -> tuple[int, int]:
+ return int(lon / OSM_STOP_LINK_RADIUS_DEG), int(lat / OSM_STOP_LINK_RADIUS_DEG)
+
+
+def _name_overlap(left: str, right: str) -> float:
+ if not left or not right:
+ return 0.0
+ left_tokens = set(left.split())
+ right_tokens = set(right.split())
+ if not left_tokens or not right_tokens:
+ return 0.0
+ return len(left_tokens & right_tokens) / len(left_tokens | right_tokens)
+
+
+def _representative_point(geometry_text: str | None) -> Point | None:
+ if not geometry_text:
+ return None
+ try:
+ geom = shape(json.loads(geometry_text))
+ except Exception: # noqa: BLE001 - malformed source geometry should not stop extraction
+ return None
+ if isinstance(geom, Point):
+ return geom
+ return geom.representative_point()
+
+
+def _osm_route_candidates(
+ session: Session,
+ *,
+ progress_callback: ProgressCallback | None = None,
+) -> _OsmRouteCandidateIndex:
+ active_osm_dataset_ids = _active_dataset_ids(session, "osm_geojson")
+ if not active_osm_dataset_ids:
+ return _OsmRouteCandidateIndex(by_ref_mode={}, by_id={})
+ indexed: dict[tuple[str, str], list[_OsmRouteCandidate]] = {}
+ by_id: dict[int, _OsmRouteCandidate] = {}
+ total_features = sum(osm_feature_count(session, dataset_id, kind="route") for dataset_id in active_osm_dataset_ids)
+ processed = 0
+ batch_size = max(100, int(settings.route_layer_osm_route_batch_size))
+ for dataset_id in active_osm_dataset_ids:
+ offset = 0
+ while True:
+ features = query_osm_features(
+ session,
+ [dataset_id],
+ kinds=["route"],
+ geometry_required=True,
+ limit=batch_size,
+ offset=offset,
+ )
+ if not features:
+ break
+ for feature in features:
+ try:
+ geometry_text = _normalized_geometry_text(feature.geometry_geojson) or feature.geometry_geojson
+ geom = shape(json.loads(geometry_text))
+ except Exception: # noqa: BLE001 - ignore malformed route geometry
+ continue
+ ref_key = norm_ref(feature.ref or feature.name or "")
+ if not ref_key:
+ continue
+ _, bbox = geometry_json_and_bbox(json.loads(geometry_text))
+ candidate = _OsmRouteCandidate(
+ feature=feature,
+ geom=geom,
+ geometry_text=geometry_text,
+ bbox=bbox,
+ ref_key=ref_key,
+ mode=feature.mode,
+ )
+ indexed.setdefault((ref_key, feature.mode or ""), []).append(candidate)
+ by_id[feature.id] = candidate
+ processed += len(features)
+ offset += len(features)
+ _emit_progress(
+ progress_callback,
+ "route_layer_osm_route_batch",
+ f"Indexed OSM route candidates for dataset #{dataset_id}.",
+ processed,
+ total_features or None,
+ {"dataset_id": dataset_id, "processed": processed, "candidate_refs": len(indexed), "candidates": len(by_id)},
+ )
+ if len(features) < batch_size:
+ break
+ _emit_progress(
+ progress_callback,
+ "route_layer_osm_routes_indexed",
+ "Indexed OSM route candidates.",
+ processed,
+ total_features or None,
+ {"candidate_refs": len(indexed), "candidates": len(by_id)},
+ )
+ return _OsmRouteCandidateIndex(by_ref_mode=indexed, by_id=by_id)
+
+
+def _route_layer_overrides(session: Session) -> _RouteLayerOverrides:
+ matches = session.scalars(
+ select(RouteMatch).where(RouteMatch.status.in_(["accepted", "rejected"]))
+ ).all()
+ accepted: dict[int, int] = {}
+ rejected: dict[int, set[int]] = {}
+ for match in matches:
+ if match.osm_feature_id is None:
+ continue
+ if match.status == "accepted":
+ accepted[match.gtfs_route_id] = match.osm_feature_id
+ elif match.status == "rejected":
+ rejected.setdefault(match.gtfs_route_id, set()).add(match.osm_feature_id)
+ return _RouteLayerOverrides(accepted_by_gtfs_route_id=accepted, rejected_by_gtfs_route_id=rejected)
+
+
+def _gtfs_pattern_seeds(session: Session) -> list[_GtfsPatternSeed]:
+ active_gtfs_dataset_ids = _active_dataset_ids(session, "gtfs")
+ if not active_gtfs_dataset_ids:
+ return []
+ rows = session.execute(
+ select(GtfsRoute, GtfsTrip.shape_id, func.min(GtfsTrip.trip_id))
+ .join(GtfsTrip, and_(GtfsTrip.dataset_id == GtfsRoute.dataset_id, GtfsTrip.route_id == GtfsRoute.route_id))
+ .where(GtfsRoute.dataset_id.in_(active_gtfs_dataset_ids))
+ .group_by(GtfsRoute.id, GtfsTrip.shape_id)
+ .order_by(GtfsRoute.dataset_id, GtfsRoute.route_id, GtfsTrip.shape_id)
+ ).all()
+ shape_rows = session.execute(
+ select(
+ GtfsShape.dataset_id,
+ GtfsShape.shape_id,
+ GtfsShape.geometry_geojson,
+ GtfsShape.min_lon,
+ GtfsShape.min_lat,
+ GtfsShape.max_lon,
+ GtfsShape.max_lat,
+ ).where(GtfsShape.dataset_id.in_(active_gtfs_dataset_ids))
+ ).all()
+ shapes = {
+ (dataset_id, shape_id): {
+ "geometry": geometry,
+ "bbox": (min_lon, min_lat, max_lon, max_lat),
+ "points": _geometry_points_from_text(geometry),
+ }
+ for dataset_id, shape_id, geometry, min_lon, min_lat, max_lon, max_lat in shape_rows
+ }
+ seeds = []
+ for route, shape_id, trip_id in rows:
+ geometry_text = None
+ geometry_source = "none"
+ bbox = (route.min_lon, route.min_lat, route.max_lon, route.max_lat)
+ points = _geometry_points_from_text(route.geometry_geojson)
+ if shape_id:
+ shape_row = shapes.get((route.dataset_id, shape_id))
+ if shape_row is not None:
+ geometry_text = shape_row["geometry"]
+ bbox = shape_row["bbox"]
+ points = shape_row["points"]
+ geometry_source = "gtfs_shape"
+ if not geometry_text and route.geometry_geojson:
+ geometry_text = route.geometry_geojson
+ geometry_source = "gtfs_route"
+ start_point = Point(points[0]) if points else None
+ end_point = Point(points[-1]) if points else None
+ center_point = _bbox_center_point(bbox)
+ seeds.append(
+ _GtfsPatternSeed(
+ route=route,
+ shape_id=shape_id,
+ trip_id=trip_id,
+ geometry_text=geometry_text,
+ geometry_source=geometry_source,
+ bbox=bbox,
+ start_point=start_point,
+ end_point=end_point,
+ center_point=center_point,
+ )
+ )
+ return seeds
+
+
+def _choose_osm_candidate(
+ seed: _GtfsPatternSeed,
+ candidate_index: _OsmRouteCandidateIndex,
+ overrides: _RouteLayerOverrides,
+) -> tuple[_OsmRouteCandidate | None, float, dict[str, object]]:
+ if not seed.geometry_text:
+ return None, 0.0, {"reason": "no GTFS geometry available"}
+ accepted_feature_id = overrides.accepted_by_gtfs_route_id.get(seed.route.id)
+ if accepted_feature_id is not None:
+ accepted = candidate_index.by_id.get(accepted_feature_id)
+ if accepted is not None:
+ return (
+ accepted,
+ 100.0,
+ {
+ "manual": "accepted_route_match",
+ "osm_feature_id": accepted.feature.id,
+ "osm_id": accepted.feature.osm_id,
+ },
+ )
+ route_ref = norm_ref(seed.route.short_name or seed.route.route_id)
+ if not route_ref:
+ return None, 0.0, {"reason": "no GTFS route ref"}
+ candidate_pool = []
+ rejected_feature_ids = overrides.rejected_by_gtfs_route_id.get(seed.route.id, set())
+ for (ref_key, mode), candidates in candidate_index.by_ref_mode.items():
+ if ref_key != route_ref:
+ continue
+ if _mode_compatible(seed.route.mode or "", mode):
+ candidate_pool.extend(candidate for candidate in candidates if candidate.feature.id not in rejected_feature_ids)
+ if not candidate_pool:
+ return None, 0.0, {"reason": "no OSM route candidate with same ref and mode"}
+
+ best = None
+ best_rank_score = 0.0
+ best_score = 0.0
+ best_reasons: dict[str, object] = {}
+ for candidate in candidate_pool:
+ score = 50.0
+ reasons: dict[str, object] = {"ref": "exact", "mode": "compatible"}
+ if bbox_overlap(seed.bbox, candidate.bbox):
+ score += 20
+ reasons["bbox"] = "overlap"
+ if seed.start_point is not None and seed.end_point is not None:
+ endpoint_distance = candidate.geom.distance(seed.start_point) + candidate.geom.distance(seed.end_point)
+ reasons["endpoint_distance_deg"] = round(endpoint_distance, 6)
+ if endpoint_distance < 0.002:
+ score += 30
+ elif endpoint_distance < 0.01:
+ score += 22
+ elif endpoint_distance < 0.03:
+ score += 10
+ direction_metrics = _candidate_direction_metrics(seed, candidate)
+ if direction_metrics:
+ direction_score = _direction_alignment_score(direction_metrics)
+ score += direction_score
+ reasons["directional_match"] = {**direction_metrics, "score": direction_score}
+ if seed.center_point is not None:
+ centroid_distance = candidate.geom.distance(seed.center_point)
+ reasons["center_distance_deg"] = round(centroid_distance, 6)
+ if centroid_distance < 0.004:
+ score += 10
+ elif centroid_distance < 0.015:
+ score += 5
+ if score > best_rank_score:
+ best = candidate
+ best_rank_score = score
+ best_score = min(score, 100.0)
+ best_reasons = reasons
+ if best is None or best_score < OSM_ROUTE_MIN_SCORE:
+ reasons = best_reasons or {"reason": "no OSM candidate above threshold"}
+ reasons["fallback"] = "gtfs_proposed_route_layer_pattern"
+ return None, best_score, reasons
+ best_reasons["osm_feature_id"] = best.feature.id
+ best_reasons["osm_id"] = best.feature.osm_id
+ return best, best_score, best_reasons
+
+
+def _osm_pattern_key(feature: OsmFeature) -> str:
+ return f"osm:{feature.osm_type}:{feature.osm_id}"
+
+
+def _link_reasons(seed: _GtfsPatternSeed, chosen: _OsmRouteCandidate, reasons: dict[str, object]) -> dict[str, object]:
+ link_reasons = dict(reasons)
+ link_reasons["gtfs_geometry_source"] = seed.geometry_source
+ link_reasons["direction"] = _direction_evidence(seed, chosen)
+ return link_reasons
+
+
+def _direction_evidence(seed: _GtfsPatternSeed, candidate: _OsmRouteCandidate) -> dict[str, object]:
+ if seed.start_point is None or seed.end_point is None:
+ return {"direction": "unknown", "reason": "missing GTFS shape endpoints"}
+
+ evidence: dict[str, object] = {}
+ start_projection = _project_point_on_geometry(candidate.geom, seed.start_point)
+ end_projection = _project_point_on_geometry(candidate.geom, seed.end_point)
+ if start_projection is not None and end_projection is not None:
+ evidence["start_projection"] = round(start_projection, 6)
+ evidence["end_projection"] = round(end_projection, 6)
+ if abs(start_projection - end_projection) > 1e-9:
+ evidence["direction"] = "forward" if start_projection < end_projection else "reverse"
+
+ endpoints = _geometry_endpoints(candidate.geom)
+ if endpoints is not None:
+ osm_start, osm_end = endpoints
+ forward_distance = osm_start.distance(seed.start_point) + osm_end.distance(seed.end_point)
+ reverse_distance = osm_start.distance(seed.end_point) + osm_end.distance(seed.start_point)
+ evidence["endpoint_forward_distance_deg"] = round(forward_distance, 6)
+ evidence["endpoint_reverse_distance_deg"] = round(reverse_distance, 6)
+ if abs(forward_distance - reverse_distance) > 1e-9:
+ evidence["endpoint_direction"] = "forward" if forward_distance < reverse_distance else "reverse"
+ evidence.setdefault("direction", evidence.get("endpoint_direction", "unknown"))
+
+ evidence.setdefault("direction", "unknown")
+ return evidence
+
+
+def _candidate_direction_metrics(seed: _GtfsPatternSeed, candidate: _OsmRouteCandidate) -> dict[str, object] | None:
+ if seed.start_point is None or seed.end_point is None:
+ return None
+
+ metrics: dict[str, object] = {}
+ start_projection = _project_point_on_geometry(candidate.geom, seed.start_point)
+ end_projection = _project_point_on_geometry(candidate.geom, seed.end_point)
+ if start_projection is not None and end_projection is not None:
+ projection_delta = end_projection - start_projection
+ metrics["projection_delta"] = round(projection_delta, 6)
+ if abs(projection_delta) > 1e-9:
+ metrics["projection_direction"] = "forward" if projection_delta > 0 else "reverse"
+
+ endpoints = _geometry_endpoints(candidate.geom)
+ if endpoints is not None:
+ osm_start, osm_end = endpoints
+ forward_distance = osm_start.distance(seed.start_point) + osm_end.distance(seed.end_point)
+ reverse_distance = osm_start.distance(seed.end_point) + osm_end.distance(seed.start_point)
+ metrics["endpoint_forward_distance_deg"] = round(forward_distance, 6)
+ metrics["endpoint_reverse_distance_deg"] = round(reverse_distance, 6)
+ metrics["endpoint_margin_deg"] = round(abs(reverse_distance - forward_distance), 6)
+ if abs(forward_distance - reverse_distance) > 1e-9:
+ metrics["endpoint_direction"] = "forward" if forward_distance < reverse_distance else "reverse"
+
+ return metrics or None
+
+
+def _direction_alignment_score(metrics: dict[str, object]) -> float:
+ score = 0.0
+ if metrics.get("projection_direction") == "forward":
+ score += 16.0
+ if metrics.get("endpoint_direction") == "forward":
+ forward_distance = float(metrics.get("endpoint_forward_distance_deg") or 999.0)
+ margin = float(metrics.get("endpoint_margin_deg") or 0.0)
+ if forward_distance < 0.004:
+ score += 12.0
+ elif forward_distance < 0.015:
+ score += 7.0
+ elif forward_distance < 0.04:
+ score += 3.0
+ if margin > 0.01:
+ score += 4.0
+ elif margin > 0.002:
+ score += 2.0
+ return min(score, 28.0)
+
+
+def _update_pattern_metadata(pattern: RoutePattern, **values: object) -> None:
+ try:
+ metadata = json.loads(pattern.metadata_json or "{}")
+ except json.JSONDecodeError:
+ metadata = {}
+ metadata.update(values)
+ pattern.metadata_json = json.dumps(metadata, separators=(",", ":"))
+
+
+def _representative_stop_times(
+ session: Session, pending: list[_PatternBuildItem]
+) -> dict[tuple[int, str], list[GtfsStopTime]]:
+ trip_ids_by_dataset: dict[int, set[str]] = {}
+ for item in pending:
+ seed = item.seed
+ if seed.trip_id:
+ trip_ids_by_dataset.setdefault(seed.route.dataset_id, set()).add(seed.trip_id)
+ grouped: dict[tuple[int, str], list[GtfsStopTime]] = {}
+ for dataset_id, trip_ids in trip_ids_by_dataset.items():
+ for chunk in _chunks(sorted(trip_ids), 600):
+ rows_by_trip = storage_stop_times_by_trip(session, dataset_id, chunk)
+ rows = [row for trip_id in chunk for row in rows_by_trip.get(trip_id, [])]
+ for row in rows:
+ grouped.setdefault((dataset_id, row.trip_id), []).append(row)
+ return grouped
+
+
+def _canonical_link_lookup(
+ session: Session, stop_times_by_trip: dict[tuple[int, str], list[GtfsStopTime]]
+) -> dict[tuple[int, str], int]:
+ stop_ids_by_dataset: dict[int, set[str]] = {}
+ for (dataset_id, _), rows in stop_times_by_trip.items():
+ stop_ids_by_dataset.setdefault(dataset_id, set()).update(row.stop_id for row in rows)
+ lookup = {}
+ for dataset_id, stop_ids in stop_ids_by_dataset.items():
+ for chunk in _chunks(sorted(stop_ids), 900):
+ links = session.scalars(
+ select(CanonicalStopLink).where(
+ CanonicalStopLink.object_type == "gtfs_stop",
+ CanonicalStopLink.dataset_id == dataset_id,
+ CanonicalStopLink.external_id.in_(chunk),
+ )
+ ).all()
+ lookup.update({(link.dataset_id, link.external_id): link.canonical_stop_id for link in links})
+ return lookup
+
+
+def _route_pattern_stop_objects(
+ pattern: RoutePattern,
+ dataset_id: int,
+ trip_id: str | None,
+ rows: list[GtfsStopTime],
+ canonical_lookup: dict[tuple[int, str], int],
+) -> list[RoutePatternStop]:
+ if not trip_id:
+ return []
+ if not rows:
+ return []
+ objects: list[RoutePatternStop] = []
+ seen: set[int] = set()
+ for row in rows:
+ canonical_stop_id = canonical_lookup.get((dataset_id, row.stop_id))
+ if canonical_stop_id is None:
+ continue
+ if canonical_stop_id in seen:
+ continue
+ seen.add(canonical_stop_id)
+ objects.append(
+ RoutePatternStop(
+ route_pattern_id=pattern.id,
+ canonical_stop_id=canonical_stop_id,
+ sequence=row.stop_sequence,
+ distance_along=None,
+ source_kind="timetable_link",
+ confidence=0.75 if pattern.source_kind == "osm" else 0.45,
+ )
+ )
+ return objects
+
+
+def _chunks(values: list[str], size: int) -> Iterable[list[str]]:
+ for start in range(0, len(values), size):
+ yield values[start : start + size]
+
+
+def _chunks_objects(values: list, size: int) -> Iterable[list]:
+ for start in range(0, len(values), size):
+ yield values[start : start + size]
+
+
+def _normalized_geometry_text(geometry_text: str | None) -> str | None:
+ if not geometry_text:
+ return None
+ try:
+ geom = shape(json.loads(geometry_text))
+ if isinstance(geom, MultiLineString):
+ merged = linemerge(geom)
+ if isinstance(merged, (LineString, MultiLineString)) and not merged.is_empty:
+ geom = merged
+ return json.dumps(geom.__geo_interface__, separators=(",", ":"))
+ except Exception: # noqa: BLE001 - preserve source geometry if normalization fails
+ return geometry_text
+
+
+def _geometry_points_from_text(geometry_text: str | None) -> list[tuple[float, float]]:
+ if not geometry_text:
+ return []
+ try:
+ geometry = json.loads(geometry_text)
+ except json.JSONDecodeError:
+ return []
+ geometry_type = geometry.get("type")
+ coords = geometry.get("coordinates") or []
+ if geometry_type == "LineString":
+ return [(float(lon), float(lat)) for lon, lat, *_ in coords]
+ if geometry_type == "MultiLineString":
+ lines = [
+ [(float(lon), float(lat)) for lon, lat, *_ in line]
+ for line in coords
+ if len(line) >= 2
+ ]
+ if not lines:
+ return []
+ return max(lines, key=len)
+ return []
+
+
+def _bbox_center_point(bbox: tuple[float | None, float | None, float | None, float | None]) -> Point | None:
+ min_lon, min_lat, max_lon, max_lat = bbox
+ if None in bbox:
+ return None
+ return Point((float(min_lon) + float(max_lon)) / 2, (float(min_lat) + float(max_lat)) / 2)
+
+
+def _geometry_endpoints(geom) -> tuple[Point, Point] | None:
+ lines = list(_iter_lines(geom))
+ if not lines:
+ return None
+ longest = max(lines, key=lambda line: line.length)
+ coords = list(longest.coords)
+ if len(coords) < 2:
+ return None
+ return Point(coords[0]), Point(coords[-1])
+
+
+def _iter_lines(geom) -> Iterable[LineString]:
+ if isinstance(geom, LineString):
+ yield geom
+ elif isinstance(geom, MultiLineString):
+ yield from geom.geoms
+
+
+def _project_point_on_geometry(geom, point: Point) -> float | None:
+ best_line = None
+ best_distance = None
+ for line in _iter_lines(geom):
+ distance = line.distance(point)
+ if best_distance is None or distance < best_distance:
+ best_line = line
+ best_distance = distance
+ if best_line is None:
+ return None
+ return float(best_line.project(point))
+
+
+def _bounds_tuple(geom) -> tuple[float | None, float | None, float | None, float | None]:
+ if geom.is_empty:
+ return (None, None, None, None)
+ min_lon, min_lat, max_lon, max_lat = geom.bounds
+ return min_lon, min_lat, max_lon, max_lat
+
+
+def _mode_compatible(gtfs_mode: str, osm_mode: str) -> bool:
+ if not gtfs_mode or not osm_mode:
+ return True
+ if gtfs_mode == osm_mode:
+ return True
+ return osm_mode in MODE_GROUPS.get(gtfs_mode, {gtfs_mode}) or gtfs_mode in MODE_GROUPS.get(osm_mode, {osm_mode})
diff --git a/app/pipeline/routing_layer.py b/app/pipeline/routing_layer.py
new file mode 100644
index 0000000..fb10c88
--- /dev/null
+++ b/app/pipeline/routing_layer.py
@@ -0,0 +1,473 @@
+from __future__ import annotations
+
+import json
+import math
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Callable
+
+import osmium
+from sqlalchemy import delete, func, select, text
+from sqlalchemy.dialects.postgresql import insert as postgresql_insert
+from sqlalchemy.orm import Session
+
+from app.config import settings
+from app.models import Dataset, RoutingEdge, RoutingNode
+from app.spatial import analyze_postgresql_tables, refresh_postgis_geometries
+
+
+ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, object] | None], None]
+ROUTING_LAYER_VERSION = "routing_layer_v2_osm_highway_segments_service_tags"
+
+DRIVE_HIGHWAYS = {
+ "motorway",
+ "motorway_link",
+ "trunk",
+ "trunk_link",
+ "primary",
+ "primary_link",
+ "secondary",
+ "secondary_link",
+ "tertiary",
+ "tertiary_link",
+ "unclassified",
+ "residential",
+ "living_street",
+ "service",
+ "road",
+ "track",
+}
+WALK_HIGHWAYS = {
+ "pedestrian",
+ "footway",
+ "path",
+ "steps",
+ "cycleway",
+ "bridleway",
+ "living_street",
+ "residential",
+ "service",
+ "track",
+ "unclassified",
+ "tertiary",
+ "tertiary_link",
+ "secondary",
+ "secondary_link",
+ "primary",
+ "primary_link",
+ "road",
+}
+EXCLUDED_HIGHWAYS = {"construction", "proposed", "abandoned", "platform", "raceway"}
+NO_VALUES = {"no", "private", "agricultural", "forestry", "delivery", "customers"}
+YES_VALUES = {"yes", "designated", "permissive", "destination"}
+ONEWAY_FORWARD = {"yes", "true", "1"}
+ONEWAY_REVERSE = {"-1", "reverse"}
+DEFAULT_DRIVE_SPEED_KMH = {
+ "motorway": 110,
+ "motorway_link": 50,
+ "trunk": 90,
+ "trunk_link": 45,
+ "primary": 70,
+ "primary_link": 40,
+ "secondary": 60,
+ "secondary_link": 35,
+ "tertiary": 50,
+ "tertiary_link": 30,
+ "unclassified": 40,
+ "residential": 30,
+ "living_street": 10,
+ "service": 15,
+ "road": 30,
+ "track": 15,
+}
+DEFAULT_WALK_SPEED_MPS = 1.35
+STEP_WALK_SPEED_MPS = 0.65
+
+
+@dataclass
+class RoutingImportResult:
+ dataset_id: int
+ input_path: str
+ nodes: int
+ edges: int
+ walk_edges: int
+ drive_edges: int
+ skipped_ways: int
+ version: str = ROUTING_LAYER_VERSION
+
+ def as_dict(self) -> dict[str, object]:
+ return {
+ "version": self.version,
+ "dataset_id": self.dataset_id,
+ "input_path": self.input_path,
+ "nodes": self.nodes,
+ "edges": self.edges,
+ "walk_edges": self.walk_edges,
+ "drive_edges": self.drive_edges,
+ "skipped_ways": self.skipped_ways,
+ }
+
+
+def active_routing_dataset(session: Session) -> Dataset | None:
+ active_osm = session.scalar(
+ select(Dataset).where(Dataset.kind == "osm_geojson", Dataset.is_active.is_(True)).order_by(Dataset.id.desc())
+ )
+ if active_osm is not None:
+ metadata = _metadata(active_osm)
+ raw_dataset_id = metadata.get("raw_dataset_id")
+ if raw_dataset_id is not None:
+ raw = session.get(Dataset, int(raw_dataset_id))
+ if raw is not None and Path(raw.local_path).exists():
+ return raw
+ return session.scalar(
+ select(Dataset)
+ .where(Dataset.kind == "osm_pbf_raw")
+ .order_by(Dataset.is_active.desc(), Dataset.id.desc())
+ )
+
+
+def rebuild_routing_layer(
+ session: Session,
+ *,
+ dataset_id: int | None = None,
+ input_path: str | Path | None = None,
+ reset: bool = True,
+ batch_size: int = 5000,
+ progress_callback: ProgressCallback | None = None,
+) -> dict[str, object]:
+ if not settings.is_postgresql_database:
+ raise RuntimeError("The routing layer importer requires PostgreSQL/PostGIS.")
+ dataset = session.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(session)
+ if dataset is None:
+ raise ValueError("No OSM PBF dataset is available for routing import.")
+ path = Path(input_path or dataset.local_path)
+ if not path.exists():
+ raise FileNotFoundError(f"Routing import PBF does not exist: {path}")
+
+ if reset:
+ _emit(progress_callback, "routing_layer_clear_started", "Clearing existing routing graph.", None, None, {"dataset_id": dataset.id})
+ session.execute(delete(RoutingEdge).where(RoutingEdge.dataset_id == dataset.id))
+ session.execute(delete(RoutingNode).where(RoutingNode.dataset_id == dataset.id))
+ session.commit()
+
+ _emit(progress_callback, "routing_layer_import_started", "Importing routable OSM highway graph.", None, None, {"dataset_id": dataset.id, "path": str(path)})
+ handler = _RoutingGraphHandler(session=session, dataset_id=dataset.id, batch_size=batch_size, progress_callback=progress_callback)
+ handler.apply_file(str(path), locations=True)
+ handler.flush()
+
+ return finalize_routing_layer(
+ session,
+ dataset_id=dataset.id,
+ input_path=str(path),
+ skipped_way_count=handler.skipped_way_count,
+ progress_callback=progress_callback,
+ )
+
+
+def finalize_routing_layer(
+ session: Session,
+ *,
+ dataset_id: int | None = None,
+ input_path: str | Path | None = None,
+ skipped_way_count: int = 0,
+ progress_callback: ProgressCallback | None = None,
+) -> dict[str, object]:
+ if not settings.is_postgresql_database:
+ raise RuntimeError("The routing layer finalizer requires PostgreSQL/PostGIS.")
+ dataset = session.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(session)
+ if dataset is None:
+ raise ValueError("No routing dataset is available to finalize.")
+ path = Path(input_path or dataset.local_path)
+ _emit(progress_callback, "routing_layer_geometry_indexes_dropped", "Dropping routing geometry indexes before bulk refresh.", None, None, {"dataset_id": dataset.id})
+ _drop_routing_geometry_indexes(session)
+ session.commit()
+ _emit(progress_callback, "routing_layer_geometry_started", "Refreshing routing node PostGIS geometries.", None, None, {"dataset_id": dataset.id})
+ refresh_postgis_geometries(session, dataset_id=dataset.id, tables=["routing_nodes"], only_missing=False)
+ session.commit()
+ _emit(progress_callback, "routing_layer_geometry_indexes_started", "Rebuilding routing geometry indexes.", None, None, {"dataset_id": dataset.id})
+ _create_routing_geometry_indexes(session)
+ session.commit()
+ analyze_postgresql_tables(session, ["routing_nodes", "routing_edges"])
+ node_count = int(session.scalar(select(func.count()).select_from(RoutingNode).where(RoutingNode.dataset_id == dataset.id)) or 0)
+ edge_count = int(session.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset.id)) or 0)
+ walk_edge_count = int(session.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset.id, RoutingEdge.walk_cost_s.is_not(None))) or 0)
+ drive_edge_count = int(session.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset.id, RoutingEdge.drive_cost_s.is_not(None))) or 0)
+ dataset_metadata = _metadata(dataset)
+ dataset_metadata["routing_layer"] = {
+ "version": ROUTING_LAYER_VERSION,
+ "nodes": node_count,
+ "edges": edge_count,
+ "walk_edges": walk_edge_count,
+ "drive_edges": drive_edge_count,
+ "input_path": str(path),
+ }
+ dataset.metadata_json = json.dumps(dataset_metadata, indent=2)
+ session.commit()
+ result = RoutingImportResult(
+ dataset_id=dataset.id,
+ input_path=str(path),
+ nodes=node_count,
+ edges=edge_count,
+ walk_edges=walk_edge_count,
+ drive_edges=drive_edge_count,
+ skipped_ways=skipped_way_count,
+ ).as_dict()
+ _emit(progress_callback, "routing_layer_import_completed", "Routing graph import completed.", edge_count, edge_count, result)
+ return result
+
+
+def _drop_routing_geometry_indexes(session: Session) -> None:
+ session.execute(text("DROP INDEX IF EXISTS ix_routing_nodes_geom_gist"))
+ session.execute(text("DROP INDEX IF EXISTS ix_routing_edges_geom_gist"))
+ session.execute(text("DROP INDEX IF EXISTS ix_routing_edges_bbox_box_gist"))
+
+
+def _create_routing_geometry_indexes(session: Session) -> None:
+ session.execute(text("CREATE INDEX IF NOT EXISTS ix_routing_nodes_geom_gist ON routing_nodes USING GIST (geom)"))
+ session.execute(text("CREATE INDEX IF NOT EXISTS ix_routing_edges_bbox_box_gist ON routing_edges USING GIST (box(point(max_lon, max_lat), point(min_lon, min_lat)))"))
+
+
+class _RoutingGraphHandler(osmium.SimpleHandler):
+ def __init__(
+ self,
+ *,
+ session: Session,
+ dataset_id: int,
+ batch_size: int,
+ progress_callback: ProgressCallback | None,
+ ) -> None:
+ super().__init__()
+ self.session = session
+ self.dataset_id = dataset_id
+ self.batch_size = max(500, int(batch_size))
+ self.progress_callback = progress_callback
+ self.nodes: dict[int, dict[str, object]] = {}
+ self.edges: list[dict[str, object]] = []
+ self.node_count = int(
+ session.scalar(select(func.count()).select_from(RoutingNode).where(RoutingNode.dataset_id == dataset_id)) or 0
+ )
+ self.edge_count = int(
+ session.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset_id)) or 0
+ )
+ self.walk_edge_count = 0
+ self.drive_edge_count = 0
+ self.skipped_way_count = 0
+ self.processed_way_count = 0
+
+ def way(self, way) -> None:
+ tags = {tag.k: tag.v for tag in way.tags}
+ highway = tags.get("highway")
+ if not highway or highway in EXCLUDED_HIGHWAYS:
+ self.skipped_way_count += 1
+ return
+ walkable = _walkable(tags, highway)
+ drivable = _drivable(tags, highway)
+ if not walkable and not drivable:
+ self.skipped_way_count += 1
+ return
+
+ nodes = []
+ for node in way.nodes:
+ if not node.location.valid():
+ continue
+ nodes.append((int(node.ref), float(node.location.lon), float(node.location.lat)))
+ if len(nodes) < 2:
+ self.skipped_way_count += 1
+ return
+
+ oneway = _oneway_direction(tags, highway)
+ drive_speed_mps = _drive_speed_mps(tags, highway)
+ walk_speed_mps = STEP_WALK_SPEED_MPS if highway == "steps" else DEFAULT_WALK_SPEED_MPS
+ for left, right in zip(nodes, nodes[1:]):
+ source_id, source_lon, source_lat = left
+ target_id, target_lon, target_lat = right
+ if source_id == target_id:
+ continue
+ length_m = _distance_m(source_lat, source_lon, target_lat, target_lon)
+ if length_m <= 0:
+ continue
+ if oneway == "reverse":
+ source_id, target_id = target_id, source_id
+ source_lon, target_lon = target_lon, source_lon
+ source_lat, target_lat = target_lat, source_lat
+
+ walk_cost = length_m / walk_speed_mps if walkable else None
+ drive_cost = length_m / drive_speed_mps if drivable and drive_speed_mps > 0 else None
+ reverse_walk_cost = walk_cost
+ reverse_drive_cost = None if oneway in {"forward", "reverse"} else drive_cost
+ self.nodes[source_id] = {"dataset_id": self.dataset_id, "osm_node_id": source_id, "lon": source_lon, "lat": source_lat}
+ self.nodes[target_id] = {"dataset_id": self.dataset_id, "osm_node_id": target_id, "lon": target_lon, "lat": target_lat}
+ self.edges.append(
+ {
+ "dataset_id": self.dataset_id,
+ "osm_way_id": int(way.id),
+ "source_osm_node_id": source_id,
+ "target_osm_node_id": target_id,
+ "source_lon": source_lon,
+ "source_lat": source_lat,
+ "target_lon": target_lon,
+ "target_lat": target_lat,
+ "highway": highway,
+ "name": tags.get("name"),
+ "length_m": length_m,
+ "walk_cost_s": walk_cost,
+ "reverse_walk_cost_s": reverse_walk_cost,
+ "drive_cost_s": drive_cost,
+ "reverse_drive_cost_s": reverse_drive_cost,
+ "geometry_geojson": json.dumps({"type": "LineString", "coordinates": [[source_lon, source_lat], [target_lon, target_lat]]}, separators=(",", ":")),
+ "min_lon": min(source_lon, target_lon),
+ "min_lat": min(source_lat, target_lat),
+ "max_lon": max(source_lon, target_lon),
+ "max_lat": max(source_lat, target_lat),
+ "tags_json": _routing_tags_json(tags),
+ }
+ )
+ self.edge_count += 1
+ if walk_cost is not None:
+ self.walk_edge_count += 1
+ if drive_cost is not None:
+ self.drive_edge_count += 1
+
+ self.processed_way_count += 1
+ if len(self.edges) >= self.batch_size:
+ self.flush()
+ if self.processed_way_count % 100_000 == 0:
+ _emit(
+ self.progress_callback,
+ "routing_layer_import_batch",
+ f"Imported {self.edge_count:,} routing edges.",
+ self.edge_count,
+ None,
+ {"processed_ways": self.processed_way_count, "nodes_pending": len(self.nodes), "edges": self.edge_count},
+ )
+
+ def flush(self) -> None:
+ if not self.nodes and not self.edges:
+ return
+ node_rows = list(self.nodes.values())
+ edge_rows = self.edges
+ if node_rows:
+ stmt = postgresql_insert(RoutingNode).values(node_rows)
+ stmt = stmt.on_conflict_do_nothing(index_elements=["dataset_id", "osm_node_id"])
+ self.session.execute(stmt)
+ self.node_count += len(node_rows)
+ self.nodes.clear()
+ if edge_rows:
+ self.session.bulk_insert_mappings(RoutingEdge, edge_rows)
+ self.edges = []
+ self.session.commit()
+
+
+def _walkable(tags: dict[str, str], highway: str) -> bool:
+ if highway not in WALK_HIGHWAYS:
+ return False
+ access = _tag_value(tags, "access")
+ foot = _tag_value(tags, "foot")
+ if foot in NO_VALUES:
+ return False
+ if access in NO_VALUES and foot not in YES_VALUES:
+ return False
+ if highway in {"motorway", "motorway_link", "trunk", "trunk_link"} and foot not in YES_VALUES:
+ return False
+ return True
+
+
+def _drivable(tags: dict[str, str], highway: str) -> bool:
+ if highway not in DRIVE_HIGHWAYS:
+ return False
+ access = _tag_value(tags, "access")
+ motor_vehicle = _tag_value(tags, "motor_vehicle")
+ motorcar = _tag_value(tags, "motorcar")
+ vehicle = _tag_value(tags, "vehicle")
+ if motorcar in NO_VALUES or motor_vehicle in NO_VALUES or vehicle in NO_VALUES:
+ return False
+ if access in NO_VALUES and motorcar not in YES_VALUES and motor_vehicle not in YES_VALUES:
+ return False
+ if highway in {"footway", "path", "pedestrian", "steps", "cycleway", "bridleway"}:
+ return motorcar in YES_VALUES or motor_vehicle in YES_VALUES
+ return True
+
+
+def _oneway_direction(tags: dict[str, str], highway: str) -> str:
+ oneway = _tag_value(tags, "oneway")
+ if oneway in ONEWAY_REVERSE:
+ return "reverse"
+ if oneway in ONEWAY_FORWARD or tags.get("junction") == "roundabout" or highway == "motorway":
+ return "forward"
+ return "both"
+
+
+def _drive_speed_mps(tags: dict[str, str], highway: str) -> float:
+ maxspeed = _parse_maxspeed(tags.get("maxspeed"))
+ kmh = maxspeed or DEFAULT_DRIVE_SPEED_KMH.get(highway, 30)
+ return max(5.0, float(kmh) / 3.6)
+
+
+def _parse_maxspeed(value: str | None) -> float | None:
+ if not value:
+ return None
+ text = value.strip().lower()
+ if text in {"signals", "none", "walk", "variable"}:
+ return None
+ if text.endswith("mph"):
+ number = _leading_float(text[:-3])
+ return None if number is None else number * 1.60934
+ return _leading_float(text)
+
+
+def _leading_float(value: str) -> float | None:
+ digits = []
+ for char in value.strip():
+ if char.isdigit() or char == ".":
+ digits.append(char)
+ elif digits:
+ break
+ if not digits:
+ return None
+ try:
+ return float("".join(digits))
+ except ValueError:
+ return None
+
+
+def _routing_tags_json(tags: dict[str, str]) -> str:
+ selected = {
+ key: value
+ for key, value in tags.items()
+ if key in {"access", "bicycle", "bridge", "foot", "highway", "junction", "maxspeed", "motor_vehicle", "motorcar", "name", "oneway", "service", "surface", "tunnel", "vehicle"}
+ }
+ return json.dumps(selected, separators=(",", ":"))
+
+
+def _tag_value(tags: dict[str, str], key: str) -> str:
+ return str(tags.get(key) or "").strip().lower()
+
+
+def _distance_m(lat_a: float, lon_a: float, lat_b: float, lon_b: float) -> float:
+ radius = 6_371_000.0
+ phi_a = math.radians(lat_a)
+ phi_b = math.radians(lat_b)
+ delta_phi = math.radians(lat_b - lat_a)
+ delta_lambda = math.radians(lon_b - lon_a)
+ hav = math.sin(delta_phi / 2) ** 2 + math.cos(phi_a) * math.cos(phi_b) * math.sin(delta_lambda / 2) ** 2
+ return radius * 2 * math.atan2(math.sqrt(hav), math.sqrt(1 - hav))
+
+
+def _metadata(dataset: Dataset) -> dict[str, object]:
+ try:
+ value = json.loads(dataset.metadata_json or "{}")
+ except json.JSONDecodeError:
+ return {}
+ return value if isinstance(value, dict) else {}
+
+
+def _emit(
+ progress_callback: ProgressCallback | None,
+ event_type: str,
+ message: str,
+ progress_current: int | None,
+ progress_total: int | None,
+ metadata: dict[str, object] | None = None,
+) -> None:
+ if progress_callback is not None:
+ progress_callback(event_type, message, progress_current, progress_total, metadata)
diff --git a/app/pipeline/run.py b/app/pipeline/run.py
new file mode 100644
index 0000000..2f0e57e
--- /dev/null
+++ b/app/pipeline/run.py
@@ -0,0 +1,40 @@
+from __future__ import annotations
+
+from datetime import datetime, timezone
+from typing import Callable, Any
+
+from sqlalchemy.orm import Session
+
+from app.models import Source
+from app.pipeline.gtfs import run_gtfs_source
+from app.pipeline.osm_diff import run_osm_diff_source
+from app.pipeline.osm_geojson import run_osm_geojson_source
+from app.pipeline.osm_pbf import run_osm_pbf_source
+
+
+ProgressCallback = Callable[[str, str, int | None, int | None, dict[str, Any] | None], None]
+
+
+def run_source(session: Session, source: Source, progress_callback: ProgressCallback | None = None):
+ source.status = "running"
+ source.last_run_at = datetime.now(timezone.utc)
+ source.last_error = None
+ session.flush()
+ try:
+ if source.kind == "gtfs":
+ dataset = run_gtfs_source(session, source, progress_callback=progress_callback)
+ elif source.kind == "osm_geojson":
+ dataset = run_osm_geojson_source(session, source)
+ elif source.kind == "osm_pbf":
+ dataset = run_osm_pbf_source(session, source, progress_callback=progress_callback)
+ elif source.kind == "osm_diff":
+ dataset = run_osm_diff_source(session, source)
+ else:
+ raise ValueError(f"Unsupported source kind: {source.kind}")
+ source.status = "ok"
+ source.last_error = None
+ return dataset
+ except Exception as exc: # noqa: BLE001 - persist pipeline error for UI
+ source.status = "error"
+ source.last_error = str(exc)
+ raise
diff --git a/app/pipeline/sample_data.py b/app/pipeline/sample_data.py
new file mode 100644
index 0000000..2f8dbe4
--- /dev/null
+++ b/app/pipeline/sample_data.py
@@ -0,0 +1,294 @@
+from __future__ import annotations
+
+import csv
+import io
+import json
+import zipfile
+from pathlib import Path
+from datetime import datetime, timezone
+
+from sqlalchemy import delete, select
+from sqlalchemy.orm import Session
+
+from app.config import settings
+from app.db import init_db
+from app.models import (
+ Dataset,
+ CanonicalStop,
+ CanonicalStopLink,
+ GtfsAgency,
+ GtfsCalendar,
+ GtfsCalendarDate,
+ GtfsRoute,
+ GtfsRoutePatternLink,
+ GtfsShape,
+ GtfsStop,
+ GtfsStopTime,
+ GtfsTripRoutePatternLink,
+ GtfsTrip,
+ Itinerary,
+ ItineraryLeg,
+ Job,
+ JobEvent,
+ MatchRule,
+ OsmDiffState,
+ OsmFeature,
+ PipelineRun,
+ RouteMatch,
+ RoutePattern,
+ RoutePatternStop,
+ RoutingEdge,
+ RoutingNode,
+ Source,
+ SourceCatalogEntry,
+ SourceUpdateCheck,
+ TravelRequest,
+)
+from app.pipeline.matcher import run_route_matching
+from app.pipeline.route_layer import rebuild_route_layer
+from app.pipeline.run import run_source
+
+
+def load_sample_project(session: Session, *, preserve_job_id: int | None = None) -> dict:
+ """Clear the DB, create a small Berlin-like GTFS + OSM sample, import, and match."""
+ init_db()
+ clear_project_data(session, preserve_job_id=preserve_job_id, preserve_catalog=True)
+ sample_dir = settings.data_dir / "sample"
+ sample_dir.mkdir(parents=True, exist_ok=True)
+ gtfs_path = sample_dir / "sample_berlin.gtfs.zip"
+ osm_path = sample_dir / "sample_berlin_osm.geojson"
+ create_sample_gtfs(gtfs_path)
+ create_sample_osm_geojson(osm_path)
+
+ gtfs_source = Source(name="Sample Berlin GTFS", kind="gtfs", url=str(gtfs_path), country="DE", license="sample")
+ osm_source = Source(name="Sample Berlin OSM transport", kind="osm_geojson", url=str(osm_path), country="DE", license="sample")
+ session.add_all([gtfs_source, osm_source])
+ session.flush()
+
+ gtfs_dataset = run_source(session, gtfs_source)
+ osm_dataset = run_source(session, osm_source)
+ match_result = run_route_matching(session)
+ route_layer_result = rebuild_route_layer(session)
+ return {
+ "status": "ok",
+ "gtfs_dataset_id": gtfs_dataset.id,
+ "osm_dataset_id": osm_dataset.id,
+ "match_result": match_result,
+ "route_layer_result": route_layer_result,
+ }
+
+
+def clear_project_data(
+ session: Session,
+ *,
+ preserve_job_id: int | None = None,
+ preserve_catalog: bool = True,
+) -> None:
+ """Clear user/project data while optionally preserving the current queue job."""
+ session.execute(delete(PipelineRun))
+ if preserve_job_id is None:
+ session.execute(delete(JobEvent))
+ session.execute(delete(Job))
+ else:
+ _cancel_other_jobs_for_reset(session, preserve_job_id)
+
+ for model in [
+ ItineraryLeg,
+ Itinerary,
+ TravelRequest,
+ SourceUpdateCheck,
+ OsmDiffState,
+ MatchRule,
+ RouteMatch,
+ GtfsTripRoutePatternLink,
+ GtfsRoutePatternLink,
+ RoutePatternStop,
+ RoutePattern,
+ CanonicalStopLink,
+ CanonicalStop,
+ RoutingEdge,
+ RoutingNode,
+ GtfsStopTime,
+ GtfsCalendarDate,
+ GtfsCalendar,
+ GtfsShape,
+ GtfsTrip,
+ GtfsRoute,
+ GtfsStop,
+ GtfsAgency,
+ OsmFeature,
+ Dataset,
+ Source,
+ ]:
+ session.execute(delete(model))
+ if not preserve_catalog:
+ session.execute(delete(SourceCatalogEntry))
+ session.flush()
+
+
+def _cancel_other_jobs_for_reset(session: Session, preserve_job_id: int) -> None:
+ now = datetime.now(timezone.utc)
+ active_statuses = {"queued", "running", "paused"}
+ jobs = session.scalars(
+ select(Job).where(Job.id != preserve_job_id, Job.status.in_(active_statuses))
+ ).all()
+ for job in jobs:
+ job.status = "cancelled"
+ job.requested_action = None
+ job.lease_owner = None
+ job.lease_expires_at = None
+ job.paused_at = None
+ job.error = None
+ job.updated_at = now
+ job.finished_at = now
+ session.add(
+ JobEvent(
+ job_id=job.id,
+ event_type="cancelled_by_reset",
+ message=f"Job cancelled by reset job #{preserve_job_id}.",
+ progress_current=job.progress_current,
+ progress_total=job.progress_total,
+ )
+ )
+
+
+def create_sample_gtfs(path: Path) -> None:
+ agencies = [
+ {"agency_id": "BVG", "agency_name": "BVG", "agency_url": "https://example.invalid/bvg", "agency_timezone": "Europe/Berlin"},
+ {"agency_id": "DB", "agency_name": "DB Regio", "agency_url": "https://example.invalid/db", "agency_timezone": "Europe/Berlin"},
+ {"agency_id": "XAIR", "agency_name": "Example Airport Shuttle", "agency_url": "https://example.invalid/xair", "agency_timezone": "Europe/Berlin"},
+ ]
+ stops = [
+ {"stop_id": "hbf", "stop_name": "Berlin Hauptbahnhof", "stop_lat": "52.5251", "stop_lon": "13.3696"},
+ {"stop_id": "friedrich", "stop_name": "Friedrichstraße", "stop_lat": "52.5201", "stop_lon": "13.3862"},
+ {"stop_id": "alex", "stop_name": "Alexanderplatz", "stop_lat": "52.5219", "stop_lon": "13.4132"},
+ {"stop_id": "ost", "stop_name": "Ostbahnhof", "stop_lat": "52.5100", "stop_lon": "13.4344"},
+ {"stop_id": "zoo", "stop_name": "Zoologischer Garten", "stop_lat": "52.5069", "stop_lon": "13.3320"},
+ {"stop_id": "wittenberg", "stop_name": "Wittenbergplatz", "stop_lat": "52.5020", "stop_lon": "13.3430"},
+ {"stop_id": "potsdamer", "stop_name": "Potsdamer Platz", "stop_lat": "52.5096", "stop_lon": "13.3760"},
+ {"stop_id": "stadtmitte", "stop_name": "Stadtmitte", "stop_lat": "52.5113", "stop_lon": "13.3907"},
+ {"stop_id": "reichstag", "stop_name": "Reichstag", "stop_lat": "52.5186", "stop_lon": "13.3763"},
+ {"stop_id": "hackescher", "stop_name": "Hackescher Markt", "stop_lat": "52.5220", "stop_lon": "13.4023"},
+ {"stop_id": "naturkunde", "stop_name": "Naturkundemuseum", "stop_lat": "52.5300", "stop_lon": "13.3790"},
+ {"stop_id": "wannsee", "stop_name": "Wannsee", "stop_lat": "52.4210", "stop_lon": "13.1797"},
+ {"stop_id": "kladow", "stop_name": "Kladow", "stop_lat": "52.4547", "stop_lon": "13.1439"},
+ {"stop_id": "airport", "stop_name": "Example Airport Terminal", "stop_lat": "52.3650", "stop_lon": "13.5100"},
+ ]
+ routes = [
+ {"route_id": "u2", "agency_id": "BVG", "route_short_name": "U2", "route_long_name": "Pankow - Ruhleben", "route_type": "1"},
+ {"route_id": "re1", "agency_id": "DB", "route_short_name": "RE1", "route_long_name": "Magdeburg - Frankfurt Oder", "route_type": "2"},
+ {"route_id": "m5", "agency_id": "BVG", "route_short_name": "M5", "route_long_name": "Hauptbahnhof - Hohenschönhausen", "route_type": "0"},
+ {"route_id": "bus100", "agency_id": "BVG", "route_short_name": "100", "route_long_name": "Zoologischer Garten - Alexanderplatz", "route_type": "3"},
+ {"route_id": "f10", "agency_id": "BVG", "route_short_name": "F10", "route_long_name": "Wannsee - Kladow", "route_type": "4"},
+ {"route_id": "x99", "agency_id": "XAIR", "route_short_name": "X99", "route_long_name": "Airport Express Sample", "route_type": "3"},
+ ]
+ trips = [
+ {"route_id": r["route_id"], "service_id": "daily", "trip_id": f"{r['route_id']}_trip", "shape_id": f"{r['route_id']}_shape"}
+ for r in routes
+ ]
+ stop_sequences = {
+ "u2_trip": ["zoo", "wittenberg", "potsdamer", "stadtmitte", "alex"],
+ "re1_trip": ["hbf", "friedrich", "alex", "ost"],
+ "m5_trip": ["hbf", "naturkunde", "hackescher", "alex"],
+ "bus100_trip": ["zoo", "reichstag", "alex"],
+ "f10_trip": ["wannsee", "kladow"],
+ "x99_trip": ["alex", "airport"],
+ }
+ coords = {row["stop_id"]: (row["stop_lon"], row["stop_lat"]) for row in stops}
+ stop_times = []
+ shapes = []
+ for trip in trips:
+ trip_id = trip["trip_id"]
+ for idx, stop_id in enumerate(stop_sequences[trip_id], start=1):
+ stop_times.append(
+ {
+ "trip_id": trip_id,
+ "arrival_time": f"08:{idx * 5:02d}:00",
+ "departure_time": f"08:{idx * 5 + 1:02d}:00",
+ "stop_id": stop_id,
+ "stop_sequence": str(idx),
+ }
+ )
+ lon, lat = coords[stop_id]
+ shapes.append(
+ {
+ "shape_id": trip["shape_id"],
+ "shape_pt_lat": lat,
+ "shape_pt_lon": lon,
+ "shape_pt_sequence": str(idx),
+ }
+ )
+ calendar = [
+ {
+ "service_id": "daily",
+ "monday": "1",
+ "tuesday": "1",
+ "wednesday": "1",
+ "thursday": "1",
+ "friday": "1",
+ "saturday": "1",
+ "sunday": "1",
+ "start_date": "20260101",
+ "end_date": "20261231",
+ }
+ ]
+
+ with zipfile.ZipFile(path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
+ _write_csv(zf, "agency.txt", ["agency_id", "agency_name", "agency_url", "agency_timezone"], agencies)
+ _write_csv(zf, "stops.txt", ["stop_id", "stop_name", "stop_lat", "stop_lon"], stops)
+ _write_csv(zf, "routes.txt", ["route_id", "agency_id", "route_short_name", "route_long_name", "route_type"], routes)
+ _write_csv(zf, "trips.txt", ["route_id", "service_id", "trip_id", "shape_id"], trips)
+ _write_csv(zf, "stop_times.txt", ["trip_id", "arrival_time", "departure_time", "stop_id", "stop_sequence"], stop_times)
+ _write_csv(
+ zf,
+ "calendar.txt",
+ ["service_id", "monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday", "start_date", "end_date"],
+ calendar,
+ )
+ _write_csv(zf, "shapes.txt", ["shape_id", "shape_pt_lat", "shape_pt_lon", "shape_pt_sequence"], shapes)
+
+
+def _write_csv(zf: zipfile.ZipFile, name: str, fields: list[str], rows: list[dict[str, str]]) -> None:
+ buffer = io.StringIO(newline="")
+ writer = csv.DictWriter(buffer, fieldnames=fields)
+ writer.writeheader()
+ writer.writerows(rows)
+ zf.writestr(name, buffer.getvalue())
+
+
+def create_sample_osm_geojson(path: Path) -> None:
+ def line(fid, mode, ref, name, operator, coords):
+ return {
+ "type": "Feature",
+ "geometry": {"type": "LineString", "coordinates": coords},
+ "properties": {
+ "osm_type": "relation",
+ "osm_id": str(fid),
+ "type": "route",
+ "route": mode,
+ "ref": ref,
+ "name": name,
+ "operator": operator,
+ "network": "VBB" if operator == "BVG" else "DB",
+ },
+ }
+
+ def point(fid, kind, name, coords, props=None):
+ props = props or {}
+ props.update({"osm_type": "node", "osm_id": str(fid), "name": name})
+ return {"type": "Feature", "geometry": {"type": "Point", "coordinates": coords}, "properties": props}
+
+ features = [
+ line(1002, "subway", "U2", "U2 Ruhleben - Pankow", "BVG", [[13.3320, 52.5069], [13.3430, 52.5020], [13.3760, 52.5096], [13.3907, 52.5113], [13.4132, 52.5219]]),
+ line(2001, "train", "RE1", "RE1 Magdeburg - Frankfurt Oder", "DB Regio", [[13.3696, 52.5251], [13.3862, 52.5201], [13.4132, 52.5219], [13.4344, 52.5100]]),
+ line(5005, "tram", "M5", "M5 Hauptbahnhof - Hohenschönhausen", "BVG", [[13.3696, 52.5251], [13.3790, 52.5300], [13.4023, 52.5220], [13.4132, 52.5219]]),
+ line(6100, "bus", "100", "Bus 100 Zoologischer Garten - Alexanderplatz", "BVG", [[13.3320, 52.5069], [13.3763, 52.5186], [13.4132, 52.5219]]),
+ line(7010, "ferry", "F10", "F10 Wannsee - Kladow", "BVG", [[13.1797, 52.4210], [13.1439, 52.4547]]),
+ line(5010, "tram", "M10", "M10 Warschauer Straße - Hauptbahnhof", "BVG", [[13.4500, 52.5050], [13.4020, 52.5300], [13.3696, 52.5251]]),
+ point(9001, "station", "Berlin Hauptbahnhof", [13.3696, 52.5251], {"railway": "station"}),
+ point(9002, "station", "Alexanderplatz", [13.4132, 52.5219], {"railway": "station"}),
+ point(9003, "stop", "Zoologischer Garten", [13.3320, 52.5069], {"public_transport": "station", "railway": "station"}),
+ point(9004, "terminal", "Wannsee Ferry Terminal", [13.1797, 52.4210], {"amenity": "ferry_terminal"}),
+ point(9005, "terminal", "Kladow Ferry Terminal", [13.1439, 52.4547], {"amenity": "ferry_terminal"}),
+ ]
+ path.write_text(json.dumps({"type": "FeatureCollection", "features": features}, indent=2), encoding="utf-8")
diff --git a/app/pipeline/state.py b/app/pipeline/state.py
new file mode 100644
index 0000000..5c7e865
--- /dev/null
+++ b/app/pipeline/state.py
@@ -0,0 +1,135 @@
+from __future__ import annotations
+
+from datetime import datetime, timezone
+import hashlib
+import json
+from typing import Any
+
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
+from app.models import PipelineRun
+
+
+STAGE_ACQUIRE_RAW = "acquire_raw"
+STAGE_FILTER_TRANSPORT = "filter_transport"
+STAGE_EXTRACT_GEOMETRY = "extract_geometry"
+STAGE_LABEL_FEATURES = "label_features"
+STAGE_BUILD_INDEXES = "build_indexes"
+STAGE_MATCH_ROUTES = "match_routes"
+STAGE_BUILD_ROUTE_LAYER = "build_route_layer"
+
+
+def stable_json(value: Any) -> str:
+ return json.dumps(value, sort_keys=True, separators=(",", ":"), default=str)
+
+
+def dependency_hash(value: Any) -> str:
+ return hashlib.sha256(stable_json(value).encode("utf-8")).hexdigest()
+
+
+def latest_completed_run(
+ session: Session,
+ *,
+ stage: str,
+ version: str,
+ dependency_hash_value: str,
+ source_id: int | None = None,
+ dataset_id: int | None = None,
+) -> PipelineRun | None:
+ stmt = (
+ select(PipelineRun)
+ .where(
+ PipelineRun.stage == stage,
+ PipelineRun.version == version,
+ PipelineRun.dependency_hash == dependency_hash_value,
+ PipelineRun.status == "completed",
+ )
+ .order_by(PipelineRun.finished_at.desc(), PipelineRun.id.desc())
+ .limit(1)
+ )
+ if source_id is None:
+ stmt = stmt.where(PipelineRun.source_id.is_(None))
+ else:
+ stmt = stmt.where(PipelineRun.source_id == source_id)
+ if dataset_id is None:
+ stmt = stmt.where(PipelineRun.dataset_id.is_(None))
+ else:
+ stmt = stmt.where(PipelineRun.dataset_id == dataset_id)
+ return session.scalar(stmt)
+
+
+def start_pipeline_run(
+ session: Session,
+ *,
+ stage: str,
+ version: str,
+ dependency_hash_value: str,
+ source_id: int | None = None,
+ dataset_id: int | None = None,
+ job_id: int | None = None,
+ inputs: dict[str, Any] | None = None,
+) -> PipelineRun:
+ now = datetime.now(timezone.utc)
+ run = PipelineRun(
+ stage=stage,
+ version=version,
+ dependency_hash=dependency_hash_value,
+ status="running",
+ source_id=source_id,
+ dataset_id=dataset_id,
+ job_id=job_id,
+ input_json=None if inputs is None else stable_json(inputs),
+ started_at=now,
+ updated_at=now,
+ )
+ session.add(run)
+ session.flush()
+ return run
+
+
+def finish_pipeline_run(
+ session: Session,
+ run: PipelineRun,
+ *,
+ status: str = "completed",
+ outputs: dict[str, Any] | None = None,
+ error: str | None = None,
+) -> PipelineRun:
+ now = datetime.now(timezone.utc)
+ run.status = status
+ run.output_json = None if outputs is None else stable_json(outputs)
+ run.error = error
+ run.updated_at = now
+ run.finished_at = now
+ session.flush()
+ return run
+
+
+def pipeline_run_payload(run: PipelineRun) -> dict[str, Any]:
+ return {
+ "id": run.id,
+ "stage": run.stage,
+ "version": run.version,
+ "dependency_hash": run.dependency_hash,
+ "status": run.status,
+ "source_id": run.source_id,
+ "dataset_id": run.dataset_id,
+ "job_id": run.job_id,
+ "input": _json_object(run.input_json),
+ "output": _json_object(run.output_json),
+ "error": run.error,
+ "started_at": run.started_at.isoformat() if run.started_at else None,
+ "updated_at": run.updated_at.isoformat() if run.updated_at else None,
+ "finished_at": run.finished_at.isoformat() if run.finished_at else None,
+ }
+
+
+def _json_object(text: str | None) -> dict[str, Any]:
+ if not text:
+ return {}
+ try:
+ value = json.loads(text)
+ except json.JSONDecodeError:
+ return {}
+ return value if isinstance(value, dict) else {}
diff --git a/app/pipeline/utils.py b/app/pipeline/utils.py
new file mode 100644
index 0000000..da1d374
--- /dev/null
+++ b/app/pipeline/utils.py
@@ -0,0 +1,89 @@
+from __future__ import annotations
+
+import hashlib
+import json
+import re
+from pathlib import Path
+from typing import Iterable, Optional
+
+from shapely.geometry import shape
+
+
+def sha256_file(path: Path) -> str:
+ h = hashlib.sha256()
+ with path.open("rb") as f:
+ for chunk in iter(lambda: f.read(1024 * 1024), b""):
+ h.update(chunk)
+ return h.hexdigest()
+
+
+def norm_text(value: object) -> str:
+ if value is None:
+ return ""
+ value = str(value).lower().strip()
+ value = value.replace("ß", "ss")
+ value = re.sub(r"[^a-z0-9]+", " ", value)
+ return re.sub(r"\s+", " ", value).strip()
+
+
+def norm_ref(value: object) -> str:
+ if value is None:
+ return ""
+ return re.sub(r"[^a-z0-9]+", "", str(value).lower())
+
+
+def first_nonempty(*values: object) -> str:
+ for value in values:
+ if value is None:
+ continue
+ text = str(value).strip()
+ if text:
+ return text
+ return ""
+
+
+def geometry_json_and_bbox(geometry: object) -> tuple[Optional[str], tuple[Optional[float], Optional[float], Optional[float], Optional[float]]]:
+ if geometry is None:
+ return None, (None, None, None, None)
+ try:
+ geom = shape(geometry) if isinstance(geometry, dict) else geometry
+ if geom.is_empty:
+ return None, (None, None, None, None)
+ min_lon, min_lat, max_lon, max_lat = geom.bounds
+ return json.dumps(geom.__geo_interface__, separators=(",", ":")), (min_lon, min_lat, max_lon, max_lat)
+ except Exception:
+ return None, (None, None, None, None)
+
+
+def bbox_overlap(a: tuple[float | None, float | None, float | None, float | None], b: tuple[float | None, float | None, float | None, float | None]) -> bool:
+ if any(v is None for v in (*a, *b)):
+ return False
+ aminx, aminy, amaxx, amaxy = a # type: ignore[misc]
+ bminx, bminy, bmaxx, bmaxy = b # type: ignore[misc]
+ return not (amaxx < bminx or bmaxx < aminx or amaxy < bminy or bmaxy < aminy)
+
+
+def bbox_center(b: tuple[float | None, float | None, float | None, float | None]) -> Optional[tuple[float, float]]:
+ if any(v is None for v in b):
+ return None
+ minx, miny, maxx, maxy = b # type: ignore[misc]
+ return ((minx + maxx) / 2, (miny + maxy) / 2)
+
+
+def approx_bbox_center_distance_deg(a: tuple[float | None, float | None, float | None, float | None], b: tuple[float | None, float | None, float | None, float | None]) -> Optional[float]:
+ ca = bbox_center(a)
+ cb = bbox_center(b)
+ if ca is None or cb is None:
+ return None
+ return ((ca[0] - cb[0]) ** 2 + (ca[1] - cb[1]) ** 2) ** 0.5
+
+
+def batched(iterable: Iterable[dict], batch_size: int = 1000) -> Iterable[list[dict]]:
+ batch: list[dict] = []
+ for item in iterable:
+ batch.append(item)
+ if len(batch) >= batch_size:
+ yield batch
+ batch = []
+ if batch:
+ yield batch
diff --git a/app/qa.py b/app/qa.py
new file mode 100644
index 0000000..1f04b94
--- /dev/null
+++ b/app/qa.py
@@ -0,0 +1,393 @@
+from __future__ import annotations
+
+import json
+from datetime import datetime, timezone
+from typing import Any
+
+from sqlalchemy import func, select
+from sqlalchemy.orm import Session
+
+from app.gtfs_storage import missing_sidecar_paths as missing_gtfs_sidecar_paths
+from app.models import (
+ CanonicalStop,
+ CanonicalStopLink,
+ Dataset,
+ GtfsAgency,
+ GtfsCalendar,
+ GtfsCalendarDate,
+ GtfsRoute,
+ GtfsShape,
+ GtfsStop,
+ GtfsTrip,
+ Job,
+ OsmFeature,
+ RouteMatch,
+ RoutePattern,
+ RoutePatternStop,
+ Source,
+ SourceCatalogEntry,
+)
+from app.osm_storage import missing_sidecar_paths as missing_osm_sidecar_paths
+from app.pipeline.osm_addresses import ADDRESS_INDEX_VERSION
+from app.pipeline.routing_layer import active_routing_dataset
+
+
+def qa_summary(session: Session) -> dict[str, Any]:
+ active_gtfs_datasets = session.scalars(
+ select(Dataset).where(Dataset.kind == "gtfs", Dataset.is_active.is_(True)).order_by(Dataset.id)
+ ).all()
+ active_osm_datasets = session.scalars(
+ select(Dataset).where(Dataset.kind == "osm_geojson", Dataset.is_active.is_(True)).order_by(Dataset.id)
+ ).all()
+ active_gtfs_ids = [int(dataset.id) for dataset in active_gtfs_datasets]
+ active_osm_ids = [int(dataset.id) for dataset in active_osm_datasets]
+
+ source_catalog_total = _count(session, SourceCatalogEntry)
+ registered_sources = _count(session, Source)
+ linked_catalog_entries = int(
+ session.scalar(
+ select(func.count(func.distinct(Source.catalog_entry_id))).where(Source.catalog_entry_id.is_not(None))
+ )
+ or 0
+ )
+ priority_backlog = _priority_catalog_backlog(session)
+ failed_sources = int(
+ session.scalar(
+ select(func.count())
+ .select_from(Source)
+ .where((Source.last_error.is_not(None)) | Source.status.in_(["failed", "error"]))
+ )
+ or 0
+ )
+
+ active_jobs = _job_status_counts(session)
+ missing_gtfs_sidecars = sum(1 for dataset in active_gtfs_datasets if missing_gtfs_sidecar_paths(dataset))
+ missing_osm_sidecars = sum(1 for dataset in active_osm_datasets if missing_osm_sidecar_paths(dataset))
+
+ gtfs_counts = _gtfs_validation_counts(session, active_gtfs_ids)
+ link_counts = _link_quality_counts(session, active_gtfs_ids, active_osm_ids)
+ route_counts = _route_quality_counts(session, active_gtfs_ids)
+ address_status = _lightweight_address_index_status(session)
+ license_unknown = int(
+ session.scalar(
+ select(func.count())
+ .select_from(Source)
+ .where(Source.kind == "gtfs", (Source.license.is_(None)) | (func.lower(Source.license).in_(["", "unknown"])))
+ )
+ or 0
+ )
+
+ return {
+ "generated_at": datetime.now(timezone.utc).isoformat(),
+ "decision": {
+ "deployment": "same_workbench_for_now",
+ "database": "same_postgresql_database_for_now",
+ "split_trigger": "Split when third-party API, accounts/billing, heavy export jobs, or independent scaling are needed.",
+ "api_contract": "/api/qa/summary is intentionally display-ready but stable enough to become a harmonization-service summary endpoint.",
+ },
+ "sections": [
+ {
+ "id": "source_discovery",
+ "title": "Source Discovery",
+ "items": [
+ _item("Identified sources", source_catalog_total, "info", "Rows in the source catalog."),
+ _item("Registered sources", registered_sources, "info", "Sources known to the importer."),
+ _item("Catalog entries linked", linked_catalog_entries, "good" if linked_catalog_entries else "warn", "Catalog rows connected to importer sources."),
+ _item("Priority catalog backlog", priority_backlog, "warn" if priority_backlog else "good", "P0/P1 catalog rows without a registered source."),
+ ],
+ },
+ {
+ "id": "import_health",
+ "title": "Import Health",
+ "items": [
+ _item("Active GTFS datasets", len(active_gtfs_ids), "good" if active_gtfs_ids else "warn", "Feeds currently participating in harmonization."),
+ _item("Active OSM datasets", len(active_osm_ids), "good" if active_osm_ids else "warn", "Visual/spatial datasets currently active."),
+ _item("Running jobs", active_jobs.get("running", 0), "warn" if active_jobs.get("running", 0) else "info", "Currently running queued work."),
+ _item("Queued jobs", active_jobs.get("queued", 0), "info", "Outstanding queued work."),
+ _item("Failed sources", failed_sources, "bad" if failed_sources else "good", "Sources with failed status or last_error."),
+ _item("Missing GTFS sidecars", missing_gtfs_sidecars, "bad" if missing_gtfs_sidecars else "good", "Active GTFS datasets whose sidecar is unavailable."),
+ _item("Missing OSM sidecars", missing_osm_sidecars, "bad" if missing_osm_sidecars else "good", "Active OSM datasets whose sidecar is unavailable."),
+ ],
+ },
+ {
+ "id": "gtfs_validation",
+ "title": "GTFS Validation",
+ "items": [
+ _item("Agencies", gtfs_counts["agencies"], "info", "Imported agency.txt rows."),
+ _item("Stops", gtfs_counts["stops"], "info", "Imported stops."),
+ _item("Routes", gtfs_counts["routes"], "info", "Imported routes."),
+ _item("Trips", gtfs_counts["trips"], "info", "Imported trips."),
+ _item("Shapes", gtfs_counts["shapes"], "info", "Imported shape records."),
+ _item("Stops without coordinates", gtfs_counts["stops_without_coordinates"], "bad" if gtfs_counts["stops_without_coordinates"] else "good", "Stops that cannot be spatially linked or routed."),
+ _item("Routes without geometry", gtfs_counts["routes_without_geometry"], "warn" if gtfs_counts["routes_without_geometry"] else "good", "Routes with no stored GTFS shape geometry."),
+ _item("Routes without agency", gtfs_counts["routes_without_agency"], "warn" if gtfs_counts["routes_without_agency"] else "good", "Routes missing agency/operator references."),
+ _item("Calendar range", gtfs_counts["calendar_range"], "info", "Min/max imported service dates from calendars and exceptions."),
+ ],
+ },
+ {
+ "id": "deduplication",
+ "title": "Deduplication and Stop Links",
+ "items": [
+ _item("Canonical stops", link_counts["canonical_stops"], "info", "Current normalized stop/station records."),
+ _item("GTFS stop links", link_counts["gtfs_stop_links"], "good" if link_counts["gtfs_stop_links"] else "warn", "Timetable stops linked into canonical stops."),
+ _item("GTFS stops without canonical link", link_counts["gtfs_stops_without_canonical"], "bad" if link_counts["gtfs_stops_without_canonical"] else "good", "Imported active stops that still need deduplication/linking."),
+ _item("OSM visual stop links", link_counts["osm_stop_links"], "good" if link_counts["osm_stop_links"] else "warn", "OSM stop/station features linked to canonical stops."),
+ _item("OSM stops without canonical link", link_counts["osm_stops_without_canonical"], "warn" if link_counts["osm_stops_without_canonical"] else "good", "Visual stops that are not yet linked to GTFS/canonical stops."),
+ _item("Multi-source stop groups", link_counts["multi_source_stop_groups"], "info", "Canonical stops that merge GTFS stops from multiple datasets."),
+ _item("Long-distance OSM links", link_counts["long_distance_osm_links"], "warn" if link_counts["long_distance_osm_links"] else "good", "OSM stop links over 150m from the canonical stop."),
+ ],
+ },
+ {
+ "id": "route_quality",
+ "title": "Route Matching and Geometry",
+ "items": [
+ _item("Matched/accepted routes", route_counts["matched_or_accepted"], "good" if route_counts["matched_or_accepted"] else "warn", "GTFS routes with accepted or automatic OSM matches."),
+ _item("Probable matches", route_counts["probable"], "warn" if route_counts["probable"] else "info", "Potential conflicts needing review."),
+ _item("Weak matches", route_counts["weak"], "warn" if route_counts["weak"] else "good", "Low-confidence route links."),
+ _item("Missing route matches", route_counts["missing"], "bad" if route_counts["missing"] else "good", "Routes with no visual match."),
+ _item("Unreviewed GTFS routes", route_counts["routes_without_match"], "warn" if route_counts["routes_without_match"] else "good", "Active GTFS routes without a RouteMatch row."),
+ _item("Route patterns", route_counts["route_patterns"], "info", "Published visual route-layer patterns."),
+ _item("Route patterns without stops", route_counts["route_patterns_without_stops"], "warn" if route_counts["route_patterns_without_stops"] else "good", "Visual patterns missing canonical stop sequence evidence."),
+ ],
+ },
+ {
+ "id": "publication_readiness",
+ "title": "Publication Readiness",
+ "items": [
+ _item("Address index stale", "yes" if address_status.get("stale") else "no", "warn" if address_status.get("stale") else "good", "Address polygons/search index version status."),
+ _item("GTFS licenses unknown", license_unknown, "warn" if license_unknown else "good", "GTFS sources without explicit redistribution/license status."),
+ _item("Canonical export", "draft", "warn", "Canonical Europe dataset export tables/API are not versioned yet."),
+ _item("Third-party API", "later", "info", "Accounts, billing, quotas, and API backend are intentionally out of scope for this step."),
+ ],
+ },
+ ],
+ "next_actions": [
+ "Add review queues for each non-zero bad/warn metric.",
+ "Persist source authority and redistribution policy before publishing third-party exports.",
+ "Create versioned canonical snapshots and export manifests.",
+ ],
+ }
+
+
+def _item(label: str, value: object, tone: str, description: str) -> dict[str, object]:
+ return {"label": label, "value": value, "tone": tone, "description": description}
+
+
+def _lightweight_address_index_status(session: Session) -> dict[str, object]:
+ dataset = active_routing_dataset(session)
+ if dataset is None or not dataset.metadata_json:
+ return {"stale": False, "version": None, "current_version": ADDRESS_INDEX_VERSION}
+ try:
+ metadata = json.loads(dataset.metadata_json or "{}")
+ except json.JSONDecodeError:
+ metadata = {}
+ address_index = metadata.get("address_index") if isinstance(metadata, dict) else {}
+ if not isinstance(address_index, dict):
+ address_index = {}
+ version = address_index.get("version")
+ return {
+ "stale": bool(address_index and version != ADDRESS_INDEX_VERSION),
+ "version": version,
+ "current_version": ADDRESS_INDEX_VERSION,
+ }
+
+
+def _count(session: Session, model, *where) -> int:
+ stmt = select(func.count()).select_from(model)
+ if where:
+ stmt = stmt.where(*where)
+ return int(session.scalar(stmt) or 0)
+
+
+def _priority_catalog_backlog(session: Session) -> int:
+ linked = select(Source.id).where(Source.catalog_entry_id == SourceCatalogEntry.id).exists()
+ return int(
+ session.scalar(
+ select(func.count())
+ .select_from(SourceCatalogEntry)
+ .where(SourceCatalogEntry.priority.in_(["P0", "P0 fallback", "P1"]), ~linked)
+ )
+ or 0
+ )
+
+
+def _job_status_counts(session: Session) -> dict[str, int]:
+ return {
+ str(status): int(count)
+ for status, count in session.execute(
+ select(Job.status, func.count())
+ .where(Job.dismissed_at.is_(None), Job.status.in_(["queued", "running", "paused", "failed"]))
+ .group_by(Job.status)
+ ).all()
+ }
+
+
+def _gtfs_validation_counts(session: Session, dataset_ids: list[int]) -> dict[str, object]:
+ if not dataset_ids:
+ return {
+ "agencies": 0,
+ "stops": 0,
+ "routes": 0,
+ "trips": 0,
+ "shapes": 0,
+ "stops_without_coordinates": 0,
+ "routes_without_geometry": 0,
+ "routes_without_agency": 0,
+ "calendar_range": "none",
+ }
+ calendar_min, calendar_max = session.execute(
+ select(func.min(GtfsCalendar.start_date), func.max(GtfsCalendar.end_date)).where(GtfsCalendar.dataset_id.in_(dataset_ids))
+ ).one()
+ exception_min, exception_max = session.execute(
+ select(func.min(GtfsCalendarDate.date), func.max(GtfsCalendarDate.date)).where(GtfsCalendarDate.dataset_id.in_(dataset_ids))
+ ).one()
+ min_date = min(value for value in [calendar_min, exception_min] if value is not None) if (calendar_min or exception_min) else None
+ max_date = max(value for value in [calendar_max, exception_max] if value is not None) if (calendar_max or exception_max) else None
+ return {
+ "agencies": _count(session, GtfsAgency, GtfsAgency.dataset_id.in_(dataset_ids)),
+ "stops": _count(session, GtfsStop, GtfsStop.dataset_id.in_(dataset_ids)),
+ "routes": _count(session, GtfsRoute, GtfsRoute.dataset_id.in_(dataset_ids)),
+ "trips": _count(session, GtfsTrip, GtfsTrip.dataset_id.in_(dataset_ids)),
+ "shapes": _count(session, GtfsShape, GtfsShape.dataset_id.in_(dataset_ids)),
+ "stops_without_coordinates": _count(
+ session,
+ GtfsStop,
+ GtfsStop.dataset_id.in_(dataset_ids),
+ (GtfsStop.lat.is_(None)) | (GtfsStop.lon.is_(None)),
+ ),
+ "routes_without_geometry": _count(
+ session,
+ GtfsRoute,
+ GtfsRoute.dataset_id.in_(dataset_ids),
+ (GtfsRoute.geometry_geojson.is_(None)) | (GtfsRoute.geometry_geojson == ""),
+ ),
+ "routes_without_agency": _count(
+ session,
+ GtfsRoute,
+ GtfsRoute.dataset_id.in_(dataset_ids),
+ (GtfsRoute.agency_id.is_(None)) | (GtfsRoute.agency_id == ""),
+ ),
+ "calendar_range": f"{min_date or 'unknown'} -> {max_date or 'unknown'}",
+ }
+
+
+def _link_quality_counts(session: Session, gtfs_dataset_ids: list[int], osm_dataset_ids: list[int]) -> dict[str, int]:
+ if gtfs_dataset_ids:
+ gtfs_link_exists = (
+ select(CanonicalStopLink.id)
+ .where(
+ CanonicalStopLink.object_type == "gtfs_stop",
+ CanonicalStopLink.dataset_id == GtfsStop.dataset_id,
+ CanonicalStopLink.object_id == GtfsStop.id,
+ )
+ .exists()
+ )
+ gtfs_stops_without_canonical = _count(
+ session,
+ GtfsStop,
+ GtfsStop.dataset_id.in_(gtfs_dataset_ids),
+ ~gtfs_link_exists,
+ )
+ gtfs_stop_links = _count(
+ session,
+ CanonicalStopLink,
+ CanonicalStopLink.object_type == "gtfs_stop",
+ CanonicalStopLink.dataset_id.in_(gtfs_dataset_ids),
+ )
+ multi_source_subquery = (
+ select(CanonicalStopLink.canonical_stop_id)
+ .where(CanonicalStopLink.object_type == "gtfs_stop", CanonicalStopLink.dataset_id.in_(gtfs_dataset_ids))
+ .group_by(CanonicalStopLink.canonical_stop_id)
+ .having(func.count(func.distinct(CanonicalStopLink.dataset_id)) > 1)
+ .subquery()
+ )
+ multi_source_stop_groups = int(session.scalar(select(func.count()).select_from(multi_source_subquery)) or 0)
+ else:
+ gtfs_stops_without_canonical = 0
+ gtfs_stop_links = 0
+ multi_source_stop_groups = 0
+
+ if osm_dataset_ids:
+ osm_link_exists = (
+ select(CanonicalStopLink.id)
+ .where(
+ CanonicalStopLink.object_type == "osm_feature",
+ CanonicalStopLink.dataset_id == OsmFeature.dataset_id,
+ CanonicalStopLink.object_id == OsmFeature.id,
+ )
+ .exists()
+ )
+ osm_stops_without_canonical = _count(
+ session,
+ OsmFeature,
+ OsmFeature.dataset_id.in_(osm_dataset_ids),
+ OsmFeature.kind.in_(["stop", "station", "terminal"]),
+ ~osm_link_exists,
+ )
+ osm_stop_links = _count(
+ session,
+ CanonicalStopLink,
+ CanonicalStopLink.object_type == "osm_feature",
+ CanonicalStopLink.dataset_id.in_(osm_dataset_ids),
+ )
+ long_distance_osm_links = _count(
+ session,
+ CanonicalStopLink,
+ CanonicalStopLink.object_type == "osm_feature",
+ CanonicalStopLink.dataset_id.in_(osm_dataset_ids),
+ CanonicalStopLink.distance_m > 150,
+ )
+ else:
+ osm_stops_without_canonical = 0
+ osm_stop_links = 0
+ long_distance_osm_links = 0
+
+ return {
+ "canonical_stops": _count(session, CanonicalStop),
+ "gtfs_stop_links": gtfs_stop_links,
+ "gtfs_stops_without_canonical": gtfs_stops_without_canonical,
+ "osm_stop_links": osm_stop_links,
+ "osm_stops_without_canonical": osm_stops_without_canonical,
+ "multi_source_stop_groups": multi_source_stop_groups,
+ "long_distance_osm_links": long_distance_osm_links,
+ }
+
+
+def _route_quality_counts(session: Session, gtfs_dataset_ids: list[int]) -> dict[str, int]:
+ route_patterns = _count(session, RoutePattern)
+ route_pattern_stop_exists = (
+ select(RoutePatternStop.id)
+ .where(RoutePatternStop.route_pattern_id == RoutePattern.id)
+ .exists()
+ )
+ route_patterns_without_stops = _count(session, RoutePattern, ~route_pattern_stop_exists)
+ if not gtfs_dataset_ids:
+ return {
+ "matched_or_accepted": 0,
+ "probable": 0,
+ "weak": 0,
+ "missing": 0,
+ "routes_without_match": 0,
+ "route_patterns": route_patterns,
+ "route_patterns_without_stops": route_patterns_without_stops,
+ }
+ match_rows = {
+ str(status): int(count)
+ for status, count in session.execute(
+ select(RouteMatch.status, func.count())
+ .join(GtfsRoute, GtfsRoute.id == RouteMatch.gtfs_route_id)
+ .where(GtfsRoute.dataset_id.in_(gtfs_dataset_ids))
+ .group_by(RouteMatch.status)
+ ).all()
+ }
+ match_exists = select(RouteMatch.id).where(RouteMatch.gtfs_route_id == GtfsRoute.id).exists()
+ routes_without_match = _count(session, GtfsRoute, GtfsRoute.dataset_id.in_(gtfs_dataset_ids), ~match_exists)
+ return {
+ "matched_or_accepted": match_rows.get("matched", 0) + match_rows.get("accepted", 0),
+ "probable": match_rows.get("probable", 0),
+ "weak": match_rows.get("weak", 0),
+ "missing": match_rows.get("missing", 0),
+ "routes_without_match": routes_without_match,
+ "route_patterns": route_patterns,
+ "route_patterns_without_stops": route_patterns_without_stops,
+ }
diff --git a/app/routing.py b/app/routing.py
new file mode 100644
index 0000000..331a9da
--- /dev/null
+++ b/app/routing.py
@@ -0,0 +1,911 @@
+from __future__ import annotations
+
+import copy
+import heapq
+import json
+import math
+import threading
+import time
+from collections import OrderedDict
+from dataclasses import dataclass
+
+from sqlalchemy import func, select, text
+from sqlalchemy.exc import SQLAlchemyError
+from sqlalchemy.orm import Session
+
+from app.config import settings
+from app.models import Dataset, RoutingEdge, RoutingNode
+from app.pipeline.routing_layer import active_routing_dataset
+from app.serializers import feature_collection
+
+
+WALK_HEURISTIC_MPS = 1.6
+DRIVE_HEURISTIC_MPS = 36.0
+DEFAULT_MAX_VISITED = 160_000
+PGR_WALK_BBOX_PADDING_KM = [0.5, 1.5, 4, 10, 25]
+PGR_DRIVE_BBOX_PADDING_KM = [2, 8, 25, 75, 200]
+PGR_WALK_STATEMENT_TIMEOUT_MS = 2_500
+PGR_DRIVE_STATEMENT_TIMEOUT_MS = 7_500
+ROUTE_CACHE_TTL_SECONDS = 15 * 60
+ROUTE_CACHE_MAX_ENTRIES = 512
+_route_cache_lock = threading.RLock()
+_route_cache: OrderedDict[tuple[object, ...], tuple[float, dict[str, object]]] = OrderedDict()
+
+
+@dataclass(frozen=True)
+class _GraphNode:
+ osm_node_id: int
+ lon: float
+ lat: float
+ distance_m: float
+
+
+@dataclass(frozen=True)
+class _Traversal:
+ edge_id: int
+ from_node: int
+ to_node: int
+ from_lon: float
+ from_lat: float
+ to_lon: float
+ to_lat: float
+ cost_s: float
+ length_m: float
+ highway: str | None
+ name: str | None
+ geometry_geojson: str
+ reversed: bool
+
+
+def routing_status(db: Session) -> dict[str, object]:
+ dataset = active_routing_dataset(db)
+ dataset_id = None if dataset is None else int(dataset.id)
+ node_count = 0
+ edge_count = 0
+ if dataset_id is not None:
+ node_count, edge_count = _routing_status_counts(db, dataset, dataset_id)
+ pgrouting_available = False
+ pgrouting_installed = False
+ if settings.is_postgresql_database:
+ pgrouting_available = bool(
+ db.execute(text("SELECT EXISTS (SELECT 1 FROM pg_available_extensions WHERE name = 'pgrouting')")).scalar()
+ )
+ pgrouting_installed = bool(
+ db.execute(text("SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgrouting')")).scalar()
+ )
+ return {
+ "dataset_id": dataset_id,
+ "nodes": node_count,
+ "edges": edge_count,
+ "available": edge_count > 0,
+ "engine": "pgrouting" if pgrouting_installed else "python_astar",
+ "pgrouting_available": pgrouting_available,
+ "pgrouting_installed": pgrouting_installed,
+ }
+
+
+def _routing_status_counts(db: Session, dataset: Dataset, dataset_id: int) -> tuple[int, int]:
+ metadata = _metadata(dataset)
+ routing_layer = metadata.get("routing_layer")
+ if isinstance(routing_layer, dict):
+ try:
+ nodes = int(routing_layer.get("nodes") or 0)
+ edges = int(routing_layer.get("edges") or 0)
+ except (TypeError, ValueError):
+ nodes = 0
+ edges = 0
+ if nodes or edges:
+ return nodes, edges
+ if settings.is_postgresql_database:
+ rows = db.execute(
+ text(
+ """
+ SELECT relname, COALESCE(reltuples, 0)::bigint AS estimate
+ FROM pg_class
+ WHERE oid IN ('routing_nodes'::regclass, 'routing_edges'::regclass)
+ """
+ )
+ ).mappings()
+ estimates = {str(row["relname"]): int(row["estimate"] or 0) for row in rows}
+ return estimates.get("routing_nodes", 0), estimates.get("routing_edges", 0)
+ node_count = int(db.scalar(select(func.count()).select_from(RoutingNode).where(RoutingNode.dataset_id == dataset_id)) or 0)
+ edge_count = int(db.scalar(select(func.count()).select_from(RoutingEdge).where(RoutingEdge.dataset_id == dataset_id)) or 0)
+ return node_count, edge_count
+
+
+def _metadata(dataset: Dataset) -> dict[str, object]:
+ if not dataset.metadata_json:
+ return {}
+ try:
+ value = json.loads(dataset.metadata_json)
+ except json.JSONDecodeError:
+ return {}
+ return value if isinstance(value, dict) else {}
+
+
+def route_between_points(
+ db: Session,
+ *,
+ from_lon: float,
+ from_lat: float,
+ to_lon: float,
+ to_lat: float,
+ mode: str = "walk",
+ dataset_id: int | None = None,
+ max_visited: int = DEFAULT_MAX_VISITED,
+) -> dict[str, object]:
+ if mode not in {"walk", "drive"}:
+ raise ValueError("mode must be walk or drive")
+ dataset = db.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(db)
+ if dataset is None:
+ raise ValueError("No routing dataset is available.")
+ dataset_id = int(dataset.id)
+ cache_key = _route_cache_key(dataset_id, mode, from_lon, from_lat, to_lon, to_lat)
+ cached = _route_cache_get(cache_key)
+ if cached is not None:
+ return cached
+ start = _nearest_node(db, dataset_id, from_lon, from_lat, mode)
+ target = _nearest_node(db, dataset_id, to_lon, to_lat, mode)
+ if start is None or target is None:
+ raise ValueError("Routing graph has no nearby nodes for the requested mode.")
+ if start.osm_node_id == target.osm_node_id:
+ payload = _single_point_route(start, from_lon, from_lat, to_lon, to_lat, mode, dataset_id)
+ _route_cache_put(cache_key, payload)
+ return payload
+ if settings.is_postgresql_database and _pgrouting_installed(db):
+ try:
+ payload = _route_with_pgrouting(
+ db,
+ dataset_id=dataset_id,
+ mode=mode,
+ start=start,
+ target=target,
+ from_lon=from_lon,
+ from_lat=from_lat,
+ to_lon=to_lon,
+ to_lat=to_lat,
+ )
+ _route_cache_put(cache_key, payload)
+ return payload
+ except ValueError:
+ pass
+ except SQLAlchemyError:
+ db.rollback()
+
+ heuristic_mps = WALK_HEURISTIC_MPS if mode == "walk" else DRIVE_HEURISTIC_MPS
+ queue: list[tuple[float, float, int]] = []
+ heapq.heappush(queue, (0.0, 0.0, start.osm_node_id))
+ costs: dict[int, float] = {start.osm_node_id: 0.0}
+ coords: dict[int, tuple[float, float]] = {start.osm_node_id: (start.lon, start.lat), target.osm_node_id: (target.lon, target.lat)}
+ previous: dict[int, tuple[int, _Traversal]] = {}
+ adjacency_cache: dict[int, list[_Traversal]] = {}
+ visited: set[int] = set()
+
+ while queue and len(visited) < max(1, max_visited):
+ _, cost, node_id = heapq.heappop(queue)
+ if node_id in visited:
+ continue
+ visited.add(node_id)
+ if node_id == target.osm_node_id:
+ payload = _route_payload(
+ dataset_id=dataset_id,
+ mode=mode,
+ start=start,
+ target=target,
+ from_lon=from_lon,
+ from_lat=from_lat,
+ to_lon=to_lon,
+ to_lat=to_lat,
+ previous=previous,
+ total_cost_s=cost,
+ visited=len(visited),
+ )
+ _route_cache_put(cache_key, payload)
+ return payload
+ for edge in adjacency_cache.setdefault(node_id, _outgoing_edges(db, dataset_id, node_id, mode)):
+ coords[edge.to_node] = (edge.to_lon, edge.to_lat)
+ next_cost = cost + edge.cost_s
+ if next_cost >= costs.get(edge.to_node, float("inf")):
+ continue
+ costs[edge.to_node] = next_cost
+ previous[edge.to_node] = (node_id, edge)
+ heuristic = _distance_m(edge.to_lat, edge.to_lon, target.lat, target.lon) / heuristic_mps
+ heapq.heappush(queue, (next_cost + heuristic, next_cost, edge.to_node))
+
+ raise ValueError(f"No {mode} route found within {max_visited:,} visited graph nodes.")
+
+
+def direct_route_between_points(
+ db: Session,
+ *,
+ from_lon: float,
+ from_lat: float,
+ to_lon: float,
+ to_lat: float,
+ mode: str = "walk",
+ dataset_id: int | None = None,
+ reason: str | None = None,
+) -> dict[str, object]:
+ if mode not in {"walk", "drive"}:
+ raise ValueError("mode must be walk or drive")
+ dataset = db.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(db)
+ payload = _direct_route_payload(
+ dataset_id=0 if dataset is None else int(dataset.id),
+ mode=mode,
+ from_lon=float(from_lon),
+ from_lat=float(from_lat),
+ to_lon=float(to_lon),
+ to_lat=float(to_lat),
+ )
+ if reason:
+ payload["warning"] = reason
+ return payload
+
+
+def snap_point_to_routing_graph(
+ db: Session,
+ *,
+ lon: float,
+ lat: float,
+ mode: str = "walk",
+ dataset_id: int | None = None,
+ max_distance_m: float = 250,
+) -> dict[str, object] | None:
+ if mode not in {"walk", "drive"}:
+ raise ValueError("mode must be walk or drive")
+ dataset = db.get(Dataset, dataset_id) if dataset_id is not None else active_routing_dataset(db)
+ if dataset is None:
+ return None
+ dataset_id = int(dataset.id)
+ if settings.is_postgresql_database:
+ return _snap_point_to_routing_edge_postgresql(
+ db,
+ dataset_id=dataset_id,
+ lon=float(lon),
+ lat=float(lat),
+ mode=mode,
+ max_distance_m=float(max_distance_m),
+ )
+ node = _nearest_node(db, dataset_id, float(lon), float(lat), mode)
+ if node is None or node.distance_m > max_distance_m:
+ return None
+ return {
+ "dataset_id": dataset_id,
+ "lon": node.lon,
+ "lat": node.lat,
+ "distance_m": round(node.distance_m, 1),
+ "source": "routing_node",
+ "osm_node_id": node.osm_node_id,
+ }
+
+
+def _snap_point_to_routing_edge_postgresql(
+ db: Session,
+ *,
+ dataset_id: int,
+ lon: float,
+ lat: float,
+ mode: str,
+ max_distance_m: float,
+) -> dict[str, object] | None:
+ cost_column = "walk_cost_s" if mode == "walk" else "drive_cost_s"
+ reverse_cost_column = "reverse_walk_cost_s" if mode == "walk" else "reverse_drive_cost_s"
+ radius_deg = max_distance_m / 111_320
+ row = db.execute(
+ text(
+ f"""
+ WITH point AS (
+ SELECT ST_SetSRID(ST_MakePoint(:lon, :lat), 4326) AS geom
+ ),
+ edges AS MATERIALIZED (
+ SELECT
+ edge.id,
+ edge.highway,
+ edge.name,
+ CASE
+ WHEN edge.tags_json IS NULL OR edge.tags_json = '' THEN NULL
+ ELSE edge.tags_json::jsonb ->> 'service'
+ END AS service,
+ edge.source_osm_node_id,
+ edge.target_osm_node_id,
+ ST_SetSRID(
+ ST_MakeLine(
+ ST_MakePoint(edge.source_lon, edge.source_lat),
+ ST_MakePoint(edge.target_lon, edge.target_lat)
+ ),
+ 4326
+ ) AS edge_geom
+ FROM routing_edges AS edge
+ CROSS JOIN point
+ WHERE edge.dataset_id = :dataset_id
+ AND (edge.{cost_column} IS NOT NULL OR edge.{reverse_cost_column} IS NOT NULL)
+ AND box(point(edge.max_lon, edge.max_lat), point(edge.min_lon, edge.min_lat))
+ && box(
+ point(:lon + :radius_deg, :lat + :radius_deg),
+ point(:lon - :radius_deg, :lat - :radius_deg)
+ )
+ ),
+ candidate AS (
+ SELECT
+ edges.id,
+ edges.highway,
+ edges.name,
+ edges.service,
+ edges.source_osm_node_id,
+ edges.target_osm_node_id,
+ ST_ClosestPoint(edges.edge_geom, point.geom) AS snapped_geom,
+ ST_DistanceSphere(edges.edge_geom, point.geom) AS distance_m,
+ CASE
+ WHEN edges.highway IN ('footway', 'pedestrian', 'steps') THEN 0
+ WHEN edges.highway IN ('path', 'cycleway', 'bridleway') THEN 1
+ WHEN edges.highway IN ('living_street', 'residential') THEN 2
+ WHEN edges.highway = 'service' THEN 3
+ ELSE 4
+ END AS highway_rank,
+ CASE
+ WHEN :mode != 'walk' THEN 0
+ WHEN edges.highway = 'service' THEN 20
+ WHEN edges.highway IN ('primary', 'primary_link', 'secondary', 'secondary_link') THEN 10
+ WHEN edges.highway IN ('tertiary', 'tertiary_link', 'unclassified', 'road') THEN 5
+ ELSE 0
+ END AS snap_penalty_m
+ FROM edges
+ CROSS JOIN point
+ WHERE ST_DWithin(edges.edge_geom::geography, point.geom::geography, :max_distance_m)
+ AND NOT (
+ :mode = 'walk'
+ AND edges.highway = 'service'
+ AND COALESCE(edges.service, '') IN ('driveway', 'parking_aisle', 'drive-through')
+ )
+ ORDER BY
+ ST_DistanceSphere(edges.edge_geom, point.geom) + CASE
+ WHEN :mode != 'walk' THEN 0
+ WHEN edges.highway = 'service' THEN 20
+ WHEN edges.highway IN ('primary', 'primary_link', 'secondary', 'secondary_link') THEN 10
+ WHEN edges.highway IN ('tertiary', 'tertiary_link', 'unclassified', 'road') THEN 5
+ ELSE 0
+ END,
+ ST_DistanceSphere(edges.edge_geom, point.geom),
+ highway_rank,
+ edges.id
+ LIMIT 1
+ )
+ SELECT
+ id,
+ highway,
+ name,
+ source_osm_node_id,
+ target_osm_node_id,
+ ST_X(snapped_geom) AS lon,
+ ST_Y(snapped_geom) AS lat,
+ distance_m
+ FROM candidate
+ """
+ ),
+ {
+ "dataset_id": dataset_id,
+ "lon": lon,
+ "lat": lat,
+ "radius_deg": radius_deg,
+ "max_distance_m": max_distance_m,
+ "mode": mode,
+ },
+ ).mappings().first()
+ if row is None:
+ return None
+ return {
+ "dataset_id": dataset_id,
+ "lon": float(row["lon"]),
+ "lat": float(row["lat"]),
+ "distance_m": round(float(row["distance_m"] or 0), 1),
+ "source": "routing_edge",
+ "edge_id": int(row["id"]),
+ "highway": row["highway"],
+ "name": row["name"],
+ "source_osm_node_id": int(row["source_osm_node_id"]),
+ "target_osm_node_id": int(row["target_osm_node_id"]),
+ }
+
+
+def _route_cache_key(dataset_id: int, mode: str, from_lon: float, from_lat: float, to_lon: float, to_lat: float) -> tuple[object, ...]:
+ return (
+ int(dataset_id),
+ mode,
+ round(float(from_lon), 6),
+ round(float(from_lat), 6),
+ round(float(to_lon), 6),
+ round(float(to_lat), 6),
+ )
+
+
+def _route_cache_get(key: tuple[object, ...]) -> dict[str, object] | None:
+ now = time.monotonic()
+ with _route_cache_lock:
+ cached = _route_cache.get(key)
+ if cached is None:
+ return None
+ expires_at, payload = cached
+ if expires_at <= now:
+ _route_cache.pop(key, None)
+ return None
+ _route_cache.move_to_end(key)
+ return copy.deepcopy(payload)
+
+
+def _route_cache_put(key: tuple[object, ...], payload: dict[str, object]) -> None:
+ with _route_cache_lock:
+ _route_cache[key] = (time.monotonic() + ROUTE_CACHE_TTL_SECONDS, copy.deepcopy(payload))
+ _route_cache.move_to_end(key)
+ while len(_route_cache) > ROUTE_CACHE_MAX_ENTRIES:
+ _route_cache.popitem(last=False)
+
+
+def _pgrouting_installed(db: Session) -> bool:
+ return bool(db.execute(text("SELECT EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pgrouting')")).scalar())
+
+
+def _route_with_pgrouting(
+ db: Session,
+ *,
+ dataset_id: int,
+ mode: str,
+ start: _GraphNode,
+ target: _GraphNode,
+ from_lon: float,
+ from_lat: float,
+ to_lon: float,
+ to_lat: float,
+) -> dict[str, object]:
+ cost_column = "walk_cost_s" if mode == "walk" else "drive_cost_s"
+ reverse_cost_column = "reverse_walk_cost_s" if mode == "walk" else "reverse_drive_cost_s"
+ routing_cost = _routing_cost_expression(cost_column, mode)
+ reverse_routing_cost = _routing_cost_expression(reverse_cost_column, mode)
+ for padding_km in PGR_WALK_BBOX_PADDING_KM if mode == "walk" else PGR_DRIVE_BBOX_PADDING_KM:
+ _set_local_statement_timeout(
+ db,
+ PGR_WALK_STATEMENT_TIMEOUT_MS if mode == "walk" else PGR_DRIVE_STATEMENT_TIMEOUT_MS,
+ )
+ bbox = _expanded_bbox(
+ min(from_lon, to_lon, start.lon, target.lon),
+ min(from_lat, to_lat, start.lat, target.lat),
+ max(from_lon, to_lon, start.lon, target.lon),
+ max(from_lat, to_lat, start.lat, target.lat),
+ padding_km,
+ )
+ edge_sql = f"""
+ SELECT
+ id,
+ source_osm_node_id AS source,
+ target_osm_node_id AS target,
+ COALESCE({routing_cost}, -1)::float8 AS cost,
+ COALESCE({reverse_routing_cost}, -1)::float8 AS reverse_cost
+ FROM routing_edges
+ WHERE dataset_id = {int(dataset_id)}
+ AND ({cost_column} IS NOT NULL OR {reverse_cost_column} IS NOT NULL)
+ AND box(point(max_lon, max_lat), point(min_lon, min_lat))
+ && box(point({bbox[2]:.8f}, {bbox[3]:.8f}), point({bbox[0]:.8f}, {bbox[1]:.8f}))
+ """
+ rows = db.execute(
+ text(
+ f"""
+ WITH route AS (
+ SELECT *
+ FROM pgr_dijkstra(:edge_sql, :start_node, :target_node, directed := true)
+ ),
+ steps AS (
+ SELECT
+ route.path_seq,
+ route.node AS from_node,
+ LEAD(route.node) OVER (ORDER BY route.path_seq) AS to_node,
+ route.edge,
+ route.cost
+ FROM route
+ )
+ SELECT
+ steps.path_seq,
+ steps.from_node,
+ steps.to_node,
+ steps.cost,
+ edge.id,
+ edge.source_osm_node_id,
+ edge.target_osm_node_id,
+ edge.source_lon,
+ edge.source_lat,
+ edge.target_lon,
+ edge.target_lat,
+ edge.length_m,
+ edge.highway,
+ edge.name,
+ edge.geometry_geojson,
+ CASE
+ WHEN steps.from_node = edge.source_osm_node_id THEN edge.{cost_column}
+ ELSE edge.{reverse_cost_column}
+ END AS actual_cost_s
+ FROM steps
+ JOIN routing_edges AS edge ON edge.id = steps.edge
+ WHERE steps.edge <> -1
+ ORDER BY steps.path_seq
+ """
+ ),
+ {"edge_sql": edge_sql, "start_node": start.osm_node_id, "target_node": target.osm_node_id},
+ ).all()
+ if rows:
+ return _pgrouting_payload(
+ dataset_id=dataset_id,
+ mode=mode,
+ start=start,
+ target=target,
+ from_lon=from_lon,
+ from_lat=from_lat,
+ to_lon=to_lon,
+ to_lat=to_lat,
+ rows=rows,
+ padding_km=padding_km,
+ )
+ raise ValueError("pgRouting did not find a route in the bounded search area.")
+
+
+def _set_local_statement_timeout(db: Session, timeout_ms: int) -> None:
+ db.execute(text("SELECT set_config('statement_timeout', :timeout, true)"), {"timeout": f"{int(timeout_ms)}ms"})
+
+
+def _pgrouting_payload(
+ *,
+ dataset_id: int,
+ mode: str,
+ start: _GraphNode,
+ target: _GraphNode,
+ from_lon: float,
+ from_lat: float,
+ to_lon: float,
+ to_lat: float,
+ rows,
+ padding_km: float,
+) -> dict[str, object]:
+ previous: dict[int, tuple[int, _Traversal]] = {}
+ total_cost = 0.0
+ for row in rows:
+ if row.to_node is None:
+ continue
+ from_node = int(row.from_node)
+ to_node = int(row.to_node)
+ source_node = int(row.source_osm_node_id)
+ target_node = int(row.target_osm_node_id)
+ actual_cost = float(row.actual_cost_s if row.actual_cost_s is not None else row.cost or 0)
+ reversed_edge = from_node == target_node and to_node == source_node
+ if reversed_edge:
+ from_lon_edge, from_lat_edge = float(row.target_lon), float(row.target_lat)
+ to_lon_edge, to_lat_edge = float(row.source_lon), float(row.source_lat)
+ else:
+ from_lon_edge, from_lat_edge = float(row.source_lon), float(row.source_lat)
+ to_lon_edge, to_lat_edge = float(row.target_lon), float(row.target_lat)
+ total_cost += actual_cost
+ previous[to_node] = (
+ from_node,
+ _Traversal(
+ edge_id=int(row.id),
+ from_node=from_node,
+ to_node=to_node,
+ from_lon=from_lon_edge,
+ from_lat=from_lat_edge,
+ to_lon=to_lon_edge,
+ to_lat=to_lat_edge,
+ cost_s=actual_cost,
+ length_m=float(row.length_m),
+ highway=row.highway,
+ name=row.name,
+ geometry_geojson=str(row.geometry_geojson),
+ reversed=reversed_edge,
+ ),
+ )
+ payload = _route_payload(
+ dataset_id=dataset_id,
+ mode=mode,
+ start=start,
+ target=target,
+ from_lon=from_lon,
+ from_lat=from_lat,
+ to_lon=to_lon,
+ to_lat=to_lat,
+ previous=previous,
+ total_cost_s=total_cost,
+ visited=len(rows),
+ )
+ payload["engine"] = "pgrouting"
+ payload["bbox_padding_km"] = padding_km
+ return payload
+
+
+def _routing_cost_expression(column: str, mode: str) -> str:
+ if mode != "walk":
+ return column
+ return f"""
+ CASE
+ WHEN {column} IS NULL THEN NULL
+ ELSE {column} * CASE
+ WHEN highway IN ('footway', 'pedestrian') THEN 0.70
+ WHEN highway = 'path' THEN 0.78
+ WHEN highway = 'steps' THEN 0.95
+ WHEN highway = 'cycleway' THEN 1.05
+ WHEN highway = 'bridleway' THEN 1.10
+ WHEN highway IN ('living_street', 'track') THEN 1.15
+ WHEN highway IN ('residential', 'service') THEN 1.35
+ WHEN highway IN ('unclassified', 'road') THEN 1.55
+ WHEN highway IN ('tertiary', 'tertiary_link') THEN 1.80
+ WHEN highway IN ('secondary', 'secondary_link') THEN 2.15
+ WHEN highway IN ('primary', 'primary_link') THEN 2.50
+ ELSE 1.30
+ END
+ END
+ """
+
+
+def _nearest_node(db: Session, dataset_id: int, lon: float, lat: float, mode: str) -> _GraphNode | None:
+ cost_column = "walk_cost_s" if mode == "walk" else "drive_cost_s"
+ reverse_cost_column = "reverse_walk_cost_s" if mode == "walk" else "reverse_drive_cost_s"
+ row = None
+ for candidate_limit in (64, 512, 4096):
+ row = db.execute(
+ text(
+ f"""
+ WITH nearest AS MATERIALIZED (
+ SELECT node.osm_node_id, node.lon, node.lat, node.geom
+ FROM routing_nodes AS node
+ WHERE node.dataset_id = :dataset_id
+ AND node.geom IS NOT NULL
+ ORDER BY node.geom <-> ST_SetSRID(ST_MakePoint(:lon, :lat), 4326)
+ LIMIT :candidate_limit
+ ),
+ candidate AS (
+ SELECT nearest.osm_node_id, nearest.lon, nearest.lat, nearest.geom
+ FROM nearest
+ WHERE EXISTS (
+ SELECT 1
+ FROM routing_edges AS edge
+ WHERE edge.dataset_id = :dataset_id
+ AND (
+ (edge.source_osm_node_id = nearest.osm_node_id AND edge.{cost_column} IS NOT NULL)
+ OR (edge.target_osm_node_id = nearest.osm_node_id AND edge.{reverse_cost_column} IS NOT NULL)
+ )
+ LIMIT 1
+ )
+ ORDER BY nearest.geom <-> ST_SetSRID(ST_MakePoint(:lon, :lat), 4326)
+ LIMIT 1
+ )
+ SELECT osm_node_id, lon, lat, ST_DistanceSphere(geom, ST_SetSRID(ST_MakePoint(:lon, :lat), 4326)) AS distance_m
+ FROM candidate
+ """
+ ),
+ {"dataset_id": dataset_id, "lon": lon, "lat": lat, "candidate_limit": candidate_limit},
+ ).first()
+ if row is not None:
+ break
+ if row is None:
+ return None
+ return _GraphNode(osm_node_id=int(row.osm_node_id), lon=float(row.lon), lat=float(row.lat), distance_m=float(row.distance_m or 0))
+
+
+def _outgoing_edges(db: Session, dataset_id: int, node_id: int, mode: str) -> list[_Traversal]:
+ cost_column = "walk_cost_s" if mode == "walk" else "drive_cost_s"
+ reverse_cost_column = "reverse_walk_cost_s" if mode == "walk" else "reverse_drive_cost_s"
+ rows = db.execute(
+ text(
+ f"""
+ SELECT
+ id, source_osm_node_id, target_osm_node_id,
+ source_lon, source_lat, target_lon, target_lat,
+ length_m, highway, name, geometry_geojson,
+ CASE
+ WHEN source_osm_node_id = :node_id THEN {cost_column}
+ ELSE {reverse_cost_column}
+ END AS cost_s,
+ target_osm_node_id != :node_id AS forward
+ FROM routing_edges
+ WHERE dataset_id = :dataset_id
+ AND (
+ (source_osm_node_id = :node_id AND {cost_column} IS NOT NULL)
+ OR (target_osm_node_id = :node_id AND {reverse_cost_column} IS NOT NULL)
+ )
+ """
+ ),
+ {"dataset_id": dataset_id, "node_id": node_id},
+ ).all()
+ edges = []
+ for row in rows:
+ forward = bool(row.forward)
+ if forward:
+ to_node = int(row.target_osm_node_id)
+ from_lon, from_lat = float(row.source_lon), float(row.source_lat)
+ to_lon, to_lat = float(row.target_lon), float(row.target_lat)
+ else:
+ to_node = int(row.source_osm_node_id)
+ from_lon, from_lat = float(row.target_lon), float(row.target_lat)
+ to_lon, to_lat = float(row.source_lon), float(row.source_lat)
+ edges.append(
+ _Traversal(
+ edge_id=int(row.id),
+ from_node=node_id,
+ to_node=to_node,
+ from_lon=from_lon,
+ from_lat=from_lat,
+ to_lon=to_lon,
+ to_lat=to_lat,
+ cost_s=float(row.cost_s),
+ length_m=float(row.length_m),
+ highway=row.highway,
+ name=row.name,
+ geometry_geojson=str(row.geometry_geojson),
+ reversed=not forward,
+ )
+ )
+ return edges
+
+
+def _route_payload(
+ *,
+ dataset_id: int,
+ mode: str,
+ start: _GraphNode,
+ target: _GraphNode,
+ from_lon: float,
+ from_lat: float,
+ to_lon: float,
+ to_lat: float,
+ previous: dict[int, tuple[int, _Traversal]],
+ total_cost_s: float,
+ visited: int,
+) -> dict[str, object]:
+ edges: list[_Traversal] = []
+ current = target.osm_node_id
+ while current != start.osm_node_id:
+ prior, edge = previous[current]
+ edges.append(edge)
+ current = prior
+ edges.reverse()
+ network_distance = sum(edge.length_m for edge in edges)
+ access_distance = start.distance_m + target.distance_m
+ features = []
+ if start.distance_m:
+ features.append(_connector_feature("access", mode, [[from_lon, from_lat], [start.lon, start.lat]], start.distance_m))
+ for index, edge in enumerate(edges, start=1):
+ geometry = json.loads(edge.geometry_geojson)
+ if edge.reversed:
+ geometry["coordinates"] = list(reversed(geometry.get("coordinates", [])))
+ features.append(
+ {
+ "type": "Feature",
+ "geometry": geometry,
+ "properties": {
+ "feature_type": "routing_edge",
+ "sequence": index,
+ "mode": mode,
+ "edge_id": edge.edge_id,
+ "highway": edge.highway,
+ "name": edge.name,
+ "length_m": edge.length_m,
+ "cost_s": edge.cost_s,
+ },
+ }
+ )
+ if target.distance_m:
+ features.append(_connector_feature("egress", mode, [[target.lon, target.lat], [to_lon, to_lat]], target.distance_m))
+ duration_seconds = total_cost_s + _connector_seconds(access_distance, mode)
+ return {
+ "dataset_id": dataset_id,
+ "mode": mode,
+ "engine": "python_astar",
+ "distance_m": round(network_distance + access_distance, 1),
+ "network_distance_m": round(network_distance, 1),
+ "access_distance_m": round(access_distance, 1),
+ "duration_seconds": round(duration_seconds, 1),
+ "duration_minutes": _duration_minutes_ceil(duration_seconds),
+ "duration_label": _duration_label(duration_seconds),
+ "visited_nodes": visited,
+ "start_node": {"osm_node_id": start.osm_node_id, "distance_m": round(start.distance_m, 1)},
+ "target_node": {"osm_node_id": target.osm_node_id, "distance_m": round(target.distance_m, 1)},
+ "features": feature_collection(features),
+ }
+
+
+def _single_point_route(start: _GraphNode, from_lon: float, from_lat: float, to_lon: float, to_lat: float, mode: str, dataset_id: int) -> dict[str, object]:
+ return _direct_route_payload(
+ dataset_id=dataset_id,
+ mode=mode,
+ from_lon=from_lon,
+ from_lat=from_lat,
+ to_lon=to_lon,
+ to_lat=to_lat,
+ engine="python_astar",
+ start_node={"osm_node_id": start.osm_node_id, "distance_m": round(start.distance_m, 1)},
+ target_node={"osm_node_id": start.osm_node_id, "distance_m": round(start.distance_m, 1)},
+ visited_nodes=1,
+ )
+
+
+def _direct_route_payload(
+ *,
+ dataset_id: int,
+ mode: str,
+ from_lon: float,
+ from_lat: float,
+ to_lon: float,
+ to_lat: float,
+ engine: str = "direct_fallback",
+ start_node: dict[str, object] | None = None,
+ target_node: dict[str, object] | None = None,
+ visited_nodes: int = 0,
+) -> dict[str, object]:
+ distance = _distance_m(from_lat, from_lon, to_lat, to_lon)
+ duration_seconds = _connector_seconds(distance, mode)
+ return {
+ "dataset_id": dataset_id,
+ "mode": mode,
+ "engine": engine,
+ "distance_m": round(distance, 1),
+ "network_distance_m": 0,
+ "access_distance_m": round(distance, 1),
+ "duration_seconds": round(duration_seconds, 1),
+ "duration_minutes": _duration_minutes_ceil(duration_seconds),
+ "duration_label": _duration_label(duration_seconds),
+ "visited_nodes": visited_nodes,
+ "start_node": start_node,
+ "target_node": target_node,
+ "features": feature_collection([_connector_feature("direct", mode, [[from_lon, from_lat], [to_lon, to_lat]], distance)]),
+ }
+
+
+def _connector_feature(kind: str, mode: str, coordinates: list[list[float]], distance_m: float) -> dict:
+ return {
+ "type": "Feature",
+ "geometry": {"type": "LineString", "coordinates": coordinates},
+ "properties": {
+ "feature_type": "routing_connector",
+ "connector": kind,
+ "mode": mode,
+ "length_m": distance_m,
+ "cost_s": _connector_seconds(distance_m, mode),
+ },
+ }
+
+
+def _connector_seconds(distance_m: float, mode: str) -> float:
+ speed = 1.35 if mode == "walk" else 8.0
+ return float(distance_m) / speed
+
+
+def _duration_minutes_ceil(seconds: int | float | None) -> int | None:
+ if seconds is None:
+ return None
+ return max(0, int(math.ceil(float(seconds) / 60)))
+
+
+def _duration_label(seconds: int | float | None) -> str | None:
+ minutes_total = _duration_minutes_ceil(seconds)
+ if minutes_total is None:
+ return None
+ days = minutes_total // (24 * 60)
+ remaining = minutes_total % (24 * 60)
+ hours = remaining // 60
+ minutes = remaining % 60
+ if days:
+ return f"{days}d {hours:02d}:{minutes:02d}"
+ if hours:
+ return f"{hours}:{minutes:02d}"
+ return f"{minutes} min"
+
+
+def _expanded_bbox(min_lon: float, min_lat: float, max_lon: float, max_lat: float, padding_km: float) -> tuple[float, float, float, float]:
+ mid_lat = (min_lat + max_lat) / 2
+ lat_delta = padding_km / 111.0
+ lon_delta = padding_km / max(1.0, 111.0 * math.cos(math.radians(mid_lat)))
+ return (min_lon - lon_delta, min_lat - lat_delta, max_lon + lon_delta, max_lat + lat_delta)
+
+
+def _distance_m(lat_a: float, lon_a: float, lat_b: float, lon_b: float) -> float:
+ radius = 6_371_000.0
+ phi_a = math.radians(lat_a)
+ phi_b = math.radians(lat_b)
+ delta_phi = math.radians(lat_b - lat_a)
+ delta_lambda = math.radians(lon_b - lon_a)
+ hav = math.sin(delta_phi / 2) ** 2 + math.cos(phi_a) * math.cos(phi_b) * math.sin(delta_lambda / 2) ** 2
+ return radius * 2 * math.atan2(math.sqrt(hav), math.sqrt(1 - hav))
diff --git a/app/serializers.py b/app/serializers.py
new file mode 100644
index 0000000..389ad9a
--- /dev/null
+++ b/app/serializers.py
@@ -0,0 +1,130 @@
+from __future__ import annotations
+
+import json
+from typing import Any, Iterable
+
+from app.models import GtfsRoute, GtfsStop, OsmFeature, RouteMatch, RoutePattern
+from app.osm_storage import osm_feature_public_id
+
+
+def feature_collection(features: Iterable[dict[str, Any]]) -> dict[str, Any]:
+ return {"type": "FeatureCollection", "features": list(features)}
+
+
+def gtfs_route_feature(route: GtfsRoute, extra: dict[str, Any] | None = None) -> dict[str, Any] | None:
+ if not route.geometry_geojson:
+ return None
+ props = {
+ "id": route.id,
+ "dataset_id": route.dataset_id,
+ "route_id": route.route_id,
+ "mode": route.mode,
+ "route_scope": route.route_scope,
+ "ref": route.short_name,
+ "name": route.long_name,
+ "operator": route.operator_name,
+ "source": "gtfs",
+ }
+ if extra:
+ props.update(extra)
+ return {"type": "Feature", "geometry": json.loads(route.geometry_geojson), "properties": props}
+
+
+def osm_feature_feature(feature: OsmFeature, extra: dict[str, Any] | None = None) -> dict[str, Any] | None:
+ if not feature.geometry_geojson:
+ return None
+ props = {
+ "id": osm_feature_public_id(feature),
+ "row_id": feature.id,
+ "dataset_id": feature.dataset_id,
+ "osm_type": feature.osm_type,
+ "osm_id": feature.osm_id,
+ "kind": feature.kind,
+ "mode": feature.mode,
+ "route_scope": feature.route_scope,
+ "ref": feature.ref,
+ "name": feature.name,
+ "operator": feature.operator,
+ "network": feature.network,
+ "source": "osm",
+ }
+ if extra:
+ props.update(extra)
+ return {"type": "Feature", "geometry": json.loads(feature.geometry_geojson), "properties": props}
+
+
+def route_pattern_feature(pattern: RoutePattern, extra: dict[str, Any] | None = None) -> dict[str, Any] | None:
+ if not pattern.geometry_geojson:
+ return None
+ props = {
+ "id": pattern.id,
+ "route_pattern_id": pattern.id,
+ "route_ref": pattern.route_ref,
+ "ref": pattern.route_ref,
+ "name": pattern.route_name,
+ "mode": pattern.mode,
+ "route_scope": pattern.route_scope,
+ "operator": pattern.operator_name,
+ "source": "route_layer",
+ "source_kind": pattern.source_kind,
+ "status": pattern.status,
+ "confidence": pattern.confidence,
+ "osm_feature_id": pattern.osm_feature_id,
+ "gtfs_route_id": pattern.gtfs_route_id,
+ "gtfs_shape_id": pattern.gtfs_shape_id,
+ }
+ if extra:
+ props.update(extra)
+ return {"type": "Feature", "geometry": json.loads(pattern.geometry_geojson), "properties": props}
+
+
+def gtfs_stop_feature(stop: GtfsStop) -> dict[str, Any] | None:
+ if stop.lon is None or stop.lat is None:
+ return None
+ return {
+ "type": "Feature",
+ "geometry": {"type": "Point", "coordinates": [stop.lon, stop.lat]},
+ "properties": {
+ "id": stop.id,
+ "dataset_id": stop.dataset_id,
+ "stop_id": stop.stop_id,
+ "name": stop.name,
+ "source": "gtfs",
+ },
+ }
+
+
+def match_row(match: RouteMatch) -> dict[str, Any]:
+ route = match.gtfs_route
+ feature = match.osm_feature
+ return {
+ "id": match.id,
+ "status": match.status,
+ "confidence": match.confidence,
+ "rule_source": match.rule_source,
+ "gtfs": {
+ "id": route.id,
+ "dataset_id": route.dataset_id,
+ "route_id": route.route_id,
+ "mode": route.mode,
+ "route_scope": route.route_scope,
+ "ref": route.short_name,
+ "name": route.long_name,
+ "operator": route.operator_name,
+ },
+ "osm": None
+ if feature is None
+ else {
+ "id": feature.id,
+ "dataset_id": feature.dataset_id,
+ "osm_type": feature.osm_type,
+ "osm_id": feature.osm_id,
+ "mode": feature.mode,
+ "route_scope": feature.route_scope,
+ "ref": feature.ref,
+ "name": feature.name,
+ "operator": feature.operator,
+ "network": feature.network,
+ },
+ "reasons": json.loads(match.reasons_json or "{}"),
+ }
diff --git a/app/source_catalog.py b/app/source_catalog.py
new file mode 100644
index 0000000..b2b8d79
--- /dev/null
+++ b/app/source_catalog.py
@@ -0,0 +1,309 @@
+from __future__ import annotations
+
+import csv
+import hashlib
+from datetime import datetime, timezone
+from pathlib import Path
+from typing import Iterable
+
+from sqlalchemy import func, or_, select
+from sqlalchemy.orm import Session
+
+from app.models import Source, SourceCatalogEntry
+
+
+DIRECT_INGEST_KINDS = {"gtfs", "osm_geojson", "osm_pbf"}
+
+
+def default_source_catalog_path() -> Path:
+ return Path(__file__).resolve().parents[1] / "docs" / "source_catalog_seed.csv"
+
+
+def default_ingestable_sources_path() -> Path:
+ return Path(__file__).resolve().parents[1] / "docs" / "ingestable_sources_seed.csv"
+
+
+def import_source_catalog(session: Session, path: Path | str | None = None, *, update_existing: bool = True) -> dict[str, int]:
+ csv_path = _resolve_path(path, default_source_catalog_path())
+ rows = _read_csv(csv_path)
+ created = 0
+ updated = 0
+ skipped = 0
+ for row in rows:
+ source_name = _value(row, "Source name")
+ if not source_name:
+ skipped += 1
+ continue
+ payload = {
+ "catalog_key": _catalog_key(row),
+ "geography": _value(row, "Geography"),
+ "country_code": _value(row, "Country code"),
+ "mode_scope": _value(row, "Mode scope"),
+ "source_name": source_name,
+ "source_category": _value(row, "Source category"),
+ "formats_apis": _value(row, "Formats / APIs"),
+ "availability": _value(row, "Availability"),
+ "coverage_notes": _value(row, "Coverage notes"),
+ "geometry_notes": _value(row, "Supersedes OSM for"),
+ "disruptions_closures": _value(row, "Disruptions / closures"),
+ "operator_list_use": _value(row, "Operator-list use"),
+ "access_license_notes": _value(row, "Access / licence notes"),
+ "priority": _value(row, "Priority"),
+ "source_url": _value(row, "Source URL"),
+ "evidence_url": _value(row, "Evidence URL"),
+ "next_pipeline_action": _value(row, "Next pipeline action"),
+ }
+ existing = session.scalar(select(SourceCatalogEntry).where(SourceCatalogEntry.catalog_key == payload["catalog_key"]))
+ if existing is None:
+ session.add(SourceCatalogEntry(**payload))
+ created += 1
+ continue
+ if not update_existing:
+ skipped += 1
+ continue
+ for key, value in payload.items():
+ setattr(existing, key, value)
+ existing.updated_at = datetime.now(timezone.utc)
+ updated += 1
+ session.flush()
+ return {"created": created, "updated": updated, "skipped": skipped}
+
+
+def import_ingestable_sources(
+ session: Session,
+ path: Path | str | None = None,
+ *,
+ update_existing: bool = True,
+) -> dict[str, int]:
+ csv_path = _resolve_path(path, default_ingestable_sources_path())
+ rows = _read_csv(csv_path)
+ created = 0
+ updated = 0
+ skipped = 0
+ linked_catalog = 0
+ for row in rows:
+ name = _value(row, "name")
+ kind = (_value(row, "kind") or "").lower()
+ url = _value(row, "url")
+ if not name or not url or kind not in DIRECT_INGEST_KINDS:
+ skipped += 1
+ continue
+ catalog_entry = _catalog_entry_for_ingestable_row(session, row)
+ payload = {
+ "name": name,
+ "kind": kind,
+ "url": url,
+ "country": _value(row, "country"),
+ "license": _value(row, "license"),
+ "priority": _value(row, "priority"),
+ "mode_scope": _value(row, "mode_scope"),
+ "source_basis": _value(row, "source_basis"),
+ "notes": _value(row, "notes"),
+ "catalog_entry_id": None if catalog_entry is None else catalog_entry.id,
+ }
+ existing = session.scalar(
+ select(Source)
+ .where(Source.kind == kind, Source.url == url)
+ .order_by(Source.id)
+ .limit(1)
+ )
+ if existing is None:
+ existing = session.scalar(select(Source).where(Source.name == name, Source.url == url).order_by(Source.id).limit(1))
+ if existing is None:
+ session.add(Source(**payload))
+ created += 1
+ if catalog_entry is not None:
+ linked_catalog += 1
+ continue
+ if not update_existing:
+ skipped += 1
+ continue
+ for key, value in payload.items():
+ setattr(existing, key, value)
+ existing.enabled = True
+ updated += 1
+ if catalog_entry is not None:
+ linked_catalog += 1
+ session.flush()
+ return {"created": created, "updated": updated, "skipped": skipped, "linked_catalog": linked_catalog}
+
+
+def source_catalog_summary(session: Session) -> dict[str, object]:
+ priority_counts = {
+ priority or "unknown": count
+ for priority, count in session.execute(
+ select(SourceCatalogEntry.priority, func.count()).group_by(SourceCatalogEntry.priority)
+ ).all()
+ }
+ status_counts = {
+ status or "unknown": count
+ for status, count in session.execute(select(SourceCatalogEntry.status, func.count()).group_by(SourceCatalogEntry.status)).all()
+ }
+ ingestable_sources = session.scalar(
+ select(func.count()).select_from(Source).where(Source.source_basis.is_not(None) | Source.priority.is_not(None))
+ ) or 0
+ return {
+ "catalog_entries": session.scalar(select(func.count()).select_from(SourceCatalogEntry)) or 0,
+ "catalog_by_priority": priority_counts,
+ "catalog_by_status": status_counts,
+ "seeded_ingestable_sources": ingestable_sources,
+ }
+
+
+def source_catalog_rows(
+ session: Session,
+ *,
+ q: str | None = None,
+ country: str | None = None,
+ priority: str | None = None,
+ status: str | None = None,
+ limit: int = 100,
+) -> list[SourceCatalogEntry]:
+ stmt = select(SourceCatalogEntry).order_by(
+ SourceCatalogEntry.priority,
+ SourceCatalogEntry.country_code,
+ SourceCatalogEntry.source_name,
+ SourceCatalogEntry.id,
+ )
+ if q:
+ pattern = f"%{q.strip()}%"
+ stmt = stmt.where(
+ or_(
+ SourceCatalogEntry.source_name.ilike(pattern),
+ SourceCatalogEntry.source_category.ilike(pattern),
+ SourceCatalogEntry.formats_apis.ilike(pattern),
+ SourceCatalogEntry.coverage_notes.ilike(pattern),
+ SourceCatalogEntry.next_pipeline_action.ilike(pattern),
+ )
+ )
+ if country:
+ stmt = stmt.where(SourceCatalogEntry.country_code.ilike(f"%{country.strip()}%"))
+ if priority:
+ stmt = stmt.where(SourceCatalogEntry.priority == priority.strip())
+ if status:
+ stmt = stmt.where(SourceCatalogEntry.status == status.strip())
+ return session.scalars(stmt.limit(max(1, min(limit, 500)))).all()
+
+
+def catalog_entry_payload(entry: SourceCatalogEntry, *, linked_source_count: int = 0) -> dict[str, object]:
+ return {
+ "id": entry.id,
+ "geography": entry.geography,
+ "country_code": entry.country_code,
+ "mode_scope": entry.mode_scope,
+ "source_name": entry.source_name,
+ "source_category": entry.source_category,
+ "formats_apis": entry.formats_apis,
+ "availability": entry.availability,
+ "coverage_notes": entry.coverage_notes,
+ "geometry_notes": entry.geometry_notes,
+ "disruptions_closures": entry.disruptions_closures,
+ "operator_list_use": entry.operator_list_use,
+ "access_license_notes": entry.access_license_notes,
+ "priority": entry.priority,
+ "source_url": entry.source_url,
+ "evidence_url": entry.evidence_url,
+ "next_pipeline_action": entry.next_pipeline_action,
+ "status": entry.status,
+ "linked_source_count": linked_source_count,
+ "created_at": entry.created_at.isoformat() if entry.created_at else None,
+ "updated_at": entry.updated_at.isoformat() if entry.updated_at else None,
+ }
+
+
+def linked_source_counts(session: Session, entries: Iterable[SourceCatalogEntry]) -> dict[int, int]:
+ entry_ids = [entry.id for entry in entries]
+ if not entry_ids:
+ return {}
+ return {
+ entry_id: count
+ for entry_id, count in session.execute(
+ select(Source.catalog_entry_id, func.count())
+ .where(Source.catalog_entry_id.in_(entry_ids))
+ .group_by(Source.catalog_entry_id)
+ ).all()
+ if entry_id is not None
+ }
+
+
+def _catalog_entry_for_ingestable_row(session: Session, row: dict[str, str]) -> SourceCatalogEntry | None:
+ country = _value(row, "country")
+ source_basis = _value(row, "source_basis")
+ name = _value(row, "name")
+ if not country and not source_basis and not name:
+ return None
+ if name:
+ exact = session.scalar(
+ select(SourceCatalogEntry)
+ .where(func.lower(SourceCatalogEntry.source_name) == name.lower())
+ .order_by(SourceCatalogEntry.id)
+ .limit(1)
+ )
+ if exact is not None:
+ return exact
+ clauses = []
+ if country:
+ clauses.append(SourceCatalogEntry.country_code.ilike(f"%{country}%"))
+ if source_basis:
+ for token in _basis_tokens(source_basis):
+ clauses.append(SourceCatalogEntry.source_name.ilike(f"%{token}%"))
+ clauses.append(SourceCatalogEntry.coverage_notes.ilike(f"%{token}%"))
+ if name:
+ first_word = name.split()[0]
+ if len(first_word) > 2:
+ clauses.append(SourceCatalogEntry.source_name.ilike(f"%{first_word}%"))
+ if not clauses:
+ return None
+ return session.scalar(
+ select(SourceCatalogEntry)
+ .where(or_(*clauses))
+ .order_by(SourceCatalogEntry.priority, SourceCatalogEntry.id)
+ .limit(1)
+ )
+
+
+def _basis_tokens(value: str) -> list[str]:
+ tokens = []
+ for raw in value.replace("/", " ").replace("-", " ").split():
+ token = raw.strip(" ,.;()")
+ if len(token) >= 5 and token.lower() not in {"official", "mirror", "feeds", "transport"}:
+ tokens.append(token)
+ return tokens[:4]
+
+
+def _catalog_key(row: dict[str, str]) -> str:
+ parts = [
+ _value(row, "Country code"),
+ _value(row, "Source name"),
+ _value(row, "Source URL"),
+ _value(row, "Formats / APIs"),
+ ]
+ text = "|".join(part.lower() for part in parts if part)
+ if not text:
+ text = repr(sorted(row.items()))
+ return hashlib.sha256(text.encode("utf-8")).hexdigest()
+
+
+def _read_csv(path: Path) -> list[dict[str, str]]:
+ if not path.exists():
+ raise FileNotFoundError(path)
+ with path.open("r", encoding="utf-8-sig", newline="") as handle:
+ reader = csv.DictReader(handle)
+ return [dict(row) for row in reader]
+
+
+def _resolve_path(path: Path | str | None, default_path: Path) -> Path:
+ if path is None:
+ return default_path
+ candidate = Path(path)
+ if candidate.is_absolute():
+ return candidate
+ return Path.cwd() / candidate
+
+
+def _value(row: dict[str, str], key: str) -> str | None:
+ value = row.get(key)
+ if value is None:
+ return None
+ stripped = value.strip()
+ return stripped or None
diff --git a/app/source_updates.py b/app/source_updates.py
new file mode 100644
index 0000000..effd330
--- /dev/null
+++ b/app/source_updates.py
@@ -0,0 +1,256 @@
+from __future__ import annotations
+
+import json
+from datetime import datetime, timezone
+from pathlib import Path
+from urllib.parse import urlparse
+
+import requests
+from sqlalchemy import select
+from sqlalchemy.orm import Session
+
+from app.config import settings
+from app.models import Dataset, Source, SourceUpdateCheck
+from app.pipeline.utils import norm_text, sha256_file
+
+
+def check_source_for_update(session: Session, source: Source) -> SourceUpdateCheck:
+ active_dataset = session.scalar(
+ select(Dataset)
+ .where(Dataset.source_id == source.id, Dataset.is_active.is_(True))
+ .order_by(Dataset.created_at.desc(), Dataset.id.desc())
+ )
+ recovery = _recover_missing_managed_cache_url(source)
+ remote = _source_remote_metadata(source)
+ if recovery is not None:
+ remote["recovered_source_url"] = recovery["url"]
+ remote["previous_source_url"] = recovery["previous_url"]
+ update_available, reason = _update_decision(active_dataset, remote)
+ check = SourceUpdateCheck(
+ source_id=source.id,
+ status=remote["status"],
+ update_available=update_available,
+ reason=reason,
+ remote_url=source.url,
+ etag=remote.get("etag"),
+ last_modified=remote.get("last_modified"),
+ content_length=remote.get("content_length"),
+ content_type=remote.get("content_type"),
+ local_mtime=remote.get("local_mtime"),
+ local_size=remote.get("local_size"),
+ local_sha256=remote.get("local_sha256"),
+ active_dataset_id=None if active_dataset is None else active_dataset.id,
+ active_dataset_sha256=None if active_dataset is None else active_dataset.sha256,
+ metadata_json=json.dumps(remote, separators=(",", ":"), default=_json_default),
+ )
+ session.add(check)
+ source.status = "update_check_error" if remote["status"] != "checked" else "update_available" if update_available else "up_to_date"
+ source.last_error = None if remote["status"] == "checked" else reason
+ session.flush()
+ return check
+
+
+def latest_source_update_check(session: Session, source_id: int) -> SourceUpdateCheck | None:
+ return session.scalar(
+ select(SourceUpdateCheck)
+ .where(SourceUpdateCheck.source_id == source_id)
+ .order_by(SourceUpdateCheck.checked_at.desc(), SourceUpdateCheck.id.desc())
+ )
+
+
+def update_check_payload(check: SourceUpdateCheck | None) -> dict | None:
+ if check is None:
+ return None
+ try:
+ metadata = json.loads(check.metadata_json or "{}")
+ except json.JSONDecodeError:
+ metadata = {}
+ return {
+ "id": check.id,
+ "source_id": check.source_id,
+ "checked_at": check.checked_at.isoformat() if check.checked_at else None,
+ "status": check.status,
+ "update_available": check.update_available,
+ "reason": check.reason,
+ "etag": check.etag,
+ "last_modified": check.last_modified,
+ "content_length": check.content_length,
+ "content_type": check.content_type,
+ "local_mtime": check.local_mtime.isoformat() if check.local_mtime else None,
+ "local_size": check.local_size,
+ "local_sha256": check.local_sha256,
+ "active_dataset_id": check.active_dataset_id,
+ "active_dataset_sha256": check.active_dataset_sha256,
+ "metadata": metadata,
+ }
+
+
+def record_dataset_update_metadata(dataset: Dataset, check: SourceUpdateCheck | None) -> None:
+ if check is None:
+ return
+ try:
+ metadata = json.loads(dataset.metadata_json or "{}")
+ except json.JSONDecodeError:
+ metadata = {}
+ metadata["source_update_check"] = {
+ "id": check.id,
+ "checked_at": check.checked_at.isoformat() if check.checked_at else None,
+ "etag": check.etag,
+ "last_modified": check.last_modified,
+ "content_length": check.content_length,
+ "content_type": check.content_type,
+ "local_mtime": check.local_mtime.isoformat() if check.local_mtime else None,
+ "local_size": check.local_size,
+ "local_sha256": check.local_sha256,
+ "metadata": update_check_payload(check).get("metadata", {}),
+ }
+ dataset.metadata_json = json.dumps(metadata, indent=2, default=_json_default)
+
+
+def _source_remote_metadata(source: Source) -> dict:
+ parsed = urlparse(source.url)
+ if parsed.scheme in {"http", "https"}:
+ return _http_metadata(source.url)
+ path = Path(parsed.path) if parsed.scheme == "file" else Path(source.url)
+ return _local_metadata(path)
+
+
+def _recover_missing_managed_cache_url(source: Source) -> dict | None:
+ parsed = urlparse(source.url)
+ if parsed.scheme in {"http", "https"}:
+ return None
+ path = Path(parsed.path) if parsed.scheme == "file" else Path(source.url)
+ if path.exists() or not _is_managed_source_cache_path(path, source.id):
+ return None
+ replacement = _seed_source_url_for(source)
+ if replacement is None:
+ return None
+ previous_url = source.url
+ source.url = replacement
+ return {"previous_url": previous_url, "url": replacement}
+
+
+def _is_managed_source_cache_path(path: Path, source_id: int) -> bool:
+ source_dir = f"source_{source_id}"
+ try:
+ resolved = path.resolve()
+ managed_dir = (settings.data_dir / "sources" / source_dir).resolve()
+ resolved.relative_to(managed_dir)
+ return True
+ except ValueError:
+ pass
+ parts = path.parts
+ return any(part == "sources" and index + 1 < len(parts) and parts[index + 1] == source_dir for index, part in enumerate(parts))
+
+
+def _seed_source_url_for(source: Source) -> str | None:
+ seed_path = Path(__file__).resolve().parents[1] / "scripts" / "example_sources.json"
+ if not seed_path.exists():
+ return None
+ try:
+ rows = json.loads(seed_path.read_text(encoding="utf-8"))
+ except (OSError, json.JSONDecodeError):
+ return None
+ source_tokens = set(norm_text(source.name).split())
+ for row in rows if isinstance(rows, list) else []:
+ if not isinstance(row, dict):
+ continue
+ url = str(row.get("url") or "")
+ if urlparse(url).scheme not in {"http", "https"}:
+ continue
+ if row.get("kind") != source.kind:
+ continue
+ if source.country and row.get("country") and str(row.get("country")) != source.country:
+ continue
+ row_tokens = set(norm_text(row.get("name")).split())
+ if row_tokens and (row_tokens <= source_tokens or source_tokens <= row_tokens):
+ return url
+ return None
+
+
+def _http_metadata(url: str) -> dict:
+ response = None
+ try:
+ response = requests.head(url, allow_redirects=True, timeout=30)
+ if response.status_code in {405, 501}:
+ response.close()
+ response = requests.get(url, stream=True, timeout=30)
+ response.raise_for_status()
+ except Exception as exc: # noqa: BLE001 - persisted as update-check status
+ return {"status": "error", "error": str(exc)}
+ finally:
+ if response is not None:
+ response.close()
+ headers = response.headers
+ content_length = headers.get("Content-Length")
+ return {
+ "status": "checked",
+ "etag": headers.get("ETag"),
+ "last_modified": headers.get("Last-Modified"),
+ "content_length": int(content_length) if content_length and content_length.isdigit() else None,
+ "content_type": headers.get("Content-Type"),
+ "final_url": response.url,
+ "update_artifact": _update_artifact(url, headers.get("Content-Type")),
+ }
+
+
+def _local_metadata(path: Path) -> dict:
+ if not path.exists():
+ return {"status": "error", "error": f"Source file does not exist: {path}"}
+ stat = path.stat()
+ return {
+ "status": "checked",
+ "local_mtime": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc),
+ "local_size": stat.st_size,
+ "local_sha256": sha256_file(path),
+ "update_artifact": _update_artifact(str(path), None),
+ }
+
+
+def _update_decision(active_dataset: Dataset | None, remote: dict) -> tuple[bool, str]:
+ if remote["status"] != "checked":
+ return False, remote.get("error") or "update check failed"
+ if active_dataset is None:
+ return True, "no active dataset imported"
+ if remote.get("local_sha256"):
+ if remote["local_sha256"] == active_dataset.sha256:
+ return False, "local file hash matches active dataset"
+ return True, "local file hash differs from active dataset"
+
+ previous = _dataset_update_metadata(active_dataset)
+ comparable = []
+ for key in ("etag", "last_modified", "content_length"):
+ current = remote.get(key)
+ old = previous.get(key)
+ if current is not None and old is not None:
+ comparable.append(key)
+ if str(current) != str(old):
+ return True, f"remote {key} changed"
+ if comparable:
+ return False, "remote metadata matches active dataset"
+ return True, "no previous remote metadata recorded"
+
+
+def _dataset_update_metadata(dataset: Dataset) -> dict:
+ try:
+ metadata = json.loads(dataset.metadata_json or "{}")
+ except json.JSONDecodeError:
+ return {}
+ return metadata.get("source_update_check") or {}
+
+
+def _json_default(value):
+ if isinstance(value, datetime):
+ return value.isoformat()
+ raise TypeError(f"Object of type {type(value).__name__} is not JSON serializable")
+
+
+def _update_artifact(url_or_path: str, content_type: str | None) -> dict:
+ lower = url_or_path.lower()
+ is_osm_diff = lower.endswith(".osc") or lower.endswith(".osc.gz")
+ is_gtfs_zip = lower.endswith(".zip") or (content_type or "").lower() in {"application/zip", "application/x-zip-compressed"}
+ return {
+ "kind": "osm_diff" if is_osm_diff else "gtfs_or_archive" if is_gtfs_zip else "full_snapshot",
+ "is_diff": is_osm_diff,
+ "content_type": content_type,
+ }
diff --git a/app/spatial.py b/app/spatial.py
new file mode 100644
index 0000000..368daa8
--- /dev/null
+++ b/app/spatial.py
@@ -0,0 +1,158 @@
+from __future__ import annotations
+
+from collections.abc import Iterable
+
+from sqlalchemy import text
+from sqlalchemy.orm import Session
+
+from app.config import settings
+
+
+POSTGIS_GEOMETRY_TABLES = {
+ "osm_features",
+ "gtfs_routes",
+ "gtfs_shapes",
+ "gtfs_stops",
+ "canonical_stops",
+ "route_patterns",
+ "osm_addresses",
+ "routing_nodes",
+ "routing_edges",
+}
+
+
+def using_postgresql() -> bool:
+ return settings.is_postgresql_database
+
+
+def refresh_postgis_geometries(
+ session: Session,
+ *,
+ dataset_id: int | None = None,
+ tables: Iterable[str] | None = None,
+ only_missing: bool = True,
+) -> None:
+ if not using_postgresql():
+ return
+ selected = set(tables or POSTGIS_GEOMETRY_TABLES)
+ unknown = selected - POSTGIS_GEOMETRY_TABLES
+ if unknown:
+ raise ValueError(f"Unsupported PostGIS geometry table(s): {', '.join(sorted(unknown))}")
+
+ if "osm_features" in selected:
+ _refresh_geojson_geometry(session, "osm_features", dataset_id=dataset_id, only_missing=only_missing)
+ if "gtfs_routes" in selected:
+ _refresh_geojson_geometry(session, "gtfs_routes", dataset_id=dataset_id, only_missing=only_missing)
+ if "gtfs_shapes" in selected:
+ _refresh_geojson_geometry(session, "gtfs_shapes", dataset_id=dataset_id, only_missing=only_missing)
+ if "route_patterns" in selected:
+ _refresh_geojson_geometry(session, "route_patterns", dataset_id=None, only_missing=only_missing)
+ if "osm_addresses" in selected:
+ _refresh_address_geometry(session, dataset_id=dataset_id, only_missing=only_missing)
+ if "gtfs_stops" in selected:
+ _refresh_point_geometry(session, "gtfs_stops", dataset_id=dataset_id, only_missing=only_missing)
+ if "canonical_stops" in selected:
+ _refresh_point_geometry(session, "canonical_stops", dataset_id=None, only_missing=only_missing)
+ if "routing_nodes" in selected:
+ _refresh_point_geometry(session, "routing_nodes", dataset_id=dataset_id, only_missing=only_missing)
+ if "routing_edges" in selected:
+ _refresh_routing_edge_geometry(session, dataset_id=dataset_id, only_missing=only_missing)
+
+
+def analyze_postgresql_tables(session: Session, tables: Iterable[str]) -> None:
+ if not using_postgresql():
+ return
+ for table in tables:
+ session.execute(text(f"ANALYZE {table}"))
+
+
+def _refresh_geojson_geometry(session: Session, table: str, *, dataset_id: int | None, only_missing: bool) -> None:
+ where = ["geometry_geojson IS NOT NULL", "geometry_geojson <> ''"]
+ params: dict[str, object] = {}
+ if dataset_id is not None:
+ where.append("dataset_id = :dataset_id")
+ params["dataset_id"] = int(dataset_id)
+ if only_missing:
+ where.append("geom IS NULL")
+ session.execute(
+ text(
+ f"""
+ UPDATE {table}
+ SET geom = ST_SetSRID(ST_GeomFromGeoJSON(geometry_geojson), 4326)
+ WHERE {" AND ".join(where)}
+ """
+ ),
+ params,
+ )
+
+
+def _refresh_point_geometry(session: Session, table: str, *, dataset_id: int | None, only_missing: bool) -> None:
+ where = ["lon IS NOT NULL", "lat IS NOT NULL"]
+ params: dict[str, object] = {}
+ if dataset_id is not None:
+ where.append("dataset_id = :dataset_id")
+ params["dataset_id"] = int(dataset_id)
+ if only_missing:
+ where.append("geom IS NULL")
+ session.execute(
+ text(
+ f"""
+ UPDATE {table}
+ SET geom = ST_SetSRID(ST_MakePoint(lon, lat), 4326)
+ WHERE {" AND ".join(where)}
+ """
+ ),
+ params,
+ )
+
+
+def _refresh_address_geometry(session: Session, *, dataset_id: int | None, only_missing: bool) -> None:
+ _refresh_point_geometry(session, "osm_addresses", dataset_id=dataset_id, only_missing=only_missing)
+ where = ["geometry_geojson IS NOT NULL", "geometry_geojson <> ''"]
+ params: dict[str, object] = {}
+ if dataset_id is not None:
+ where.append("dataset_id = :dataset_id")
+ params["dataset_id"] = int(dataset_id)
+ if only_missing:
+ where.append("area_geom IS NULL")
+ session.execute(
+ text(
+ f"""
+ UPDATE osm_addresses
+ SET area_geom = ST_SetSRID(ST_GeomFromGeoJSON(geometry_geojson), 4326)
+ WHERE {" AND ".join(where)}
+ """
+ ),
+ params,
+ )
+
+
+def _refresh_routing_edge_geometry(session: Session, *, dataset_id: int | None, only_missing: bool) -> None:
+ where = [
+ "source_lon IS NOT NULL",
+ "source_lat IS NOT NULL",
+ "target_lon IS NOT NULL",
+ "target_lat IS NOT NULL",
+ ]
+ params: dict[str, object] = {}
+ if dataset_id is not None:
+ where.append("dataset_id = :dataset_id")
+ params["dataset_id"] = int(dataset_id)
+ if only_missing:
+ where.append("geom IS NULL")
+ session.execute(
+ text(
+ f"""
+ UPDATE routing_edges
+ SET geom = ST_SetSRID(
+ ST_MakeLine(
+ ST_MakePoint(source_lon, source_lat),
+ ST_MakePoint(target_lon, target_lat)
+ ),
+ 4326
+ )
+ WHERE {" AND ".join(where)}
+ """
+ ),
+ params,
+ )
diff --git a/app/static/app.js b/app/static/app.js
new file mode 100644
index 0000000..0d92783
--- /dev/null
+++ b/app/static/app.js
@@ -0,0 +1,4090 @@
+let map;
+let layerLoadTimer;
+let layerLoadSequence = 0;
+let layers = {};
+let layerState = {};
+let layerCounts = {};
+let layerLoading = {};
+let layerGroups = [];
+let savedLayerState;
+let journeyLayer;
+let datasetSearchLayer;
+let candidatePreviewLayer;
+let candidatePreviewData;
+let selectedCandidatePreviewId;
+let journeySearchTimers = {};
+let journeyStopAbortControllers = {};
+let datasetSearchTimer;
+let datasetSearchAbortController;
+let datasetSearchSequence = 0;
+let activeJourneySearchId;
+let journeySearchPollTimer;
+let lastJourneyResponse;
+let lastJourneyDrawSignature;
+let lastItineraries = [];
+let journeyStopSearchSequence = {};
+let journeyContextPopup;
+let selectedDatasetSearchKey;
+let activeJobPollTimer;
+let activeJobDetailsId;
+let jobDetailsPollTimer;
+let jobListRevision;
+let jobListRefreshTimer;
+let jobListRefreshInFlight = false;
+let jobListRefreshFailureShown = false;
+let layerLoadAbortController;
+let allSources = [];
+let sourceCatalogEntries = [];
+let sourceCatalogSummary = {};
+const JOB_DETAILS_POLL_MS = 4000;
+const JOB_LIST_REFRESH_MS = 5000;
+const JOB_LIST_REFRESH_HIDDEN_MS = 15000;
+const JOURNEY_STOP_SEARCH_DEBOUNCE_MS = 400;
+const SIDEBAR_COLLAPSED_STORAGE_KEY = 'mobilitySidebarCollapsed';
+const MAP_VIEW_STORAGE_KEY = 'mobilityMapView';
+const DEFAULT_MAP_VIEW = { center: [52.52, 13.405], zoom: 11 };
+
+const osmRouteModes = [
+ { label: 'Rail: long-distance', mode: 'train', routeScope: 'long_distance', color: '#1d4ed8', enabled: true, minZoom: 5, baseWeight: 3.4, detailWeight: 6, tooltipMinZoom: 10 },
+ { label: 'Rail: regional', mode: 'train', routeScope: 'regional', color: '#2563eb', enabled: true, minZoom: 7, baseWeight: 3, detailWeight: 5.4, tooltipMinZoom: 11 },
+ { label: 'Rail: local/S-Bahn', mode: 'train', routeScope: 'local', color: '#0f766e', enabled: true, minZoom: 10, baseWeight: 2.6, detailWeight: 4.8, tooltipMinZoom: 12 },
+ { label: 'Rail: unknown', mode: 'train', routeScope: 'unknown', color: '#64748b', enabled: false, minZoom: 10, baseWeight: 2.4, detailWeight: 4.4, tooltipMinZoom: 13 },
+ { label: 'Bus: long-distance', mode: 'bus,coach', routeScope: 'long_distance', color: '#9333ea', enabled: false, minZoom: 7, baseWeight: 2.6, detailWeight: 5, tooltipMinZoom: 11 },
+ { label: 'Bus: regional', mode: 'bus,trolleybus', routeScope: 'regional', color: '#ea580c', enabled: true, minZoom: 10, baseWeight: 2.2, detailWeight: 4.6, tooltipMinZoom: 13 },
+ { label: 'Bus: local', mode: 'bus,trolleybus', routeScope: 'local', color: '#d97706', enabled: true, minZoom: 12, baseWeight: 2, detailWeight: 4.2, tooltipMinZoom: 14 },
+ { label: 'Tram/light rail', mode: 'tram,light_rail', routeScope: 'local', color: '#7c3aed', enabled: true, minZoom: 11, baseWeight: 2.4, detailWeight: 4.8, tooltipMinZoom: 13 },
+ { label: 'Subway', mode: 'subway', routeScope: 'local', color: '#dc2626', enabled: true, minZoom: 10, baseWeight: 2.8, detailWeight: 5.2, tooltipMinZoom: 12 },
+ { label: 'Ferry', mode: 'ferry', routeScope: 'local', color: '#0891b2', enabled: true, minZoom: 10, baseWeight: 2.4, detailWeight: 4.6, tooltipMinZoom: 13 },
+ { label: 'Other routes', mode: 'monorail,funicular,aerialway', routeScope: 'local', color: '#64748b', enabled: false, minZoom: 11, baseWeight: 2.2, detailWeight: 4.2, tooltipMinZoom: 13 }
+];
+
+const matchStatusLayers = [
+ ['Matched', 'matched', '#16a34a'],
+ ['Accepted', 'accepted', '#15803d'],
+ ['Probable', 'probable', '#ca8a04'],
+ ['Weak', 'weak', '#ea580c'],
+ ['Missing', 'missing', '#dc2626']
+];
+
+function zoomLineStyle(color, { baseWeight = 3, detailWeight = 5, opacity = 0.72, detailOpacity = 0.9, dashArray } = {}) {
+ return { color, weight: baseWeight, detailWeight, opacity, detailOpacity, dashArray, zoomResponsive: true };
+}
+
+function routeLayer(id, label, mode, color, sourceId, defaultEnabled = true, options = {}) {
+ const params = { kind: 'route', mode, source_id: String(sourceId) };
+ if (options.routeScope) params.route_scope = options.routeScope;
+ return {
+ id,
+ label,
+ category: 'osm-route',
+ endpoint: '/api/map/osm_features.geojson',
+ params,
+ minZoom: options.minZoom ?? 9,
+ defaultEnabled,
+ style: zoomLineStyle(color, {
+ baseWeight: options.baseWeight ?? 3,
+ detailWeight: options.detailWeight ?? 5,
+ opacity: options.opacity ?? 0.68,
+ detailOpacity: options.detailOpacity ?? 0.86,
+ dashArray: options.dashArray
+ }),
+ tooltipMinZoom: options.tooltipMinZoom,
+ limit: 5000
+ };
+}
+
+function routePatternLayer(id, label, mode, sourceKind, color, defaultEnabled = true, options = {}) {
+ const params = { mode, source_kind: sourceKind };
+ if (options.routeScope) params.route_scope = options.routeScope;
+ return {
+ id,
+ label,
+ category: 'route-layer',
+ endpoint: '/api/map/route_patterns.geojson',
+ params,
+ minZoom: options.minZoom ?? 9,
+ defaultEnabled,
+ style: zoomLineStyle(color, {
+ baseWeight: options.baseWeight ?? (sourceKind === 'gtfs_proposed' ? 2.8 : 3.4),
+ detailWeight: options.detailWeight ?? (sourceKind === 'gtfs_proposed' ? 4.2 : 5.6),
+ opacity: sourceKind === 'gtfs_proposed' ? 0.48 : 0.78,
+ detailOpacity: sourceKind === 'gtfs_proposed' ? 0.64 : 0.92,
+ dashArray: sourceKind === 'gtfs_proposed' ? '7 5' : undefined
+ }),
+ tooltipMinZoom: options.tooltipMinZoom,
+ limit: 7000
+ };
+}
+
+function setLayerGroupsFromSources(sources) {
+ layerGroups = buildLayerGroupsFromSources(sources);
+ initializeLayerState();
+ renderLayerControls();
+}
+
+function buildLayerGroupsFromSources(sources) {
+ const groups = [
+ {
+ id: 'routeLayer',
+ label: 'Route layer',
+ children: [
+ routePatternLayer('routeLayerRailLongDistance', 'Rail: long-distance', 'train', 'osm', '#1d4ed8', true, { routeScope: 'long_distance', minZoom: 5, baseWeight: 3.6, detailWeight: 6.2, tooltipMinZoom: 10 }),
+ routePatternLayer('routeLayerRailRegional', 'Rail: regional', 'train', 'osm', '#2563eb', true, { routeScope: 'regional', minZoom: 7, baseWeight: 3.2, detailWeight: 5.8, tooltipMinZoom: 11 }),
+ routePatternLayer('routeLayerRailLocal', 'Rail: local/S-Bahn', 'train', 'osm', '#0f766e', true, { routeScope: 'local', minZoom: 10, baseWeight: 2.8, detailWeight: 5, tooltipMinZoom: 12 }),
+ routePatternLayer('routeLayerRailUnknown', 'Rail: unknown', 'train', 'osm', '#64748b', false, { routeScope: 'unknown', minZoom: 10, baseWeight: 2.4, detailWeight: 4.4, tooltipMinZoom: 13 }),
+ routePatternLayer('routeLayerBusLongDistance', 'Bus: long-distance', 'bus,coach', 'osm', '#9333ea', false, { routeScope: 'long_distance', minZoom: 7, baseWeight: 2.8, detailWeight: 5, tooltipMinZoom: 11 }),
+ routePatternLayer('routeLayerBusRegional', 'Bus: regional', 'bus,trolleybus', 'osm', '#ea580c', true, { routeScope: 'regional', minZoom: 10, baseWeight: 2.4, detailWeight: 4.8, tooltipMinZoom: 13 }),
+ routePatternLayer('routeLayerBusLocal', 'Bus: local', 'bus,trolleybus', 'osm', '#d97706', true, { routeScope: 'local', minZoom: 12, baseWeight: 2.2, detailWeight: 4.4, tooltipMinZoom: 14 }),
+ routePatternLayer('routeLayerTram', 'Tram/light rail', 'tram,light_rail', 'osm', '#7c3aed', true, { routeScope: 'local', minZoom: 11, baseWeight: 2.6, detailWeight: 5, tooltipMinZoom: 13 }),
+ routePatternLayer('routeLayerSubway', 'Subway', 'subway', 'osm', '#dc2626', true, { routeScope: 'local', minZoom: 10, baseWeight: 3, detailWeight: 5.4, tooltipMinZoom: 12 }),
+ routePatternLayer('routeLayerFerry', 'Ferry', 'ferry', 'osm', '#0891b2', true, { routeScope: 'local', minZoom: 10, baseWeight: 2.4, detailWeight: 4.6, tooltipMinZoom: 13 }),
+ routePatternLayer('routeLayerProposed', 'GTFS proposed', 'train,subway,tram,bus,coach,trolleybus,ferry,light_rail', 'gtfs_proposed', '#111827', false)
+ ]
+ }
+ ];
+ sources.filter(hasActiveGtfsDataset).forEach(source => {
+ const suffix = `Source${source.id}`;
+ groups.push({
+ id: `gtfs${suffix}`,
+ label: `GTFS: ${source.name}`,
+ children: [
+ {
+ id: `gtfsRoutes${suffix}`,
+ label: 'Routes',
+ category: 'gtfs-route',
+ endpoint: '/api/map/gtfs_routes.geojson',
+ params: { source_id: String(source.id) },
+ minZoom: 8,
+ defaultEnabled: true,
+ style: { color: '#18864b', weight: 4, opacity: 0.74 },
+ limit: 5000
+ },
+ {
+ id: `gtfsStops${suffix}`,
+ label: 'Stops',
+ category: 'gtfs-stop',
+ endpoint: '/api/map/gtfs_stops.geojson',
+ params: { source_id: String(source.id) },
+ minZoom: 13,
+ defaultEnabled: false,
+ pointStyle: { radius: 4, weight: 1, color: '#14532d', fillOpacity: 0.82 },
+ limit: 4000
+ }
+ ]
+ });
+ });
+
+ sources.filter(hasActiveOsmDataset).forEach(source => {
+ const suffix = `Source${source.id}`;
+ groups.push({
+ id: `osm${suffix}`,
+ label: `OSM: ${source.name}`,
+ children: [
+ ...osmRouteModes.map(config =>
+ routeLayer(
+ `osm${config.label.replace(/[^A-Za-z0-9]+/g, '')}Routes${suffix}`,
+ config.label,
+ config.mode,
+ config.color,
+ source.id,
+ config.enabled,
+ {
+ routeScope: config.routeScope,
+ minZoom: config.minZoom,
+ baseWeight: config.baseWeight,
+ detailWeight: config.detailWeight,
+ tooltipMinZoom: config.tooltipMinZoom
+ }
+ )
+ ),
+ {
+ id: `osmRailPaths${suffix}`,
+ label: 'Rail/tram paths',
+ category: 'osm-infra',
+ endpoint: '/api/map/osm_features.geojson',
+ params: { source_id: String(source.id), kind: 'infra', mode: 'train,light_rail,subway,tram,monorail,funicular' },
+ minZoom: 13,
+ defaultEnabled: false,
+ style: { color: '#475569', weight: 2, opacity: 0.62 },
+ limit: 8000
+ },
+ {
+ id: `osmFerryPaths${suffix}`,
+ label: 'Ferry paths',
+ category: 'osm-infra',
+ endpoint: '/api/map/osm_features.geojson',
+ params: { source_id: String(source.id), kind: 'infra', mode: 'ferry' },
+ minZoom: 13,
+ defaultEnabled: false,
+ style: { color: '#0e7490', weight: 2, opacity: 0.62, dashArray: '5 5' },
+ limit: 4000
+ },
+ {
+ id: `osmStops${suffix}`,
+ label: 'Stops',
+ category: 'osm-stop',
+ endpoint: '/api/map/osm_features.geojson',
+ params: { source_id: String(source.id), kind: 'stop,station,terminal', geometry: 'point' },
+ minZoom: 14,
+ defaultEnabled: false,
+ pointStyle: { radius: 4, weight: 1, color: '#334155', fillOpacity: 0.62 },
+ limit: 5000
+ },
+ {
+ id: `osmStopWays${suffix}`,
+ label: 'Stop ways',
+ category: 'osm-stop',
+ endpoint: '/api/map/osm_features.geojson',
+ params: { source_id: String(source.id), kind: 'stop,station,terminal', geometry: 'nonpoint' },
+ minZoom: 15,
+ defaultEnabled: false,
+ style: { color: '#111827', weight: 2, opacity: 0.54, fillOpacity: 0.12 },
+ limit: 5000
+ }
+ ]
+ });
+ });
+
+ sources.filter(hasActiveGtfsDataset).forEach(source => {
+ const suffix = `Source${source.id}`;
+ groups.push({
+ id: `review${suffix}`,
+ label: `Match status: ${source.name}`,
+ children: matchStatusLayers.map(([label, status, color]) => {
+ const style = { color, weight: status === 'missing' ? 6 : 5, opacity: 0.88 };
+ if (status === 'missing') style.dashArray = '8 6';
+ return {
+ id: `match${status[0].toUpperCase()}${status.slice(1)}${suffix}`,
+ label,
+ category: 'match-status',
+ status,
+ endpoint: '/api/map/matched_gtfs_routes.geojson',
+ params: { source_id: String(source.id), status },
+ minZoom: 8,
+ defaultEnabled: false,
+ style,
+ limit: 5000
+ };
+ })
+ });
+ });
+ return groups;
+}
+
+function hasActiveGtfsDataset(source) {
+ return (source.datasets || []).some(dataset => dataset.kind === 'gtfs' && dataset.is_active);
+}
+
+function hasActiveOsmDataset(source) {
+ return (source.datasets || []).some(dataset => dataset.kind === 'osm_geojson' && dataset.is_active);
+}
+
+function initMap() {
+ const view = loadSavedMapView();
+ map = L.map('map', { preferCanvas: true }).setView(view.center, view.zoom);
+ map.createPane('searchPane');
+ map.getPane('searchPane').style.zIndex = 450;
+ map.createPane('candidatePane');
+ map.getPane('candidatePane').style.zIndex = 470;
+ map.createPane('journeyPane');
+ map.getPane('journeyPane').style.zIndex = 490;
+ L.tileLayer('https://tile.openstreetmap.org/{z}/{x}/{y}.png', {
+ maxZoom: 19,
+ attribution: '© OpenStreetMap contributors'
+ }).addTo(map);
+ map.on('moveend zoomend', scheduleMapLayerLoad);
+ map.on('moveend zoomend', saveMapViewport);
+ map.on('contextmenu', showJourneyContextMenu);
+ map.getContainer().addEventListener('contextmenu', showJourneyContainerContextMenu, true);
+}
+
+function loadSavedMapView() {
+ try {
+ const saved = JSON.parse(localStorage.getItem(MAP_VIEW_STORAGE_KEY) || 'null');
+ const lat = Number(saved?.center?.[0]);
+ const lon = Number(saved?.center?.[1]);
+ const zoom = Number(saved?.zoom);
+ if (
+ Number.isFinite(lat) &&
+ Number.isFinite(lon) &&
+ Number.isFinite(zoom) &&
+ lat >= -90 &&
+ lat <= 90 &&
+ lon >= -180 &&
+ lon <= 180 &&
+ zoom >= 0 &&
+ zoom <= 22
+ ) {
+ return { center: [lat, lon], zoom };
+ }
+ } catch (_) {}
+ return DEFAULT_MAP_VIEW;
+}
+
+function saveMapViewport() {
+ if (!map) return;
+ const center = map.getCenter();
+ const zoom = map.getZoom();
+ try {
+ localStorage.setItem(MAP_VIEW_STORAGE_KEY, JSON.stringify({
+ center: [Number(center.lat.toFixed(6)), Number(center.lng.toFixed(6))],
+ zoom: Number(zoom)
+ }));
+ } catch (_) {}
+}
+
+async function api(path, options = {}) {
+ const response = await fetch(path, {
+ headers: { 'Content-Type': 'application/json', ...(options.headers || {}) },
+ ...options
+ });
+ if (!response.ok) {
+ let detail = response.statusText;
+ try { detail = (await response.json()).detail || detail; } catch (_) {}
+ if (response.status === 409) updateMapStatus(detail);
+ throw new Error(detail);
+ }
+ return response.json();
+}
+
+function clearLayer(name) {
+ if (layers[name]) {
+ map.removeLayer(layers[name]);
+ delete layers[name];
+ }
+}
+
+function allLayerConfigs() {
+ return layerGroups.flatMap(group => group.children);
+}
+
+function initializeLayerState() {
+ if (savedLayerState === undefined) {
+ try {
+ savedLayerState = JSON.parse(localStorage.getItem('mobilityLayerState') || '{}');
+ } catch (_) {
+ savedLayerState = {};
+ }
+ }
+ allLayerConfigs().forEach(config => {
+ if (typeof layerState[config.id] !== 'boolean') {
+ layerState[config.id] = typeof savedLayerState[config.id] === 'boolean' ? savedLayerState[config.id] : config.defaultEnabled !== false;
+ }
+ });
+}
+
+function saveLayerState() {
+ localStorage.setItem('mobilityLayerState', JSON.stringify(layerState));
+}
+
+function renderLayerControls() {
+ const container = document.getElementById('layerControls');
+ container.innerHTML = layerGroups.map(group => `
+
+
+
+
${escapeHtml(err.message)}
`; + } +} + +function renderQaSummary(data) { + const container = document.getElementById('qaDashboard'); + if (!container) return; + const sections = data.sections || []; + const decision = data.decision || {}; + container.classList.remove('muted'); + container.innerHTML = ` +${escapeHtml(err.message)}
`; + } +} + +function renderGtfsHarmonizationInventory(data) { + const container = document.getElementById('gtfsHarmonizationInventory'); + if (!container) return; + const summary = data.summary || {}; + const feeds = data.feeds || []; + container.classList.remove('muted'); + if (!feeds.length) { + container.classList.add('muted'); + container.innerHTML = 'No GTFS sources registered yet.'; + return; + } + container.innerHTML = ` +Loading feed QA...
'); + try { + const data = await api(`/api/harmonization/gtfs/sources/${encodeURIComponent(sourceId)}`); + document.getElementById('overlayTitle').textContent = `GTFS QA: ${data.source?.name || `source #${sourceId}`}`; + document.getElementById('overlayContent').innerHTML = renderGtfsHarmonizationDetail(data); + } catch (err) { + document.getElementById('overlayContent').innerHTML = `${escapeHtml(err.message)}
`; + } +} + +function renderGtfsHarmonizationDetail(feed) { + const source = feed.source || {}; + const dataset = feed.active_dataset; + const issues = feed.issues || []; + const review = source.qa_review || {}; + return ` +No validation issue detected for this first-pass QA.
No GTFS datasets.
'} +${escapeHtml(JSON.stringify(event.metadata, null, 2))}`;
+}
+
+function renderJobDetails(data, queueData = {}) {
+ const job = data.job || {};
+ const events = data.events || [];
+ const queueJobs = queueData.jobs || [];
+ const latestEvent = events.length ? events[events.length - 1] : null;
+ const progressMax = Number(job.progress_total || 1);
+ const progressValue = Number(job.progress_current || 0);
+ const resultHtml = job.result && Object.keys(job.result).length
+ ? `${escapeHtml(JSON.stringify(job.result, null, 2))}No phase template for this job kind; use the event log below.
'} +No queue rows returned.
'} +No events yet.
'} +Loading job details...
'); + await loadJobDetails(jobId); +} + +async function loadJobDetails(jobId) { + if (jobDetailsPollTimer) { + window.clearTimeout(jobDetailsPollTimer); + jobDetailsPollTimer = undefined; + } + try { + const [details, queue] = await Promise.all([ + api(`/api/jobs/${encodeURIComponent(jobId)}/events?limit=200`), + api('/api/jobs?limit=20') + ]); + if (activeJobDetailsId !== String(jobId)) return; + document.getElementById('overlayTitle').textContent = `Job #${jobId}`; + document.getElementById('overlayContent').innerHTML = renderJobDetails(details, queue); + if (!details.job?.terminal && !document.getElementById('overlay')?.hidden) { + jobDetailsPollTimer = window.setTimeout(() => loadJobDetails(jobId), JOB_DETAILS_POLL_MS); + } + } catch (err) { + if (activeJobDetailsId === String(jobId)) { + document.getElementById('overlayContent').innerHTML = `${escapeHtml(err.message)}
`; + } + } +} + +function jobKindLabel(job) { + if (job.kind === 'source_import') return 'Source import'; + if (job.kind === 'source_delete') return 'Source delete'; + if (job.kind === 'dataset_delete') return 'Dataset delete'; + if (job.kind === 'maintenance') return job.description || 'Maintenance'; + if (job.kind === 'route_layer_rebuild') return 'Route-layer rebuild'; + if (job.kind === 'route_matching') return 'Route matching'; + if (job.kind === 'osm_relabel') return 'OSM relabeling'; + return job.description || job.kind || 'Job'; +} + +function renderWorkerStatus(workers) { + if (!workers.length) { + return '${escapeHtml(emptyMessage)}
`; + return; + } + container.innerHTML = sources.map(sourceCard).join(''); +} + +function sourceCard(source) { + return ` +Searching datasets...
'; + try { + const data = await api(`/api/datasets/search?${params.toString()}`, { signal: controller.signal }); + if (sequence !== datasetSearchSequence) return; + renderDatasetSearchResults(data); + } catch (err) { + if (err.name === 'AbortError' || sequence !== datasetSearchSequence) return; + results.innerHTML = `${escapeHtml(err.message)}
`; + } finally { + if (datasetSearchAbortController === controller) { + datasetSearchAbortController = undefined; + } + } +} + +function renderDatasetSearchResults(data) { + const results = document.getElementById('datasetSearchResults'); + if (!results) return; + const gtfs = data.gtfs_routes || []; + const osm = data.osm_routes || []; + const patterns = data.route_patterns || []; + if (!gtfs.length && !osm.length && !patterns.length) { + results.innerHTML = 'No dataset entries found.
'; + return; + } + results.innerHTML = ` + ${datasetResultSection('GTFS timetable routes', gtfs, renderGtfsSearchHit)} + ${datasetResultSection('OSM visual routes', osm, renderOsmSearchHit)} + ${datasetResultSection('Extracted route-layer patterns', patterns, renderRoutePatternSearchHit)} + `; +} + +function datasetResultSection(title, rows, renderer) { + if (!rows.length) return ''; + return ` +No catalog entries loaded.
'; + return; + } + container.innerHTML = sourceCatalogEntries.map(sourceCatalogEntryCard).join(''); +} + +function sourceCatalogEntryCard(entry) { + const kind = inferCatalogSourceKind(entry); + const actionLabel = kind && kind.startsWith('osm_') ? 'Use as map source' : 'Use as GTFS source'; + return ` +Loading Geofabrik catalog...
'; + try { + const data = await api(`/api/geofabrik/catalog?${params.toString()}`); + renderGeofabrikResults(data.entries || []); + } catch (err) { + container.innerHTML = `${escapeHtml(err.message)}
`; + } +} + +function renderGeofabrikResults(entries) { + const container = document.getElementById('geofabrikResults'); + if (!container) return; + if (!entries.length) { + container.classList.add('muted'); + container.innerHTML = 'No Geofabrik extracts found.'; + return; + } + container.classList.remove('muted'); + container.innerHTML = entries.map(entry => ` +No matches yet. Run the matcher.
'; + return; + } + container.innerHTML = matches.map(match => ` +Generating travel options...
'; + } + const payload = journeyPlannerPayload(); + const data = await api('/api/itineraries/generate', { + method: 'POST', + body: JSON.stringify(payload) + }); + renderItineraryResults(data.itineraries || []); +} + +async function loadItineraries() { + const data = await api('/api/itineraries?limit=20'); + renderItineraryResults(data.itineraries || []); +} + +function renderItineraryResults(itineraries) { + const container = document.getElementById('itineraryResults'); + if (!container) return; + lastItineraries = itineraries; + if (!itineraries.length) { + container.classList.add('muted'); + container.innerHTML = 'No itinerary options yet.'; + return; + } + container.classList.remove('muted'); + container.innerHTML = itineraries.map(itinerary => ` +Loading candidates...
', { mapReview: true }); + try { + const data = await api(`/api/matches/${matchId}/candidates?limit=30`); + renderCandidateOverlay(data); + } catch (err) { + openOverlay('Matching candidates', `${escapeHtml(err.message)}
`); + } +} + +function renderCandidateOverlay(data) { + const route = data.route || {}; + const rows = data.candidates || []; + const title = `Candidates for ${route.ref || route.route_id || `match ${data.match_id}`}`; + const currentOrFirst = rows.find(candidate => candidate.current_match) || rows[0]; + const content = ` +${escapeHtml(JSON.stringify(candidate.reasons, null, 2))}
+ No OSM route candidates.
'} + `; + openOverlay(title, content, { mapReview: true }); + drawCandidatePreview(data.preview); +} + +async function acceptCandidate(matchId, osmFeatureId, button) { + const originalText = button?.textContent; + if (button) { + button.disabled = true; + button.textContent = 'Saving...'; + } + try { + await api(`/api/matches/${encodeURIComponent(matchId)}/candidates/${encodeURIComponent(osmFeatureId)}/accept`, { method: 'POST' }); + await Promise.all([loadMatches(), loadStats(), loadMapLayers()]); + await showCandidates(matchId); + } catch (err) { + alert(err.message); + } finally { + if (button) { + button.disabled = false; + button.textContent = originalText; + } + } +} + +async function showCanonicalStop(canonicalStopId) { + openOverlay('Stop detail', 'Loading stop detail...
'); + try { + const data = await api(`/api/canonical-stops/${canonicalStopId}`); + renderCanonicalStopOverlay(data); + } catch (err) { + openOverlay('Stop detail', `${escapeHtml(err.message)}
`); + } +} + +function renderCanonicalStopOverlay(data) { + const stop = data.canonical_stop || {}; + const gtfsStops = data.gtfs_stops || []; + const osmFeatures = data.osm_features || []; + const rules = data.rules || []; + const html = ` +No timetable stops linked.
'} +No OSM visual stops linked.
'} +${escapeHtml(JSON.stringify({ selector: rule.selector, action: rule.action }, null, 2))}
+ No stored manual stop decisions found for this stop.
'} +${escapeHtml(err.message)}
`; + }); + const input = document.getElementById('canonicalCandidateQuery'); + if (input) { + input.addEventListener('keydown', event => { + if (event.key !== 'Enter') return; + event.preventDefault(); + loadCanonicalStopCandidates(stop.id).catch(err => alert(err.message)); + }); + } +} + +async function loadCanonicalStopCandidates(canonicalStopId) { + const results = document.getElementById('canonicalCandidateResults'); + if (!results) return; + const query = (document.getElementById('canonicalCandidateQuery')?.value || '').trim(); + const params = new URLSearchParams({ limit: '40' }); + if (query) params.set('q', query); + results.classList.remove('muted'); + results.innerHTML = 'Loading candidates...
'; + const data = await api(`/api/canonical-stops/${canonicalStopId}/gtfs-candidates?${params.toString()}`); + const candidates = data.candidates || []; + if (!candidates.length) { + results.classList.add('muted'); + results.innerHTML = 'No candidate GTFS stops found.'; + return; + } + results.classList.remove('muted'); + results.innerHTML = candidates.map(candidate => ` +Searching datasets...
'; + } else if (results) { + results.classList.add('muted'); + results.innerHTML = 'Search all imported datasets by label, route ID, and route-layer reference.'; + } + datasetSearchTimer = window.setTimeout(() => searchDatasets(), 100); + }); + document.getElementById('datasetSearchActiveOnly').addEventListener('change', () => { + window.clearTimeout(datasetSearchTimer); + searchDatasets(); + }); + document.getElementById('datasetSearchResults').addEventListener('click', event => { + const row = event.target.closest('[data-search-feature-type]'); + if (!row) return; + showDatasetSearchFeature(row.dataset.searchFeatureType, row.dataset.searchFeatureId, row); + }); + document.getElementById('datasetSearchResults').addEventListener('keydown', event => { + if (event.key !== 'Enter' && event.key !== ' ') return; + const row = event.target.closest('[data-search-feature-type]'); + if (!row) return; + event.preventDefault(); + showDatasetSearchFeature(row.dataset.searchFeatureType, row.dataset.searchFeatureId, row); + }); + document.getElementById('sources')?.addEventListener('click', handleSourceAction); + document.getElementById('gtfsHarmonizationInventory')?.addEventListener('click', event => { + const button = event.target.closest('[data-gtfs-feed-detail]'); + if (!button) return; + showGtfsHarmonizationDetail(button.dataset.gtfsFeedDetail); + }); + document.getElementById('mappingSources')?.addEventListener('click', handleSourceAction); + document.getElementById('importSourceCatalogBtn').addEventListener('click', () => importSourceCatalog().catch(err => alert(err.message))); + document.getElementById('importIngestableSourcesBtn').addEventListener('click', () => importIngestableSources().catch(err => alert(err.message))); + document.getElementById('sourceCatalogSearch').addEventListener('input', () => loadSourceCatalog().catch(err => console.warn(err))); + document.getElementById('sourceCatalogCountry').addEventListener('input', () => loadSourceCatalog().catch(err => console.warn(err))); + document.getElementById('sourceCatalogPriority').addEventListener('change', () => loadSourceCatalog().catch(err => console.warn(err))); + document.getElementById('sourceCatalog').addEventListener('click', event => { + const button = event.target.closest('[data-fill-source-from-catalog]'); + if (!button) return; + fillSourceFormFromCatalog(button.dataset.fillSourceFromCatalog); + }); + document.getElementById('geofabrikSearchBtn').addEventListener('click', () => searchGeofabrik()); + document.getElementById('geofabrikSearch').addEventListener('keydown', event => { + if (event.key !== 'Enter') return; + event.preventDefault(); + searchGeofabrik(); + }); + document.getElementById('geofabrikResults').addEventListener('click', event => { + const addButton = event.target.closest('[data-geofabrik-add]'); + if (addButton) { + createGeofabrikSource(addButton.dataset.geofabrikAdd, false, addButton); + return; + } + const importButton = event.target.closest('[data-geofabrik-import]'); + if (importButton) { + createGeofabrikSource(importButton.dataset.geofabrikImport, true, importButton); + } + }); + document.getElementById('overlayCloseBtn').addEventListener('click', closeOverlay); + document.getElementById('overlay').addEventListener('submit', event => { + const form = event.target.closest('[data-gtfs-review-form]'); + if (!form) return; + event.preventDefault(); + const button = form.querySelector('button[type="submit"]'); + saveGtfsFeedReview(form.dataset.sourceId, gtfsReviewPayloadFromForm(form), button); + }); + document.getElementById('overlay').addEventListener('click', event => { + const approveButton = event.target.closest('[data-gtfs-review-approve]'); + if (approveButton) { + const form = approveButton.closest('[data-gtfs-review-form]'); + const payload = gtfsReviewPayloadFromForm(form); + payload.review_status = 'approved'; + saveGtfsFeedReview(approveButton.dataset.gtfsReviewApprove, payload, approveButton); + return; + } + const addRelatedButton = event.target.closest('[data-gtfs-add-related-source]'); + if (addRelatedButton) { + prepareRelatedGtfsSource(addRelatedButton.dataset.gtfsAddRelatedSource); + return; + } + const previewButton = event.target.closest('[data-preview-candidate]'); + if (previewButton) { + focusCandidatePreview(previewButton.dataset.previewCandidate); + return; + } + const candidateButton = event.target.closest('[data-accept-candidate]'); + if (candidateButton) { + acceptCandidate(candidateButton.dataset.matchId, candidateButton.dataset.acceptCandidate, candidateButton); + return; + } + const canonicalSearchButton = event.target.closest('[data-canonical-candidate-search]'); + if (canonicalSearchButton) { + loadCanonicalStopCandidates(canonicalSearchButton.dataset.canonicalCandidateSearch).catch(err => alert(err.message)); + return; + } + const canonicalLinkButton = event.target.closest('[data-canonical-link-candidate]'); + if (canonicalLinkButton) { + linkCanonicalStopCandidate( + canonicalLinkButton.dataset.canonicalStopTarget, + canonicalLinkButton.dataset.canonicalLinkCandidate, + canonicalLinkButton + ); + return; + } + const canonicalUnlinkButton = event.target.closest('[data-canonical-unlink]'); + if (canonicalUnlinkButton) { + unlinkCanonicalStopLink(canonicalUnlinkButton.dataset.canonicalUnlink, canonicalUnlinkButton); + return; + } + if (event.target.id === 'overlay') closeOverlay(); + }); + document.getElementById('sourceForm')?.addEventListener('submit', submitSourceForm); + document.getElementById('mappingSourceForm')?.addEventListener('submit', submitSourceForm); + document.getElementById('journeyEarlierBtn').addEventListener('click', () => shiftJourneyTime(-15)); + document.getElementById('journeySwapBtn').addEventListener('click', () => swapJourneyEndpoints()); + document.getElementById('journeyLaterBtn').addEventListener('click', () => shiftJourneyTime(15)); + document.getElementById('generateItinerariesBtn').addEventListener('click', async () => { + try { + await generateItinerariesFromForm(); + } catch (err) { + const container = document.getElementById('itineraryResults'); + if (container) container.innerHTML = `${escapeHtml(err.message)}
`; + } + }); + document.getElementById('reloadItinerariesBtn').addEventListener('click', () => loadItineraries().catch(err => alert(err.message))); + document.getElementById('itineraryResults').addEventListener('click', event => { + const saveButton = event.target.closest('[data-itinerary-save]'); + if (saveButton) { + saveItinerary(saveButton.dataset.itinerarySave, saveButton.dataset.itinerarySaved === 'true', saveButton); + return; + } + const legButton = event.target.closest('[data-itinerary-leg-lock]'); + if (legButton) { + lockItineraryLeg(legButton.dataset.itineraryLegLock, legButton.dataset.itineraryLegLocked === 'true', legButton); + return; + } + const showButton = event.target.closest('[data-itinerary-show]'); + if (showButton) { + showItinerary(showButton.dataset.itineraryShow); + } + }); + document.getElementById('journeyFromQuery').addEventListener('input', () => scheduleJourneyStopSearch('from')); + document.getElementById('journeyToQuery').addEventListener('input', () => scheduleJourneyStopSearch('to')); + document.getElementById('journeyViaQuery').addEventListener('input', () => scheduleJourneyStopSearch('via')); + document.querySelectorAll('input[name="journeyMode"]').forEach(input => { + input.addEventListener('change', () => { + stopActiveJourneySearch().catch(err => console.warn(err)); + updateJourneyModeControls(); + }); + }); + document.getElementById('journeyDirectOnly').addEventListener('change', () => stopActiveJourneySearch().catch(err => console.warn(err))); + document.getElementById('journeyRanking').addEventListener('change', () => stopActiveJourneySearch().catch(err => console.warn(err))); + updateJourneyModeControls(); + ['journeyFromSuggestions', 'journeyToSuggestions', 'journeyViaSuggestions'].forEach(id => { + document.getElementById(id).addEventListener('click', event => { + const button = event.target.closest('[data-stop-id]'); + if (!button) return; + selectJourneyStop(button.dataset.stopRole, button.dataset.stopId, button.dataset.stopLabel); + }); + }); + document.getElementById('journeyForm').addEventListener('submit', async (event) => { + event.preventDefault(); + try { + await searchJourney(); + } catch (err) { + renderJourneyMessage(err.message); + } + }); +} + +window.addEventListener('load', async () => { + initMap(); + setupEvents(); + try { + await refreshAll(); + } catch (err) { + console.error(err); + alert(`Startup refresh failed: ${err.message}`); + } finally { + startJobListRefresh(); + } +}); diff --git a/app/static/style.css b/app/static/style.css new file mode 100644 index 0000000..9eb4e6a --- /dev/null +++ b/app/static/style.css @@ -0,0 +1,1498 @@ +:root { + font-family: Inter, system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif; + color: #16202a; + background: #f4f6f8; +} +* { box-sizing: border-box; } +body { margin: 0; } +header { + display: flex; + justify-content: space-between; + gap: 16px; + align-items: center; + padding: 14px 18px; + background: #101820; + color: #fff; +} +h1 { margin: 0; font-size: 20px; } +header p { margin: 4px 0 0; color: #c5d0da; font-size: 13px; } +.actions { display: flex; gap: 8px; flex-wrap: wrap; } +.workflow-actions, +.maintenance-grid { + display: grid; + grid-template-columns: repeat(2, minmax(0, 1fr)); + gap: 8px; +} +.admin-status { + margin-top: 8px; + font-size: 12px; + line-height: 1.35; + overflow-wrap: anywhere; +} +button { + cursor: pointer; + border: 1px solid #9da9b4; + background: #fff; + color: #16202a; + border-radius: 6px; + padding: 7px 10px; + font-weight: 600; + min-width: 0; + overflow-wrap: anywhere; +} +button:hover { background: #edf2f7; } +button.danger { border-color: #c14141; color: #a32222; } +button.primary { border-color: #2f7d4f; color: #21643d; } +main { + display: grid; + grid-template-columns: 420px 1fr; + height: calc(100vh - 70px); + transition: grid-template-columns .18s ease; +} +aside { + position: relative; + overflow: hidden; + padding: 0; + border-right: 1px solid #d4dce3; + min-width: 0; + background: #f4f6f8; +} +.sidebar-content { + height: 100%; + overflow: auto; + padding: 12px 12px 54px; + min-width: 0; + transition: opacity .12s ease, visibility .12s ease; +} +.sidebar-collapse-handle { + position: absolute; + bottom: 10px; + right: 8px; + z-index: 40; + display: grid; + place-items: center; + width: 28px; + height: 28px; + min-width: 28px; + padding: 0; + border-radius: 999px; + border-color: #b9c4ce; + background: #fff; + box-shadow: 0 2px 7px rgba(15, 23, 42, .16); + font-size: 20px; + line-height: 1; +} +.sidebar-collapse-handle:hover { + background: #edf2f7; +} +main.sidebar-collapsed { + grid-template-columns: 38px 1fr; +} +main.sidebar-collapsed aside { + overflow: hidden; + padding: 0; +} +main.sidebar-collapsed .sidebar-content { + opacity: 0; + visibility: hidden; + pointer-events: none; +} +main.sidebar-collapsed .sidebar-collapse-handle { + left: 5px; + right: auto; + bottom: 12px; +} +.map-panel { position: relative; min-width: 0; } +#map { width: 100%; height: 100%; background: #e4e9ee; } +.route-line-label { + border: 0; + background: rgba(255, 255, 255, .86); + color: #111827; + box-shadow: 0 1px 2px rgba(15, 23, 42, .16); + font-size: 11px; + font-weight: 700; + line-height: 1; + padding: 2px 4px; +} +.route-line-label::before { display: none; } +.card { + background: #fff; + border: 1px solid #d7dee6; + border-radius: 8px; + padding: 12px; + margin-bottom: 12px; + box-shadow: 0 1px 3px rgba(16, 24, 32, .05); +} +.card h2 { font-size: 15px; margin: 0; } +.sidebar-section { + padding: 0; + min-width: 0; + max-width: 100%; +} +.sidebar-section > summary, +.nested-section > summary { + cursor: pointer; + list-style: none; + display: flex; + justify-content: space-between; + gap: 10px; + align-items: center; +} +.sidebar-section > summary { + min-height: 42px; + padding: 11px 12px; +} +.nested-section > summary { + min-height: 34px; + padding: 8px 0; +} +.sidebar-section > summary::-webkit-details-marker, +.nested-section > summary::-webkit-details-marker { display: none; } +.sidebar-section > summary::after, +.nested-section > summary::after { + content: ""; + width: 7px; + height: 7px; + border: solid #607080; + border-width: 0 2px 2px 0; + transform: rotate(-45deg); + transition: transform .15s ease; +} +.sidebar-section[open] > summary::after, +.nested-section[open] > summary::after { + transform: rotate(45deg); +} +.sidebar-section > summary h2, +.nested-section > summary h3 { + margin: 0; +} +.nested-section > summary h3 { + color: #273646; + font-size: 13px; +} +.sidebar-section-body { + padding: 0 12px 12px; + min-width: 0; + max-width: 100%; +} +.sidebar-section[open] > summary { + border-bottom: 1px solid #edf1f5; + margin-bottom: 10px; +} +.nested-section { + border-top: 1px solid #edf1f5; + min-width: 0; + max-width: 100%; +} +.nested-section:first-child { + border-top: none; +} +.nested-section-body { + display: grid; + gap: 8px; + padding-bottom: 10px; + min-width: 0; +} +.nested-section[open] > summary { + margin-bottom: 6px; +} +form { display: grid; gap: 8px; } +label { display: grid; gap: 4px; font-size: 12px; color: #52606d; } +input, select, textarea { + border: 1px solid #c6d0d9; + border-radius: 6px; + padding: 7px; + font: inherit; + color: #16202a; + min-width: 0; +} +textarea { + resize: vertical; + min-height: 52px; +} +input.journey-selected-location { + padding-left: 31px; + background-repeat: no-repeat; + background-position: 8px center; + background-size: 17px 17px; +} +input.journey-selected-address { + background-image: url("data:image/svg+xml,%3Csvg width='17' height='17' viewBox='0 0 17 17' xmlns='http://www.w3.org/2000/svg'%3E%3Cpath fill='%232d5f7f' d='M8.5 1.4A5.1 5.1 0 0 0 3.4 6.5c0 3.7 5.1 9.1 5.1 9.1s5.1-5.4 5.1-9.1A5.1 5.1 0 0 0 8.5 1.4zm0 7.1a2 2 0 1 1 0-4 2 2 0 0 1 0 4z'/%3E%3C/svg%3E"); +} +input.journey-selected-stop { + background-image: url("data:image/svg+xml,%3Csvg width='17' height='17' viewBox='0 0 17 17' xmlns='http://www.w3.org/2000/svg'%3E%3Ccircle cx='8.5' cy='8.5' r='6.7' fill='white' stroke='%232d5f7f' stroke-width='1.7'/%3E%3Ctext x='8.5' y='11.5' text-anchor='middle' font-family='Arial,sans-serif' font-size='8.4' font-weight='800' fill='%232d5f7f'%3EH%3C/text%3E%3C/svg%3E"); +} +.stats { display: grid; grid-template-columns: repeat(2, 1fr); gap: 8px; } +.stat { + border: 1px solid #e1e7ee; + border-left-width: 4px; + border-radius: 8px; + padding: 8px; + background: #f8fafc; +} +.stat.info { border-left-color: #64748b; } +.stat.good { border-left-color: #23864f; background: #f5fbf7; } +.stat.warn { border-left-color: #c47a12; background: #fffaf0; } +.stat.bad { border-left-color: #c24141; background: #fff7f7; } +.stat strong { display: block; font-size: 18px; } +.stat span { font-size: 11px; color: #5a6875; } +.qa-dashboard { + display: grid; + gap: 10px; + font-size: 12px; +} +.qa-toolbar { + display: flex; + justify-content: flex-end; + margin-bottom: 8px; +} +.qa-decision { + display: grid; + gap: 3px; + border: 1px solid #d9e2ec; + border-radius: 8px; + background: #f8fafc; + padding: 8px; +} +.qa-decision strong { + color: #17212b; +} +.qa-decision span { + color: #607080; +} +.qa-section { + display: grid; + gap: 6px; +} +.qa-section h3 { + margin: 0; + font-size: 12px; + color: #273646; +} +.qa-grid { + display: grid; + grid-template-columns: repeat(2, minmax(0, 1fr)); + gap: 6px; +} +.qa-item { + min-width: 0; + border: 1px solid #e1e7ee; + border-left-width: 4px; + border-radius: 8px; + background: #f8fafc; + padding: 7px; +} +.qa-item.info { border-left-color: #64748b; } +.qa-item.good { border-left-color: #23864f; background: #f5fbf7; } +.qa-item.warn { border-left-color: #c47a12; background: #fffaf0; } +.qa-item.bad { border-left-color: #c24141; background: #fff7f7; } +.qa-item strong, +.qa-item span { + display: block; + min-width: 0; + overflow-wrap: anywhere; +} +.qa-item strong { + font-size: 15px; + color: #17212b; +} +.qa-item span { + font-size: 11px; + color: #5a6875; +} +.qa-actions { + margin: 0; + padding-left: 18px; + color: #52606d; +} +.harmonization-inventory, +.harmonization-feed-list, +.harmonization-detail, +.harmonization-review-list, +.harmonization-dataset-list { + display: grid; + gap: 8px; + min-width: 0; +} +.harmonization-summary { + display: grid; + grid-template-columns: repeat(5, minmax(0, 1fr)); + gap: 5px; +} +.harmonization-summary div { + border: 1px solid #e2e8f0; + border-radius: 6px; + background: #f8fafc; + padding: 6px; + min-width: 0; +} +.harmonization-summary strong, +.harmonization-summary span { + display: block; + min-width: 0; + overflow-wrap: anywhere; +} +.harmonization-summary strong { + font-size: 13px; + color: #17212b; +} +.harmonization-summary span { + font-size: 10px; + color: #5a6875; +} +.harmonization-feed { + display: grid; + gap: 6px; + border-top: 1px solid #e1e7ee; + padding-top: 8px; + font-size: 12px; + min-width: 0; +} +.harmonization-feed:first-child { + border-top: none; +} +.harmonization-feed-title { + display: flex; + justify-content: space-between; + align-items: start; + gap: 8px; + min-width: 0; +} +.harmonization-feed-title > * { + min-width: 0; + overflow-wrap: anywhere; +} +.harmonization-issues { + display: flex; + flex-wrap: wrap; + gap: 4px; +} +.harmonization-inventory { + max-height: 360px; + overflow-x: hidden; + overflow-y: auto; + padding-right: 4px; +} +.harmonization-detail { + align-content: start; + padding-right: 4px; +} +.harmonization-review-form { + display: grid; + grid-template-columns: minmax(0, 1fr) minmax(140px, .45fr); + gap: 8px; +} +.harmonization-review-form label:has(textarea), +.harmonization-review-form .source-actions, +.harmonization-review-form > .muted { + grid-column: 1 / -1; +} +.harmonization-review-item { + display: grid; + gap: 2px; + border-left: 4px solid #64748b; + border-radius: 6px; + background: #f8fafc; + padding: 8px; + min-width: 0; +} +.harmonization-review-item.error { + border-left-color: #c24141; + background: #fff7f7; +} +.harmonization-review-item.probable { + border-left-color: #c47a12; + background: #fffaf0; +} +.harmonization-review-item.ok { + border-left-color: #23864f; + background: #f5fbf7; +} +.harmonization-review-item strong, +.harmonization-review-item span { + min-width: 0; + overflow-wrap: anywhere; +} +.harmonization-review-item span { + color: #52606d; +} +.source, .match, .catalog-entry { + border-top: 1px solid #e1e7ee; + padding: 8px 0; + font-size: 12px; + min-width: 0; + overflow-wrap: anywhere; +} +.source:first-child, .match:first-child, .catalog-entry:first-child { border-top: none; } +.source-title, .match-title { font-weight: 700; color: #17212b; } +.match-title { + display: flex; + flex-wrap: wrap; + gap: 4px 6px; + align-items: center; +} +.source, +.catalog-entry { + display: grid; + gap: 7px; +} +.source-title, +.catalog-title { + display: flex; + justify-content: space-between; + gap: 8px; + align-items: start; + font-weight: 700; + min-width: 0; +} +.catalog-title { + flex-wrap: wrap; +} +.source-title > *, +.catalog-title > *, +.dataset-title > *, +.dataset-result-title > *, +.job-title > *, +.job-progress > *, +.worker-row > *, +.layer-row > span { + min-width: 0; + overflow-wrap: anywhere; +} +.source-actions, +.dataset-actions, +.candidate-actions, +.source-datasets { + display: flex; + gap: 6px; + flex-wrap: wrap; +} +.source-meta, +.source-update-row, +.source-job-row, +.source-warning, +.dataset-row { + display: grid; + gap: 3px; +} +.source-warning { + border: 1px solid #f0c36a; + border-radius: 6px; + background: #fff8e1; + color: #6f4c00; + padding: 7px 8px; +} +.dataset-row { + width: 100%; + border: 1px solid #e2e8f0; + border-radius: 6px; + background: #f8fafc; + padding: 7px; + min-width: 0; + overflow-wrap: anywhere; +} +.dataset-title { + display: flex; + flex-wrap: wrap; + justify-content: space-between; + gap: 8px; +} +.metric-row { + display: flex; + gap: 6px; + flex-wrap: wrap; +} +.metric { + border: 1px solid #dde5ee; + border-radius: 999px; + padding: 2px 7px; + background: #fff; + color: #3d4b58; + font-size: 11px; + max-width: 100%; + min-width: 0; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} +.muted { color: #687683; } +.badge { + display: inline-block; + border-radius: 999px; + padding: 2px 7px; + background: #e6edf5; + color: #273646; + font-size: 11px; + margin-left: 4px; +} +.badge.ok, .badge.matched, .badge.accepted { background: #dff2e7; color: #145f35; } +.badge.error, .badge.rejected, .badge.missing { background: #fde5e5; color: #9b1c1c; } +.badge.probable { background: #fff2cc; color: #7d5700; } +.badge.weak { background: #ffe8d6; color: #8a4300; } +.badge.queued { background: #e6edf5; color: #273646; } +.badge.running { background: #dbeafe; color: #1d4ed8; } +.badge.paused { background: #ede9fe; color: #5b21b6; } +.badge.completed { background: #dff2e7; color: #145f35; } +.badge.failed, .badge.cancelled { background: #fde5e5; color: #9b1c1c; } +.match-actions { + display: flex; + flex-wrap: wrap; + gap: 6px; + margin-top: 6px; +} +.filter-row { + display: flex; + gap: 8px; + margin-bottom: 8px; + min-width: 0; + max-width: 100%; +} +.filter-row select, +.filter-row input { flex: 1; } +.source-catalog-card { + min-width: 0; + max-width: 100%; + overflow: hidden; +} +.source-catalog-card > .nested-section-body { + padding-right: 4px; + min-width: 0; + overflow-x: hidden; +} +.dataset-search-form { + display: grid; + gap: 8px; +} +.inline-check { + display: flex; + align-items: center; + gap: 6px; + color: #52606d; +} +.inline-check input { width: auto; } +.dataset-search-results { + margin-top: 8px; + display: grid; + gap: 8px; + font-size: 12px; + min-width: 0; + max-height: 300px; + overflow-x: hidden; + overflow-y: auto; + padding-right: 4px; +} +.dataset-result-section { + display: grid; + gap: 6px; +} +.dataset-result-section h3 { + margin: 4px 0 0; + font-size: 12px; + color: #273646; +} +.dataset-result-row { + display: grid; + gap: 4px; + border: 1px solid #e2e8f0; + border-radius: 6px; + background: #f8fafc; + padding: 7px; +} +.dataset-result-row.clickable { + cursor: pointer; +} +.dataset-result-row.clickable:hover, +.dataset-result-row.clickable:focus { + border-color: #74a99b; + background: #eef8f5; + outline: none; +} +.dataset-result-row.selected { + border-color: #0f766e; + background: #e6f5f1; + box-shadow: 0 0 0 1px rgba(15, 118, 110, .18); +} +.dataset-result-row.no-geometry { + opacity: .68; +} +.dataset-result-row.loading { + border-color: #0f766e; +} +.dataset-result-title { + display: flex; + flex-wrap: wrap; + justify-content: space-between; + gap: 8px; +} +.geometry-badge { + display: inline-flex; + align-items: center; + gap: 4px; + border-radius: 999px; + padding: 2px 7px; + font-size: 11px; + background: #e8eef5; + color: #314151; +} +.geometry-badge.ok { + background: #dff2e7; + color: #145f35; +} +.geometry-badge.missing { + background: #eef1f4; + color: #687683; +} +.geometry-dot { + width: 6px; + height: 6px; + border-radius: 50%; + background: currentColor; +} +.source-catalog-actions { + display: grid; + grid-template-columns: 1fr 1fr; + gap: 8px; + margin-bottom: 8px; + min-width: 0; +} +.matches-card { + overflow: visible; +} +.matches-card > .nested-section-body { + overflow-x: hidden; + padding-right: 4px; +} +#matches { + min-width: 0; + max-height: 340px; + overflow-x: hidden; + overflow-y: auto; + padding-right: 4px; +} +#matches .muted, +#sourceCatalog, +#geofabrikResults, +#mappingSources { + min-width: 0; + overflow-x: hidden; + overflow-wrap: anywhere; +} +#sourceCatalog .catalog-entry, +#geofabrikResults .catalog-entry { + width: 100%; + max-width: 100%; + overflow-x: hidden; +} +.source-catalog-filter { + display: grid; + grid-template-columns: minmax(0, 1.4fr) minmax(0, .7fr) minmax(0, .9fr); + gap: 8px; + min-width: 0; +} +.source-catalog-filter > * { + width: 100%; +} +.geofabrik-filter { + display: grid; + grid-template-columns: minmax(0, 1fr) auto; + align-items: stretch; + gap: 8px; + margin-bottom: 0; +} +.geofabrik-filter input { + width: 100%; +} +.geofabrik-filter button { + width: auto; + min-width: 74px; + white-space: nowrap; +} +.catalog-entry .muted, +.catalog-entry .metric-row, +.catalog-entry .source-actions { + min-width: 0; + width: 100%; + max-width: 100%; + overflow-wrap: anywhere; +} +#sourceCatalog, +#geofabrikResults, +#sources, +#mappingSources { + width: 100%; + max-height: 320px; + min-width: 0; + max-width: 100%; + overflow-y: auto; + padding-right: 4px; +} +#geofabrikResults.dataset-search-results { + gap: 10px; + align-content: start; + min-height: 96px; + margin-top: 0; +} +#geofabrikResults .catalog-entry { + display: grid; + gap: 7px; + min-height: 96px; + padding: 10px; + border: 1px solid #e2e8f0; + border-radius: 6px; + background: #f8fafc; +} +#geofabrikResults .catalog-entry:first-child { + border-top: 1px solid #e2e8f0; +} +.preset-row { + display: grid; + grid-template-columns: 1fr 1fr; + gap: 6px; + margin-bottom: 10px; +} +.layer-controls { + display: grid; + gap: 8px; + font-size: 12px; + max-height: 360px; + overflow-x: hidden; + overflow-y: auto; + padding-right: 4px; +} +.layer-group { + border: 1px solid #e1e7ee; + border-radius: 8px; + background: #f8fafc; +} +.layer-group summary { + cursor: pointer; + list-style: none; + padding: 8px; + font-weight: 700; +} +.layer-group summary::-webkit-details-marker { display: none; } +.layer-group label, +.layer-row { + display: flex; + align-items: center; + gap: 7px; + color: #26323e; +} +.layer-group input { + width: 15px; + height: 15px; + margin: 0; +} +.layer-children { + display: grid; + gap: 2px; + padding: 0 8px 8px 24px; +} +.layer-row { + min-height: 24px; + justify-content: space-between; + border-radius: 6px; + padding: 1px 3px; +} +.layer-row.loading { background: #edf6fb; } +.layer-row span:nth-child(2) { + flex: 1; +} +.layer-count { + min-width: 42px; + text-align: right; + color: #6b7885; + font-variant-numeric: tabular-nums; +} +.map-status { + margin-top: 8px; + min-height: 16px; + font-size: 12px; +} +.jobs { + display: grid; + gap: 8px; + font-size: 12px; + min-width: 0; + max-height: 280px; + overflow-x: hidden; + overflow-y: auto; + padding-right: 4px; +} +.jobs-toolbar { + display: flex; + flex-wrap: wrap; + justify-content: flex-end; + gap: 6px; + min-width: 0; +} +.worker-list, +.worker-row { + display: grid; + gap: 4px; +} +.worker-row { + border: 1px solid #e1e7ee; + border-radius: 6px; + background: #f8fafc; + padding: 7px; +} +.job-row { + border-top: 1px solid #e1e7ee; + padding-top: 8px; + min-width: 0; + overflow-wrap: anywhere; +} +.job-row:first-child { + border-top: none; + padding-top: 0; +} +.job-title, +.job-progress, +.job-actions { + display: flex; + justify-content: space-between; + gap: 8px; + align-items: center; + min-width: 0; + flex-wrap: wrap; +} +.job-actions { + justify-content: flex-start; + flex-wrap: wrap; + margin-top: 6px; +} +.job-title { + font-weight: 700; +} +.job-progress progress { + width: min(120px, 100%); + max-width: 100%; + height: 9px; +} +.job-detail { + display: grid; + gap: 14px; + font-size: 12px; +} +.job-detail section { + display: grid; + gap: 8px; +} +.job-detail h3 { + margin: 0; + font-size: 13px; +} +.job-detail pre, +.job-event-row pre { + max-height: 220px; + overflow: auto; + margin: 6px 0 0; + border: 1px solid #e2e8f0; + border-radius: 6px; + background: #f8fafc; + padding: 8px; + font-size: 11px; + white-space: pre-wrap; +} +.job-detail-summary, +.job-detail-progress, +.job-event-title { + display: flex; + justify-content: space-between; + gap: 10px; + align-items: center; + min-width: 0; + flex-wrap: wrap; +} +.job-detail-progress progress { + flex: 1 1 180px; + height: 10px; +} +.job-current-event { + border: 1px solid #cbd5e1; + border-radius: 6px; + background: #f8fafc; + padding: 8px; +} +.job-step-list, +.job-event-list, +.job-queue-snapshot { + display: grid; + gap: 8px; +} +.job-step, +.job-event-row, +.job-queue-item { + display: grid; + grid-template-columns: auto minmax(0, 1fr); + gap: 9px; + align-items: start; + min-width: 0; +} +.job-step, +.job-event-row { + border: 1px solid #e2e8f0; + border-radius: 6px; + background: #f8fafc; + padding: 8px; +} +.job-step-index { + display: inline-grid; + place-items: center; + width: 24px; + height: 24px; + border-radius: 999px; + border: 1px solid #cbd5e1; + background: #fff; + color: #475569; + font-weight: 700; + font-size: 11px; +} +.job-step.done .job-step-index { + border-color: #16a34a; + background: #ecfdf5; + color: #166534; +} +.job-step.current { + border-color: #93c5fd; + background: #eff6ff; +} +.job-step.current .job-step-index { + border-color: #2563eb; + background: #dbeafe; + color: #1d4ed8; +} +.job-step.failed, +.job-step.cancelled { + border-color: #fecaca; + background: #fff7f7; +} +.job-step.pending { + opacity: .72; +} +.job-queue-item { + grid-template-columns: auto minmax(0, 1fr) auto; + align-items: center; + border-top: 1px solid #e2e8f0; + padding-top: 6px; +} +.job-queue-item:first-child { + border-top: none; + padding-top: 0; +} +.job-queue-item.selected { + color: #0f172a; + font-weight: 700; +} +.spinner { + display: inline-block; + width: 14px; + height: 14px; + border: 2px solid #c8d4df; + border-top-color: #2563eb; + border-radius: 50%; + animation: spin .8s linear infinite; + vertical-align: -2px; +} +.spinner-small { + width: 12px; + height: 12px; + border-width: 2px; +} +@keyframes spin { + to { transform: rotate(360deg); } +} +.journey-options { + display: grid; + grid-template-columns: 1fr 1fr; + gap: 8px; +} +.journey-snapshot { + display: grid; + gap: 2px; + border: 1px solid #d8e0e8; + border-radius: 6px; + background: #f8fafc; + padding: 7px 8px; + font-size: 12px; + min-width: 0; +} +.journey-snapshot strong, +.journey-snapshot span { + min-width: 0; + overflow-wrap: anywhere; +} +.journey-snapshot strong { + color: #17212b; +} +.journey-mode { + display: grid; + grid-template-columns: repeat(3, 1fr); + gap: 6px; + margin: 8px 0; +} +.journey-mode label, +.journey-direct { + display: flex; + align-items: center; + gap: 6px; + font-size: 12px; +} +.journey-mode label { + justify-content: center; + min-height: 30px; + border: 1px solid #d8e0e8; + border-radius: 6px; + background: #fff; + padding: 4px 6px; +} +.journey-direct { + margin: 8px 0; +} +.journey-message { + display: flex; + align-items: center; + gap: 6px; +} +.journey-actions { + display: grid; + grid-template-columns: 1fr 1fr 1fr; + gap: 8px; +} +.journey-swap { + margin: -2px 0 4px; + justify-self: start; +} +.stop-suggestions { + display: grid; + gap: 4px; + margin-top: -4px; +} +.stop-suggestion { + display: grid; + grid-template-columns: 22px minmax(0, 1fr); + align-items: center; + gap: 6px; + width: 100%; + text-align: left; + border: 1px solid #d8e0e8; + border-radius: 6px; + background: #fff; + padding: 6px 8px; + font-size: 12px; + font-weight: 500; +} +.stop-suggestion-text { + display: grid; + min-width: 0; + gap: 2px; +} +.stop-suggestion-text strong { + overflow-wrap: anywhere; +} +.stop-suggestion-icon { + display: inline-grid; + place-items: center; + width: 18px; + height: 18px; + color: #2d5f7f; + font-size: 14px; + line-height: 1; +} +.stop-place-icon { + border: 1.5px solid #2d5f7f; + border-radius: 50%; + font-size: 11px; + font-weight: 800; +} +.stop-suggestion:hover { + background: #eef4f8; +} +.stop-suggestion-text span { + color: #667482; + font-size: 11px; + overflow-wrap: anywhere; +} +.journey-results { + margin-top: 8px; + font-size: 12px; + max-height: 260px; + overflow: auto; +} +.itinerary-panel { + margin-top: 10px; + border-top: 1px solid #dbe3eb; + padding-top: 9px; +} +.itinerary-results { + display: grid; + gap: 8px; + margin-top: 8px; + font-size: 12px; + max-height: 280px; + overflow: auto; +} +.itinerary { + border: 1px solid #dbe3eb; + border-radius: 8px; + padding: 8px; + background: #fff; +} +.itinerary.saved { + border-color: #86efac; + background: #f0fdf4; +} +.itinerary-leg { + display: flex; + justify-content: space-between; + gap: 8px; + align-items: center; + margin-top: 5px; + border-top: 1px solid #edf2f7; + padding-top: 5px; +} +.itinerary-leg span { + min-width: 0; + overflow-wrap: anywhere; +} +.journey { + border-top: 1px solid #e1e7ee; + padding: 8px 0; +} +.journey:first-child { border-top: none; } +.journey-title { + display: flex; + justify-content: space-between; + gap: 8px; + align-items: center; + font-weight: 700; +} +.journey-leg { + margin-top: 4px; + color: #2f3b46; + display: flex; + align-items: center; + flex-wrap: wrap; + gap: 4px; +} +.journey-leg strong { + color: #17212b; +} +.mode-icon { + display: inline-grid; + place-items: center; + width: 20px; + height: 20px; + border-radius: 4px; + background: #e5e7eb; + color: #17212b; + font-size: 10px; + font-weight: 800; + line-height: 1; + vertical-align: middle; + flex: 0 0 auto; +} +.mode-train { background: #ede9fe; color: #7c3aed; } +.mode-light_rail, +.mode-tram { background: #fee2e2; color: #dc2626; } +.mode-subway { background: #fee2e2; color: #ef4444; } +.mode-bus, +.mode-trolleybus { background: #fef3c7; color: #ca8a04; } +.mode-coach { background: #fef3c7; color: #a16207; } +.mode-ferry { background: #dbeafe; color: #0284c7; } +.mode-walk { background: #dcfce7; color: #16a34a; } +.mode-drive, +.mode-car { background: #ffedd5; color: #f97316; } +.mode-monorail, +.mode-funicular, +.mode-aerialway { background: #ede9fe; color: #7c3aed; } +.inline-link { + display: inline; + width: auto; + padding: 0; + border: none; + background: transparent; + color: #1d4ed8; + font: inherit; + text-align: left; + text-decoration: underline; + text-underline-offset: 2px; +} +.inline-link:hover { + color: #0f766e; +} +.map-floating { + position: absolute; + top: 12px; + right: 12px; + z-index: 600; + width: min(390px, calc(100% - 24px)); + max-height: calc(100% - 28px); + overflow: auto; + background: rgba(255,255,255,.96); + border: 1px solid #cfd8e2; + border-radius: 8px; + padding: 11px; + box-shadow: 0 8px 24px rgba(16, 24, 32, .16); +} +.map-floating h2 { + font-size: 15px; + margin: 0 0 10px; +} +.map-loading { + position: absolute; + top: 12px; + left: 50%; + transform: translateX(-50%); + z-index: 650; + display: flex; + align-items: center; + gap: 8px; + padding: 8px 11px; + border: 1px solid #c7d3df; + border-radius: 8px; + background: rgba(255,255,255,.96); + box-shadow: 0 8px 24px rgba(16, 24, 32, .14); + color: #273646; + font-size: 12px; + font-weight: 700; +} +.map-loading[hidden] { display: none; } +.journey-context-popup .leaflet-popup-content { + margin: 9px 10px; +} +.journey-context-menu { + display: grid; + gap: 8px; + min-width: 240px; + font-size: 12px; +} +.journey-context-title { + display: grid; + grid-template-columns: 22px minmax(0, 1fr); + gap: 7px; + align-items: center; +} +.journey-context-title span { + display: grid; + gap: 2px; + min-width: 0; +} +.journey-context-title strong, +.journey-context-title small { + overflow-wrap: anywhere; +} +.journey-context-title small { + color: #667482; +} +.journey-context-actions { + display: grid; + grid-template-columns: 1fr 1fr; + gap: 6px; +} +.journey-context-status { + display: flex; + align-items: center; + gap: 6px; +} +.legend { + position: absolute; + right: 12px; + bottom: 12px; + display: grid; + gap: 6px; + background: rgba(255,255,255,.94); + border: 1px solid #d1d8e0; + border-radius: 8px; + padding: 9px; + font-size: 12px; + z-index: 500; +} +.overlay { + position: fixed; + inset: 0; + z-index: 1000; + background: rgba(15, 23, 32, .34); + display: grid; + place-items: center; + padding: 20px; +} +.overlay[hidden] { display: none; } +.overlay.map-review { + background: transparent; + place-items: start end; + pointer-events: none; + padding: 12px; +} +.overlay-panel { + display: flex; + flex-direction: column; + width: min(900px, 100%); + max-height: min(680px, 84vh); + overflow: hidden; + background: #fff; + border: 1px solid #cfd8e2; + border-radius: 8px; + box-shadow: 0 18px 48px rgba(16, 24, 32, .24); + padding: 14px; +} +.overlay.map-review .overlay-panel { + pointer-events: auto; + width: min(560px, calc(100vw - 24px)); + max-height: calc(100vh - 24px); +} +#overlayContent { + min-height: 0; + overflow: auto; + padding-right: 4px; +} +.overlay-title { + z-index: 1; + display: flex; + justify-content: space-between; + gap: 12px; + align-items: center; + background: #fff; + margin-bottom: 10px; + padding: 0 0 10px; + border-bottom: 1px solid #e2e8f0; +} +.overlay-title h2 { + margin: 0; + font-size: 16px; +} +.candidate { + border-top: 1px solid #e2e8f0; + padding: 9px 0; + font-size: 12px; +} +.candidate:first-child { border-top: none; } +.candidate.selected { + margin: 0 -8px; + padding: 9px 8px; + border-radius: 8px; + background: #fff7ed; + border-top-color: transparent; + box-shadow: inset 0 0 0 1px #fed7aa; +} +.candidate-context { + display: grid; + gap: 7px; + margin-bottom: 8px; +} +.candidate-preview-legend { + display: flex; + flex-wrap: wrap; + gap: 8px; + font-size: 11px; + color: #52606d; +} +.candidate-swatch { + display: inline-block; + width: 22px; + height: 0; + border-top: 4px solid #64748b; + margin-right: 4px; + vertical-align: middle; +} +.candidate-swatch.gtfs { + border-top-color: #0f766e; + border-top-style: dashed; +} +.candidate-swatch.selected { border-top-color: #f97316; } +.candidate-title { + display: flex; + justify-content: space-between; + gap: 8px; + font-weight: 700; +} +.candidate-actions { margin-top: 6px; } +.candidate pre { + margin: 5px 0 0; + padding: 7px; + border-radius: 6px; + background: #f7fafc; + white-space: pre-wrap; + overflow-wrap: anywhere; +} +.canonical-stop-detail { + display: grid; + gap: 14px; + font-size: 12px; +} +.canonical-stop-detail h3 { + margin: 0 0 6px; + font-size: 13px; +} +.canonical-summary, +.canonical-link-row, +.canonical-candidate-row, +.rule-row { + border: 1px solid #e1e7ee; + border-radius: 8px; + padding: 9px; + background: #fbfdff; +} +.canonical-summary { + display: grid; + gap: 3px; +} +.canonical-link-row, +.canonical-candidate-row { + display: flex; + justify-content: space-between; + gap: 12px; + align-items: start; + margin-top: 6px; +} +.canonical-candidates { + display: grid; + gap: 6px; + margin-top: 8px; +} +.rule-row { + margin-top: 6px; +} +.rule-row pre { + margin: 6px 0 0; + padding: 7px; + border-radius: 6px; + background: #f7fafc; + white-space: pre-wrap; + overflow-wrap: anywhere; +} +.line { display: inline-block; width: 24px; height: 0; border-top: 4px solid #555; margin-right: 6px; vertical-align: middle; } +.line.osm { border-color: #6b7280; } +.line.gtfs { border-color: #18864b; } +.line.missing { border-color: #d03030; } +.dot { display: inline-block; width: 10px; height: 10px; border-radius: 50%; background: #334155; margin-right: 6px; } +@media (max-width: 900px) { + main { grid-template-columns: 1fr; height: auto; } + aside { height: 50vh; } + main.sidebar-collapsed { grid-template-columns: 1fr; } + main.sidebar-collapsed aside { height: 42px; } + main.sidebar-collapsed .sidebar-collapse-handle { + left: auto; + right: 8px; + bottom: 7px; + } + .map-panel { height: 50vh; } + .map-floating { + top: 8px; + right: 8px; + width: calc(100% - 16px); + max-height: calc(100% - 16px); + } + .journey-results { + max-height: 140px; + } +} +@media (max-width: 520px) { + .source-catalog-filter { + grid-template-columns: 1fr; + } + .source-catalog-actions { + grid-template-columns: 1fr; + } +} diff --git a/app/templates/index.html b/app/templates/index.html new file mode 100644 index 0000000..f58f6e1 --- /dev/null +++ b/app/templates/index.html @@ -0,0 +1,329 @@ + + + + + +Harmonized transit, mapping data, route layer, map review, and journey tests.
+| DE-BE-VBB | +Verkehrsverbund Berlin-Brandenburg | +VBB Verkehrsverbund Berlin-Brandenburg GmbH | +2026-01-01 | +2026-12-12 | +20260603 | +2026-06-03 | +2026-06-03 | +Details, ... | +
| Release Url | https://example.test/gtfs.zip |
| Publisher's License | CC BY 4.0 |
| License given for use in OSM | Attribution on contributor page is sufficient. |
| "network:guid" | DE-BE-VBB |