diff --git a/src/central/gui/__init__.py b/src/central/gui/__init__.py index 4d2372c..1907d44 100644 --- a/src/central/gui/__init__.py +++ b/src/central/gui/__init__.py @@ -37,6 +37,8 @@ def _configure_csrf() -> None: class CsrfSettings(BaseModel): secret_key: str + token_location: str = "body" + token_key: str = "csrf_token" @CsrfProtect.load_config def get_csrf_config(): diff --git a/src/central/gui/routes.py b/src/central/gui/routes.py index 47e7dba..993f21a 100644 --- a/src/central/gui/routes.py +++ b/src/central/gui/routes.py @@ -24,19 +24,6 @@ from central.gui.db import get_pool router = APIRouter() -async def _validate_csrf_form(request, csrf_protect): - """Validate CSRF token from form data.""" - form = await request.form() - csrf_token = form.get("csrf_token") - if csrf_token: - cookie_token = request.cookies.get("fastapi-csrf-token") - if not cookie_token or cookie_token != csrf_token: - from fastapi_csrf_protect.exceptions import TokenValidationError - raise TokenValidationError("CSRF token mismatch") - else: - from fastapi_csrf_protect.exceptions import MissingTokenError - raise MissingTokenError("Missing CSRF token in form") - def _get_templates(): """Get templates instance (deferred import to avoid circular).""" from central.gui import templates @@ -119,7 +106,7 @@ async def setup_submit( pool = get_pool() # Validate CSRF - await _validate_csrf_form(request, csrf_protect) + await csrf_protect.validate_csrf(request) # Validate input error = None @@ -213,7 +200,7 @@ async def login_submit( pool = get_pool() # Validate CSRF - await _validate_csrf_form(request, csrf_protect) + await csrf_protect.validate_csrf(request) # Look up operator async with pool.acquire() as conn: @@ -279,7 +266,7 @@ async def logout( pool = get_pool() # Validate CSRF - await _validate_csrf_form(request, csrf_protect) + await csrf_protect.validate_csrf(request) # Get current session session_token = request.cookies.get("central_session") @@ -328,7 +315,7 @@ async def change_password_submit( operator = request.state.operator # Validate CSRF - await _validate_csrf_form(request, csrf_protect) + await csrf_protect.validate_csrf(request) # Get current password hash async with pool.acquire() as conn: