diff --git a/src/central/gui/routes.py b/src/central/gui/routes.py index 96c9791..80a61c9 100644 --- a/src/central/gui/routes.py +++ b/src/central/gui/routes.py @@ -866,31 +866,31 @@ async def setup_adapters_submit(request: Request) -> Response: new_settings[field.name] = values elif field.widget == "region": - # Region validation + # Region validation via RegionConfig model + from central.config_models import RegionConfig region_north_str = form.get(f"{adapter_name}_{field.name}_north", "").strip() region_south_str = form.get(f"{adapter_name}_{field.name}_south", "").strip() region_east_str = form.get(f"{adapter_name}_{field.name}_east", "").strip() region_west_str = form.get(f"{adapter_name}_{field.name}_west", "").strip() try: - region_north = float(region_north_str) - region_south = float(region_south_str) - region_east = float(region_east_str) - region_west = float(region_west_str) + region_model = RegionConfig( + north=float(region_north_str), + south=float(region_south_str), + east=float(region_east_str), + west=float(region_west_str), + ) + new_settings[field.name] = region_model.model_dump() + except (ValueError, ValidationError) as e: + errors[f"{adapter_name}_{field.name}"] = str(e) - if not (-90 <= region_south < region_north <= 90): - errors[f"{adapter_name}_{field.name}"] = "Invalid latitude: south < north, both -90 to 90" - elif not (-180 <= region_west < region_east <= 180): - errors[f"{adapter_name}_{field.name}"] = "Invalid longitude: west < east, both -180 to 180" - else: - new_settings[field.name] = { - "north": region_north, - "south": region_south, - "east": region_east, - "west": region_west, - } - except ValueError: - errors[f"{adapter_name}_{field.name}"] = "Region coordinates must be valid numbers" + # Run Pydantic validation on assembled settings to catch Literal violations etc. + try: + adapter_cls.settings_schema(**new_settings) + except ValidationError as e: + for err in e.errors(): + loc = err["loc"][0] if err["loc"] else "unknown" + errors[f"{adapter_name}_{loc}"] = err["msg"] new_adapters[adapter_name] = { "enabled": enabled, @@ -900,12 +900,18 @@ async def setup_adapters_submit(request: Request) -> Response: # If errors, re-render if errors: - adapters = [ - {"name": name, "enabled": new_adapters[name]["enabled"], - "cadence_s": new_adapters[name]["cadence_s"], - "settings": new_adapters[name]["settings"]} - for name in ["firms", "nws", "usgs_quake"] - ] + adapters = [] + for name, cls in wizard_adapters: + settings_dict = new_adapters[name]["settings"] + fields = describe_fields(cls.settings_schema, settings_dict) + adapters.append({ + "name": name, + "display_name": cls.display_name, + "enabled": new_adapters[name]["enabled"], + "cadence_s": new_adapters[name]["cadence_s"], + "settings": settings_dict, + "fields": fields, + }) api_keys = [{"alias": k["alias"]} for k in state.api_keys] if state.system: @@ -923,8 +929,6 @@ async def setup_adapters_submit(request: Request) -> Response: "csrf_token": csrf_token, "adapters": adapters, "api_keys": api_keys, - "valid_satellites": _get_valid_satellites(), - "valid_feeds": sorted(_get_valid_feeds()), "tile_url": tile_url, "tile_attribution": tile_attribution, "error": "Please fix the errors below.", diff --git a/tests/test_wizard.py b/tests/test_wizard.py index 0492276..fa2efa2 100644 --- a/tests/test_wizard.py +++ b/tests/test_wizard.py @@ -199,3 +199,184 @@ class TestSetupGateMiddlewareWizard: response = client.get("/setup/operator") assert response.status_code == 302 assert response.headers["location"] == "/" + +class TestSetupAdaptersErrorRerender: + """Test wizard adapters form error re-render path.""" + + @pytest.mark.asyncio + async def test_invalid_cadence_rerenders_with_error(self): + """POST /setup/adapters with cadence_s=5 re-renders form with error, no DB write.""" + from central.gui.routes import setup_adapters_submit + + mock_request = MagicMock() + mock_request.cookies = {} + mock_request.state = MagicMock() + + # Mock form data with invalid cadence + mock_form = MagicMock() + mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", + "nws_enabled": "on", + "nws_cadence_s": "5", # Invalid: below ge=10 + "nws_contact_email": "test@example.com", + "nws_region_north": "49.0", + "nws_region_south": "31.0", + "nws_region_east": "-102.0", + "nws_region_west": "-124.0", + "firms_cadence_s": "300", + "firms_region_north": "49.0", + "firms_region_south": "31.0", + "firms_region_east": "-102.0", + "firms_region_west": "-124.0", + "usgs_quake_cadence_s": "300", + "usgs_quake_feed": "all_hour", + "usgs_quake_region_north": "49.0", + "usgs_quake_region_south": "31.0", + "usgs_quake_region_east": "-102.0", + "usgs_quake_region_west": "-124.0", + }.get(k, d) + mock_form.getlist.side_effect = lambda k: { + "firms_satellites": ["VIIRS_SNPP_NRT"], + }.get(k, []) + mock_form.__contains__ = lambda self, k: k in ["nws_enabled"] + + mock_request.form = AsyncMock(return_value=mock_form) + + # Mock wizard state + mock_state = MagicMock() + mock_state.operator = {"username": "test", "password_hash": "hash"} + mock_state.api_keys = [] + mock_state.adapters = None + mock_state.system = None + + # Mock pool with no actual DB access (should not be called for writes) + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetch = AsyncMock(return_value=[ + {"name": "nws", "enabled": False, "cadence_s": 300, "settings": {}}, + {"name": "firms", "enabled": False, "cadence_s": 300, "settings": {}}, + {"name": "usgs_quake", "enabled": False, "cadence_s": 300, "settings": {}}, + ]) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + mock_templates = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_templates.TemplateResponse.return_value = mock_response + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab" + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.wizard.get_wizard_state", return_value=mock_state): + with patch("central.gui.routes.reuse_or_generate_pre_auth_csrf", return_value=("csrf", None)): + result = await setup_adapters_submit(mock_request) + + # Should return 200 (re-render), not 302 (redirect) + assert result.status_code == 200 + + # Check that template was called with errors + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + + assert context["error"] == "Please fix the errors below." + assert "errors" in context + assert context["errors"] is not None + assert "nws_cadence_s" in context["errors"] + assert "10" in context["errors"]["nws_cadence_s"] # Should mention min value + + # Verify adapters have correct shape (with fields) + assert "adapters" in context + for adapter in context["adapters"]: + assert "name" in adapter + assert "display_name" in adapter + assert "enabled" in adapter + assert "cadence_s" in adapter + assert "settings" in adapter + assert "fields" in adapter + + # Verify no DB execute was called (no writes) + mock_conn.execute.assert_not_called() + + @pytest.mark.asyncio + async def test_invalid_region_bounds_shows_pydantic_error(self): + """POST /setup/adapters with inverted region bounds shows RegionConfig error.""" + from central.gui.routes import setup_adapters_submit + + mock_request = MagicMock() + mock_request.cookies = {} + mock_request.state = MagicMock() + + # Mock form data with inverted region (south > north) + mock_form = MagicMock() + mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", + "nws_cadence_s": "300", + "nws_contact_email": "test@example.com", + "nws_region_north": "10.0", # Invalid: north < south + "nws_region_south": "20.0", + "nws_region_east": "-102.0", + "nws_region_west": "-124.0", + "firms_cadence_s": "300", + "firms_region_north": "49.0", + "firms_region_south": "31.0", + "firms_region_east": "-102.0", + "firms_region_west": "-124.0", + "usgs_quake_cadence_s": "300", + "usgs_quake_feed": "all_hour", + "usgs_quake_region_north": "49.0", + "usgs_quake_region_south": "31.0", + "usgs_quake_region_east": "-102.0", + "usgs_quake_region_west": "-124.0", + }.get(k, d) + mock_form.getlist.side_effect = lambda k: { + "firms_satellites": ["VIIRS_SNPP_NRT"], + }.get(k, []) + mock_form.__contains__ = lambda self, k: False + + mock_request.form = AsyncMock(return_value=mock_form) + + mock_state = MagicMock() + mock_state.operator = {"username": "test", "password_hash": "hash"} + mock_state.api_keys = [] + mock_state.adapters = None + mock_state.system = None + + mock_pool = MagicMock() + mock_conn = MagicMock() + mock_conn.fetch = AsyncMock(return_value=[ + {"name": "nws", "enabled": False, "cadence_s": 300, "settings": {}}, + {"name": "firms", "enabled": False, "cadence_s": 300, "settings": {}}, + {"name": "usgs_quake", "enabled": False, "cadence_s": 300, "settings": {}}, + ]) + mock_conn.__aenter__ = AsyncMock(return_value=mock_conn) + mock_conn.__aexit__ = AsyncMock() + mock_pool.acquire = MagicMock(return_value=mock_conn) + + mock_templates = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_templates.TemplateResponse.return_value = mock_response + + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret12345678901234567890ab" + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.wizard.get_wizard_state", return_value=mock_state): + with patch("central.gui.routes.reuse_or_generate_pre_auth_csrf", return_value=("csrf", None)): + result = await setup_adapters_submit(mock_request) + + assert result.status_code == 200 + + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + + assert context["errors"] is not None + assert "nws_region" in context["errors"] + # Error should come from RegionConfig validator, mentioning bounds + assert "north" in context["errors"]["nws_region"].lower() or "south" in context["errors"]["nws_region"].lower() +