diff --git a/src/central/gui/routes.py b/src/central/gui/routes.py index 19a8b6d..47e7dba 100644 --- a/src/central/gui/routes.py +++ b/src/central/gui/routes.py @@ -24,6 +24,19 @@ from central.gui.db import get_pool router = APIRouter() +async def _validate_csrf_form(request, csrf_protect): + """Validate CSRF token from form data.""" + form = await request.form() + csrf_token = form.get("csrf_token") + if csrf_token: + cookie_token = request.cookies.get("fastapi-csrf-token") + if not cookie_token or cookie_token != csrf_token: + from fastapi_csrf_protect.exceptions import TokenValidationError + raise TokenValidationError("CSRF token mismatch") + else: + from fastapi_csrf_protect.exceptions import MissingTokenError + raise MissingTokenError("Missing CSRF token in form") + def _get_templates(): """Get templates instance (deferred import to avoid circular).""" from central.gui import templates @@ -62,13 +75,18 @@ async def health() -> dict: @router.get("/", response_class=HTMLResponse) -async def index(request: Request) -> HTMLResponse: +async def index(request: Request, csrf_protect: CsrfProtect = Depends()) -> HTMLResponse: """Render the index page.""" templates = _get_templates() - return templates.TemplateResponse( + operator = getattr(request.state, "operator", None) + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + response = templates.TemplateResponse( request=request, name="index.html", + context={"operator": operator, "csrf_token": signed_token}, ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response @router.get("/setup", response_class=HTMLResponse) @@ -101,7 +119,7 @@ async def setup_submit( pool = get_pool() # Validate CSRF - await csrf_protect.validate_csrf(request) + await _validate_csrf_form(request, csrf_protect) # Validate input error = None @@ -195,7 +213,7 @@ async def login_submit( pool = get_pool() # Validate CSRF - await csrf_protect.validate_csrf(request) + await _validate_csrf_form(request, csrf_protect) # Look up operator async with pool.acquire() as conn: @@ -261,7 +279,7 @@ async def logout( pool = get_pool() # Validate CSRF - await csrf_protect.validate_csrf(request) + await _validate_csrf_form(request, csrf_protect) # Get current session session_token = request.cookies.get("central_session") @@ -310,7 +328,7 @@ async def change_password_submit( operator = request.state.operator # Validate CSRF - await csrf_protect.validate_csrf(request) + await _validate_csrf_form(request, csrf_protect) # Get current password hash async with pool.acquire() as conn: