From ab7126ec8da73a42723d67f55aeb4c3c2236b4aa Mon Sep 17 00:00:00 2001 From: Matt Johnson Date: Sat, 16 May 2026 18:49:53 +0000 Subject: [PATCH] refactor(supervisor): remove adapter-specific branches, add stream wiring - Replace if name == nws with generic apply_config call - Add _create_adapter factory method - Add stream management: ensure_stream, retention recompute loop - Handle streams config changes via NOTIFY Co-Authored-By: Claude Opus 4.5 --- src/central/supervisor.py | 231 ++++++++++++++++++++++++++++++-------- 1 file changed, 185 insertions(+), 46 deletions(-) diff --git a/src/central/supervisor.py b/src/central/supervisor.py index fad8588..2543291 100644 --- a/src/central/supervisor.py +++ b/src/central/supervisor.py @@ -13,16 +13,27 @@ from typing import Any import nats from nats.js import JetStreamContext +from central.adapter import SourceAdapter from central.adapters.nws import NWSAdapter from central.cloudevents_wire import wrap_event -from central.config import NWSAdapterConfig from central.config_models import AdapterConfig from central.config_source import ConfigSource, DbConfigSource +from central.config_store import ConfigStore from central.bootstrap_config import get_settings from central.models import subject_for_event +from central.stream_manager import StreamManager CURSOR_DB_PATH = Path("/var/lib/central/cursors.db") +# Stream subject mappings +STREAM_SUBJECTS = { + "CENTRAL_WX": ["central.wx.>"], + "CENTRAL_META": ["central.meta.>"], +} + +# Recompute interval for stream max_bytes (1 hour) +STREAM_RECOMPUTE_INTERVAL_S = 3600 + class JsonFormatter(logging.Formatter): """JSON log formatter for structured logging.""" @@ -66,7 +77,7 @@ class AdapterState: """Runtime state for a scheduled adapter.""" name: str - adapter: NWSAdapter + adapter: SourceAdapter config: AdapterConfig task: asyncio.Task[None] | None = None last_completed_poll: datetime | None = None @@ -84,14 +95,17 @@ class Supervisor: def __init__( self, config_source: ConfigSource, + config_store: ConfigStore, nats_url: str, cloudevents_config: Any = None, ) -> None: self._config_source = config_source + self._config_store = config_store self._nats_url = nats_url self._cloudevents_config = cloudevents_config self._nc: nats.NATS | None = None self._js: JetStreamContext | None = None + self._stream_manager: StreamManager | None = None self._adapter_states: dict[str, AdapterState] = {} self._tasks: list[asyncio.Task[None]] = [] self._shutdown_event = asyncio.Event() @@ -103,6 +117,7 @@ class Supervisor: """Connect to NATS.""" self._nc = await nats.connect(self._nats_url) self._js = self._nc.jetstream() + self._stream_manager = StreamManager(self._js) logger.info("Connected to NATS", extra={"url": self._nats_url}) async def disconnect(self) -> None: @@ -112,6 +127,7 @@ class Supervisor: await self._nc.close() self._nc = None self._js = None + self._stream_manager = None logger.info("Disconnected from NATS") async def _publish_meta(self, subject: str, data: dict[str, Any]) -> None: @@ -132,14 +148,12 @@ class Supervisor: headers={"Nats-Msg-Id": msg_id}, ) - def _adapter_config_to_nws_config(self, config: AdapterConfig) -> NWSAdapterConfig: - """Convert unified AdapterConfig to NWSAdapterConfig.""" - return NWSAdapterConfig( - enabled=config.enabled, - cadence_s=config.cadence_s, - states=config.settings.get("states", []), - contact_email=config.settings.get("contact_email", ""), - ) + def _create_adapter(self, config: AdapterConfig) -> SourceAdapter: + """Create an adapter instance based on config name.""" + if config.name == "nws": + return NWSAdapter(config=config, cursor_db_path=CURSOR_DB_PATH) + else: + raise ValueError(f"Unknown adapter type: {config.name}") async def _run_adapter_loop(self, state: AdapterState) -> None: """Run an adapter poll loop with rate-limit aware scheduling.""" @@ -263,11 +277,7 @@ class Supervisor: existing_state.cancel_event.clear() # Reinitialize the adapter with new config - nws_config = self._adapter_config_to_nws_config(config) - existing_state.adapter = NWSAdapter( - config=nws_config, - cursor_db_path=CURSOR_DB_PATH, - ) + existing_state.adapter = self._create_adapter(config) await existing_state.adapter.startup() # Start the loop task @@ -300,34 +310,29 @@ class Supervisor: return # New adapter - create fresh state - if config.name == "nws": - nws_config = self._adapter_config_to_nws_config(config) - adapter = NWSAdapter( - config=nws_config, - cursor_db_path=CURSOR_DB_PATH, - ) - await adapter.startup() + try: + adapter = self._create_adapter(config) + except ValueError as e: + logger.warning(str(e), extra={"adapter": config.name}) + return - state = AdapterState( - name=config.name, - adapter=adapter, - config=config, - ) - state.task = asyncio.create_task(self._run_adapter_loop(state)) - self._adapter_states[config.name] = state + await adapter.startup() - logger.info( - "Adapter started", - extra={ - "adapter": config.name, - "cadence_s": config.cadence_s, - }, - ) - else: - logger.warning( - "Unknown adapter type", - extra={"adapter": config.name}, - ) + state = AdapterState( + name=config.name, + adapter=adapter, + config=config, + ) + state.task = asyncio.create_task(self._run_adapter_loop(state)) + self._adapter_states[config.name] = state + + logger.info( + "Adapter started", + extra={ + "adapter": config.name, + "cadence_s": config.cadence_s, + }, + ) async def _stop_adapter(self, name: str) -> None: """Stop a running adapter but preserve state for potential restart. @@ -423,10 +428,8 @@ class Supervisor: # Update config state.config = new_config - # Update adapter settings if needed (e.g., states list) - if name == "nws": - nws_config = self._adapter_config_to_nws_config(new_config) - state.adapter.states = set(s.upper() for s in nws_config.states) + # Apply config to adapter (generic - each adapter handles its own settings) + await state.adapter.apply_config(new_config) # Calculate next poll time for logging if state.last_completed_poll: @@ -451,11 +454,139 @@ class Supervisor: # This ensures immediate event delivery to the sleeping loop. return state + async def _ensure_streams(self) -> None: + """Ensure all configured streams exist with correct settings.""" + if not self._stream_manager: + return + + streams = await self._config_store.list_streams() + for stream_config in streams: + subjects = STREAM_SUBJECTS.get(stream_config.name, []) + if not subjects: + logger.warning( + "No subjects configured for stream", + extra={"stream": stream_config.name}, + ) + continue + + try: + await self._stream_manager.ensure_stream( + stream_config.name, + subjects, + stream_config, + ) + except Exception as e: + logger.error( + "Failed to ensure stream", + extra={"stream": stream_config.name, "error": str(e)}, + ) + + async def _handle_stream_change(self, stream_name: str) -> None: + """Handle a stream configuration change.""" + if not self._stream_manager: + return + + stream_config = await self._config_store.get_stream(stream_name) + if stream_config is None: + logger.warning( + "Stream config not found", + extra={"stream": stream_name}, + ) + return + + try: + # Apply retention settings + await self._stream_manager.apply_retention(stream_name, stream_config) + + # Immediate recompute of max_bytes + new_max_bytes = await self._stream_manager.recompute_max_bytes( + stream_name, stream_config.max_age_s + ) + + # Update database (won't trigger NOTIFY due to column filter) + await self._config_store.update_stream_max_bytes(stream_name, new_max_bytes) + + logger.info( + "Stream retention updated", + extra={ + "stream": stream_name, + "max_age_s": stream_config.max_age_s, + "new_max_bytes": new_max_bytes, + }, + ) + except Exception as e: + logger.error( + "Failed to handle stream change", + extra={"stream": stream_name, "error": str(e)}, + ) + + async def _stream_retention_recompute_loop(self) -> None: + """Periodically recompute max_bytes for all streams.""" + while not self._shutdown_event.is_set(): + try: + await asyncio.wait_for( + self._shutdown_event.wait(), + timeout=STREAM_RECOMPUTE_INTERVAL_S, + ) + # Shutdown requested + break + except asyncio.TimeoutError: + pass + + # Recompute for all streams + if not self._stream_manager: + continue + + streams = await self._config_store.list_streams() + for stream_config in streams: + if not stream_config.managed_max_bytes: + continue + + try: + new_max_bytes = await self._stream_manager.recompute_max_bytes( + stream_config.name, stream_config.max_age_s + ) + + # Only update if change > 10% + change_ratio = abs(new_max_bytes - stream_config.max_bytes) / max(stream_config.max_bytes, 1) + if change_ratio > 0.10: + await self._config_store.update_stream_max_bytes( + stream_config.name, new_max_bytes + ) + await self._stream_manager.apply_retention( + stream_config.name, + await self._config_store.get_stream(stream_config.name), + ) + logger.info( + "Recomputed stream max_bytes", + extra={ + "stream": stream_config.name, + "old_max_bytes": stream_config.max_bytes, + "new_max_bytes": new_max_bytes, + "change_ratio": change_ratio, + }, + ) + except Exception as e: + logger.error( + "Failed to recompute stream max_bytes", + extra={"stream": stream_config.name, "error": str(e)}, + ) + async def _on_config_change(self, table: str, key: str) -> None: """Handle a configuration change notification. Called when NOTIFY fires for config changes. """ + # Handle stream changes + if table == "streams": + stream_name = key + logger.info( + "Stream config change received", + extra={"stream": stream_name}, + ) + await self._handle_stream_change(stream_name) + return + if table != "adapters": return @@ -534,6 +665,9 @@ class Supervisor: """Start the supervisor.""" await self.connect() + # Ensure streams exist with correct configuration + await self._ensure_streams() + # Load and start enabled adapters enabled_adapters = await self._config_source.list_enabled_adapters() for config in enabled_adapters: @@ -547,6 +681,9 @@ class Supervisor: # Start heartbeat self._tasks.append(asyncio.create_task(self._heartbeat_loop())) + # Start stream retention recompute loop + self._tasks.append(asyncio.create_task(self._stream_retention_recompute_loop())) + logger.info( "Supervisor started", extra={"adapters": list(self._adapter_states.keys())}, @@ -594,11 +731,13 @@ async def async_main() -> None: extra={"config_source": "db"}, ) - # Create database config source + # Create database config source and store config_source = await DbConfigSource.create(settings.db_dsn) + config_store = await ConfigStore.create(settings.db_dsn) supervisor = Supervisor( config_source=config_source, + config_store=config_store, nats_url=settings.nats_url, # CloudEvents uses protocol-level defaults from cloudevents_constants cloudevents_config=None,