central/tests/test_supervisor_hotreload.py

357 lines
12 KiB
Python
Raw Normal View History

"""Tests for supervisor hot-reload and rate-limiting behavior."""
import asyncio
import base64
import os
from datetime import datetime, timedelta, timezone
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import asyncpg
import pytest
import pytest_asyncio
from central.config_models import AdapterConfig
from central.config_source import DbConfigSource
from central.config_store import ConfigStore
from central.crypto import KEY_SIZE, clear_key_cache
# Test database DSN
TEST_DB_DSN = os.environ.get(
"CENTRAL_TEST_DB_DSN",
"postgresql://central_test:testpass@localhost/central_test",
)
@pytest.fixture(scope="session")
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
"""Create a master key file for the test session."""
key = os.urandom(KEY_SIZE)
key_path = tmp_path_factory.mktemp("keys") / "master.key"
key_path.write_text(base64.b64encode(key).decode())
return key_path
@pytest.fixture(autouse=True)
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Configure master key path for all tests."""
clear_key_cache()
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
@pytest_asyncio.fixture
async def db_conn() -> asyncpg.Connection:
"""Get a direct database connection for setup/teardown."""
conn = await asyncpg.connect(TEST_DB_DSN)
yield conn
await conn.close()
@pytest_asyncio.fixture
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
"""Ensure config schema exists and is clean before each test."""
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
await db_conn.execute("""
CREATE TABLE IF NOT EXISTS config.adapters (
name TEXT PRIMARY KEY,
enabled BOOLEAN NOT NULL DEFAULT true,
cadence_s INTEGER NOT NULL,
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
paused_at TIMESTAMPTZ,
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
)
""")
# Create notify trigger
await db_conn.execute("""
CREATE OR REPLACE FUNCTION config.notify_config_change()
RETURNS trigger AS $$
DECLARE
key_value TEXT;
BEGIN
IF TG_TABLE_NAME = 'adapters' THEN
key_value := COALESCE(NEW.name, OLD.name, '');
ELSE
key_value := '';
END IF;
PERFORM pg_notify('config_changed', TG_TABLE_NAME || ':' || key_value);
RETURN COALESCE(NEW, OLD);
END;
$$ LANGUAGE plpgsql
""")
await db_conn.execute("""
DROP TRIGGER IF EXISTS adapters_notify ON config.adapters;
CREATE TRIGGER adapters_notify
AFTER INSERT OR UPDATE OR DELETE ON config.adapters
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
""")
await db_conn.execute("DELETE FROM config.adapters")
@pytest_asyncio.fixture
async def config_store(clean_config_schema: None) -> ConfigStore:
"""Create a ConfigStore connected to the test database."""
store = await ConfigStore.create(TEST_DB_DSN)
yield store
await store.close()
class TestDbConfigSourceNotifications:
"""Tests for DbConfigSource NOTIFY integration."""
@pytest.mark.asyncio
async def test_watch_receives_notifications(
self,
config_store: ConfigStore,
db_conn: asyncpg.Connection,
) -> None:
"""watch_for_changes receives NOTIFY when adapter changes."""
source = DbConfigSource(config_store)
notifications: list[tuple[str, str]] = []
notification_received = asyncio.Event()
async def callback(table: str, key: str) -> None:
notifications.append((table, key))
notification_received.set()
# Start watching in background
watch_task = asyncio.create_task(source.watch_for_changes(callback))
try:
# Wait for listener to connect
await asyncio.sleep(0.2)
# Insert an adapter via direct connection (not through store)
# This triggers the NOTIFY
await db_conn.execute("""
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
VALUES ('test_adapter', true, 60, '{}'::jsonb)
""")
# Wait for notification
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
assert len(notifications) >= 1
assert notifications[0] == ("adapters", "test_adapter")
finally:
watch_task.cancel()
try:
await watch_task
except asyncio.CancelledError:
pass
class TestRateLimitGuarantee:
"""Tests for rate-limit guarantees during hot-reload.
These tests verify the critical invariant: cadence changes must not
cause extra API calls before (last_poll + new_cadence).
"""
@pytest.mark.asyncio
async def test_cadence_change_respects_last_poll_time(self) -> None:
"""Changing cadence mid-cycle schedules next poll at last_poll + new_cadence.
This is the core rate-limit guarantee test (gate 3).
"""
# Import supervisor module to access AdapterState
from central.supervisor import AdapterState
# Mock adapter
mock_adapter = MagicMock()
mock_adapter.name = "test"
mock_adapter.cadence_s = 60
# Create adapter state with a known last_completed_poll time
last_poll = datetime.now(timezone.utc) - timedelta(seconds=30)
config = AdapterConfig(
name="test",
enabled=True,
cadence_s=60, # Original cadence
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
state = AdapterState(
name="test",
adapter=mock_adapter,
config=config,
last_completed_poll=last_poll,
)
# Simulate cadence change to 90 seconds
new_config = AdapterConfig(
name="test",
enabled=True,
cadence_s=90, # New cadence
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
# Update state as reschedule would
state.config = new_config
state.adapter.cadence_s = 90
# Calculate expected next poll time
expected_next_poll = last_poll + timedelta(seconds=90)
now = datetime.now(timezone.utc)
expected_wait = max(0, (expected_next_poll - now).total_seconds())
# The wait time should be based on last_poll + new_cadence
# Since last_poll was 30 seconds ago and new cadence is 90,
# we should wait 60 more seconds (90 - 30 = 60)
actual_next_poll = last_poll.timestamp() + new_config.cadence_s
actual_wait = max(0, actual_next_poll - now.timestamp())
# Allow 1 second tolerance for timing
assert abs(actual_wait - 60) < 2, (
f"Expected ~60s wait, got {actual_wait}s. "
f"Rate limit violated: poll would happen before last_poll + new_cadence"
)
@pytest.mark.asyncio
async def test_cadence_increase_after_gap_polls_immediately(self) -> None:
"""When last_poll + new_cadence is already past, poll immediately.
If operator increases cadence to 120s after a gap of 150s,
the poll should happen now (not wait another 120s).
"""
from central.supervisor import AdapterState
mock_adapter = MagicMock()
mock_adapter.name = "test"
mock_adapter.cadence_s = 60
# Last poll was 150 seconds ago
last_poll = datetime.now(timezone.utc) - timedelta(seconds=150)
config = AdapterConfig(
name="test",
enabled=True,
cadence_s=120, # Increased cadence
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
state = AdapterState(
name="test",
adapter=mock_adapter,
config=config,
last_completed_poll=last_poll,
)
# Calculate next poll time
now = datetime.now(timezone.utc)
next_poll_at = last_poll.timestamp() + config.cadence_s
wait_time = max(0, next_poll_at - now.timestamp())
# Since 150 > 120, next poll should be immediate (wait_time ~= 0)
assert wait_time < 1, (
f"Expected immediate poll (wait ~0s), got {wait_time}s. "
f"After a gap exceeding new cadence, poll should happen now."
)
@pytest.mark.asyncio
async def test_enable_disable_enable_respects_rate_limit(self) -> None:
"""Re-enabling adapter schedules poll at last_poll + cadence.
If adapter was disabled for a while and then re-enabled, the next
poll should be at (last_completed_poll + cadence_s), not immediately
(unless that time has already passed).
"""
from central.supervisor import AdapterState
mock_adapter = MagicMock()
mock_adapter.name = "test"
mock_adapter.cadence_s = 60
# Last poll was 30 seconds ago, then adapter was disabled
last_poll = datetime.now(timezone.utc) - timedelta(seconds=30)
# Re-enabled config
config = AdapterConfig(
name="test",
enabled=True,
cadence_s=60,
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
state = AdapterState(
name="test",
adapter=mock_adapter,
config=config,
last_completed_poll=last_poll,
)
# Calculate next poll time
now = datetime.now(timezone.utc)
next_poll_at = last_poll.timestamp() + config.cadence_s
wait_time = max(0, next_poll_at - now.timestamp())
# Should wait ~30 more seconds (60 - 30 = 30)
assert abs(wait_time - 30) < 2, (
f"Expected ~30s wait after re-enable, got {wait_time}s. "
f"Rate limit violated on enable→disable→enable sequence."
)
@pytest.mark.asyncio
async def test_multiple_rapid_cadence_changes_no_extra_polls(self) -> None:
"""Multiple rapid cadence changes don't cause extra polls.
If NOTIFY fires rapidly (609012090), the final schedule should
still be based on last_completed_poll + final_cadence.
"""
from central.supervisor import AdapterState
mock_adapter = MagicMock()
mock_adapter.name = "test"
mock_adapter.cadence_s = 60
# Last poll was 20 seconds ago
last_poll = datetime.now(timezone.utc) - timedelta(seconds=20)
state = AdapterState(
name="test",
adapter=mock_adapter,
config=AdapterConfig(
name="test",
enabled=True,
cadence_s=60,
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
),
last_completed_poll=last_poll,
)
# Simulate rapid cadence changes
for cadence in [90, 120, 90]: # Final cadence is 90
state.config = AdapterConfig(
name="test",
enabled=True,
cadence_s=cadence,
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
state.adapter.cadence_s = cadence
# Final schedule should be last_poll + 90
now = datetime.now(timezone.utc)
final_cadence = 90
next_poll_at = last_poll.timestamp() + final_cadence
wait_time = max(0, next_poll_at - now.timestamp())
# Should wait ~70 seconds (90 - 20 = 70)
assert abs(wait_time - 70) < 2, (
f"Expected ~70s wait after rapid changes, got {wait_time}s. "
f"Multiple NOTIFYs should not cause extra polls."
)