diff --git a/src/central/supervisor.py b/src/central/supervisor.py index 0f3d6fd..b7466fb 100644 --- a/src/central/supervisor.py +++ b/src/central/supervisor.py @@ -72,10 +72,10 @@ class AdapterState: last_completed_poll: datetime | None = None cancel_event: asyncio.Event = field(default_factory=asyncio.Event) - def __post_init__(self) -> None: - # Ensure cancel_event is created - if self.cancel_event is None: - self.cancel_event = asyncio.Event() + @property + def is_running(self) -> bool: + """Check if adapter loop is currently running.""" + return self.task is not None and not self.task.done() class Supervisor: @@ -200,13 +200,8 @@ class Supervisor: state.adapter.bump_last_seen(event.id) continue - # Build CloudEvent - if self._cloudevents_config: - envelope, msg_id = wrap_event(event, self._cloudevents_config) - else: - # Fallback for testing - envelope = {"id": event.id, "data": event.data} - msg_id = event.id + # Build CloudEvent (uses defaults if no config provided) + envelope, msg_id = wrap_event(event, self._cloudevents_config) subject = subject_for_event(event) @@ -247,14 +242,64 @@ class Supervisor: logger.info("Swept old published IDs", extra={"count": swept}) async def _start_adapter(self, config: AdapterConfig) -> None: - """Start an adapter based on its configuration.""" - if config.name in self._adapter_states: - logger.warning( - "Adapter already running", - extra={"adapter": config.name}, + """Start an adapter based on its configuration. + + If the adapter was previously stopped (state exists but task is not running), + reuses the existing state to preserve last_completed_poll for rate limiting. + """ + existing_state = self._adapter_states.get(config.name) + + if existing_state is not None: + if existing_state.is_running: + logger.warning( + "Adapter already running", + extra={"adapter": config.name}, + ) + return + + # Adapter was stopped - restart with preserved state + # Update config and restart the adapter + existing_state.config = config + existing_state.cancel_event.clear() + + # Reinitialize the adapter with new config + nws_config = self._adapter_config_to_nws_config(config) + existing_state.adapter = NWSAdapter( + config=nws_config, + cursor_db_path=CURSOR_DB_PATH, + ) + await existing_state.adapter.startup() + + # Start the loop task + existing_state.task = asyncio.create_task( + self._run_adapter_loop(existing_state) + ) + + # Calculate next poll time for logging + if existing_state.last_completed_poll: + next_poll_at = datetime.fromtimestamp( + existing_state.last_completed_poll.timestamp() + config.cadence_s, + tz=timezone.utc, + ) + if next_poll_at <= datetime.now(timezone.utc): + next_poll_at = datetime.now(timezone.utc) + else: + next_poll_at = datetime.now(timezone.utc) + + logger.info( + "Adapter restarted", + extra={ + "adapter": config.name, + "cadence_s": config.cadence_s, + "preserved_last_poll": existing_state.last_completed_poll.isoformat() + if existing_state.last_completed_poll + else None, + "next_poll": next_poll_at.isoformat(), + }, ) return + # New adapter - create fresh state if config.name == "nws": nws_config = self._adapter_config_to_nws_config(config) adapter = NWSAdapter( @@ -285,11 +330,20 @@ class Supervisor: ) async def _stop_adapter(self, name: str) -> None: - """Stop a running adapter.""" - state = self._adapter_states.pop(name, None) + """Stop a running adapter but preserve state for potential restart. + + The adapter state (including last_completed_poll) is preserved so that + if the adapter is re-enabled, the rate-limit guarantee is maintained. + Use _remove_adapter() to fully remove an adapter from tracking. + """ + state = self._adapter_states.get(name) if state is None: return + if not state.is_running: + # Already stopped + return + # Signal the loop to stop state.cancel_event.set() @@ -299,9 +353,44 @@ class Supervisor: await state.task except asyncio.CancelledError: pass + state.task = None await state.adapter.shutdown() - logger.info("Adapter stopped", extra={"adapter": name}) + logger.info( + "Adapter stopped", + extra={ + "adapter": name, + "preserved_last_poll": state.last_completed_poll.isoformat() + if state.last_completed_poll + else None, + }, + ) + + async def _remove_adapter(self, name: str) -> None: + """Fully remove an adapter, dropping all preserved state. + + Called when an adapter is deleted from the database (not just disabled). + """ + state = self._adapter_states.pop(name, None) + if state is None: + return + + # Stop if running + if state.is_running: + state.cancel_event.set() + if state.task: + state.task.cancel() + try: + await state.task + except asyncio.CancelledError: + pass + + await state.adapter.shutdown() + + logger.info( + "Adapter removed", + extra={"adapter": name}, + ) async def _reschedule_adapter( self, @@ -319,6 +408,11 @@ class Supervisor: await self._start_adapter(new_config) return + if not state.is_running: + # Adapter stopped - restart it + await self._start_adapter(new_config) + return + old_cadence = state.config.cadence_s new_cadence = new_config.cadence_s @@ -375,18 +469,18 @@ class Supervisor: current_state = self._adapter_states.get(adapter_name) if new_config is None: - # Adapter was deleted + # Adapter was deleted - fully remove, don't just stop if current_state: - await self._stop_adapter(adapter_name) + await self._remove_adapter(adapter_name) logger.info( - "Adapter deleted, stopped", + "Adapter deleted, removed", extra={"adapter": adapter_name}, ) return if not new_config.enabled or new_config.is_paused: - # Adapter disabled or paused - if current_state: + # Adapter disabled or paused - stop but preserve state + if current_state and current_state.is_running: await self._stop_adapter(adapter_name) logger.info( "Adapter disabled/paused, stopped", @@ -398,8 +492,8 @@ class Supervisor: ) return - if current_state is None: - # Adapter was enabled or created + if current_state is None or not current_state.is_running: + # Adapter was enabled or created - start (will reuse state if exists) await self._start_adapter(new_config) logger.info( "Adapter enabled, started", @@ -468,9 +562,9 @@ class Supervisor: except asyncio.CancelledError: pass - # Stop all adapters + # Remove all adapters (full cleanup) for name in list(self._adapter_states.keys()): - await self._stop_adapter(name) + await self._remove_adapter(name) # Close config source await self._config_source.close() @@ -497,21 +591,26 @@ async def async_main() -> None: toml_path=settings.config_toml_path, ) - # Load CloudEvents config for envelope generation - # For now, load from TOML regardless of config source - # (CloudEvents config is not adapter-specific) - try: - toml_config = load_config(str(settings.config_toml_path)) - cloudevents_config = toml_config - except Exception: - # If TOML doesn't exist and using DB source, create minimal config - cloudevents_config = None + # CloudEvents config: try TOML first, fall back to code defaults + # (CloudEvents envelope format is protocol-level, not operator-configurable) + cloudevents_config = None + if settings.config_source == "toml": + try: + toml_config = load_config(str(settings.config_toml_path)) + cloudevents_config = toml_config + except Exception: + pass # Will use defaults from cloudevents_constants supervisor = Supervisor( config_source=config_source, nats_url=settings.nats_url, cloudevents_config=cloudevents_config, ) + logger.info( + "CloudEvents config: %s", + "TOML" if cloudevents_config else "defaults", + extra={"cloudevents_source": "toml" if cloudevents_config else "defaults"}, + ) loop = asyncio.get_running_loop() shutdown_event = asyncio.Event() diff --git a/tests/test_supervisor_integration.py b/tests/test_supervisor_integration.py new file mode 100644 index 0000000..d3b6dc7 --- /dev/null +++ b/tests/test_supervisor_integration.py @@ -0,0 +1,546 @@ +"""Integration tests for Supervisor hot-reload with enable/disable/enable flow. + +These tests exercise the actual Supervisor._on_config_change code path, +not just AdapterState math in isolation. They verify the rate-limit +guarantee is maintained across adapter stop/start cycles. + +IMPORTANT: These tests are designed to: +- FAIL on unfixed code (Test B fails because last_completed_poll is lost) +- PASS on fixed code (last_completed_poll is preserved across disable/enable) +""" + +import asyncio +import base64 +import os +from datetime import datetime, timedelta, timezone +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio + +from central.config_models import AdapterConfig +from central.crypto import KEY_SIZE, clear_key_cache + + +def adapter_is_running(state) -> bool: + """Check if adapter is running (compatible with both fixed and unfixed code).""" + # Fixed code has is_running property; unfixed checks task directly + if hasattr(state, 'is_running'): + return state.is_running + return state.task is not None and not state.task.done() + + +async def cleanup_adapter(supervisor, name: str) -> None: + """Clean up adapter (compatible with both fixed and unfixed code).""" + # Fixed code has _remove_adapter; unfixed uses _stop_adapter which pops + if hasattr(supervisor, '_remove_adapter'): + await supervisor._remove_adapter(name) + else: + await supervisor._stop_adapter(name) + +# Test database DSN +TEST_DB_DSN = os.environ.get( + "CENTRAL_TEST_DB_DSN", + "postgresql://central_test:testpass@localhost/central_test", +) + + +@pytest.fixture(scope="session") +def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path: + """Create a master key file for the test session.""" + key = os.urandom(KEY_SIZE) + key_path = tmp_path_factory.mktemp("keys") / "master.key" + key_path.write_text(base64.b64encode(key).decode()) + return key_path + + +@pytest.fixture(autouse=True) +def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """Configure master key path for all tests.""" + clear_key_cache() + monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN) + monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path)) + + +class MockConfigSource: + """Mock ConfigSource for testing Supervisor without DB.""" + + def __init__(self) -> None: + self._adapters: dict[str, AdapterConfig] = {} + + def set_adapter(self, config: AdapterConfig | None, name: str | None = None) -> None: + """Set or remove an adapter config.""" + if config is None: + if name: + self._adapters.pop(name, None) + else: + self._adapters[config.name] = config + + async def list_enabled_adapters(self) -> list[AdapterConfig]: + return [a for a in self._adapters.values() if a.enabled and not a.is_paused] + + async def get_adapter(self, name: str) -> AdapterConfig | None: + return self._adapters.get(name) + + async def watch_for_changes(self, callback) -> None: + # No-op for testing + return + + async def close(self) -> None: + pass + + +class MockNWSAdapter: + """Mock NWSAdapter that tracks poll calls and allows control.""" + + def __init__(self, config, cursor_db_path) -> None: + self.config = config + self.cadence_s = config.cadence_s + self.states = set(s.upper() for s in config.states) + self.poll_count = 0 + self.poll_times: list[datetime] = [] + self._shutdown = False + + async def startup(self) -> None: + pass + + async def shutdown(self) -> None: + self._shutdown = True + + async def poll(self): + """Yield nothing - we just track that poll was called.""" + self.poll_count += 1 + self.poll_times.append(datetime.now(timezone.utc)) + return + yield # Make this an async generator + + def is_published(self, event_id: str) -> bool: + return False + + def mark_published(self, event_id: str) -> None: + pass + + def bump_last_seen(self, event_id: str) -> None: + pass + + def sweep_old_ids(self) -> int: + return 0 + + +@pytest.fixture +def mock_nats(): + """Mock NATS connection.""" + mock_nc = AsyncMock() + mock_nc.publish = AsyncMock() + mock_js = AsyncMock() + mock_js.publish = AsyncMock() + mock_nc.jetstream.return_value = mock_js + return mock_nc + + +class TestEnableDisableEnableIntegration: + """Integration tests for enable→disable→enable flow through Supervisor. + + These tests verify that _on_config_change → _stop_adapter → _start_adapter + preserves last_completed_poll correctly. + """ + + @pytest.mark.asyncio + async def test_enable_disable_enable_gap_longer_than_cadence( + self, mock_nats, tmp_path: Path + ) -> None: + """Test A: Re-enable after gap longer than cadence polls immediately. + + - Start adapter (cadence 60s) + - Simulate completed poll 5 minutes ago + - Disable adapter + - Re-enable adapter + - Assert next poll fires immediately (last+cadence is in past) + - Assert exactly ONE poll happens, not multiple catch-up + """ + from central.supervisor import Supervisor, AdapterState + + config_source = MockConfigSource() + initial_config = AdapterConfig( + name="nws", + enabled=True, + cadence_s=60, + settings={"states": ["ID"], "contact_email": "test@test.com"}, + paused_at=None, + updated_at=datetime.now(timezone.utc), + ) + config_source.set_adapter(initial_config) + + supervisor = Supervisor( + config_source=config_source, + nats_url="nats://localhost:4222", + cloudevents_config=None, + ) + + # Mock NATS connection + supervisor._nc = mock_nats + supervisor._js = mock_nats.jetstream() + + # Patch NWSAdapter to use our mock + with patch("central.supervisor.NWSAdapter", MockNWSAdapter): + # Start supervisor (starts adapter) + await supervisor._start_adapter(initial_config) + + state = supervisor._adapter_states.get("nws") + assert state is not None + assert adapter_is_running(state) + + # Simulate completed poll 5 minutes ago + state.last_completed_poll = datetime.now(timezone.utc) - timedelta(minutes=5) + saved_last_poll = state.last_completed_poll + + # Disable adapter + disabled_config = AdapterConfig( + name="nws", + enabled=False, + cadence_s=60, + settings={"states": ["ID"], "contact_email": "test@test.com"}, + paused_at=None, + updated_at=datetime.now(timezone.utc), + ) + config_source.set_adapter(disabled_config) + await supervisor._on_config_change("adapters", "nws") + + # Verify stopped but state preserved (THIS IS THE KEY CHECK) + # On unfixed code, state will be NONE because pop() removes it + # On fixed code, state still exists with is_running=False + state = supervisor._adapter_states.get("nws") + assert state is not None, ( + "State was removed on stop! This violates the rate-limit guarantee. " + "State should be preserved to maintain last_completed_poll." + ) + assert not adapter_is_running(state) + assert state.last_completed_poll == saved_last_poll + + # Re-enable adapter + reenabled_config = AdapterConfig( + name="nws", + enabled=True, + cadence_s=60, + settings={"states": ["ID"], "contact_email": "test@test.com"}, + paused_at=None, + updated_at=datetime.now(timezone.utc), + ) + config_source.set_adapter(reenabled_config) + await supervisor._on_config_change("adapters", "nws") + + # Verify restarted with preserved last_completed_poll + state = supervisor._adapter_states.get("nws") + assert state is not None + assert adapter_is_running(state) + assert state.last_completed_poll == saved_last_poll + + # The loop should detect that last_poll + cadence is in the past + # and poll immediately. Let's verify by checking the wait time logic. + now = datetime.now(timezone.utc) + next_poll_at = saved_last_poll.timestamp() + 60 # cadence = 60s + wait_time = max(0, next_poll_at - now.timestamp()) + + # last_poll was 5 minutes ago, cadence is 60s + # next_poll_at = 5_minutes_ago + 60s = 4_minutes_ago + # wait_time should be 0 (poll immediately) + assert wait_time == 0, ( + f"Expected immediate poll (wait=0), got wait={wait_time}s. " + f"last_poll was {saved_last_poll}, now is {now}" + ) + + # Cleanup + supervisor._shutdown_event.set() + await cleanup_adapter(supervisor, "nws") + + @pytest.mark.asyncio + async def test_enable_disable_enable_gap_shorter_than_cadence( + self, mock_nats, tmp_path: Path + ) -> None: + """Test B: Re-enable after gap shorter than cadence respects rate limit. + + THIS IS THE KEY TEST that failed before the fix. + + - Start adapter (cadence 60s) + - Simulate completed poll 10 seconds ago + - Disable adapter + - Re-enable adapter 20 seconds later (still within cadence window) + - Assert next poll fires at last_poll + 60s, NOT immediately + """ + from central.supervisor import Supervisor, AdapterState + + config_source = MockConfigSource() + initial_config = AdapterConfig( + name="nws", + enabled=True, + cadence_s=60, + settings={"states": ["ID"], "contact_email": "test@test.com"}, + paused_at=None, + updated_at=datetime.now(timezone.utc), + ) + config_source.set_adapter(initial_config) + + supervisor = Supervisor( + config_source=config_source, + nats_url="nats://localhost:4222", + cloudevents_config=None, + ) + + supervisor._nc = mock_nats + supervisor._js = mock_nats.jetstream() + + with patch("central.supervisor.NWSAdapter", MockNWSAdapter): + # Start adapter + await supervisor._start_adapter(initial_config) + + state = supervisor._adapter_states.get("nws") + assert state is not None + + # Simulate completed poll 10 seconds ago + state.last_completed_poll = datetime.now(timezone.utc) - timedelta(seconds=10) + saved_last_poll = state.last_completed_poll + + # Disable adapter + disabled_config = AdapterConfig( + name="nws", + enabled=False, + cadence_s=60, + settings={"states": ["ID"], "contact_email": "test@test.com"}, + paused_at=None, + updated_at=datetime.now(timezone.utc), + ) + config_source.set_adapter(disabled_config) + await supervisor._on_config_change("adapters", "nws") + + # Verify stopped but state preserved (THIS IS THE KEY CHECK) + # On unfixed code, state will be NONE because pop() removes it + # On fixed code, state still exists with is_running=False + state = supervisor._adapter_states.get("nws") + assert state is not None, ( + "State was removed on stop! This violates the rate-limit guarantee. " + "State should be preserved to maintain last_completed_poll." + ) + assert not adapter_is_running(state) + assert state.last_completed_poll == saved_last_poll + + # Re-enable adapter (simulate 20 seconds later, but we're just + # checking the rate limit logic) + reenabled_config = AdapterConfig( + name="nws", + enabled=True, + cadence_s=60, + settings={"states": ["ID"], "contact_email": "test@test.com"}, + paused_at=None, + updated_at=datetime.now(timezone.utc), + ) + config_source.set_adapter(reenabled_config) + await supervisor._on_config_change("adapters", "nws") + + # Verify restarted with preserved last_completed_poll + state = supervisor._adapter_states.get("nws") + assert state is not None + assert adapter_is_running(state) + assert state.last_completed_poll == saved_last_poll + + # The loop should detect that last_poll + cadence is still in the future + # and wait until then. + now = datetime.now(timezone.utc) + next_poll_at = saved_last_poll.timestamp() + 60 + wait_time = max(0, next_poll_at - now.timestamp()) + + # last_poll was ~10 seconds ago, cadence is 60s + # wait_time should be ~50s (60 - 10 = 50) + assert 45 < wait_time < 55, ( + f"Expected ~50s wait (respecting rate limit), got wait={wait_time}s. " + f"Rate limit violated: poll would happen before last_poll + cadence" + ) + + # Cleanup + supervisor._shutdown_event.set() + await cleanup_adapter(supervisor, "nws") + + @pytest.mark.asyncio + async def test_enable_disable_delete_readd_fresh_state( + self, mock_nats, tmp_path: Path + ) -> None: + """Test C: Delete then re-add clears preserved state. + + - Start adapter + - Simulate completed poll + - Disable adapter + - DELETE adapter from DB (not just disable) + - Re-add adapter with same name + - Assert preserved timestamp is dropped (fresh adapter, immediate poll) + """ + from central.supervisor import Supervisor + + config_source = MockConfigSource() + initial_config = AdapterConfig( + name="nws", + enabled=True, + cadence_s=60, + settings={"states": ["ID"], "contact_email": "test@test.com"}, + paused_at=None, + updated_at=datetime.now(timezone.utc), + ) + config_source.set_adapter(initial_config) + + supervisor = Supervisor( + config_source=config_source, + nats_url="nats://localhost:4222", + cloudevents_config=None, + ) + + supervisor._nc = mock_nats + supervisor._js = mock_nats.jetstream() + + with patch("central.supervisor.NWSAdapter", MockNWSAdapter): + # Start adapter + await supervisor._start_adapter(initial_config) + + state = supervisor._adapter_states.get("nws") + assert state is not None + + # Simulate completed poll 10 seconds ago + state.last_completed_poll = datetime.now(timezone.utc) - timedelta(seconds=10) + + # Disable adapter + disabled_config = AdapterConfig( + name="nws", + enabled=False, + cadence_s=60, + settings={"states": ["ID"], "contact_email": "test@test.com"}, + paused_at=None, + updated_at=datetime.now(timezone.utc), + ) + config_source.set_adapter(disabled_config) + await supervisor._on_config_change("adapters", "nws") + + # DELETE adapter from DB (remove from config source) + config_source.set_adapter(None, name="nws") + await supervisor._on_config_change("adapters", "nws") + + # Verify adapter fully removed + assert "nws" not in supervisor._adapter_states + + # Re-add adapter with same name + new_config = AdapterConfig( + name="nws", + enabled=True, + cadence_s=60, + settings={"states": ["ID"], "contact_email": "test@test.com"}, + paused_at=None, + updated_at=datetime.now(timezone.utc), + ) + config_source.set_adapter(new_config) + await supervisor._on_config_change("adapters", "nws") + + # Verify new adapter started fresh + state = supervisor._adapter_states.get("nws") + assert state is not None + assert adapter_is_running(state) + # last_completed_poll should be None (fresh adapter) + assert state.last_completed_poll is None, ( + f"Expected None (fresh adapter), got {state.last_completed_poll}. " + f"Preserved state not cleared on delete." + ) + + # Cleanup + supervisor._shutdown_event.set() + await cleanup_adapter(supervisor, "nws") + + @pytest.mark.asyncio + async def test_stop_preserves_state_start_reuses_it( + self, mock_nats, tmp_path: Path + ) -> None: + """Verify _stop_adapter preserves state and _start_adapter reuses it.""" + from central.supervisor import Supervisor + + config_source = MockConfigSource() + config = AdapterConfig( + name="nws", + enabled=True, + cadence_s=60, + settings={"states": ["ID"], "contact_email": "test@test.com"}, + paused_at=None, + updated_at=datetime.now(timezone.utc), + ) + config_source.set_adapter(config) + + supervisor = Supervisor( + config_source=config_source, + nats_url="nats://localhost:4222", + cloudevents_config=None, + ) + + supervisor._nc = mock_nats + supervisor._js = mock_nats.jetstream() + + with patch("central.supervisor.NWSAdapter", MockNWSAdapter): + # Start adapter + await supervisor._start_adapter(config) + + state = supervisor._adapter_states.get("nws") + state.last_completed_poll = datetime.now(timezone.utc) - timedelta(seconds=30) + saved_poll = state.last_completed_poll + + # Stop adapter + await supervisor._stop_adapter("nws") + + # State should still exist + assert "nws" in supervisor._adapter_states + state = supervisor._adapter_states["nws"] + assert not adapter_is_running(state) + assert state.last_completed_poll == saved_poll + + # Restart adapter + await supervisor._start_adapter(config) + + # Should reuse existing state + state = supervisor._adapter_states.get("nws") + assert adapter_is_running(state) + assert state.last_completed_poll == saved_poll + + # Cleanup + supervisor._shutdown_event.set() + await cleanup_adapter(supervisor, "nws") + + @pytest.mark.asyncio + async def test_remove_adapter_clears_state( + self, mock_nats, tmp_path: Path + ) -> None: + """Verify _remove_adapter fully clears state.""" + from central.supervisor import Supervisor + + config_source = MockConfigSource() + config = AdapterConfig( + name="nws", + enabled=True, + cadence_s=60, + settings={"states": ["ID"], "contact_email": "test@test.com"}, + paused_at=None, + updated_at=datetime.now(timezone.utc), + ) + config_source.set_adapter(config) + + supervisor = Supervisor( + config_source=config_source, + nats_url="nats://localhost:4222", + cloudevents_config=None, + ) + + supervisor._nc = mock_nats + supervisor._js = mock_nats.jetstream() + + with patch("central.supervisor.NWSAdapter", MockNWSAdapter): + await supervisor._start_adapter(config) + + state = supervisor._adapter_states.get("nws") + state.last_completed_poll = datetime.now(timezone.utc) + + # Remove adapter + await cleanup_adapter(supervisor, "nws") + + # State should be gone + assert "nws" not in supervisor._adapter_states