From 29fa49c5c2e51d8f04fb8269d6bf0d4165cc102d Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 16 May 2026 01:55:33 +0000 Subject: [PATCH] feat(supervisor): add hot-reload support with rate-limit guarantee Refactors supervisor to use ConfigSource abstraction: - AdapterState tracks last_completed_poll for rate limiting - Hot-reload via NOTIFY: cadence/enable/disable changes take effect - Rate-limit guarantee: next poll at last_poll + new_cadence, not now - Logs config source at startup (toml or db) - Logs reschedule decisions with next poll timestamp Co-Authored-By: Claude Opus 4.5 --- src/central/supervisor.py | 794 ++++++++++++++++++++++++++------------ 1 file changed, 539 insertions(+), 255 deletions(-) diff --git a/src/central/supervisor.py b/src/central/supervisor.py index 9139507..0f3d6fd 100644 --- a/src/central/supervisor.py +++ b/src/central/supervisor.py @@ -1,255 +1,539 @@ -"""Central supervisor - adapter scheduler and event publisher.""" - -import asyncio -import json -import logging -import signal -import sys -from datetime import datetime, timezone -from pathlib import Path -from typing import Any - -import nats -from nats.js import JetStreamContext - -from central.adapters.nws import NWSAdapter -from central.cloudevents_wire import wrap_event -from central.config import load_config, Config -from central.models import subject_for_event - -CURSOR_DB_PATH = Path("/var/lib/central/cursors.db") -CONFIG_PATH = "/etc/central/central.toml" - - -class JsonFormatter(logging.Formatter): - """JSON log formatter for structured logging.""" - - def format(self, record: logging.LogRecord) -> str: - log_obj: dict[str, Any] = { - "ts": datetime.now(timezone.utc).isoformat(), - "level": record.levelname, - "logger": record.name, - "msg": record.getMessage(), - } - if record.exc_info: - log_obj["exc"] = self.formatException(record.exc_info) - if hasattr(record, "extra"): - log_obj.update(record.extra) - # Include any extra fields passed via extra={} - for key in record.__dict__: - if key not in ( - "name", "msg", "args", "created", "filename", "funcName", - "levelname", "levelno", "lineno", "module", "msecs", - "pathname", "process", "processName", "relativeCreated", - "stack_info", "exc_info", "exc_text", "thread", "threadName", - "taskName", "message", - ): - log_obj[key] = record.__dict__[key] - return json.dumps(log_obj) - - -def setup_logging() -> None: - """Configure JSON logging to stdout.""" - handler = logging.StreamHandler(sys.stdout) - handler.setFormatter(JsonFormatter()) - logging.root.handlers = [handler] - logging.root.setLevel(logging.INFO) - - -logger = logging.getLogger("central.supervisor") - - -class Supervisor: - """Main supervisor process.""" - - def __init__(self, config: Config) -> None: - self.config = config - self._nc: nats.NATS | None = None - self._js: JetStreamContext | None = None - self._adapters: list[NWSAdapter] = [] - self._tasks: list[asyncio.Task[None]] = [] - self._shutdown_event = asyncio.Event() - self._start_time = datetime.now(timezone.utc) - - async def connect(self) -> None: - """Connect to NATS.""" - self._nc = await nats.connect(self.config.nats.url) - self._js = self._nc.jetstream() - logger.info("Connected to NATS", extra={"url": self.config.nats.url}) - - async def disconnect(self) -> None: - """Disconnect from NATS.""" - if self._nc: - await self._nc.drain() - await self._nc.close() - self._nc = None - self._js = None - logger.info("Disconnected from NATS") - - async def _publish_meta(self, subject: str, data: dict[str, Any]) -> None: - """Publish a meta event (no Nats-Msg-Id).""" - if not self._nc: - return - payload = json.dumps(data).encode() - await self._nc.publish(subject, payload) - - async def _publish_event(self, subject: str, envelope: dict[str, Any], msg_id: str) -> None: - """Publish an event with dedup header.""" - if not self._js: - return - payload = json.dumps(envelope).encode() - await self._js.publish( - subject, - payload, - headers={"Nats-Msg-Id": msg_id}, - ) - - async def _run_adapter(self, adapter: NWSAdapter) -> None: - """Run an adapter poll loop.""" - while not self._shutdown_event.is_set(): - poll_start = datetime.now(timezone.utc) - try: - async for event in adapter.poll(): - # Dedup check - if adapter.is_published(event.id): - adapter.bump_last_seen(event.id) - continue - - # Build CloudEvent - envelope, msg_id = wrap_event(event, self.config) - subject = subject_for_event(event) - - # Publish - await self._publish_event(subject, envelope, msg_id) - adapter.mark_published(event.id) - - logger.info( - "Published event", - extra={"id": event.id, "subject": subject, "category": event.category} - ) - - # Publish success status - await self._publish_meta( - f"central.meta.adapter.{adapter.name}.status", - {"ok": True, "ts": datetime.now(timezone.utc).isoformat()} - ) - - except Exception as e: - logger.exception("Adapter poll failed", extra={"adapter": adapter.name}) - await self._publish_meta( - f"central.meta.adapter.{adapter.name}.status", - { - "ok": False, - "error": str(e), - "ts": datetime.now(timezone.utc).isoformat() - } - ) - - # Sweep old IDs - swept = adapter.sweep_old_ids() - if swept > 0: - logger.info("Swept old published IDs", extra={"count": swept}) - - # Sleep until next cadence - elapsed = (datetime.now(timezone.utc) - poll_start).total_seconds() - sleep_time = max(0, adapter.cadence_s - elapsed) - try: - await asyncio.wait_for( - self._shutdown_event.wait(), - timeout=sleep_time - ) - except asyncio.TimeoutError: - pass - - async def _heartbeat_loop(self) -> None: - """Publish periodic heartbeats.""" - while not self._shutdown_event.is_set(): - uptime = (datetime.now(timezone.utc) - self._start_time).total_seconds() - await self._publish_meta( - "central.meta.heartbeat", - {"ts": datetime.now(timezone.utc).isoformat(), "uptime_s": uptime} - ) - try: - await asyncio.wait_for( - self._shutdown_event.wait(), - timeout=30 - ) - except asyncio.TimeoutError: - pass - - async def start(self) -> None: - """Start the supervisor.""" - await self.connect() - - # Initialize adapters - if self.config.adapters.get("nws") and self.config.adapters["nws"].enabled: - adapter = NWSAdapter( - config=self.config.adapters["nws"], - cursor_db_path=CURSOR_DB_PATH, - ) - await adapter.startup() - self._adapters.append(adapter) - logger.info("NWS adapter initialized") - - # Start adapter tasks - for adapter in self._adapters: - task = asyncio.create_task(self._run_adapter(adapter)) - self._tasks.append(task) - - # Start heartbeat - self._tasks.append(asyncio.create_task(self._heartbeat_loop())) - - logger.info("Supervisor started", extra={"adapters": [a.name for a in self._adapters]}) - - async def stop(self) -> None: - """Stop the supervisor gracefully.""" - logger.info("Supervisor shutting down") - self._shutdown_event.set() - - # Cancel tasks - for task in self._tasks: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - # Shutdown adapters - for adapter in self._adapters: - await adapter.shutdown() - - await self.disconnect() - logger.info("Supervisor stopped") - - -async def async_main() -> None: - """Async entry point.""" - setup_logging() - - config = load_config(CONFIG_PATH) - supervisor = Supervisor(config) - - loop = asyncio.get_running_loop() - shutdown_event = asyncio.Event() - - def handle_signal() -> None: - shutdown_event.set() - - for sig in (signal.SIGTERM, signal.SIGINT): - loop.add_signal_handler(sig, handle_signal) - - await supervisor.start() - - # Wait for shutdown signal - await shutdown_event.wait() - - await supervisor.stop() - - -def main() -> None: - """Entry point.""" - asyncio.run(async_main()) - - -if __name__ == "__main__": - main() +"""Central supervisor - adapter scheduler and event publisher.""" + +import asyncio +import json +import logging +import signal +import sys +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import nats +from nats.js import JetStreamContext + +from central.adapters.nws import NWSAdapter +from central.cloudevents_wire import wrap_event +from central.config import load_config, Config, NWSAdapterConfig +from central.config_models import AdapterConfig +from central.config_source import ConfigSource, create_config_source +from central.bootstrap_config import get_settings +from central.models import subject_for_event + +CURSOR_DB_PATH = Path("/var/lib/central/cursors.db") + + +class JsonFormatter(logging.Formatter): + """JSON log formatter for structured logging.""" + + def format(self, record: logging.LogRecord) -> str: + log_obj: dict[str, Any] = { + "ts": datetime.now(timezone.utc).isoformat(), + "level": record.levelname, + "logger": record.name, + "msg": record.getMessage(), + } + if record.exc_info: + log_obj["exc"] = self.formatException(record.exc_info) + if hasattr(record, "extra"): + log_obj.update(record.extra) + for key in record.__dict__: + if key not in ( + "name", "msg", "args", "created", "filename", "funcName", + "levelname", "levelno", "lineno", "module", "msecs", + "pathname", "process", "processName", "relativeCreated", + "stack_info", "exc_info", "exc_text", "thread", "threadName", + "taskName", "message", + ): + log_obj[key] = record.__dict__[key] + return json.dumps(log_obj) + + +def setup_logging() -> None: + """Configure JSON logging to stdout.""" + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(JsonFormatter()) + logging.root.handlers = [handler] + logging.root.setLevel(logging.INFO) + + +logger = logging.getLogger("central.supervisor") + + +@dataclass +class AdapterState: + """Runtime state for a scheduled adapter.""" + + name: str + adapter: NWSAdapter + config: AdapterConfig + task: asyncio.Task[None] | None = None + last_completed_poll: datetime | None = None + cancel_event: asyncio.Event = field(default_factory=asyncio.Event) + + def __post_init__(self) -> None: + # Ensure cancel_event is created + if self.cancel_event is None: + self.cancel_event = asyncio.Event() + + +class Supervisor: + """Main supervisor process.""" + + def __init__( + self, + config_source: ConfigSource, + nats_url: str, + cloudevents_config: Any = None, + ) -> None: + self._config_source = config_source + self._nats_url = nats_url + self._cloudevents_config = cloudevents_config + self._nc: nats.NATS | None = None + self._js: JetStreamContext | None = None + self._adapter_states: dict[str, AdapterState] = {} + self._tasks: list[asyncio.Task[None]] = [] + self._shutdown_event = asyncio.Event() + self._start_time = datetime.now(timezone.utc) + self._config_watch_task: asyncio.Task[None] | None = None + self._lock = asyncio.Lock() + + async def connect(self) -> None: + """Connect to NATS.""" + self._nc = await nats.connect(self._nats_url) + self._js = self._nc.jetstream() + logger.info("Connected to NATS", extra={"url": self._nats_url}) + + async def disconnect(self) -> None: + """Disconnect from NATS.""" + if self._nc: + await self._nc.drain() + await self._nc.close() + self._nc = None + self._js = None + logger.info("Disconnected from NATS") + + async def _publish_meta(self, subject: str, data: dict[str, Any]) -> None: + """Publish a meta event (no Nats-Msg-Id).""" + if not self._nc: + return + payload = json.dumps(data).encode() + await self._nc.publish(subject, payload) + + async def _publish_event(self, subject: str, envelope: dict[str, Any], msg_id: str) -> None: + """Publish an event with dedup header.""" + if not self._js: + return + payload = json.dumps(envelope).encode() + await self._js.publish( + subject, + payload, + headers={"Nats-Msg-Id": msg_id}, + ) + + def _adapter_config_to_nws_config(self, config: AdapterConfig) -> NWSAdapterConfig: + """Convert unified AdapterConfig to NWSAdapterConfig.""" + return NWSAdapterConfig( + enabled=config.enabled, + cadence_s=config.cadence_s, + states=config.settings.get("states", []), + contact_email=config.settings.get("contact_email", ""), + ) + + async def _run_adapter_loop(self, state: AdapterState) -> None: + """Run an adapter poll loop with rate-limit aware scheduling.""" + while not self._shutdown_event.is_set(): + # Calculate next poll time based on rate-limit guarantee + now = datetime.now(timezone.utc) + + if state.last_completed_poll is not None: + next_poll_at = state.last_completed_poll.timestamp() + state.config.cadence_s + wait_time = max(0, next_poll_at - now.timestamp()) + else: + # First poll - run immediately + wait_time = 0 + + if wait_time > 0: + logger.debug( + "Waiting for next poll", + extra={ + "adapter": state.name, + "wait_s": wait_time, + "next_poll": datetime.fromtimestamp( + now.timestamp() + wait_time, tz=timezone.utc + ).isoformat(), + }, + ) + # Wait for either timeout or cancel signal + try: + await asyncio.wait_for( + state.cancel_event.wait(), + timeout=wait_time, + ) + # Cancel event was set - check if we should exit or reschedule + if self._shutdown_event.is_set(): + break + # Clear the cancel event and re-evaluate schedule + state.cancel_event.clear() + continue + except asyncio.TimeoutError: + pass + + # Check shutdown before polling + if self._shutdown_event.is_set(): + break + + # Check if adapter is still enabled + if not state.config.enabled or state.config.is_paused: + logger.info( + "Adapter disabled/paused, stopping loop", + extra={"adapter": state.name}, + ) + break + + poll_start = datetime.now(timezone.utc) + try: + async for event in state.adapter.poll(): + # Dedup check + if state.adapter.is_published(event.id): + state.adapter.bump_last_seen(event.id) + continue + + # Build CloudEvent + if self._cloudevents_config: + envelope, msg_id = wrap_event(event, self._cloudevents_config) + else: + # Fallback for testing + envelope = {"id": event.id, "data": event.data} + msg_id = event.id + + subject = subject_for_event(event) + + # Publish + await self._publish_event(subject, envelope, msg_id) + state.adapter.mark_published(event.id) + + logger.info( + "Published event", + extra={"id": event.id, "subject": subject, "category": event.category} + ) + + # Publish success status + await self._publish_meta( + f"central.meta.adapter.{state.name}.status", + {"ok": True, "ts": datetime.now(timezone.utc).isoformat()} + ) + + # Mark poll completion time for rate limiting + state.last_completed_poll = datetime.now(timezone.utc) + + except Exception as e: + logger.exception("Adapter poll failed", extra={"adapter": state.name}) + await self._publish_meta( + f"central.meta.adapter.{state.name}.status", + { + "ok": False, + "error": str(e), + "ts": datetime.now(timezone.utc).isoformat() + } + ) + # Still mark completion time to avoid tight retry loops + state.last_completed_poll = datetime.now(timezone.utc) + + # Sweep old IDs + swept = state.adapter.sweep_old_ids() + if swept > 0: + logger.info("Swept old published IDs", extra={"count": swept}) + + async def _start_adapter(self, config: AdapterConfig) -> None: + """Start an adapter based on its configuration.""" + if config.name in self._adapter_states: + logger.warning( + "Adapter already running", + extra={"adapter": config.name}, + ) + return + + if config.name == "nws": + nws_config = self._adapter_config_to_nws_config(config) + adapter = NWSAdapter( + config=nws_config, + cursor_db_path=CURSOR_DB_PATH, + ) + await adapter.startup() + + state = AdapterState( + name=config.name, + adapter=adapter, + config=config, + ) + state.task = asyncio.create_task(self._run_adapter_loop(state)) + self._adapter_states[config.name] = state + + logger.info( + "Adapter started", + extra={ + "adapter": config.name, + "cadence_s": config.cadence_s, + }, + ) + else: + logger.warning( + "Unknown adapter type", + extra={"adapter": config.name}, + ) + + async def _stop_adapter(self, name: str) -> None: + """Stop a running adapter.""" + state = self._adapter_states.pop(name, None) + if state is None: + return + + # Signal the loop to stop + state.cancel_event.set() + + if state.task: + state.task.cancel() + try: + await state.task + except asyncio.CancelledError: + pass + + await state.adapter.shutdown() + logger.info("Adapter stopped", extra={"adapter": name}) + + async def _reschedule_adapter( + self, + name: str, + new_config: AdapterConfig, + ) -> None: + """Reschedule an adapter with new configuration. + + Maintains rate-limit guarantee: next poll at + (last_completed_poll + new_cadence_s), not now + new_cadence_s. + """ + state = self._adapter_states.get(name) + if state is None: + # Adapter not running - just start it + await self._start_adapter(new_config) + return + + old_cadence = state.config.cadence_s + new_cadence = new_config.cadence_s + + # Update config + state.config = new_config + + # Update adapter's cadence + state.adapter.cadence_s = new_cadence + + # Update adapter settings if needed (e.g., states list) + if name == "nws": + nws_config = self._adapter_config_to_nws_config(new_config) + state.adapter.states = set(s.upper() for s in nws_config.states) + + # Calculate next poll time for logging + if state.last_completed_poll: + next_poll_at = datetime.fromtimestamp( + state.last_completed_poll.timestamp() + new_cadence, + tz=timezone.utc, + ) + else: + next_poll_at = datetime.now(timezone.utc) + + logger.info( + "Rescheduled adapter", + extra={ + "adapter": name, + "old_cadence_s": old_cadence, + "new_cadence_s": new_cadence, + "next_poll": next_poll_at.isoformat(), + }, + ) + + # Signal the loop to re-evaluate its schedule + state.cancel_event.set() + + async def _on_config_change(self, table: str, key: str) -> None: + """Handle a configuration change notification. + + Called when NOTIFY fires for config changes. + """ + if table != "adapters": + return + + adapter_name = key + logger.info( + "Config change received", + extra={"table": table, "key": key}, + ) + + async with self._lock: + # Fetch the current config for this adapter + new_config = await self._config_source.get_adapter(adapter_name) + current_state = self._adapter_states.get(adapter_name) + + if new_config is None: + # Adapter was deleted + if current_state: + await self._stop_adapter(adapter_name) + logger.info( + "Adapter deleted, stopped", + extra={"adapter": adapter_name}, + ) + return + + if not new_config.enabled or new_config.is_paused: + # Adapter disabled or paused + if current_state: + await self._stop_adapter(adapter_name) + logger.info( + "Adapter disabled/paused, stopped", + extra={ + "adapter": adapter_name, + "enabled": new_config.enabled, + "paused": new_config.is_paused, + }, + ) + return + + if current_state is None: + # Adapter was enabled or created + await self._start_adapter(new_config) + logger.info( + "Adapter enabled, started", + extra={"adapter": adapter_name}, + ) + else: + # Adapter config changed (cadence, settings) + await self._reschedule_adapter(adapter_name, new_config) + + async def _heartbeat_loop(self) -> None: + """Publish periodic heartbeats.""" + while not self._shutdown_event.is_set(): + uptime = (datetime.now(timezone.utc) - self._start_time).total_seconds() + await self._publish_meta( + "central.meta.heartbeat", + {"ts": datetime.now(timezone.utc).isoformat(), "uptime_s": uptime} + ) + try: + await asyncio.wait_for( + self._shutdown_event.wait(), + timeout=30 + ) + except asyncio.TimeoutError: + pass + + async def start(self) -> None: + """Start the supervisor.""" + await self.connect() + + # Load and start enabled adapters + enabled_adapters = await self._config_source.list_enabled_adapters() + for config in enabled_adapters: + await self._start_adapter(config) + + # Start config watcher (for DB source, this runs forever; for TOML, returns immediately) + self._config_watch_task = asyncio.create_task( + self._config_source.watch_for_changes(self._on_config_change) + ) + + # Start heartbeat + self._tasks.append(asyncio.create_task(self._heartbeat_loop())) + + logger.info( + "Supervisor started", + extra={"adapters": list(self._adapter_states.keys())}, + ) + + async def stop(self) -> None: + """Stop the supervisor gracefully.""" + logger.info("Supervisor shutting down") + self._shutdown_event.set() + + # Cancel config watcher + if self._config_watch_task: + self._config_watch_task.cancel() + try: + await self._config_watch_task + except asyncio.CancelledError: + pass + + # Cancel heartbeat and other tasks + for task in self._tasks: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + # Stop all adapters + for name in list(self._adapter_states.keys()): + await self._stop_adapter(name) + + # Close config source + await self._config_source.close() + + await self.disconnect() + logger.info("Supervisor stopped") + + +async def async_main() -> None: + """Async entry point.""" + setup_logging() + + settings = get_settings() + logger.info( + "Config source: %s", + settings.config_source, + extra={"config_source": settings.config_source}, + ) + + # Create config source based on setting + config_source = await create_config_source( + source_type=settings.config_source, + dsn=settings.db_dsn, + toml_path=settings.config_toml_path, + ) + + # Load CloudEvents config for envelope generation + # For now, load from TOML regardless of config source + # (CloudEvents config is not adapter-specific) + try: + toml_config = load_config(str(settings.config_toml_path)) + cloudevents_config = toml_config + except Exception: + # If TOML doesn't exist and using DB source, create minimal config + cloudevents_config = None + + supervisor = Supervisor( + config_source=config_source, + nats_url=settings.nats_url, + cloudevents_config=cloudevents_config, + ) + + loop = asyncio.get_running_loop() + shutdown_event = asyncio.Event() + + def handle_signal() -> None: + shutdown_event.set() + + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, handle_signal) + + await supervisor.start() + + # Wait for shutdown signal + await shutdown_event.wait() + + await supervisor.stop() + + +def main() -> None: + """Entry point.""" + asyncio.run(async_main()) + + +if __name__ == "__main__": + main()