mirror of
https://github.com/zvx-echo6/central.git
synced 2026-05-21 18:14:44 +02:00
fix(gui): handle CSRF errors on wizard paths
Update csrf_exception_handler to re-render wizard forms with error message instead of redirecting to /login when CSRF validation fails. - /setup/operator: re-render with error - /setup/system: re-render with current system values + error - /setup/keys: re-render with current keys list + error - /setup/adapters: re-render with current adapter config + error - /setup/finish: re-render with summary data + error - /setup: redirect to /setup (middleware routes to appropriate step) Add error display to setup_keys.html and setup_finish.html templates. Add 7 new CSRF handler tests for wizard paths. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
62116ca6a4
commit
616452c1df
4 changed files with 343 additions and 6 deletions
|
|
@ -143,39 +143,205 @@ def _create_app() -> FastAPI:
|
|||
@app.exception_handler(CsrfProtectError)
|
||||
async def csrf_exception_handler(request, exc: CsrfProtectError):
|
||||
from fastapi_csrf_protect import CsrfProtect
|
||||
|
||||
from central.gui.db import get_pool
|
||||
|
||||
csrf_protect = CsrfProtect()
|
||||
csrf_token, signed_token = csrf_protect.generate_csrf_tokens()
|
||||
|
||||
error_msg = "Your session expired. Please try again."
|
||||
|
||||
if request.url.path == "/login":
|
||||
response = templates.TemplateResponse(
|
||||
request=request,
|
||||
name="login.html",
|
||||
context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."},
|
||||
context={"csrf_token": csrf_token, "error": error_msg},
|
||||
)
|
||||
csrf_protect.set_csrf_cookie(signed_token, response)
|
||||
return response
|
||||
|
||||
elif request.url.path == "/setup":
|
||||
# /setup is a redirect path now, not a form
|
||||
return RedirectResponse("/setup", status_code=302)
|
||||
|
||||
elif request.url.path == "/setup/operator":
|
||||
response = templates.TemplateResponse(
|
||||
request=request,
|
||||
name="setup.html",
|
||||
context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."},
|
||||
name="setup_operator.html",
|
||||
context={"csrf_token": csrf_token, "error": error_msg, "form_data": None},
|
||||
)
|
||||
csrf_protect.set_csrf_cookie(signed_token, response)
|
||||
return response
|
||||
|
||||
elif request.url.path == "/setup/system":
|
||||
pool = get_pool()
|
||||
system = {
|
||||
"map_tile_url": "https://tile.openstreetmap.org/{z}/{x}/{y}.png",
|
||||
"map_attribution": "© OpenStreetMap contributors",
|
||||
}
|
||||
if pool:
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"SELECT map_tile_url, map_attribution FROM config.system WHERE id = true"
|
||||
)
|
||||
if row:
|
||||
system = {
|
||||
"map_tile_url": row["map_tile_url"],
|
||||
"map_attribution": row["map_attribution"],
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
response = templates.TemplateResponse(
|
||||
request=request,
|
||||
name="setup_system.html",
|
||||
context={
|
||||
"csrf_token": csrf_token,
|
||||
"error": error_msg,
|
||||
"errors": None,
|
||||
"form_data": None,
|
||||
"system": system,
|
||||
},
|
||||
)
|
||||
csrf_protect.set_csrf_cookie(signed_token, response)
|
||||
return response
|
||||
|
||||
elif request.url.path == "/setup/keys":
|
||||
pool = get_pool()
|
||||
keys = []
|
||||
if pool:
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"SELECT alias, created_at FROM config.api_keys ORDER BY alias"
|
||||
)
|
||||
keys = [{"alias": row["alias"], "created_at": row["created_at"]} for row in rows]
|
||||
except Exception:
|
||||
pass
|
||||
response = templates.TemplateResponse(
|
||||
request=request,
|
||||
name="setup_keys.html",
|
||||
context={
|
||||
"csrf_token": csrf_token,
|
||||
"keys": keys,
|
||||
"errors": None,
|
||||
"form_data": None,
|
||||
"success": None,
|
||||
"error": error_msg,
|
||||
},
|
||||
)
|
||||
csrf_protect.set_csrf_cookie(signed_token, response)
|
||||
return response
|
||||
|
||||
elif request.url.path == "/setup/adapters":
|
||||
pool = get_pool()
|
||||
adapters = []
|
||||
api_keys = []
|
||||
tile_url = "https://tile.openstreetmap.org/{z}/{x}/{y}.png"
|
||||
tile_attribution = "© OpenStreetMap contributors"
|
||||
if pool:
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"SELECT name, enabled, cadence_s, settings FROM config.adapters ORDER BY name"
|
||||
)
|
||||
for row in rows:
|
||||
settings = row["settings"] or {}
|
||||
adapters.append({
|
||||
"name": row["name"],
|
||||
"enabled": row["enabled"],
|
||||
"cadence_s": row["cadence_s"],
|
||||
"settings": settings,
|
||||
})
|
||||
key_rows = await conn.fetch(
|
||||
"SELECT alias FROM config.api_keys ORDER BY alias"
|
||||
)
|
||||
api_keys = [{"alias": k["alias"]} for k in key_rows]
|
||||
sys_row = await conn.fetchrow(
|
||||
"SELECT map_tile_url, map_attribution FROM config.system WHERE id = true"
|
||||
)
|
||||
if sys_row:
|
||||
tile_url = sys_row["map_tile_url"]
|
||||
tile_attribution = sys_row["map_attribution"]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Import helper functions for valid values
|
||||
from central.gui.routes import _get_valid_satellites, _get_valid_feeds
|
||||
|
||||
response = templates.TemplateResponse(
|
||||
request=request,
|
||||
name="setup_adapters.html",
|
||||
context={
|
||||
"csrf_token": csrf_token,
|
||||
"adapters": adapters,
|
||||
"api_keys": api_keys,
|
||||
"valid_satellites": _get_valid_satellites(),
|
||||
"valid_feeds": sorted(_get_valid_feeds()),
|
||||
"tile_url": tile_url,
|
||||
"tile_attribution": tile_attribution,
|
||||
"error": error_msg,
|
||||
"errors": None,
|
||||
"form_data": None,
|
||||
},
|
||||
)
|
||||
csrf_protect.set_csrf_cookie(signed_token, response)
|
||||
return response
|
||||
|
||||
elif request.url.path == "/setup/finish":
|
||||
pool = get_pool()
|
||||
operator_count = 0
|
||||
key_count = 0
|
||||
system = {"map_tile_url": ""}
|
||||
adapters = []
|
||||
if pool:
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
operator_count = await conn.fetchval("SELECT COUNT(*) FROM config.operators")
|
||||
key_count = await conn.fetchval("SELECT COUNT(*) FROM config.api_keys")
|
||||
sys_row = await conn.fetchrow(
|
||||
"SELECT map_tile_url FROM config.system WHERE id = true"
|
||||
)
|
||||
if sys_row:
|
||||
system = {"map_tile_url": sys_row["map_tile_url"]}
|
||||
rows = await conn.fetch(
|
||||
"SELECT name, enabled, cadence_s FROM config.adapters ORDER BY name"
|
||||
)
|
||||
adapters = [
|
||||
{"name": row["name"], "enabled": row["enabled"], "cadence_s": row["cadence_s"]}
|
||||
for row in rows
|
||||
]
|
||||
except Exception:
|
||||
pass
|
||||
response = templates.TemplateResponse(
|
||||
request=request,
|
||||
name="setup_finish.html",
|
||||
context={
|
||||
"csrf_token": csrf_token,
|
||||
"operator_count": operator_count,
|
||||
"key_count": key_count,
|
||||
"system": system,
|
||||
"adapters": adapters,
|
||||
"error": error_msg,
|
||||
},
|
||||
)
|
||||
csrf_protect.set_csrf_cookie(signed_token, response)
|
||||
return response
|
||||
|
||||
elif request.url.path == "/logout":
|
||||
return RedirectResponse("/login", status_code=302)
|
||||
|
||||
elif request.url.path == "/change-password":
|
||||
response = templates.TemplateResponse(
|
||||
request=request,
|
||||
name="change_password.html",
|
||||
context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."},
|
||||
context={"csrf_token": csrf_token, "error": error_msg},
|
||||
)
|
||||
csrf_protect.set_csrf_cookie(signed_token, response)
|
||||
return response
|
||||
|
||||
elif request.url.path.startswith("/adapters/"):
|
||||
# Redirect back to adapters list
|
||||
return RedirectResponse("/adapters", status_code=302)
|
||||
|
||||
else:
|
||||
# Fallback: redirect to login
|
||||
return RedirectResponse("/login", status_code=302)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,10 @@
|
|||
<p>Review your configuration and finish the setup wizard.</p>
|
||||
</header>
|
||||
|
||||
{% if error %}
|
||||
<p style="color: var(--pico-color-red-500);">{{ error }}</p>
|
||||
{% endif %}
|
||||
|
||||
<h2>Summary</h2>
|
||||
|
||||
<table>
|
||||
|
|
|
|||
|
|
@ -13,6 +13,10 @@
|
|||
<p>Add API keys for adapters that require external service credentials (e.g., FIRMS).</p>
|
||||
</header>
|
||||
|
||||
{% if error %}
|
||||
<p style="color: var(--pico-color-red-500);">{{ error }}</p>
|
||||
{% endif %}
|
||||
|
||||
{% if success %}
|
||||
<p style="color: var(--pico-color-green-500);">{{ success }}</p>
|
||||
{% endif %}
|
||||
|
|
|
|||
|
|
@ -107,3 +107,166 @@ class TestCsrfHandlerNoTraceback:
|
|||
|
||||
assert handler is not None
|
||||
assert inspect.iscoroutinefunction(handler)
|
||||
|
||||
|
||||
class TestCsrfHandlerWizardPaths:
|
||||
"""Test CSRF exception handler for wizard paths."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_operator_csrf_error_renders_form_with_error(self):
|
||||
"""CSRF error on /setup/operator re-renders form with error message."""
|
||||
from central.gui import _create_app
|
||||
from fastapi_csrf_protect.exceptions import TokenValidationError
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
app = _create_app()
|
||||
from fastapi_csrf_protect.exceptions import CsrfProtectError
|
||||
handler = app.exception_handlers.get(CsrfProtectError)
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.url.path = "/setup/operator"
|
||||
|
||||
exc = TokenValidationError("Invalid token")
|
||||
|
||||
result = await handler(mock_request, exc)
|
||||
|
||||
# Should be HTML response, not redirect
|
||||
assert hasattr(result, "body")
|
||||
assert result.status_code == 200
|
||||
body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
|
||||
assert "session expired" in body.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_system_csrf_error_renders_form_with_error(self):
|
||||
"""CSRF error on /setup/system re-renders form with error message."""
|
||||
from central.gui import _create_app
|
||||
from fastapi_csrf_protect.exceptions import TokenValidationError
|
||||
|
||||
app = _create_app()
|
||||
from fastapi_csrf_protect.exceptions import CsrfProtectError
|
||||
handler = app.exception_handlers.get(CsrfProtectError)
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.url.path = "/setup/system"
|
||||
|
||||
exc = TokenValidationError("Invalid token")
|
||||
|
||||
with patch("central.gui.db.get_pool", return_value=None):
|
||||
result = await handler(mock_request, exc)
|
||||
|
||||
assert hasattr(result, "body")
|
||||
assert result.status_code == 200
|
||||
body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
|
||||
assert "session expired" in body.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_keys_csrf_error_renders_form_with_error(self):
|
||||
"""CSRF error on /setup/keys re-renders form with error message."""
|
||||
from central.gui import _create_app
|
||||
from fastapi_csrf_protect.exceptions import TokenValidationError
|
||||
|
||||
app = _create_app()
|
||||
from fastapi_csrf_protect.exceptions import CsrfProtectError
|
||||
handler = app.exception_handlers.get(CsrfProtectError)
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.url.path = "/setup/keys"
|
||||
|
||||
exc = TokenValidationError("Invalid token")
|
||||
|
||||
with patch("central.gui.db.get_pool", return_value=None):
|
||||
result = await handler(mock_request, exc)
|
||||
|
||||
assert hasattr(result, "body")
|
||||
assert result.status_code == 200
|
||||
body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
|
||||
assert "session expired" in body.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_adapters_csrf_error_renders_form_with_error(self):
|
||||
"""CSRF error on /setup/adapters re-renders form with error message."""
|
||||
from central.gui import _create_app
|
||||
from fastapi_csrf_protect.exceptions import TokenValidationError
|
||||
|
||||
app = _create_app()
|
||||
from fastapi_csrf_protect.exceptions import CsrfProtectError
|
||||
handler = app.exception_handlers.get(CsrfProtectError)
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.url.path = "/setup/adapters"
|
||||
|
||||
exc = TokenValidationError("Invalid token")
|
||||
|
||||
with patch("central.gui.db.get_pool", return_value=None):
|
||||
result = await handler(mock_request, exc)
|
||||
|
||||
assert hasattr(result, "body")
|
||||
assert result.status_code == 200
|
||||
body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
|
||||
assert "session expired" in body.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_finish_csrf_error_renders_form_with_error(self):
|
||||
"""CSRF error on /setup/finish re-renders form with error message."""
|
||||
from central.gui import _create_app
|
||||
from fastapi_csrf_protect.exceptions import TokenValidationError
|
||||
|
||||
app = _create_app()
|
||||
from fastapi_csrf_protect.exceptions import CsrfProtectError
|
||||
handler = app.exception_handlers.get(CsrfProtectError)
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.url.path = "/setup/finish"
|
||||
|
||||
exc = TokenValidationError("Invalid token")
|
||||
|
||||
with patch("central.gui.db.get_pool", return_value=None):
|
||||
result = await handler(mock_request, exc)
|
||||
|
||||
assert hasattr(result, "body")
|
||||
assert result.status_code == 200
|
||||
body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
|
||||
assert "session expired" in body.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_setup_base_csrf_error_redirects_to_setup(self):
|
||||
"""CSRF error on /setup redirects to /setup (middleware routes to step)."""
|
||||
from central.gui import _create_app
|
||||
from fastapi_csrf_protect.exceptions import TokenValidationError
|
||||
from fastapi.responses import RedirectResponse
|
||||
|
||||
app = _create_app()
|
||||
from fastapi_csrf_protect.exceptions import CsrfProtectError
|
||||
handler = app.exception_handlers.get(CsrfProtectError)
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.url.path = "/setup"
|
||||
|
||||
exc = TokenValidationError("Invalid token")
|
||||
|
||||
result = await handler(mock_request, exc)
|
||||
|
||||
assert isinstance(result, RedirectResponse)
|
||||
assert result.status_code == 302
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_csrf_error_still_works(self):
|
||||
"""CSRF error on /login still renders login form with error (regression test)."""
|
||||
from central.gui import _create_app
|
||||
from fastapi_csrf_protect.exceptions import TokenValidationError
|
||||
|
||||
app = _create_app()
|
||||
from fastapi_csrf_protect.exceptions import CsrfProtectError
|
||||
handler = app.exception_handlers.get(CsrfProtectError)
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.url.path = "/login"
|
||||
|
||||
exc = TokenValidationError("Invalid token")
|
||||
|
||||
result = await handler(mock_request, exc)
|
||||
|
||||
assert hasattr(result, "body")
|
||||
assert result.status_code == 200
|
||||
body = result.body.decode() if hasattr(result.body, "decode") else str(result.body)
|
||||
assert "session expired" in body.lower()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue