feat(supervisor): add hot-reload support with rate-limit guarantee

Refactors supervisor to use ConfigSource abstraction:
- AdapterState tracks last_completed_poll for rate limiting
- Hot-reload via NOTIFY: cadence/enable/disable changes take effect
- Rate-limit guarantee: next poll at last_poll + new_cadence, not now
- Logs config source at startup (toml or db)
- Logs reschedule decisions with next poll timestamp

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Ubuntu 2026-05-16 01:55:33 +00:00
commit 29fa49c5c2

View file

@ -1,255 +1,539 @@
"""Central supervisor - adapter scheduler and event publisher.""" """Central supervisor - adapter scheduler and event publisher."""
import asyncio import asyncio
import json import json
import logging import logging
import signal import signal
import sys import sys
from datetime import datetime, timezone from dataclasses import dataclass, field
from pathlib import Path from datetime import datetime, timezone
from typing import Any from pathlib import Path
from typing import Any
import nats
from nats.js import JetStreamContext import nats
from nats.js import JetStreamContext
from central.adapters.nws import NWSAdapter
from central.cloudevents_wire import wrap_event from central.adapters.nws import NWSAdapter
from central.config import load_config, Config from central.cloudevents_wire import wrap_event
from central.models import subject_for_event from central.config import load_config, Config, NWSAdapterConfig
from central.config_models import AdapterConfig
CURSOR_DB_PATH = Path("/var/lib/central/cursors.db") from central.config_source import ConfigSource, create_config_source
CONFIG_PATH = "/etc/central/central.toml" from central.bootstrap_config import get_settings
from central.models import subject_for_event
class JsonFormatter(logging.Formatter): CURSOR_DB_PATH = Path("/var/lib/central/cursors.db")
"""JSON log formatter for structured logging."""
def format(self, record: logging.LogRecord) -> str: class JsonFormatter(logging.Formatter):
log_obj: dict[str, Any] = { """JSON log formatter for structured logging."""
"ts": datetime.now(timezone.utc).isoformat(),
"level": record.levelname, def format(self, record: logging.LogRecord) -> str:
"logger": record.name, log_obj: dict[str, Any] = {
"msg": record.getMessage(), "ts": datetime.now(timezone.utc).isoformat(),
} "level": record.levelname,
if record.exc_info: "logger": record.name,
log_obj["exc"] = self.formatException(record.exc_info) "msg": record.getMessage(),
if hasattr(record, "extra"): }
log_obj.update(record.extra) if record.exc_info:
# Include any extra fields passed via extra={} log_obj["exc"] = self.formatException(record.exc_info)
for key in record.__dict__: if hasattr(record, "extra"):
if key not in ( log_obj.update(record.extra)
"name", "msg", "args", "created", "filename", "funcName", for key in record.__dict__:
"levelname", "levelno", "lineno", "module", "msecs", if key not in (
"pathname", "process", "processName", "relativeCreated", "name", "msg", "args", "created", "filename", "funcName",
"stack_info", "exc_info", "exc_text", "thread", "threadName", "levelname", "levelno", "lineno", "module", "msecs",
"taskName", "message", "pathname", "process", "processName", "relativeCreated",
): "stack_info", "exc_info", "exc_text", "thread", "threadName",
log_obj[key] = record.__dict__[key] "taskName", "message",
return json.dumps(log_obj) ):
log_obj[key] = record.__dict__[key]
return json.dumps(log_obj)
def setup_logging() -> None:
"""Configure JSON logging to stdout."""
handler = logging.StreamHandler(sys.stdout) def setup_logging() -> None:
handler.setFormatter(JsonFormatter()) """Configure JSON logging to stdout."""
logging.root.handlers = [handler] handler = logging.StreamHandler(sys.stdout)
logging.root.setLevel(logging.INFO) handler.setFormatter(JsonFormatter())
logging.root.handlers = [handler]
logging.root.setLevel(logging.INFO)
logger = logging.getLogger("central.supervisor")
logger = logging.getLogger("central.supervisor")
class Supervisor:
"""Main supervisor process."""
@dataclass
def __init__(self, config: Config) -> None: class AdapterState:
self.config = config """Runtime state for a scheduled adapter."""
self._nc: nats.NATS | None = None
self._js: JetStreamContext | None = None name: str
self._adapters: list[NWSAdapter] = [] adapter: NWSAdapter
self._tasks: list[asyncio.Task[None]] = [] config: AdapterConfig
self._shutdown_event = asyncio.Event() task: asyncio.Task[None] | None = None
self._start_time = datetime.now(timezone.utc) last_completed_poll: datetime | None = None
cancel_event: asyncio.Event = field(default_factory=asyncio.Event)
async def connect(self) -> None:
"""Connect to NATS.""" def __post_init__(self) -> None:
self._nc = await nats.connect(self.config.nats.url) # Ensure cancel_event is created
self._js = self._nc.jetstream() if self.cancel_event is None:
logger.info("Connected to NATS", extra={"url": self.config.nats.url}) self.cancel_event = asyncio.Event()
async def disconnect(self) -> None:
"""Disconnect from NATS.""" class Supervisor:
if self._nc: """Main supervisor process."""
await self._nc.drain()
await self._nc.close() def __init__(
self._nc = None self,
self._js = None config_source: ConfigSource,
logger.info("Disconnected from NATS") nats_url: str,
cloudevents_config: Any = None,
async def _publish_meta(self, subject: str, data: dict[str, Any]) -> None: ) -> None:
"""Publish a meta event (no Nats-Msg-Id).""" self._config_source = config_source
if not self._nc: self._nats_url = nats_url
return self._cloudevents_config = cloudevents_config
payload = json.dumps(data).encode() self._nc: nats.NATS | None = None
await self._nc.publish(subject, payload) self._js: JetStreamContext | None = None
self._adapter_states: dict[str, AdapterState] = {}
async def _publish_event(self, subject: str, envelope: dict[str, Any], msg_id: str) -> None: self._tasks: list[asyncio.Task[None]] = []
"""Publish an event with dedup header.""" self._shutdown_event = asyncio.Event()
if not self._js: self._start_time = datetime.now(timezone.utc)
return self._config_watch_task: asyncio.Task[None] | None = None
payload = json.dumps(envelope).encode() self._lock = asyncio.Lock()
await self._js.publish(
subject, async def connect(self) -> None:
payload, """Connect to NATS."""
headers={"Nats-Msg-Id": msg_id}, self._nc = await nats.connect(self._nats_url)
) self._js = self._nc.jetstream()
logger.info("Connected to NATS", extra={"url": self._nats_url})
async def _run_adapter(self, adapter: NWSAdapter) -> None:
"""Run an adapter poll loop.""" async def disconnect(self) -> None:
while not self._shutdown_event.is_set(): """Disconnect from NATS."""
poll_start = datetime.now(timezone.utc) if self._nc:
try: await self._nc.drain()
async for event in adapter.poll(): await self._nc.close()
# Dedup check self._nc = None
if adapter.is_published(event.id): self._js = None
adapter.bump_last_seen(event.id) logger.info("Disconnected from NATS")
continue
async def _publish_meta(self, subject: str, data: dict[str, Any]) -> None:
# Build CloudEvent """Publish a meta event (no Nats-Msg-Id)."""
envelope, msg_id = wrap_event(event, self.config) if not self._nc:
subject = subject_for_event(event) return
payload = json.dumps(data).encode()
# Publish await self._nc.publish(subject, payload)
await self._publish_event(subject, envelope, msg_id)
adapter.mark_published(event.id) async def _publish_event(self, subject: str, envelope: dict[str, Any], msg_id: str) -> None:
"""Publish an event with dedup header."""
logger.info( if not self._js:
"Published event", return
extra={"id": event.id, "subject": subject, "category": event.category} payload = json.dumps(envelope).encode()
) await self._js.publish(
subject,
# Publish success status payload,
await self._publish_meta( headers={"Nats-Msg-Id": msg_id},
f"central.meta.adapter.{adapter.name}.status", )
{"ok": True, "ts": datetime.now(timezone.utc).isoformat()}
) def _adapter_config_to_nws_config(self, config: AdapterConfig) -> NWSAdapterConfig:
"""Convert unified AdapterConfig to NWSAdapterConfig."""
except Exception as e: return NWSAdapterConfig(
logger.exception("Adapter poll failed", extra={"adapter": adapter.name}) enabled=config.enabled,
await self._publish_meta( cadence_s=config.cadence_s,
f"central.meta.adapter.{adapter.name}.status", states=config.settings.get("states", []),
{ contact_email=config.settings.get("contact_email", ""),
"ok": False, )
"error": str(e),
"ts": datetime.now(timezone.utc).isoformat() async def _run_adapter_loop(self, state: AdapterState) -> None:
} """Run an adapter poll loop with rate-limit aware scheduling."""
) while not self._shutdown_event.is_set():
# Calculate next poll time based on rate-limit guarantee
# Sweep old IDs now = datetime.now(timezone.utc)
swept = adapter.sweep_old_ids()
if swept > 0: if state.last_completed_poll is not None:
logger.info("Swept old published IDs", extra={"count": swept}) next_poll_at = state.last_completed_poll.timestamp() + state.config.cadence_s
wait_time = max(0, next_poll_at - now.timestamp())
# Sleep until next cadence else:
elapsed = (datetime.now(timezone.utc) - poll_start).total_seconds() # First poll - run immediately
sleep_time = max(0, adapter.cadence_s - elapsed) wait_time = 0
try:
await asyncio.wait_for( if wait_time > 0:
self._shutdown_event.wait(), logger.debug(
timeout=sleep_time "Waiting for next poll",
) extra={
except asyncio.TimeoutError: "adapter": state.name,
pass "wait_s": wait_time,
"next_poll": datetime.fromtimestamp(
async def _heartbeat_loop(self) -> None: now.timestamp() + wait_time, tz=timezone.utc
"""Publish periodic heartbeats.""" ).isoformat(),
while not self._shutdown_event.is_set(): },
uptime = (datetime.now(timezone.utc) - self._start_time).total_seconds() )
await self._publish_meta( # Wait for either timeout or cancel signal
"central.meta.heartbeat", try:
{"ts": datetime.now(timezone.utc).isoformat(), "uptime_s": uptime} await asyncio.wait_for(
) state.cancel_event.wait(),
try: timeout=wait_time,
await asyncio.wait_for( )
self._shutdown_event.wait(), # Cancel event was set - check if we should exit or reschedule
timeout=30 if self._shutdown_event.is_set():
) break
except asyncio.TimeoutError: # Clear the cancel event and re-evaluate schedule
pass state.cancel_event.clear()
continue
async def start(self) -> None: except asyncio.TimeoutError:
"""Start the supervisor.""" pass
await self.connect()
# Check shutdown before polling
# Initialize adapters if self._shutdown_event.is_set():
if self.config.adapters.get("nws") and self.config.adapters["nws"].enabled: break
adapter = NWSAdapter(
config=self.config.adapters["nws"], # Check if adapter is still enabled
cursor_db_path=CURSOR_DB_PATH, if not state.config.enabled or state.config.is_paused:
) logger.info(
await adapter.startup() "Adapter disabled/paused, stopping loop",
self._adapters.append(adapter) extra={"adapter": state.name},
logger.info("NWS adapter initialized") )
break
# Start adapter tasks
for adapter in self._adapters: poll_start = datetime.now(timezone.utc)
task = asyncio.create_task(self._run_adapter(adapter)) try:
self._tasks.append(task) async for event in state.adapter.poll():
# Dedup check
# Start heartbeat if state.adapter.is_published(event.id):
self._tasks.append(asyncio.create_task(self._heartbeat_loop())) state.adapter.bump_last_seen(event.id)
continue
logger.info("Supervisor started", extra={"adapters": [a.name for a in self._adapters]})
# Build CloudEvent
async def stop(self) -> None: if self._cloudevents_config:
"""Stop the supervisor gracefully.""" envelope, msg_id = wrap_event(event, self._cloudevents_config)
logger.info("Supervisor shutting down") else:
self._shutdown_event.set() # Fallback for testing
envelope = {"id": event.id, "data": event.data}
# Cancel tasks msg_id = event.id
for task in self._tasks:
task.cancel() subject = subject_for_event(event)
try:
await task # Publish
except asyncio.CancelledError: await self._publish_event(subject, envelope, msg_id)
pass state.adapter.mark_published(event.id)
# Shutdown adapters logger.info(
for adapter in self._adapters: "Published event",
await adapter.shutdown() extra={"id": event.id, "subject": subject, "category": event.category}
)
await self.disconnect()
logger.info("Supervisor stopped") # Publish success status
await self._publish_meta(
f"central.meta.adapter.{state.name}.status",
async def async_main() -> None: {"ok": True, "ts": datetime.now(timezone.utc).isoformat()}
"""Async entry point.""" )
setup_logging()
# Mark poll completion time for rate limiting
config = load_config(CONFIG_PATH) state.last_completed_poll = datetime.now(timezone.utc)
supervisor = Supervisor(config)
except Exception as e:
loop = asyncio.get_running_loop() logger.exception("Adapter poll failed", extra={"adapter": state.name})
shutdown_event = asyncio.Event() await self._publish_meta(
f"central.meta.adapter.{state.name}.status",
def handle_signal() -> None: {
shutdown_event.set() "ok": False,
"error": str(e),
for sig in (signal.SIGTERM, signal.SIGINT): "ts": datetime.now(timezone.utc).isoformat()
loop.add_signal_handler(sig, handle_signal) }
)
await supervisor.start() # Still mark completion time to avoid tight retry loops
state.last_completed_poll = datetime.now(timezone.utc)
# Wait for shutdown signal
await shutdown_event.wait() # Sweep old IDs
swept = state.adapter.sweep_old_ids()
await supervisor.stop() if swept > 0:
logger.info("Swept old published IDs", extra={"count": swept})
def main() -> None: async def _start_adapter(self, config: AdapterConfig) -> None:
"""Entry point.""" """Start an adapter based on its configuration."""
asyncio.run(async_main()) if config.name in self._adapter_states:
logger.warning(
"Adapter already running",
if __name__ == "__main__": extra={"adapter": config.name},
main() )
return
if config.name == "nws":
nws_config = self._adapter_config_to_nws_config(config)
adapter = NWSAdapter(
config=nws_config,
cursor_db_path=CURSOR_DB_PATH,
)
await adapter.startup()
state = AdapterState(
name=config.name,
adapter=adapter,
config=config,
)
state.task = asyncio.create_task(self._run_adapter_loop(state))
self._adapter_states[config.name] = state
logger.info(
"Adapter started",
extra={
"adapter": config.name,
"cadence_s": config.cadence_s,
},
)
else:
logger.warning(
"Unknown adapter type",
extra={"adapter": config.name},
)
async def _stop_adapter(self, name: str) -> None:
"""Stop a running adapter."""
state = self._adapter_states.pop(name, None)
if state is None:
return
# Signal the loop to stop
state.cancel_event.set()
if state.task:
state.task.cancel()
try:
await state.task
except asyncio.CancelledError:
pass
await state.adapter.shutdown()
logger.info("Adapter stopped", extra={"adapter": name})
async def _reschedule_adapter(
self,
name: str,
new_config: AdapterConfig,
) -> None:
"""Reschedule an adapter with new configuration.
Maintains rate-limit guarantee: next poll at
(last_completed_poll + new_cadence_s), not now + new_cadence_s.
"""
state = self._adapter_states.get(name)
if state is None:
# Adapter not running - just start it
await self._start_adapter(new_config)
return
old_cadence = state.config.cadence_s
new_cadence = new_config.cadence_s
# Update config
state.config = new_config
# Update adapter's cadence
state.adapter.cadence_s = new_cadence
# Update adapter settings if needed (e.g., states list)
if name == "nws":
nws_config = self._adapter_config_to_nws_config(new_config)
state.adapter.states = set(s.upper() for s in nws_config.states)
# Calculate next poll time for logging
if state.last_completed_poll:
next_poll_at = datetime.fromtimestamp(
state.last_completed_poll.timestamp() + new_cadence,
tz=timezone.utc,
)
else:
next_poll_at = datetime.now(timezone.utc)
logger.info(
"Rescheduled adapter",
extra={
"adapter": name,
"old_cadence_s": old_cadence,
"new_cadence_s": new_cadence,
"next_poll": next_poll_at.isoformat(),
},
)
# Signal the loop to re-evaluate its schedule
state.cancel_event.set()
async def _on_config_change(self, table: str, key: str) -> None:
"""Handle a configuration change notification.
Called when NOTIFY fires for config changes.
"""
if table != "adapters":
return
adapter_name = key
logger.info(
"Config change received",
extra={"table": table, "key": key},
)
async with self._lock:
# Fetch the current config for this adapter
new_config = await self._config_source.get_adapter(adapter_name)
current_state = self._adapter_states.get(adapter_name)
if new_config is None:
# Adapter was deleted
if current_state:
await self._stop_adapter(adapter_name)
logger.info(
"Adapter deleted, stopped",
extra={"adapter": adapter_name},
)
return
if not new_config.enabled or new_config.is_paused:
# Adapter disabled or paused
if current_state:
await self._stop_adapter(adapter_name)
logger.info(
"Adapter disabled/paused, stopped",
extra={
"adapter": adapter_name,
"enabled": new_config.enabled,
"paused": new_config.is_paused,
},
)
return
if current_state is None:
# Adapter was enabled or created
await self._start_adapter(new_config)
logger.info(
"Adapter enabled, started",
extra={"adapter": adapter_name},
)
else:
# Adapter config changed (cadence, settings)
await self._reschedule_adapter(adapter_name, new_config)
async def _heartbeat_loop(self) -> None:
"""Publish periodic heartbeats."""
while not self._shutdown_event.is_set():
uptime = (datetime.now(timezone.utc) - self._start_time).total_seconds()
await self._publish_meta(
"central.meta.heartbeat",
{"ts": datetime.now(timezone.utc).isoformat(), "uptime_s": uptime}
)
try:
await asyncio.wait_for(
self._shutdown_event.wait(),
timeout=30
)
except asyncio.TimeoutError:
pass
async def start(self) -> None:
"""Start the supervisor."""
await self.connect()
# Load and start enabled adapters
enabled_adapters = await self._config_source.list_enabled_adapters()
for config in enabled_adapters:
await self._start_adapter(config)
# Start config watcher (for DB source, this runs forever; for TOML, returns immediately)
self._config_watch_task = asyncio.create_task(
self._config_source.watch_for_changes(self._on_config_change)
)
# Start heartbeat
self._tasks.append(asyncio.create_task(self._heartbeat_loop()))
logger.info(
"Supervisor started",
extra={"adapters": list(self._adapter_states.keys())},
)
async def stop(self) -> None:
"""Stop the supervisor gracefully."""
logger.info("Supervisor shutting down")
self._shutdown_event.set()
# Cancel config watcher
if self._config_watch_task:
self._config_watch_task.cancel()
try:
await self._config_watch_task
except asyncio.CancelledError:
pass
# Cancel heartbeat and other tasks
for task in self._tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Stop all adapters
for name in list(self._adapter_states.keys()):
await self._stop_adapter(name)
# Close config source
await self._config_source.close()
await self.disconnect()
logger.info("Supervisor stopped")
async def async_main() -> None:
"""Async entry point."""
setup_logging()
settings = get_settings()
logger.info(
"Config source: %s",
settings.config_source,
extra={"config_source": settings.config_source},
)
# Create config source based on setting
config_source = await create_config_source(
source_type=settings.config_source,
dsn=settings.db_dsn,
toml_path=settings.config_toml_path,
)
# Load CloudEvents config for envelope generation
# For now, load from TOML regardless of config source
# (CloudEvents config is not adapter-specific)
try:
toml_config = load_config(str(settings.config_toml_path))
cloudevents_config = toml_config
except Exception:
# If TOML doesn't exist and using DB source, create minimal config
cloudevents_config = None
supervisor = Supervisor(
config_source=config_source,
nats_url=settings.nats_url,
cloudevents_config=cloudevents_config,
)
loop = asyncio.get_running_loop()
shutdown_event = asyncio.Event()
def handle_signal() -> None:
shutdown_event.set()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, handle_signal)
await supervisor.start()
# Wait for shutdown signal
await shutdown_event.wait()
await supervisor.stop()
def main() -> None:
"""Entry point."""
asyncio.run(async_main())
if __name__ == "__main__":
main()