mirror of
https://github.com/zvx-echo6/central.git
synced 2026-06-10 11:54:37 +02:00
chore: normalize line endings to LF
This commit is contained in:
parent
43088d7fbb
commit
374a8c067f
26 changed files with 5357 additions and 5346 deletions
|
|
@ -1,18 +1,18 @@
|
|||
# Central Tests
|
||||
|
||||
## Test Database
|
||||
|
||||
Some tests (notably `test_config_store.py`) require a real PostgreSQL database.
|
||||
By default, tests connect to:
|
||||
|
||||
```
|
||||
postgresql://central_test:testpass@localhost/central_test
|
||||
```
|
||||
|
||||
If your test database uses different credentials, set the `CENTRAL_TEST_DB_DSN`
|
||||
environment variable:
|
||||
|
||||
```bash
|
||||
export CENTRAL_TEST_DB_DSN="postgresql://myuser:mypass@localhost/mydb"
|
||||
uv run pytest tests/test_config_store.py
|
||||
```
|
||||
# Central Tests
|
||||
|
||||
## Test Database
|
||||
|
||||
Some tests (notably `test_config_store.py`) require a real PostgreSQL database.
|
||||
By default, tests connect to:
|
||||
|
||||
```
|
||||
postgresql://central_test:testpass@localhost/central_test
|
||||
```
|
||||
|
||||
If your test database uses different credentials, set the `CENTRAL_TEST_DB_DSN`
|
||||
environment variable:
|
||||
|
||||
```bash
|
||||
export CENTRAL_TEST_DB_DSN="postgresql://myuser:mypass@localhost/mydb"
|
||||
uv run pytest tests/test_config_store.py
|
||||
```
|
||||
|
|
|
|||
|
|
@ -1,123 +1,123 @@
|
|||
"""Tests for bootstrap configuration."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import pytest
|
||||
|
||||
from central.bootstrap_config import Settings, get_settings
|
||||
|
||||
|
||||
class TestSettingsFromEnv:
|
||||
"""Test loading settings from environment variables."""
|
||||
|
||||
def test_reads_from_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Settings are read from CENTRAL_* environment variables."""
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://test:pass@localhost/testdb")
|
||||
monkeypatch.setenv("CENTRAL_NATS_URL", "nats://10.0.0.1:4222")
|
||||
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", "/tmp/test.key")
|
||||
monkeypatch.setenv("CENTRAL_LOG_LEVEL", "DEBUG")
|
||||
|
||||
settings = Settings()
|
||||
|
||||
assert settings.db_dsn == "postgresql://test:pass@localhost/testdb"
|
||||
assert settings.nats_url == "nats://10.0.0.1:4222"
|
||||
assert settings.master_key_path == Path("/tmp/test.key")
|
||||
assert settings.log_level == "DEBUG"
|
||||
|
||||
def test_defaults_applied(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Default values are used when env vars not set."""
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://x:y@localhost/db")
|
||||
# Clear any existing env vars that might interfere
|
||||
monkeypatch.delenv("CENTRAL_NATS_URL", raising=False)
|
||||
monkeypatch.delenv("CENTRAL_MASTER_KEY_PATH", raising=False)
|
||||
monkeypatch.delenv("CENTRAL_LOG_LEVEL", raising=False)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
assert settings.nats_url == "nats://localhost:4222"
|
||||
assert settings.master_key_path == Path("/etc/central/master.key")
|
||||
assert settings.log_level == "INFO"
|
||||
|
||||
|
||||
class TestSettingsFromFile:
|
||||
"""Test loading settings from .env file."""
|
||||
|
||||
def test_reads_from_env_file(self, tmp_path: Path) -> None:
|
||||
"""Settings are read from .env file when env vars not present."""
|
||||
env_file = tmp_path / ".env"
|
||||
env_file.write_text(
|
||||
"CENTRAL_DB_DSN=postgresql://file:pass@localhost/filedb\n"
|
||||
"CENTRAL_NATS_URL=nats://file.local:4222\n"
|
||||
"CENTRAL_LOG_LEVEL=WARNING\n"
|
||||
)
|
||||
|
||||
# Create settings pointing to the temp .env file
|
||||
settings = Settings(_env_file=env_file)
|
||||
|
||||
assert settings.db_dsn == "postgresql://file:pass@localhost/filedb"
|
||||
assert settings.nats_url == "nats://file.local:4222"
|
||||
assert settings.log_level == "WARNING"
|
||||
|
||||
def test_env_vars_override_file(
|
||||
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Environment variables take precedence over .env file."""
|
||||
env_file = tmp_path / ".env"
|
||||
env_file.write_text("CENTRAL_DB_DSN=postgresql://file@localhost/filedb\n")
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://env@localhost/envdb")
|
||||
|
||||
settings = Settings(_env_file=env_file)
|
||||
|
||||
assert settings.db_dsn == "postgresql://env@localhost/envdb"
|
||||
|
||||
|
||||
class TestSettingsValidation:
|
||||
"""Test settings validation and error handling."""
|
||||
|
||||
def test_fails_if_required_var_missing(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Clear error when required CENTRAL_DB_DSN is missing."""
|
||||
# Ensure no env vars or .env file provides the DSN
|
||||
monkeypatch.delenv("CENTRAL_DB_DSN", raising=False)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
# Use a non-existent .env file path to ensure no fallback
|
||||
Settings(_env_file=Path("/nonexistent/.env"))
|
||||
|
||||
# pydantic-settings raises ValidationError for missing required fields
|
||||
assert "db_dsn" in str(exc_info.value).lower() or "validation" in str(exc_info.value).lower()
|
||||
|
||||
def test_invalid_log_level_rejected(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Invalid log level values are rejected."""
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://x@localhost/db")
|
||||
monkeypatch.setenv("CENTRAL_LOG_LEVEL", "INVALID")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
Settings()
|
||||
|
||||
|
||||
class TestGetSettings:
|
||||
"""Test the cached settings loader."""
|
||||
|
||||
def test_caches_result(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""get_settings() returns cached instance."""
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://cached@localhost/db")
|
||||
get_settings.cache_clear()
|
||||
|
||||
s1 = get_settings()
|
||||
s2 = get_settings()
|
||||
|
||||
assert s1 is s2
|
||||
|
||||
def test_cache_clear_reloads(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""cache_clear() forces reload on next call."""
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://first@localhost/db")
|
||||
get_settings.cache_clear()
|
||||
s1 = get_settings()
|
||||
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://second@localhost/db")
|
||||
get_settings.cache_clear()
|
||||
s2 = get_settings()
|
||||
|
||||
assert s1.db_dsn != s2.db_dsn
|
||||
"""Tests for bootstrap configuration."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import pytest
|
||||
|
||||
from central.bootstrap_config import Settings, get_settings
|
||||
|
||||
|
||||
class TestSettingsFromEnv:
|
||||
"""Test loading settings from environment variables."""
|
||||
|
||||
def test_reads_from_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Settings are read from CENTRAL_* environment variables."""
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://test:pass@localhost/testdb")
|
||||
monkeypatch.setenv("CENTRAL_NATS_URL", "nats://10.0.0.1:4222")
|
||||
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", "/tmp/test.key")
|
||||
monkeypatch.setenv("CENTRAL_LOG_LEVEL", "DEBUG")
|
||||
|
||||
settings = Settings()
|
||||
|
||||
assert settings.db_dsn == "postgresql://test:pass@localhost/testdb"
|
||||
assert settings.nats_url == "nats://10.0.0.1:4222"
|
||||
assert settings.master_key_path == Path("/tmp/test.key")
|
||||
assert settings.log_level == "DEBUG"
|
||||
|
||||
def test_defaults_applied(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Default values are used when env vars not set."""
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://x:y@localhost/db")
|
||||
# Clear any existing env vars that might interfere
|
||||
monkeypatch.delenv("CENTRAL_NATS_URL", raising=False)
|
||||
monkeypatch.delenv("CENTRAL_MASTER_KEY_PATH", raising=False)
|
||||
monkeypatch.delenv("CENTRAL_LOG_LEVEL", raising=False)
|
||||
|
||||
settings = Settings()
|
||||
|
||||
assert settings.nats_url == "nats://localhost:4222"
|
||||
assert settings.master_key_path == Path("/etc/central/master.key")
|
||||
assert settings.log_level == "INFO"
|
||||
|
||||
|
||||
class TestSettingsFromFile:
|
||||
"""Test loading settings from .env file."""
|
||||
|
||||
def test_reads_from_env_file(self, tmp_path: Path) -> None:
|
||||
"""Settings are read from .env file when env vars not present."""
|
||||
env_file = tmp_path / ".env"
|
||||
env_file.write_text(
|
||||
"CENTRAL_DB_DSN=postgresql://file:pass@localhost/filedb\n"
|
||||
"CENTRAL_NATS_URL=nats://file.local:4222\n"
|
||||
"CENTRAL_LOG_LEVEL=WARNING\n"
|
||||
)
|
||||
|
||||
# Create settings pointing to the temp .env file
|
||||
settings = Settings(_env_file=env_file)
|
||||
|
||||
assert settings.db_dsn == "postgresql://file:pass@localhost/filedb"
|
||||
assert settings.nats_url == "nats://file.local:4222"
|
||||
assert settings.log_level == "WARNING"
|
||||
|
||||
def test_env_vars_override_file(
|
||||
self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Environment variables take precedence over .env file."""
|
||||
env_file = tmp_path / ".env"
|
||||
env_file.write_text("CENTRAL_DB_DSN=postgresql://file@localhost/filedb\n")
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://env@localhost/envdb")
|
||||
|
||||
settings = Settings(_env_file=env_file)
|
||||
|
||||
assert settings.db_dsn == "postgresql://env@localhost/envdb"
|
||||
|
||||
|
||||
class TestSettingsValidation:
|
||||
"""Test settings validation and error handling."""
|
||||
|
||||
def test_fails_if_required_var_missing(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Clear error when required CENTRAL_DB_DSN is missing."""
|
||||
# Ensure no env vars or .env file provides the DSN
|
||||
monkeypatch.delenv("CENTRAL_DB_DSN", raising=False)
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
# Use a non-existent .env file path to ensure no fallback
|
||||
Settings(_env_file=Path("/nonexistent/.env"))
|
||||
|
||||
# pydantic-settings raises ValidationError for missing required fields
|
||||
assert "db_dsn" in str(exc_info.value).lower() or "validation" in str(exc_info.value).lower()
|
||||
|
||||
def test_invalid_log_level_rejected(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Invalid log level values are rejected."""
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://x@localhost/db")
|
||||
monkeypatch.setenv("CENTRAL_LOG_LEVEL", "INVALID")
|
||||
|
||||
with pytest.raises(Exception):
|
||||
Settings()
|
||||
|
||||
|
||||
class TestGetSettings:
|
||||
"""Test the cached settings loader."""
|
||||
|
||||
def test_caches_result(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""get_settings() returns cached instance."""
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://cached@localhost/db")
|
||||
get_settings.cache_clear()
|
||||
|
||||
s1 = get_settings()
|
||||
s2 = get_settings()
|
||||
|
||||
assert s1 is s2
|
||||
|
||||
def test_cache_clear_reloads(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""cache_clear() forces reload on next call."""
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://first@localhost/db")
|
||||
get_settings.cache_clear()
|
||||
s1 = get_settings()
|
||||
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", "postgresql://second@localhost/db")
|
||||
get_settings.cache_clear()
|
||||
s2 = get_settings()
|
||||
|
||||
assert s1.db_dsn != s2.db_dsn
|
||||
|
|
|
|||
|
|
@ -1,132 +1,132 @@
|
|||
"""Tests for configuration source abstraction."""
|
||||
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from central.config_source import (
|
||||
ConfigSource,
|
||||
DbConfigSource,
|
||||
)
|
||||
from central.crypto import KEY_SIZE, clear_key_cache
|
||||
|
||||
# Test database DSN
|
||||
TEST_DB_DSN = os.environ.get(
|
||||
"CENTRAL_TEST_DB_DSN",
|
||||
"postgresql://central_test:testpass@localhost/central_test",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
|
||||
"""Create a master key file for the test session."""
|
||||
key = os.urandom(KEY_SIZE)
|
||||
key_path = tmp_path_factory.mktemp("keys") / "master.key"
|
||||
key_path.write_text(base64.b64encode(key).decode())
|
||||
return key_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Configure master key path for all tests."""
|
||||
clear_key_cache()
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
|
||||
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_conn() -> asyncpg.Connection:
|
||||
"""Get a direct database connection for setup/teardown."""
|
||||
conn = await asyncpg.connect(TEST_DB_DSN)
|
||||
yield conn
|
||||
await conn.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
|
||||
"""Ensure config schema exists and is clean before each test."""
|
||||
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
|
||||
await db_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS config.adapters (
|
||||
name TEXT PRIMARY KEY,
|
||||
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||
cadence_s INTEGER NOT NULL,
|
||||
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
paused_at TIMESTAMPTZ,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
)
|
||||
""")
|
||||
await db_conn.execute("DELETE FROM config.adapters")
|
||||
|
||||
|
||||
class TestDbConfigSource:
|
||||
"""Tests for database-backed config source."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_source(self, clean_config_schema: None) -> DbConfigSource:
|
||||
"""Create a DbConfigSource for testing."""
|
||||
source = await DbConfigSource.create(TEST_DB_DSN)
|
||||
yield source
|
||||
await source.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_enabled_adapters_empty(self, db_source: DbConfigSource) -> None:
|
||||
"""list_enabled_adapters returns empty list when no adapters."""
|
||||
adapters = await db_source.list_enabled_adapters()
|
||||
assert adapters == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_enabled_adapters(
|
||||
self, db_source: DbConfigSource, db_conn: asyncpg.Connection
|
||||
) -> None:
|
||||
"""list_enabled_adapters returns only enabled, non-paused adapters."""
|
||||
# Insert test adapters
|
||||
await db_conn.execute("""
|
||||
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
|
||||
VALUES
|
||||
('enabled_adapter', true, 60, '{"key": "value"}'::jsonb),
|
||||
('disabled_adapter', false, 60, '{}'::jsonb),
|
||||
('paused_adapter', true, 60, '{}'::jsonb)
|
||||
""")
|
||||
await db_conn.execute("""
|
||||
UPDATE config.adapters
|
||||
SET paused_at = now()
|
||||
WHERE name = 'paused_adapter'
|
||||
""")
|
||||
|
||||
adapters = await db_source.list_enabled_adapters()
|
||||
|
||||
assert len(adapters) == 1
|
||||
assert adapters[0].name == "enabled_adapter"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_adapter(
|
||||
self, db_source: DbConfigSource, db_conn: asyncpg.Connection
|
||||
) -> None:
|
||||
"""get_adapter returns correct adapter config."""
|
||||
await db_conn.execute("""
|
||||
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
|
||||
VALUES ('test_adapter', true, 120, '{"states": ["ID"]}'::jsonb)
|
||||
""")
|
||||
|
||||
adapter = await db_source.get_adapter("test_adapter")
|
||||
|
||||
assert adapter is not None
|
||||
assert adapter.name == "test_adapter"
|
||||
assert adapter.cadence_s == 120
|
||||
assert adapter.settings == {"states": ["ID"]}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_adapter(self, db_source: DbConfigSource) -> None:
|
||||
"""get_adapter returns None for nonexistent adapter."""
|
||||
adapter = await db_source.get_adapter("does_not_exist")
|
||||
assert adapter is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implements_protocol(self, db_source: DbConfigSource) -> None:
|
||||
"""DbConfigSource implements ConfigSource protocol."""
|
||||
assert isinstance(db_source, ConfigSource)
|
||||
"""Tests for configuration source abstraction."""
|
||||
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from central.config_source import (
|
||||
ConfigSource,
|
||||
DbConfigSource,
|
||||
)
|
||||
from central.crypto import KEY_SIZE, clear_key_cache
|
||||
|
||||
# Test database DSN
|
||||
TEST_DB_DSN = os.environ.get(
|
||||
"CENTRAL_TEST_DB_DSN",
|
||||
"postgresql://central_test:testpass@localhost/central_test",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
|
||||
"""Create a master key file for the test session."""
|
||||
key = os.urandom(KEY_SIZE)
|
||||
key_path = tmp_path_factory.mktemp("keys") / "master.key"
|
||||
key_path.write_text(base64.b64encode(key).decode())
|
||||
return key_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Configure master key path for all tests."""
|
||||
clear_key_cache()
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
|
||||
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_conn() -> asyncpg.Connection:
|
||||
"""Get a direct database connection for setup/teardown."""
|
||||
conn = await asyncpg.connect(TEST_DB_DSN)
|
||||
yield conn
|
||||
await conn.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
|
||||
"""Ensure config schema exists and is clean before each test."""
|
||||
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
|
||||
await db_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS config.adapters (
|
||||
name TEXT PRIMARY KEY,
|
||||
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||
cadence_s INTEGER NOT NULL,
|
||||
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
paused_at TIMESTAMPTZ,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
)
|
||||
""")
|
||||
await db_conn.execute("DELETE FROM config.adapters")
|
||||
|
||||
|
||||
class TestDbConfigSource:
|
||||
"""Tests for database-backed config source."""
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_source(self, clean_config_schema: None) -> DbConfigSource:
|
||||
"""Create a DbConfigSource for testing."""
|
||||
source = await DbConfigSource.create(TEST_DB_DSN)
|
||||
yield source
|
||||
await source.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_enabled_adapters_empty(self, db_source: DbConfigSource) -> None:
|
||||
"""list_enabled_adapters returns empty list when no adapters."""
|
||||
adapters = await db_source.list_enabled_adapters()
|
||||
assert adapters == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_enabled_adapters(
|
||||
self, db_source: DbConfigSource, db_conn: asyncpg.Connection
|
||||
) -> None:
|
||||
"""list_enabled_adapters returns only enabled, non-paused adapters."""
|
||||
# Insert test adapters
|
||||
await db_conn.execute("""
|
||||
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
|
||||
VALUES
|
||||
('enabled_adapter', true, 60, '{"key": "value"}'::jsonb),
|
||||
('disabled_adapter', false, 60, '{}'::jsonb),
|
||||
('paused_adapter', true, 60, '{}'::jsonb)
|
||||
""")
|
||||
await db_conn.execute("""
|
||||
UPDATE config.adapters
|
||||
SET paused_at = now()
|
||||
WHERE name = 'paused_adapter'
|
||||
""")
|
||||
|
||||
adapters = await db_source.list_enabled_adapters()
|
||||
|
||||
assert len(adapters) == 1
|
||||
assert adapters[0].name == "enabled_adapter"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_adapter(
|
||||
self, db_source: DbConfigSource, db_conn: asyncpg.Connection
|
||||
) -> None:
|
||||
"""get_adapter returns correct adapter config."""
|
||||
await db_conn.execute("""
|
||||
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
|
||||
VALUES ('test_adapter', true, 120, '{"states": ["ID"]}'::jsonb)
|
||||
""")
|
||||
|
||||
adapter = await db_source.get_adapter("test_adapter")
|
||||
|
||||
assert adapter is not None
|
||||
assert adapter.name == "test_adapter"
|
||||
assert adapter.cadence_s == 120
|
||||
assert adapter.settings == {"states": ["ID"]}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_adapter(self, db_source: DbConfigSource) -> None:
|
||||
"""get_adapter returns None for nonexistent adapter."""
|
||||
adapter = await db_source.get_adapter("does_not_exist")
|
||||
assert adapter is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implements_protocol(self, db_source: DbConfigSource) -> None:
|
||||
"""DbConfigSource implements ConfigSource protocol."""
|
||||
assert isinstance(db_source, ConfigSource)
|
||||
|
|
|
|||
|
|
@ -1,339 +1,339 @@
|
|||
"""Tests for database-backed configuration store.
|
||||
|
||||
These tests require a real Postgres database. Set CENTRAL_TEST_DB_DSN
|
||||
environment variable to override the default test database connection.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from central.config_store import ConfigStore
|
||||
from central.crypto import KEY_SIZE, clear_key_cache
|
||||
|
||||
# Test database DSN - uses central_test database with well-known test password.
|
||||
# Override via CENTRAL_TEST_DB_DSN env var if your test DB differs.
|
||||
TEST_DB_DSN = os.environ.get(
|
||||
"CENTRAL_TEST_DB_DSN",
|
||||
"postgresql://central_test:testpass@localhost/central_test",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
|
||||
"""Create a master key file for the test session."""
|
||||
key = os.urandom(KEY_SIZE)
|
||||
key_path = tmp_path_factory.mktemp("keys") / "master.key"
|
||||
key_path.write_text(base64.b64encode(key).decode())
|
||||
return key_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Configure master key path for all tests."""
|
||||
clear_key_cache()
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
|
||||
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_conn() -> asyncpg.Connection:
|
||||
"""Get a direct database connection for setup/teardown."""
|
||||
conn = await asyncpg.connect(TEST_DB_DSN)
|
||||
yield conn
|
||||
await conn.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
|
||||
"""Ensure config schema exists and is clean before each test."""
|
||||
# Create schema if not exists
|
||||
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
|
||||
|
||||
# Create tables if not exist
|
||||
await db_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS config.adapters (
|
||||
name TEXT PRIMARY KEY,
|
||||
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||
cadence_s INTEGER NOT NULL,
|
||||
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
paused_at TIMESTAMPTZ,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
)
|
||||
""")
|
||||
await db_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS config.api_keys (
|
||||
alias TEXT PRIMARY KEY,
|
||||
encrypted_value BYTEA NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
rotated_at TIMESTAMPTZ,
|
||||
last_used_at TIMESTAMPTZ
|
||||
)
|
||||
""")
|
||||
|
||||
# Create notify function with proper key detection
|
||||
await db_conn.execute("""
|
||||
CREATE OR REPLACE FUNCTION config.notify_config_change()
|
||||
RETURNS trigger AS $$
|
||||
DECLARE
|
||||
key_value TEXT;
|
||||
BEGIN
|
||||
IF TG_TABLE_NAME = 'adapters' THEN
|
||||
key_value := COALESCE(NEW.name, OLD.name, '');
|
||||
ELSIF TG_TABLE_NAME = 'api_keys' THEN
|
||||
key_value := COALESCE(NEW.alias, OLD.alias, '');
|
||||
ELSE
|
||||
key_value := '';
|
||||
END IF;
|
||||
|
||||
PERFORM pg_notify('config_changed', TG_TABLE_NAME || ':' || key_value);
|
||||
RETURN COALESCE(NEW, OLD);
|
||||
END;
|
||||
$$ LANGUAGE plpgsql
|
||||
""")
|
||||
|
||||
# Create triggers if not exist
|
||||
await db_conn.execute("""
|
||||
DROP TRIGGER IF EXISTS adapters_notify ON config.adapters;
|
||||
CREATE TRIGGER adapters_notify
|
||||
AFTER INSERT OR UPDATE OR DELETE ON config.adapters
|
||||
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
|
||||
""")
|
||||
await db_conn.execute("""
|
||||
DROP TRIGGER IF EXISTS api_keys_notify ON config.api_keys;
|
||||
CREATE TRIGGER api_keys_notify
|
||||
AFTER INSERT OR UPDATE OR DELETE ON config.api_keys
|
||||
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
|
||||
""")
|
||||
|
||||
# Clean tables
|
||||
await db_conn.execute("DELETE FROM config.adapters")
|
||||
await db_conn.execute("DELETE FROM config.api_keys")
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def config_store(clean_config_schema: None) -> ConfigStore:
|
||||
"""Create a ConfigStore connected to the test database."""
|
||||
store = await ConfigStore.create(TEST_DB_DSN)
|
||||
yield store
|
||||
await store.close()
|
||||
|
||||
|
||||
class TestAdapterConfig:
|
||||
"""Tests for adapter configuration operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_and_get(self, config_store: ConfigStore) -> None:
|
||||
"""Can insert and retrieve adapter config."""
|
||||
await config_store.upsert_adapter(
|
||||
name="test_adapter",
|
||||
enabled=True,
|
||||
cadence_s=120,
|
||||
settings={"key": "value"},
|
||||
)
|
||||
|
||||
adapter = await config_store.get_adapter("test_adapter")
|
||||
|
||||
assert adapter is not None
|
||||
assert adapter.name == "test_adapter"
|
||||
assert adapter.enabled is True
|
||||
assert adapter.cadence_s == 120
|
||||
assert adapter.settings == {"key": "value"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent(self, config_store: ConfigStore) -> None:
|
||||
"""Getting nonexistent adapter returns None."""
|
||||
adapter = await config_store.get_adapter("does_not_exist")
|
||||
assert adapter is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_adapters(self, config_store: ConfigStore) -> None:
|
||||
"""Can list all adapters."""
|
||||
await config_store.upsert_adapter("adapter_a", True, 60, {})
|
||||
await config_store.upsert_adapter("adapter_b", False, 300, {"x": 1})
|
||||
|
||||
adapters = await config_store.list_adapters()
|
||||
|
||||
assert len(adapters) == 2
|
||||
names = [a.name for a in adapters]
|
||||
assert "adapter_a" in names
|
||||
assert "adapter_b" in names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_updates_existing(self, config_store: ConfigStore) -> None:
|
||||
"""Upsert updates existing adapter."""
|
||||
await config_store.upsert_adapter("updater", True, 60, {"v": 1})
|
||||
await config_store.upsert_adapter("updater", False, 120, {"v": 2})
|
||||
|
||||
adapter = await config_store.get_adapter("updater")
|
||||
|
||||
assert adapter is not None
|
||||
assert adapter.enabled is False
|
||||
assert adapter.cadence_s == 120
|
||||
assert adapter.settings == {"v": 2}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_unpause(self, config_store: ConfigStore) -> None:
|
||||
"""Can pause and unpause adapter."""
|
||||
await config_store.upsert_adapter("pausable", True, 60, {})
|
||||
|
||||
await config_store.pause_adapter("pausable")
|
||||
adapter = await config_store.get_adapter("pausable")
|
||||
assert adapter is not None
|
||||
assert adapter.is_paused is True
|
||||
|
||||
await config_store.unpause_adapter("pausable")
|
||||
adapter = await config_store.get_adapter("pausable")
|
||||
assert adapter is not None
|
||||
assert adapter.is_paused is False
|
||||
|
||||
|
||||
class TestApiKeys:
|
||||
"""Tests for API key operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_and_get_key(self, config_store: ConfigStore) -> None:
|
||||
"""Can store and retrieve encrypted API key."""
|
||||
await config_store.set_api_key("test_key", "super_secret_value")
|
||||
|
||||
value = await config_store.get_api_key("test_key")
|
||||
|
||||
assert value == "super_secret_value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_key(self, config_store: ConfigStore) -> None:
|
||||
"""Getting nonexistent key returns None."""
|
||||
value = await config_store.get_api_key("does_not_exist")
|
||||
assert value is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_rotation(self, config_store: ConfigStore) -> None:
|
||||
"""Updating key sets rotated_at."""
|
||||
await config_store.set_api_key("rotate_me", "value1")
|
||||
await config_store.set_api_key("rotate_me", "value2")
|
||||
|
||||
value = await config_store.get_api_key("rotate_me")
|
||||
assert value == "value2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_key(self, config_store: ConfigStore) -> None:
|
||||
"""Can delete API key."""
|
||||
await config_store.set_api_key("delete_me", "value")
|
||||
|
||||
deleted = await config_store.delete_api_key("delete_me")
|
||||
assert deleted is True
|
||||
|
||||
value = await config_store.get_api_key("delete_me")
|
||||
assert value is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent(self, config_store: ConfigStore) -> None:
|
||||
"""Deleting nonexistent key returns False."""
|
||||
deleted = await config_store.delete_api_key("never_existed")
|
||||
assert deleted is False
|
||||
|
||||
|
||||
class TestNotifications:
|
||||
"""Tests for LISTEN/NOTIFY functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notify_on_adapter_change(self, config_store: ConfigStore) -> None:
|
||||
"""NOTIFY fires when adapter is changed."""
|
||||
notifications: list[tuple[str, str]] = []
|
||||
notification_received = asyncio.Event()
|
||||
|
||||
async def callback(table: str, key: str) -> None:
|
||||
notifications.append((table, key))
|
||||
notification_received.set()
|
||||
|
||||
# Start listener in background
|
||||
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
|
||||
|
||||
try:
|
||||
# Give listener time to subscribe
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Trigger a change
|
||||
await config_store.upsert_adapter("notify_test", True, 60, {})
|
||||
|
||||
# Wait for notification (with timeout)
|
||||
try:
|
||||
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
pytest.fail("Notification not received within timeout")
|
||||
|
||||
assert len(notifications) >= 1
|
||||
assert notifications[0][0] == "adapters"
|
||||
assert notifications[0][1] == "notify_test"
|
||||
|
||||
finally:
|
||||
listen_task.cancel()
|
||||
try:
|
||||
await listen_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notify_on_api_key_change(self, config_store: ConfigStore) -> None:
|
||||
"""NOTIFY fires when API key is changed."""
|
||||
notifications: list[tuple[str, str]] = []
|
||||
notification_received = asyncio.Event()
|
||||
|
||||
async def callback(table: str, key: str) -> None:
|
||||
notifications.append((table, key))
|
||||
notification_received.set()
|
||||
|
||||
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
|
||||
|
||||
try:
|
||||
await asyncio.sleep(0.1)
|
||||
await config_store.set_api_key("notify_key", "secret")
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
pytest.fail("Notification not received within timeout")
|
||||
|
||||
assert len(notifications) >= 1
|
||||
assert notifications[0][0] == "api_keys"
|
||||
assert notifications[0][1] == "notify_key"
|
||||
|
||||
finally:
|
||||
listen_task.cancel()
|
||||
try:
|
||||
await listen_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
class TestListenerReconnect:
|
||||
"""Tests for listener reconnection on connection loss."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listener_cancellation_propagates(
|
||||
self, config_store: ConfigStore
|
||||
) -> None:
|
||||
"""Cancellation cleanly stops the listener without reconnect loop."""
|
||||
async def callback(table: str, key: str) -> None:
|
||||
pass
|
||||
|
||||
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
|
||||
|
||||
# Give listener time to start
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Cancel and verify it stops
|
||||
listen_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(listen_task, timeout=2.0)
|
||||
except asyncio.CancelledError:
|
||||
pass # Expected
|
||||
except asyncio.TimeoutError:
|
||||
pytest.fail("Listener did not stop after cancellation")
|
||||
|
||||
assert listen_task.cancelled() or listen_task.done()
|
||||
"""Tests for database-backed configuration store.
|
||||
|
||||
These tests require a real Postgres database. Set CENTRAL_TEST_DB_DSN
|
||||
environment variable to override the default test database connection.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from central.config_store import ConfigStore
|
||||
from central.crypto import KEY_SIZE, clear_key_cache
|
||||
|
||||
# Test database DSN - uses central_test database with well-known test password.
|
||||
# Override via CENTRAL_TEST_DB_DSN env var if your test DB differs.
|
||||
TEST_DB_DSN = os.environ.get(
|
||||
"CENTRAL_TEST_DB_DSN",
|
||||
"postgresql://central_test:testpass@localhost/central_test",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
|
||||
"""Create a master key file for the test session."""
|
||||
key = os.urandom(KEY_SIZE)
|
||||
key_path = tmp_path_factory.mktemp("keys") / "master.key"
|
||||
key_path.write_text(base64.b64encode(key).decode())
|
||||
return key_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Configure master key path for all tests."""
|
||||
clear_key_cache()
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
|
||||
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_conn() -> asyncpg.Connection:
|
||||
"""Get a direct database connection for setup/teardown."""
|
||||
conn = await asyncpg.connect(TEST_DB_DSN)
|
||||
yield conn
|
||||
await conn.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
|
||||
"""Ensure config schema exists and is clean before each test."""
|
||||
# Create schema if not exists
|
||||
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
|
||||
|
||||
# Create tables if not exist
|
||||
await db_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS config.adapters (
|
||||
name TEXT PRIMARY KEY,
|
||||
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||
cadence_s INTEGER NOT NULL,
|
||||
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
paused_at TIMESTAMPTZ,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
)
|
||||
""")
|
||||
await db_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS config.api_keys (
|
||||
alias TEXT PRIMARY KEY,
|
||||
encrypted_value BYTEA NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
rotated_at TIMESTAMPTZ,
|
||||
last_used_at TIMESTAMPTZ
|
||||
)
|
||||
""")
|
||||
|
||||
# Create notify function with proper key detection
|
||||
await db_conn.execute("""
|
||||
CREATE OR REPLACE FUNCTION config.notify_config_change()
|
||||
RETURNS trigger AS $$
|
||||
DECLARE
|
||||
key_value TEXT;
|
||||
BEGIN
|
||||
IF TG_TABLE_NAME = 'adapters' THEN
|
||||
key_value := COALESCE(NEW.name, OLD.name, '');
|
||||
ELSIF TG_TABLE_NAME = 'api_keys' THEN
|
||||
key_value := COALESCE(NEW.alias, OLD.alias, '');
|
||||
ELSE
|
||||
key_value := '';
|
||||
END IF;
|
||||
|
||||
PERFORM pg_notify('config_changed', TG_TABLE_NAME || ':' || key_value);
|
||||
RETURN COALESCE(NEW, OLD);
|
||||
END;
|
||||
$$ LANGUAGE plpgsql
|
||||
""")
|
||||
|
||||
# Create triggers if not exist
|
||||
await db_conn.execute("""
|
||||
DROP TRIGGER IF EXISTS adapters_notify ON config.adapters;
|
||||
CREATE TRIGGER adapters_notify
|
||||
AFTER INSERT OR UPDATE OR DELETE ON config.adapters
|
||||
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
|
||||
""")
|
||||
await db_conn.execute("""
|
||||
DROP TRIGGER IF EXISTS api_keys_notify ON config.api_keys;
|
||||
CREATE TRIGGER api_keys_notify
|
||||
AFTER INSERT OR UPDATE OR DELETE ON config.api_keys
|
||||
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
|
||||
""")
|
||||
|
||||
# Clean tables
|
||||
await db_conn.execute("DELETE FROM config.adapters")
|
||||
await db_conn.execute("DELETE FROM config.api_keys")
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def config_store(clean_config_schema: None) -> ConfigStore:
|
||||
"""Create a ConfigStore connected to the test database."""
|
||||
store = await ConfigStore.create(TEST_DB_DSN)
|
||||
yield store
|
||||
await store.close()
|
||||
|
||||
|
||||
class TestAdapterConfig:
|
||||
"""Tests for adapter configuration operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_and_get(self, config_store: ConfigStore) -> None:
|
||||
"""Can insert and retrieve adapter config."""
|
||||
await config_store.upsert_adapter(
|
||||
name="test_adapter",
|
||||
enabled=True,
|
||||
cadence_s=120,
|
||||
settings={"key": "value"},
|
||||
)
|
||||
|
||||
adapter = await config_store.get_adapter("test_adapter")
|
||||
|
||||
assert adapter is not None
|
||||
assert adapter.name == "test_adapter"
|
||||
assert adapter.enabled is True
|
||||
assert adapter.cadence_s == 120
|
||||
assert adapter.settings == {"key": "value"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent(self, config_store: ConfigStore) -> None:
|
||||
"""Getting nonexistent adapter returns None."""
|
||||
adapter = await config_store.get_adapter("does_not_exist")
|
||||
assert adapter is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_adapters(self, config_store: ConfigStore) -> None:
|
||||
"""Can list all adapters."""
|
||||
await config_store.upsert_adapter("adapter_a", True, 60, {})
|
||||
await config_store.upsert_adapter("adapter_b", False, 300, {"x": 1})
|
||||
|
||||
adapters = await config_store.list_adapters()
|
||||
|
||||
assert len(adapters) == 2
|
||||
names = [a.name for a in adapters]
|
||||
assert "adapter_a" in names
|
||||
assert "adapter_b" in names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_updates_existing(self, config_store: ConfigStore) -> None:
|
||||
"""Upsert updates existing adapter."""
|
||||
await config_store.upsert_adapter("updater", True, 60, {"v": 1})
|
||||
await config_store.upsert_adapter("updater", False, 120, {"v": 2})
|
||||
|
||||
adapter = await config_store.get_adapter("updater")
|
||||
|
||||
assert adapter is not None
|
||||
assert adapter.enabled is False
|
||||
assert adapter.cadence_s == 120
|
||||
assert adapter.settings == {"v": 2}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pause_unpause(self, config_store: ConfigStore) -> None:
|
||||
"""Can pause and unpause adapter."""
|
||||
await config_store.upsert_adapter("pausable", True, 60, {})
|
||||
|
||||
await config_store.pause_adapter("pausable")
|
||||
adapter = await config_store.get_adapter("pausable")
|
||||
assert adapter is not None
|
||||
assert adapter.is_paused is True
|
||||
|
||||
await config_store.unpause_adapter("pausable")
|
||||
adapter = await config_store.get_adapter("pausable")
|
||||
assert adapter is not None
|
||||
assert adapter.is_paused is False
|
||||
|
||||
|
||||
class TestApiKeys:
|
||||
"""Tests for API key operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_and_get_key(self, config_store: ConfigStore) -> None:
|
||||
"""Can store and retrieve encrypted API key."""
|
||||
await config_store.set_api_key("test_key", "super_secret_value")
|
||||
|
||||
value = await config_store.get_api_key("test_key")
|
||||
|
||||
assert value == "super_secret_value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_nonexistent_key(self, config_store: ConfigStore) -> None:
|
||||
"""Getting nonexistent key returns None."""
|
||||
value = await config_store.get_api_key("does_not_exist")
|
||||
assert value is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_rotation(self, config_store: ConfigStore) -> None:
|
||||
"""Updating key sets rotated_at."""
|
||||
await config_store.set_api_key("rotate_me", "value1")
|
||||
await config_store.set_api_key("rotate_me", "value2")
|
||||
|
||||
value = await config_store.get_api_key("rotate_me")
|
||||
assert value == "value2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_key(self, config_store: ConfigStore) -> None:
|
||||
"""Can delete API key."""
|
||||
await config_store.set_api_key("delete_me", "value")
|
||||
|
||||
deleted = await config_store.delete_api_key("delete_me")
|
||||
assert deleted is True
|
||||
|
||||
value = await config_store.get_api_key("delete_me")
|
||||
assert value is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_nonexistent(self, config_store: ConfigStore) -> None:
|
||||
"""Deleting nonexistent key returns False."""
|
||||
deleted = await config_store.delete_api_key("never_existed")
|
||||
assert deleted is False
|
||||
|
||||
|
||||
class TestNotifications:
|
||||
"""Tests for LISTEN/NOTIFY functionality."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notify_on_adapter_change(self, config_store: ConfigStore) -> None:
|
||||
"""NOTIFY fires when adapter is changed."""
|
||||
notifications: list[tuple[str, str]] = []
|
||||
notification_received = asyncio.Event()
|
||||
|
||||
async def callback(table: str, key: str) -> None:
|
||||
notifications.append((table, key))
|
||||
notification_received.set()
|
||||
|
||||
# Start listener in background
|
||||
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
|
||||
|
||||
try:
|
||||
# Give listener time to subscribe
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Trigger a change
|
||||
await config_store.upsert_adapter("notify_test", True, 60, {})
|
||||
|
||||
# Wait for notification (with timeout)
|
||||
try:
|
||||
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
pytest.fail("Notification not received within timeout")
|
||||
|
||||
assert len(notifications) >= 1
|
||||
assert notifications[0][0] == "adapters"
|
||||
assert notifications[0][1] == "notify_test"
|
||||
|
||||
finally:
|
||||
listen_task.cancel()
|
||||
try:
|
||||
await listen_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_notify_on_api_key_change(self, config_store: ConfigStore) -> None:
|
||||
"""NOTIFY fires when API key is changed."""
|
||||
notifications: list[tuple[str, str]] = []
|
||||
notification_received = asyncio.Event()
|
||||
|
||||
async def callback(table: str, key: str) -> None:
|
||||
notifications.append((table, key))
|
||||
notification_received.set()
|
||||
|
||||
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
|
||||
|
||||
try:
|
||||
await asyncio.sleep(0.1)
|
||||
await config_store.set_api_key("notify_key", "secret")
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
pytest.fail("Notification not received within timeout")
|
||||
|
||||
assert len(notifications) >= 1
|
||||
assert notifications[0][0] == "api_keys"
|
||||
assert notifications[0][1] == "notify_key"
|
||||
|
||||
finally:
|
||||
listen_task.cancel()
|
||||
try:
|
||||
await listen_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
class TestListenerReconnect:
|
||||
"""Tests for listener reconnection on connection loss."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_listener_cancellation_propagates(
|
||||
self, config_store: ConfigStore
|
||||
) -> None:
|
||||
"""Cancellation cleanly stops the listener without reconnect loop."""
|
||||
async def callback(table: str, key: str) -> None:
|
||||
pass
|
||||
|
||||
listen_task = asyncio.create_task(config_store.listen_for_changes(callback))
|
||||
|
||||
# Give listener time to start
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Cancel and verify it stops
|
||||
listen_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(listen_task, timeout=2.0)
|
||||
except asyncio.CancelledError:
|
||||
pass # Expected
|
||||
except asyncio.TimeoutError:
|
||||
pytest.fail("Listener did not stop after cancellation")
|
||||
|
||||
assert listen_task.cancelled() or listen_task.done()
|
||||
|
|
|
|||
|
|
@ -1,175 +1,175 @@
|
|||
"""Tests for cryptographic primitives."""
|
||||
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from central.crypto import (
|
||||
KEY_SIZE,
|
||||
DecryptionError,
|
||||
KeyLoadError,
|
||||
clear_key_cache,
|
||||
decrypt,
|
||||
encrypt,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def master_key(tmp_path: Path) -> Path:
|
||||
"""Create a valid master key file."""
|
||||
key = os.urandom(KEY_SIZE)
|
||||
key_path = tmp_path / "master.key"
|
||||
key_path.write_text(base64.b64encode(key).decode())
|
||||
clear_key_cache()
|
||||
return key_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def wrong_key(tmp_path: Path) -> Path:
|
||||
"""Create a different master key file."""
|
||||
key = os.urandom(KEY_SIZE)
|
||||
key_path = tmp_path / "wrong.key"
|
||||
key_path.write_text(base64.b64encode(key).decode())
|
||||
return key_path
|
||||
|
||||
|
||||
class TestEncryptDecrypt:
|
||||
"""Test encrypt/decrypt round-trip."""
|
||||
|
||||
def test_round_trip(self, master_key: Path) -> None:
|
||||
"""Encrypting then decrypting returns original plaintext."""
|
||||
plaintext = b"Hello, Central!"
|
||||
|
||||
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||
decrypted = decrypt(ciphertext, key_path=master_key)
|
||||
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_round_trip_empty(self, master_key: Path) -> None:
|
||||
"""Empty plaintext encrypts and decrypts correctly."""
|
||||
plaintext = b""
|
||||
|
||||
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||
decrypted = decrypt(ciphertext, key_path=master_key)
|
||||
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_round_trip_large(self, master_key: Path) -> None:
|
||||
"""Large plaintext encrypts and decrypts correctly."""
|
||||
plaintext = os.urandom(1024 * 1024) # 1MB
|
||||
|
||||
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||
decrypted = decrypt(ciphertext, key_path=master_key)
|
||||
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_ciphertext_different_each_time(self, master_key: Path) -> None:
|
||||
"""Same plaintext produces different ciphertext (random nonce)."""
|
||||
plaintext = b"test"
|
||||
|
||||
ct1 = encrypt(plaintext, key_path=master_key)
|
||||
ct2 = encrypt(plaintext, key_path=master_key)
|
||||
|
||||
assert ct1 != ct2
|
||||
# But both decrypt to same plaintext
|
||||
assert decrypt(ct1, key_path=master_key) == plaintext
|
||||
assert decrypt(ct2, key_path=master_key) == plaintext
|
||||
|
||||
|
||||
class TestDecryptionFailures:
|
||||
"""Test AEAD authentication catches tampering."""
|
||||
|
||||
def test_wrong_key_fails(self, master_key: Path, wrong_key: Path) -> None:
|
||||
"""Decryption with wrong key raises DecryptionError."""
|
||||
plaintext = b"secret"
|
||||
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||
|
||||
clear_key_cache() # Clear cache so wrong_key is loaded
|
||||
with pytest.raises(DecryptionError):
|
||||
decrypt(ciphertext, key_path=wrong_key)
|
||||
|
||||
def test_tampered_ciphertext_fails(self, master_key: Path) -> None:
|
||||
"""Modified ciphertext is detected and rejected."""
|
||||
plaintext = b"secret"
|
||||
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||
|
||||
# Flip a bit in the ciphertext (after nonce, before tag)
|
||||
tampered = bytearray(ciphertext)
|
||||
tampered[15] ^= 0x01 # Flip one bit
|
||||
tampered = bytes(tampered)
|
||||
|
||||
with pytest.raises(DecryptionError):
|
||||
decrypt(tampered, key_path=master_key)
|
||||
|
||||
def test_tampered_tag_fails(self, master_key: Path) -> None:
|
||||
"""Modified authentication tag is detected and rejected."""
|
||||
plaintext = b"secret"
|
||||
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||
|
||||
# Flip a bit in the last byte (part of the tag)
|
||||
tampered = bytearray(ciphertext)
|
||||
tampered[-1] ^= 0x01
|
||||
tampered = bytes(tampered)
|
||||
|
||||
with pytest.raises(DecryptionError):
|
||||
decrypt(tampered, key_path=master_key)
|
||||
|
||||
def test_truncated_ciphertext_fails(self, master_key: Path) -> None:
|
||||
"""Truncated ciphertext is rejected."""
|
||||
ciphertext = b"tooshort"
|
||||
|
||||
with pytest.raises(DecryptionError, match="too short"):
|
||||
decrypt(ciphertext, key_path=master_key)
|
||||
|
||||
|
||||
class TestKeyLoading:
|
||||
"""Test master key loading."""
|
||||
|
||||
def test_missing_key_file(self, tmp_path: Path) -> None:
|
||||
"""Missing key file raises KeyLoadError."""
|
||||
clear_key_cache()
|
||||
missing = tmp_path / "nonexistent.key"
|
||||
|
||||
with pytest.raises(KeyLoadError, match="not found"):
|
||||
encrypt(b"test", key_path=missing)
|
||||
|
||||
def test_invalid_key_size(self, tmp_path: Path) -> None:
|
||||
"""Key file with wrong size raises KeyLoadError."""
|
||||
clear_key_cache()
|
||||
bad_key = tmp_path / "bad.key"
|
||||
bad_key.write_text(base64.b64encode(b"tooshort").decode())
|
||||
|
||||
with pytest.raises(KeyLoadError, match="Invalid master key size"):
|
||||
encrypt(b"test", key_path=bad_key)
|
||||
|
||||
def test_invalid_base64(self, tmp_path: Path) -> None:
|
||||
"""Invalid base64 in key file raises KeyLoadError."""
|
||||
clear_key_cache()
|
||||
bad_key = tmp_path / "bad.key"
|
||||
bad_key.write_text("not valid base64!!!")
|
||||
|
||||
with pytest.raises(KeyLoadError):
|
||||
encrypt(b"test", key_path=bad_key)
|
||||
|
||||
def test_key_cached(self, master_key: Path) -> None:
|
||||
"""Key is cached after first load."""
|
||||
# First encryption loads the key
|
||||
encrypt(b"test1", key_path=master_key)
|
||||
|
||||
# Delete the file
|
||||
master_key.unlink()
|
||||
|
||||
# Second encryption should still work (cached)
|
||||
ciphertext = encrypt(b"test2", key_path=master_key)
|
||||
assert len(ciphertext) > 0
|
||||
|
||||
def test_cache_clear(self, master_key: Path) -> None:
|
||||
"""clear_key_cache forces reload."""
|
||||
encrypt(b"test", key_path=master_key)
|
||||
master_key.unlink()
|
||||
clear_key_cache()
|
||||
|
||||
with pytest.raises(KeyLoadError, match="not found"):
|
||||
encrypt(b"test", key_path=master_key)
|
||||
"""Tests for cryptographic primitives."""
|
||||
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from central.crypto import (
|
||||
KEY_SIZE,
|
||||
DecryptionError,
|
||||
KeyLoadError,
|
||||
clear_key_cache,
|
||||
decrypt,
|
||||
encrypt,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def master_key(tmp_path: Path) -> Path:
|
||||
"""Create a valid master key file."""
|
||||
key = os.urandom(KEY_SIZE)
|
||||
key_path = tmp_path / "master.key"
|
||||
key_path.write_text(base64.b64encode(key).decode())
|
||||
clear_key_cache()
|
||||
return key_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def wrong_key(tmp_path: Path) -> Path:
|
||||
"""Create a different master key file."""
|
||||
key = os.urandom(KEY_SIZE)
|
||||
key_path = tmp_path / "wrong.key"
|
||||
key_path.write_text(base64.b64encode(key).decode())
|
||||
return key_path
|
||||
|
||||
|
||||
class TestEncryptDecrypt:
|
||||
"""Test encrypt/decrypt round-trip."""
|
||||
|
||||
def test_round_trip(self, master_key: Path) -> None:
|
||||
"""Encrypting then decrypting returns original plaintext."""
|
||||
plaintext = b"Hello, Central!"
|
||||
|
||||
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||
decrypted = decrypt(ciphertext, key_path=master_key)
|
||||
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_round_trip_empty(self, master_key: Path) -> None:
|
||||
"""Empty plaintext encrypts and decrypts correctly."""
|
||||
plaintext = b""
|
||||
|
||||
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||
decrypted = decrypt(ciphertext, key_path=master_key)
|
||||
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_round_trip_large(self, master_key: Path) -> None:
|
||||
"""Large plaintext encrypts and decrypts correctly."""
|
||||
plaintext = os.urandom(1024 * 1024) # 1MB
|
||||
|
||||
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||
decrypted = decrypt(ciphertext, key_path=master_key)
|
||||
|
||||
assert decrypted == plaintext
|
||||
|
||||
def test_ciphertext_different_each_time(self, master_key: Path) -> None:
|
||||
"""Same plaintext produces different ciphertext (random nonce)."""
|
||||
plaintext = b"test"
|
||||
|
||||
ct1 = encrypt(plaintext, key_path=master_key)
|
||||
ct2 = encrypt(plaintext, key_path=master_key)
|
||||
|
||||
assert ct1 != ct2
|
||||
# But both decrypt to same plaintext
|
||||
assert decrypt(ct1, key_path=master_key) == plaintext
|
||||
assert decrypt(ct2, key_path=master_key) == plaintext
|
||||
|
||||
|
||||
class TestDecryptionFailures:
|
||||
"""Test AEAD authentication catches tampering."""
|
||||
|
||||
def test_wrong_key_fails(self, master_key: Path, wrong_key: Path) -> None:
|
||||
"""Decryption with wrong key raises DecryptionError."""
|
||||
plaintext = b"secret"
|
||||
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||
|
||||
clear_key_cache() # Clear cache so wrong_key is loaded
|
||||
with pytest.raises(DecryptionError):
|
||||
decrypt(ciphertext, key_path=wrong_key)
|
||||
|
||||
def test_tampered_ciphertext_fails(self, master_key: Path) -> None:
|
||||
"""Modified ciphertext is detected and rejected."""
|
||||
plaintext = b"secret"
|
||||
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||
|
||||
# Flip a bit in the ciphertext (after nonce, before tag)
|
||||
tampered = bytearray(ciphertext)
|
||||
tampered[15] ^= 0x01 # Flip one bit
|
||||
tampered = bytes(tampered)
|
||||
|
||||
with pytest.raises(DecryptionError):
|
||||
decrypt(tampered, key_path=master_key)
|
||||
|
||||
def test_tampered_tag_fails(self, master_key: Path) -> None:
|
||||
"""Modified authentication tag is detected and rejected."""
|
||||
plaintext = b"secret"
|
||||
ciphertext = encrypt(plaintext, key_path=master_key)
|
||||
|
||||
# Flip a bit in the last byte (part of the tag)
|
||||
tampered = bytearray(ciphertext)
|
||||
tampered[-1] ^= 0x01
|
||||
tampered = bytes(tampered)
|
||||
|
||||
with pytest.raises(DecryptionError):
|
||||
decrypt(tampered, key_path=master_key)
|
||||
|
||||
def test_truncated_ciphertext_fails(self, master_key: Path) -> None:
|
||||
"""Truncated ciphertext is rejected."""
|
||||
ciphertext = b"tooshort"
|
||||
|
||||
with pytest.raises(DecryptionError, match="too short"):
|
||||
decrypt(ciphertext, key_path=master_key)
|
||||
|
||||
|
||||
class TestKeyLoading:
|
||||
"""Test master key loading."""
|
||||
|
||||
def test_missing_key_file(self, tmp_path: Path) -> None:
|
||||
"""Missing key file raises KeyLoadError."""
|
||||
clear_key_cache()
|
||||
missing = tmp_path / "nonexistent.key"
|
||||
|
||||
with pytest.raises(KeyLoadError, match="not found"):
|
||||
encrypt(b"test", key_path=missing)
|
||||
|
||||
def test_invalid_key_size(self, tmp_path: Path) -> None:
|
||||
"""Key file with wrong size raises KeyLoadError."""
|
||||
clear_key_cache()
|
||||
bad_key = tmp_path / "bad.key"
|
||||
bad_key.write_text(base64.b64encode(b"tooshort").decode())
|
||||
|
||||
with pytest.raises(KeyLoadError, match="Invalid master key size"):
|
||||
encrypt(b"test", key_path=bad_key)
|
||||
|
||||
def test_invalid_base64(self, tmp_path: Path) -> None:
|
||||
"""Invalid base64 in key file raises KeyLoadError."""
|
||||
clear_key_cache()
|
||||
bad_key = tmp_path / "bad.key"
|
||||
bad_key.write_text("not valid base64!!!")
|
||||
|
||||
with pytest.raises(KeyLoadError):
|
||||
encrypt(b"test", key_path=bad_key)
|
||||
|
||||
def test_key_cached(self, master_key: Path) -> None:
|
||||
"""Key is cached after first load."""
|
||||
# First encryption loads the key
|
||||
encrypt(b"test1", key_path=master_key)
|
||||
|
||||
# Delete the file
|
||||
master_key.unlink()
|
||||
|
||||
# Second encryption should still work (cached)
|
||||
ciphertext = encrypt(b"test2", key_path=master_key)
|
||||
assert len(ciphertext) > 0
|
||||
|
||||
def test_cache_clear(self, master_key: Path) -> None:
|
||||
"""clear_key_cache forces reload."""
|
||||
encrypt(b"test", key_path=master_key)
|
||||
master_key.unlink()
|
||||
clear_key_cache()
|
||||
|
||||
with pytest.raises(KeyLoadError, match="not found"):
|
||||
encrypt(b"test", key_path=master_key)
|
||||
|
|
|
|||
|
|
@ -1,161 +1,161 @@
|
|||
"""Smoke tests for Central models and CloudEvents wire format."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from central.models import Event, Geo, subject_for_event
|
||||
from central.config import NWSAdapterConfig, CloudEventsConfig, NATSConfig, PostgresConfig, Config
|
||||
from central.cloudevents_wire import wrap_event
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_geo() -> Geo:
|
||||
"""Sample Geo object for testing."""
|
||||
return Geo(
|
||||
centroid=(-116.2, 43.6),
|
||||
bbox=(-116.5, 43.4, -115.9, 43.8),
|
||||
regions=["US-ID-Ada", "US-ID-Canyon"],
|
||||
primary_region="US-ID-Ada",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_event(sample_geo: Geo) -> Event:
|
||||
"""Sample Event object for testing."""
|
||||
return Event(
|
||||
id="urn:central:nws:alert:KBOI-202401151200-SVR",
|
||||
source="central/adapters/nws",
|
||||
category="wx.alert.severe_thunderstorm_warning",
|
||||
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
|
||||
expires=datetime(2024, 1, 15, 13, 0, 0, tzinfo=timezone.utc),
|
||||
severity=3,
|
||||
geo=sample_geo,
|
||||
data={"headline": "Severe Thunderstorm Warning", "urgency": "Immediate"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config() -> Config:
|
||||
"""Sample Config object for testing."""
|
||||
return Config(
|
||||
adapters={
|
||||
"nws": NWSAdapterConfig(
|
||||
enabled=True,
|
||||
cadence_s=60,
|
||||
states=["ID", "MT"],
|
||||
contact_email="test@example.com",
|
||||
)
|
||||
},
|
||||
cloudevents=CloudEventsConfig(
|
||||
type_prefix="central",
|
||||
source="central.local",
|
||||
schema_version="1.0",
|
||||
),
|
||||
nats=NATSConfig(url="nats://localhost:4222"),
|
||||
postgres=PostgresConfig(dsn="postgresql://user:pass@localhost/db"),
|
||||
)
|
||||
|
||||
|
||||
class TestSubjectForEvent:
|
||||
"""Tests for subject_for_event helper."""
|
||||
|
||||
def test_county_subject(self, sample_event: Event) -> None:
|
||||
"""County codes produce county subject."""
|
||||
subject = subject_for_event(sample_event)
|
||||
assert subject == "central.wx.alert.us.id.county.ada"
|
||||
|
||||
def test_zone_subject(self, sample_geo: Geo) -> None:
|
||||
"""Zone codes produce zone subject."""
|
||||
geo = Geo(
|
||||
centroid=sample_geo.centroid,
|
||||
bbox=sample_geo.bbox,
|
||||
regions=["US-ID-Z033"],
|
||||
primary_region="US-ID-Z033",
|
||||
)
|
||||
event = Event(
|
||||
id="test-zone",
|
||||
source="test",
|
||||
category="wx.alert.winter_storm_warning",
|
||||
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
|
||||
geo=geo,
|
||||
data={},
|
||||
)
|
||||
subject = subject_for_event(event)
|
||||
assert subject == "central.wx.alert.us.id.zone.z033"
|
||||
|
||||
def test_unknown_subject(self, sample_event: Event) -> None:
|
||||
"""Missing primary_region produces unknown subject."""
|
||||
geo = Geo(regions=[], primary_region=None)
|
||||
event = Event(
|
||||
id="test-unknown",
|
||||
source="test",
|
||||
category="wx.alert.test",
|
||||
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
|
||||
geo=geo,
|
||||
data={},
|
||||
)
|
||||
subject = subject_for_event(event)
|
||||
assert subject == "central.wx.alert.us.unknown"
|
||||
|
||||
def test_custom_prefix(self, sample_event: Event) -> None:
|
||||
"""Custom prefix is used in subject."""
|
||||
subject = subject_for_event(sample_event, prefix="myapp.events")
|
||||
assert subject == "myapp.events.alert.us.id.county.ada"
|
||||
|
||||
|
||||
class TestCloudEventsWire:
|
||||
"""Tests for CloudEvents wire format."""
|
||||
|
||||
def test_required_fields_present(
|
||||
self, sample_event: Event, sample_config: Config
|
||||
) -> None:
|
||||
"""Required CloudEvents fields are present."""
|
||||
envelope, msg_id = wrap_event(sample_event, sample_config)
|
||||
|
||||
assert msg_id == sample_event.id
|
||||
assert envelope["id"] == sample_event.id
|
||||
assert envelope["source"] == sample_config.cloudevents.source
|
||||
assert envelope["type"] == "central.wx.alert.severe_thunderstorm_warning.v1"
|
||||
assert envelope["specversion"] == "1.0"
|
||||
assert "time" in envelope
|
||||
assert envelope["datacontenttype"] == "application/json"
|
||||
assert "data" in envelope
|
||||
|
||||
def test_extension_attributes_lowercase(
|
||||
self, sample_event: Event, sample_config: Config
|
||||
) -> None:
|
||||
"""Extension attributes are lowercase with no underscores."""
|
||||
envelope, _ = wrap_event(sample_event, sample_config)
|
||||
|
||||
# Check that extension attributes exist and are lowercase
|
||||
assert envelope["centralschemaversion"] == "1.0"
|
||||
assert envelope["centralcategory"] == "wx.alert.severe_thunderstorm_warning"
|
||||
assert envelope["centralseverity"] == 3
|
||||
|
||||
# Verify no uppercase or underscores in extension names
|
||||
for key in ["centralschemaversion", "centralcategory", "centralseverity"]:
|
||||
assert key.islower()
|
||||
assert "_" not in key
|
||||
|
||||
def test_severity_none_omits_centralseverity(
|
||||
self, sample_geo: Geo, sample_config: Config
|
||||
) -> None:
|
||||
"""When severity is None, centralseverity is omitted entirely."""
|
||||
event = Event(
|
||||
id="test-no-severity",
|
||||
source="test",
|
||||
category="wx.alert.test",
|
||||
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
|
||||
severity=None, # Explicitly None
|
||||
geo=sample_geo,
|
||||
data={},
|
||||
)
|
||||
|
||||
envelope, _ = wrap_event(event, sample_config)
|
||||
|
||||
# centralseverity should not be present at all
|
||||
assert "centralseverity" not in envelope
|
||||
# Other extensions should still be present
|
||||
assert "centralschemaversion" in envelope
|
||||
assert "centralcategory" in envelope
|
||||
"""Smoke tests for Central models and CloudEvents wire format."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from central.models import Event, Geo, subject_for_event
|
||||
from central.config import NWSAdapterConfig, CloudEventsConfig, NATSConfig, PostgresConfig, Config
|
||||
from central.cloudevents_wire import wrap_event
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_geo() -> Geo:
|
||||
"""Sample Geo object for testing."""
|
||||
return Geo(
|
||||
centroid=(-116.2, 43.6),
|
||||
bbox=(-116.5, 43.4, -115.9, 43.8),
|
||||
regions=["US-ID-Ada", "US-ID-Canyon"],
|
||||
primary_region="US-ID-Ada",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_event(sample_geo: Geo) -> Event:
|
||||
"""Sample Event object for testing."""
|
||||
return Event(
|
||||
id="urn:central:nws:alert:KBOI-202401151200-SVR",
|
||||
source="central/adapters/nws",
|
||||
category="wx.alert.severe_thunderstorm_warning",
|
||||
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
|
||||
expires=datetime(2024, 1, 15, 13, 0, 0, tzinfo=timezone.utc),
|
||||
severity=3,
|
||||
geo=sample_geo,
|
||||
data={"headline": "Severe Thunderstorm Warning", "urgency": "Immediate"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config() -> Config:
|
||||
"""Sample Config object for testing."""
|
||||
return Config(
|
||||
adapters={
|
||||
"nws": NWSAdapterConfig(
|
||||
enabled=True,
|
||||
cadence_s=60,
|
||||
states=["ID", "MT"],
|
||||
contact_email="test@example.com",
|
||||
)
|
||||
},
|
||||
cloudevents=CloudEventsConfig(
|
||||
type_prefix="central",
|
||||
source="central.local",
|
||||
schema_version="1.0",
|
||||
),
|
||||
nats=NATSConfig(url="nats://localhost:4222"),
|
||||
postgres=PostgresConfig(dsn="postgresql://user:pass@localhost/db"),
|
||||
)
|
||||
|
||||
|
||||
class TestSubjectForEvent:
|
||||
"""Tests for subject_for_event helper."""
|
||||
|
||||
def test_county_subject(self, sample_event: Event) -> None:
|
||||
"""County codes produce county subject."""
|
||||
subject = subject_for_event(sample_event)
|
||||
assert subject == "central.wx.alert.us.id.county.ada"
|
||||
|
||||
def test_zone_subject(self, sample_geo: Geo) -> None:
|
||||
"""Zone codes produce zone subject."""
|
||||
geo = Geo(
|
||||
centroid=sample_geo.centroid,
|
||||
bbox=sample_geo.bbox,
|
||||
regions=["US-ID-Z033"],
|
||||
primary_region="US-ID-Z033",
|
||||
)
|
||||
event = Event(
|
||||
id="test-zone",
|
||||
source="test",
|
||||
category="wx.alert.winter_storm_warning",
|
||||
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
|
||||
geo=geo,
|
||||
data={},
|
||||
)
|
||||
subject = subject_for_event(event)
|
||||
assert subject == "central.wx.alert.us.id.zone.z033"
|
||||
|
||||
def test_unknown_subject(self, sample_event: Event) -> None:
|
||||
"""Missing primary_region produces unknown subject."""
|
||||
geo = Geo(regions=[], primary_region=None)
|
||||
event = Event(
|
||||
id="test-unknown",
|
||||
source="test",
|
||||
category="wx.alert.test",
|
||||
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
|
||||
geo=geo,
|
||||
data={},
|
||||
)
|
||||
subject = subject_for_event(event)
|
||||
assert subject == "central.wx.alert.us.unknown"
|
||||
|
||||
def test_custom_prefix(self, sample_event: Event) -> None:
|
||||
"""Custom prefix is used in subject."""
|
||||
subject = subject_for_event(sample_event, prefix="myapp.events")
|
||||
assert subject == "myapp.events.alert.us.id.county.ada"
|
||||
|
||||
|
||||
class TestCloudEventsWire:
|
||||
"""Tests for CloudEvents wire format."""
|
||||
|
||||
def test_required_fields_present(
|
||||
self, sample_event: Event, sample_config: Config
|
||||
) -> None:
|
||||
"""Required CloudEvents fields are present."""
|
||||
envelope, msg_id = wrap_event(sample_event, sample_config)
|
||||
|
||||
assert msg_id == sample_event.id
|
||||
assert envelope["id"] == sample_event.id
|
||||
assert envelope["source"] == sample_config.cloudevents.source
|
||||
assert envelope["type"] == "central.wx.alert.severe_thunderstorm_warning.v1"
|
||||
assert envelope["specversion"] == "1.0"
|
||||
assert "time" in envelope
|
||||
assert envelope["datacontenttype"] == "application/json"
|
||||
assert "data" in envelope
|
||||
|
||||
def test_extension_attributes_lowercase(
|
||||
self, sample_event: Event, sample_config: Config
|
||||
) -> None:
|
||||
"""Extension attributes are lowercase with no underscores."""
|
||||
envelope, _ = wrap_event(sample_event, sample_config)
|
||||
|
||||
# Check that extension attributes exist and are lowercase
|
||||
assert envelope["centralschemaversion"] == "1.0"
|
||||
assert envelope["centralcategory"] == "wx.alert.severe_thunderstorm_warning"
|
||||
assert envelope["centralseverity"] == 3
|
||||
|
||||
# Verify no uppercase or underscores in extension names
|
||||
for key in ["centralschemaversion", "centralcategory", "centralseverity"]:
|
||||
assert key.islower()
|
||||
assert "_" not in key
|
||||
|
||||
def test_severity_none_omits_centralseverity(
|
||||
self, sample_geo: Geo, sample_config: Config
|
||||
) -> None:
|
||||
"""When severity is None, centralseverity is omitted entirely."""
|
||||
event = Event(
|
||||
id="test-no-severity",
|
||||
source="test",
|
||||
category="wx.alert.test",
|
||||
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
|
||||
severity=None, # Explicitly None
|
||||
geo=sample_geo,
|
||||
data={},
|
||||
)
|
||||
|
||||
envelope, _ = wrap_event(event, sample_config)
|
||||
|
||||
# centralseverity should not be present at all
|
||||
assert "centralseverity" not in envelope
|
||||
# Other extensions should still be present
|
||||
assert "centralschemaversion" in envelope
|
||||
assert "centralcategory" in envelope
|
||||
|
|
|
|||
|
|
@ -1,357 +1,357 @@
|
|||
"""Tests for supervisor hot-reload and rate-limiting behavior."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from central.config_models import AdapterConfig
|
||||
from central.config_source import DbConfigSource
|
||||
from central.config_store import ConfigStore
|
||||
from central.crypto import KEY_SIZE, clear_key_cache
|
||||
|
||||
# Test database DSN
|
||||
TEST_DB_DSN = os.environ.get(
|
||||
"CENTRAL_TEST_DB_DSN",
|
||||
"postgresql://central_test:testpass@localhost/central_test",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
|
||||
"""Create a master key file for the test session."""
|
||||
key = os.urandom(KEY_SIZE)
|
||||
key_path = tmp_path_factory.mktemp("keys") / "master.key"
|
||||
key_path.write_text(base64.b64encode(key).decode())
|
||||
return key_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Configure master key path for all tests."""
|
||||
clear_key_cache()
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
|
||||
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_conn() -> asyncpg.Connection:
|
||||
"""Get a direct database connection for setup/teardown."""
|
||||
conn = await asyncpg.connect(TEST_DB_DSN)
|
||||
yield conn
|
||||
await conn.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
|
||||
"""Ensure config schema exists and is clean before each test."""
|
||||
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
|
||||
await db_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS config.adapters (
|
||||
name TEXT PRIMARY KEY,
|
||||
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||
cadence_s INTEGER NOT NULL,
|
||||
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
paused_at TIMESTAMPTZ,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
)
|
||||
""")
|
||||
# Create notify trigger
|
||||
await db_conn.execute("""
|
||||
CREATE OR REPLACE FUNCTION config.notify_config_change()
|
||||
RETURNS trigger AS $$
|
||||
DECLARE
|
||||
key_value TEXT;
|
||||
BEGIN
|
||||
IF TG_TABLE_NAME = 'adapters' THEN
|
||||
key_value := COALESCE(NEW.name, OLD.name, '');
|
||||
ELSE
|
||||
key_value := '';
|
||||
END IF;
|
||||
PERFORM pg_notify('config_changed', TG_TABLE_NAME || ':' || key_value);
|
||||
RETURN COALESCE(NEW, OLD);
|
||||
END;
|
||||
$$ LANGUAGE plpgsql
|
||||
""")
|
||||
await db_conn.execute("""
|
||||
DROP TRIGGER IF EXISTS adapters_notify ON config.adapters;
|
||||
CREATE TRIGGER adapters_notify
|
||||
AFTER INSERT OR UPDATE OR DELETE ON config.adapters
|
||||
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
|
||||
""")
|
||||
await db_conn.execute("DELETE FROM config.adapters")
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def config_store(clean_config_schema: None) -> ConfigStore:
|
||||
"""Create a ConfigStore connected to the test database."""
|
||||
store = await ConfigStore.create(TEST_DB_DSN)
|
||||
yield store
|
||||
await store.close()
|
||||
|
||||
|
||||
class TestDbConfigSourceNotifications:
|
||||
"""Tests for DbConfigSource NOTIFY integration."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_watch_receives_notifications(
|
||||
self,
|
||||
config_store: ConfigStore,
|
||||
db_conn: asyncpg.Connection,
|
||||
) -> None:
|
||||
"""watch_for_changes receives NOTIFY when adapter changes."""
|
||||
source = DbConfigSource(config_store)
|
||||
notifications: list[tuple[str, str]] = []
|
||||
notification_received = asyncio.Event()
|
||||
|
||||
async def callback(table: str, key: str) -> None:
|
||||
notifications.append((table, key))
|
||||
notification_received.set()
|
||||
|
||||
# Start watching in background
|
||||
watch_task = asyncio.create_task(source.watch_for_changes(callback))
|
||||
|
||||
try:
|
||||
# Wait for listener to connect
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Insert an adapter via direct connection (not through store)
|
||||
# This triggers the NOTIFY
|
||||
await db_conn.execute("""
|
||||
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
|
||||
VALUES ('test_adapter', true, 60, '{}'::jsonb)
|
||||
""")
|
||||
|
||||
# Wait for notification
|
||||
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
|
||||
|
||||
assert len(notifications) >= 1
|
||||
assert notifications[0] == ("adapters", "test_adapter")
|
||||
|
||||
finally:
|
||||
watch_task.cancel()
|
||||
try:
|
||||
await watch_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
class TestRateLimitGuarantee:
|
||||
"""Tests for rate-limit guarantees during hot-reload.
|
||||
|
||||
These tests verify the critical invariant: cadence changes must not
|
||||
cause extra API calls before (last_poll + new_cadence).
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cadence_change_respects_last_poll_time(self) -> None:
|
||||
"""Changing cadence mid-cycle schedules next poll at last_poll + new_cadence.
|
||||
|
||||
This is the core rate-limit guarantee test (gate 3).
|
||||
"""
|
||||
# Import supervisor module to access AdapterState
|
||||
from central.supervisor import AdapterState
|
||||
|
||||
# Mock adapter
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test"
|
||||
mock_adapter.cadence_s = 60
|
||||
|
||||
# Create adapter state with a known last_completed_poll time
|
||||
last_poll = datetime.now(timezone.utc) - timedelta(seconds=30)
|
||||
|
||||
config = AdapterConfig(
|
||||
name="test",
|
||||
enabled=True,
|
||||
cadence_s=60, # Original cadence
|
||||
settings={},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
state = AdapterState(
|
||||
name="test",
|
||||
adapter=mock_adapter,
|
||||
config=config,
|
||||
last_completed_poll=last_poll,
|
||||
)
|
||||
|
||||
# Simulate cadence change to 90 seconds
|
||||
new_config = AdapterConfig(
|
||||
name="test",
|
||||
enabled=True,
|
||||
cadence_s=90, # New cadence
|
||||
settings={},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Update state as reschedule would
|
||||
state.config = new_config
|
||||
state.adapter.cadence_s = 90
|
||||
|
||||
# Calculate expected next poll time
|
||||
expected_next_poll = last_poll + timedelta(seconds=90)
|
||||
now = datetime.now(timezone.utc)
|
||||
expected_wait = max(0, (expected_next_poll - now).total_seconds())
|
||||
|
||||
# The wait time should be based on last_poll + new_cadence
|
||||
# Since last_poll was 30 seconds ago and new cadence is 90,
|
||||
# we should wait 60 more seconds (90 - 30 = 60)
|
||||
actual_next_poll = last_poll.timestamp() + new_config.cadence_s
|
||||
actual_wait = max(0, actual_next_poll - now.timestamp())
|
||||
|
||||
# Allow 1 second tolerance for timing
|
||||
assert abs(actual_wait - 60) < 2, (
|
||||
f"Expected ~60s wait, got {actual_wait}s. "
|
||||
f"Rate limit violated: poll would happen before last_poll + new_cadence"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cadence_increase_after_gap_polls_immediately(self) -> None:
|
||||
"""When last_poll + new_cadence is already past, poll immediately.
|
||||
|
||||
If operator increases cadence to 120s after a gap of 150s,
|
||||
the poll should happen now (not wait another 120s).
|
||||
"""
|
||||
from central.supervisor import AdapterState
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test"
|
||||
mock_adapter.cadence_s = 60
|
||||
|
||||
# Last poll was 150 seconds ago
|
||||
last_poll = datetime.now(timezone.utc) - timedelta(seconds=150)
|
||||
|
||||
config = AdapterConfig(
|
||||
name="test",
|
||||
enabled=True,
|
||||
cadence_s=120, # Increased cadence
|
||||
settings={},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
state = AdapterState(
|
||||
name="test",
|
||||
adapter=mock_adapter,
|
||||
config=config,
|
||||
last_completed_poll=last_poll,
|
||||
)
|
||||
|
||||
# Calculate next poll time
|
||||
now = datetime.now(timezone.utc)
|
||||
next_poll_at = last_poll.timestamp() + config.cadence_s
|
||||
wait_time = max(0, next_poll_at - now.timestamp())
|
||||
|
||||
# Since 150 > 120, next poll should be immediate (wait_time ~= 0)
|
||||
assert wait_time < 1, (
|
||||
f"Expected immediate poll (wait ~0s), got {wait_time}s. "
|
||||
f"After a gap exceeding new cadence, poll should happen now."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_disable_enable_respects_rate_limit(self) -> None:
|
||||
"""Re-enabling adapter schedules poll at last_poll + cadence.
|
||||
|
||||
If adapter was disabled for a while and then re-enabled, the next
|
||||
poll should be at (last_completed_poll + cadence_s), not immediately
|
||||
(unless that time has already passed).
|
||||
"""
|
||||
from central.supervisor import AdapterState
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test"
|
||||
mock_adapter.cadence_s = 60
|
||||
|
||||
# Last poll was 30 seconds ago, then adapter was disabled
|
||||
last_poll = datetime.now(timezone.utc) - timedelta(seconds=30)
|
||||
|
||||
# Re-enabled config
|
||||
config = AdapterConfig(
|
||||
name="test",
|
||||
enabled=True,
|
||||
cadence_s=60,
|
||||
settings={},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
state = AdapterState(
|
||||
name="test",
|
||||
adapter=mock_adapter,
|
||||
config=config,
|
||||
last_completed_poll=last_poll,
|
||||
)
|
||||
|
||||
# Calculate next poll time
|
||||
now = datetime.now(timezone.utc)
|
||||
next_poll_at = last_poll.timestamp() + config.cadence_s
|
||||
wait_time = max(0, next_poll_at - now.timestamp())
|
||||
|
||||
# Should wait ~30 more seconds (60 - 30 = 30)
|
||||
assert abs(wait_time - 30) < 2, (
|
||||
f"Expected ~30s wait after re-enable, got {wait_time}s. "
|
||||
f"Rate limit violated on enable→disable→enable sequence."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_rapid_cadence_changes_no_extra_polls(self) -> None:
|
||||
"""Multiple rapid cadence changes don't cause extra polls.
|
||||
|
||||
If NOTIFY fires rapidly (60→90→120→90), the final schedule should
|
||||
still be based on last_completed_poll + final_cadence.
|
||||
"""
|
||||
from central.supervisor import AdapterState
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test"
|
||||
mock_adapter.cadence_s = 60
|
||||
|
||||
# Last poll was 20 seconds ago
|
||||
last_poll = datetime.now(timezone.utc) - timedelta(seconds=20)
|
||||
|
||||
state = AdapterState(
|
||||
name="test",
|
||||
adapter=mock_adapter,
|
||||
config=AdapterConfig(
|
||||
name="test",
|
||||
enabled=True,
|
||||
cadence_s=60,
|
||||
settings={},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
),
|
||||
last_completed_poll=last_poll,
|
||||
)
|
||||
|
||||
# Simulate rapid cadence changes
|
||||
for cadence in [90, 120, 90]: # Final cadence is 90
|
||||
state.config = AdapterConfig(
|
||||
name="test",
|
||||
enabled=True,
|
||||
cadence_s=cadence,
|
||||
settings={},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
state.adapter.cadence_s = cadence
|
||||
|
||||
# Final schedule should be last_poll + 90
|
||||
now = datetime.now(timezone.utc)
|
||||
final_cadence = 90
|
||||
next_poll_at = last_poll.timestamp() + final_cadence
|
||||
wait_time = max(0, next_poll_at - now.timestamp())
|
||||
|
||||
# Should wait ~70 seconds (90 - 20 = 70)
|
||||
assert abs(wait_time - 70) < 2, (
|
||||
f"Expected ~70s wait after rapid changes, got {wait_time}s. "
|
||||
f"Multiple NOTIFYs should not cause extra polls."
|
||||
)
|
||||
|
||||
"""Tests for supervisor hot-reload and rate-limiting behavior."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from central.config_models import AdapterConfig
|
||||
from central.config_source import DbConfigSource
|
||||
from central.config_store import ConfigStore
|
||||
from central.crypto import KEY_SIZE, clear_key_cache
|
||||
|
||||
# Test database DSN
|
||||
TEST_DB_DSN = os.environ.get(
|
||||
"CENTRAL_TEST_DB_DSN",
|
||||
"postgresql://central_test:testpass@localhost/central_test",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def master_key_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
|
||||
"""Create a master key file for the test session."""
|
||||
key = os.urandom(KEY_SIZE)
|
||||
key_path = tmp_path_factory.mktemp("keys") / "master.key"
|
||||
key_path.write_text(base64.b64encode(key).decode())
|
||||
return key_path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Configure master key path for all tests."""
|
||||
clear_key_cache()
|
||||
monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN)
|
||||
monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path))
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_conn() -> asyncpg.Connection:
|
||||
"""Get a direct database connection for setup/teardown."""
|
||||
conn = await asyncpg.connect(TEST_DB_DSN)
|
||||
yield conn
|
||||
await conn.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def clean_config_schema(db_conn: asyncpg.Connection) -> None:
|
||||
"""Ensure config schema exists and is clean before each test."""
|
||||
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
|
||||
await db_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS config.adapters (
|
||||
name TEXT PRIMARY KEY,
|
||||
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||
cadence_s INTEGER NOT NULL,
|
||||
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
paused_at TIMESTAMPTZ,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
)
|
||||
""")
|
||||
# Create notify trigger
|
||||
await db_conn.execute("""
|
||||
CREATE OR REPLACE FUNCTION config.notify_config_change()
|
||||
RETURNS trigger AS $$
|
||||
DECLARE
|
||||
key_value TEXT;
|
||||
BEGIN
|
||||
IF TG_TABLE_NAME = 'adapters' THEN
|
||||
key_value := COALESCE(NEW.name, OLD.name, '');
|
||||
ELSE
|
||||
key_value := '';
|
||||
END IF;
|
||||
PERFORM pg_notify('config_changed', TG_TABLE_NAME || ':' || key_value);
|
||||
RETURN COALESCE(NEW, OLD);
|
||||
END;
|
||||
$$ LANGUAGE plpgsql
|
||||
""")
|
||||
await db_conn.execute("""
|
||||
DROP TRIGGER IF EXISTS adapters_notify ON config.adapters;
|
||||
CREATE TRIGGER adapters_notify
|
||||
AFTER INSERT OR UPDATE OR DELETE ON config.adapters
|
||||
FOR EACH ROW EXECUTE FUNCTION config.notify_config_change()
|
||||
""")
|
||||
await db_conn.execute("DELETE FROM config.adapters")
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def config_store(clean_config_schema: None) -> ConfigStore:
|
||||
"""Create a ConfigStore connected to the test database."""
|
||||
store = await ConfigStore.create(TEST_DB_DSN)
|
||||
yield store
|
||||
await store.close()
|
||||
|
||||
|
||||
class TestDbConfigSourceNotifications:
|
||||
"""Tests for DbConfigSource NOTIFY integration."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_watch_receives_notifications(
|
||||
self,
|
||||
config_store: ConfigStore,
|
||||
db_conn: asyncpg.Connection,
|
||||
) -> None:
|
||||
"""watch_for_changes receives NOTIFY when adapter changes."""
|
||||
source = DbConfigSource(config_store)
|
||||
notifications: list[tuple[str, str]] = []
|
||||
notification_received = asyncio.Event()
|
||||
|
||||
async def callback(table: str, key: str) -> None:
|
||||
notifications.append((table, key))
|
||||
notification_received.set()
|
||||
|
||||
# Start watching in background
|
||||
watch_task = asyncio.create_task(source.watch_for_changes(callback))
|
||||
|
||||
try:
|
||||
# Wait for listener to connect
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Insert an adapter via direct connection (not through store)
|
||||
# This triggers the NOTIFY
|
||||
await db_conn.execute("""
|
||||
INSERT INTO config.adapters (name, enabled, cadence_s, settings)
|
||||
VALUES ('test_adapter', true, 60, '{}'::jsonb)
|
||||
""")
|
||||
|
||||
# Wait for notification
|
||||
await asyncio.wait_for(notification_received.wait(), timeout=5.0)
|
||||
|
||||
assert len(notifications) >= 1
|
||||
assert notifications[0] == ("adapters", "test_adapter")
|
||||
|
||||
finally:
|
||||
watch_task.cancel()
|
||||
try:
|
||||
await watch_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
class TestRateLimitGuarantee:
|
||||
"""Tests for rate-limit guarantees during hot-reload.
|
||||
|
||||
These tests verify the critical invariant: cadence changes must not
|
||||
cause extra API calls before (last_poll + new_cadence).
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cadence_change_respects_last_poll_time(self) -> None:
|
||||
"""Changing cadence mid-cycle schedules next poll at last_poll + new_cadence.
|
||||
|
||||
This is the core rate-limit guarantee test (gate 3).
|
||||
"""
|
||||
# Import supervisor module to access AdapterState
|
||||
from central.supervisor import AdapterState
|
||||
|
||||
# Mock adapter
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test"
|
||||
mock_adapter.cadence_s = 60
|
||||
|
||||
# Create adapter state with a known last_completed_poll time
|
||||
last_poll = datetime.now(timezone.utc) - timedelta(seconds=30)
|
||||
|
||||
config = AdapterConfig(
|
||||
name="test",
|
||||
enabled=True,
|
||||
cadence_s=60, # Original cadence
|
||||
settings={},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
state = AdapterState(
|
||||
name="test",
|
||||
adapter=mock_adapter,
|
||||
config=config,
|
||||
last_completed_poll=last_poll,
|
||||
)
|
||||
|
||||
# Simulate cadence change to 90 seconds
|
||||
new_config = AdapterConfig(
|
||||
name="test",
|
||||
enabled=True,
|
||||
cadence_s=90, # New cadence
|
||||
settings={},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Update state as reschedule would
|
||||
state.config = new_config
|
||||
state.adapter.cadence_s = 90
|
||||
|
||||
# Calculate expected next poll time
|
||||
expected_next_poll = last_poll + timedelta(seconds=90)
|
||||
now = datetime.now(timezone.utc)
|
||||
expected_wait = max(0, (expected_next_poll - now).total_seconds())
|
||||
|
||||
# The wait time should be based on last_poll + new_cadence
|
||||
# Since last_poll was 30 seconds ago and new cadence is 90,
|
||||
# we should wait 60 more seconds (90 - 30 = 60)
|
||||
actual_next_poll = last_poll.timestamp() + new_config.cadence_s
|
||||
actual_wait = max(0, actual_next_poll - now.timestamp())
|
||||
|
||||
# Allow 1 second tolerance for timing
|
||||
assert abs(actual_wait - 60) < 2, (
|
||||
f"Expected ~60s wait, got {actual_wait}s. "
|
||||
f"Rate limit violated: poll would happen before last_poll + new_cadence"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cadence_increase_after_gap_polls_immediately(self) -> None:
|
||||
"""When last_poll + new_cadence is already past, poll immediately.
|
||||
|
||||
If operator increases cadence to 120s after a gap of 150s,
|
||||
the poll should happen now (not wait another 120s).
|
||||
"""
|
||||
from central.supervisor import AdapterState
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test"
|
||||
mock_adapter.cadence_s = 60
|
||||
|
||||
# Last poll was 150 seconds ago
|
||||
last_poll = datetime.now(timezone.utc) - timedelta(seconds=150)
|
||||
|
||||
config = AdapterConfig(
|
||||
name="test",
|
||||
enabled=True,
|
||||
cadence_s=120, # Increased cadence
|
||||
settings={},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
state = AdapterState(
|
||||
name="test",
|
||||
adapter=mock_adapter,
|
||||
config=config,
|
||||
last_completed_poll=last_poll,
|
||||
)
|
||||
|
||||
# Calculate next poll time
|
||||
now = datetime.now(timezone.utc)
|
||||
next_poll_at = last_poll.timestamp() + config.cadence_s
|
||||
wait_time = max(0, next_poll_at - now.timestamp())
|
||||
|
||||
# Since 150 > 120, next poll should be immediate (wait_time ~= 0)
|
||||
assert wait_time < 1, (
|
||||
f"Expected immediate poll (wait ~0s), got {wait_time}s. "
|
||||
f"After a gap exceeding new cadence, poll should happen now."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enable_disable_enable_respects_rate_limit(self) -> None:
|
||||
"""Re-enabling adapter schedules poll at last_poll + cadence.
|
||||
|
||||
If adapter was disabled for a while and then re-enabled, the next
|
||||
poll should be at (last_completed_poll + cadence_s), not immediately
|
||||
(unless that time has already passed).
|
||||
"""
|
||||
from central.supervisor import AdapterState
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test"
|
||||
mock_adapter.cadence_s = 60
|
||||
|
||||
# Last poll was 30 seconds ago, then adapter was disabled
|
||||
last_poll = datetime.now(timezone.utc) - timedelta(seconds=30)
|
||||
|
||||
# Re-enabled config
|
||||
config = AdapterConfig(
|
||||
name="test",
|
||||
enabled=True,
|
||||
cadence_s=60,
|
||||
settings={},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
state = AdapterState(
|
||||
name="test",
|
||||
adapter=mock_adapter,
|
||||
config=config,
|
||||
last_completed_poll=last_poll,
|
||||
)
|
||||
|
||||
# Calculate next poll time
|
||||
now = datetime.now(timezone.utc)
|
||||
next_poll_at = last_poll.timestamp() + config.cadence_s
|
||||
wait_time = max(0, next_poll_at - now.timestamp())
|
||||
|
||||
# Should wait ~30 more seconds (60 - 30 = 30)
|
||||
assert abs(wait_time - 30) < 2, (
|
||||
f"Expected ~30s wait after re-enable, got {wait_time}s. "
|
||||
f"Rate limit violated on enable→disable→enable sequence."
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_rapid_cadence_changes_no_extra_polls(self) -> None:
|
||||
"""Multiple rapid cadence changes don't cause extra polls.
|
||||
|
||||
If NOTIFY fires rapidly (60→90→120→90), the final schedule should
|
||||
still be based on last_completed_poll + final_cadence.
|
||||
"""
|
||||
from central.supervisor import AdapterState
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.name = "test"
|
||||
mock_adapter.cadence_s = 60
|
||||
|
||||
# Last poll was 20 seconds ago
|
||||
last_poll = datetime.now(timezone.utc) - timedelta(seconds=20)
|
||||
|
||||
state = AdapterState(
|
||||
name="test",
|
||||
adapter=mock_adapter,
|
||||
config=AdapterConfig(
|
||||
name="test",
|
||||
enabled=True,
|
||||
cadence_s=60,
|
||||
settings={},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
),
|
||||
last_completed_poll=last_poll,
|
||||
)
|
||||
|
||||
# Simulate rapid cadence changes
|
||||
for cadence in [90, 120, 90]: # Final cadence is 90
|
||||
state.config = AdapterConfig(
|
||||
name="test",
|
||||
enabled=True,
|
||||
cadence_s=cadence,
|
||||
settings={},
|
||||
paused_at=None,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
state.adapter.cadence_s = cadence
|
||||
|
||||
# Final schedule should be last_poll + 90
|
||||
now = datetime.now(timezone.utc)
|
||||
final_cadence = 90
|
||||
next_poll_at = last_poll.timestamp() + final_cadence
|
||||
wait_time = max(0, next_poll_at - now.timestamp())
|
||||
|
||||
# Should wait ~70 seconds (90 - 20 = 70)
|
||||
assert abs(wait_time - 70) < 2, (
|
||||
f"Expected ~70s wait after rapid changes, got {wait_time}s. "
|
||||
f"Multiple NOTIFYs should not cause extra polls."
|
||||
)
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -1,482 +1,482 @@
|
|||
"""Tests for USGS earthquake adapter."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
from central.adapters.usgs_quake import (
|
||||
USGSQuakeAdapter,
|
||||
magnitude_tier,
|
||||
magnitude_to_severity,
|
||||
)
|
||||
from central.config_models import AdapterConfig, RegionConfig
|
||||
from central.models import Event, Geo
|
||||
|
||||
|
||||
# Sample USGS GeoJSON response
|
||||
SAMPLE_GEOJSON = {
|
||||
"type": "FeatureCollection",
|
||||
"metadata": {
|
||||
"generated": 1715878800000,
|
||||
"url": "https://earthquake.usgs.gov/earthquakes/feed/v1.0/summary/all_hour.geojson",
|
||||
"title": "USGS All Earthquakes, Past Hour",
|
||||
"status": 200,
|
||||
"api": "1.10.3",
|
||||
"count": 3
|
||||
},
|
||||
"features": [
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {
|
||||
"mag": 2.5,
|
||||
"place": "10km N of Boise, Idaho",
|
||||
"time": 1715878500000,
|
||||
"updated": 1715878600000,
|
||||
"tz": None,
|
||||
"url": "https://earthquake.usgs.gov/earthquakes/eventpage/us1234",
|
||||
"detail": "https://earthquake.usgs.gov/earthquakes/feed/v1.0/detail/us1234.geojson",
|
||||
"felt": None,
|
||||
"cdi": None,
|
||||
"mmi": None,
|
||||
"alert": None,
|
||||
"status": "automatic",
|
||||
"tsunami": 0,
|
||||
"sig": 100,
|
||||
"net": "us",
|
||||
"code": "1234",
|
||||
"ids": ",us1234,",
|
||||
"sources": ",us,",
|
||||
"types": ",origin,",
|
||||
"nst": 10,
|
||||
"dmin": 0.5,
|
||||
"rms": 0.3,
|
||||
"gap": 100,
|
||||
"magType": "ml",
|
||||
"type": "earthquake",
|
||||
"title": "M 2.5 - 10km N of Boise, Idaho"
|
||||
},
|
||||
"geometry": {
|
||||
"type": "Point",
|
||||
"coordinates": [-116.2, 43.7, 10.5]
|
||||
},
|
||||
"id": "us1234"
|
||||
},
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {
|
||||
"mag": 4.5,
|
||||
"place": "20km S of Portland, Oregon",
|
||||
"time": 1715878400000,
|
||||
"updated": 1715878500000,
|
||||
"tz": None,
|
||||
"url": "https://earthquake.usgs.gov/earthquakes/eventpage/us5678",
|
||||
"detail": "https://earthquake.usgs.gov/earthquakes/feed/v1.0/detail/us5678.geojson",
|
||||
"felt": 50,
|
||||
"cdi": 4.0,
|
||||
"mmi": 3.5,
|
||||
"alert": "green",
|
||||
"status": "reviewed",
|
||||
"tsunami": 0,
|
||||
"sig": 300,
|
||||
"net": "us",
|
||||
"code": "5678",
|
||||
"ids": ",us5678,",
|
||||
"sources": ",us,",
|
||||
"types": ",origin,shakemap,",
|
||||
"nst": 25,
|
||||
"dmin": 0.2,
|
||||
"rms": 0.2,
|
||||
"gap": 50,
|
||||
"magType": "mw",
|
||||
"type": "earthquake",
|
||||
"title": "M 4.5 - 20km S of Portland, Oregon"
|
||||
},
|
||||
"geometry": {
|
||||
"type": "Point",
|
||||
"coordinates": [-122.6, 45.3, 15.0]
|
||||
},
|
||||
"id": "us5678"
|
||||
},
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {
|
||||
"mag": 3.0,
|
||||
"place": "50km E of San Francisco, California",
|
||||
"time": 1715878300000,
|
||||
"updated": 1715878400000,
|
||||
"tz": None,
|
||||
"url": "https://earthquake.usgs.gov/earthquakes/eventpage/us9999",
|
||||
"detail": "https://earthquake.usgs.gov/earthquakes/feed/v1.0/detail/us9999.geojson",
|
||||
"felt": None,
|
||||
"cdi": None,
|
||||
"mmi": None,
|
||||
"alert": None,
|
||||
"status": "automatic",
|
||||
"tsunami": 0,
|
||||
"sig": 150,
|
||||
"net": "us",
|
||||
"code": "9999",
|
||||
"ids": ",us9999,",
|
||||
"sources": ",us,",
|
||||
"types": ",origin,",
|
||||
"nst": 15,
|
||||
"dmin": 0.3,
|
||||
"rms": 0.25,
|
||||
"gap": 80,
|
||||
"magType": "ml",
|
||||
"type": "earthquake",
|
||||
"title": "M 3.0 - 50km E of San Francisco, California"
|
||||
},
|
||||
"geometry": {
|
||||
"type": "Point",
|
||||
"coordinates": [-121.5, 37.8, 8.0]
|
||||
},
|
||||
"id": "us9999"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Sample with null magnitude
|
||||
SAMPLE_NULL_MAG = {
|
||||
"type": "FeatureCollection",
|
||||
"metadata": {"count": 1},
|
||||
"features": [
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {
|
||||
"mag": None,
|
||||
"place": "Quarry blast",
|
||||
"time": 1715878500000,
|
||||
"type": "quarry blast"
|
||||
},
|
||||
"geometry": {
|
||||
"type": "Point",
|
||||
"coordinates": [-116.0, 44.0, 0.0]
|
||||
},
|
||||
"id": "usquarry1"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def make_adapter_config(
|
||||
region: dict | None = None,
|
||||
feed: str = "all_hour",
|
||||
) -> AdapterConfig:
|
||||
"""Create an AdapterConfig for testing."""
|
||||
settings = {"feed": feed}
|
||||
if region:
|
||||
settings["region"] = region
|
||||
else:
|
||||
settings["region"] = {
|
||||
"north": 49.5,
|
||||
"south": 40.0,
|
||||
"east": -110.0,
|
||||
"west": -125.0,
|
||||
}
|
||||
|
||||
return AdapterConfig(
|
||||
name="usgs_quake",
|
||||
enabled=True,
|
||||
cadence_s=60,
|
||||
settings=settings,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db_path():
|
||||
"""Create a temporary database path for testing."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
yield Path(f.name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_store():
|
||||
"""Create a mock ConfigStore."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
class TestMagnitudeTier:
|
||||
"""Test magnitude tier classification."""
|
||||
|
||||
def test_minor(self):
|
||||
assert magnitude_tier(0.5) == "minor"
|
||||
assert magnitude_tier(2.9) == "minor"
|
||||
|
||||
def test_light(self):
|
||||
assert magnitude_tier(3.0) == "light"
|
||||
assert magnitude_tier(3.9) == "light"
|
||||
|
||||
def test_moderate(self):
|
||||
assert magnitude_tier(4.0) == "moderate"
|
||||
assert magnitude_tier(4.9) == "moderate"
|
||||
|
||||
def test_strong(self):
|
||||
assert magnitude_tier(5.0) == "strong"
|
||||
assert magnitude_tier(5.9) == "strong"
|
||||
|
||||
def test_major(self):
|
||||
assert magnitude_tier(6.0) == "major"
|
||||
assert magnitude_tier(6.9) == "major"
|
||||
|
||||
def test_great(self):
|
||||
assert magnitude_tier(7.0) == "great"
|
||||
assert magnitude_tier(9.5) == "great"
|
||||
|
||||
|
||||
class TestMagnitudeToSeverity:
|
||||
"""Test magnitude to severity mapping."""
|
||||
|
||||
def test_severity_levels(self):
|
||||
assert magnitude_to_severity(2.0) == 0
|
||||
assert magnitude_to_severity(3.5) == 1
|
||||
assert magnitude_to_severity(4.5) == 2
|
||||
assert magnitude_to_severity(5.5) == 3
|
||||
assert magnitude_to_severity(6.5) == 4
|
||||
assert magnitude_to_severity(7.5) == 5
|
||||
|
||||
|
||||
class TestRegionFiltering:
|
||||
"""Test region/bbox filtering."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filters_out_of_bbox(self, temp_db_path, mock_config_store):
|
||||
"""Test that quakes outside bbox are filtered."""
|
||||
# Region covers PNW only (north of 40, west of -110)
|
||||
config = make_adapter_config(
|
||||
region={"north": 49.5, "south": 40.0, "east": -110.0, "west": -125.0}
|
||||
)
|
||||
adapter = USGSQuakeAdapter(
|
||||
config=config,
|
||||
config_store=mock_config_store,
|
||||
cursor_db_path=temp_db_path,
|
||||
)
|
||||
await adapter.startup()
|
||||
|
||||
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = SAMPLE_GEOJSON
|
||||
|
||||
events = []
|
||||
async for event in adapter.poll():
|
||||
events.append(event)
|
||||
|
||||
# us1234 (Boise) and us5678 (Portland) are in bbox
|
||||
# us9999 (SF, lat 37.8) is outside bbox (south < 40)
|
||||
assert len(events) == 2
|
||||
event_ids = {e.id for e in events}
|
||||
assert "us1234" in event_ids
|
||||
assert "us5678" in event_ids
|
||||
assert "us9999" not in event_ids
|
||||
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
class TestDeduplication:
|
||||
"""Test deduplication logic."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedup_marks_published(self, temp_db_path, mock_config_store):
|
||||
config = make_adapter_config()
|
||||
adapter = USGSQuakeAdapter(
|
||||
config=config,
|
||||
config_store=mock_config_store,
|
||||
cursor_db_path=temp_db_path,
|
||||
)
|
||||
await adapter.startup()
|
||||
|
||||
event_id = "us1234"
|
||||
|
||||
assert not adapter.is_published(event_id)
|
||||
adapter.mark_published(event_id)
|
||||
assert adapter.is_published(event_id)
|
||||
|
||||
await adapter.shutdown()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_second_poll_no_duplicates(self, temp_db_path, mock_config_store):
|
||||
"""Test that second poll with same events yields nothing."""
|
||||
config = make_adapter_config()
|
||||
adapter = USGSQuakeAdapter(
|
||||
config=config,
|
||||
config_store=mock_config_store,
|
||||
cursor_db_path=temp_db_path,
|
||||
)
|
||||
await adapter.startup()
|
||||
|
||||
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = SAMPLE_GEOJSON
|
||||
|
||||
# First poll
|
||||
events1 = []
|
||||
async for event in adapter.poll():
|
||||
events1.append(event)
|
||||
|
||||
# Second poll - same data
|
||||
events2 = []
|
||||
async for event in adapter.poll():
|
||||
events2.append(event)
|
||||
|
||||
# First poll should have events (2 in bbox)
|
||||
assert len(events1) == 2
|
||||
# Second poll should have 0 (all deduped)
|
||||
assert len(events2) == 0
|
||||
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
class TestNullMagnitude:
|
||||
"""Test handling of null magnitude events."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_null_magnitude(self, temp_db_path, mock_config_store):
|
||||
"""Test that events with null magnitude are skipped."""
|
||||
config = make_adapter_config()
|
||||
adapter = USGSQuakeAdapter(
|
||||
config=config,
|
||||
config_store=mock_config_store,
|
||||
cursor_db_path=temp_db_path,
|
||||
)
|
||||
await adapter.startup()
|
||||
|
||||
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = SAMPLE_NULL_MAG
|
||||
|
||||
events = []
|
||||
async for event in adapter.poll():
|
||||
events.append(event)
|
||||
|
||||
# Should skip the null-magnitude event
|
||||
assert len(events) == 0
|
||||
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
class TestEventGeneration:
|
||||
"""Test Event generation from features."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_category(self, temp_db_path, mock_config_store):
|
||||
config = make_adapter_config()
|
||||
adapter = USGSQuakeAdapter(
|
||||
config=config,
|
||||
config_store=mock_config_store,
|
||||
cursor_db_path=temp_db_path,
|
||||
)
|
||||
await adapter.startup()
|
||||
|
||||
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = SAMPLE_GEOJSON
|
||||
|
||||
events = []
|
||||
async for event in adapter.poll():
|
||||
events.append(event)
|
||||
|
||||
# Check categories
|
||||
categories = {e.category for e in events}
|
||||
# us1234 is M2.5 -> minor, us5678 is M4.5 -> moderate
|
||||
assert "quake.event.minor" in categories
|
||||
assert "quake.event.moderate" in categories
|
||||
|
||||
await adapter.shutdown()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_severity(self, temp_db_path, mock_config_store):
|
||||
config = make_adapter_config()
|
||||
adapter = USGSQuakeAdapter(
|
||||
config=config,
|
||||
config_store=mock_config_store,
|
||||
cursor_db_path=temp_db_path,
|
||||
)
|
||||
await adapter.startup()
|
||||
|
||||
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = SAMPLE_GEOJSON
|
||||
|
||||
events = []
|
||||
async for event in adapter.poll():
|
||||
events.append(event)
|
||||
|
||||
# Find events by ID
|
||||
events_by_id = {e.id: e for e in events}
|
||||
|
||||
# M2.5 -> severity 0
|
||||
assert events_by_id["us1234"].severity == 0
|
||||
# M4.5 -> severity 2
|
||||
assert events_by_id["us5678"].severity == 2
|
||||
|
||||
await adapter.shutdown()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_geo(self, temp_db_path, mock_config_store):
|
||||
config = make_adapter_config()
|
||||
adapter = USGSQuakeAdapter(
|
||||
config=config,
|
||||
config_store=mock_config_store,
|
||||
cursor_db_path=temp_db_path,
|
||||
)
|
||||
await adapter.startup()
|
||||
|
||||
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = SAMPLE_GEOJSON
|
||||
|
||||
events = []
|
||||
async for event in adapter.poll():
|
||||
events.append(event)
|
||||
|
||||
events_by_id = {e.id: e for e in events}
|
||||
|
||||
# Check Boise quake coordinates
|
||||
boise = events_by_id["us1234"]
|
||||
assert boise.geo.centroid == (-116.2, 43.7)
|
||||
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
class TestApplyConfig:
|
||||
"""Test hot-reload configuration application."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_config_updates_region(self, temp_db_path, mock_config_store):
|
||||
config = make_adapter_config(
|
||||
region={"north": 49.5, "south": 40.0, "east": -110.0, "west": -125.0}
|
||||
)
|
||||
adapter = USGSQuakeAdapter(
|
||||
config=config,
|
||||
config_store=mock_config_store,
|
||||
cursor_db_path=temp_db_path,
|
||||
)
|
||||
await adapter.startup()
|
||||
|
||||
assert adapter.region.north == 49.5
|
||||
|
||||
new_config = make_adapter_config(
|
||||
region={"north": 48.0, "south": 45.0, "east": -115.0, "west": -125.0}
|
||||
)
|
||||
await adapter.apply_config(new_config)
|
||||
|
||||
assert adapter.region.north == 48.0
|
||||
assert adapter.region.south == 45.0
|
||||
|
||||
await adapter.shutdown()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_config_updates_feed(self, temp_db_path, mock_config_store):
|
||||
config = make_adapter_config(feed="all_hour")
|
||||
adapter = USGSQuakeAdapter(
|
||||
config=config,
|
||||
config_store=mock_config_store,
|
||||
cursor_db_path=temp_db_path,
|
||||
)
|
||||
await adapter.startup()
|
||||
|
||||
assert adapter._feed == "all_hour"
|
||||
|
||||
new_config = make_adapter_config(feed="all_day")
|
||||
await adapter.apply_config(new_config)
|
||||
|
||||
assert adapter._feed == "all_day"
|
||||
|
||||
await adapter.shutdown()
|
||||
"""Tests for USGS earthquake adapter."""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
from central.adapters.usgs_quake import (
|
||||
USGSQuakeAdapter,
|
||||
magnitude_tier,
|
||||
magnitude_to_severity,
|
||||
)
|
||||
from central.config_models import AdapterConfig, RegionConfig
|
||||
from central.models import Event, Geo
|
||||
|
||||
|
||||
# Sample USGS GeoJSON response
|
||||
SAMPLE_GEOJSON = {
|
||||
"type": "FeatureCollection",
|
||||
"metadata": {
|
||||
"generated": 1715878800000,
|
||||
"url": "https://earthquake.usgs.gov/earthquakes/feed/v1.0/summary/all_hour.geojson",
|
||||
"title": "USGS All Earthquakes, Past Hour",
|
||||
"status": 200,
|
||||
"api": "1.10.3",
|
||||
"count": 3
|
||||
},
|
||||
"features": [
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {
|
||||
"mag": 2.5,
|
||||
"place": "10km N of Boise, Idaho",
|
||||
"time": 1715878500000,
|
||||
"updated": 1715878600000,
|
||||
"tz": None,
|
||||
"url": "https://earthquake.usgs.gov/earthquakes/eventpage/us1234",
|
||||
"detail": "https://earthquake.usgs.gov/earthquakes/feed/v1.0/detail/us1234.geojson",
|
||||
"felt": None,
|
||||
"cdi": None,
|
||||
"mmi": None,
|
||||
"alert": None,
|
||||
"status": "automatic",
|
||||
"tsunami": 0,
|
||||
"sig": 100,
|
||||
"net": "us",
|
||||
"code": "1234",
|
||||
"ids": ",us1234,",
|
||||
"sources": ",us,",
|
||||
"types": ",origin,",
|
||||
"nst": 10,
|
||||
"dmin": 0.5,
|
||||
"rms": 0.3,
|
||||
"gap": 100,
|
||||
"magType": "ml",
|
||||
"type": "earthquake",
|
||||
"title": "M 2.5 - 10km N of Boise, Idaho"
|
||||
},
|
||||
"geometry": {
|
||||
"type": "Point",
|
||||
"coordinates": [-116.2, 43.7, 10.5]
|
||||
},
|
||||
"id": "us1234"
|
||||
},
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {
|
||||
"mag": 4.5,
|
||||
"place": "20km S of Portland, Oregon",
|
||||
"time": 1715878400000,
|
||||
"updated": 1715878500000,
|
||||
"tz": None,
|
||||
"url": "https://earthquake.usgs.gov/earthquakes/eventpage/us5678",
|
||||
"detail": "https://earthquake.usgs.gov/earthquakes/feed/v1.0/detail/us5678.geojson",
|
||||
"felt": 50,
|
||||
"cdi": 4.0,
|
||||
"mmi": 3.5,
|
||||
"alert": "green",
|
||||
"status": "reviewed",
|
||||
"tsunami": 0,
|
||||
"sig": 300,
|
||||
"net": "us",
|
||||
"code": "5678",
|
||||
"ids": ",us5678,",
|
||||
"sources": ",us,",
|
||||
"types": ",origin,shakemap,",
|
||||
"nst": 25,
|
||||
"dmin": 0.2,
|
||||
"rms": 0.2,
|
||||
"gap": 50,
|
||||
"magType": "mw",
|
||||
"type": "earthquake",
|
||||
"title": "M 4.5 - 20km S of Portland, Oregon"
|
||||
},
|
||||
"geometry": {
|
||||
"type": "Point",
|
||||
"coordinates": [-122.6, 45.3, 15.0]
|
||||
},
|
||||
"id": "us5678"
|
||||
},
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {
|
||||
"mag": 3.0,
|
||||
"place": "50km E of San Francisco, California",
|
||||
"time": 1715878300000,
|
||||
"updated": 1715878400000,
|
||||
"tz": None,
|
||||
"url": "https://earthquake.usgs.gov/earthquakes/eventpage/us9999",
|
||||
"detail": "https://earthquake.usgs.gov/earthquakes/feed/v1.0/detail/us9999.geojson",
|
||||
"felt": None,
|
||||
"cdi": None,
|
||||
"mmi": None,
|
||||
"alert": None,
|
||||
"status": "automatic",
|
||||
"tsunami": 0,
|
||||
"sig": 150,
|
||||
"net": "us",
|
||||
"code": "9999",
|
||||
"ids": ",us9999,",
|
||||
"sources": ",us,",
|
||||
"types": ",origin,",
|
||||
"nst": 15,
|
||||
"dmin": 0.3,
|
||||
"rms": 0.25,
|
||||
"gap": 80,
|
||||
"magType": "ml",
|
||||
"type": "earthquake",
|
||||
"title": "M 3.0 - 50km E of San Francisco, California"
|
||||
},
|
||||
"geometry": {
|
||||
"type": "Point",
|
||||
"coordinates": [-121.5, 37.8, 8.0]
|
||||
},
|
||||
"id": "us9999"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Sample with null magnitude
|
||||
SAMPLE_NULL_MAG = {
|
||||
"type": "FeatureCollection",
|
||||
"metadata": {"count": 1},
|
||||
"features": [
|
||||
{
|
||||
"type": "Feature",
|
||||
"properties": {
|
||||
"mag": None,
|
||||
"place": "Quarry blast",
|
||||
"time": 1715878500000,
|
||||
"type": "quarry blast"
|
||||
},
|
||||
"geometry": {
|
||||
"type": "Point",
|
||||
"coordinates": [-116.0, 44.0, 0.0]
|
||||
},
|
||||
"id": "usquarry1"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def make_adapter_config(
|
||||
region: dict | None = None,
|
||||
feed: str = "all_hour",
|
||||
) -> AdapterConfig:
|
||||
"""Create an AdapterConfig for testing."""
|
||||
settings = {"feed": feed}
|
||||
if region:
|
||||
settings["region"] = region
|
||||
else:
|
||||
settings["region"] = {
|
||||
"north": 49.5,
|
||||
"south": 40.0,
|
||||
"east": -110.0,
|
||||
"west": -125.0,
|
||||
}
|
||||
|
||||
return AdapterConfig(
|
||||
name="usgs_quake",
|
||||
enabled=True,
|
||||
cadence_s=60,
|
||||
settings=settings,
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db_path():
|
||||
"""Create a temporary database path for testing."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
yield Path(f.name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_store():
|
||||
"""Create a mock ConfigStore."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
class TestMagnitudeTier:
|
||||
"""Test magnitude tier classification."""
|
||||
|
||||
def test_minor(self):
|
||||
assert magnitude_tier(0.5) == "minor"
|
||||
assert magnitude_tier(2.9) == "minor"
|
||||
|
||||
def test_light(self):
|
||||
assert magnitude_tier(3.0) == "light"
|
||||
assert magnitude_tier(3.9) == "light"
|
||||
|
||||
def test_moderate(self):
|
||||
assert magnitude_tier(4.0) == "moderate"
|
||||
assert magnitude_tier(4.9) == "moderate"
|
||||
|
||||
def test_strong(self):
|
||||
assert magnitude_tier(5.0) == "strong"
|
||||
assert magnitude_tier(5.9) == "strong"
|
||||
|
||||
def test_major(self):
|
||||
assert magnitude_tier(6.0) == "major"
|
||||
assert magnitude_tier(6.9) == "major"
|
||||
|
||||
def test_great(self):
|
||||
assert magnitude_tier(7.0) == "great"
|
||||
assert magnitude_tier(9.5) == "great"
|
||||
|
||||
|
||||
class TestMagnitudeToSeverity:
|
||||
"""Test magnitude to severity mapping."""
|
||||
|
||||
def test_severity_levels(self):
|
||||
assert magnitude_to_severity(2.0) == 0
|
||||
assert magnitude_to_severity(3.5) == 1
|
||||
assert magnitude_to_severity(4.5) == 2
|
||||
assert magnitude_to_severity(5.5) == 3
|
||||
assert magnitude_to_severity(6.5) == 4
|
||||
assert magnitude_to_severity(7.5) == 5
|
||||
|
||||
|
||||
class TestRegionFiltering:
|
||||
"""Test region/bbox filtering."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_filters_out_of_bbox(self, temp_db_path, mock_config_store):
|
||||
"""Test that quakes outside bbox are filtered."""
|
||||
# Region covers PNW only (north of 40, west of -110)
|
||||
config = make_adapter_config(
|
||||
region={"north": 49.5, "south": 40.0, "east": -110.0, "west": -125.0}
|
||||
)
|
||||
adapter = USGSQuakeAdapter(
|
||||
config=config,
|
||||
config_store=mock_config_store,
|
||||
cursor_db_path=temp_db_path,
|
||||
)
|
||||
await adapter.startup()
|
||||
|
||||
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = SAMPLE_GEOJSON
|
||||
|
||||
events = []
|
||||
async for event in adapter.poll():
|
||||
events.append(event)
|
||||
|
||||
# us1234 (Boise) and us5678 (Portland) are in bbox
|
||||
# us9999 (SF, lat 37.8) is outside bbox (south < 40)
|
||||
assert len(events) == 2
|
||||
event_ids = {e.id for e in events}
|
||||
assert "us1234" in event_ids
|
||||
assert "us5678" in event_ids
|
||||
assert "us9999" not in event_ids
|
||||
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
class TestDeduplication:
|
||||
"""Test deduplication logic."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dedup_marks_published(self, temp_db_path, mock_config_store):
|
||||
config = make_adapter_config()
|
||||
adapter = USGSQuakeAdapter(
|
||||
config=config,
|
||||
config_store=mock_config_store,
|
||||
cursor_db_path=temp_db_path,
|
||||
)
|
||||
await adapter.startup()
|
||||
|
||||
event_id = "us1234"
|
||||
|
||||
assert not adapter.is_published(event_id)
|
||||
adapter.mark_published(event_id)
|
||||
assert adapter.is_published(event_id)
|
||||
|
||||
await adapter.shutdown()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_second_poll_no_duplicates(self, temp_db_path, mock_config_store):
|
||||
"""Test that second poll with same events yields nothing."""
|
||||
config = make_adapter_config()
|
||||
adapter = USGSQuakeAdapter(
|
||||
config=config,
|
||||
config_store=mock_config_store,
|
||||
cursor_db_path=temp_db_path,
|
||||
)
|
||||
await adapter.startup()
|
||||
|
||||
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = SAMPLE_GEOJSON
|
||||
|
||||
# First poll
|
||||
events1 = []
|
||||
async for event in adapter.poll():
|
||||
events1.append(event)
|
||||
|
||||
# Second poll - same data
|
||||
events2 = []
|
||||
async for event in adapter.poll():
|
||||
events2.append(event)
|
||||
|
||||
# First poll should have events (2 in bbox)
|
||||
assert len(events1) == 2
|
||||
# Second poll should have 0 (all deduped)
|
||||
assert len(events2) == 0
|
||||
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
class TestNullMagnitude:
|
||||
"""Test handling of null magnitude events."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_null_magnitude(self, temp_db_path, mock_config_store):
|
||||
"""Test that events with null magnitude are skipped."""
|
||||
config = make_adapter_config()
|
||||
adapter = USGSQuakeAdapter(
|
||||
config=config,
|
||||
config_store=mock_config_store,
|
||||
cursor_db_path=temp_db_path,
|
||||
)
|
||||
await adapter.startup()
|
||||
|
||||
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = SAMPLE_NULL_MAG
|
||||
|
||||
events = []
|
||||
async for event in adapter.poll():
|
||||
events.append(event)
|
||||
|
||||
# Should skip the null-magnitude event
|
||||
assert len(events) == 0
|
||||
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
class TestEventGeneration:
|
||||
"""Test Event generation from features."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_category(self, temp_db_path, mock_config_store):
|
||||
config = make_adapter_config()
|
||||
adapter = USGSQuakeAdapter(
|
||||
config=config,
|
||||
config_store=mock_config_store,
|
||||
cursor_db_path=temp_db_path,
|
||||
)
|
||||
await adapter.startup()
|
||||
|
||||
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = SAMPLE_GEOJSON
|
||||
|
||||
events = []
|
||||
async for event in adapter.poll():
|
||||
events.append(event)
|
||||
|
||||
# Check categories
|
||||
categories = {e.category for e in events}
|
||||
# us1234 is M2.5 -> minor, us5678 is M4.5 -> moderate
|
||||
assert "quake.event.minor" in categories
|
||||
assert "quake.event.moderate" in categories
|
||||
|
||||
await adapter.shutdown()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_severity(self, temp_db_path, mock_config_store):
|
||||
config = make_adapter_config()
|
||||
adapter = USGSQuakeAdapter(
|
||||
config=config,
|
||||
config_store=mock_config_store,
|
||||
cursor_db_path=temp_db_path,
|
||||
)
|
||||
await adapter.startup()
|
||||
|
||||
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = SAMPLE_GEOJSON
|
||||
|
||||
events = []
|
||||
async for event in adapter.poll():
|
||||
events.append(event)
|
||||
|
||||
# Find events by ID
|
||||
events_by_id = {e.id: e for e in events}
|
||||
|
||||
# M2.5 -> severity 0
|
||||
assert events_by_id["us1234"].severity == 0
|
||||
# M4.5 -> severity 2
|
||||
assert events_by_id["us5678"].severity == 2
|
||||
|
||||
await adapter.shutdown()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_event_geo(self, temp_db_path, mock_config_store):
|
||||
config = make_adapter_config()
|
||||
adapter = USGSQuakeAdapter(
|
||||
config=config,
|
||||
config_store=mock_config_store,
|
||||
cursor_db_path=temp_db_path,
|
||||
)
|
||||
await adapter.startup()
|
||||
|
||||
with patch.object(adapter, "_fetch_geojson", new_callable=AsyncMock) as mock_fetch:
|
||||
mock_fetch.return_value = SAMPLE_GEOJSON
|
||||
|
||||
events = []
|
||||
async for event in adapter.poll():
|
||||
events.append(event)
|
||||
|
||||
events_by_id = {e.id: e for e in events}
|
||||
|
||||
# Check Boise quake coordinates
|
||||
boise = events_by_id["us1234"]
|
||||
assert boise.geo.centroid == (-116.2, 43.7)
|
||||
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
class TestApplyConfig:
|
||||
"""Test hot-reload configuration application."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_config_updates_region(self, temp_db_path, mock_config_store):
|
||||
config = make_adapter_config(
|
||||
region={"north": 49.5, "south": 40.0, "east": -110.0, "west": -125.0}
|
||||
)
|
||||
adapter = USGSQuakeAdapter(
|
||||
config=config,
|
||||
config_store=mock_config_store,
|
||||
cursor_db_path=temp_db_path,
|
||||
)
|
||||
await adapter.startup()
|
||||
|
||||
assert adapter.region.north == 49.5
|
||||
|
||||
new_config = make_adapter_config(
|
||||
region={"north": 48.0, "south": 45.0, "east": -115.0, "west": -125.0}
|
||||
)
|
||||
await adapter.apply_config(new_config)
|
||||
|
||||
assert adapter.region.north == 48.0
|
||||
assert adapter.region.south == 45.0
|
||||
|
||||
await adapter.shutdown()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_config_updates_feed(self, temp_db_path, mock_config_store):
|
||||
config = make_adapter_config(feed="all_hour")
|
||||
adapter = USGSQuakeAdapter(
|
||||
config=config,
|
||||
config_store=mock_config_store,
|
||||
cursor_db_path=temp_db_path,
|
||||
)
|
||||
await adapter.startup()
|
||||
|
||||
assert adapter._feed == "all_hour"
|
||||
|
||||
new_config = make_adapter_config(feed="all_day")
|
||||
await adapter.apply_config(new_config)
|
||||
|
||||
assert adapter._feed == "all_day"
|
||||
|
||||
await adapter.shutdown()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue