feat(supervisor): add hot-reload support with rate-limit guarantee

Refactors supervisor to use ConfigSource abstraction:
- AdapterState tracks last_completed_poll for rate limiting
- Hot-reload via NOTIFY: cadence/enable/disable changes take effect
- Rate-limit guarantee: next poll at last_poll + new_cadence, not now
- Logs config source at startup (toml or db)
- Logs reschedule decisions with next poll timestamp

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Ubuntu 2026-05-16 01:55:33 +00:00
commit 29fa49c5c2

View file

@ -5,6 +5,7 @@ import json
import logging import logging
import signal import signal
import sys import sys
from dataclasses import dataclass, field
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -14,11 +15,13 @@ from nats.js import JetStreamContext
from central.adapters.nws import NWSAdapter from central.adapters.nws import NWSAdapter
from central.cloudevents_wire import wrap_event from central.cloudevents_wire import wrap_event
from central.config import load_config, Config from central.config import load_config, Config, NWSAdapterConfig
from central.config_models import AdapterConfig
from central.config_source import ConfigSource, create_config_source
from central.bootstrap_config import get_settings
from central.models import subject_for_event from central.models import subject_for_event
CURSOR_DB_PATH = Path("/var/lib/central/cursors.db") CURSOR_DB_PATH = Path("/var/lib/central/cursors.db")
CONFIG_PATH = "/etc/central/central.toml"
class JsonFormatter(logging.Formatter): class JsonFormatter(logging.Formatter):
@ -35,7 +38,6 @@ class JsonFormatter(logging.Formatter):
log_obj["exc"] = self.formatException(record.exc_info) log_obj["exc"] = self.formatException(record.exc_info)
if hasattr(record, "extra"): if hasattr(record, "extra"):
log_obj.update(record.extra) log_obj.update(record.extra)
# Include any extra fields passed via extra={}
for key in record.__dict__: for key in record.__dict__:
if key not in ( if key not in (
"name", "msg", "args", "created", "filename", "funcName", "name", "msg", "args", "created", "filename", "funcName",
@ -59,23 +61,49 @@ def setup_logging() -> None:
logger = logging.getLogger("central.supervisor") logger = logging.getLogger("central.supervisor")
@dataclass
class AdapterState:
"""Runtime state for a scheduled adapter."""
name: str
adapter: NWSAdapter
config: AdapterConfig
task: asyncio.Task[None] | None = None
last_completed_poll: datetime | None = None
cancel_event: asyncio.Event = field(default_factory=asyncio.Event)
def __post_init__(self) -> None:
# Ensure cancel_event is created
if self.cancel_event is None:
self.cancel_event = asyncio.Event()
class Supervisor: class Supervisor:
"""Main supervisor process.""" """Main supervisor process."""
def __init__(self, config: Config) -> None: def __init__(
self.config = config self,
config_source: ConfigSource,
nats_url: str,
cloudevents_config: Any = None,
) -> None:
self._config_source = config_source
self._nats_url = nats_url
self._cloudevents_config = cloudevents_config
self._nc: nats.NATS | None = None self._nc: nats.NATS | None = None
self._js: JetStreamContext | None = None self._js: JetStreamContext | None = None
self._adapters: list[NWSAdapter] = [] self._adapter_states: dict[str, AdapterState] = {}
self._tasks: list[asyncio.Task[None]] = [] self._tasks: list[asyncio.Task[None]] = []
self._shutdown_event = asyncio.Event() self._shutdown_event = asyncio.Event()
self._start_time = datetime.now(timezone.utc) self._start_time = datetime.now(timezone.utc)
self._config_watch_task: asyncio.Task[None] | None = None
self._lock = asyncio.Lock()
async def connect(self) -> None: async def connect(self) -> None:
"""Connect to NATS.""" """Connect to NATS."""
self._nc = await nats.connect(self.config.nats.url) self._nc = await nats.connect(self._nats_url)
self._js = self._nc.jetstream() self._js = self._nc.jetstream()
logger.info("Connected to NATS", extra={"url": self.config.nats.url}) logger.info("Connected to NATS", extra={"url": self._nats_url})
async def disconnect(self) -> None: async def disconnect(self) -> None:
"""Disconnect from NATS.""" """Disconnect from NATS."""
@ -104,24 +132,87 @@ class Supervisor:
headers={"Nats-Msg-Id": msg_id}, headers={"Nats-Msg-Id": msg_id},
) )
async def _run_adapter(self, adapter: NWSAdapter) -> None: def _adapter_config_to_nws_config(self, config: AdapterConfig) -> NWSAdapterConfig:
"""Run an adapter poll loop.""" """Convert unified AdapterConfig to NWSAdapterConfig."""
return NWSAdapterConfig(
enabled=config.enabled,
cadence_s=config.cadence_s,
states=config.settings.get("states", []),
contact_email=config.settings.get("contact_email", ""),
)
async def _run_adapter_loop(self, state: AdapterState) -> None:
"""Run an adapter poll loop with rate-limit aware scheduling."""
while not self._shutdown_event.is_set(): while not self._shutdown_event.is_set():
# Calculate next poll time based on rate-limit guarantee
now = datetime.now(timezone.utc)
if state.last_completed_poll is not None:
next_poll_at = state.last_completed_poll.timestamp() + state.config.cadence_s
wait_time = max(0, next_poll_at - now.timestamp())
else:
# First poll - run immediately
wait_time = 0
if wait_time > 0:
logger.debug(
"Waiting for next poll",
extra={
"adapter": state.name,
"wait_s": wait_time,
"next_poll": datetime.fromtimestamp(
now.timestamp() + wait_time, tz=timezone.utc
).isoformat(),
},
)
# Wait for either timeout or cancel signal
try:
await asyncio.wait_for(
state.cancel_event.wait(),
timeout=wait_time,
)
# Cancel event was set - check if we should exit or reschedule
if self._shutdown_event.is_set():
break
# Clear the cancel event and re-evaluate schedule
state.cancel_event.clear()
continue
except asyncio.TimeoutError:
pass
# Check shutdown before polling
if self._shutdown_event.is_set():
break
# Check if adapter is still enabled
if not state.config.enabled or state.config.is_paused:
logger.info(
"Adapter disabled/paused, stopping loop",
extra={"adapter": state.name},
)
break
poll_start = datetime.now(timezone.utc) poll_start = datetime.now(timezone.utc)
try: try:
async for event in adapter.poll(): async for event in state.adapter.poll():
# Dedup check # Dedup check
if adapter.is_published(event.id): if state.adapter.is_published(event.id):
adapter.bump_last_seen(event.id) state.adapter.bump_last_seen(event.id)
continue continue
# Build CloudEvent # Build CloudEvent
envelope, msg_id = wrap_event(event, self.config) if self._cloudevents_config:
envelope, msg_id = wrap_event(event, self._cloudevents_config)
else:
# Fallback for testing
envelope = {"id": event.id, "data": event.data}
msg_id = event.id
subject = subject_for_event(event) subject = subject_for_event(event)
# Publish # Publish
await self._publish_event(subject, envelope, msg_id) await self._publish_event(subject, envelope, msg_id)
adapter.mark_published(event.id) state.adapter.mark_published(event.id)
logger.info( logger.info(
"Published event", "Published event",
@ -130,37 +221,194 @@ class Supervisor:
# Publish success status # Publish success status
await self._publish_meta( await self._publish_meta(
f"central.meta.adapter.{adapter.name}.status", f"central.meta.adapter.{state.name}.status",
{"ok": True, "ts": datetime.now(timezone.utc).isoformat()} {"ok": True, "ts": datetime.now(timezone.utc).isoformat()}
) )
# Mark poll completion time for rate limiting
state.last_completed_poll = datetime.now(timezone.utc)
except Exception as e: except Exception as e:
logger.exception("Adapter poll failed", extra={"adapter": adapter.name}) logger.exception("Adapter poll failed", extra={"adapter": state.name})
await self._publish_meta( await self._publish_meta(
f"central.meta.adapter.{adapter.name}.status", f"central.meta.adapter.{state.name}.status",
{ {
"ok": False, "ok": False,
"error": str(e), "error": str(e),
"ts": datetime.now(timezone.utc).isoformat() "ts": datetime.now(timezone.utc).isoformat()
} }
) )
# Still mark completion time to avoid tight retry loops
state.last_completed_poll = datetime.now(timezone.utc)
# Sweep old IDs # Sweep old IDs
swept = adapter.sweep_old_ids() swept = state.adapter.sweep_old_ids()
if swept > 0: if swept > 0:
logger.info("Swept old published IDs", extra={"count": swept}) logger.info("Swept old published IDs", extra={"count": swept})
# Sleep until next cadence async def _start_adapter(self, config: AdapterConfig) -> None:
elapsed = (datetime.now(timezone.utc) - poll_start).total_seconds() """Start an adapter based on its configuration."""
sleep_time = max(0, adapter.cadence_s - elapsed) if config.name in self._adapter_states:
logger.warning(
"Adapter already running",
extra={"adapter": config.name},
)
return
if config.name == "nws":
nws_config = self._adapter_config_to_nws_config(config)
adapter = NWSAdapter(
config=nws_config,
cursor_db_path=CURSOR_DB_PATH,
)
await adapter.startup()
state = AdapterState(
name=config.name,
adapter=adapter,
config=config,
)
state.task = asyncio.create_task(self._run_adapter_loop(state))
self._adapter_states[config.name] = state
logger.info(
"Adapter started",
extra={
"adapter": config.name,
"cadence_s": config.cadence_s,
},
)
else:
logger.warning(
"Unknown adapter type",
extra={"adapter": config.name},
)
async def _stop_adapter(self, name: str) -> None:
"""Stop a running adapter."""
state = self._adapter_states.pop(name, None)
if state is None:
return
# Signal the loop to stop
state.cancel_event.set()
if state.task:
state.task.cancel()
try: try:
await asyncio.wait_for( await state.task
self._shutdown_event.wait(), except asyncio.CancelledError:
timeout=sleep_time
)
except asyncio.TimeoutError:
pass pass
await state.adapter.shutdown()
logger.info("Adapter stopped", extra={"adapter": name})
async def _reschedule_adapter(
self,
name: str,
new_config: AdapterConfig,
) -> None:
"""Reschedule an adapter with new configuration.
Maintains rate-limit guarantee: next poll at
(last_completed_poll + new_cadence_s), not now + new_cadence_s.
"""
state = self._adapter_states.get(name)
if state is None:
# Adapter not running - just start it
await self._start_adapter(new_config)
return
old_cadence = state.config.cadence_s
new_cadence = new_config.cadence_s
# Update config
state.config = new_config
# Update adapter's cadence
state.adapter.cadence_s = new_cadence
# Update adapter settings if needed (e.g., states list)
if name == "nws":
nws_config = self._adapter_config_to_nws_config(new_config)
state.adapter.states = set(s.upper() for s in nws_config.states)
# Calculate next poll time for logging
if state.last_completed_poll:
next_poll_at = datetime.fromtimestamp(
state.last_completed_poll.timestamp() + new_cadence,
tz=timezone.utc,
)
else:
next_poll_at = datetime.now(timezone.utc)
logger.info(
"Rescheduled adapter",
extra={
"adapter": name,
"old_cadence_s": old_cadence,
"new_cadence_s": new_cadence,
"next_poll": next_poll_at.isoformat(),
},
)
# Signal the loop to re-evaluate its schedule
state.cancel_event.set()
async def _on_config_change(self, table: str, key: str) -> None:
"""Handle a configuration change notification.
Called when NOTIFY fires for config changes.
"""
if table != "adapters":
return
adapter_name = key
logger.info(
"Config change received",
extra={"table": table, "key": key},
)
async with self._lock:
# Fetch the current config for this adapter
new_config = await self._config_source.get_adapter(adapter_name)
current_state = self._adapter_states.get(adapter_name)
if new_config is None:
# Adapter was deleted
if current_state:
await self._stop_adapter(adapter_name)
logger.info(
"Adapter deleted, stopped",
extra={"adapter": adapter_name},
)
return
if not new_config.enabled or new_config.is_paused:
# Adapter disabled or paused
if current_state:
await self._stop_adapter(adapter_name)
logger.info(
"Adapter disabled/paused, stopped",
extra={
"adapter": adapter_name,
"enabled": new_config.enabled,
"paused": new_config.is_paused,
},
)
return
if current_state is None:
# Adapter was enabled or created
await self._start_adapter(new_config)
logger.info(
"Adapter enabled, started",
extra={"adapter": adapter_name},
)
else:
# Adapter config changed (cadence, settings)
await self._reschedule_adapter(adapter_name, new_config)
async def _heartbeat_loop(self) -> None: async def _heartbeat_loop(self) -> None:
"""Publish periodic heartbeats.""" """Publish periodic heartbeats."""
while not self._shutdown_event.is_set(): while not self._shutdown_event.is_set():
@ -181,32 +429,38 @@ class Supervisor:
"""Start the supervisor.""" """Start the supervisor."""
await self.connect() await self.connect()
# Initialize adapters # Load and start enabled adapters
if self.config.adapters.get("nws") and self.config.adapters["nws"].enabled: enabled_adapters = await self._config_source.list_enabled_adapters()
adapter = NWSAdapter( for config in enabled_adapters:
config=self.config.adapters["nws"], await self._start_adapter(config)
cursor_db_path=CURSOR_DB_PATH,
)
await adapter.startup()
self._adapters.append(adapter)
logger.info("NWS adapter initialized")
# Start adapter tasks # Start config watcher (for DB source, this runs forever; for TOML, returns immediately)
for adapter in self._adapters: self._config_watch_task = asyncio.create_task(
task = asyncio.create_task(self._run_adapter(adapter)) self._config_source.watch_for_changes(self._on_config_change)
self._tasks.append(task) )
# Start heartbeat # Start heartbeat
self._tasks.append(asyncio.create_task(self._heartbeat_loop())) self._tasks.append(asyncio.create_task(self._heartbeat_loop()))
logger.info("Supervisor started", extra={"adapters": [a.name for a in self._adapters]}) logger.info(
"Supervisor started",
extra={"adapters": list(self._adapter_states.keys())},
)
async def stop(self) -> None: async def stop(self) -> None:
"""Stop the supervisor gracefully.""" """Stop the supervisor gracefully."""
logger.info("Supervisor shutting down") logger.info("Supervisor shutting down")
self._shutdown_event.set() self._shutdown_event.set()
# Cancel tasks # Cancel config watcher
if self._config_watch_task:
self._config_watch_task.cancel()
try:
await self._config_watch_task
except asyncio.CancelledError:
pass
# Cancel heartbeat and other tasks
for task in self._tasks: for task in self._tasks:
task.cancel() task.cancel()
try: try:
@ -214,9 +468,12 @@ class Supervisor:
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
# Shutdown adapters # Stop all adapters
for adapter in self._adapters: for name in list(self._adapter_states.keys()):
await adapter.shutdown() await self._stop_adapter(name)
# Close config source
await self._config_source.close()
await self.disconnect() await self.disconnect()
logger.info("Supervisor stopped") logger.info("Supervisor stopped")
@ -226,8 +483,35 @@ async def async_main() -> None:
"""Async entry point.""" """Async entry point."""
setup_logging() setup_logging()
config = load_config(CONFIG_PATH) settings = get_settings()
supervisor = Supervisor(config) logger.info(
"Config source: %s",
settings.config_source,
extra={"config_source": settings.config_source},
)
# Create config source based on setting
config_source = await create_config_source(
source_type=settings.config_source,
dsn=settings.db_dsn,
toml_path=settings.config_toml_path,
)
# Load CloudEvents config for envelope generation
# For now, load from TOML regardless of config source
# (CloudEvents config is not adapter-specific)
try:
toml_config = load_config(str(settings.config_toml_path))
cloudevents_config = toml_config
except Exception:
# If TOML doesn't exist and using DB source, create minimal config
cloudevents_config = None
supervisor = Supervisor(
config_source=config_source,
nats_url=settings.nats_url,
cloudevents_config=cloudevents_config,
)
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
shutdown_event = asyncio.Event() shutdown_event = asyncio.Event()