Merge refactor/a1-self-describing-adapters: self-describing adapter pattern

- Add display_name, description, settings_schema (Pydantic), requires_api_key, wizard_order, default_cadence_s to SourceAdapter ABC
- Implement in NWSAdapter, FIRMSAdapter, USGSQuakeAdapter
- Auto-discovery via pkgutil.iter_modules
- Fix quake stream bug (events now route to CENTRAL_QUAKE)
- 308 tests pass, live verified on CT104

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Matt Johnson 2026-05-18 22:49:42 +00:00
commit 87f46e8b35
11 changed files with 512 additions and 342 deletions

View file

@ -4,6 +4,8 @@ from abc import ABC, abstractmethod
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from pydantic import BaseModel
if TYPE_CHECKING: if TYPE_CHECKING:
from central.config_models import AdapterConfig from central.config_models import AdapterConfig
@ -16,9 +18,24 @@ class SourceAdapter(ABC):
Adapters yield Events. The supervisor handles scheduling, Adapters yield Events. The supervisor handles scheduling,
CloudEvents wrapping, publish, and metadata heartbeats. CloudEvents wrapping, publish, and metadata heartbeats.
Class attributes that subclasses must define:
name: Short identifier, e.g. "nws"
display_name: Human-readable name for GUI
description: Short description of the adapter
settings_schema: Pydantic model class for adapter settings
requires_api_key: Key alias if API key required, else None
wizard_order: Order in setup wizard (None = not in wizard)
default_cadence_s: Default polling interval in seconds
""" """
name: str # short identifier, e.g. "nws" name: str
display_name: str
description: str
settings_schema: type[BaseModel]
requires_api_key: str | None = None
wizard_order: int | None = None
default_cadence_s: int
@abstractmethod @abstractmethod
async def poll(self) -> AsyncIterator[Event]: async def poll(self) -> AsyncIterator[Event]:
@ -40,6 +57,16 @@ class SourceAdapter(ABC):
""" """
... ...
@abstractmethod
def subject_for(self, event: Event) -> str:
"""
Compute the NATS subject for an event.
Each adapter knows its own subject hierarchy. The supervisor
calls this to determine where to publish each event.
"""
...
async def startup(self) -> None: async def startup(self) -> None:
"""Optional lifecycle hook called before first poll.""" """Optional lifecycle hook called before first poll."""
pass pass

View file

@ -18,6 +18,8 @@ from tenacity import (
) )
from central.adapter import SourceAdapter from central.adapter import SourceAdapter
from pydantic import BaseModel
from central.config_models import AdapterConfig, RegionConfig from central.config_models import AdapterConfig, RegionConfig
from central.config_store import ConfigStore from central.config_store import ConfigStore
from central.models import Event, Geo from central.models import Event, Geo
@ -49,10 +51,23 @@ SEVERITY_MAP = {
} }
class FIRMSSettings(BaseModel):
"""Settings schema for FIRMS adapter."""
api_key_alias: str = "firms"
satellites: list[str] = ["VIIRS_SNPP_NRT", "VIIRS_NOAA20_NRT"]
region: RegionConfig | None = None
class FIRMSAdapter(SourceAdapter): class FIRMSAdapter(SourceAdapter):
"""NASA FIRMS fire hotspot adapter.""" """NASA FIRMS fire hotspot adapter."""
name = "firms" name = "firms"
display_name = "NASA FIRMS Fire Hotspots"
description = "Near-real-time satellite-detected fire hotspots from NASA FIRMS."
settings_schema = FIRMSSettings
requires_api_key = "firms"
wizard_order = 2
default_cadence_s = 300
def __init__( def __init__(
self, self,
@ -116,6 +131,15 @@ class FIRMSAdapter(SourceAdapter):
}, },
) )
def subject_for(self, event: Event) -> str:
"""Compute NATS subject for a fire hotspot event.
Subject format: central.fire.hotspot.<satellite>.<confidence>
The category already contains this structure.
"""
return f"central.{event.category}"
async def startup(self) -> None: async def startup(self) -> None:
"""Initialize HTTP session, dedup tracker, and fetch API key.""" """Initialize HTTP session, dedup tracker, and fetch API key."""
# Fetch API key # Fetch API key
@ -417,14 +441,3 @@ class FIRMSAdapter(SourceAdapter):
}, },
) )
def subject_for_fire_hotspot(ev: Event) -> str:
"""Compute the NATS subject for a fire hotspot event.
Subject format: central.fire.hotspot.<satellite>.<confidence>
The category already contains the satellite and confidence info,
so we just prefix with 'central.'.
"""
# category is "fire.hotspot.<satellite>.<confidence>"
return f"central.{ev.category}"

View file

@ -19,6 +19,8 @@ from tenacity import (
from central import __version__ from central import __version__
from central.adapter import SourceAdapter from central.adapter import SourceAdapter
from pydantic import BaseModel
from central.config_models import AdapterConfig, RegionConfig from central.config_models import AdapterConfig, RegionConfig
from central.config_store import ConfigStore from central.config_store import ConfigStore
from central.models import Event, Geo from central.models import Event, Geo
@ -189,10 +191,22 @@ def _build_regions(same_codes: list[str], ugc_codes: list[str]) -> list[str]:
return sorted(regions) return sorted(regions)
class NWSSettings(BaseModel):
"""Settings schema for NWS adapter."""
contact_email: str = ""
region: RegionConfig | None = None
class NWSAdapter(SourceAdapter): class NWSAdapter(SourceAdapter):
"""National Weather Service alerts adapter.""" """National Weather Service alerts adapter."""
name = "nws" name = "nws"
display_name = "NWS Weather Alerts"
description = "National Weather Service active alerts via api.weather.gov."
settings_schema = NWSSettings
requires_api_key = None
wizard_order = 1
default_cadence_s = 60
def __init__( def __init__(
self, self,
@ -234,6 +248,35 @@ class NWSAdapter(SourceAdapter):
}, },
) )
def subject_for(self, event: Event) -> str:
"""Compute NATS subject for a weather alert.
Subject format: central.wx.alert.us.<state>.<type>.<code>
where type is 'county' or 'zone' based on primary_region format.
"""
prefix = "central.wx"
if event.geo.primary_region is None:
return f"{prefix}.alert.us.unknown"
region = event.geo.primary_region
# Parse US-<STATE>-<CODE> format
parts = region.split("-")
if len(parts) < 3 or parts[0] != "US":
return f"{prefix}.alert.us.unknown"
state = parts[1].lower()
code = "-".join(parts[2:]) # Handle multi-part names
if code.startswith("Z") and len(code) >= 2 and code[1:].isdigit():
# Zone code like Z033
return f"{prefix}.alert.us.{state}.zone.{code.lower()}"
else:
# County name
return f"{prefix}.alert.us.{state}.county.{code.lower()}"
def _geometry_intersects_region(self, geometry: dict[str, Any] | None) -> bool: def _geometry_intersects_region(self, geometry: dict[str, Any] | None) -> bool:
"""Check if feature geometry intersects configured region bbox. """Check if feature geometry intersects configured region bbox.

View file

@ -17,6 +17,8 @@ from tenacity import (
) )
from central.adapter import SourceAdapter from central.adapter import SourceAdapter
from pydantic import BaseModel
from central.config_models import AdapterConfig, RegionConfig from central.config_models import AdapterConfig, RegionConfig
from central.config_store import ConfigStore from central.config_store import ConfigStore
from central.models import Event, Geo from central.models import Event, Geo
@ -60,10 +62,22 @@ def magnitude_to_severity(mag: float) -> int:
return 5 return 5
class USGSQuakeSettings(BaseModel):
"""Settings schema for USGS quake adapter."""
feed: str = "all_hour"
region: RegionConfig | None = None
class USGSQuakeAdapter(SourceAdapter): class USGSQuakeAdapter(SourceAdapter):
"""USGS Earthquake Hazards Program adapter.""" """USGS Earthquake Hazards Program adapter."""
name = "usgs_quake" name = "usgs_quake"
display_name = "USGS Earthquakes"
description = "USGS earthquake feed (configurable window)."
settings_schema = USGSQuakeSettings
requires_api_key = None
wizard_order = 3
default_cadence_s = 60
def __init__( def __init__(
self, self,
@ -398,3 +412,9 @@ class USGSQuakeAdapter(SourceAdapter):
new_count += 1 new_count += 1
logger.info("USGS quake yielded events", extra={"count": new_count}) logger.info("USGS quake yielded events", extra={"count": new_count})
def subject_for(self, event: Event) -> str:
"""Return NATS subject for quake event."""
return f"central.{event.category}"

View file

@ -32,48 +32,3 @@ class Event(BaseModel):
data: dict[str, Any] # adapter-specific payload data: dict[str, Any] # adapter-specific payload
def subject_for_event(ev: Event) -> str:
"""
Compute the NATS subject for an event based on its category.
Dispatch by category prefix:
- fire.*: returns central.<category> directly
- wx.*: uses weather alert subject logic
Weather alert subjects:
central.wx.alert.us.<state_lower>.county.<county_lower>
or
central.wx.alert.us.<state_lower>.zone.<zone_lower>
based on whether the primary_region encodes a county or a zone.
Fire hotspot subjects:
central.fire.hotspot.<satellite>.<confidence>
"""
# Fire events: subject is just central.<category>
if ev.category.startswith("fire."):
return f"central.{ev.category}"
# Weather events: use geo-based subject logic
prefix = "central.wx"
if ev.geo.primary_region is None:
return f"{prefix}.alert.us.unknown"
region = ev.geo.primary_region
# Parse US-<STATE>-<CODE> format
# County codes are like "Ada", "Canyon" (names)
# Zone codes start with "Z" like "Z033"
parts = region.split("-")
if len(parts) < 3 or parts[0] != "US":
return f"{prefix}.alert.us.unknown"
state = parts[1].lower()
code = "-".join(parts[2:]) # Handle multi-part names like "Payette-Washington"
if code.startswith("Z") and len(code) >= 2 and code[1:].isdigit():
# Zone code like Z033
return f"{prefix}.alert.us.{state}.zone.{code.lower()}"
else:
# County name
return f"{prefix}.alert.us.{state}.county.{code.lower()}"

View file

@ -13,24 +13,40 @@ from typing import Any
import nats import nats
from nats.js import JetStreamContext from nats.js import JetStreamContext
import importlib
import pkgutil
from central.adapter import SourceAdapter from central.adapter import SourceAdapter
from central.adapters.nws import NWSAdapter
from central.adapters.firms import FIRMSAdapter
from central.adapters.usgs_quake import USGSQuakeAdapter
from central.cloudevents_wire import wrap_event from central.cloudevents_wire import wrap_event
from central.config_models import AdapterConfig from central.config_models import AdapterConfig
from central.config_source import ConfigSource, DbConfigSource from central.config_source import ConfigSource, DbConfigSource
from central.config_store import ConfigStore from central.config_store import ConfigStore
from central.bootstrap_config import get_settings from central.bootstrap_config import get_settings
from central.models import subject_for_event
from central.stream_manager import StreamManager from central.stream_manager import StreamManager
import central.adapters
# Adapter registry - add new adapters here def discover_adapters() -> dict[str, type[SourceAdapter]]:
_ADAPTER_REGISTRY: dict[str, type[SourceAdapter]] = { """Auto-discover adapter classes from central.adapters package."""
"nws": NWSAdapter, registry: dict[str, type[SourceAdapter]] = {}
"firms": FIRMSAdapter, for module_info in pkgutil.iter_modules(central.adapters.__path__):
"usgs_quake": USGSQuakeAdapter, try:
} module = importlib.import_module(f"central.adapters.{module_info.name}")
except Exception as e:
logger.error(
"Failed to import adapter module",
extra={"module": module_info.name, "error": str(e)},
)
continue
for attr_name in dir(module):
attr = getattr(module, attr_name)
if (
isinstance(attr, type)
and issubclass(attr, SourceAdapter)
and attr is not SourceAdapter
and hasattr(attr, "name")
):
registry[attr.name] = attr
return registry
CURSOR_DB_PATH = Path("/var/lib/central/cursors.db") CURSOR_DB_PATH = Path("/var/lib/central/cursors.db")
@ -114,6 +130,7 @@ class Supervisor:
self._config_store = config_store self._config_store = config_store
self._nats_url = nats_url self._nats_url = nats_url
self._cloudevents_config = cloudevents_config self._cloudevents_config = cloudevents_config
self._adapters = discover_adapters()
self._nc: nats.NATS | None = None self._nc: nats.NATS | None = None
self._js: JetStreamContext | None = None self._js: JetStreamContext | None = None
self._stream_manager: StreamManager | None = None self._stream_manager: StreamManager | None = None
@ -161,7 +178,7 @@ class Supervisor:
def _create_adapter(self, config: AdapterConfig) -> SourceAdapter: def _create_adapter(self, config: AdapterConfig) -> SourceAdapter:
"""Create an adapter instance based on config name.""" """Create an adapter instance based on config name."""
cls = _ADAPTER_REGISTRY.get(config.name) cls = self._adapters.get(config.name)
if cls is None: if cls is None:
raise ValueError(f"Unknown adapter type: {config.name}") raise ValueError(f"Unknown adapter type: {config.name}")
return cls( return cls(
@ -232,7 +249,7 @@ class Supervisor:
# Build CloudEvent (uses defaults if no config provided) # Build CloudEvent (uses defaults if no config provided)
envelope, msg_id = wrap_event(event, self._cloudevents_config) envelope, msg_id = wrap_event(event, self._cloudevents_config)
subject = subject_for_event(event) subject = state.adapter.subject_for(event)
# Publish # Publish
await self._publish_event(subject, envelope, msg_id) await self._publish_event(subject, envelope, msg_id)

View file

@ -10,7 +10,6 @@ from central.adapters.firms import (
FIRMSAdapter, FIRMSAdapter,
CONFIDENCE_MAP, CONFIDENCE_MAP,
SATELLITE_SHORT, SATELLITE_SHORT,
subject_for_fire_hotspot,
) )
from central.config_models import AdapterConfig, RegionConfig from central.config_models import AdapterConfig, RegionConfig
from central.models import Event, Geo from central.models import Event, Geo
@ -285,7 +284,14 @@ class TestDeduplication:
class TestSubjectGeneration: class TestSubjectGeneration:
"""Test subject generation for fire hotspots.""" """Test subject generation for fire hotspots."""
def test_subject_format(self): @pytest.mark.asyncio
async def test_subject_format(self, temp_db_path, mock_config_store):
config = make_adapter_config()
adapter = FIRMSAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
event = Event( event = Event(
id="test", id="test",
adapter="firms", adapter="firms",
@ -296,10 +302,17 @@ class TestSubjectGeneration:
data={}, data={},
) )
subject = subject_for_fire_hotspot(event) subject = adapter.subject_for(event)
assert subject == "central.fire.hotspot.viirs_snpp.high" assert subject == "central.fire.hotspot.viirs_snpp.high"
def test_subject_nominal_confidence(self): @pytest.mark.asyncio
async def test_subject_nominal_confidence(self, temp_db_path, mock_config_store):
config = make_adapter_config()
adapter = FIRMSAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
event = Event( event = Event(
id="test", id="test",
adapter="firms", adapter="firms",
@ -310,7 +323,7 @@ class TestSubjectGeneration:
data={}, data={},
) )
subject = subject_for_fire_hotspot(event) subject = adapter.subject_for(event)
assert subject == "central.fire.hotspot.viirs_noaa20.nominal" assert subject == "central.fire.hotspot.viirs_noaa20.nominal"

View file

@ -4,7 +4,7 @@ from datetime import datetime, timezone
import pytest import pytest
from central.models import Event, Geo, subject_for_event from central.models import Event, Geo
from central.config import NWSAdapterConfig, CloudEventsConfig, NATSConfig, PostgresConfig, Config from central.config import NWSAdapterConfig, CloudEventsConfig, NATSConfig, PostgresConfig, Config
from central.cloudevents_wire import wrap_event from central.cloudevents_wire import wrap_event
@ -57,47 +57,6 @@ def sample_config() -> Config:
) )
class TestSubjectForEvent:
"""Tests for subject_for_event helper."""
def test_county_subject(self, sample_event: Event) -> None:
"""County codes produce county subject."""
subject = subject_for_event(sample_event)
assert subject == "central.wx.alert.us.id.county.ada"
def test_zone_subject(self, sample_geo: Geo) -> None:
"""Zone codes produce zone subject."""
geo = Geo(
centroid=sample_geo.centroid,
bbox=sample_geo.bbox,
regions=["US-ID-Z033"],
primary_region="US-ID-Z033",
)
event = Event(
id="test-zone",
adapter="nws",
category="wx.alert.winter_storm_warning",
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
geo=geo,
data={},
)
subject = subject_for_event(event)
assert subject == "central.wx.alert.us.id.zone.z033"
def test_unknown_subject(self, sample_event: Event) -> None:
"""Missing primary_region produces unknown subject."""
geo = Geo(regions=[], primary_region=None)
event = Event(
id="test-unknown",
adapter="nws",
category="wx.alert.test",
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
geo=geo,
data={},
)
subject = subject_for_event(event)
assert subject == "central.wx.alert.us.unknown"
class TestCloudEventsWire: class TestCloudEventsWire:
"""Tests for CloudEvents wire format.""" """Tests for CloudEvents wire format."""

View file

@ -17,7 +17,6 @@ from central.adapters.nws import (
SEVERITY_MAP, SEVERITY_MAP,
) )
from central.config_models import AdapterConfig from central.config_models import AdapterConfig
from central.models import subject_for_event
# Sample NWS GeoJSON features for testing # Sample NWS GeoJSON features for testing
@ -272,7 +271,7 @@ class TestSubjectDerivation:
def test_county_subject(self, adapter: NWSAdapter) -> None: def test_county_subject(self, adapter: NWSAdapter) -> None:
event = adapter._normalize_feature(SAMPLE_FEATURE_ID) event = adapter._normalize_feature(SAMPLE_FEATURE_ID)
assert event is not None assert event is not None
subject = subject_for_event(event) subject = adapter.subject_for(event)
# Primary region should be alphabetically first # Primary region should be alphabetically first
# Could be county or zone depending on sort order # Could be county or zone depending on sort order
assert subject.startswith("central.wx.alert.us.id.") assert subject.startswith("central.wx.alert.us.id.")
@ -294,7 +293,7 @@ class TestSubjectDerivation:
} }
event = adapter._normalize_feature(feature) event = adapter._normalize_feature(feature)
assert event is not None assert event is not None
subject = subject_for_event(event) subject = adapter.subject_for(event)
assert "zone" in subject assert "zone" in subject

View file

@ -200,76 +200,77 @@ class TestEnableDisableEnableIntegration:
supervisor._js = mock_nats.jetstream() supervisor._js = mock_nats.jetstream()
# Patch NWSAdapter to use our mock # Patch NWSAdapter to use our mock
with patch("central.supervisor.NWSAdapter", MockNWSAdapter): # Inject mock adapter into supervisor's registry
# Start supervisor (starts adapter) supervisor._adapters["nws"] = MockNWSAdapter
await supervisor._start_adapter(initial_config) # Start supervisor (starts adapter)
await supervisor._start_adapter(initial_config)
state = supervisor._adapter_states.get("nws") state = supervisor._adapter_states.get("nws")
assert state is not None assert state is not None
assert adapter_is_running(state) assert adapter_is_running(state)
# Simulate completed poll 5 minutes ago # Simulate completed poll 5 minutes ago
state.last_completed_poll = datetime.now(timezone.utc) - timedelta(minutes=5) state.last_completed_poll = datetime.now(timezone.utc) - timedelta(minutes=5)
saved_last_poll = state.last_completed_poll saved_last_poll = state.last_completed_poll
# Disable adapter # Disable adapter
disabled_config = AdapterConfig( disabled_config = AdapterConfig(
name="nws", name="nws",
enabled=False, enabled=False,
cadence_s=60, cadence_s=60,
settings={"states": ["ID"], "contact_email": "test@test.com"}, settings={"states": ["ID"], "contact_email": "test@test.com"},
paused_at=None, paused_at=None,
updated_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc),
) )
config_source.set_adapter(disabled_config) config_source.set_adapter(disabled_config)
await supervisor._on_config_change("adapters", "nws") await supervisor._on_config_change("adapters", "nws")
# Verify stopped but state preserved (THIS IS THE KEY CHECK) # Verify stopped but state preserved (THIS IS THE KEY CHECK)
# On unfixed code, state will be NONE because pop() removes it # On unfixed code, state will be NONE because pop() removes it
# On fixed code, state still exists with is_running=False # On fixed code, state still exists with is_running=False
state = supervisor._adapter_states.get("nws") state = supervisor._adapter_states.get("nws")
assert state is not None, ( assert state is not None, (
"State was removed on stop! This violates the rate-limit guarantee. " "State was removed on stop! This violates the rate-limit guarantee. "
"State should be preserved to maintain last_completed_poll." "State should be preserved to maintain last_completed_poll."
) )
assert not adapter_is_running(state) assert not adapter_is_running(state)
assert state.last_completed_poll == saved_last_poll assert state.last_completed_poll == saved_last_poll
# Re-enable adapter # Re-enable adapter
reenabled_config = AdapterConfig( reenabled_config = AdapterConfig(
name="nws", name="nws",
enabled=True, enabled=True,
cadence_s=60, cadence_s=60,
settings={"states": ["ID"], "contact_email": "test@test.com"}, settings={"states": ["ID"], "contact_email": "test@test.com"},
paused_at=None, paused_at=None,
updated_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc),
) )
config_source.set_adapter(reenabled_config) config_source.set_adapter(reenabled_config)
await supervisor._on_config_change("adapters", "nws") await supervisor._on_config_change("adapters", "nws")
# Verify restarted with preserved last_completed_poll # Verify restarted with preserved last_completed_poll
state = supervisor._adapter_states.get("nws") state = supervisor._adapter_states.get("nws")
assert state is not None assert state is not None
assert adapter_is_running(state) assert adapter_is_running(state)
assert state.last_completed_poll == saved_last_poll assert state.last_completed_poll == saved_last_poll
# The loop should detect that last_poll + cadence is in the past # The loop should detect that last_poll + cadence is in the past
# and poll immediately. Let's verify by checking the wait time logic. # and poll immediately. Let's verify by checking the wait time logic.
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
next_poll_at = saved_last_poll.timestamp() + 60 # cadence = 60s next_poll_at = saved_last_poll.timestamp() + 60 # cadence = 60s
wait_time = max(0, next_poll_at - now.timestamp()) wait_time = max(0, next_poll_at - now.timestamp())
# last_poll was 5 minutes ago, cadence is 60s # last_poll was 5 minutes ago, cadence is 60s
# next_poll_at = 5_minutes_ago + 60s = 4_minutes_ago # next_poll_at = 5_minutes_ago + 60s = 4_minutes_ago
# wait_time should be 0 (poll immediately) # wait_time should be 0 (poll immediately)
assert wait_time == 0, ( assert wait_time == 0, (
f"Expected immediate poll (wait=0), got wait={wait_time}s. " f"Expected immediate poll (wait=0), got wait={wait_time}s. "
f"last_poll was {saved_last_poll}, now is {now}" f"last_poll was {saved_last_poll}, now is {now}"
) )
# Cleanup # Cleanup
supervisor._shutdown_event.set() supervisor._shutdown_event.set()
await cleanup_adapter(supervisor, "nws") await cleanup_adapter(supervisor, "nws")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_enable_disable_enable_gap_shorter_than_cadence( async def test_enable_disable_enable_gap_shorter_than_cadence(
@ -308,75 +309,76 @@ class TestEnableDisableEnableIntegration:
supervisor._nc = mock_nats supervisor._nc = mock_nats
supervisor._js = mock_nats.jetstream() supervisor._js = mock_nats.jetstream()
with patch("central.supervisor.NWSAdapter", MockNWSAdapter): # Inject mock adapter into supervisor's registry
# Start adapter supervisor._adapters["nws"] = MockNWSAdapter
await supervisor._start_adapter(initial_config) # Start adapter
await supervisor._start_adapter(initial_config)
state = supervisor._adapter_states.get("nws") state = supervisor._adapter_states.get("nws")
assert state is not None assert state is not None
# Simulate completed poll 10 seconds ago # Simulate completed poll 10 seconds ago
state.last_completed_poll = datetime.now(timezone.utc) - timedelta(seconds=10) state.last_completed_poll = datetime.now(timezone.utc) - timedelta(seconds=10)
saved_last_poll = state.last_completed_poll saved_last_poll = state.last_completed_poll
# Disable adapter # Disable adapter
disabled_config = AdapterConfig( disabled_config = AdapterConfig(
name="nws", name="nws",
enabled=False, enabled=False,
cadence_s=60, cadence_s=60,
settings={"states": ["ID"], "contact_email": "test@test.com"}, settings={"states": ["ID"], "contact_email": "test@test.com"},
paused_at=None, paused_at=None,
updated_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc),
) )
config_source.set_adapter(disabled_config) config_source.set_adapter(disabled_config)
await supervisor._on_config_change("adapters", "nws") await supervisor._on_config_change("adapters", "nws")
# Verify stopped but state preserved (THIS IS THE KEY CHECK) # Verify stopped but state preserved (THIS IS THE KEY CHECK)
# On unfixed code, state will be NONE because pop() removes it # On unfixed code, state will be NONE because pop() removes it
# On fixed code, state still exists with is_running=False # On fixed code, state still exists with is_running=False
state = supervisor._adapter_states.get("nws") state = supervisor._adapter_states.get("nws")
assert state is not None, ( assert state is not None, (
"State was removed on stop! This violates the rate-limit guarantee. " "State was removed on stop! This violates the rate-limit guarantee. "
"State should be preserved to maintain last_completed_poll." "State should be preserved to maintain last_completed_poll."
) )
assert not adapter_is_running(state) assert not adapter_is_running(state)
assert state.last_completed_poll == saved_last_poll assert state.last_completed_poll == saved_last_poll
# Re-enable adapter (simulate 20 seconds later, but we're just # Re-enable adapter (simulate 20 seconds later, but we're just
# checking the rate limit logic) # checking the rate limit logic)
reenabled_config = AdapterConfig( reenabled_config = AdapterConfig(
name="nws", name="nws",
enabled=True, enabled=True,
cadence_s=60, cadence_s=60,
settings={"states": ["ID"], "contact_email": "test@test.com"}, settings={"states": ["ID"], "contact_email": "test@test.com"},
paused_at=None, paused_at=None,
updated_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc),
) )
config_source.set_adapter(reenabled_config) config_source.set_adapter(reenabled_config)
await supervisor._on_config_change("adapters", "nws") await supervisor._on_config_change("adapters", "nws")
# Verify restarted with preserved last_completed_poll # Verify restarted with preserved last_completed_poll
state = supervisor._adapter_states.get("nws") state = supervisor._adapter_states.get("nws")
assert state is not None assert state is not None
assert adapter_is_running(state) assert adapter_is_running(state)
assert state.last_completed_poll == saved_last_poll assert state.last_completed_poll == saved_last_poll
# The loop should detect that last_poll + cadence is still in the future # The loop should detect that last_poll + cadence is still in the future
# and wait until then. # and wait until then.
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
next_poll_at = saved_last_poll.timestamp() + 60 next_poll_at = saved_last_poll.timestamp() + 60
wait_time = max(0, next_poll_at - now.timestamp()) wait_time = max(0, next_poll_at - now.timestamp())
# last_poll was ~10 seconds ago, cadence is 60s # last_poll was ~10 seconds ago, cadence is 60s
# wait_time should be ~50s (60 - 10 = 50) # wait_time should be ~50s (60 - 10 = 50)
assert 45 < wait_time < 55, ( assert 45 < wait_time < 55, (
f"Expected ~50s wait (respecting rate limit), got wait={wait_time}s. " f"Expected ~50s wait (respecting rate limit), got wait={wait_time}s. "
f"Rate limit violated: poll would happen before last_poll + cadence" f"Rate limit violated: poll would happen before last_poll + cadence"
) )
# Cleanup # Cleanup
supervisor._shutdown_event.set() supervisor._shutdown_event.set()
await cleanup_adapter(supervisor, "nws") await cleanup_adapter(supervisor, "nws")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_enable_disable_delete_readd_fresh_state( async def test_enable_disable_delete_readd_fresh_state(
@ -414,60 +416,61 @@ class TestEnableDisableEnableIntegration:
supervisor._nc = mock_nats supervisor._nc = mock_nats
supervisor._js = mock_nats.jetstream() supervisor._js = mock_nats.jetstream()
with patch("central.supervisor.NWSAdapter", MockNWSAdapter): # Inject mock adapter into supervisor's registry
# Start adapter supervisor._adapters["nws"] = MockNWSAdapter
await supervisor._start_adapter(initial_config) # Start adapter
await supervisor._start_adapter(initial_config)
state = supervisor._adapter_states.get("nws") state = supervisor._adapter_states.get("nws")
assert state is not None assert state is not None
# Simulate completed poll 10 seconds ago # Simulate completed poll 10 seconds ago
state.last_completed_poll = datetime.now(timezone.utc) - timedelta(seconds=10) state.last_completed_poll = datetime.now(timezone.utc) - timedelta(seconds=10)
# Disable adapter # Disable adapter
disabled_config = AdapterConfig( disabled_config = AdapterConfig(
name="nws", name="nws",
enabled=False, enabled=False,
cadence_s=60, cadence_s=60,
settings={"states": ["ID"], "contact_email": "test@test.com"}, settings={"states": ["ID"], "contact_email": "test@test.com"},
paused_at=None, paused_at=None,
updated_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc),
) )
config_source.set_adapter(disabled_config) config_source.set_adapter(disabled_config)
await supervisor._on_config_change("adapters", "nws") await supervisor._on_config_change("adapters", "nws")
# DELETE adapter from DB (remove from config source) # DELETE adapter from DB (remove from config source)
config_source.set_adapter(None, name="nws") config_source.set_adapter(None, name="nws")
await supervisor._on_config_change("adapters", "nws") await supervisor._on_config_change("adapters", "nws")
# Verify adapter fully removed # Verify adapter fully removed
assert "nws" not in supervisor._adapter_states assert "nws" not in supervisor._adapter_states
# Re-add adapter with same name # Re-add adapter with same name
new_config = AdapterConfig( new_config = AdapterConfig(
name="nws", name="nws",
enabled=True, enabled=True,
cadence_s=60, cadence_s=60,
settings={"states": ["ID"], "contact_email": "test@test.com"}, settings={"states": ["ID"], "contact_email": "test@test.com"},
paused_at=None, paused_at=None,
updated_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc),
) )
config_source.set_adapter(new_config) config_source.set_adapter(new_config)
await supervisor._on_config_change("adapters", "nws") await supervisor._on_config_change("adapters", "nws")
# Verify new adapter started fresh # Verify new adapter started fresh
state = supervisor._adapter_states.get("nws") state = supervisor._adapter_states.get("nws")
assert state is not None assert state is not None
assert adapter_is_running(state) assert adapter_is_running(state)
# last_completed_poll should be None (fresh adapter) # last_completed_poll should be None (fresh adapter)
assert state.last_completed_poll is None, ( assert state.last_completed_poll is None, (
f"Expected None (fresh adapter), got {state.last_completed_poll}. " f"Expected None (fresh adapter), got {state.last_completed_poll}. "
f"Preserved state not cleared on delete." f"Preserved state not cleared on delete."
) )
# Cleanup # Cleanup
supervisor._shutdown_event.set() supervisor._shutdown_event.set()
await cleanup_adapter(supervisor, "nws") await cleanup_adapter(supervisor, "nws")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_stop_preserves_state_start_reuses_it( async def test_stop_preserves_state_start_reuses_it(
@ -497,34 +500,35 @@ class TestEnableDisableEnableIntegration:
supervisor._nc = mock_nats supervisor._nc = mock_nats
supervisor._js = mock_nats.jetstream() supervisor._js = mock_nats.jetstream()
with patch("central.supervisor.NWSAdapter", MockNWSAdapter): # Inject mock adapter into supervisor's registry
# Start adapter supervisor._adapters["nws"] = MockNWSAdapter
await supervisor._start_adapter(config) # Start adapter
await supervisor._start_adapter(config)
state = supervisor._adapter_states.get("nws") state = supervisor._adapter_states.get("nws")
state.last_completed_poll = datetime.now(timezone.utc) - timedelta(seconds=30) state.last_completed_poll = datetime.now(timezone.utc) - timedelta(seconds=30)
saved_poll = state.last_completed_poll saved_poll = state.last_completed_poll
# Stop adapter # Stop adapter
await supervisor._stop_adapter("nws") await supervisor._stop_adapter("nws")
# State should still exist # State should still exist
assert "nws" in supervisor._adapter_states assert "nws" in supervisor._adapter_states
state = supervisor._adapter_states["nws"] state = supervisor._adapter_states["nws"]
assert not adapter_is_running(state) assert not adapter_is_running(state)
assert state.last_completed_poll == saved_poll assert state.last_completed_poll == saved_poll
# Restart adapter # Restart adapter
await supervisor._start_adapter(config) await supervisor._start_adapter(config)
# Should reuse existing state # Should reuse existing state
state = supervisor._adapter_states.get("nws") state = supervisor._adapter_states.get("nws")
assert adapter_is_running(state) assert adapter_is_running(state)
assert state.last_completed_poll == saved_poll assert state.last_completed_poll == saved_poll
# Cleanup # Cleanup
supervisor._shutdown_event.set() supervisor._shutdown_event.set()
await cleanup_adapter(supervisor, "nws") await cleanup_adapter(supervisor, "nws")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_remove_adapter_clears_state( async def test_remove_adapter_clears_state(
@ -554,14 +558,15 @@ class TestEnableDisableEnableIntegration:
supervisor._nc = mock_nats supervisor._nc = mock_nats
supervisor._js = mock_nats.jetstream() supervisor._js = mock_nats.jetstream()
with patch("central.supervisor.NWSAdapter", MockNWSAdapter): # Inject mock adapter into supervisor's registry
await supervisor._start_adapter(config) supervisor._adapters["nws"] = MockNWSAdapter
await supervisor._start_adapter(config)
state = supervisor._adapter_states.get("nws") state = supervisor._adapter_states.get("nws")
state.last_completed_poll = datetime.now(timezone.utc) state.last_completed_poll = datetime.now(timezone.utc)
# Remove adapter # Remove adapter
await cleanup_adapter(supervisor, "nws") await cleanup_adapter(supervisor, "nws")
# State should be gone # State should be gone
assert "nws" not in supervisor._adapter_states assert "nws" not in supervisor._adapter_states

View file

@ -480,3 +480,122 @@ class TestApplyConfig:
assert adapter._feed == "all_day" assert adapter._feed == "all_day"
await adapter.shutdown() await adapter.shutdown()
class TestSubjectFor:
"""Test subject_for method for all magnitude tiers."""
@pytest.mark.asyncio
async def test_subject_minor(self, temp_db_path, mock_config_store):
config = make_adapter_config()
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
event = Event(
id="test-minor",
adapter="usgs_quake",
category="quake.event.minor",
time=datetime.now(timezone.utc),
severity=0,
geo=Geo(centroid=(-116.0, 45.0)),
data={},
)
assert adapter.subject_for(event) == "central.quake.event.minor"
@pytest.mark.asyncio
async def test_subject_light(self, temp_db_path, mock_config_store):
config = make_adapter_config()
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
event = Event(
id="test-light",
adapter="usgs_quake",
category="quake.event.light",
time=datetime.now(timezone.utc),
severity=1,
geo=Geo(centroid=(-116.0, 45.0)),
data={},
)
assert adapter.subject_for(event) == "central.quake.event.light"
@pytest.mark.asyncio
async def test_subject_moderate(self, temp_db_path, mock_config_store):
config = make_adapter_config()
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
event = Event(
id="test-moderate",
adapter="usgs_quake",
category="quake.event.moderate",
time=datetime.now(timezone.utc),
severity=2,
geo=Geo(centroid=(-116.0, 45.0)),
data={},
)
assert adapter.subject_for(event) == "central.quake.event.moderate"
@pytest.mark.asyncio
async def test_subject_strong(self, temp_db_path, mock_config_store):
config = make_adapter_config()
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
event = Event(
id="test-strong",
adapter="usgs_quake",
category="quake.event.strong",
time=datetime.now(timezone.utc),
severity=3,
geo=Geo(centroid=(-116.0, 45.0)),
data={},
)
assert adapter.subject_for(event) == "central.quake.event.strong"
@pytest.mark.asyncio
async def test_subject_major(self, temp_db_path, mock_config_store):
config = make_adapter_config()
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
event = Event(
id="test-major",
adapter="usgs_quake",
category="quake.event.major",
time=datetime.now(timezone.utc),
severity=4,
geo=Geo(centroid=(-116.0, 45.0)),
data={},
)
assert adapter.subject_for(event) == "central.quake.event.major"
@pytest.mark.asyncio
async def test_subject_great(self, temp_db_path, mock_config_store):
config = make_adapter_config()
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
event = Event(
id="test-great",
adapter="usgs_quake",
category="quake.event.great",
time=datetime.now(timezone.utc),
severity=5,
geo=Geo(centroid=(-116.0, 45.0)),
data={},
)
assert adapter.subject_for(event) == "central.quake.event.great"