From 616452c1dfbf31e5a8ae53e6939baeee6a178d65 Mon Sep 17 00:00:00 2001 From: zvx-echo6 Date: Sun, 17 May 2026 19:41:58 -0600 Subject: [PATCH] fix(gui): handle CSRF errors on wizard paths Update csrf_exception_handler to re-render wizard forms with error message instead of redirecting to /login when CSRF validation fails. - /setup/operator: re-render with error - /setup/system: re-render with current system values + error - /setup/keys: re-render with current keys list + error - /setup/adapters: re-render with current adapter config + error - /setup/finish: re-render with summary data + error - /setup: redirect to /setup (middleware routes to appropriate step) Add error display to setup_keys.html and setup_finish.html templates. Add 7 new CSRF handler tests for wizard paths. Co-Authored-By: Claude Opus 4.5 --- src/central/gui/__init__.py | 178 +++++++++++++++++++- src/central/gui/templates/setup_finish.html | 4 + src/central/gui/templates/setup_keys.html | 4 + tests/test_csrf_handler.py | 163 ++++++++++++++++++ 4 files changed, 343 insertions(+), 6 deletions(-) diff --git a/src/central/gui/__init__.py b/src/central/gui/__init__.py index 20d79aa..7501a10 100644 --- a/src/central/gui/__init__.py +++ b/src/central/gui/__init__.py @@ -143,39 +143,205 @@ def _create_app() -> FastAPI: @app.exception_handler(CsrfProtectError) async def csrf_exception_handler(request, exc: CsrfProtectError): from fastapi_csrf_protect import CsrfProtect - + from central.gui.db import get_pool + csrf_protect = CsrfProtect() csrf_token, signed_token = csrf_protect.generate_csrf_tokens() - + error_msg = "Your session expired. Please try again." + if request.url.path == "/login": response = templates.TemplateResponse( request=request, name="login.html", - context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."}, + context={"csrf_token": csrf_token, "error": error_msg}, ) csrf_protect.set_csrf_cookie(signed_token, response) return response + elif request.url.path == "/setup": + # /setup is a redirect path now, not a form + return RedirectResponse("/setup", status_code=302) + + elif request.url.path == "/setup/operator": response = templates.TemplateResponse( request=request, - name="setup.html", - context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."}, + name="setup_operator.html", + context={"csrf_token": csrf_token, "error": error_msg, "form_data": None}, ) csrf_protect.set_csrf_cookie(signed_token, response) return response + + elif request.url.path == "/setup/system": + pool = get_pool() + system = { + "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png", + "map_attribution": "© OpenStreetMap contributors", + } + if pool: + try: + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + if row: + system = { + "map_tile_url": row["map_tile_url"], + "map_attribution": row["map_attribution"], + } + except Exception: + pass + response = templates.TemplateResponse( + request=request, + name="setup_system.html", + context={ + "csrf_token": csrf_token, + "error": error_msg, + "errors": None, + "form_data": None, + "system": system, + }, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + elif request.url.path == "/setup/keys": + pool = get_pool() + keys = [] + if pool: + try: + async with pool.acquire() as conn: + rows = await conn.fetch( + "SELECT alias, created_at FROM config.api_keys ORDER BY alias" + ) + keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in rows] + except Exception: + pass + response = templates.TemplateResponse( + request=request, + name="setup_keys.html", + context={ + "csrf_token": csrf_token, + "keys": keys, + "errors": None, + "form_data": None, + "success": None, + "error": error_msg, + }, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + elif request.url.path == "/setup/adapters": + pool = get_pool() + adapters = [] + api_keys = [] + tile_url = "https://tile.openstreetmap.org/{z}/{x}/{y}.png" + tile_attribution = "© OpenStreetMap contributors" + if pool: + try: + async with pool.acquire() as conn: + rows = await conn.fetch( + "SELECT name, enabled, cadence_s, settings FROM config.adapters ORDER BY name" + ) + for row in rows: + settings = row["settings"] or {} + adapters.append({ + "name": row["name"], + "enabled": row["enabled"], + "cadence_s": row["cadence_s"], + "settings": settings, + }) + key_rows = await conn.fetch( + "SELECT alias FROM config.api_keys ORDER BY alias" + ) + api_keys = [{"alias": k["alias"]} for k in key_rows] + sys_row = await conn.fetchrow( + "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true" + ) + if sys_row: + tile_url = sys_row["map_tile_url"] + tile_attribution = sys_row["map_attribution"] + except Exception: + pass + + # Import helper functions for valid values + from central.gui.routes import _get_valid_satellites, _get_valid_feeds + + response = templates.TemplateResponse( + request=request, + name="setup_adapters.html", + context={ + "csrf_token": csrf_token, + "adapters": adapters, + "api_keys": api_keys, + "valid_satellites": _get_valid_satellites(), + "valid_feeds": sorted(_get_valid_feeds()), + "tile_url": tile_url, + "tile_attribution": tile_attribution, + "error": error_msg, + "errors": None, + "form_data": None, + }, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + + elif request.url.path == "/setup/finish": + pool = get_pool() + operator_count = 0 + key_count = 0 + system = {"map_tile_url": ""} + adapters = [] + if pool: + try: + async with pool.acquire() as conn: + operator_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators") + key_count = await conn.fetchval("SELECT COUNT(*) FROM config.api_keys") + sys_row = await conn.fetchrow( + "SELECT map_tile_url FROM config.system WHERE id = true" + ) + if sys_row: + system = {"map_tile_url": sys_row["map_tile_url"]} + rows = await conn.fetch( + "SELECT name, enabled, cadence_s FROM config.adapters ORDER BY name" + ) + adapters = [ + {"name": row["name"], "enabled": row["enabled"], "cadence_s": row["cadence_s"]} + for row in rows + ] + except Exception: + pass + response = templates.TemplateResponse( + request=request, + name="setup_finish.html", + context={ + "csrf_token": csrf_token, + "operator_count": operator_count, + "key_count": key_count, + "system": system, + "adapters": adapters, + "error": error_msg, + }, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + elif request.url.path == "/logout": return RedirectResponse("/login", status_code=302) + elif request.url.path == "/change-password": response = templates.TemplateResponse( request=request, name="change_password.html", - context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."}, + context={"csrf_token": csrf_token, "error": error_msg}, ) csrf_protect.set_csrf_cookie(signed_token, response) return response + elif request.url.path.startswith("/adapters/"): # Redirect back to adapters list return RedirectResponse("/adapters", status_code=302) + else: # Fallback: redirect to login return RedirectResponse("/login", status_code=302) diff --git a/src/central/gui/templates/setup_finish.html b/src/central/gui/templates/setup_finish.html index 7e0ac7e..bc250bc 100644 --- a/src/central/gui/templates/setup_finish.html +++ b/src/central/gui/templates/setup_finish.html @@ -13,6 +13,10 @@

Review your configuration and finish the setup wizard.

+ {% if error %} +

{{ error }}

+ {% endif %} +

Summary

diff --git a/src/central/gui/templates/setup_keys.html b/src/central/gui/templates/setup_keys.html index 28457cc..4c3f125 100644 --- a/src/central/gui/templates/setup_keys.html +++ b/src/central/gui/templates/setup_keys.html @@ -13,6 +13,10 @@

Add API keys for adapters that require external service credentials (e.g., FIRMS).

+ {% if error %} +

{{ error }}

+ {% endif %} + {% if success %}

{{ success }}

{% endif %} diff --git a/tests/test_csrf_handler.py b/tests/test_csrf_handler.py index 58456e3..85d3089 100644 --- a/tests/test_csrf_handler.py +++ b/tests/test_csrf_handler.py @@ -107,3 +107,166 @@ class TestCsrfHandlerNoTraceback: assert handler is not None assert inspect.iscoroutinefunction(handler) + + +class TestCsrfHandlerWizardPaths: + """Test CSRF exception handler for wizard paths.""" + + @pytest.mark.asyncio + async def test_setup_operator_csrf_error_renders_form_with_error(self): + """CSRF error on /setup/operator re-renders form with error message.""" + from central.gui import _create_app + from fastapi_csrf_protect.exceptions import TokenValidationError + from fastapi.responses import HTMLResponse + + app = _create_app() + from fastapi_csrf_protect.exceptions import CsrfProtectError + handler = app.exception_handlers.get(CsrfProtectError) + + mock_request = MagicMock() + mock_request.url.path = "/setup/operator" + + exc = TokenValidationError("Invalid token") + + result = await handler(mock_request, exc) + + # Should be HTML response, not redirect + assert hasattr(result, "body") + assert result.status_code == 200 + body = result.body.decode() if hasattr(result.body, "decode") else str(result.body) + assert "session expired" in body.lower() + + @pytest.mark.asyncio + async def test_setup_system_csrf_error_renders_form_with_error(self): + """CSRF error on /setup/system re-renders form with error message.""" + from central.gui import _create_app + from fastapi_csrf_protect.exceptions import TokenValidationError + + app = _create_app() + from fastapi_csrf_protect.exceptions import CsrfProtectError + handler = app.exception_handlers.get(CsrfProtectError) + + mock_request = MagicMock() + mock_request.url.path = "/setup/system" + + exc = TokenValidationError("Invalid token") + + with patch("central.gui.db.get_pool", return_value=None): + result = await handler(mock_request, exc) + + assert hasattr(result, "body") + assert result.status_code == 200 + body = result.body.decode() if hasattr(result.body, "decode") else str(result.body) + assert "session expired" in body.lower() + + @pytest.mark.asyncio + async def test_setup_keys_csrf_error_renders_form_with_error(self): + """CSRF error on /setup/keys re-renders form with error message.""" + from central.gui import _create_app + from fastapi_csrf_protect.exceptions import TokenValidationError + + app = _create_app() + from fastapi_csrf_protect.exceptions import CsrfProtectError + handler = app.exception_handlers.get(CsrfProtectError) + + mock_request = MagicMock() + mock_request.url.path = "/setup/keys" + + exc = TokenValidationError("Invalid token") + + with patch("central.gui.db.get_pool", return_value=None): + result = await handler(mock_request, exc) + + assert hasattr(result, "body") + assert result.status_code == 200 + body = result.body.decode() if hasattr(result.body, "decode") else str(result.body) + assert "session expired" in body.lower() + + @pytest.mark.asyncio + async def test_setup_adapters_csrf_error_renders_form_with_error(self): + """CSRF error on /setup/adapters re-renders form with error message.""" + from central.gui import _create_app + from fastapi_csrf_protect.exceptions import TokenValidationError + + app = _create_app() + from fastapi_csrf_protect.exceptions import CsrfProtectError + handler = app.exception_handlers.get(CsrfProtectError) + + mock_request = MagicMock() + mock_request.url.path = "/setup/adapters" + + exc = TokenValidationError("Invalid token") + + with patch("central.gui.db.get_pool", return_value=None): + result = await handler(mock_request, exc) + + assert hasattr(result, "body") + assert result.status_code == 200 + body = result.body.decode() if hasattr(result.body, "decode") else str(result.body) + assert "session expired" in body.lower() + + @pytest.mark.asyncio + async def test_setup_finish_csrf_error_renders_form_with_error(self): + """CSRF error on /setup/finish re-renders form with error message.""" + from central.gui import _create_app + from fastapi_csrf_protect.exceptions import TokenValidationError + + app = _create_app() + from fastapi_csrf_protect.exceptions import CsrfProtectError + handler = app.exception_handlers.get(CsrfProtectError) + + mock_request = MagicMock() + mock_request.url.path = "/setup/finish" + + exc = TokenValidationError("Invalid token") + + with patch("central.gui.db.get_pool", return_value=None): + result = await handler(mock_request, exc) + + assert hasattr(result, "body") + assert result.status_code == 200 + body = result.body.decode() if hasattr(result.body, "decode") else str(result.body) + assert "session expired" in body.lower() + + @pytest.mark.asyncio + async def test_setup_base_csrf_error_redirects_to_setup(self): + """CSRF error on /setup redirects to /setup (middleware routes to step).""" + from central.gui import _create_app + from fastapi_csrf_protect.exceptions import TokenValidationError + from fastapi.responses import RedirectResponse + + app = _create_app() + from fastapi_csrf_protect.exceptions import CsrfProtectError + handler = app.exception_handlers.get(CsrfProtectError) + + mock_request = MagicMock() + mock_request.url.path = "/setup" + + exc = TokenValidationError("Invalid token") + + result = await handler(mock_request, exc) + + assert isinstance(result, RedirectResponse) + assert result.status_code == 302 + + @pytest.mark.asyncio + async def test_login_csrf_error_still_works(self): + """CSRF error on /login still renders login form with error (regression test).""" + from central.gui import _create_app + from fastapi_csrf_protect.exceptions import TokenValidationError + + app = _create_app() + from fastapi_csrf_protect.exceptions import CsrfProtectError + handler = app.exception_handlers.get(CsrfProtectError) + + mock_request = MagicMock() + mock_request.url.path = "/login" + + exc = TokenValidationError("Invalid token") + + result = await handler(mock_request, exc) + + assert hasattr(result, "body") + assert result.status_code == 200 + body = result.body.decode() if hasattr(result.body, "decode") else str(result.body) + assert "session expired" in body.lower()