diff --git a/src/central/adapter.py b/src/central/adapter.py index 76ab04e..b923c7c 100644 --- a/src/central/adapter.py +++ b/src/central/adapter.py @@ -40,7 +40,14 @@ class SourceAdapter(ABC): the wizard validates it against staged api_keys state.""" wizard_order: int | None = None default_cadence_s: int - + + enrichment_locations: list[tuple[str, str]] = [] + """Coordinate field paths the supervisor enriches, as (lat_field, lon_field) + tuples into Event.data. Empty (the default) means publish as-is — no + enrichment. Each tuple names top-level keys in Event.data holding a + latitude and longitude; the supervisor extracts them, runs registered + enrichers, and attaches results under Event.data["_enriched"].""" + @abstractmethod async def poll(self) -> AsyncIterator[Event]: """ diff --git a/src/central/adapters/firms.py b/src/central/adapters/firms.py index c9b4efb..46a0d81 100644 --- a/src/central/adapters/firms.py +++ b/src/central/adapters/firms.py @@ -70,6 +70,9 @@ class FIRMSAdapter(SourceAdapter): wizard_order = 2 default_cadence_s = 300 + # Enrichment pilot (PR J): FIRMS rows carry top-level latitude/longitude. + enrichment_locations = [("latitude", "longitude")] + def __init__( self, config: AdapterConfig, diff --git a/src/central/api_key_resolver.py b/src/central/api_key_resolver.py new file mode 100644 index 0000000..e5c3ffa --- /dev/null +++ b/src/central/api_key_resolver.py @@ -0,0 +1,82 @@ +"""Resolve an adapter row's effective api-key alias and check existence. + +Four call sites share this answer: the /adapters list view, the +/adapters/{name} edit form (GET), the same edit form's POST error +re-render path, and the supervisor's adapter-start precondition. +Previously each inlined a predicate that compared the hardcoded +SourceAdapter class attribute `requires_api_key` against `config.api_keys`, +ignoring the per-row `settings[api_key_field]` value the operator +actually selected. An operator could pick alias "firms_production" via +the form, store a working key under that alias, and still see the +"⚠️ API Key Missing" chip + a disabled enable checkbox + the supervisor +refusing to start the adapter — all from the same root cause. + +Two entry points: + +- `resolve_api_key_alias` — sync. Pure function of (cls, settings). + Returns the alias the row should consult, or None when no key is + required (no SQL needed). Supervisor uses this directly because it + already has a ConfigStore and does its own key fetch. + +- `adapter_has_resolved_api_key` — async. Wraps the sync resolver with + the existence SELECT against config.api_keys. The three GUI route + call sites use this. + +Resolution order (same as the adapter's runtime read in __init__): + 1. requires_api_key is None -> no key needed. + 2. settings[api_key_field] is a non-empty string -> use that. + 3. fall back to the requires_api_key class-attribute default. +""" + +from typing import Any + +from central.adapter import SourceAdapter + + +def resolve_api_key_alias( + adapter_cls: type[SourceAdapter] | None, + settings: dict | None, +) -> str | None: + """Return the api-key alias for this adapter row, or None when no key is required. + + Pure function — no IO, no SQL. Same resolution order the adapter's own + __init__ uses when reading the alias out of config.settings. + """ + if adapter_cls is None or adapter_cls.requires_api_key is None: + return None + field = adapter_cls.api_key_field + if field and settings: + value = settings.get(field) + if isinstance(value, str) and value.strip(): + return value.strip() + return adapter_cls.requires_api_key + + +async def adapter_has_resolved_api_key( + conn: Any, + adapter_cls: type[SourceAdapter] | None, + settings: dict | None, +) -> tuple[bool, str | None]: + """Return (has_key, resolved_alias) for one adapter row. + + Args: + conn: asyncpg connection or any object exposing an awaitable fetchval + with the same signature. + adapter_cls: The discovered SourceAdapter subclass for this row, or + None when the row references an adapter no longer in the codebase. + settings: The row's jsonb settings dict, or None when the row has no + settings stored yet. + + Returns: + (has_key, resolved_alias). When the adapter does not require an api + key (requires_api_key is None) or no adapter class is available, + returns (True, None) and no SQL is issued. + """ + alias = resolve_api_key_alias(adapter_cls, settings) + if alias is None: + return True, None + has_key = await conn.fetchval( + "SELECT 1 FROM config.api_keys WHERE alias = $1", + alias, + ) + return bool(has_key), alias diff --git a/src/central/config_models.py b/src/central/config_models.py index 5516bc1..557423a 100644 --- a/src/central/config_models.py +++ b/src/central/config_models.py @@ -47,6 +47,28 @@ class AdapterConfig(BaseModel): return self.paused_at is not None +class EnrichmentConfig(BaseModel): + """Configuration for the supervisor's enrichment stage. + + Read once at supervisor startup (hot-reload is out of scope for PR J). + Defaults wire the GeocoderEnricher to the NoOpBackend (all-null bundle); + real backends arrive in PR K via backend_class + backend_settings. + """ + + enricher_class: str = Field( + default="GeocoderEnricher", description="Enricher class name to instantiate" + ) + backend_class: str = Field( + default="NoOpBackend", description="Backend class name to instantiate" + ) + backend_settings: dict[str, Any] = Field( + default_factory=dict, description="Keyword args passed to the backend constructor" + ) + cache_ttl_s: int = Field( + default=86400, description="Enrichment cache TTL in seconds (default 24h)" + ) + + class StreamConfig(BaseModel): """Configuration for a JetStream stream.""" diff --git a/src/central/enrichment/__init__.py b/src/central/enrichment/__init__.py new file mode 100644 index 0000000..0e32625 --- /dev/null +++ b/src/central/enrichment/__init__.py @@ -0,0 +1,28 @@ +"""Central enrichment framework. + +The supervisor runs registered enrichers over each event whose adapter +declares `enrichment_locations`, attaching results under +`event.data["_enriched"][]`. Provenance is explicit: anything +under `_enriched` is Central-derived; everything else in `data` is upstream +verbatim. +""" + +from central.enrichment.backends.no_op import NoOpBackend +from central.enrichment.base import Enricher +from central.enrichment.cache import EnrichmentCache +from central.enrichment.geocoder import ( + GEOCODER_FIELDS, + GeocoderBackend, + GeocoderEnricher, + all_null_bundle, +) + +__all__ = [ + "Enricher", + "EnrichmentCache", + "GeocoderEnricher", + "GeocoderBackend", + "GEOCODER_FIELDS", + "all_null_bundle", + "NoOpBackend", +] diff --git a/src/central/enrichment/backends/__init__.py b/src/central/enrichment/backends/__init__.py new file mode 100644 index 0000000..309c273 --- /dev/null +++ b/src/central/enrichment/backends/__init__.py @@ -0,0 +1,5 @@ +"""Geocoder backend implementations.""" + +from central.enrichment.backends.no_op import NoOpBackend + +__all__ = ["NoOpBackend"] diff --git a/src/central/enrichment/backends/no_op.py b/src/central/enrichment/backends/no_op.py new file mode 100644 index 0000000..77b0567 --- /dev/null +++ b/src/central/enrichment/backends/no_op.py @@ -0,0 +1,17 @@ +"""No-op geocoder backend — returns an all-null bundle for every input. + +The default backend in PR J. Real backends (Navi, Photon, Nominatim) land in +PR K; until then the framework is exercisable end-to-end with NoOpBackend, +which satisfies the GeocoderBackend contract while resolving nothing. +""" + +from typing import Any + +from central.enrichment.geocoder import all_null_bundle + + +class NoOpBackend: + """GeocoderBackend that resolves no fields.""" + + async def reverse(self, lat: float, lon: float) -> dict[str, Any]: + return all_null_bundle() diff --git a/src/central/enrichment/base.py b/src/central/enrichment/base.py new file mode 100644 index 0000000..7249355 --- /dev/null +++ b/src/central/enrichment/base.py @@ -0,0 +1,30 @@ +"""Enricher protocol — the framework-level contract. + +An Enricher takes a single location and returns a flat dict of enrichment +fields. The supervisor attaches each enricher's result under +`event.data["_enriched"][enricher.name]` before publishing. Everything under +`_enriched` is Central-provenance; everything else in `data` is upstream +verbatim. +""" + +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class Enricher(Protocol): + """Pluggable enrichment unit. + + name: short identifier, used as the key under event.data["_enriched"]. + """ + + name: str + + async def enrich(self, location: dict[str, float]) -> dict[str, Any]: + """Given a location ({"lat": float, "lon": float}), return enrichment fields. + + Fields the enricher can't resolve are present with value None (NOT + omitted) so consumers see a stable field set. Implementations must + NEVER raise — they handle their own failures and return an all-null + bundle on total failure. + """ + ... diff --git a/src/central/enrichment/cache.py b/src/central/enrichment/cache.py new file mode 100644 index 0000000..06be7af --- /dev/null +++ b/src/central/enrichment/cache.py @@ -0,0 +1,126 @@ +"""SQLite-backed enrichment cache with rounded-coords keys + TTL. + +Keyed on (enricher_name, lat_rounded, lon_rounded) where coordinates are +rounded to 4 decimal places (~11 m). Uses stdlib sqlite3 off the event loop +via asyncio.to_thread (no async-sqlite dependency in the project). A fresh +connection is opened per operation — sqlite3 connections are not safe to +share across threads, and to_thread may run ops on different pool threads. +""" + +import asyncio +import json +import logging +import sqlite3 +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + +_COORD_PRECISION = 4 + +_SCHEMA = """ +CREATE TABLE IF NOT EXISTS enrichment_cache ( + enricher_name TEXT NOT NULL, + lat_rounded REAL NOT NULL, + lon_rounded REAL NOT NULL, + payload_json TEXT NOT NULL, + cached_at TEXT NOT NULL, + PRIMARY KEY (enricher_name, lat_rounded, lon_rounded) +) +""" + + +def round_coord(value: float) -> float: + """Round a coordinate to the cache-key precision (4 dp).""" + return round(float(value), _COORD_PRECISION) + + +class EnrichmentCache: + """Thread-offloaded sqlite cache for enrichment bundles.""" + + def __init__(self, db_path: str | Path, ttl_s: int = 86400) -> None: + self._db_path = Path(db_path) + self._ttl_s = ttl_s + self._db_path.parent.mkdir(parents=True, exist_ok=True) + self._init_db() + + def _connect(self) -> sqlite3.Connection: + return sqlite3.connect(self._db_path, timeout=30) + + def _init_db(self) -> None: + conn = self._connect() + try: + conn.execute(_SCHEMA) + conn.commit() + finally: + conn.close() + + # --- sync bodies (run inside asyncio.to_thread) ------------------------ + + def _get_sync(self, enricher_name: str, lat: float, lon: float) -> dict[str, Any] | None: + lat_r = round_coord(lat) + lon_r = round_coord(lon) + conn = self._connect() + try: + cur = conn.execute( + """ + SELECT payload_json, cached_at FROM enrichment_cache + WHERE enricher_name = ? AND lat_rounded = ? AND lon_rounded = ? + """, + (enricher_name, lat_r, lon_r), + ) + row = cur.fetchone() + finally: + conn.close() + if row is None: + return None + payload_json, cached_at_iso = row + if self._is_expired(cached_at_iso): + return None + return json.loads(payload_json) + + def _set_sync( + self, enricher_name: str, lat: float, lon: float, payload: dict[str, Any] + ) -> None: + lat_r = round_coord(lat) + lon_r = round_coord(lon) + now_iso = datetime.now(timezone.utc).isoformat() + conn = self._connect() + try: + conn.execute( + """ + INSERT INTO enrichment_cache + (enricher_name, lat_rounded, lon_rounded, payload_json, cached_at) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (enricher_name, lat_rounded, lon_rounded) DO UPDATE SET + payload_json = excluded.payload_json, + cached_at = excluded.cached_at + """, + (enricher_name, lat_r, lon_r, json.dumps(payload), now_iso), + ) + conn.commit() + finally: + conn.close() + + def _is_expired(self, cached_at_iso: str) -> bool: + try: + cached_at = datetime.fromisoformat(cached_at_iso) + except ValueError: + return True + if cached_at.tzinfo is None: + cached_at = cached_at.replace(tzinfo=timezone.utc) + age_s = (datetime.now(timezone.utc) - cached_at).total_seconds() + return age_s > self._ttl_s + + # --- async surface ----------------------------------------------------- + + async def get(self, enricher_name: str, lat: float, lon: float) -> dict[str, Any] | None: + """Return the cached bundle, or None on miss / expiry.""" + return await asyncio.to_thread(self._get_sync, enricher_name, lat, lon) + + async def set( + self, enricher_name: str, lat: float, lon: float, payload: dict[str, Any] + ) -> None: + """Cache a bundle (idempotent upsert on the rounded-coords key).""" + await asyncio.to_thread(self._set_sync, enricher_name, lat, lon, payload) diff --git a/src/central/enrichment/geocoder.py b/src/central/enrichment/geocoder.py new file mode 100644 index 0000000..ab75017 --- /dev/null +++ b/src/central/enrichment/geocoder.py @@ -0,0 +1,92 @@ +"""Reverse-geocoding enricher + the pluggable backend contract. + +GeocoderEnricher is the only enricher in PR J. It owns the cache + the +all-null normalization; the actual lookup is delegated to a GeocoderBackend. +PR J ships NoOpBackend only (all-null); real backends (Navi/Photon/Nominatim) +land in PR K. +""" + +import logging +from typing import Any, Protocol, runtime_checkable + +from central.enrichment.cache import EnrichmentCache + +logger = logging.getLogger(__name__) + +# Locked canonical geocoder field set. The single source of truth for what a +# geocoder enrichment bundle looks like — backends fill what they can and +# return None for the rest; NoOpBackend returns all None. +GEOCODER_FIELDS: tuple[str, ...] = ( + "name", + "city", + "county", + "state", + "country", + "postal_code", + "timezone", + "landclass", + "elevation_m", +) + + +def all_null_bundle() -> dict[str, Any]: + """A geocoder bundle with every locked field present and None.""" + return {field: None for field in GEOCODER_FIELDS} + + +@runtime_checkable +class GeocoderBackend(Protocol): + """The pluggable reverse-geocoding layer beneath GeocoderEnricher.""" + + async def reverse(self, lat: float, lon: float) -> dict[str, Any]: + """Return canonical geocoder fields (see GEOCODER_FIELDS). + + Fields the backend can't resolve return None. Must never raise. + """ + ... + + +class GeocoderEnricher: + """Reverse-geocode a location into the canonical geocoder field set. + + Resolution: cache hit -> return cached; cache miss -> call backend, cache + the (normalized) result even when all-null, return it. Backend failure + (any exception escaping the backend's "never raise" contract) -> return + all-null and DO NOT cache, so the next call retries. + """ + + name = "geocoder" + + def __init__( + self, + backend: GeocoderBackend, + cache: EnrichmentCache | None = None, + ) -> None: + self._backend = backend + self._cache = cache + + async def enrich(self, location: dict[str, float]) -> dict[str, Any]: + lat = location.get("lat") + lon = location.get("lon") + if lat is None or lon is None: + return all_null_bundle() + + if self._cache is not None: + cached = await self._cache.get(self.name, lat, lon) + if cached is not None: + return cached + + try: + raw = await self._backend.reverse(lat, lon) + except Exception: + # Backend broke its "never raise" contract. Return all-null and do + # NOT cache, so a transient failure doesn't get pinned for the TTL. + logger.exception("geocoder backend raised; returning all-null bundle") + return all_null_bundle() + + # Normalize to the locked field set: every field present, extras dropped. + normalized = {field: raw.get(field) for field in GEOCODER_FIELDS} + + if self._cache is not None: + await self._cache.set(self.name, lat, lon, normalized) + return normalized diff --git a/src/central/gui/routes.py b/src/central/gui/routes.py index a171c47..ef7ba49 100644 --- a/src/central/gui/routes.py +++ b/src/central/gui/routes.py @@ -51,6 +51,7 @@ from pathlib import Path from central.config_models import AdapterConfig from central.gui.db import get_pool from central.gui.form_descriptors import describe_fields, FieldDescriptor +from central.api_key_resolver import adapter_has_resolved_api_key from central.adapter_discovery import discover_adapters from central.streams import STREAMS as STREAM_REGISTRY from pydantic import ValidationError @@ -1348,16 +1349,13 @@ async def adapters_list( settings = row["settings"] or {} adapter_cls = adapter_classes.get(row["name"]) - # Check if required API key is missing - api_key_missing = False - requires_api_key_alias = None - if adapter_cls and adapter_cls.requires_api_key is not None: - requires_api_key_alias = adapter_cls.requires_api_key - has_key = await conn.fetchval( - "SELECT 1 FROM config.api_keys WHERE alias = $1", - requires_api_key_alias, - ) - api_key_missing = not has_key + # Check if required API key is missing — resolve via the per-row + # settings[api_key_field] (operator-selected alias), falling back + # to the class-attribute default when settings hasn't been set. + has_key, requires_api_key_alias = await adapter_has_resolved_api_key( + conn, adapter_cls, settings, + ) + api_key_missing = not has_key adapters.append({ "name": row["name"], @@ -1445,23 +1443,15 @@ async def adapters_edit_form( if f.name == adapter_cls.api_key_field: f.widget = "api_key_select" - # Fetch API keys for api_key_select widget - api_keys = [] + # Fetch API keys for api_key_select widget + resolve the per-adapter + # alias against the operator-set settings, not the class-attr default. async with pool.acquire() as conn: api_key_rows = await conn.fetch("SELECT alias FROM config.api_keys ORDER BY alias") api_keys = [{"alias": r["alias"]} for r in api_key_rows] - - # Check if required API key is missing - api_key_missing = False - requires_api_key_alias = None - if adapter_cls and adapter_cls.requires_api_key is not None: - requires_api_key_alias = adapter_cls.requires_api_key - async with pool.acquire() as conn: - has_key = await conn.fetchval( - "SELECT 1 FROM config.api_keys WHERE alias = $1", - requires_api_key_alias, - ) - api_key_missing = not has_key + has_key, requires_api_key_alias = await adapter_has_resolved_api_key( + conn, adapter_cls, settings, + ) + api_key_missing = not has_key # Generic settings-driven preview. Adapters opt in by overriding # SourceAdapter.preview_for_settings; the framework is duck-typed on the @@ -1700,20 +1690,15 @@ async def adapters_edit_submit( if f.name == adapter_cls.api_key_field: f.widget = "api_key_select" - # Fetch API keys for api_key_select widget + # Fetch API keys for api_key_select widget + resolve the per-adapter + # alias against the pre-edit settings (form validation failed, so + # the stored settings haven't been replaced). api_key_rows = await conn.fetch("SELECT alias FROM config.api_keys ORDER BY alias") api_keys = [{"alias": r["alias"]} for r in api_key_rows] - - # Check if required API key is missing - api_key_missing = False - requires_api_key_alias = None - if adapter_cls and adapter_cls.requires_api_key is not None: - requires_api_key_alias = adapter_cls.requires_api_key - has_key = await conn.fetchval( - "SELECT 1 FROM config.api_keys WHERE alias = $1", - requires_api_key_alias, - ) - api_key_missing = not has_key + has_key, requires_api_key_alias = await adapter_has_resolved_api_key( + conn, adapter_cls, current_settings, + ) + api_key_missing = not has_key csrf_token = request.state.csrf_token response = templates.TemplateResponse( diff --git a/src/central/supervisor.py b/src/central/supervisor.py index bdc68e6..8ca6be9 100644 --- a/src/central/supervisor.py +++ b/src/central/supervisor.py @@ -20,9 +20,66 @@ from central.config_models import AdapterConfig from central.config_source import ConfigSource, DbConfigSource from central.config_store import ConfigStore from central.bootstrap_config import get_settings +from central.api_key_resolver import resolve_api_key_alias +from central.config_models import EnrichmentConfig +from central.enrichment.base import Enricher +from central.enrichment.cache import EnrichmentCache +from central.enrichment.backends.no_op import NoOpBackend +from central.enrichment.geocoder import GeocoderEnricher +from central.models import Event from central.stream_manager import StreamManager from central.streams import STREAMS as STREAM_REGISTRY CURSOR_DB_PATH = Path("/var/lib/central/cursors.db") +ENRICHMENT_CACHE_DB_PATH = Path("/var/lib/central/enrichment_cache.db") + +# Enricher / backend class-name registries for EnrichmentConfig resolution. +# PR J ships GeocoderEnricher + NoOpBackend only; PR K extends these. +_ENRICHER_REGISTRY: dict[str, type] = {"GeocoderEnricher": GeocoderEnricher} +_BACKEND_REGISTRY: dict[str, type] = {"NoOpBackend": NoOpBackend} + + +def build_enrichers( + enrichment_config: EnrichmentConfig, + cache_db_path: Path = ENRICHMENT_CACHE_DB_PATH, +) -> list[Enricher]: + """Instantiate the configured enricher(s) with their backend + cache. + + Read once at supervisor startup — enrichment config is NOT hot-reloaded + in PR J (see EnrichmentConfig docstring). + """ + backend_cls = _BACKEND_REGISTRY[enrichment_config.backend_class] + backend = backend_cls(**enrichment_config.backend_settings) + cache = EnrichmentCache(cache_db_path, ttl_s=enrichment_config.cache_ttl_s) + enricher_cls = _ENRICHER_REGISTRY[enrichment_config.enricher_class] + return [enricher_cls(backend, cache=cache)] + + +async def apply_enrichment( + event: Event, + enrichment_locations: list[tuple[str, str]], + enrichers: list[Enricher], +) -> None: + """Attach enrichment results to event.data["_enriched"] in place. + + No-op when the adapter declares no enrichment_locations or no enrichers + are registered. Uses the first (lat_path, lon_path) tuple that resolves to + a non-null coordinate pair in event.data. Each enricher's result is keyed + by enricher.name. Mutates the data dict in place (Event is frozen, but its + data dict is not — this avoids a model_copy on every published event). + """ + if not enrichment_locations or not enrichers: + return + for lat_path, lon_path in enrichment_locations: + lat = event.data.get(lat_path) + lon = event.data.get(lon_path) + if lat is None or lon is None: + continue + location = {"lat": float(lat), "lon": float(lon)} + enriched: dict[str, Any] = {} + for enricher in enrichers: + enriched[enricher.name] = await enricher.enrich(location) + event.data["_enriched"] = enriched + return # Stream subject mappings -- derived from the registry; every stream is included # (META too: supervisor must create it in JetStream even though archive skips it). @@ -95,11 +152,16 @@ class Supervisor: config_store: ConfigStore, nats_url: str, cloudevents_config: Any = None, + enrichment_config: EnrichmentConfig | None = None, ) -> None: self._config_source = config_source self._config_store = config_store self._nats_url = nats_url self._cloudevents_config = cloudevents_config + # Enrichment is read once at startup (no hot-reload in PR J). + self._enrichers: list[Enricher] = build_enrichers( + enrichment_config or EnrichmentConfig() + ) self._adapters = discover_adapters() self._nc: nats.NATS | None = None self._js: JetStreamContext | None = None @@ -216,6 +278,16 @@ class Supervisor: state.adapter.bump_last_seen(event.id) continue + # Enrichment (no-op unless the adapter declares + # enrichment_locations). Runs after dedup so we don't + # enrich events we'd skip, and before wrap_event so the + # _enriched block lands in the published payload. + await apply_enrichment( + event, + state.adapter.enrichment_locations, + self._enrichers, + ) + # Build CloudEvent (uses defaults if no config provided) envelope, msg_id = wrap_event(event, self._cloudevents_config) @@ -263,10 +335,13 @@ class Supervisor: If the adapter was previously stopped (state exists but task is not running), reuses the existing state to preserve last_completed_poll for rate limiting. """ - # API key precondition + # API key precondition — resolve via per-row settings[api_key_field] + # (operator-selected alias), falling back to the class-attribute + # default when settings hasn't been set. Returns None when no key + # is required. adapter_cls = self._adapters.get(config.name) - if adapter_cls is not None and adapter_cls.requires_api_key is not None: - alias = adapter_cls.requires_api_key + alias = resolve_api_key_alias(adapter_cls, config.settings) + if alias is not None: key_value = await self._config_store.get_api_key(alias) if not key_value: error_msg = f"missing api key: {alias}" @@ -760,6 +835,9 @@ async def async_main() -> None: nats_url=settings.nats_url, # CloudEvents uses protocol-level defaults from cloudevents_constants cloudevents_config=None, + # Enrichment defaults: GeocoderEnricher + NoOpBackend (all-null). Read + # once here at startup; PR K wires real backends + DB-backed config. + enrichment_config=EnrichmentConfig(), ) logger.info( "CloudEvents config: defaults", diff --git a/tests/test_adapters.py b/tests/test_adapters.py index aa3cd90..8ed0259 100644 --- a/tests/test_adapters.py +++ b/tests/test_adapters.py @@ -452,8 +452,17 @@ class TestAdaptersJsonbRegression: {"alias": "firms_key"}, {"alias": "other_key"}, ]) + # The /adapters/{name} edit handler also issues a fetchval against + # config.api_keys to resolve whether the adapter's key is present. + # Return 1 (truthy) so the handler proceeds — this test only asserts + # api_keys reaches the template context, not the warning state. + mock_conn.fetchval = AsyncMock(return_value=1) mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) - mock_conn.__aexit__ = AsyncMock() + # AsyncMock() with no return_value yields a MagicMock — which is truthy, + # and the async context manager protocol reads a truthy __aexit__ return + # as "exception suppressed." That silently swallows any error inside the + # `async with` block. Pin return_value=None so exceptions propagate. + mock_conn.__aexit__ = AsyncMock(return_value=None) mock_pool = MagicMock() mock_pool.acquire = MagicMock(return_value=mock_conn) diff --git a/tests/test_api_key_resolver.py b/tests/test_api_key_resolver.py new file mode 100644 index 0000000..a137f5b --- /dev/null +++ b/tests/test_api_key_resolver.py @@ -0,0 +1,249 @@ +"""Tests for central.gui._api_key_resolver.adapter_has_resolved_api_key. + +The /adapters list, /adapters/{name} edit, and POST error re-render flows +previously inlined a predicate that compared the SourceAdapter class +attribute `requires_api_key` (default literal string) against +`config.api_keys` — ignoring the per-row `settings[api_key_field]` alias +the operator actually selected. The helper consolidates the resolution +logic; these tests pin it. + +No hardcoded adapter-name list literals: keyless-adapter case picks a +real one via central.adapter_discovery. The settings-driven cases use +inline minimal SourceAdapter subclasses because no live adapter today +has the pertinent class-attr/field combinations. +""" + +from collections.abc import AsyncIterator +from typing import Any + +import pytest +from pydantic import BaseModel + +from central.adapter import SourceAdapter +from central.adapter_discovery import discover_adapters +from central.api_key_resolver import ( + adapter_has_resolved_api_key, + resolve_api_key_alias, +) +from central.models import Event + + +class _FakeConn: + """Captures the (sql, alias) of each fetchval call and returns a scripted + result. Mirrors the asyncpg connection surface adapter_has_resolved_api_key + actually uses.""" + + def __init__(self, result: Any) -> None: + self.result = result + self.calls: list[tuple[str, tuple[Any, ...]]] = [] + + async def fetchval(self, sql: str, *params: Any) -> Any: + self.calls.append((sql, params)) + return self.result + + +def _pick_keyless_adapter() -> type[SourceAdapter] | None: + """First discovered adapter with requires_api_key is None. + + Lets test 1 exercise the no-SQL short-circuit against a real subclass + without hardcoding the adapter's name. + """ + for cls in discover_adapters().values(): + if cls.requires_api_key is None: + return cls + return None + + +class _BareKeyAdapter(SourceAdapter): + """Requires a key but does NOT expose an operator-overridable field. + + Forces the helper to fall back to the class-attribute default. + """ + + name = "_bare_key" + display_name = "_BareKey" + description = "Test fixture: class-attr fallback path." + settings_schema = BaseModel + requires_api_key = "default" + api_key_field = None + default_cadence_s = 60 + + async def poll(self) -> AsyncIterator[Event]: # pragma: no cover + if False: + yield # type: ignore[unreachable] + return + + async def apply_config(self, new_config) -> None: # pragma: no cover + pass + + def subject_for(self, event: Event) -> str: # pragma: no cover + return "central.test._bare_key" + + +class _OperatorOverridableAdapter(SourceAdapter): + """Requires a key AND exposes an operator-overridable alias field. + + Models the FIRMS-shaped contract that previously misfired. + """ + + name = "_overridable" + display_name = "_Overridable" + description = "Test fixture: operator-overridable alias field." + settings_schema = BaseModel + requires_api_key = "default" + api_key_field = "api_key_alias" + default_cadence_s = 60 + + async def poll(self) -> AsyncIterator[Event]: # pragma: no cover + if False: + yield # type: ignore[unreachable] + return + + async def apply_config(self, new_config) -> None: # pragma: no cover + pass + + def subject_for(self, event: Event) -> str: # pragma: no cover + return "central.test._overridable" + + +@pytest.mark.asyncio +async def test_keyless_adapter_short_circuits_without_sql(): + """An adapter with requires_api_key=None returns (True, None) and the + helper issues NO SQL — caller can assume the gate is open.""" + keyless = _pick_keyless_adapter() + if keyless is None: + pytest.skip("no keyless adapter discovered in central.adapters") + conn = _FakeConn(result=1) + has_key, alias = await adapter_has_resolved_api_key(conn, keyless, {}) + assert (has_key, alias) == (True, None) + assert conn.calls == [], "no SQL expected for keyless adapter" + + +@pytest.mark.asyncio +async def test_class_attr_fallback_when_no_api_key_field(): + """Adapter exposes requires_api_key but no api_key_field — helper falls + back to the class-attribute default for the lookup alias.""" + conn = _FakeConn(result=1) + has_key, alias = await adapter_has_resolved_api_key(conn, _BareKeyAdapter, {}) + assert alias == "default" + assert has_key is True + assert len(conn.calls) == 1 + sql, params = conn.calls[0] + assert "config.api_keys" in sql + assert params == ("default",) + + +@pytest.mark.asyncio +async def test_settings_override_takes_precedence_over_class_attr(): + """Operator-selected alias in settings[api_key_field] wins over the + class-attribute default. Regression guard for the FIRMS-shaped bug: + requires_api_key='default' but operator stored a key under 'custom'.""" + conn = _FakeConn(result=1) + settings = {"api_key_alias": "custom"} + has_key, alias = await adapter_has_resolved_api_key( + conn, _OperatorOverridableAdapter, settings, + ) + assert alias == "custom", "operator-selected alias should win" + assert has_key is True + assert conn.calls[0][1] == ("custom",), ( + f"SQL must query the operator-selected alias, got {conn.calls[0][1]!r}" + ) + + +@pytest.mark.asyncio +async def test_missing_settings_field_falls_back_to_class_attr(): + """Operator hasn't visited the form yet — settings dict lacks the + api_key_field entry. Class-attr default is the safe fallback.""" + conn = _FakeConn(result=1) + has_key, alias = await adapter_has_resolved_api_key( + conn, _OperatorOverridableAdapter, {"unrelated": "noise"}, + ) + assert alias == "default" + assert conn.calls[0][1] == ("default",) + + # Same path when settings is None entirely. + conn2 = _FakeConn(result=1) + has_key2, alias2 = await adapter_has_resolved_api_key( + conn2, _OperatorOverridableAdapter, None, + ) + assert alias2 == "default" + assert conn2.calls[0][1] == ("default",) + + +@pytest.mark.asyncio +async def test_empty_string_settings_field_falls_back_to_class_attr(): + """Empty string in settings does NOT mean "look up alias '' "; it means + "not set, use default" — same shape the form parser yields for cleared + text inputs.""" + conn = _FakeConn(result=1) + has_key, alias = await adapter_has_resolved_api_key( + conn, _OperatorOverridableAdapter, {"api_key_alias": " "}, + ) + assert alias == "default", "whitespace-only must fall through to class attr" + assert conn.calls[0][1] == ("default",) + + conn2 = _FakeConn(result=1) + has_key2, alias2 = await adapter_has_resolved_api_key( + conn2, _OperatorOverridableAdapter, {"api_key_alias": ""}, + ) + assert alias2 == "default" + assert conn2.calls[0][1] == ("default",) + + +@pytest.mark.asyncio +async def test_resolved_alias_absent_returns_has_key_false(): + """If the resolved alias has no matching row, helper returns + (False, alias) so caller renders the warning chip with the right alias + shown in the title attribute.""" + conn = _FakeConn(result=None) + has_key, alias = await adapter_has_resolved_api_key( + conn, _OperatorOverridableAdapter, {"api_key_alias": "ghost"}, + ) + assert (has_key, alias) == (False, "ghost") + + +@pytest.mark.asyncio +async def test_unknown_adapter_cls_short_circuits(): + """Row references a deleted adapter — adapter_cls is None. Helper must + not crash, must not issue SQL, and must report no missing key (the warning + is for known adapters; orphaned rows are a separate concern).""" + conn = _FakeConn(result=None) + has_key, alias = await adapter_has_resolved_api_key(conn, None, {}) + assert (has_key, alias) == (True, None) + assert conn.calls == [] + + +# --------------------------------------------------------------------------- +# Supervisor uses the sync resolver directly (no DB round-trip — supervisor +# has its own ConfigStore and calls get_api_key on the resolved alias). +# These tests mirror the route coverage on the sync entry point. +# --------------------------------------------------------------------------- + + +def test_supervisor_uses_operator_selected_alias(): + """Supervisor's adapter-start precondition must resolve to the same alias + the GUI's warning predicate resolves to. Regression guard for the second + half of the FIRMS-shaped bug: routes opened the gate but supervisor still + refused to start with `missing api key: `.""" + alias = resolve_api_key_alias( + _OperatorOverridableAdapter, {"api_key_alias": "firms_production"}, + ) + assert alias == "firms_production" + + +def test_supervisor_falls_back_to_class_attr_when_settings_unset(): + """Operator hasn't touched settings yet — supervisor still finds the + class-attribute default (preserves pre-bug behavior for fresh installs).""" + assert resolve_api_key_alias(_OperatorOverridableAdapter, None) == "default" + assert resolve_api_key_alias(_OperatorOverridableAdapter, {}) == "default" + assert resolve_api_key_alias(_BareKeyAdapter, {}) == "default" + + +def test_supervisor_returns_none_for_keyless_adapter(): + """When the adapter doesn't need a key, supervisor should skip the + `get_api_key` call entirely — the sync resolver returns None to signal.""" + keyless = _pick_keyless_adapter() + if keyless is None: + pytest.skip("no keyless adapter discovered in central.adapters") + assert resolve_api_key_alias(keyless, {}) is None + assert resolve_api_key_alias(None, {}) is None diff --git a/tests/test_enrichment_framework.py b/tests/test_enrichment_framework.py new file mode 100644 index 0000000..98ac7c9 --- /dev/null +++ b/tests/test_enrichment_framework.py @@ -0,0 +1,160 @@ +"""Tests for the enrichment cache + framework wiring. + +Covers cache hit/miss/TTL/rounding, idempotent concurrent writes, and the +"backend failure -> all-null, not cached" contract via GeocoderEnricher. +""" + +import asyncio +import json +import sqlite3 +from pathlib import Path +from typing import Any + +import pytest + +from central.enrichment.cache import EnrichmentCache, round_coord +from central.enrichment.geocoder import GEOCODER_FIELDS, GeocoderEnricher, all_null_bundle + + +@pytest.fixture +def cache_path(tmp_path: Path) -> Path: + return tmp_path / "nested" / "enrichment_cache.db" + + +def test_init_creates_parent_dir_and_table(cache_path: Path): + assert not cache_path.parent.exists() + cache = EnrichmentCache(cache_path, ttl_s=60) + assert cache_path.parent.is_dir() + # Table exists and is queryable. + conn = sqlite3.connect(cache_path) + try: + cur = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='enrichment_cache'" + ) + assert cur.fetchone() is not None + finally: + conn.close() + + +@pytest.mark.asyncio +async def test_cache_miss_then_hit(cache_path: Path): + cache = EnrichmentCache(cache_path, ttl_s=3600) + assert await cache.get("geocoder", 45.0, -116.0) is None # miss + payload = {"name": "Somewhere", "state": "ID"} + await cache.set("geocoder", 45.0, -116.0, payload) + hit = await cache.get("geocoder", 45.0, -116.0) + assert hit == payload + + +@pytest.mark.asyncio +async def test_ttl_expiry_returns_miss(cache_path: Path): + cache = EnrichmentCache(cache_path, ttl_s=0) # everything immediately stale + await cache.set("geocoder", 1.0, 2.0, {"name": "x"}) + # ttl_s=0 -> age (>0) always exceeds ttl -> treated as expired. + assert await cache.get("geocoder", 1.0, 2.0) is None + + +def test_round_coord_4dp(): + assert round_coord(45.123456789) == 45.1235 + assert round_coord(-116.000049) == -116.0 + assert round_coord(12.99995) == 13.0 + + +@pytest.mark.asyncio +async def test_rounding_collapses_nearby_coords_to_same_key(cache_path: Path): + cache = EnrichmentCache(cache_path, ttl_s=3600) + await cache.set("geocoder", 45.12341, -116.45678, {"name": "rounded"}) + # 45.123413 / -116.456784 round to the same 4dp key -> same row. + hit = await cache.get("geocoder", 45.123413, -116.456784) + assert hit == {"name": "rounded"} + + +@pytest.mark.asyncio +async def test_concurrent_sets_do_not_double_write(cache_path: Path): + cache = EnrichmentCache(cache_path, ttl_s=3600) + await asyncio.gather( + *[cache.set("geocoder", 10.0, 20.0, {"n": i}) for i in range(20)] + ) + conn = sqlite3.connect(cache_path) + try: + count = conn.execute( + "SELECT COUNT(*) FROM enrichment_cache WHERE enricher_name='geocoder' " + "AND lat_rounded=? AND lon_rounded=?", + (10.0, 20.0), + ).fetchone()[0] + finally: + conn.close() + assert count == 1, "PRIMARY KEY must collapse concurrent writes to one row" + + +class _CountingBackend: + """Backend that counts reverse() calls; lets tests prove cache hits.""" + + def __init__(self) -> None: + self.calls = 0 + + async def reverse(self, lat: float, lon: float) -> dict[str, Any]: + self.calls += 1 + return {**all_null_bundle(), "name": "Counted", "state": "ID"} + + +class _ExplodingBackend: + """Backend that violates the never-raise contract.""" + + def __init__(self) -> None: + self.calls = 0 + + async def reverse(self, lat: float, lon: float) -> dict[str, Any]: + self.calls += 1 + raise RuntimeError("upstream geocoder down") + + +@pytest.mark.asyncio +async def test_backend_failure_returns_all_null_and_does_not_cache(cache_path: Path): + cache = EnrichmentCache(cache_path, ttl_s=3600) + backend = _ExplodingBackend() + enricher = GeocoderEnricher(backend, cache=cache) + + result = await enricher.enrich({"lat": 5.0, "lon": 6.0}) + assert result == all_null_bundle() + + # Nothing cached -> a second call retries the backend (calls increments). + assert await cache.get("geocoder", 5.0, 6.0) is None + await enricher.enrich({"lat": 5.0, "lon": 6.0}) + assert backend.calls == 2, "failed lookups must not be cached (must retry)" + + +@pytest.mark.asyncio +async def test_successful_result_is_cached_and_avoids_second_backend_call(cache_path: Path): + cache = EnrichmentCache(cache_path, ttl_s=3600) + backend = _CountingBackend() + enricher = GeocoderEnricher(backend, cache=cache) + + first = await enricher.enrich({"lat": 7.5, "lon": 8.5}) + second = await enricher.enrich({"lat": 7.5, "lon": 8.5}) + assert first == second + assert backend.calls == 1, "second call with same coords must hit cache" + + +@pytest.mark.asyncio +async def test_all_null_result_is_cached(cache_path: Path): + """A backend that resolves nothing still gets cached — the contract says + cache even all-null so we don't re-hammer the backend for known-empty + coordinates.""" + + class _NullCounting: + def __init__(self) -> None: + self.calls = 0 + + async def reverse(self, lat: float, lon: float) -> dict[str, Any]: + self.calls += 1 + return all_null_bundle() + + cache = EnrichmentCache(cache_path, ttl_s=3600) + backend = _NullCounting() + enricher = GeocoderEnricher(backend, cache=cache) + await enricher.enrich({"lat": 1.0, "lon": 1.0}) + await enricher.enrich({"lat": 1.0, "lon": 1.0}) + assert backend.calls == 1 + cached = await cache.get("geocoder", 1.0, 1.0) + assert cached == all_null_bundle() diff --git a/tests/test_firms.py b/tests/test_firms.py index bfe629a..2ab1e1d 100644 --- a/tests/test_firms.py +++ b/tests/test_firms.py @@ -421,3 +421,58 @@ class TestApplyConfig: assert adapter._satellites == ["VIIRS_NOAA20_NRT"] await adapter.shutdown() + + +class TestEnrichmentIntegration: + """FIRMS is the PR J enrichment pilot.""" + + def test_enrichment_locations_declared_and_resolvable(self, temp_db_path, mock_config_store): + """FIRMS declares enrichment_locations and the declared paths actually + resolve to coordinates in a real event's data — verified structurally, + not by hardcoding the literal tuple.""" + locations = getattr(FIRMSAdapter, "enrichment_locations") + assert isinstance(locations, list) and len(locations) >= 1 + for tup in locations: + assert isinstance(tup, tuple) and len(tup) == 2 + assert all(isinstance(p, str) for p in tup) + + config = make_adapter_config() + adapter = FIRMSAdapter( + config=config, + config_store=mock_config_store, + cursor_db_path=temp_db_path, + ) + rows = adapter._parse_csv(SAMPLE_CSV, "VIIRS_SNPP_NRT") + event = adapter._row_to_event(rows[0], "VIIRS_SNPP_NRT") + # Every declared (lat_path, lon_path) must resolve to a float in data. + for lat_path, lon_path in locations: + assert isinstance(event.data.get(lat_path), float) + assert isinstance(event.data.get(lon_path), float) + + @pytest.mark.asyncio + async def test_event_passes_through_supervisor_enrichment( + self, tmp_path, temp_db_path, mock_config_store + ): + """A FIRMS event run through the supervisor's enrichment stage emerges + with data._enriched.geocoder populated (all-null under NoOpBackend).""" + from central.config_models import EnrichmentConfig + from central.enrichment.geocoder import all_null_bundle + from central.supervisor import apply_enrichment, build_enrichers + + config = make_adapter_config() + adapter = FIRMSAdapter( + config=config, + config_store=mock_config_store, + cursor_db_path=temp_db_path, + ) + rows = adapter._parse_csv(SAMPLE_CSV, "VIIRS_SNPP_NRT") + event = adapter._row_to_event(rows[0], "VIIRS_SNPP_NRT") + assert "_enriched" not in event.data + + enrichers = build_enrichers( + EnrichmentConfig(), cache_db_path=tmp_path / "enrichment_cache.db" + ) + await apply_enrichment(event, adapter.enrichment_locations, enrichers) + + assert "_enriched" in event.data + assert event.data["_enriched"]["geocoder"] == all_null_bundle() diff --git a/tests/test_geocoder_enricher.py b/tests/test_geocoder_enricher.py new file mode 100644 index 0000000..df0e1fc --- /dev/null +++ b/tests/test_geocoder_enricher.py @@ -0,0 +1,65 @@ +"""Tests for GeocoderEnricher with the default NoOpBackend.""" + +from typing import Any + +import pytest + +from central.enrichment.backends.no_op import NoOpBackend +from central.enrichment.cache import EnrichmentCache +from central.enrichment.geocoder import ( + GEOCODER_FIELDS, + GeocoderEnricher, + all_null_bundle, +) + + +@pytest.mark.asyncio +async def test_noop_backend_returns_all_null_bundle(): + enricher = GeocoderEnricher(NoOpBackend()) + result = await enricher.enrich({"lat": 45.0, "lon": -116.0}) + assert result == all_null_bundle() + assert all(v is None for v in result.values()) + + +@pytest.mark.asyncio +async def test_field_set_matches_locked_protocol(): + """Every field in the locked GEOCODER_FIELDS set is present (all None for + NoOpBackend), and no extra keys leak through — bidirectional equality.""" + enricher = GeocoderEnricher(NoOpBackend()) + result = await enricher.enrich({"lat": 1.0, "lon": 2.0}) + assert set(result.keys()) == set(GEOCODER_FIELDS) + + +@pytest.mark.asyncio +async def test_missing_coords_returns_all_null_without_backend_call(): + class _Tripwire: + async def reverse(self, lat: float, lon: float) -> dict[str, Any]: + raise AssertionError("backend must not be called for null coords") + + enricher = GeocoderEnricher(_Tripwire()) + assert await enricher.enrich({"lat": None, "lon": None}) == all_null_bundle() # type: ignore[dict-item] + assert await enricher.enrich({}) == all_null_bundle() + + +@pytest.mark.asyncio +async def test_enricher_name_is_geocoder(): + """The name keys the result under event.data['_enriched'][name].""" + assert GeocoderEnricher(NoOpBackend()).name == "geocoder" + + +@pytest.mark.asyncio +async def test_sequential_calls_same_coords_hit_cache(tmp_path): + class _CountingNoOp: + def __init__(self) -> None: + self.calls = 0 + + async def reverse(self, lat: float, lon: float) -> dict[str, Any]: + self.calls += 1 + return all_null_bundle() + + cache = EnrichmentCache(tmp_path / "c.db", ttl_s=3600) + backend = _CountingNoOp() + enricher = GeocoderEnricher(backend, cache=cache) + for _ in range(5): + await enricher.enrich({"lat": 33.5, "lon": -111.9}) + assert backend.calls == 1, "repeated identical coords must collapse to one backend call"