diff --git a/src/central/adapter.py b/src/central/adapter.py index 276a9cf..0322d3e 100644 --- a/src/central/adapter.py +++ b/src/central/adapter.py @@ -34,6 +34,10 @@ class SourceAdapter(ABC): description: str settings_schema: type[BaseModel] requires_api_key: str | None = None + api_key_field: str | None = None + """Names the settings_schema field that holds an api_key alias reference, if any. + The GUI renders this field as a select populated from config.api_keys; + the wizard validates it against staged api_keys state.""" wizard_order: int | None = None default_cadence_s: int diff --git a/src/central/adapters/firms.py b/src/central/adapters/firms.py index 0b1647b..c9b4efb 100644 --- a/src/central/adapters/firms.py +++ b/src/central/adapters/firms.py @@ -66,6 +66,7 @@ class FIRMSAdapter(SourceAdapter): description = "Near-real-time satellite-detected fire hotspots from NASA FIRMS." settings_schema = FIRMSSettings requires_api_key = "firms" + api_key_field = "api_key_alias" wizard_order = 2 default_cadence_s = 300 diff --git a/src/central/adapters/nws.py b/src/central/adapters/nws.py index ce95d3a..8205a1f 100644 --- a/src/central/adapters/nws.py +++ b/src/central/adapters/nws.py @@ -19,7 +19,7 @@ from tenacity import ( from central import __version__ from central.adapter import SourceAdapter -from pydantic import BaseModel +from pydantic import BaseModel, Field from central.config_models import AdapterConfig, RegionConfig from central.config_store import ConfigStore @@ -193,7 +193,11 @@ def _build_regions(same_codes: list[str], ugc_codes: list[str]) -> list[str]: class NWSSettings(BaseModel): """Settings schema for NWS adapter.""" - contact_email: str = "" + contact_email: str = Field( + default="", + pattern=r"^$|^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$", + description="Contact email for NWS API User-Agent header", + ) region: RegionConfig | None = None diff --git a/src/central/gui/routes.py b/src/central/gui/routes.py index 80a61c9..f2d4e28 100644 --- a/src/central/gui/routes.py +++ b/src/central/gui/routes.py @@ -652,6 +652,11 @@ async def setup_adapters_form(request: Request) -> HTMLResponse: else: settings_dict = {} fields = describe_fields(cls.settings_schema, settings_dict) + # Swap widget for api_key_field to api_key_select + if cls.api_key_field is not None: + for f in fields: + if f.name == cls.api_key_field: + f.widget = "api_key_select" adapters.append({ "name": name, "display_name": cls.display_name, @@ -683,6 +688,11 @@ async def setup_adapters_form(request: Request) -> HTMLResponse: enabled = False cadence_s = 300 fields = describe_fields(cls.settings_schema, settings_dict) + # Swap widget for api_key_field to api_key_select + if cls.api_key_field is not None: + for f in fields: + if f.name == cls.api_key_field: + f.widget = "api_key_select" adapters.append({ "name": name, "display_name": cls.display_name, @@ -803,28 +813,12 @@ async def setup_adapters_submit(request: Request) -> Response: if field.widget == "text": value = form.get(form_key, "").strip() - # Special validation for contact_email - if field.name == "contact_email": - if enabled: - if not value: - errors[form_key] = "Contact email is required when enabled" - elif not EMAIL_REGEX.match(value): - errors[form_key] = "Invalid email format" - else: - new_settings[field.name] = value - else: - new_settings[field.name] = value if value else current_settings.get(field.name) - # Special validation for api_key_alias - elif field.name == "api_key_alias": - if value: - if not any(k["alias"] == value for k in state.api_keys): - errors[form_key] = "API key alias does not exist" - else: - new_settings[field.name] = value - else: - new_settings[field.name] = None - else: - new_settings[field.name] = value if value else current_settings.get(field.name) + new_settings[field.name] = value if value else current_settings.get(field.name) + + elif field.widget == "api_key_select": + # API key alias field - stored as text, validated post-loop + value = form.get(form_key, "").strip() + new_settings[field.name] = value if value else None elif field.widget == "number": value_str = form.get(form_key, "").strip() @@ -892,6 +886,15 @@ async def setup_adapters_submit(request: Request) -> Response: loc = err["loc"][0] if err["loc"] else "unknown" errors[f"{adapter_name}_{loc}"] = err["msg"] + # Generic api_key_field validation against wizard state + if adapter_cls.api_key_field is not None: + field_value = new_settings.get(adapter_cls.api_key_field) + if field_value: + if not any(k["alias"] == field_value for k in state.api_keys): + errors[f"{adapter_name}_{adapter_cls.api_key_field}"] = ( + "API key alias does not exist" + ) + new_adapters[adapter_name] = { "enabled": enabled, "cadence_s": cadence_s, @@ -904,6 +907,11 @@ async def setup_adapters_submit(request: Request) -> Response: for name, cls in wizard_adapters: settings_dict = new_adapters[name]["settings"] fields = describe_fields(cls.settings_schema, settings_dict) + # Swap widget for api_key_field to api_key_select + if cls.api_key_field is not None: + for f in fields: + if f.name == cls.api_key_field: + f.widget = "api_key_select" adapters.append({ "name": name, "display_name": cls.display_name, @@ -1399,6 +1407,17 @@ async def adapters_edit_form( fields = [] if adapter_cls and hasattr(adapter_cls, "settings_schema"): fields = describe_fields(adapter_cls.settings_schema, settings) + # Swap widget for api_key_field to api_key_select + if adapter_cls.api_key_field is not None: + for f in fields: + if f.name == adapter_cls.api_key_field: + f.widget = "api_key_select" + + # Fetch API keys for api_key_select widget + api_keys = [] + async with pool.acquire() as conn: + api_key_rows = await conn.fetch("SELECT alias FROM config.api_keys ORDER BY alias") + api_keys = [{"alias": r["alias"]} for r in api_key_rows] csrf_token = request.state.csrf_token response = templates.TemplateResponse( @@ -1409,6 +1428,7 @@ async def adapters_edit_form( "csrf_token": csrf_token, "adapter": adapter, "fields": fields, + "api_keys": api_keys, "errors": None, "form_data": None, "tile_url": tile_url, @@ -1520,6 +1540,10 @@ async def adapters_edit_submit( parsed_values[field.name] = values else: parsed_values[field.name] = values + elif field.widget == "api_key_select": + # API key select - validate against existing keys + value = raw.strip() if raw else None + parsed_values[field.name] = value elif field.widget == "region": # Region handled separately below pass @@ -1600,6 +1624,15 @@ async def adapters_edit_submit( fields = [] if adapter_cls and hasattr(adapter_cls, "settings_schema"): fields = describe_fields(adapter_cls.settings_schema, current_settings) + # Swap widget for api_key_field to api_key_select + if adapter_cls.api_key_field is not None: + for f in fields: + if f.name == adapter_cls.api_key_field: + f.widget = "api_key_select" + + # Fetch API keys for api_key_select widget + api_key_rows = await conn.fetch("SELECT alias FROM config.api_keys ORDER BY alias") + api_keys = [{"alias": r["alias"]} for r in api_key_rows] csrf_token = request.state.csrf_token response = templates.TemplateResponse( @@ -1610,6 +1643,7 @@ async def adapters_edit_submit( "csrf_token": csrf_token, "adapter": adapter, "fields": fields, + "api_keys": api_keys, "errors": errors, "form_data": form_data, "tile_url": tile_url, diff --git a/src/central/gui/templates/adapters_edit.html b/src/central/gui/templates/adapters_edit.html index 4397467..3085cba 100644 --- a/src/central/gui/templates/adapters_edit.html +++ b/src/central/gui/templates/adapters_edit.html @@ -134,6 +134,24 @@ {% if errors and errors[field.name] %} {{ errors[field.name] }} {% endif %} + + {% elif field.widget == "api_key_select" %} + + + {% if field.description %} + {{ field.description }} + {% endif %} + {% if errors and errors[field.name] %} + {{ errors[field.name] }} + {% endif %} {% endif %} {% endfor %} diff --git a/src/central/gui/templates/setup_adapters.html b/src/central/gui/templates/setup_adapters.html index c92d0c1..e0cc977 100644 --- a/src/central/gui/templates/setup_adapters.html +++ b/src/central/gui/templates/setup_adapters.html @@ -53,8 +53,18 @@ {% set form_key = adapter.name + '_' + field.name %} {% if field.widget == "text" %} - {# Special handling for api_key_alias - render as select from wizard API keys #} - {% if field.name == "api_key_alias" %} + + + {% if field.description %} + {{ field.description }} + {% endif %} + {% if errors and errors.get(form_key) %} + {{ errors[form_key] }} + {% endif %} + + {% elif field.widget == "api_key_select" %} - {% else %} - - - {% endif %} {% if field.description %} {{ field.description }} {% endif %} diff --git a/tests/test_adapters.py b/tests/test_adapters.py index 80ac48b..beaeae0 100644 --- a/tests/test_adapters.py +++ b/tests/test_adapters.py @@ -417,3 +417,47 @@ class TestAdaptersJsonbRegression: assert isinstance(captured_audit["after"], dict), f"after should be dict, got {type(captured_audit['after'])}" assert isinstance(captured_audit["before"]["settings"], dict), "before.settings should be dict" assert isinstance(captured_audit["after"]["settings"], dict), "after.settings should be dict" + + @pytest.mark.asyncio + async def test_adapters_edit_fetches_api_keys_into_context(self): + """GET /adapters/firms includes api_keys from database in context.""" + from central.gui.routes import adapters_edit_form + + mock_request = MagicMock() + mock_request.state.operator = MagicMock(id=1, username="testop") + mock_request.state.csrf_token = "test_csrf_token" + + mock_conn = MagicMock() + mock_conn.fetchrow = AsyncMock(side_effect=[ + # Adapter row + {"name": "firms", "enabled": True, "cadence_s": 300, "settings": {}, + "paused_at": None, "updated_at": None, "last_error": None}, + # System row + {"map_tile_url": "https://tile.example.com", "map_attribution": "Test"}, + ]) + mock_conn.fetch = AsyncMock(return_value=[ + {"alias": "firms_key"}, + {"alias": "other_key"}, + ]) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + + mock_pool = MagicMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + mock_templates = MagicMock() + mock_response = MagicMock() + mock_templates.TemplateResponse.return_value = mock_response + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + result = await adapters_edit_form(mock_request, "firms") + + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + + assert "api_keys" in context + assert len(context["api_keys"]) == 2 + assert context["api_keys"][0]["alias"] == "firms_key" + assert context["api_keys"][1]["alias"] == "other_key" + diff --git a/tests/test_wizard.py b/tests/test_wizard.py index fa2efa2..56ddb00 100644 --- a/tests/test_wizard.py +++ b/tests/test_wizard.py @@ -380,3 +380,176 @@ class TestSetupAdaptersErrorRerender: # Error should come from RegionConfig validator, mentioning bounds assert "north" in context["errors"]["nws_region"].lower() or "south" in context["errors"]["nws_region"].lower() + @pytest.mark.asyncio + async def test_invalid_contact_email_via_pydantic_pattern(self): + """POST /setup/adapters with NWS contact_email='not-an-email' shows Pydantic pattern error.""" + from central.gui.routes import setup_adapters_submit + + mock_request = MagicMock() + mock_request.cookies = {} + mock_request.state = MagicMock() + + mock_form = MagicMock() + mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", + "nws_enabled": "on", + "nws_cadence_s": "300", + "nws_contact_email": "not-an-email", # Invalid email format + "nws_region_north": "49.0", + "nws_region_south": "31.0", + "nws_region_east": "-102.0", + "nws_region_west": "-124.0", + "firms_cadence_s": "300", + "firms_region_north": "49.0", + "firms_region_south": "31.0", + "firms_region_east": "-102.0", + "firms_region_west": "-124.0", + "usgs_quake_cadence_s": "300", + "usgs_quake_feed": "all_hour", + "usgs_quake_region_north": "49.0", + "usgs_quake_region_south": "31.0", + "usgs_quake_region_east": "-102.0", + "usgs_quake_region_west": "-124.0", + }.get(k, d) + mock_form.getlist.side_effect = lambda k: { + "firms_satellites": ["VIIRS_SNPP_NRT"], + }.get(k, []) + mock_form.__contains__ = lambda self, k: k in ["nws_enabled"] + + mock_request.form = AsyncMock(return_value=mock_form) + + mock_state = MagicMock() + mock_state.operator = {"username": "test", "password_hash": "hash"} + mock_state.api_keys = [] + mock_state.adapters = None + mock_state.system = None + + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetch = AsyncMock(return_value=[ + {"name": "nws", "enabled": False, "cadence_s": 300, "settings": {}}, + {"name": "firms", "enabled": False, "cadence_s": 300, "settings": {}}, + {"name": "usgs_quake", "enabled": False, "cadence_s": 300, "settings": {}}, + ]) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + mock_templates = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_templates.TemplateResponse.return_value = mock_response + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab" + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.wizard.get_wizard_state", return_value=mock_state): + with patch("central.gui.routes.reuse_or_generate_pre_auth_csrf", return_value=("csrf", None)): + result = await setup_adapters_submit(mock_request) + + assert result.status_code == 200 + + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + + assert context["errors"] is not None + assert "nws_contact_email" in context["errors"] + # Error should be from Pydantic pattern validation + error_msg = context["errors"]["nws_contact_email"].lower() + assert "pattern" in error_msg or "string" in error_msg or "match" in error_msg + + @pytest.mark.asyncio + async def test_invalid_api_key_alias_generic(self): + """POST /setup/adapters with FIRMS api_key_alias='bogus' shows generic error.""" + from central.gui.routes import setup_adapters_submit + + mock_request = MagicMock() + mock_request.cookies = {} + mock_request.state = MagicMock() + + mock_form = MagicMock() + mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", + "nws_cadence_s": "300", + "nws_contact_email": "test@example.com", + "nws_region_north": "49.0", + "nws_region_south": "31.0", + "nws_region_east": "-102.0", + "nws_region_west": "-124.0", + "firms_cadence_s": "300", + "firms_api_key_alias": "bogus-alias-not-in-state", # Invalid alias + "firms_region_north": "49.0", + "firms_region_south": "31.0", + "firms_region_east": "-102.0", + "firms_region_west": "-124.0", + "usgs_quake_cadence_s": "300", + "usgs_quake_feed": "all_hour", + "usgs_quake_region_north": "49.0", + "usgs_quake_region_south": "31.0", + "usgs_quake_region_east": "-102.0", + "usgs_quake_region_west": "-124.0", + }.get(k, d) + mock_form.getlist.side_effect = lambda k: { + "firms_satellites": ["VIIRS_SNPP_NRT"], + }.get(k, []) + mock_form.__contains__ = lambda self, k: False + + mock_request.form = AsyncMock(return_value=mock_form) + + mock_state = MagicMock() + mock_state.operator = {"username": "test", "password_hash": "hash"} + mock_state.api_keys = [{"alias": "valid_key"}] # Only valid_key exists + mock_state.adapters = None + mock_state.system = None + + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetch = AsyncMock(return_value=[ + {"name": "nws", "enabled": False, "cadence_s": 300, "settings": {}}, + {"name": "firms", "enabled": False, "cadence_s": 300, "settings": {}}, + {"name": "usgs_quake", "enabled": False, "cadence_s": 300, "settings": {}}, + ]) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + mock_templates = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_templates.TemplateResponse.return_value = mock_response + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab" + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.wizard.get_wizard_state", return_value=mock_state): + with patch("central.gui.routes.reuse_or_generate_pre_auth_csrf", return_value=("csrf", None)): + result = await setup_adapters_submit(mock_request) + + assert result.status_code == 200 + + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + + assert context["errors"] is not None + assert "firms_api_key_alias" in context["errors"] + assert "API key alias does not exist" in context["errors"]["firms_api_key_alias"] + + @pytest.mark.asyncio + async def test_api_key_field_none_no_check(self): + """Adapters with api_key_field=None do not trigger the api_key check.""" + # Verify that NWSAdapter has api_key_field=None + from central.adapters.nws import NWSAdapter + from central.adapters.firms import FIRMSAdapter + from central.adapters.usgs_quake import USGSQuakeAdapter + + # NWS and USGS should have api_key_field=None + assert NWSAdapter.api_key_field is None + assert USGSQuakeAdapter.api_key_field is None + + # FIRMS should have api_key_field set + assert FIRMSAdapter.api_key_field == "api_key_alias" +