diff --git a/src/central/gui/__init__.py b/src/central/gui/__init__.py
index 20d79aa..7501a10 100644
--- a/src/central/gui/__init__.py
+++ b/src/central/gui/__init__.py
@@ -143,39 +143,205 @@ def _create_app() -> FastAPI:
@app.exception_handler(CsrfProtectError)
async def csrf_exception_handler(request, exc: CsrfProtectError):
from fastapi_csrf_protect import CsrfProtect
-
+ from central.gui.db import get_pool
+
csrf_protect = CsrfProtect()
csrf_token, signed_token = csrf_protect.generate_csrf_tokens()
-
+ error_msg = "Your session expired. Please try again."
+
if request.url.path == "/login":
response = templates.TemplateResponse(
request=request,
name="login.html",
- context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."},
+ context={"csrf_token": csrf_token, "error": error_msg},
)
csrf_protect.set_csrf_cookie(signed_token, response)
return response
+
elif request.url.path == "/setup":
+ # /setup is a redirect path now, not a form
+ return RedirectResponse("/setup", status_code=302)
+
+ elif request.url.path == "/setup/operator":
response = templates.TemplateResponse(
request=request,
- name="setup.html",
- context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."},
+ name="setup_operator.html",
+ context={"csrf_token": csrf_token, "error": error_msg, "form_data": None},
)
csrf_protect.set_csrf_cookie(signed_token, response)
return response
+
+ elif request.url.path == "/setup/system":
+ pool = get_pool()
+ system = {
+ "map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png",
+ "map_attribution": "© OpenStreetMap contributors",
+ }
+ if pool:
+ try:
+ async with pool.acquire() as conn:
+ row = await conn.fetchrow(
+ "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true"
+ )
+ if row:
+ system = {
+ "map_tile_url": row["map_tile_url"],
+ "map_attribution": row["map_attribution"],
+ }
+ except Exception:
+ pass
+ response = templates.TemplateResponse(
+ request=request,
+ name="setup_system.html",
+ context={
+ "csrf_token": csrf_token,
+ "error": error_msg,
+ "errors": None,
+ "form_data": None,
+ "system": system,
+ },
+ )
+ csrf_protect.set_csrf_cookie(signed_token, response)
+ return response
+
+ elif request.url.path == "/setup/keys":
+ pool = get_pool()
+ keys = []
+ if pool:
+ try:
+ async with pool.acquire() as conn:
+ rows = await conn.fetch(
+ "SELECT alias, created_at FROM config.api_keys ORDER BY alias"
+ )
+ keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in rows]
+ except Exception:
+ pass
+ response = templates.TemplateResponse(
+ request=request,
+ name="setup_keys.html",
+ context={
+ "csrf_token": csrf_token,
+ "keys": keys,
+ "errors": None,
+ "form_data": None,
+ "success": None,
+ "error": error_msg,
+ },
+ )
+ csrf_protect.set_csrf_cookie(signed_token, response)
+ return response
+
+ elif request.url.path == "/setup/adapters":
+ pool = get_pool()
+ adapters = []
+ api_keys = []
+ tile_url = "https://tile.openstreetmap.org/{z}/{x}/{y}.png"
+ tile_attribution = "© OpenStreetMap contributors"
+ if pool:
+ try:
+ async with pool.acquire() as conn:
+ rows = await conn.fetch(
+ "SELECT name, enabled, cadence_s, settings FROM config.adapters ORDER BY name"
+ )
+ for row in rows:
+ settings = row["settings"] or {}
+ adapters.append({
+ "name": row["name"],
+ "enabled": row["enabled"],
+ "cadence_s": row["cadence_s"],
+ "settings": settings,
+ })
+ key_rows = await conn.fetch(
+ "SELECT alias FROM config.api_keys ORDER BY alias"
+ )
+ api_keys = [{"alias": k["alias"]} for k in key_rows]
+ sys_row = await conn.fetchrow(
+ "SELECT map_tile_url, map_attribution FROM config.system WHERE id = true"
+ )
+ if sys_row:
+ tile_url = sys_row["map_tile_url"]
+ tile_attribution = sys_row["map_attribution"]
+ except Exception:
+ pass
+
+ # Import helper functions for valid values
+ from central.gui.routes import _get_valid_satellites, _get_valid_feeds
+
+ response = templates.TemplateResponse(
+ request=request,
+ name="setup_adapters.html",
+ context={
+ "csrf_token": csrf_token,
+ "adapters": adapters,
+ "api_keys": api_keys,
+ "valid_satellites": _get_valid_satellites(),
+ "valid_feeds": sorted(_get_valid_feeds()),
+ "tile_url": tile_url,
+ "tile_attribution": tile_attribution,
+ "error": error_msg,
+ "errors": None,
+ "form_data": None,
+ },
+ )
+ csrf_protect.set_csrf_cookie(signed_token, response)
+ return response
+
+ elif request.url.path == "/setup/finish":
+ pool = get_pool()
+ operator_count = 0
+ key_count = 0
+ system = {"map_tile_url": ""}
+ adapters = []
+ if pool:
+ try:
+ async with pool.acquire() as conn:
+ operator_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators")
+ key_count = await conn.fetchval("SELECT COUNT(*) FROM config.api_keys")
+ sys_row = await conn.fetchrow(
+ "SELECT map_tile_url FROM config.system WHERE id = true"
+ )
+ if sys_row:
+ system = {"map_tile_url": sys_row["map_tile_url"]}
+ rows = await conn.fetch(
+ "SELECT name, enabled, cadence_s FROM config.adapters ORDER BY name"
+ )
+ adapters = [
+ {"name": row["name"], "enabled": row["enabled"], "cadence_s": row["cadence_s"]}
+ for row in rows
+ ]
+ except Exception:
+ pass
+ response = templates.TemplateResponse(
+ request=request,
+ name="setup_finish.html",
+ context={
+ "csrf_token": csrf_token,
+ "operator_count": operator_count,
+ "key_count": key_count,
+ "system": system,
+ "adapters": adapters,
+ "error": error_msg,
+ },
+ )
+ csrf_protect.set_csrf_cookie(signed_token, response)
+ return response
+
elif request.url.path == "/logout":
return RedirectResponse("/login", status_code=302)
+
elif request.url.path == "/change-password":
response = templates.TemplateResponse(
request=request,
name="change_password.html",
- context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."},
+ context={"csrf_token": csrf_token, "error": error_msg},
)
csrf_protect.set_csrf_cookie(signed_token, response)
return response
+
elif request.url.path.startswith("/adapters/"):
# Redirect back to adapters list
return RedirectResponse("/adapters", status_code=302)
+
else:
# Fallback: redirect to login
return RedirectResponse("/login", status_code=302)
diff --git a/src/central/gui/templates/setup_finish.html b/src/central/gui/templates/setup_finish.html
index 7e0ac7e..bc250bc 100644
--- a/src/central/gui/templates/setup_finish.html
+++ b/src/central/gui/templates/setup_finish.html
@@ -13,6 +13,10 @@
Review your configuration and finish the setup wizard.
+ {% if error %}
+ {{ error }}
+ {% endif %}
+
Summary
diff --git a/src/central/gui/templates/setup_keys.html b/src/central/gui/templates/setup_keys.html
index 28457cc..4c3f125 100644
--- a/src/central/gui/templates/setup_keys.html
+++ b/src/central/gui/templates/setup_keys.html
@@ -13,6 +13,10 @@
Add API keys for adapters that require external service credentials (e.g., FIRMS).
+ {% if error %}
+ {{ error }}
+ {% endif %}
+
{% if success %}
{{ success }}
{% endif %}
diff --git a/tests/test_csrf_handler.py b/tests/test_csrf_handler.py
index 58456e3..85d3089 100644
--- a/tests/test_csrf_handler.py
+++ b/tests/test_csrf_handler.py
@@ -107,3 +107,166 @@ class TestCsrfHandlerNoTraceback:
assert handler is not None
assert inspect.iscoroutinefunction(handler)
+
+
+class TestCsrfHandlerWizardPaths:
+ """Test CSRF exception handler for wizard paths."""
+
+ @pytest.mark.asyncio
+ async def test_setup_operator_csrf_error_renders_form_with_error(self):
+ """CSRF error on /setup/operator re-renders form with error message."""
+ from central.gui import _create_app
+ from fastapi_csrf_protect.exceptions import TokenValidationError
+ from fastapi.responses import HTMLResponse
+
+ app = _create_app()
+ from fastapi_csrf_protect.exceptions import CsrfProtectError
+ handler = app.exception_handlers.get(CsrfProtectError)
+
+ mock_request = MagicMock()
+ mock_request.url.path = "/setup/operator"
+
+ exc = TokenValidationError("Invalid token")
+
+ result = await handler(mock_request, exc)
+
+ # Should be HTML response, not redirect
+ assert hasattr(result, "body")
+ assert result.status_code == 200
+ body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
+ assert "session expired" in body.lower()
+
+ @pytest.mark.asyncio
+ async def test_setup_system_csrf_error_renders_form_with_error(self):
+ """CSRF error on /setup/system re-renders form with error message."""
+ from central.gui import _create_app
+ from fastapi_csrf_protect.exceptions import TokenValidationError
+
+ app = _create_app()
+ from fastapi_csrf_protect.exceptions import CsrfProtectError
+ handler = app.exception_handlers.get(CsrfProtectError)
+
+ mock_request = MagicMock()
+ mock_request.url.path = "/setup/system"
+
+ exc = TokenValidationError("Invalid token")
+
+ with patch("central.gui.db.get_pool", return_value=None):
+ result = await handler(mock_request, exc)
+
+ assert hasattr(result, "body")
+ assert result.status_code == 200
+ body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
+ assert "session expired" in body.lower()
+
+ @pytest.mark.asyncio
+ async def test_setup_keys_csrf_error_renders_form_with_error(self):
+ """CSRF error on /setup/keys re-renders form with error message."""
+ from central.gui import _create_app
+ from fastapi_csrf_protect.exceptions import TokenValidationError
+
+ app = _create_app()
+ from fastapi_csrf_protect.exceptions import CsrfProtectError
+ handler = app.exception_handlers.get(CsrfProtectError)
+
+ mock_request = MagicMock()
+ mock_request.url.path = "/setup/keys"
+
+ exc = TokenValidationError("Invalid token")
+
+ with patch("central.gui.db.get_pool", return_value=None):
+ result = await handler(mock_request, exc)
+
+ assert hasattr(result, "body")
+ assert result.status_code == 200
+ body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
+ assert "session expired" in body.lower()
+
+ @pytest.mark.asyncio
+ async def test_setup_adapters_csrf_error_renders_form_with_error(self):
+ """CSRF error on /setup/adapters re-renders form with error message."""
+ from central.gui import _create_app
+ from fastapi_csrf_protect.exceptions import TokenValidationError
+
+ app = _create_app()
+ from fastapi_csrf_protect.exceptions import CsrfProtectError
+ handler = app.exception_handlers.get(CsrfProtectError)
+
+ mock_request = MagicMock()
+ mock_request.url.path = "/setup/adapters"
+
+ exc = TokenValidationError("Invalid token")
+
+ with patch("central.gui.db.get_pool", return_value=None):
+ result = await handler(mock_request, exc)
+
+ assert hasattr(result, "body")
+ assert result.status_code == 200
+ body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
+ assert "session expired" in body.lower()
+
+ @pytest.mark.asyncio
+ async def test_setup_finish_csrf_error_renders_form_with_error(self):
+ """CSRF error on /setup/finish re-renders form with error message."""
+ from central.gui import _create_app
+ from fastapi_csrf_protect.exceptions import TokenValidationError
+
+ app = _create_app()
+ from fastapi_csrf_protect.exceptions import CsrfProtectError
+ handler = app.exception_handlers.get(CsrfProtectError)
+
+ mock_request = MagicMock()
+ mock_request.url.path = "/setup/finish"
+
+ exc = TokenValidationError("Invalid token")
+
+ with patch("central.gui.db.get_pool", return_value=None):
+ result = await handler(mock_request, exc)
+
+ assert hasattr(result, "body")
+ assert result.status_code == 200
+ body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
+ assert "session expired" in body.lower()
+
+ @pytest.mark.asyncio
+ async def test_setup_base_csrf_error_redirects_to_setup(self):
+ """CSRF error on /setup redirects to /setup (middleware routes to step)."""
+ from central.gui import _create_app
+ from fastapi_csrf_protect.exceptions import TokenValidationError
+ from fastapi.responses import RedirectResponse
+
+ app = _create_app()
+ from fastapi_csrf_protect.exceptions import CsrfProtectError
+ handler = app.exception_handlers.get(CsrfProtectError)
+
+ mock_request = MagicMock()
+ mock_request.url.path = "/setup"
+
+ exc = TokenValidationError("Invalid token")
+
+ result = await handler(mock_request, exc)
+
+ assert isinstance(result, RedirectResponse)
+ assert result.status_code == 302
+
+ @pytest.mark.asyncio
+ async def test_login_csrf_error_still_works(self):
+ """CSRF error on /login still renders login form with error (regression test)."""
+ from central.gui import _create_app
+ from fastapi_csrf_protect.exceptions import TokenValidationError
+
+ app = _create_app()
+ from fastapi_csrf_protect.exceptions import CsrfProtectError
+ handler = app.exception_handlers.get(CsrfProtectError)
+
+ mock_request = MagicMock()
+ mock_request.url.path = "/login"
+
+ exc = TokenValidationError("Invalid token")
+
+ result = await handler(mock_request, exc)
+
+ assert hasattr(result, "body")
+ assert result.status_code == 200
+ body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
+ assert "session expired" in body.lower()