fix(gui): handle CSRF errors on wizard paths

Update csrf_exception_handler to re-render wizard forms with error
message instead of redirecting to /login when CSRF validation fails.

- /setup/operator: re-render with error
- /setup/system: re-render with current system values + error
- /setup/keys: re-render with current keys list + error
- /setup/adapters: re-render with current adapter config + error
- /setup/finish: re-render with summary data + error
- /setup: redirect to /setup (middleware routes to appropriate step)

Add error display to setup_keys.html and setup_finish.html templates.
Add 7 new CSRF handler tests for wizard paths.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
zvx-echo6 2026-05-17 19:41:58 -06:00
commit 616452c1df
4 changed files with 343 additions and 6 deletions

View file

@ -143,39 +143,205 @@ def _create_app() -> FastAPI:
@app.exception_handler(CsrfProtectError) @app.exception_handler(CsrfProtectError)
async def csrf_exception_handler(request, exc: CsrfProtectError): async def csrf_exception_handler(request, exc: CsrfProtectError):
from fastapi_csrf_protect import CsrfProtect from fastapi_csrf_protect import CsrfProtect
from central.gui.db import get_pool
csrf_protect = CsrfProtect() csrf_protect = CsrfProtect()
csrf_token, signed_token = csrf_protect.generate_csrf_tokens() csrf_token, signed_token = csrf_protect.generate_csrf_tokens()
error_msg = "Your session expired. Please try again."
if request.url.path == "/login": if request.url.path == "/login":
response = templates.TemplateResponse( response = templates.TemplateResponse(
request=request, request=request,
name="login.html", name="login.html",
context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."}, context={"csrf_token": csrf_token, "error": error_msg},
) )
csrf_protect.set_csrf_cookie(signed_token, response) csrf_protect.set_csrf_cookie(signed_token, response)
return response return response
elif request.url.path == "/setup": elif request.url.path == "/setup":
# /setup is a redirect path now, not a form
return RedirectResponse("/setup", status_code=302)
elif request.url.path == "/setup/operator":
response = templates.TemplateResponse( response = templates.TemplateResponse(
request=request, request=request,
name="setup.html", name="setup_operator.html",
context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."}, context={"csrf_token": csrf_token, "error": error_msg, "form_data": None},
) )
csrf_protect.set_csrf_cookie(signed_token, response) csrf_protect.set_csrf_cookie(signed_token, response)
return response return response
elif request.url.path == "/setup/system":
pool = get_pool()
system = {
"map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png",
"map_attribution": "&copy; OpenStreetMap contributors",
}
if pool:
try:
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT map_tile_url, map_attribution FROM config.system WHERE id = true"
)
if row:
system = {
"map_tile_url": row["map_tile_url"],
"map_attribution": row["map_attribution"],
}
except Exception:
pass
response = templates.TemplateResponse(
request=request,
name="setup_system.html",
context={
"csrf_token": csrf_token,
"error": error_msg,
"errors": None,
"form_data": None,
"system": system,
},
)
csrf_protect.set_csrf_cookie(signed_token, response)
return response
elif request.url.path == "/setup/keys":
pool = get_pool()
keys = []
if pool:
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"SELECT alias, created_at FROM config.api_keys ORDER BY alias"
)
keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in rows]
except Exception:
pass
response = templates.TemplateResponse(
request=request,
name="setup_keys.html",
context={
"csrf_token": csrf_token,
"keys": keys,
"errors": None,
"form_data": None,
"success": None,
"error": error_msg,
},
)
csrf_protect.set_csrf_cookie(signed_token, response)
return response
elif request.url.path == "/setup/adapters":
pool = get_pool()
adapters = []
api_keys = []
tile_url = "https://tile.openstreetmap.org/{z}/{x}/{y}.png"
tile_attribution = "&copy; OpenStreetMap contributors"
if pool:
try:
async with pool.acquire() as conn:
rows = await conn.fetch(
"SELECT name, enabled, cadence_s, settings FROM config.adapters ORDER BY name"
)
for row in rows:
settings = row["settings"] or {}
adapters.append({
"name": row["name"],
"enabled": row["enabled"],
"cadence_s": row["cadence_s"],
"settings": settings,
})
key_rows = await conn.fetch(
"SELECT alias FROM config.api_keys ORDER BY alias"
)
api_keys = [{"alias": k["alias"]} for k in key_rows]
sys_row = await conn.fetchrow(
"SELECT map_tile_url, map_attribution FROM config.system WHERE id = true"
)
if sys_row:
tile_url = sys_row["map_tile_url"]
tile_attribution = sys_row["map_attribution"]
except Exception:
pass
# Import helper functions for valid values
from central.gui.routes import _get_valid_satellites, _get_valid_feeds
response = templates.TemplateResponse(
request=request,
name="setup_adapters.html",
context={
"csrf_token": csrf_token,
"adapters": adapters,
"api_keys": api_keys,
"valid_satellites": _get_valid_satellites(),
"valid_feeds": sorted(_get_valid_feeds()),
"tile_url": tile_url,
"tile_attribution": tile_attribution,
"error": error_msg,
"errors": None,
"form_data": None,
},
)
csrf_protect.set_csrf_cookie(signed_token, response)
return response
elif request.url.path == "/setup/finish":
pool = get_pool()
operator_count = 0
key_count = 0
system = {"map_tile_url": ""}
adapters = []
if pool:
try:
async with pool.acquire() as conn:
operator_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators")
key_count = await conn.fetchval("SELECT COUNT(*) FROM config.api_keys")
sys_row = await conn.fetchrow(
"SELECT map_tile_url FROM config.system WHERE id = true"
)
if sys_row:
system = {"map_tile_url": sys_row["map_tile_url"]}
rows = await conn.fetch(
"SELECT name, enabled, cadence_s FROM config.adapters ORDER BY name"
)
adapters = [
{"name": row["name"], "enabled": row["enabled"], "cadence_s": row["cadence_s"]}
for row in rows
]
except Exception:
pass
response = templates.TemplateResponse(
request=request,
name="setup_finish.html",
context={
"csrf_token": csrf_token,
"operator_count": operator_count,
"key_count": key_count,
"system": system,
"adapters": adapters,
"error": error_msg,
},
)
csrf_protect.set_csrf_cookie(signed_token, response)
return response
elif request.url.path == "/logout": elif request.url.path == "/logout":
return RedirectResponse("/login", status_code=302) return RedirectResponse("/login", status_code=302)
elif request.url.path == "/change-password": elif request.url.path == "/change-password":
response = templates.TemplateResponse( response = templates.TemplateResponse(
request=request, request=request,
name="change_password.html", name="change_password.html",
context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."}, context={"csrf_token": csrf_token, "error": error_msg},
) )
csrf_protect.set_csrf_cookie(signed_token, response) csrf_protect.set_csrf_cookie(signed_token, response)
return response return response
elif request.url.path.startswith("/adapters/"): elif request.url.path.startswith("/adapters/"):
# Redirect back to adapters list # Redirect back to adapters list
return RedirectResponse("/adapters", status_code=302) return RedirectResponse("/adapters", status_code=302)
else: else:
# Fallback: redirect to login # Fallback: redirect to login
return RedirectResponse("/login", status_code=302) return RedirectResponse("/login", status_code=302)

View file

@ -13,6 +13,10 @@
<p>Review your configuration and finish the setup wizard.</p> <p>Review your configuration and finish the setup wizard.</p>
</header> </header>
{% if error %}
<p style="color: var(--pico-color-red-500);">{{ error }}</p>
{% endif %}
<h2>Summary</h2> <h2>Summary</h2>
<table> <table>

View file

@ -13,6 +13,10 @@
<p>Add API keys for adapters that require external service credentials (e.g., FIRMS).</p> <p>Add API keys for adapters that require external service credentials (e.g., FIRMS).</p>
</header> </header>
{% if error %}
<p style="color: var(--pico-color-red-500);">{{ error }}</p>
{% endif %}
{% if success %} {% if success %}
<p style="color: var(--pico-color-green-500);">{{ success }}</p> <p style="color: var(--pico-color-green-500);">{{ success }}</p>
{% endif %} {% endif %}

View file

@ -107,3 +107,166 @@ class TestCsrfHandlerNoTraceback:
assert handler is not None assert handler is not None
assert inspect.iscoroutinefunction(handler) assert inspect.iscoroutinefunction(handler)
class TestCsrfHandlerWizardPaths:
"""Test CSRF exception handler for wizard paths."""
@pytest.mark.asyncio
async def test_setup_operator_csrf_error_renders_form_with_error(self):
"""CSRF error on /setup/operator re-renders form with error message."""
from central.gui import _create_app
from fastapi_csrf_protect.exceptions import TokenValidationError
from fastapi.responses import HTMLResponse
app = _create_app()
from fastapi_csrf_protect.exceptions import CsrfProtectError
handler = app.exception_handlers.get(CsrfProtectError)
mock_request = MagicMock()
mock_request.url.path = "/setup/operator"
exc = TokenValidationError("Invalid token")
result = await handler(mock_request, exc)
# Should be HTML response, not redirect
assert hasattr(result, "body")
assert result.status_code == 200
body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
assert "session expired" in body.lower()
@pytest.mark.asyncio
async def test_setup_system_csrf_error_renders_form_with_error(self):
"""CSRF error on /setup/system re-renders form with error message."""
from central.gui import _create_app
from fastapi_csrf_protect.exceptions import TokenValidationError
app = _create_app()
from fastapi_csrf_protect.exceptions import CsrfProtectError
handler = app.exception_handlers.get(CsrfProtectError)
mock_request = MagicMock()
mock_request.url.path = "/setup/system"
exc = TokenValidationError("Invalid token")
with patch("central.gui.db.get_pool", return_value=None):
result = await handler(mock_request, exc)
assert hasattr(result, "body")
assert result.status_code == 200
body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
assert "session expired" in body.lower()
@pytest.mark.asyncio
async def test_setup_keys_csrf_error_renders_form_with_error(self):
"""CSRF error on /setup/keys re-renders form with error message."""
from central.gui import _create_app
from fastapi_csrf_protect.exceptions import TokenValidationError
app = _create_app()
from fastapi_csrf_protect.exceptions import CsrfProtectError
handler = app.exception_handlers.get(CsrfProtectError)
mock_request = MagicMock()
mock_request.url.path = "/setup/keys"
exc = TokenValidationError("Invalid token")
with patch("central.gui.db.get_pool", return_value=None):
result = await handler(mock_request, exc)
assert hasattr(result, "body")
assert result.status_code == 200
body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
assert "session expired" in body.lower()
@pytest.mark.asyncio
async def test_setup_adapters_csrf_error_renders_form_with_error(self):
"""CSRF error on /setup/adapters re-renders form with error message."""
from central.gui import _create_app
from fastapi_csrf_protect.exceptions import TokenValidationError
app = _create_app()
from fastapi_csrf_protect.exceptions import CsrfProtectError
handler = app.exception_handlers.get(CsrfProtectError)
mock_request = MagicMock()
mock_request.url.path = "/setup/adapters"
exc = TokenValidationError("Invalid token")
with patch("central.gui.db.get_pool", return_value=None):
result = await handler(mock_request, exc)
assert hasattr(result, "body")
assert result.status_code == 200
body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
assert "session expired" in body.lower()
@pytest.mark.asyncio
async def test_setup_finish_csrf_error_renders_form_with_error(self):
"""CSRF error on /setup/finish re-renders form with error message."""
from central.gui import _create_app
from fastapi_csrf_protect.exceptions import TokenValidationError
app = _create_app()
from fastapi_csrf_protect.exceptions import CsrfProtectError
handler = app.exception_handlers.get(CsrfProtectError)
mock_request = MagicMock()
mock_request.url.path = "/setup/finish"
exc = TokenValidationError("Invalid token")
with patch("central.gui.db.get_pool", return_value=None):
result = await handler(mock_request, exc)
assert hasattr(result, "body")
assert result.status_code == 200
body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
assert "session expired" in body.lower()
@pytest.mark.asyncio
async def test_setup_base_csrf_error_redirects_to_setup(self):
"""CSRF error on /setup redirects to /setup (middleware routes to step)."""
from central.gui import _create_app
from fastapi_csrf_protect.exceptions import TokenValidationError
from fastapi.responses import RedirectResponse
app = _create_app()
from fastapi_csrf_protect.exceptions import CsrfProtectError
handler = app.exception_handlers.get(CsrfProtectError)
mock_request = MagicMock()
mock_request.url.path = "/setup"
exc = TokenValidationError("Invalid token")
result = await handler(mock_request, exc)
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
@pytest.mark.asyncio
async def test_login_csrf_error_still_works(self):
"""CSRF error on /login still renders login form with error (regression test)."""
from central.gui import _create_app
from fastapi_csrf_protect.exceptions import TokenValidationError
app = _create_app()
from fastapi_csrf_protect.exceptions import CsrfProtectError
handler = app.exception_handlers.get(CsrfProtectError)
mock_request = MagicMock()
mock_request.url.path = "/login"
exc = TokenValidationError("Invalid token")
result = await handler(mock_request, exc)
assert hasattr(result, "body")
assert result.status_code == 200
body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
assert "session expired" in body.lower()