chore: normalize line endings to LF

This commit is contained in:
Matt Johnson 2026-05-16 21:27:30 +00:00
commit 374a8c067f
26 changed files with 5357 additions and 5346 deletions

View file

@ -1,18 +1,18 @@
# Central Tests
## Test Database
Some tests (notably `test_config_store.py`) require a real PostgreSQL database.
By default, tests connect to:
```
postgresql://central_test:testpass@localhost/central_test
```
If your test database uses different credentials, set the `CENTRAL_TEST_DB_DSN`
environment variable:
```bash
export CENTRAL_TEST_DB_DSN="postgresql://myuser:mypass@localhost/mydb"
uv run pytest tests/test_config_store.py
```
# Central Tests
## Test Database
Some tests (notably `test_config_store.py`) require a real PostgreSQL database.
By default, tests connect to:
```
postgresql://central_test:testpass@localhost/central_test
```
If your test database uses different credentials, set the `CENTRAL_TEST_DB_DSN`
environment variable:
```bash
export CENTRAL_TEST_DB_DSN="postgresql://myuser:mypass@localhost/mydb"
uv run pytest tests/test_config_store.py
```

View file

@ -1,123 +1,123 @@
"""Tests for bootstrap configuration."""
import os
from pathlib import Path
from tempfile import NamedTemporaryFile
import pytest
from central.bootstrap_config import Settings, get_settings
class TestSettingsFromEnv:
"""Test loading settings from environment variables."""
def test_reads_from_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Settings are read from CENTRAL_* environment variables."""
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://test:pass@localhost/testdb")
monkeypatch.setenv("CENTRAL_NATS_URL", "nats://10.0.0.1:4222")
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", "/tmp/test.key")
monkeypatch.setenv("CENTRAL_LOG_LEVEL", "DEBUG")
settings = Settings()
assert settings.db_dsn == "postgresql://test:pass@localhost/testdb"
assert settings.nats_url == "nats://10.0.0.1:4222"
assert settings.master_key_path == Path("/tmp/test.key")
assert settings.log_level == "DEBUG"
def test_defaults_applied(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Default values are used when env vars not set."""
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://x:y@localhost/db")
# Clear any existing env vars that might interfere
monkeypatch.delenv("CENTRAL_NATS_URL", raising=False)
monkeypatch.delenv("CENTRAL_MASTER_KEY_PATH", raising=False)
monkeypatch.delenv("CENTRAL_LOG_LEVEL", raising=False)
settings = Settings()
assert settings.nats_url == "nats://localhost:4222"
assert settings.master_key_path == Path("/etc/central/master.key")
assert settings.log_level == "INFO"
class TestSettingsFromFile:
"""Test loading settings from .env file."""
def test_reads_from_env_file(self, tmp_path: Path) -> None:
"""Settings are read from .env file when env vars not present."""
env_file = tmp_path / ".env"
env_file.write_text(
"CENTRAL_DB_DSN=postgresql://file:pass@localhost/filedb\n"
"CENTRAL_NATS_URL=nats://file.local:4222\n"
"CENTRAL_LOG_LEVEL=WARNING\n"
)
# Create settings pointing to the temp .env file
settings = Settings(_env_file=env_file)
assert settings.db_dsn == "postgresql://file:pass@localhost/filedb"
assert settings.nats_url == "nats://file.local:4222"
assert settings.log_level == "WARNING"
def test_env_vars_override_file(
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Environment variables take precedence over .env file."""
env_file = tmp_path / ".env"
env_file.write_text("CENTRAL_DB_DSN=postgresql://file@localhost/filedb\n")
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://env@localhost/envdb")
settings = Settings(_env_file=env_file)
assert settings.db_dsn == "postgresql://env@localhost/envdb"
class TestSettingsValidation:
"""Test settings validation and error handling."""
def test_fails_if_required_var_missing(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Clear error when required CENTRAL_DB_DSN is missing."""
# Ensure no env vars or .env file provides the DSN
monkeypatch.delenv("CENTRAL_DB_DSN", raising=False)
with pytest.raises(Exception) as exc_info:
# Use a non-existent .env file path to ensure no fallback
Settings(_env_file=Path("/nonexistent/.env"))
# pydantic-settings raises ValidationError for missing required fields
assert "db_dsn" in str(exc_info.value).lower() or "validation" in str(exc_info.value).lower()
def test_invalid_log_level_rejected(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Invalid log level values are rejected."""
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://x@localhost/db")
monkeypatch.setenv("CENTRAL_LOG_LEVEL", "INVALID")
with pytest.raises(Exception):
Settings()
class TestGetSettings:
"""Test the cached settings loader."""
def test_caches_result(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""get_settings() returns cached instance."""
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://cached@localhost/db")
get_settings.cache_clear()
s1 = get_settings()
s2 = get_settings()
assert s1 is s2
def test_cache_clear_reloads(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""cache_clear() forces reload on next call."""
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://first@localhost/db")
get_settings.cache_clear()
s1 = get_settings()
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://second@localhost/db")
get_settings.cache_clear()
s2 = get_settings()
assert s1.db_dsn != s2.db_dsn
"""Tests for bootstrap configuration."""
import os
from pathlib import Path
from tempfile import NamedTemporaryFile
import pytest
from central.bootstrap_config import Settings, get_settings
class TestSettingsFromEnv:
"""Test loading settings from environment variables."""
def test_reads_from_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Settings are read from CENTRAL_* environment variables."""
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://test:pass@localhost/testdb")
monkeypatch.setenv("CENTRAL_NATS_URL", "nats://10.0.0.1:4222")
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", "/tmp/test.key")
monkeypatch.setenv("CENTRAL_LOG_LEVEL", "DEBUG")
settings = Settings()
assert settings.db_dsn == "postgresql://test:pass@localhost/testdb"
assert settings.nats_url == "nats://10.0.0.1:4222"
assert settings.master_key_path == Path("/tmp/test.key")
assert settings.log_level == "DEBUG"
def test_defaults_applied(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Default values are used when env vars not set."""
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://x:y@localhost/db")
# Clear any existing env vars that might interfere
monkeypatch.delenv("CENTRAL_NATS_URL", raising=False)
monkeypatch.delenv("CENTRAL_MASTER_KEY_PATH", raising=False)
monkeypatch.delenv("CENTRAL_LOG_LEVEL", raising=False)
settings = Settings()
assert settings.nats_url == "nats://localhost:4222"
assert settings.master_key_path == Path("/etc/central/master.key")
assert settings.log_level == "INFO"
class TestSettingsFromFile:
"""Test loading settings from .env file."""
def test_reads_from_env_file(self, tmp_path: Path) -> None:
"""Settings are read from .env file when env vars not present."""
env_file = tmp_path / ".env"
env_file.write_text(
"CENTRAL_DB_DSN=postgresql://file:pass@localhost/filedb\n"
"CENTRAL_NATS_URL=nats://file.local:4222\n"
"CENTRAL_LOG_LEVEL=WARNING\n"
)
# Create settings pointing to the temp .env file
settings = Settings(_env_file=env_file)
assert settings.db_dsn == "postgresql://file:pass@localhost/filedb"
assert settings.nats_url == "nats://file.local:4222"
assert settings.log_level == "WARNING"
def test_env_vars_override_file(
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
"""Environment variables take precedence over .env file."""
env_file = tmp_path / ".env"
env_file.write_text("CENTRAL_DB_DSN=postgresql://file@localhost/filedb\n")
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://env@localhost/envdb")
settings = Settings(_env_file=env_file)
assert settings.db_dsn == "postgresql://env@localhost/envdb"
class TestSettingsValidation:
"""Test settings validation and error handling."""
def test_fails_if_required_var_missing(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Clear error when required CENTRAL_DB_DSN is missing."""
# Ensure no env vars or .env file provides the DSN
monkeypatch.delenv("CENTRAL_DB_DSN", raising=False)
with pytest.raises(Exception) as exc_info:
# Use a non-existent .env file path to ensure no fallback
Settings(_env_file=Path("/nonexistent/.env"))
# pydantic-settings raises ValidationError for missing required fields
assert "db_dsn" in str(exc_info.value).lower() or "validation" in str(exc_info.value).lower()
def test_invalid_log_level_rejected(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Invalid log level values are rejected."""
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://x@localhost/db")
monkeypatch.setenv("CENTRAL_LOG_LEVEL", "INVALID")
with pytest.raises(Exception):
Settings()
class TestGetSettings:
"""Test the cached settings loader."""
def test_caches_result(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""get_settings() returns cached instance."""
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://cached@localhost/db")
get_settings.cache_clear()
s1 = get_settings()
s2 = get_settings()
assert s1 is s2
def test_cache_clear_reloads(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""cache_clear() forces reload on next call."""
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://first@localhost/db")
get_settings.cache_clear()
s1 = get_settings()
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://second@localhost/db")
get_settings.cache_clear()
s2 = get_settings()
assert s1.db_dsn != s2.db_dsn

View file

@ -1,132 +1,132 @@
"""Tests for configuration source abstraction."""
import base64
import os
from pathlib import Path
import asyncpg
import pytest
import pytest_asyncio
from central.config_source import (
ConfigSource,
DbConfigSource,
)
from central.crypto import KEY_SIZE, clear_key_cache
# Test database DSN
TEST_DB_DSN = os.environ.get(
"CENTRAL_TEST_DB_DSN",
"postgresql://central_test:testpass@localhost/central_test",
)
@pytest.fixture(scope="session")
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
"""Create a master key file for the test session."""
key = os.urandom(KEY_SIZE)
key_path = tmp_path_factory.mktemp("keys") / "master.key"
key_path.write_text(base64.b64encode(key).decode())
return key_path
@pytest.fixture(autouse=True)
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Configure master key path for all tests."""
clear_key_cache()
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
@pytest_asyncio.fixture
async def db_conn() -> asyncpg.Connection:
"""Get a direct database connection for setup/teardown."""
conn = await asyncpg.connect(TEST_DB_DSN)
yield conn
await conn.close()
@pytest_asyncio.fixture
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
"""Ensure config schema exists and is clean before each test."""
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
await db_conn.execute("""
CREATE TABLE IF NOT EXISTS config.adapters (
name TEXT PRIMARY KEY,
enabled BOOLEAN NOT NULL DEFAULT true,
cadence_s INTEGER NOT NULL,
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
paused_at TIMESTAMPTZ,
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
)
""")
await db_conn.execute("DELETE FROM config.adapters")
class TestDbConfigSource:
"""Tests for database-backed config source."""
@pytest_asyncio.fixture
async def db_source(self, clean_config_schema: None) -> DbConfigSource:
"""Create a DbConfigSource for testing."""
source = await DbConfigSource.create(TEST_DB_DSN)
yield source
await source.close()
@pytest.mark.asyncio
async def test_list_enabled_adapters_empty(self, db_source: DbConfigSource) -> None:
"""list_enabled_adapters returns empty list when no adapters."""
adapters = await db_source.list_enabled_adapters()
assert adapters == []
@pytest.mark.asyncio
async def test_list_enabled_adapters(
self, db_source: DbConfigSource, db_conn: asyncpg.Connection
) -> None:
"""list_enabled_adapters returns only enabled, non-paused adapters."""
# Insert test adapters
await db_conn.execute("""
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
VALUES
('enabled_adapter', true, 60, '{"key": "value"}'::jsonb),
('disabled_adapter', false, 60, '{}'::jsonb),
('paused_adapter', true, 60, '{}'::jsonb)
""")
await db_conn.execute("""
UPDATE config.adapters
SET paused_at = now()
WHERE name = 'paused_adapter'
""")
adapters = await db_source.list_enabled_adapters()
assert len(adapters) == 1
assert adapters[0].name == "enabled_adapter"
@pytest.mark.asyncio
async def test_get_adapter(
self, db_source: DbConfigSource, db_conn: asyncpg.Connection
) -> None:
"""get_adapter returns correct adapter config."""
await db_conn.execute("""
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
VALUES ('test_adapter', true, 120, '{"states": ["ID"]}'::jsonb)
""")
adapter = await db_source.get_adapter("test_adapter")
assert adapter is not None
assert adapter.name == "test_adapter"
assert adapter.cadence_s == 120
assert adapter.settings == {"states": ["ID"]}
@pytest.mark.asyncio
async def test_get_nonexistent_adapter(self, db_source: DbConfigSource) -> None:
"""get_adapter returns None for nonexistent adapter."""
adapter = await db_source.get_adapter("does_not_exist")
assert adapter is None
@pytest.mark.asyncio
async def test_implements_protocol(self, db_source: DbConfigSource) -> None:
"""DbConfigSource implements ConfigSource protocol."""
assert isinstance(db_source, ConfigSource)
"""Tests for configuration source abstraction."""
import base64
import os
from pathlib import Path
import asyncpg
import pytest
import pytest_asyncio
from central.config_source import (
ConfigSource,
DbConfigSource,
)
from central.crypto import KEY_SIZE, clear_key_cache
# Test database DSN
TEST_DB_DSN = os.environ.get(
"CENTRAL_TEST_DB_DSN",
"postgresql://central_test:testpass@localhost/central_test",
)
@pytest.fixture(scope="session")
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
"""Create a master key file for the test session."""
key = os.urandom(KEY_SIZE)
key_path = tmp_path_factory.mktemp("keys") / "master.key"
key_path.write_text(base64.b64encode(key).decode())
return key_path
@pytest.fixture(autouse=True)
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Configure master key path for all tests."""
clear_key_cache()
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
@pytest_asyncio.fixture
async def db_conn() -> asyncpg.Connection:
"""Get a direct database connection for setup/teardown."""
conn = await asyncpg.connect(TEST_DB_DSN)
yield conn
await conn.close()
@pytest_asyncio.fixture
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
"""Ensure config schema exists and is clean before each test."""
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
await db_conn.execute("""
CREATE TABLE IF NOT EXISTS config.adapters (
name TEXT PRIMARY KEY,
enabled BOOLEAN NOT NULL DEFAULT true,
cadence_s INTEGER NOT NULL,
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
paused_at TIMESTAMPTZ,
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
)
""")
await db_conn.execute("DELETE FROM config.adapters")
class TestDbConfigSource:
"""Tests for database-backed config source."""
@pytest_asyncio.fixture
async def db_source(self, clean_config_schema: None) -> DbConfigSource:
"""Create a DbConfigSource for testing."""
source = await DbConfigSource.create(TEST_DB_DSN)
yield source
await source.close()
@pytest.mark.asyncio
async def test_list_enabled_adapters_empty(self, db_source: DbConfigSource) -> None:
"""list_enabled_adapters returns empty list when no adapters."""
adapters = await db_source.list_enabled_adapters()
assert adapters == []
@pytest.mark.asyncio
async def test_list_enabled_adapters(
self, db_source: DbConfigSource, db_conn: asyncpg.Connection
) -> None:
"""list_enabled_adapters returns only enabled, non-paused adapters."""
# Insert test adapters
await db_conn.execute("""
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
VALUES
('enabled_adapter', true, 60, '{"key": "value"}'::jsonb),
('disabled_adapter', false, 60, '{}'::jsonb),
('paused_adapter', true, 60, '{}'::jsonb)
""")
await db_conn.execute("""
UPDATE config.adapters
SET paused_at = now()
WHERE name = 'paused_adapter'
""")
adapters = await db_source.list_enabled_adapters()
assert len(adapters) == 1
assert adapters[0].name == "enabled_adapter"
@pytest.mark.asyncio
async def test_get_adapter(
self, db_source: DbConfigSource, db_conn: asyncpg.Connection
) -> None:
"""get_adapter returns correct adapter config."""
await db_conn.execute("""
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
VALUES ('test_adapter', true, 120, '{"states": ["ID"]}'::jsonb)
""")
adapter = await db_source.get_adapter("test_adapter")
assert adapter is not None
assert adapter.name == "test_adapter"
assert adapter.cadence_s == 120
assert adapter.settings == {"states": ["ID"]}
@pytest.mark.asyncio
async def test_get_nonexistent_adapter(self, db_source: DbConfigSource) -> None:
"""get_adapter returns None for nonexistent adapter."""
adapter = await db_source.get_adapter("does_not_exist")
assert adapter is None
@pytest.mark.asyncio
async def test_implements_protocol(self, db_source: DbConfigSource) -> None:
"""DbConfigSource implements ConfigSource protocol."""
assert isinstance(db_source, ConfigSource)

View file

@ -1,339 +1,339 @@
"""Tests for database-backed configuration store.
These tests require a real Postgres database. Set CENTRAL_TEST_DB_DSN
environment variable to override the default test database connection.
"""
import asyncio
import base64
import os
from pathlib import Path
import asyncpg
import pytest
import pytest_asyncio
from central.config_store import ConfigStore
from central.crypto import KEY_SIZE, clear_key_cache
# Test database DSN - uses central_test database with well-known test password.
# Override via CENTRAL_TEST_DB_DSN env var if your test DB differs.
TEST_DB_DSN = os.environ.get(
"CENTRAL_TEST_DB_DSN",
"postgresql://central_test:testpass@localhost/central_test",
)
@pytest.fixture(scope="session")
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
"""Create a master key file for the test session."""
key = os.urandom(KEY_SIZE)
key_path = tmp_path_factory.mktemp("keys") / "master.key"
key_path.write_text(base64.b64encode(key).decode())
return key_path
@pytest.fixture(autouse=True)
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Configure master key path for all tests."""
clear_key_cache()
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
@pytest_asyncio.fixture
async def db_conn() -> asyncpg.Connection:
"""Get a direct database connection for setup/teardown."""
conn = await asyncpg.connect(TEST_DB_DSN)
yield conn
await conn.close()
@pytest_asyncio.fixture
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
"""Ensure config schema exists and is clean before each test."""
# Create schema if not exists
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
# Create tables if not exist
await db_conn.execute("""
CREATE TABLE IF NOT EXISTS config.adapters (
name TEXT PRIMARY KEY,
enabled BOOLEAN NOT NULL DEFAULT true,
cadence_s INTEGER NOT NULL,
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
paused_at TIMESTAMPTZ,
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
)
""")
await db_conn.execute("""
CREATE TABLE IF NOT EXISTS config.api_keys (
alias TEXT PRIMARY KEY,
encrypted_value BYTEA NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
rotated_at TIMESTAMPTZ,
last_used_at TIMESTAMPTZ
)
""")
# Create notify function with proper key detection
await db_conn.execute("""
CREATE OR REPLACE FUNCTION config.notify_config_change()
RETURNS trigger AS $$
DECLARE
key_value TEXT;
BEGIN
IF TG_TABLE_NAME = 'adapters' THEN
key_value := COALESCE(NEW.name, OLD.name, '');
ELSIF TG_TABLE_NAME = 'api_keys' THEN
key_value := COALESCE(NEW.alias, OLD.alias, '');
ELSE
key_value := '';
END IF;
PERFORM pg_notify('config_changed', TG_TABLE_NAME || ':' || key_value);
RETURN COALESCE(NEW, OLD);
END;
$$ LANGUAGE plpgsql
""")
# Create triggers if not exist
await db_conn.execute("""
DROP TRIGGER IF EXISTS adapters_notify ON config.adapters;
CREATE TRIGGER adapters_notify
AFTER INSERT OR UPDATE OR DELETE ON config.adapters
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
""")
await db_conn.execute("""
DROP TRIGGER IF EXISTS api_keys_notify ON config.api_keys;
CREATE TRIGGER api_keys_notify
AFTER INSERT OR UPDATE OR DELETE ON config.api_keys
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
""")
# Clean tables
await db_conn.execute("DELETE FROM config.adapters")
await db_conn.execute("DELETE FROM config.api_keys")
@pytest_asyncio.fixture
async def config_store(clean_config_schema: None) -> ConfigStore:
"""Create a ConfigStore connected to the test database."""
store = await ConfigStore.create(TEST_DB_DSN)
yield store
await store.close()
class TestAdapterConfig:
"""Tests for adapter configuration operations."""
@pytest.mark.asyncio
async def test_upsert_and_get(self, config_store: ConfigStore) -> None:
"""Can insert and retrieve adapter config."""
await config_store.upsert_adapter(
name="test_adapter",
enabled=True,
cadence_s=120,
settings={"key": "value"},
)
adapter = await config_store.get_adapter("test_adapter")
assert adapter is not None
assert adapter.name == "test_adapter"
assert adapter.enabled is True
assert adapter.cadence_s == 120
assert adapter.settings == {"key": "value"}
@pytest.mark.asyncio
async def test_get_nonexistent(self, config_store: ConfigStore) -> None:
"""Getting nonexistent adapter returns None."""
adapter = await config_store.get_adapter("does_not_exist")
assert adapter is None
@pytest.mark.asyncio
async def test_list_adapters(self, config_store: ConfigStore) -> None:
"""Can list all adapters."""
await config_store.upsert_adapter("adapter_a", True, 60, {})
await config_store.upsert_adapter("adapter_b", False, 300, {"x": 1})
adapters = await config_store.list_adapters()
assert len(adapters) == 2
names = [a.name for a in adapters]
assert "adapter_a" in names
assert "adapter_b" in names
@pytest.mark.asyncio
async def test_upsert_updates_existing(self, config_store: ConfigStore) -> None:
"""Upsert updates existing adapter."""
await config_store.upsert_adapter("updater", True, 60, {"v": 1})
await config_store.upsert_adapter("updater", False, 120, {"v": 2})
adapter = await config_store.get_adapter("updater")
assert adapter is not None
assert adapter.enabled is False
assert adapter.cadence_s == 120
assert adapter.settings == {"v": 2}
@pytest.mark.asyncio
async def test_pause_unpause(self, config_store: ConfigStore) -> None:
"""Can pause and unpause adapter."""
await config_store.upsert_adapter("pausable", True, 60, {})
await config_store.pause_adapter("pausable")
adapter = await config_store.get_adapter("pausable")
assert adapter is not None
assert adapter.is_paused is True
await config_store.unpause_adapter("pausable")
adapter = await config_store.get_adapter("pausable")
assert adapter is not None
assert adapter.is_paused is False
class TestApiKeys:
"""Tests for API key operations."""
@pytest.mark.asyncio
async def test_set_and_get_key(self, config_store: ConfigStore) -> None:
"""Can store and retrieve encrypted API key."""
await config_store.set_api_key("test_key", "super_secret_value")
value = await config_store.get_api_key("test_key")
assert value == "super_secret_value"
@pytest.mark.asyncio
async def test_get_nonexistent_key(self, config_store: ConfigStore) -> None:
"""Getting nonexistent key returns None."""
value = await config_store.get_api_key("does_not_exist")
assert value is None
@pytest.mark.asyncio
async def test_key_rotation(self, config_store: ConfigStore) -> None:
"""Updating key sets rotated_at."""
await config_store.set_api_key("rotate_me", "value1")
await config_store.set_api_key("rotate_me", "value2")
value = await config_store.get_api_key("rotate_me")
assert value == "value2"
@pytest.mark.asyncio
async def test_delete_key(self, config_store: ConfigStore) -> None:
"""Can delete API key."""
await config_store.set_api_key("delete_me", "value")
deleted = await config_store.delete_api_key("delete_me")
assert deleted is True
value = await config_store.get_api_key("delete_me")
assert value is None
@pytest.mark.asyncio
async def test_delete_nonexistent(self, config_store: ConfigStore) -> None:
"""Deleting nonexistent key returns False."""
deleted = await config_store.delete_api_key("never_existed")
assert deleted is False
class TestNotifications:
"""Tests for LISTEN/NOTIFY functionality."""
@pytest.mark.asyncio
async def test_notify_on_adapter_change(self, config_store: ConfigStore) -> None:
"""NOTIFY fires when adapter is changed."""
notifications: list[tuple[str, str]] = []
notification_received = asyncio.Event()
async def callback(table: str, key: str) -> None:
notifications.append((table, key))
notification_received.set()
# Start listener in background
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
try:
# Give listener time to subscribe
await asyncio.sleep(0.1)
# Trigger a change
await config_store.upsert_adapter("notify_test", True, 60, {})
# Wait for notification (with timeout)
try:
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
except asyncio.TimeoutError:
pytest.fail("Notification not received within timeout")
assert len(notifications) >= 1
assert notifications[0][0] == "adapters"
assert notifications[0][1] == "notify_test"
finally:
listen_task.cancel()
try:
await listen_task
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_notify_on_api_key_change(self, config_store: ConfigStore) -> None:
"""NOTIFY fires when API key is changed."""
notifications: list[tuple[str, str]] = []
notification_received = asyncio.Event()
async def callback(table: str, key: str) -> None:
notifications.append((table, key))
notification_received.set()
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
try:
await asyncio.sleep(0.1)
await config_store.set_api_key("notify_key", "secret")
try:
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
except asyncio.TimeoutError:
pytest.fail("Notification not received within timeout")
assert len(notifications) >= 1
assert notifications[0][0] == "api_keys"
assert notifications[0][1] == "notify_key"
finally:
listen_task.cancel()
try:
await listen_task
except asyncio.CancelledError:
pass
class TestListenerReconnect:
"""Tests for listener reconnection on connection loss."""
@pytest.mark.asyncio
async def test_listener_cancellation_propagates(
self, config_store: ConfigStore
) -> None:
"""Cancellation cleanly stops the listener without reconnect loop."""
async def callback(table: str, key: str) -> None:
pass
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
# Give listener time to start
await asyncio.sleep(0.1)
# Cancel and verify it stops
listen_task.cancel()
try:
await asyncio.wait_for(listen_task, timeout=2.0)
except asyncio.CancelledError:
pass # Expected
except asyncio.TimeoutError:
pytest.fail("Listener did not stop after cancellation")
assert listen_task.cancelled() or listen_task.done()
"""Tests for database-backed configuration store.
These tests require a real Postgres database. Set CENTRAL_TEST_DB_DSN
environment variable to override the default test database connection.
"""
import asyncio
import base64
import os
from pathlib import Path
import asyncpg
import pytest
import pytest_asyncio
from central.config_store import ConfigStore
from central.crypto import KEY_SIZE, clear_key_cache
# Test database DSN - uses central_test database with well-known test password.
# Override via CENTRAL_TEST_DB_DSN env var if your test DB differs.
TEST_DB_DSN = os.environ.get(
"CENTRAL_TEST_DB_DSN",
"postgresql://central_test:testpass@localhost/central_test",
)
@pytest.fixture(scope="session")
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
"""Create a master key file for the test session."""
key = os.urandom(KEY_SIZE)
key_path = tmp_path_factory.mktemp("keys") / "master.key"
key_path.write_text(base64.b64encode(key).decode())
return key_path
@pytest.fixture(autouse=True)
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Configure master key path for all tests."""
clear_key_cache()
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
@pytest_asyncio.fixture
async def db_conn() -> asyncpg.Connection:
"""Get a direct database connection for setup/teardown."""
conn = await asyncpg.connect(TEST_DB_DSN)
yield conn
await conn.close()
@pytest_asyncio.fixture
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
"""Ensure config schema exists and is clean before each test."""
# Create schema if not exists
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
# Create tables if not exist
await db_conn.execute("""
CREATE TABLE IF NOT EXISTS config.adapters (
name TEXT PRIMARY KEY,
enabled BOOLEAN NOT NULL DEFAULT true,
cadence_s INTEGER NOT NULL,
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
paused_at TIMESTAMPTZ,
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
)
""")
await db_conn.execute("""
CREATE TABLE IF NOT EXISTS config.api_keys (
alias TEXT PRIMARY KEY,
encrypted_value BYTEA NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
rotated_at TIMESTAMPTZ,
last_used_at TIMESTAMPTZ
)
""")
# Create notify function with proper key detection
await db_conn.execute("""
CREATE OR REPLACE FUNCTION config.notify_config_change()
RETURNS trigger AS $$
DECLARE
key_value TEXT;
BEGIN
IF TG_TABLE_NAME = 'adapters' THEN
key_value := COALESCE(NEW.name, OLD.name, '');
ELSIF TG_TABLE_NAME = 'api_keys' THEN
key_value := COALESCE(NEW.alias, OLD.alias, '');
ELSE
key_value := '';
END IF;
PERFORM pg_notify('config_changed', TG_TABLE_NAME || ':' || key_value);
RETURN COALESCE(NEW, OLD);
END;
$$ LANGUAGE plpgsql
""")
# Create triggers if not exist
await db_conn.execute("""
DROP TRIGGER IF EXISTS adapters_notify ON config.adapters;
CREATE TRIGGER adapters_notify
AFTER INSERT OR UPDATE OR DELETE ON config.adapters
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
""")
await db_conn.execute("""
DROP TRIGGER IF EXISTS api_keys_notify ON config.api_keys;
CREATE TRIGGER api_keys_notify
AFTER INSERT OR UPDATE OR DELETE ON config.api_keys
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
""")
# Clean tables
await db_conn.execute("DELETE FROM config.adapters")
await db_conn.execute("DELETE FROM config.api_keys")
@pytest_asyncio.fixture
async def config_store(clean_config_schema: None) -> ConfigStore:
"""Create a ConfigStore connected to the test database."""
store = await ConfigStore.create(TEST_DB_DSN)
yield store
await store.close()
class TestAdapterConfig:
"""Tests for adapter configuration operations."""
@pytest.mark.asyncio
async def test_upsert_and_get(self, config_store: ConfigStore) -> None:
"""Can insert and retrieve adapter config."""
await config_store.upsert_adapter(
name="test_adapter",
enabled=True,
cadence_s=120,
settings={"key": "value"},
)
adapter = await config_store.get_adapter("test_adapter")
assert adapter is not None
assert adapter.name == "test_adapter"
assert adapter.enabled is True
assert adapter.cadence_s == 120
assert adapter.settings == {"key": "value"}
@pytest.mark.asyncio
async def test_get_nonexistent(self, config_store: ConfigStore) -> None:
"""Getting nonexistent adapter returns None."""
adapter = await config_store.get_adapter("does_not_exist")
assert adapter is None
@pytest.mark.asyncio
async def test_list_adapters(self, config_store: ConfigStore) -> None:
"""Can list all adapters."""
await config_store.upsert_adapter("adapter_a", True, 60, {})
await config_store.upsert_adapter("adapter_b", False, 300, {"x": 1})
adapters = await config_store.list_adapters()
assert len(adapters) == 2
names = [a.name for a in adapters]
assert "adapter_a" in names
assert "adapter_b" in names
@pytest.mark.asyncio
async def test_upsert_updates_existing(self, config_store: ConfigStore) -> None:
"""Upsert updates existing adapter."""
await config_store.upsert_adapter("updater", True, 60, {"v": 1})
await config_store.upsert_adapter("updater", False, 120, {"v": 2})
adapter = await config_store.get_adapter("updater")
assert adapter is not None
assert adapter.enabled is False
assert adapter.cadence_s == 120
assert adapter.settings == {"v": 2}
@pytest.mark.asyncio
async def test_pause_unpause(self, config_store: ConfigStore) -> None:
"""Can pause and unpause adapter."""
await config_store.upsert_adapter("pausable", True, 60, {})
await config_store.pause_adapter("pausable")
adapter = await config_store.get_adapter("pausable")
assert adapter is not None
assert adapter.is_paused is True
await config_store.unpause_adapter("pausable")
adapter = await config_store.get_adapter("pausable")
assert adapter is not None
assert adapter.is_paused is False
class TestApiKeys:
"""Tests for API key operations."""
@pytest.mark.asyncio
async def test_set_and_get_key(self, config_store: ConfigStore) -> None:
"""Can store and retrieve encrypted API key."""
await config_store.set_api_key("test_key", "super_secret_value")
value = await config_store.get_api_key("test_key")
assert value == "super_secret_value"
@pytest.mark.asyncio
async def test_get_nonexistent_key(self, config_store: ConfigStore) -> None:
"""Getting nonexistent key returns None."""
value = await config_store.get_api_key("does_not_exist")
assert value is None
@pytest.mark.asyncio
async def test_key_rotation(self, config_store: ConfigStore) -> None:
"""Updating key sets rotated_at."""
await config_store.set_api_key("rotate_me", "value1")
await config_store.set_api_key("rotate_me", "value2")
value = await config_store.get_api_key("rotate_me")
assert value == "value2"
@pytest.mark.asyncio
async def test_delete_key(self, config_store: ConfigStore) -> None:
"""Can delete API key."""
await config_store.set_api_key("delete_me", "value")
deleted = await config_store.delete_api_key("delete_me")
assert deleted is True
value = await config_store.get_api_key("delete_me")
assert value is None
@pytest.mark.asyncio
async def test_delete_nonexistent(self, config_store: ConfigStore) -> None:
"""Deleting nonexistent key returns False."""
deleted = await config_store.delete_api_key("never_existed")
assert deleted is False
class TestNotifications:
"""Tests for LISTEN/NOTIFY functionality."""
@pytest.mark.asyncio
async def test_notify_on_adapter_change(self, config_store: ConfigStore) -> None:
"""NOTIFY fires when adapter is changed."""
notifications: list[tuple[str, str]] = []
notification_received = asyncio.Event()
async def callback(table: str, key: str) -> None:
notifications.append((table, key))
notification_received.set()
# Start listener in background
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
try:
# Give listener time to subscribe
await asyncio.sleep(0.1)
# Trigger a change
await config_store.upsert_adapter("notify_test", True, 60, {})
# Wait for notification (with timeout)
try:
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
except asyncio.TimeoutError:
pytest.fail("Notification not received within timeout")
assert len(notifications) >= 1
assert notifications[0][0] == "adapters"
assert notifications[0][1] == "notify_test"
finally:
listen_task.cancel()
try:
await listen_task
except asyncio.CancelledError:
pass
@pytest.mark.asyncio
async def test_notify_on_api_key_change(self, config_store: ConfigStore) -> None:
"""NOTIFY fires when API key is changed."""
notifications: list[tuple[str, str]] = []
notification_received = asyncio.Event()
async def callback(table: str, key: str) -> None:
notifications.append((table, key))
notification_received.set()
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
try:
await asyncio.sleep(0.1)
await config_store.set_api_key("notify_key", "secret")
try:
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
except asyncio.TimeoutError:
pytest.fail("Notification not received within timeout")
assert len(notifications) >= 1
assert notifications[0][0] == "api_keys"
assert notifications[0][1] == "notify_key"
finally:
listen_task.cancel()
try:
await listen_task
except asyncio.CancelledError:
pass
class TestListenerReconnect:
"""Tests for listener reconnection on connection loss."""
@pytest.mark.asyncio
async def test_listener_cancellation_propagates(
self, config_store: ConfigStore
) -> None:
"""Cancellation cleanly stops the listener without reconnect loop."""
async def callback(table: str, key: str) -> None:
pass
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
# Give listener time to start
await asyncio.sleep(0.1)
# Cancel and verify it stops
listen_task.cancel()
try:
await asyncio.wait_for(listen_task, timeout=2.0)
except asyncio.CancelledError:
pass # Expected
except asyncio.TimeoutError:
pytest.fail("Listener did not stop after cancellation")
assert listen_task.cancelled() or listen_task.done()

View file

@ -1,175 +1,175 @@
"""Tests for cryptographic primitives."""
import base64
import os
from pathlib import Path
import pytest
from central.crypto import (
KEY_SIZE,
DecryptionError,
KeyLoadError,
clear_key_cache,
decrypt,
encrypt,
)
@pytest.fixture
def master_key(tmp_path: Path) -> Path:
"""Create a valid master key file."""
key = os.urandom(KEY_SIZE)
key_path = tmp_path / "master.key"
key_path.write_text(base64.b64encode(key).decode())
clear_key_cache()
return key_path
@pytest.fixture
def wrong_key(tmp_path: Path) -> Path:
"""Create a different master key file."""
key = os.urandom(KEY_SIZE)
key_path = tmp_path / "wrong.key"
key_path.write_text(base64.b64encode(key).decode())
return key_path
class TestEncryptDecrypt:
"""Test encrypt/decrypt round-trip."""
def test_round_trip(self, master_key: Path) -> None:
"""Encrypting then decrypting returns original plaintext."""
plaintext = b"Hello, Central!"
ciphertext = encrypt(plaintext, key_path=master_key)
decrypted = decrypt(ciphertext, key_path=master_key)
assert decrypted == plaintext
def test_round_trip_empty(self, master_key: Path) -> None:
"""Empty plaintext encrypts and decrypts correctly."""
plaintext = b""
ciphertext = encrypt(plaintext, key_path=master_key)
decrypted = decrypt(ciphertext, key_path=master_key)
assert decrypted == plaintext
def test_round_trip_large(self, master_key: Path) -> None:
"""Large plaintext encrypts and decrypts correctly."""
plaintext = os.urandom(1024 * 1024) # 1MB
ciphertext = encrypt(plaintext, key_path=master_key)
decrypted = decrypt(ciphertext, key_path=master_key)
assert decrypted == plaintext
def test_ciphertext_different_each_time(self, master_key: Path) -> None:
"""Same plaintext produces different ciphertext (random nonce)."""
plaintext = b"test"
ct1 = encrypt(plaintext, key_path=master_key)
ct2 = encrypt(plaintext, key_path=master_key)
assert ct1 != ct2
# But both decrypt to same plaintext
assert decrypt(ct1, key_path=master_key) == plaintext
assert decrypt(ct2, key_path=master_key) == plaintext
class TestDecryptionFailures:
"""Test AEAD authentication catches tampering."""
def test_wrong_key_fails(self, master_key: Path, wrong_key: Path) -> None:
"""Decryption with wrong key raises DecryptionError."""
plaintext = b"secret"
ciphertext = encrypt(plaintext, key_path=master_key)
clear_key_cache() # Clear cache so wrong_key is loaded
with pytest.raises(DecryptionError):
decrypt(ciphertext, key_path=wrong_key)
def test_tampered_ciphertext_fails(self, master_key: Path) -> None:
"""Modified ciphertext is detected and rejected."""
plaintext = b"secret"
ciphertext = encrypt(plaintext, key_path=master_key)
# Flip a bit in the ciphertext (after nonce, before tag)
tampered = bytearray(ciphertext)
tampered[15] ^= 0x01 # Flip one bit
tampered = bytes(tampered)
with pytest.raises(DecryptionError):
decrypt(tampered, key_path=master_key)
def test_tampered_tag_fails(self, master_key: Path) -> None:
"""Modified authentication tag is detected and rejected."""
plaintext = b"secret"
ciphertext = encrypt(plaintext, key_path=master_key)
# Flip a bit in the last byte (part of the tag)
tampered = bytearray(ciphertext)
tampered[-1] ^= 0x01
tampered = bytes(tampered)
with pytest.raises(DecryptionError):
decrypt(tampered, key_path=master_key)
def test_truncated_ciphertext_fails(self, master_key: Path) -> None:
"""Truncated ciphertext is rejected."""
ciphertext = b"tooshort"
with pytest.raises(DecryptionError, match="too short"):
decrypt(ciphertext, key_path=master_key)
class TestKeyLoading:
"""Test master key loading."""
def test_missing_key_file(self, tmp_path: Path) -> None:
"""Missing key file raises KeyLoadError."""
clear_key_cache()
missing = tmp_path / "nonexistent.key"
with pytest.raises(KeyLoadError, match="not found"):
encrypt(b"test", key_path=missing)
def test_invalid_key_size(self, tmp_path: Path) -> None:
"""Key file with wrong size raises KeyLoadError."""
clear_key_cache()
bad_key = tmp_path / "bad.key"
bad_key.write_text(base64.b64encode(b"tooshort").decode())
with pytest.raises(KeyLoadError, match="Invalid master key size"):
encrypt(b"test", key_path=bad_key)
def test_invalid_base64(self, tmp_path: Path) -> None:
"""Invalid base64 in key file raises KeyLoadError."""
clear_key_cache()
bad_key = tmp_path / "bad.key"
bad_key.write_text("not valid base64!!!")
with pytest.raises(KeyLoadError):
encrypt(b"test", key_path=bad_key)
def test_key_cached(self, master_key: Path) -> None:
"""Key is cached after first load."""
# First encryption loads the key
encrypt(b"test1", key_path=master_key)
# Delete the file
master_key.unlink()
# Second encryption should still work (cached)
ciphertext = encrypt(b"test2", key_path=master_key)
assert len(ciphertext) > 0
def test_cache_clear(self, master_key: Path) -> None:
"""clear_key_cache forces reload."""
encrypt(b"test", key_path=master_key)
master_key.unlink()
clear_key_cache()
with pytest.raises(KeyLoadError, match="not found"):
encrypt(b"test", key_path=master_key)
"""Tests for cryptographic primitives."""
import base64
import os
from pathlib import Path
import pytest
from central.crypto import (
KEY_SIZE,
DecryptionError,
KeyLoadError,
clear_key_cache,
decrypt,
encrypt,
)
@pytest.fixture
def master_key(tmp_path: Path) -> Path:
"""Create a valid master key file."""
key = os.urandom(KEY_SIZE)
key_path = tmp_path / "master.key"
key_path.write_text(base64.b64encode(key).decode())
clear_key_cache()
return key_path
@pytest.fixture
def wrong_key(tmp_path: Path) -> Path:
"""Create a different master key file."""
key = os.urandom(KEY_SIZE)
key_path = tmp_path / "wrong.key"
key_path.write_text(base64.b64encode(key).decode())
return key_path
class TestEncryptDecrypt:
"""Test encrypt/decrypt round-trip."""
def test_round_trip(self, master_key: Path) -> None:
"""Encrypting then decrypting returns original plaintext."""
plaintext = b"Hello, Central!"
ciphertext = encrypt(plaintext, key_path=master_key)
decrypted = decrypt(ciphertext, key_path=master_key)
assert decrypted == plaintext
def test_round_trip_empty(self, master_key: Path) -> None:
"""Empty plaintext encrypts and decrypts correctly."""
plaintext = b""
ciphertext = encrypt(plaintext, key_path=master_key)
decrypted = decrypt(ciphertext, key_path=master_key)
assert decrypted == plaintext
def test_round_trip_large(self, master_key: Path) -> None:
"""Large plaintext encrypts and decrypts correctly."""
plaintext = os.urandom(1024 * 1024) # 1MB
ciphertext = encrypt(plaintext, key_path=master_key)
decrypted = decrypt(ciphertext, key_path=master_key)
assert decrypted == plaintext
def test_ciphertext_different_each_time(self, master_key: Path) -> None:
"""Same plaintext produces different ciphertext (random nonce)."""
plaintext = b"test"
ct1 = encrypt(plaintext, key_path=master_key)
ct2 = encrypt(plaintext, key_path=master_key)
assert ct1 != ct2
# But both decrypt to same plaintext
assert decrypt(ct1, key_path=master_key) == plaintext
assert decrypt(ct2, key_path=master_key) == plaintext
class TestDecryptionFailures:
"""Test AEAD authentication catches tampering."""
def test_wrong_key_fails(self, master_key: Path, wrong_key: Path) -> None:
"""Decryption with wrong key raises DecryptionError."""
plaintext = b"secret"
ciphertext = encrypt(plaintext, key_path=master_key)
clear_key_cache() # Clear cache so wrong_key is loaded
with pytest.raises(DecryptionError):
decrypt(ciphertext, key_path=wrong_key)
def test_tampered_ciphertext_fails(self, master_key: Path) -> None:
"""Modified ciphertext is detected and rejected."""
plaintext = b"secret"
ciphertext = encrypt(plaintext, key_path=master_key)
# Flip a bit in the ciphertext (after nonce, before tag)
tampered = bytearray(ciphertext)
tampered[15] ^= 0x01 # Flip one bit
tampered = bytes(tampered)
with pytest.raises(DecryptionError):
decrypt(tampered, key_path=master_key)
def test_tampered_tag_fails(self, master_key: Path) -> None:
"""Modified authentication tag is detected and rejected."""
plaintext = b"secret"
ciphertext = encrypt(plaintext, key_path=master_key)
# Flip a bit in the last byte (part of the tag)
tampered = bytearray(ciphertext)
tampered[-1] ^= 0x01
tampered = bytes(tampered)
with pytest.raises(DecryptionError):
decrypt(tampered, key_path=master_key)
def test_truncated_ciphertext_fails(self, master_key: Path) -> None:
"""Truncated ciphertext is rejected."""
ciphertext = b"tooshort"
with pytest.raises(DecryptionError, match="too short"):
decrypt(ciphertext, key_path=master_key)
class TestKeyLoading:
"""Test master key loading."""
def test_missing_key_file(self, tmp_path: Path) -> None:
"""Missing key file raises KeyLoadError."""
clear_key_cache()
missing = tmp_path / "nonexistent.key"
with pytest.raises(KeyLoadError, match="not found"):
encrypt(b"test", key_path=missing)
def test_invalid_key_size(self, tmp_path: Path) -> None:
"""Key file with wrong size raises KeyLoadError."""
clear_key_cache()
bad_key = tmp_path / "bad.key"
bad_key.write_text(base64.b64encode(b"tooshort").decode())
with pytest.raises(KeyLoadError, match="Invalid master key size"):
encrypt(b"test", key_path=bad_key)
def test_invalid_base64(self, tmp_path: Path) -> None:
"""Invalid base64 in key file raises KeyLoadError."""
clear_key_cache()
bad_key = tmp_path / "bad.key"
bad_key.write_text("not valid base64!!!")
with pytest.raises(KeyLoadError):
encrypt(b"test", key_path=bad_key)
def test_key_cached(self, master_key: Path) -> None:
"""Key is cached after first load."""
# First encryption loads the key
encrypt(b"test1", key_path=master_key)
# Delete the file
master_key.unlink()
# Second encryption should still work (cached)
ciphertext = encrypt(b"test2", key_path=master_key)
assert len(ciphertext) > 0
def test_cache_clear(self, master_key: Path) -> None:
"""clear_key_cache forces reload."""
encrypt(b"test", key_path=master_key)
master_key.unlink()
clear_key_cache()
with pytest.raises(KeyLoadError, match="not found"):
encrypt(b"test", key_path=master_key)

View file

@ -1,161 +1,161 @@
"""Smoke tests for Central models and CloudEvents wire format."""
from datetime import datetime, timezone
import pytest
from central.models import Event, Geo, subject_for_event
from central.config import NWSAdapterConfig, CloudEventsConfig, NATSConfig, PostgresConfig, Config
from central.cloudevents_wire import wrap_event
@pytest.fixture
def sample_geo() -> Geo:
"""Sample Geo object for testing."""
return Geo(
centroid=(-116.2, 43.6),
bbox=(-116.5, 43.4, -115.9, 43.8),
regions=["US-ID-Ada", "US-ID-Canyon"],
primary_region="US-ID-Ada",
)
@pytest.fixture
def sample_event(sample_geo: Geo) -> Event:
"""Sample Event object for testing."""
return Event(
id="urn:central:nws:alert:KBOI-202401151200-SVR",
source="central/adapters/nws",
category="wx.alert.severe_thunderstorm_warning",
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
expires=datetime(2024, 1, 15, 13, 0, 0, tzinfo=timezone.utc),
severity=3,
geo=sample_geo,
data={"headline": "Severe Thunderstorm Warning", "urgency": "Immediate"},
)
@pytest.fixture
def sample_config() -> Config:
"""Sample Config object for testing."""
return Config(
adapters={
"nws": NWSAdapterConfig(
enabled=True,
cadence_s=60,
states=["ID", "MT"],
contact_email="test@example.com",
)
},
cloudevents=CloudEventsConfig(
type_prefix="central",
source="central.local",
schema_version="1.0",
),
nats=NATSConfig(url="nats://localhost:4222"),
postgres=PostgresConfig(dsn="postgresql://user:pass@localhost/db"),
)
class TestSubjectForEvent:
"""Tests for subject_for_event helper."""
def test_county_subject(self, sample_event: Event) -> None:
"""County codes produce county subject."""
subject = subject_for_event(sample_event)
assert subject == "central.wx.alert.us.id.county.ada"
def test_zone_subject(self, sample_geo: Geo) -> None:
"""Zone codes produce zone subject."""
geo = Geo(
centroid=sample_geo.centroid,
bbox=sample_geo.bbox,
regions=["US-ID-Z033"],
primary_region="US-ID-Z033",
)
event = Event(
id="test-zone",
source="test",
category="wx.alert.winter_storm_warning",
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
geo=geo,
data={},
)
subject = subject_for_event(event)
assert subject == "central.wx.alert.us.id.zone.z033"
def test_unknown_subject(self, sample_event: Event) -> None:
"""Missing primary_region produces unknown subject."""
geo = Geo(regions=[], primary_region=None)
event = Event(
id="test-unknown",
source="test",
category="wx.alert.test",
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
geo=geo,
data={},
)
subject = subject_for_event(event)
assert subject == "central.wx.alert.us.unknown"
def test_custom_prefix(self, sample_event: Event) -> None:
"""Custom prefix is used in subject."""
subject = subject_for_event(sample_event, prefix="myapp.events")
assert subject == "myapp.events.alert.us.id.county.ada"
class TestCloudEventsWire:
"""Tests for CloudEvents wire format."""
def test_required_fields_present(
self, sample_event: Event, sample_config: Config
) -> None:
"""Required CloudEvents fields are present."""
envelope, msg_id = wrap_event(sample_event, sample_config)
assert msg_id == sample_event.id
assert envelope["id"] == sample_event.id
assert envelope["source"] == sample_config.cloudevents.source
assert envelope["type"] == "central.wx.alert.severe_thunderstorm_warning.v1"
assert envelope["specversion"] == "1.0"
assert "time" in envelope
assert envelope["datacontenttype"] == "application/json"
assert "data" in envelope
def test_extension_attributes_lowercase(
self, sample_event: Event, sample_config: Config
) -> None:
"""Extension attributes are lowercase with no underscores."""
envelope, _ = wrap_event(sample_event, sample_config)
# Check that extension attributes exist and are lowercase
assert envelope["centralschemaversion"] == "1.0"
assert envelope["centralcategory"] == "wx.alert.severe_thunderstorm_warning"
assert envelope["centralseverity"] == 3
# Verify no uppercase or underscores in extension names
for key in ["centralschemaversion", "centralcategory", "centralseverity"]:
assert key.islower()
assert "_" not in key
def test_severity_none_omits_centralseverity(
self, sample_geo: Geo, sample_config: Config
) -> None:
"""When severity is None, centralseverity is omitted entirely."""
event = Event(
id="test-no-severity",
source="test",
category="wx.alert.test",
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
severity=None, # Explicitly None
geo=sample_geo,
data={},
)
envelope, _ = wrap_event(event, sample_config)
# centralseverity should not be present at all
assert "centralseverity" not in envelope
# Other extensions should still be present
assert "centralschemaversion" in envelope
assert "centralcategory" in envelope
"""Smoke tests for Central models and CloudEvents wire format."""
from datetime import datetime, timezone
import pytest
from central.models import Event, Geo, subject_for_event
from central.config import NWSAdapterConfig, CloudEventsConfig, NATSConfig, PostgresConfig, Config
from central.cloudevents_wire import wrap_event
@pytest.fixture
def sample_geo() -> Geo:
"""Sample Geo object for testing."""
return Geo(
centroid=(-116.2, 43.6),
bbox=(-116.5, 43.4, -115.9, 43.8),
regions=["US-ID-Ada", "US-ID-Canyon"],
primary_region="US-ID-Ada",
)
@pytest.fixture
def sample_event(sample_geo: Geo) -> Event:
"""Sample Event object for testing."""
return Event(
id="urn:central:nws:alert:KBOI-202401151200-SVR",
source="central/adapters/nws",
category="wx.alert.severe_thunderstorm_warning",
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
expires=datetime(2024, 1, 15, 13, 0, 0, tzinfo=timezone.utc),
severity=3,
geo=sample_geo,
data={"headline": "Severe Thunderstorm Warning", "urgency": "Immediate"},
)
@pytest.fixture
def sample_config() -> Config:
"""Sample Config object for testing."""
return Config(
adapters={
"nws": NWSAdapterConfig(
enabled=True,
cadence_s=60,
states=["ID", "MT"],
contact_email="test@example.com",
)
},
cloudevents=CloudEventsConfig(
type_prefix="central",
source="central.local",
schema_version="1.0",
),
nats=NATSConfig(url="nats://localhost:4222"),
postgres=PostgresConfig(dsn="postgresql://user:pass@localhost/db"),
)
class TestSubjectForEvent:
"""Tests for subject_for_event helper."""
def test_county_subject(self, sample_event: Event) -> None:
"""County codes produce county subject."""
subject = subject_for_event(sample_event)
assert subject == "central.wx.alert.us.id.county.ada"
def test_zone_subject(self, sample_geo: Geo) -> None:
"""Zone codes produce zone subject."""
geo = Geo(
centroid=sample_geo.centroid,
bbox=sample_geo.bbox,
regions=["US-ID-Z033"],
primary_region="US-ID-Z033",
)
event = Event(
id="test-zone",
source="test",
category="wx.alert.winter_storm_warning",
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
geo=geo,
data={},
)
subject = subject_for_event(event)
assert subject == "central.wx.alert.us.id.zone.z033"
def test_unknown_subject(self, sample_event: Event) -> None:
"""Missing primary_region produces unknown subject."""
geo = Geo(regions=[], primary_region=None)
event = Event(
id="test-unknown",
source="test",
category="wx.alert.test",
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
geo=geo,
data={},
)
subject = subject_for_event(event)
assert subject == "central.wx.alert.us.unknown"
def test_custom_prefix(self, sample_event: Event) -> None:
"""Custom prefix is used in subject."""
subject = subject_for_event(sample_event, prefix="myapp.events")
assert subject == "myapp.events.alert.us.id.county.ada"
class TestCloudEventsWire:
"""Tests for CloudEvents wire format."""
def test_required_fields_present(
self, sample_event: Event, sample_config: Config
) -> None:
"""Required CloudEvents fields are present."""
envelope, msg_id = wrap_event(sample_event, sample_config)
assert msg_id == sample_event.id
assert envelope["id"] == sample_event.id
assert envelope["source"] == sample_config.cloudevents.source
assert envelope["type"] == "central.wx.alert.severe_thunderstorm_warning.v1"
assert envelope["specversion"] == "1.0"
assert "time" in envelope
assert envelope["datacontenttype"] == "application/json"
assert "data" in envelope
def test_extension_attributes_lowercase(
self, sample_event: Event, sample_config: Config
) -> None:
"""Extension attributes are lowercase with no underscores."""
envelope, _ = wrap_event(sample_event, sample_config)
# Check that extension attributes exist and are lowercase
assert envelope["centralschemaversion"] == "1.0"
assert envelope["centralcategory"] == "wx.alert.severe_thunderstorm_warning"
assert envelope["centralseverity"] == 3
# Verify no uppercase or underscores in extension names
for key in ["centralschemaversion", "centralcategory", "centralseverity"]:
assert key.islower()
assert "_" not in key
def test_severity_none_omits_centralseverity(
self, sample_geo: Geo, sample_config: Config
) -> None:
"""When severity is None, centralseverity is omitted entirely."""
event = Event(
id="test-no-severity",
source="test",
category="wx.alert.test",
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
severity=None, # Explicitly None
geo=sample_geo,
data={},
)
envelope, _ = wrap_event(event, sample_config)
# centralseverity should not be present at all
assert "centralseverity" not in envelope
# Other extensions should still be present
assert "centralschemaversion" in envelope
assert "centralcategory" in envelope

View file

@ -1,357 +1,357 @@
"""Tests for supervisor hot-reload and rate-limiting behavior."""
import asyncio
import base64
import os
from datetime import datetime, timedelta, timezone
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import asyncpg
import pytest
import pytest_asyncio
from central.config_models import AdapterConfig
from central.config_source import DbConfigSource
from central.config_store import ConfigStore
from central.crypto import KEY_SIZE, clear_key_cache
# Test database DSN
TEST_DB_DSN = os.environ.get(
"CENTRAL_TEST_DB_DSN",
"postgresql://central_test:testpass@localhost/central_test",
)
@pytest.fixture(scope="session")
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
"""Create a master key file for the test session."""
key = os.urandom(KEY_SIZE)
key_path = tmp_path_factory.mktemp("keys") / "master.key"
key_path.write_text(base64.b64encode(key).decode())
return key_path
@pytest.fixture(autouse=True)
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Configure master key path for all tests."""
clear_key_cache()
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
@pytest_asyncio.fixture
async def db_conn() -> asyncpg.Connection:
"""Get a direct database connection for setup/teardown."""
conn = await asyncpg.connect(TEST_DB_DSN)
yield conn
await conn.close()
@pytest_asyncio.fixture
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
"""Ensure config schema exists and is clean before each test."""
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
await db_conn.execute("""
CREATE TABLE IF NOT EXISTS config.adapters (
name TEXT PRIMARY KEY,
enabled BOOLEAN NOT NULL DEFAULT true,
cadence_s INTEGER NOT NULL,
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
paused_at TIMESTAMPTZ,
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
)
""")
# Create notify trigger
await db_conn.execute("""
CREATE OR REPLACE FUNCTION config.notify_config_change()
RETURNS trigger AS $$
DECLARE
key_value TEXT;
BEGIN
IF TG_TABLE_NAME = 'adapters' THEN
key_value := COALESCE(NEW.name, OLD.name, '');
ELSE
key_value := '';
END IF;
PERFORM pg_notify('config_changed', TG_TABLE_NAME || ':' || key_value);
RETURN COALESCE(NEW, OLD);
END;
$$ LANGUAGE plpgsql
""")
await db_conn.execute("""
DROP TRIGGER IF EXISTS adapters_notify ON config.adapters;
CREATE TRIGGER adapters_notify
AFTER INSERT OR UPDATE OR DELETE ON config.adapters
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
""")
await db_conn.execute("DELETE FROM config.adapters")
@pytest_asyncio.fixture
async def config_store(clean_config_schema: None) -> ConfigStore:
"""Create a ConfigStore connected to the test database."""
store = await ConfigStore.create(TEST_DB_DSN)
yield store
await store.close()
class TestDbConfigSourceNotifications:
"""Tests for DbConfigSource NOTIFY integration."""
@pytest.mark.asyncio
async def test_watch_receives_notifications(
self,
config_store: ConfigStore,
db_conn: asyncpg.Connection,
) -> None:
"""watch_for_changes receives NOTIFY when adapter changes."""
source = DbConfigSource(config_store)
notifications: list[tuple[str, str]] = []
notification_received = asyncio.Event()
async def callback(table: str, key: str) -> None:
notifications.append((table, key))
notification_received.set()
# Start watching in background
watch_task = asyncio.create_task(source.watch_for_changes(callback))
try:
# Wait for listener to connect
await asyncio.sleep(0.2)
# Insert an adapter via direct connection (not through store)
# This triggers the NOTIFY
await db_conn.execute("""
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
VALUES ('test_adapter', true, 60, '{}'::jsonb)
""")
# Wait for notification
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
assert len(notifications) >= 1
assert notifications[0] == ("adapters", "test_adapter")
finally:
watch_task.cancel()
try:
await watch_task
except asyncio.CancelledError:
pass
class TestRateLimitGuarantee:
"""Tests for rate-limit guarantees during hot-reload.
These tests verify the critical invariant: cadence changes must not
cause extra API calls before (last_poll + new_cadence).
"""
@pytest.mark.asyncio
async def test_cadence_change_respects_last_poll_time(self) -> None:
"""Changing cadence mid-cycle schedules next poll at last_poll + new_cadence.
This is the core rate-limit guarantee test (gate 3).
"""
# Import supervisor module to access AdapterState
from central.supervisor import AdapterState
# Mock adapter
mock_adapter = MagicMock()
mock_adapter.name = "test"
mock_adapter.cadence_s = 60
# Create adapter state with a known last_completed_poll time
last_poll = datetime.now(timezone.utc) - timedelta(seconds=30)
config = AdapterConfig(
name="test",
enabled=True,
cadence_s=60, # Original cadence
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
state = AdapterState(
name="test",
adapter=mock_adapter,
config=config,
last_completed_poll=last_poll,
)
# Simulate cadence change to 90 seconds
new_config = AdapterConfig(
name="test",
enabled=True,
cadence_s=90, # New cadence
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
# Update state as reschedule would
state.config = new_config
state.adapter.cadence_s = 90
# Calculate expected next poll time
expected_next_poll = last_poll + timedelta(seconds=90)
now = datetime.now(timezone.utc)
expected_wait = max(0, (expected_next_poll - now).total_seconds())
# The wait time should be based on last_poll + new_cadence
# Since last_poll was 30 seconds ago and new cadence is 90,
# we should wait 60 more seconds (90 - 30 = 60)
actual_next_poll = last_poll.timestamp() + new_config.cadence_s
actual_wait = max(0, actual_next_poll - now.timestamp())
# Allow 1 second tolerance for timing
assert abs(actual_wait - 60) < 2, (
f"Expected ~60s wait, got {actual_wait}s. "
f"Rate limit violated: poll would happen before last_poll + new_cadence"
)
@pytest.mark.asyncio
async def test_cadence_increase_after_gap_polls_immediately(self) -> None:
"""When last_poll + new_cadence is already past, poll immediately.
If operator increases cadence to 120s after a gap of 150s,
the poll should happen now (not wait another 120s).
"""
from central.supervisor import AdapterState
mock_adapter = MagicMock()
mock_adapter.name = "test"
mock_adapter.cadence_s = 60
# Last poll was 150 seconds ago
last_poll = datetime.now(timezone.utc) - timedelta(seconds=150)
config = AdapterConfig(
name="test",
enabled=True,
cadence_s=120, # Increased cadence
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
state = AdapterState(
name="test",
adapter=mock_adapter,
config=config,
last_completed_poll=last_poll,
)
# Calculate next poll time
now = datetime.now(timezone.utc)
next_poll_at = last_poll.timestamp() + config.cadence_s
wait_time = max(0, next_poll_at - now.timestamp())
# Since 150 > 120, next poll should be immediate (wait_time ~= 0)
assert wait_time < 1, (
f"Expected immediate poll (wait ~0s), got {wait_time}s. "
f"After a gap exceeding new cadence, poll should happen now."
)
@pytest.mark.asyncio
async def test_enable_disable_enable_respects_rate_limit(self) -> None:
"""Re-enabling adapter schedules poll at last_poll + cadence.
If adapter was disabled for a while and then re-enabled, the next
poll should be at (last_completed_poll + cadence_s), not immediately
(unless that time has already passed).
"""
from central.supervisor import AdapterState
mock_adapter = MagicMock()
mock_adapter.name = "test"
mock_adapter.cadence_s = 60
# Last poll was 30 seconds ago, then adapter was disabled
last_poll = datetime.now(timezone.utc) - timedelta(seconds=30)
# Re-enabled config
config = AdapterConfig(
name="test",
enabled=True,
cadence_s=60,
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
state = AdapterState(
name="test",
adapter=mock_adapter,
config=config,
last_completed_poll=last_poll,
)
# Calculate next poll time
now = datetime.now(timezone.utc)
next_poll_at = last_poll.timestamp() + config.cadence_s
wait_time = max(0, next_poll_at - now.timestamp())
# Should wait ~30 more seconds (60 - 30 = 30)
assert abs(wait_time - 30) < 2, (
f"Expected ~30s wait after re-enable, got {wait_time}s. "
f"Rate limit violated on enable→disable→enable sequence."
)
@pytest.mark.asyncio
async def test_multiple_rapid_cadence_changes_no_extra_polls(self) -> None:
"""Multiple rapid cadence changes don't cause extra polls.
If NOTIFY fires rapidly (609012090), the final schedule should
still be based on last_completed_poll + final_cadence.
"""
from central.supervisor import AdapterState
mock_adapter = MagicMock()
mock_adapter.name = "test"
mock_adapter.cadence_s = 60
# Last poll was 20 seconds ago
last_poll = datetime.now(timezone.utc) - timedelta(seconds=20)
state = AdapterState(
name="test",
adapter=mock_adapter,
config=AdapterConfig(
name="test",
enabled=True,
cadence_s=60,
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
),
last_completed_poll=last_poll,
)
# Simulate rapid cadence changes
for cadence in [90, 120, 90]: # Final cadence is 90
state.config = AdapterConfig(
name="test",
enabled=True,
cadence_s=cadence,
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
state.adapter.cadence_s = cadence
# Final schedule should be last_poll + 90
now = datetime.now(timezone.utc)
final_cadence = 90
next_poll_at = last_poll.timestamp() + final_cadence
wait_time = max(0, next_poll_at - now.timestamp())
# Should wait ~70 seconds (90 - 20 = 70)
assert abs(wait_time - 70) < 2, (
f"Expected ~70s wait after rapid changes, got {wait_time}s. "
f"Multiple NOTIFYs should not cause extra polls."
)
"""Tests for supervisor hot-reload and rate-limiting behavior."""
import asyncio
import base64
import os
from datetime import datetime, timedelta, timezone
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import asyncpg
import pytest
import pytest_asyncio
from central.config_models import AdapterConfig
from central.config_source import DbConfigSource
from central.config_store import ConfigStore
from central.crypto import KEY_SIZE, clear_key_cache
# Test database DSN
TEST_DB_DSN = os.environ.get(
"CENTRAL_TEST_DB_DSN",
"postgresql://central_test:testpass@localhost/central_test",
)
@pytest.fixture(scope="session")
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
"""Create a master key file for the test session."""
key = os.urandom(KEY_SIZE)
key_path = tmp_path_factory.mktemp("keys") / "master.key"
key_path.write_text(base64.b64encode(key).decode())
return key_path
@pytest.fixture(autouse=True)
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
"""Configure master key path for all tests."""
clear_key_cache()
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
@pytest_asyncio.fixture
async def db_conn() -> asyncpg.Connection:
"""Get a direct database connection for setup/teardown."""
conn = await asyncpg.connect(TEST_DB_DSN)
yield conn
await conn.close()
@pytest_asyncio.fixture
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
"""Ensure config schema exists and is clean before each test."""
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
await db_conn.execute("""
CREATE TABLE IF NOT EXISTS config.adapters (
name TEXT PRIMARY KEY,
enabled BOOLEAN NOT NULL DEFAULT true,
cadence_s INTEGER NOT NULL,
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
paused_at TIMESTAMPTZ,
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
)
""")
# Create notify trigger
await db_conn.execute("""
CREATE OR REPLACE FUNCTION config.notify_config_change()
RETURNS trigger AS $$
DECLARE
key_value TEXT;
BEGIN
IF TG_TABLE_NAME = 'adapters' THEN
key_value := COALESCE(NEW.name, OLD.name, '');
ELSE
key_value := '';
END IF;
PERFORM pg_notify('config_changed', TG_TABLE_NAME || ':' || key_value);
RETURN COALESCE(NEW, OLD);
END;
$$ LANGUAGE plpgsql
""")
await db_conn.execute("""
DROP TRIGGER IF EXISTS adapters_notify ON config.adapters;
CREATE TRIGGER adapters_notify
AFTER INSERT OR UPDATE OR DELETE ON config.adapters
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
""")
await db_conn.execute("DELETE FROM config.adapters")
@pytest_asyncio.fixture
async def config_store(clean_config_schema: None) -> ConfigStore:
"""Create a ConfigStore connected to the test database."""
store = await ConfigStore.create(TEST_DB_DSN)
yield store
await store.close()
class TestDbConfigSourceNotifications:
"""Tests for DbConfigSource NOTIFY integration."""
@pytest.mark.asyncio
async def test_watch_receives_notifications(
self,
config_store: ConfigStore,
db_conn: asyncpg.Connection,
) -> None:
"""watch_for_changes receives NOTIFY when adapter changes."""
source = DbConfigSource(config_store)
notifications: list[tuple[str, str]] = []
notification_received = asyncio.Event()
async def callback(table: str, key: str) -> None:
notifications.append((table, key))
notification_received.set()
# Start watching in background
watch_task = asyncio.create_task(source.watch_for_changes(callback))
try:
# Wait for listener to connect
await asyncio.sleep(0.2)
# Insert an adapter via direct connection (not through store)
# This triggers the NOTIFY
await db_conn.execute("""
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
VALUES ('test_adapter', true, 60, '{}'::jsonb)
""")
# Wait for notification
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
assert len(notifications) >= 1
assert notifications[0] == ("adapters", "test_adapter")
finally:
watch_task.cancel()
try:
await watch_task
except asyncio.CancelledError:
pass
class TestRateLimitGuarantee:
"""Tests for rate-limit guarantees during hot-reload.
These tests verify the critical invariant: cadence changes must not
cause extra API calls before (last_poll + new_cadence).
"""
@pytest.mark.asyncio
async def test_cadence_change_respects_last_poll_time(self) -> None:
"""Changing cadence mid-cycle schedules next poll at last_poll + new_cadence.
This is the core rate-limit guarantee test (gate 3).
"""
# Import supervisor module to access AdapterState
from central.supervisor import AdapterState
# Mock adapter
mock_adapter = MagicMock()
mock_adapter.name = "test"
mock_adapter.cadence_s = 60
# Create adapter state with a known last_completed_poll time
last_poll = datetime.now(timezone.utc) - timedelta(seconds=30)
config = AdapterConfig(
name="test",
enabled=True,
cadence_s=60, # Original cadence
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
state = AdapterState(
name="test",
adapter=mock_adapter,
config=config,
last_completed_poll=last_poll,
)
# Simulate cadence change to 90 seconds
new_config = AdapterConfig(
name="test",
enabled=True,
cadence_s=90, # New cadence
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
# Update state as reschedule would
state.config = new_config
state.adapter.cadence_s = 90
# Calculate expected next poll time
expected_next_poll = last_poll + timedelta(seconds=90)
now = datetime.now(timezone.utc)
expected_wait = max(0, (expected_next_poll - now).total_seconds())
# The wait time should be based on last_poll + new_cadence
# Since last_poll was 30 seconds ago and new cadence is 90,
# we should wait 60 more seconds (90 - 30 = 60)
actual_next_poll = last_poll.timestamp() + new_config.cadence_s
actual_wait = max(0, actual_next_poll - now.timestamp())
# Allow 1 second tolerance for timing
assert abs(actual_wait - 60) < 2, (
f"Expected ~60s wait, got {actual_wait}s. "
f"Rate limit violated: poll would happen before last_poll + new_cadence"
)
@pytest.mark.asyncio
async def test_cadence_increase_after_gap_polls_immediately(self) -> None:
"""When last_poll + new_cadence is already past, poll immediately.
If operator increases cadence to 120s after a gap of 150s,
the poll should happen now (not wait another 120s).
"""
from central.supervisor import AdapterState
mock_adapter = MagicMock()
mock_adapter.name = "test"
mock_adapter.cadence_s = 60
# Last poll was 150 seconds ago
last_poll = datetime.now(timezone.utc) - timedelta(seconds=150)
config = AdapterConfig(
name="test",
enabled=True,
cadence_s=120, # Increased cadence
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
state = AdapterState(
name="test",
adapter=mock_adapter,
config=config,
last_completed_poll=last_poll,
)
# Calculate next poll time
now = datetime.now(timezone.utc)
next_poll_at = last_poll.timestamp() + config.cadence_s
wait_time = max(0, next_poll_at - now.timestamp())
# Since 150 > 120, next poll should be immediate (wait_time ~= 0)
assert wait_time < 1, (
f"Expected immediate poll (wait ~0s), got {wait_time}s. "
f"After a gap exceeding new cadence, poll should happen now."
)
@pytest.mark.asyncio
async def test_enable_disable_enable_respects_rate_limit(self) -> None:
"""Re-enabling adapter schedules poll at last_poll + cadence.
If adapter was disabled for a while and then re-enabled, the next
poll should be at (last_completed_poll + cadence_s), not immediately
(unless that time has already passed).
"""
from central.supervisor import AdapterState
mock_adapter = MagicMock()
mock_adapter.name = "test"
mock_adapter.cadence_s = 60
# Last poll was 30 seconds ago, then adapter was disabled
last_poll = datetime.now(timezone.utc) - timedelta(seconds=30)
# Re-enabled config
config = AdapterConfig(
name="test",
enabled=True,
cadence_s=60,
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
state = AdapterState(
name="test",
adapter=mock_adapter,
config=config,
last_completed_poll=last_poll,
)
# Calculate next poll time
now = datetime.now(timezone.utc)
next_poll_at = last_poll.timestamp() + config.cadence_s
wait_time = max(0, next_poll_at - now.timestamp())
# Should wait ~30 more seconds (60 - 30 = 30)
assert abs(wait_time - 30) < 2, (
f"Expected ~30s wait after re-enable, got {wait_time}s. "
f"Rate limit violated on enable→disable→enable sequence."
)
@pytest.mark.asyncio
async def test_multiple_rapid_cadence_changes_no_extra_polls(self) -> None:
"""Multiple rapid cadence changes don't cause extra polls.
If NOTIFY fires rapidly (609012090), the final schedule should
still be based on last_completed_poll + final_cadence.
"""
from central.supervisor import AdapterState
mock_adapter = MagicMock()
mock_adapter.name = "test"
mock_adapter.cadence_s = 60
# Last poll was 20 seconds ago
last_poll = datetime.now(timezone.utc) - timedelta(seconds=20)
state = AdapterState(
name="test",
adapter=mock_adapter,
config=AdapterConfig(
name="test",
enabled=True,
cadence_s=60,
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
),
last_completed_poll=last_poll,
)
# Simulate rapid cadence changes
for cadence in [90, 120, 90]: # Final cadence is 90
state.config = AdapterConfig(
name="test",
enabled=True,
cadence_s=cadence,
settings={},
paused_at=None,
updated_at=datetime.now(timezone.utc),
)
state.adapter.cadence_s = cadence
# Final schedule should be last_poll + 90
now = datetime.now(timezone.utc)
final_cadence = 90
next_poll_at = last_poll.timestamp() + final_cadence
wait_time = max(0, next_poll_at - now.timestamp())
# Should wait ~70 seconds (90 - 20 = 70)
assert abs(wait_time - 70) < 2, (
f"Expected ~70s wait after rapid changes, got {wait_time}s. "
f"Multiple NOTIFYs should not cause extra polls."
)

File diff suppressed because it is too large Load diff

View file

@ -1,482 +1,482 @@
"""Tests for USGS earthquake adapter."""
import pytest
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
from pathlib import Path
import tempfile
from central.adapters.usgs_quake import (
USGSQuakeAdapter,
magnitude_tier,
magnitude_to_severity,
)
from central.config_models import AdapterConfig, RegionConfig
from central.models import Event, Geo
# Sample USGS GeoJSON response
SAMPLE_GEOJSON = {
"type": "FeatureCollection",
"metadata": {
"generated": 1715878800000,
"url": "https://earthquake.usgs.gov/earthquakes/feed/v1.0/summary/all_hour.geojson",
"title": "USGS All Earthquakes, Past Hour",
"status": 200,
"api": "1.10.3",
"count": 3
},
"features": [
{
"type": "Feature",
"properties": {
"mag": 2.5,
"place": "10km N of Boise, Idaho",
"time": 1715878500000,
"updated": 1715878600000,
"tz": None,
"url": "https://earthquake.usgs.gov/earthquakes/eventpage/us1234",
"detail": "https://earthquake.usgs.gov/earthquakes/feed/v1.0/detail/us1234.geojson",
"felt": None,
"cdi": None,
"mmi": None,
"alert": None,
"status": "automatic",
"tsunami": 0,
"sig": 100,
"net": "us",
"code": "1234",
"ids": ",us1234,",
"sources": ",us,",
"types": ",origin,",
"nst": 10,
"dmin": 0.5,
"rms": 0.3,
"gap": 100,
"magType": "ml",
"type": "earthquake",
"title": "M 2.5 - 10km N of Boise, Idaho"
},
"geometry": {
"type": "Point",
"coordinates": [-116.2, 43.7, 10.5]
},
"id": "us1234"
},
{
"type": "Feature",
"properties": {
"mag": 4.5,
"place": "20km S of Portland, Oregon",
"time": 1715878400000,
"updated": 1715878500000,
"tz": None,
"url": "https://earthquake.usgs.gov/earthquakes/eventpage/us5678",
"detail": "https://earthquake.usgs.gov/earthquakes/feed/v1.0/detail/us5678.geojson",
"felt": 50,
"cdi": 4.0,
"mmi": 3.5,
"alert": "green",
"status": "reviewed",
"tsunami": 0,
"sig": 300,
"net": "us",
"code": "5678",
"ids": ",us5678,",
"sources": ",us,",
"types": ",origin,shakemap,",
"nst": 25,
"dmin": 0.2,
"rms": 0.2,
"gap": 50,
"magType": "mw",
"type": "earthquake",
"title": "M 4.5 - 20km S of Portland, Oregon"
},
"geometry": {
"type": "Point",
"coordinates": [-122.6, 45.3, 15.0]
},
"id": "us5678"
},
{
"type": "Feature",
"properties": {
"mag": 3.0,
"place": "50km E of San Francisco, California",
"time": 1715878300000,
"updated": 1715878400000,
"tz": None,
"url": "https://earthquake.usgs.gov/earthquakes/eventpage/us9999",
"detail": "https://earthquake.usgs.gov/earthquakes/feed/v1.0/detail/us9999.geojson",
"felt": None,
"cdi": None,
"mmi": None,
"alert": None,
"status": "automatic",
"tsunami": 0,
"sig": 150,
"net": "us",
"code": "9999",
"ids": ",us9999,",
"sources": ",us,",
"types": ",origin,",
"nst": 15,
"dmin": 0.3,
"rms": 0.25,
"gap": 80,
"magType": "ml",
"type": "earthquake",
"title": "M 3.0 - 50km E of San Francisco, California"
},
"geometry": {
"type": "Point",
"coordinates": [-121.5, 37.8, 8.0]
},
"id": "us9999"
}
]
}
# Sample with null magnitude
SAMPLE_NULL_MAG = {
"type": "FeatureCollection",
"metadata": {"count": 1},
"features": [
{
"type": "Feature",
"properties": {
"mag": None,
"place": "Quarry blast",
"time": 1715878500000,
"type": "quarry blast"
},
"geometry": {
"type": "Point",
"coordinates": [-116.0, 44.0, 0.0]
},
"id": "usquarry1"
}
]
}
def make_adapter_config(
region: dict | None = None,
feed: str = "all_hour",
) -> AdapterConfig:
"""Create an AdapterConfig for testing."""
settings = {"feed": feed}
if region:
settings["region"] = region
else:
settings["region"] = {
"north": 49.5,
"south": 40.0,
"east": -110.0,
"west": -125.0,
}
return AdapterConfig(
name="usgs_quake",
enabled=True,
cadence_s=60,
settings=settings,
updated_at=datetime.now(timezone.utc),
)
@pytest.fixture
def temp_db_path():
"""Create a temporary database path for testing."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
yield Path(f.name)
@pytest.fixture
def mock_config_store():
"""Create a mock ConfigStore."""
return MagicMock()
class TestMagnitudeTier:
"""Test magnitude tier classification."""
def test_minor(self):
assert magnitude_tier(0.5) == "minor"
assert magnitude_tier(2.9) == "minor"
def test_light(self):
assert magnitude_tier(3.0) == "light"
assert magnitude_tier(3.9) == "light"
def test_moderate(self):
assert magnitude_tier(4.0) == "moderate"
assert magnitude_tier(4.9) == "moderate"
def test_strong(self):
assert magnitude_tier(5.0) == "strong"
assert magnitude_tier(5.9) == "strong"
def test_major(self):
assert magnitude_tier(6.0) == "major"
assert magnitude_tier(6.9) == "major"
def test_great(self):
assert magnitude_tier(7.0) == "great"
assert magnitude_tier(9.5) == "great"
class TestMagnitudeToSeverity:
"""Test magnitude to severity mapping."""
def test_severity_levels(self):
assert magnitude_to_severity(2.0) == 0
assert magnitude_to_severity(3.5) == 1
assert magnitude_to_severity(4.5) == 2
assert magnitude_to_severity(5.5) == 3
assert magnitude_to_severity(6.5) == 4
assert magnitude_to_severity(7.5) == 5
class TestRegionFiltering:
"""Test region/bbox filtering."""
@pytest.mark.asyncio
async def test_filters_out_of_bbox(self, temp_db_path, mock_config_store):
"""Test that quakes outside bbox are filtered."""
# Region covers PNW only (north of 40, west of -110)
config = make_adapter_config(
region={"north": 49.5, "south": 40.0, "east": -110.0, "west": -125.0}
)
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
await adapter.startup()
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = SAMPLE_GEOJSON
events = []
async for event in adapter.poll():
events.append(event)
# us1234 (Boise) and us5678 (Portland) are in bbox
# us9999 (SF, lat 37.8) is outside bbox (south < 40)
assert len(events) == 2
event_ids = {e.id for e in events}
assert "us1234" in event_ids
assert "us5678" in event_ids
assert "us9999" not in event_ids
await adapter.shutdown()
class TestDeduplication:
"""Test deduplication logic."""
@pytest.mark.asyncio
async def test_dedup_marks_published(self, temp_db_path, mock_config_store):
config = make_adapter_config()
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
await adapter.startup()
event_id = "us1234"
assert not adapter.is_published(event_id)
adapter.mark_published(event_id)
assert adapter.is_published(event_id)
await adapter.shutdown()
@pytest.mark.asyncio
async def test_second_poll_no_duplicates(self, temp_db_path, mock_config_store):
"""Test that second poll with same events yields nothing."""
config = make_adapter_config()
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
await adapter.startup()
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = SAMPLE_GEOJSON
# First poll
events1 = []
async for event in adapter.poll():
events1.append(event)
# Second poll - same data
events2 = []
async for event in adapter.poll():
events2.append(event)
# First poll should have events (2 in bbox)
assert len(events1) == 2
# Second poll should have 0 (all deduped)
assert len(events2) == 0
await adapter.shutdown()
class TestNullMagnitude:
"""Test handling of null magnitude events."""
@pytest.mark.asyncio
async def test_skips_null_magnitude(self, temp_db_path, mock_config_store):
"""Test that events with null magnitude are skipped."""
config = make_adapter_config()
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
await adapter.startup()
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = SAMPLE_NULL_MAG
events = []
async for event in adapter.poll():
events.append(event)
# Should skip the null-magnitude event
assert len(events) == 0
await adapter.shutdown()
class TestEventGeneration:
"""Test Event generation from features."""
@pytest.mark.asyncio
async def test_event_category(self, temp_db_path, mock_config_store):
config = make_adapter_config()
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
await adapter.startup()
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = SAMPLE_GEOJSON
events = []
async for event in adapter.poll():
events.append(event)
# Check categories
categories = {e.category for e in events}
# us1234 is M2.5 -> minor, us5678 is M4.5 -> moderate
assert "quake.event.minor" in categories
assert "quake.event.moderate" in categories
await adapter.shutdown()
@pytest.mark.asyncio
async def test_event_severity(self, temp_db_path, mock_config_store):
config = make_adapter_config()
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
await adapter.startup()
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = SAMPLE_GEOJSON
events = []
async for event in adapter.poll():
events.append(event)
# Find events by ID
events_by_id = {e.id: e for e in events}
# M2.5 -> severity 0
assert events_by_id["us1234"].severity == 0
# M4.5 -> severity 2
assert events_by_id["us5678"].severity == 2
await adapter.shutdown()
@pytest.mark.asyncio
async def test_event_geo(self, temp_db_path, mock_config_store):
config = make_adapter_config()
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
await adapter.startup()
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = SAMPLE_GEOJSON
events = []
async for event in adapter.poll():
events.append(event)
events_by_id = {e.id: e for e in events}
# Check Boise quake coordinates
boise = events_by_id["us1234"]
assert boise.geo.centroid == (-116.2, 43.7)
await adapter.shutdown()
class TestApplyConfig:
"""Test hot-reload configuration application."""
@pytest.mark.asyncio
async def test_apply_config_updates_region(self, temp_db_path, mock_config_store):
config = make_adapter_config(
region={"north": 49.5, "south": 40.0, "east": -110.0, "west": -125.0}
)
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
await adapter.startup()
assert adapter.region.north == 49.5
new_config = make_adapter_config(
region={"north": 48.0, "south": 45.0, "east": -115.0, "west": -125.0}
)
await adapter.apply_config(new_config)
assert adapter.region.north == 48.0
assert adapter.region.south == 45.0
await adapter.shutdown()
@pytest.mark.asyncio
async def test_apply_config_updates_feed(self, temp_db_path, mock_config_store):
config = make_adapter_config(feed="all_hour")
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
await adapter.startup()
assert adapter._feed == "all_hour"
new_config = make_adapter_config(feed="all_day")
await adapter.apply_config(new_config)
assert adapter._feed == "all_day"
await adapter.shutdown()
"""Tests for USGS earthquake adapter."""
import pytest
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
from pathlib import Path
import tempfile
from central.adapters.usgs_quake import (
USGSQuakeAdapter,
magnitude_tier,
magnitude_to_severity,
)
from central.config_models import AdapterConfig, RegionConfig
from central.models import Event, Geo
# Sample USGS GeoJSON response
SAMPLE_GEOJSON = {
"type": "FeatureCollection",
"metadata": {
"generated": 1715878800000,
"url": "https://earthquake.usgs.gov/earthquakes/feed/v1.0/summary/all_hour.geojson",
"title": "USGS All Earthquakes, Past Hour",
"status": 200,
"api": "1.10.3",
"count": 3
},
"features": [
{
"type": "Feature",
"properties": {
"mag": 2.5,
"place": "10km N of Boise, Idaho",
"time": 1715878500000,
"updated": 1715878600000,
"tz": None,
"url": "https://earthquake.usgs.gov/earthquakes/eventpage/us1234",
"detail": "https://earthquake.usgs.gov/earthquakes/feed/v1.0/detail/us1234.geojson",
"felt": None,
"cdi": None,
"mmi": None,
"alert": None,
"status": "automatic",
"tsunami": 0,
"sig": 100,
"net": "us",
"code": "1234",
"ids": ",us1234,",
"sources": ",us,",
"types": ",origin,",
"nst": 10,
"dmin": 0.5,
"rms": 0.3,
"gap": 100,
"magType": "ml",
"type": "earthquake",
"title": "M 2.5 - 10km N of Boise, Idaho"
},
"geometry": {
"type": "Point",
"coordinates": [-116.2, 43.7, 10.5]
},
"id": "us1234"
},
{
"type": "Feature",
"properties": {
"mag": 4.5,
"place": "20km S of Portland, Oregon",
"time": 1715878400000,
"updated": 1715878500000,
"tz": None,
"url": "https://earthquake.usgs.gov/earthquakes/eventpage/us5678",
"detail": "https://earthquake.usgs.gov/earthquakes/feed/v1.0/detail/us5678.geojson",
"felt": 50,
"cdi": 4.0,
"mmi": 3.5,
"alert": "green",
"status": "reviewed",
"tsunami": 0,
"sig": 300,
"net": "us",
"code": "5678",
"ids": ",us5678,",
"sources": ",us,",
"types": ",origin,shakemap,",
"nst": 25,
"dmin": 0.2,
"rms": 0.2,
"gap": 50,
"magType": "mw",
"type": "earthquake",
"title": "M 4.5 - 20km S of Portland, Oregon"
},
"geometry": {
"type": "Point",
"coordinates": [-122.6, 45.3, 15.0]
},
"id": "us5678"
},
{
"type": "Feature",
"properties": {
"mag": 3.0,
"place": "50km E of San Francisco, California",
"time": 1715878300000,
"updated": 1715878400000,
"tz": None,
"url": "https://earthquake.usgs.gov/earthquakes/eventpage/us9999",
"detail": "https://earthquake.usgs.gov/earthquakes/feed/v1.0/detail/us9999.geojson",
"felt": None,
"cdi": None,
"mmi": None,
"alert": None,
"status": "automatic",
"tsunami": 0,
"sig": 150,
"net": "us",
"code": "9999",
"ids": ",us9999,",
"sources": ",us,",
"types": ",origin,",
"nst": 15,
"dmin": 0.3,
"rms": 0.25,
"gap": 80,
"magType": "ml",
"type": "earthquake",
"title": "M 3.0 - 50km E of San Francisco, California"
},
"geometry": {
"type": "Point",
"coordinates": [-121.5, 37.8, 8.0]
},
"id": "us9999"
}
]
}
# Sample with null magnitude
SAMPLE_NULL_MAG = {
"type": "FeatureCollection",
"metadata": {"count": 1},
"features": [
{
"type": "Feature",
"properties": {
"mag": None,
"place": "Quarry blast",
"time": 1715878500000,
"type": "quarry blast"
},
"geometry": {
"type": "Point",
"coordinates": [-116.0, 44.0, 0.0]
},
"id": "usquarry1"
}
]
}
def make_adapter_config(
region: dict | None = None,
feed: str = "all_hour",
) -> AdapterConfig:
"""Create an AdapterConfig for testing."""
settings = {"feed": feed}
if region:
settings["region"] = region
else:
settings["region"] = {
"north": 49.5,
"south": 40.0,
"east": -110.0,
"west": -125.0,
}
return AdapterConfig(
name="usgs_quake",
enabled=True,
cadence_s=60,
settings=settings,
updated_at=datetime.now(timezone.utc),
)
@pytest.fixture
def temp_db_path():
"""Create a temporary database path for testing."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
yield Path(f.name)
@pytest.fixture
def mock_config_store():
"""Create a mock ConfigStore."""
return MagicMock()
class TestMagnitudeTier:
"""Test magnitude tier classification."""
def test_minor(self):
assert magnitude_tier(0.5) == "minor"
assert magnitude_tier(2.9) == "minor"
def test_light(self):
assert magnitude_tier(3.0) == "light"
assert magnitude_tier(3.9) == "light"
def test_moderate(self):
assert magnitude_tier(4.0) == "moderate"
assert magnitude_tier(4.9) == "moderate"
def test_strong(self):
assert magnitude_tier(5.0) == "strong"
assert magnitude_tier(5.9) == "strong"
def test_major(self):
assert magnitude_tier(6.0) == "major"
assert magnitude_tier(6.9) == "major"
def test_great(self):
assert magnitude_tier(7.0) == "great"
assert magnitude_tier(9.5) == "great"
class TestMagnitudeToSeverity:
"""Test magnitude to severity mapping."""
def test_severity_levels(self):
assert magnitude_to_severity(2.0) == 0
assert magnitude_to_severity(3.5) == 1
assert magnitude_to_severity(4.5) == 2
assert magnitude_to_severity(5.5) == 3
assert magnitude_to_severity(6.5) == 4
assert magnitude_to_severity(7.5) == 5
class TestRegionFiltering:
"""Test region/bbox filtering."""
@pytest.mark.asyncio
async def test_filters_out_of_bbox(self, temp_db_path, mock_config_store):
"""Test that quakes outside bbox are filtered."""
# Region covers PNW only (north of 40, west of -110)
config = make_adapter_config(
region={"north": 49.5, "south": 40.0, "east": -110.0, "west": -125.0}
)
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
await adapter.startup()
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = SAMPLE_GEOJSON
events = []
async for event in adapter.poll():
events.append(event)
# us1234 (Boise) and us5678 (Portland) are in bbox
# us9999 (SF, lat 37.8) is outside bbox (south < 40)
assert len(events) == 2
event_ids = {e.id for e in events}
assert "us1234" in event_ids
assert "us5678" in event_ids
assert "us9999" not in event_ids
await adapter.shutdown()
class TestDeduplication:
"""Test deduplication logic."""
@pytest.mark.asyncio
async def test_dedup_marks_published(self, temp_db_path, mock_config_store):
config = make_adapter_config()
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
await adapter.startup()
event_id = "us1234"
assert not adapter.is_published(event_id)
adapter.mark_published(event_id)
assert adapter.is_published(event_id)
await adapter.shutdown()
@pytest.mark.asyncio
async def test_second_poll_no_duplicates(self, temp_db_path, mock_config_store):
"""Test that second poll with same events yields nothing."""
config = make_adapter_config()
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
await adapter.startup()
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = SAMPLE_GEOJSON
# First poll
events1 = []
async for event in adapter.poll():
events1.append(event)
# Second poll - same data
events2 = []
async for event in adapter.poll():
events2.append(event)
# First poll should have events (2 in bbox)
assert len(events1) == 2
# Second poll should have 0 (all deduped)
assert len(events2) == 0
await adapter.shutdown()
class TestNullMagnitude:
"""Test handling of null magnitude events."""
@pytest.mark.asyncio
async def test_skips_null_magnitude(self, temp_db_path, mock_config_store):
"""Test that events with null magnitude are skipped."""
config = make_adapter_config()
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
await adapter.startup()
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = SAMPLE_NULL_MAG
events = []
async for event in adapter.poll():
events.append(event)
# Should skip the null-magnitude event
assert len(events) == 0
await adapter.shutdown()
class TestEventGeneration:
"""Test Event generation from features."""
@pytest.mark.asyncio
async def test_event_category(self, temp_db_path, mock_config_store):
config = make_adapter_config()
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
await adapter.startup()
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = SAMPLE_GEOJSON
events = []
async for event in adapter.poll():
events.append(event)
# Check categories
categories = {e.category for e in events}
# us1234 is M2.5 -> minor, us5678 is M4.5 -> moderate
assert "quake.event.minor" in categories
assert "quake.event.moderate" in categories
await adapter.shutdown()
@pytest.mark.asyncio
async def test_event_severity(self, temp_db_path, mock_config_store):
config = make_adapter_config()
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
await adapter.startup()
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = SAMPLE_GEOJSON
events = []
async for event in adapter.poll():
events.append(event)
# Find events by ID
events_by_id = {e.id: e for e in events}
# M2.5 -> severity 0
assert events_by_id["us1234"].severity == 0
# M4.5 -> severity 2
assert events_by_id["us5678"].severity == 2
await adapter.shutdown()
@pytest.mark.asyncio
async def test_event_geo(self, temp_db_path, mock_config_store):
config = make_adapter_config()
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
await adapter.startup()
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
mock_fetch.return_value = SAMPLE_GEOJSON
events = []
async for event in adapter.poll():
events.append(event)
events_by_id = {e.id: e for e in events}
# Check Boise quake coordinates
boise = events_by_id["us1234"]
assert boise.geo.centroid == (-116.2, 43.7)
await adapter.shutdown()
class TestApplyConfig:
"""Test hot-reload configuration application."""
@pytest.mark.asyncio
async def test_apply_config_updates_region(self, temp_db_path, mock_config_store):
config = make_adapter_config(
region={"north": 49.5, "south": 40.0, "east": -110.0, "west": -125.0}
)
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
await adapter.startup()
assert adapter.region.north == 49.5
new_config = make_adapter_config(
region={"north": 48.0, "south": 45.0, "east": -115.0, "west": -125.0}
)
await adapter.apply_config(new_config)
assert adapter.region.north == 48.0
assert adapter.region.south == 45.0
await adapter.shutdown()
@pytest.mark.asyncio
async def test_apply_config_updates_feed(self, temp_db_path, mock_config_store):
config = make_adapter_config(feed="all_hour")
adapter = USGSQuakeAdapter(
config=config,
config_store=mock_config_store,
cursor_db_path=temp_db_path,
)
await adapter.startup()
assert adapter._feed == "all_hour"
new_config = make_adapter_config(feed="all_day")
await adapter.apply_config(new_config)
assert adapter._feed == "all_day"
await adapter.shutdown()