mirror of
https://github.com/zvx-echo6/central.git
synced 2026-05-21 18:14:44 +02:00
Merge refactor/a3b-requires-api-key: requires_api_key enforcement
This commit is contained in:
commit
51be59ee02
8 changed files with 479 additions and 19 deletions
|
|
@ -241,6 +241,14 @@ class ConfigStore:
|
||||||
)
|
)
|
||||||
return result == "DELETE 1"
|
return result == "DELETE 1"
|
||||||
|
|
||||||
|
async def set_adapter_last_error(self, name: str, error: str | None) -> None:
|
||||||
|
"""Set or clear the last_error field on an adapter row."""
|
||||||
|
async with self._pool.acquire() as conn:
|
||||||
|
await conn.execute(
|
||||||
|
"UPDATE config.adapters SET last_error = $1 WHERE name = $2",
|
||||||
|
error, name,
|
||||||
|
)
|
||||||
|
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
# Change notifications
|
# Change notifications
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -1318,27 +1318,45 @@ async def adapters_list(
|
||||||
templates = _get_templates()
|
templates = _get_templates()
|
||||||
pool = get_pool()
|
pool = get_pool()
|
||||||
operator = request.state.operator
|
operator = request.state.operator
|
||||||
|
adapter_classes = _adapter_classes()
|
||||||
|
|
||||||
async with pool.acquire() as conn:
|
async with pool.acquire() as conn:
|
||||||
rows = await conn.fetch(
|
rows = await conn.fetch(
|
||||||
"""
|
"""
|
||||||
SELECT name, enabled, cadence_s, settings, paused_at, updated_at
|
SELECT name, enabled, cadence_s, settings, paused_at, updated_at, last_error
|
||||||
FROM config.adapters
|
FROM config.adapters
|
||||||
ORDER BY name
|
ORDER BY name
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
adapters = []
|
adapters = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
settings = row["settings"] or {}
|
settings = row["settings"] or {}
|
||||||
adapters.append({
|
adapter_cls = adapter_classes.get(row["name"])
|
||||||
"name": row["name"],
|
|
||||||
"enabled": row["enabled"],
|
# Check if required API key is missing
|
||||||
"cadence_s": row["cadence_s"],
|
api_key_missing = False
|
||||||
"settings": settings,
|
requires_api_key_alias = None
|
||||||
"paused_at": row["paused_at"],
|
if adapter_cls and adapter_cls.requires_api_key is not None:
|
||||||
"updated_at": row["updated_at"],
|
requires_api_key_alias = adapter_cls.requires_api_key
|
||||||
})
|
has_key = await conn.fetchval(
|
||||||
|
"SELECT 1 FROM config.api_keys WHERE alias = $1",
|
||||||
|
requires_api_key_alias,
|
||||||
|
)
|
||||||
|
api_key_missing = not has_key
|
||||||
|
|
||||||
|
adapters.append({
|
||||||
|
"name": row["name"],
|
||||||
|
"display_name": getattr(adapter_cls, "display_name", row["name"]) if adapter_cls else row["name"],
|
||||||
|
"enabled": row["enabled"],
|
||||||
|
"cadence_s": row["cadence_s"],
|
||||||
|
"settings": settings,
|
||||||
|
"paused_at": row["paused_at"],
|
||||||
|
"updated_at": row["updated_at"],
|
||||||
|
"last_error": row["last_error"],
|
||||||
|
"api_key_missing": api_key_missing,
|
||||||
|
"requires_api_key_alias": requires_api_key_alias,
|
||||||
|
})
|
||||||
|
|
||||||
csrf_token = request.state.csrf_token
|
csrf_token = request.state.csrf_token
|
||||||
response = templates.TemplateResponse(
|
response = templates.TemplateResponse(
|
||||||
|
|
@ -1419,6 +1437,18 @@ async def adapters_edit_form(
|
||||||
api_key_rows = await conn.fetch("SELECT alias FROM config.api_keys ORDER BY alias")
|
api_key_rows = await conn.fetch("SELECT alias FROM config.api_keys ORDER BY alias")
|
||||||
api_keys = [{"alias": r["alias"]} for r in api_key_rows]
|
api_keys = [{"alias": r["alias"]} for r in api_key_rows]
|
||||||
|
|
||||||
|
# Check if required API key is missing
|
||||||
|
api_key_missing = False
|
||||||
|
requires_api_key_alias = None
|
||||||
|
if adapter_cls and adapter_cls.requires_api_key is not None:
|
||||||
|
requires_api_key_alias = adapter_cls.requires_api_key
|
||||||
|
async with pool.acquire() as conn:
|
||||||
|
has_key = await conn.fetchval(
|
||||||
|
"SELECT 1 FROM config.api_keys WHERE alias = $1",
|
||||||
|
requires_api_key_alias,
|
||||||
|
)
|
||||||
|
api_key_missing = not has_key
|
||||||
|
|
||||||
csrf_token = request.state.csrf_token
|
csrf_token = request.state.csrf_token
|
||||||
response = templates.TemplateResponse(
|
response = templates.TemplateResponse(
|
||||||
request=request,
|
request=request,
|
||||||
|
|
@ -1433,6 +1463,8 @@ async def adapters_edit_form(
|
||||||
"form_data": None,
|
"form_data": None,
|
||||||
"tile_url": tile_url,
|
"tile_url": tile_url,
|
||||||
"tile_attribution": tile_attribution,
|
"tile_attribution": tile_attribution,
|
||||||
|
"api_key_missing": api_key_missing,
|
||||||
|
"requires_api_key_alias": requires_api_key_alias,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
@ -1634,6 +1666,17 @@ async def adapters_edit_submit(
|
||||||
api_key_rows = await conn.fetch("SELECT alias FROM config.api_keys ORDER BY alias")
|
api_key_rows = await conn.fetch("SELECT alias FROM config.api_keys ORDER BY alias")
|
||||||
api_keys = [{"alias": r["alias"]} for r in api_key_rows]
|
api_keys = [{"alias": r["alias"]} for r in api_key_rows]
|
||||||
|
|
||||||
|
# Check if required API key is missing
|
||||||
|
api_key_missing = False
|
||||||
|
requires_api_key_alias = None
|
||||||
|
if adapter_cls and adapter_cls.requires_api_key is not None:
|
||||||
|
requires_api_key_alias = adapter_cls.requires_api_key
|
||||||
|
has_key = await conn.fetchval(
|
||||||
|
"SELECT 1 FROM config.api_keys WHERE alias = $1",
|
||||||
|
requires_api_key_alias,
|
||||||
|
)
|
||||||
|
api_key_missing = not has_key
|
||||||
|
|
||||||
csrf_token = request.state.csrf_token
|
csrf_token = request.state.csrf_token
|
||||||
response = templates.TemplateResponse(
|
response = templates.TemplateResponse(
|
||||||
request=request,
|
request=request,
|
||||||
|
|
@ -1648,6 +1691,8 @@ async def adapters_edit_submit(
|
||||||
"form_data": form_data,
|
"form_data": form_data,
|
||||||
"tile_url": tile_url,
|
"tile_url": tile_url,
|
||||||
"tile_attribution": tile_attribution,
|
"tile_attribution": tile_attribution,
|
||||||
|
"api_key_missing": api_key_missing,
|
||||||
|
"requires_api_key_alias": requires_api_key_alias,
|
||||||
},
|
},
|
||||||
status_code=200,
|
status_code=200,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,13 @@
|
||||||
</article>
|
</article>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
|
{% if api_key_missing %}
|
||||||
|
<article aria-label="API Key Required" style="background-color: var(--pico-mark-background-color); margin-bottom: 1rem;">
|
||||||
|
<strong>⚠️ API Key Required:</strong> This adapter requires the <code>{{ requires_api_key_alias }}</code> API key to be configured before it can be enabled.
|
||||||
|
<a href="/api-keys">Configure API Keys</a>
|
||||||
|
</article>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
<form method="post" action="/adapters/{{ adapter.name }}">
|
<form method="post" action="/adapters/{{ adapter.name }}">
|
||||||
<input type="hidden" name="csrf_token" value="{{ csrf_token }}">
|
<input type="hidden" name="csrf_token" value="{{ csrf_token }}">
|
||||||
|
|
||||||
|
|
@ -32,8 +39,8 @@
|
||||||
<legend>Core Settings</legend>
|
<legend>Core Settings</legend>
|
||||||
|
|
||||||
<label>
|
<label>
|
||||||
<input type="checkbox" name="enabled" {% if form_data %}{% if form_data.enabled %}checked{% endif %}{% elif adapter.enabled %}checked{% endif %}>
|
<input type="checkbox" name="enabled" {% if form_data %}{% if form_data.enabled %}checked{% endif %}{% elif adapter.enabled %}checked{% endif %}{% if api_key_missing %} disabled{% endif %}>
|
||||||
Enabled
|
Enabled{% if api_key_missing %} <small>(requires API key)</small>{% endif %}
|
||||||
</label>
|
</label>
|
||||||
|
|
||||||
<label for="cadence_s">Cadence (seconds)</label>
|
<label for="cadence_s">Cadence (seconds)</label>
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,12 @@
|
||||||
<tbody>
|
<tbody>
|
||||||
{% for adapter in adapters %}
|
{% for adapter in adapters %}
|
||||||
<tr>
|
<tr>
|
||||||
<td>{{ adapter.name }}</td>
|
<td>
|
||||||
|
{{ adapter.display_name or adapter.name }}
|
||||||
|
{% if adapter.api_key_missing %}
|
||||||
|
<span style="color: var(--pico-color-orange-500); margin-left: 0.5rem;" title="Missing API key: {{ adapter.requires_api_key_alias }}">⚠️ API Key Missing</span>
|
||||||
|
{% endif %}
|
||||||
|
</td>
|
||||||
<td>{% if adapter.enabled %}Yes{% else %}No{% endif %}</td>
|
<td>{% if adapter.enabled %}Yes{% else %}No{% endif %}</td>
|
||||||
<td>{{ adapter.cadence_s }}s</td>
|
<td>{{ adapter.cadence_s }}s</td>
|
||||||
<td>{{ adapter.updated_at.strftime('%Y-%m-%d %H:%M') if adapter.updated_at else '—' }}</td>
|
<td>{{ adapter.updated_at.strftime('%Y-%m-%d %H:%M') if adapter.updated_at else '—' }}</td>
|
||||||
|
|
|
||||||
|
|
@ -266,6 +266,23 @@ class Supervisor:
|
||||||
If the adapter was previously stopped (state exists but task is not running),
|
If the adapter was previously stopped (state exists but task is not running),
|
||||||
reuses the existing state to preserve last_completed_poll for rate limiting.
|
reuses the existing state to preserve last_completed_poll for rate limiting.
|
||||||
"""
|
"""
|
||||||
|
# API key precondition
|
||||||
|
adapter_cls = self._adapters.get(config.name)
|
||||||
|
if adapter_cls is not None and adapter_cls.requires_api_key is not None:
|
||||||
|
alias = adapter_cls.requires_api_key
|
||||||
|
key_value = await self._config_store.get_api_key(alias)
|
||||||
|
if not key_value:
|
||||||
|
error_msg = f"missing api key: {alias}"
|
||||||
|
logger.warning(
|
||||||
|
"Adapter cannot start - api key missing",
|
||||||
|
extra={"adapter": config.name, "alias": alias},
|
||||||
|
)
|
||||||
|
await self._config_store.set_adapter_last_error(config.name, error_msg)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Clear any stale last_error before proceeding
|
||||||
|
await self._config_store.set_adapter_last_error(config.name, None)
|
||||||
|
|
||||||
existing_state = self._adapter_states.get(config.name)
|
existing_state = self._adapter_states.get(config.name)
|
||||||
|
|
||||||
if existing_state is not None:
|
if existing_state is not None:
|
||||||
|
|
|
||||||
|
|
@ -42,9 +42,9 @@ class TestAdaptersListAuthenticated:
|
||||||
|
|
||||||
mock_conn = AsyncMock()
|
mock_conn = AsyncMock()
|
||||||
mock_conn.fetch.return_value = [
|
mock_conn.fetch.return_value = [
|
||||||
{"name": "firms", "enabled": True, "cadence_s": 300, "settings": {"api_key_alias": "firms"}, "paused_at": None, "updated_at": None},
|
{"name": "firms", "enabled": True, "cadence_s": 300, "settings": {"api_key_alias": "firms"}, "paused_at": None, "updated_at": None, "last_error": None},
|
||||||
{"name": "nws", "enabled": True, "cadence_s": 60, "settings": {"contact_email": "test@test.com"}, "paused_at": None, "updated_at": None},
|
{"name": "nws", "enabled": True, "cadence_s": 60, "settings": {"contact_email": "test@test.com"}, "paused_at": None, "updated_at": None, "last_error": None},
|
||||||
{"name": "usgs_quake", "enabled": True, "cadence_s": 120, "settings": {"feed": "all_hour"}, "paused_at": None, "updated_at": None},
|
{"name": "usgs_quake", "enabled": True, "cadence_s": 120, "settings": {"feed": "all_hour"}, "paused_at": None, "updated_at": None, "last_error": None},
|
||||||
]
|
]
|
||||||
|
|
||||||
mock_pool = MagicMock()
|
mock_pool = MagicMock()
|
||||||
|
|
@ -55,9 +55,22 @@ class TestAdaptersListAuthenticated:
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_templates.TemplateResponse.return_value = mock_response
|
mock_templates.TemplateResponse.return_value = mock_response
|
||||||
|
|
||||||
|
# Mock adapter classes
|
||||||
|
mock_firms_cls = MagicMock()
|
||||||
|
mock_firms_cls.requires_api_key = "firms"
|
||||||
|
mock_firms_cls.display_name = "FIRMS"
|
||||||
|
mock_nws_cls = MagicMock()
|
||||||
|
mock_nws_cls.requires_api_key = None
|
||||||
|
mock_nws_cls.display_name = "NWS"
|
||||||
|
mock_usgs_cls = MagicMock()
|
||||||
|
mock_usgs_cls.requires_api_key = None
|
||||||
|
mock_usgs_cls.display_name = "USGS Quake"
|
||||||
|
mock_adapter_classes = {"firms": mock_firms_cls, "nws": mock_nws_cls, "usgs_quake": mock_usgs_cls}
|
||||||
|
|
||||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
result = await adapters_list(mock_request)
|
with patch("central.gui.routes._adapter_classes", return_value=mock_adapter_classes):
|
||||||
|
result = await adapters_list(mock_request)
|
||||||
|
|
||||||
# Verify template was called with adapters
|
# Verify template was called with adapters
|
||||||
call_args = mock_templates.TemplateResponse.call_args
|
call_args = mock_templates.TemplateResponse.call_args
|
||||||
|
|
|
||||||
361
tests/test_requires_api_key.py
Normal file
361
tests/test_requires_api_key.py
Normal file
|
|
@ -0,0 +1,361 @@
|
||||||
|
"""Tests for requires_api_key enforcement."""
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from central.config_models import AdapterConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigStoreSetAdapterLastError:
|
||||||
|
"""Tests for ConfigStore.set_adapter_last_error method."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_adapter_last_error_updates_row(self):
|
||||||
|
"""set_adapter_last_error should update the last_error column."""
|
||||||
|
from central.config_store import ConfigStore
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_conn.execute = AsyncMock()
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock()
|
||||||
|
mock_pool.acquire = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
config_store = ConfigStore.__new__(ConfigStore)
|
||||||
|
config_store._pool = mock_pool
|
||||||
|
|
||||||
|
await config_store.set_adapter_last_error("firms", "missing api key: firms")
|
||||||
|
|
||||||
|
mock_conn.execute.assert_called_once()
|
||||||
|
call_args = mock_conn.execute.call_args[0]
|
||||||
|
assert "UPDATE config.adapters SET last_error" in call_args[0]
|
||||||
|
assert call_args[1] == "missing api key: firms"
|
||||||
|
assert call_args[2] == "firms"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clear_adapter_last_error(self):
|
||||||
|
"""set_adapter_last_error with None should clear the error."""
|
||||||
|
from central.config_store import ConfigStore
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_conn.execute = AsyncMock()
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock()
|
||||||
|
mock_pool.acquire = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
config_store = ConfigStore.__new__(ConfigStore)
|
||||||
|
config_store._pool = mock_pool
|
||||||
|
|
||||||
|
await config_store.set_adapter_last_error("firms", None)
|
||||||
|
|
||||||
|
mock_conn.execute.assert_called_once()
|
||||||
|
call_args = mock_conn.execute.call_args[0]
|
||||||
|
assert call_args[1] is None
|
||||||
|
assert call_args[2] == "firms"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoutesApiKeyMissing:
|
||||||
|
"""Tests for routes api_key_missing computation."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_adapters_list_includes_api_key_missing_flag(self):
|
||||||
|
"""adapters_list should compute api_key_missing for each adapter."""
|
||||||
|
from central.gui.routes import adapters_list
|
||||||
|
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.state = MagicMock()
|
||||||
|
mock_request.state.operator = {"username": "test"}
|
||||||
|
mock_request.state.csrf_token = "test_token"
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_conn.fetch = AsyncMock(return_value=[
|
||||||
|
{"name": "firms", "enabled": False, "cadence_s": 300, "settings": {}, "paused_at": None, "updated_at": None, "last_error": None},
|
||||||
|
])
|
||||||
|
mock_conn.fetchval = AsyncMock(return_value=None) # No API key exists
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock()
|
||||||
|
mock_pool.acquire = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
# Mock adapter class with requires_api_key
|
||||||
|
mock_firms_cls = MagicMock()
|
||||||
|
mock_firms_cls.requires_api_key = "firms"
|
||||||
|
mock_firms_cls.display_name = "FIRMS"
|
||||||
|
|
||||||
|
with patch("central.gui.routes._get_templates") as mock_templates:
|
||||||
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
|
with patch("central.gui.routes._adapter_classes", return_value={"firms": mock_firms_cls}):
|
||||||
|
mock_template_response = MagicMock()
|
||||||
|
mock_templates.return_value.TemplateResponse = MagicMock(return_value=mock_template_response)
|
||||||
|
|
||||||
|
await adapters_list(mock_request)
|
||||||
|
|
||||||
|
# Check the context passed to template
|
||||||
|
call_kwargs = mock_templates.return_value.TemplateResponse.call_args[1]
|
||||||
|
adapters = call_kwargs["context"]["adapters"]
|
||||||
|
|
||||||
|
assert len(adapters) == 1
|
||||||
|
assert adapters[0]["api_key_missing"] is True
|
||||||
|
assert adapters[0]["requires_api_key_alias"] == "firms"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdapterClassRequiresApiKey:
|
||||||
|
"""Tests for adapter class requires_api_key attribute."""
|
||||||
|
|
||||||
|
def test_firms_adapter_requires_api_key(self):
|
||||||
|
"""FIRMS adapter should declare requires_api_key."""
|
||||||
|
from central.adapters.firms import FIRMSAdapter
|
||||||
|
assert FIRMSAdapter.requires_api_key == "firms"
|
||||||
|
|
||||||
|
def test_nws_adapter_no_requires_api_key(self):
|
||||||
|
"""NWS adapter should not require an API key."""
|
||||||
|
from central.adapters.nws import NWSAdapter
|
||||||
|
assert NWSAdapter.requires_api_key is None
|
||||||
|
|
||||||
|
def test_usgs_quake_adapter_no_requires_api_key(self):
|
||||||
|
"""USGS Quake adapter should not require an API key."""
|
||||||
|
from central.adapters.usgs_quake import USGSQuakeAdapter
|
||||||
|
assert USGSQuakeAdapter.requires_api_key is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestSupervisorApiKeyPrecondition:
|
||||||
|
"""Tests for supervisor API key precondition check in _start_adapter."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_adapter_refuses_when_required_key_missing(self, tmp_path: Path):
|
||||||
|
"""Adapter with requires_api_key but missing key should not start."""
|
||||||
|
from central.supervisor import Supervisor
|
||||||
|
from central.adapters.firms import FIRMSAdapter
|
||||||
|
|
||||||
|
# Create mock config store
|
||||||
|
mock_config_store = MagicMock()
|
||||||
|
mock_config_store.get_api_key = AsyncMock(return_value=None) # Key missing
|
||||||
|
mock_config_store.set_adapter_last_error = AsyncMock()
|
||||||
|
|
||||||
|
# Create mock NATS
|
||||||
|
mock_nats = MagicMock()
|
||||||
|
mock_nats.publish = AsyncMock()
|
||||||
|
|
||||||
|
# Build supervisor with FIRMS adapter
|
||||||
|
supervisor = Supervisor.__new__(Supervisor)
|
||||||
|
supervisor._config_store = mock_config_store
|
||||||
|
supervisor._adapters = {"firms": FIRMSAdapter}
|
||||||
|
supervisor._adapter_states = {}
|
||||||
|
supervisor._nats = mock_nats
|
||||||
|
supervisor._cursor_db_path = tmp_path / "cursors.db"
|
||||||
|
supervisor._log = MagicMock()
|
||||||
|
|
||||||
|
config = AdapterConfig(
|
||||||
|
name="firms",
|
||||||
|
enabled=True,
|
||||||
|
cadence_s=300,
|
||||||
|
settings={"api_key_alias": "firms", "satellites": ["VIIRS_SNPP_NRT"]},
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
await supervisor._start_adapter(config)
|
||||||
|
|
||||||
|
# Should have checked for key
|
||||||
|
mock_config_store.get_api_key.assert_called_once_with("firms")
|
||||||
|
|
||||||
|
# Should have set error
|
||||||
|
mock_config_store.set_adapter_last_error.assert_called_once()
|
||||||
|
args = mock_config_store.set_adapter_last_error.call_args[0]
|
||||||
|
assert args[0] == "firms"
|
||||||
|
assert "missing api key" in args[1].lower()
|
||||||
|
|
||||||
|
# Should NOT have created adapter state (adapter did not start)
|
||||||
|
assert "firms" not in supervisor._adapter_states
|
||||||
|
|
||||||
|
# Should NOT have published to NATS
|
||||||
|
mock_nats.publish.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_adapter_succeeds_after_key_added_and_clears_last_error(self, tmp_path: Path):
|
||||||
|
"""Adapter with requires_api_key and key present should start and clear last_error."""
|
||||||
|
from central.supervisor import Supervisor
|
||||||
|
from central.adapters.firms import FIRMSAdapter
|
||||||
|
|
||||||
|
# Create mock config store with key present
|
||||||
|
mock_config_store = MagicMock()
|
||||||
|
mock_config_store.get_api_key = AsyncMock(return_value="encrypted-firms-key")
|
||||||
|
mock_config_store.set_adapter_last_error = AsyncMock()
|
||||||
|
|
||||||
|
# Create mock NATS
|
||||||
|
mock_nats = MagicMock()
|
||||||
|
mock_nats.publish = AsyncMock()
|
||||||
|
|
||||||
|
# Build supervisor with FIRMS adapter
|
||||||
|
supervisor = Supervisor.__new__(Supervisor)
|
||||||
|
supervisor._config_store = mock_config_store
|
||||||
|
supervisor._adapters = {"firms": FIRMSAdapter}
|
||||||
|
supervisor._adapter_states = {}
|
||||||
|
supervisor._nats = mock_nats
|
||||||
|
supervisor._cursor_db_path = tmp_path / "cursors.db"
|
||||||
|
supervisor._log = MagicMock()
|
||||||
|
|
||||||
|
config = AdapterConfig(
|
||||||
|
name="firms",
|
||||||
|
enabled=True,
|
||||||
|
cadence_s=300,
|
||||||
|
settings={"api_key_alias": "firms", "satellites": ["VIIRS_SNPP_NRT"]},
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the adapter instantiation to avoid actual HTTP calls
|
||||||
|
with patch.object(FIRMSAdapter, "__init__", return_value=None):
|
||||||
|
with patch.object(FIRMSAdapter, "startup", new_callable=AsyncMock):
|
||||||
|
await supervisor._start_adapter(config)
|
||||||
|
|
||||||
|
# Should have checked for key
|
||||||
|
mock_config_store.get_api_key.assert_called_once_with("firms")
|
||||||
|
|
||||||
|
# Should have cleared any stale error (called with None)
|
||||||
|
mock_config_store.set_adapter_last_error.assert_called_once_with("firms", None)
|
||||||
|
|
||||||
|
# Should have created adapter state
|
||||||
|
assert "firms" in supervisor._adapter_states
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_adapter_does_not_check_when_no_requires_api_key(self, tmp_path: Path):
|
||||||
|
"""Adapter without requires_api_key should skip the API key check."""
|
||||||
|
from central.supervisor import Supervisor
|
||||||
|
from central.adapters.nws import NWSAdapter
|
||||||
|
|
||||||
|
# Create mock config store
|
||||||
|
mock_config_store = MagicMock()
|
||||||
|
mock_config_store.get_api_key = AsyncMock()
|
||||||
|
mock_config_store.set_adapter_last_error = AsyncMock()
|
||||||
|
|
||||||
|
# Create mock NATS
|
||||||
|
mock_nats = MagicMock()
|
||||||
|
mock_nats.publish = AsyncMock()
|
||||||
|
|
||||||
|
# Build supervisor with NWS adapter (no requires_api_key)
|
||||||
|
supervisor = Supervisor.__new__(Supervisor)
|
||||||
|
supervisor._config_store = mock_config_store
|
||||||
|
supervisor._adapters = {"nws": NWSAdapter}
|
||||||
|
supervisor._adapter_states = {}
|
||||||
|
supervisor._nats = mock_nats
|
||||||
|
supervisor._cursor_db_path = tmp_path / "cursors.db"
|
||||||
|
supervisor._log = MagicMock()
|
||||||
|
|
||||||
|
config = AdapterConfig(
|
||||||
|
name="nws",
|
||||||
|
enabled=True,
|
||||||
|
cadence_s=300,
|
||||||
|
settings={"contact_email": "test@example.com"},
|
||||||
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the adapter instantiation to avoid actual HTTP calls
|
||||||
|
with patch.object(NWSAdapter, "__init__", return_value=None):
|
||||||
|
with patch.object(NWSAdapter, "startup", new_callable=AsyncMock):
|
||||||
|
await supervisor._start_adapter(config)
|
||||||
|
|
||||||
|
# Should NOT have called get_api_key (no requires_api_key)
|
||||||
|
mock_config_store.get_api_key.assert_not_called()
|
||||||
|
|
||||||
|
# Should have cleared stale error (routine clear)
|
||||||
|
mock_config_store.set_adapter_last_error.assert_called_once_with("nws", None)
|
||||||
|
|
||||||
|
# Should have created adapter state
|
||||||
|
assert "nws" in supervisor._adapter_states
|
||||||
|
|
||||||
|
|
||||||
|
class TestAdaptersEditSubmitErrorRerender:
|
||||||
|
"""Tests for adapters_edit_submit error re-render including api_key_missing."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_adapters_edit_submit_error_rerender_includes_api_key_missing(self):
|
||||||
|
"""Error re-render on /adapters/firms should include api_key_missing in context."""
|
||||||
|
from central.gui.routes import adapters_edit_submit
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.state = MagicMock()
|
||||||
|
mock_request.state.operator = {"username": "test"}
|
||||||
|
mock_request.state.csrf_token = "test_token"
|
||||||
|
|
||||||
|
# Mock form with invalid cadence (below minimum of 10)
|
||||||
|
mock_form = MagicMock()
|
||||||
|
def form_get(k, d=""):
|
||||||
|
values = {
|
||||||
|
"csrf_token": "test_token",
|
||||||
|
"cadence_s": "5", # Invalid - below minimum
|
||||||
|
"api_key_alias": "firms",
|
||||||
|
"satellites": "",
|
||||||
|
"region_north": "",
|
||||||
|
"region_south": "",
|
||||||
|
"region_east": "",
|
||||||
|
"region_west": "",
|
||||||
|
}
|
||||||
|
return values.get(k, d)
|
||||||
|
mock_form.get = MagicMock(side_effect=form_get)
|
||||||
|
mock_form.getlist = MagicMock(return_value=["VIIRS_SNPP_NRT"])
|
||||||
|
mock_form.__contains__ = lambda self, k: k == "enabled"
|
||||||
|
mock_request.form = AsyncMock(return_value=mock_form)
|
||||||
|
|
||||||
|
mock_pool = MagicMock()
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_conn.fetchrow = AsyncMock(side_effect=[
|
||||||
|
# First call: adapter row
|
||||||
|
{
|
||||||
|
"name": "firms",
|
||||||
|
"enabled": False,
|
||||||
|
"cadence_s": 300,
|
||||||
|
"settings": {"api_key_alias": "firms", "satellites": ["VIIRS_SNPP_NRT"]},
|
||||||
|
"paused_at": None,
|
||||||
|
"updated_at": datetime.now(timezone.utc),
|
||||||
|
"last_error": None,
|
||||||
|
},
|
||||||
|
# Second call: system row for map tiles
|
||||||
|
{"map_tile_url": "https://tile.example.com/{z}/{x}/{y}.png", "map_attribution": "Test"},
|
||||||
|
])
|
||||||
|
mock_conn.fetchval = AsyncMock(return_value=None) # No API key exists
|
||||||
|
mock_conn.fetch = AsyncMock(return_value=[]) # No API keys
|
||||||
|
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_conn.__aexit__ = AsyncMock()
|
||||||
|
mock_pool.acquire = MagicMock(return_value=mock_conn)
|
||||||
|
|
||||||
|
# Mock FIRMS adapter class
|
||||||
|
class MockFIRMSSettings(BaseModel):
|
||||||
|
api_key_alias: str = ""
|
||||||
|
satellites: list[Literal["VIIRS_SNPP_NRT", "VIIRS_NOAA20_NRT"]] = []
|
||||||
|
|
||||||
|
mock_firms_cls = MagicMock()
|
||||||
|
mock_firms_cls.requires_api_key = "firms"
|
||||||
|
mock_firms_cls.api_key_field = "api_key_alias"
|
||||||
|
mock_firms_cls.display_name = "FIRMS"
|
||||||
|
mock_firms_cls.description = "Fire detection"
|
||||||
|
mock_firms_cls.settings_schema = MockFIRMSSettings
|
||||||
|
|
||||||
|
with patch("central.gui.routes._get_templates") as mock_templates:
|
||||||
|
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||||
|
with patch("central.gui.routes._adapter_classes", return_value={"firms": mock_firms_cls}):
|
||||||
|
with patch("central.gui.routes.describe_fields", return_value=[]):
|
||||||
|
mock_template_response = MagicMock()
|
||||||
|
mock_template_response.status_code = 200
|
||||||
|
mock_templates.return_value.TemplateResponse = MagicMock(return_value=mock_template_response)
|
||||||
|
|
||||||
|
result = await adapters_edit_submit(mock_request, "firms")
|
||||||
|
|
||||||
|
# Verify TemplateResponse was called (error re-render)
|
||||||
|
assert mock_templates.return_value.TemplateResponse.called
|
||||||
|
|
||||||
|
# Check the context passed to template
|
||||||
|
call_kwargs = mock_templates.return_value.TemplateResponse.call_args[1]
|
||||||
|
context = call_kwargs["context"]
|
||||||
|
|
||||||
|
# Should have errors (invalid cadence)
|
||||||
|
assert context.get("errors") is not None
|
||||||
|
assert "cadence_s" in context["errors"]
|
||||||
|
|
||||||
|
# Should include api_key_missing
|
||||||
|
assert context["api_key_missing"] is True
|
||||||
|
assert context["requires_api_key_alias"] == "firms"
|
||||||
|
|
@ -94,6 +94,8 @@ class MockConfigSource:
|
||||||
class MockNWSAdapter:
|
class MockNWSAdapter:
|
||||||
"""Mock NWSAdapter that tracks poll calls and allows control."""
|
"""Mock NWSAdapter that tracks poll calls and allows control."""
|
||||||
|
|
||||||
|
requires_api_key = None # Mock adapters don't require API keys
|
||||||
|
|
||||||
def __init__(self, config, config_store, cursor_db_path) -> None:
|
def __init__(self, config, config_store, cursor_db_path) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self._config_store = config_store
|
self._config_store = config_store
|
||||||
|
|
@ -152,6 +154,8 @@ def mock_config_store():
|
||||||
store = MagicMock()
|
store = MagicMock()
|
||||||
store.list_streams = AsyncMock(return_value=[])
|
store.list_streams = AsyncMock(return_value=[])
|
||||||
store.get_stream = AsyncMock(return_value=None)
|
store.get_stream = AsyncMock(return_value=None)
|
||||||
|
store.set_adapter_last_error = AsyncMock()
|
||||||
|
store.get_api_key = AsyncMock(return_value=None)
|
||||||
return store
|
return store
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue