central/tests/test_session_auth.py
Matt Johnson c317c9ab01 fix(csrf): replace fastapi-csrf-protect with session-bound CSRF
Fixes CSRF race condition where every GET rotated the CSRF token,
causing POST failures when users had multiple tabs or slow connections.

Changes:
- Remove fastapi-csrf-protect dependency
- Add session-bound CSRF tokens stored in config.sessions table
- Add pre-auth CSRF for unauthenticated routes (/login, /setup/operator)
- Add csrf.py module for pre-auth token generation/validation
- Update routes to use new CSRF token handling
- Add migration 013 to add csrf_token column to sessions

The session-bound approach ensures CSRF tokens remain stable for the
duration of a session, eliminating the race condition.

Note: Route tests (test_wizard.py, test_adapters.py, etc.) need
refactoring to mock get_settings() instead of CsrfProtect dependency.
Core auth/CSRF handler tests pass (74 tests).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-05-18 03:16:37 +00:00

175 lines
6.7 KiB
Python

"""Tests for session authentication middleware."""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime, timezone
from starlette.testclient import TestClient
from fastapi import FastAPI, Request
from central.gui.middleware import SessionMiddleware
from central.gui.auth import Operator
class TestSessionMiddleware:
"""Tests for SessionMiddleware."""
@pytest.mark.asyncio
async def test_no_cookie_sets_none_on_exempt_path(self):
"""SessionMiddleware sets operator=None when no session cookie on exempt path."""
mock_pool = MagicMock()
mock_pool.acquire = MagicMock()
with patch("central.gui.middleware.get_pool", return_value=mock_pool):
app = FastAPI()
@app.get("/health")
async def health(request: Request):
return {"operator": getattr(request.state, "operator", "missing")}
app.add_middleware(SessionMiddleware)
client = TestClient(app)
response = client.get("/health")
assert response.status_code == 200
assert response.json()["operator"] is None
@pytest.mark.asyncio
async def test_valid_cookie_sets_operator_on_exempt_path(self):
"""SessionMiddleware sets operator when valid session cookie on exempt path."""
mock_pool = MagicMock()
mock_conn = MagicMock()
mock_conn.fetchrow = AsyncMock(return_value={
"id": 1,
"username": "admin",
"created_at": datetime.now(timezone.utc),
"password_changed_at": datetime.now(timezone.utc),
"csrf_token": "mock_csrf_token_12345",
})
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock()
mock_pool.acquire = MagicMock(return_value=mock_conn)
with patch("central.gui.middleware.get_pool", return_value=mock_pool):
app = FastAPI()
@app.get("/health")
async def health(request: Request):
op = getattr(request.state, "operator", None)
if op:
return {"username": op.username}
return {"operator": None}
app.add_middleware(SessionMiddleware)
client = TestClient(app, cookies={"central_session": "valid-token"})
response = client.get("/health")
assert response.status_code == 200
assert response.json()["username"] == "admin"
@pytest.mark.asyncio
async def test_no_cookie_redirects_on_protected_path(self):
"""SessionMiddleware redirects to /login when no cookie on protected path."""
mock_pool = MagicMock()
mock_pool.acquire = MagicMock()
with patch("central.gui.middleware.get_pool", return_value=mock_pool):
app = FastAPI()
@app.get("/")
async def index(request: Request):
return {"message": "home"}
@app.get("/login")
async def login():
return {"message": "login"}
app.add_middleware(SessionMiddleware)
client = TestClient(app, follow_redirects=False)
response = client.get("/")
assert response.status_code == 302
assert response.headers["location"] == "/login"
@pytest.mark.asyncio
async def test_valid_cookie_allows_protected_path(self):
"""SessionMiddleware allows protected path with valid session."""
mock_pool = MagicMock()
mock_conn = MagicMock()
mock_conn.fetchrow = AsyncMock(return_value={
"id": 1,
"username": "admin",
"created_at": datetime.now(timezone.utc),
"password_changed_at": datetime.now(timezone.utc),
"csrf_token": "mock_csrf_token_12345",
})
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock()
mock_pool.acquire = MagicMock(return_value=mock_conn)
with patch("central.gui.middleware.get_pool", return_value=mock_pool):
app = FastAPI()
@app.get("/")
async def index(request: Request):
op = request.state.operator
return {"message": "home", "user": op.username}
app.add_middleware(SessionMiddleware)
client = TestClient(app, cookies={"central_session": "valid-token"})
response = client.get("/")
assert response.status_code == 200
assert response.json()["user"] == "admin"
@pytest.mark.asyncio
async def test_invalid_cookie_redirects_on_protected_path(self):
"""SessionMiddleware redirects when session is invalid/expired."""
mock_pool = MagicMock()
mock_conn = MagicMock()
mock_conn.fetchrow = AsyncMock(return_value=None) # No session found
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock()
mock_pool.acquire = MagicMock(return_value=mock_conn)
with patch("central.gui.middleware.get_pool", return_value=mock_pool):
app = FastAPI()
@app.get("/")
async def index(request: Request):
return {"operator": getattr(request.state, "operator", "missing")}
@app.get("/login")
async def login():
return {"message": "login"}
app.add_middleware(SessionMiddleware)
client = TestClient(app, cookies={"central_session": "expired-token"}, follow_redirects=False)
response = client.get("/")
assert response.status_code == 302
assert response.headers["location"] == "/login"
@pytest.mark.asyncio
async def test_middleware_handles_db_error(self):
"""SessionMiddleware handles database errors gracefully on exempt path."""
mock_pool = MagicMock()
mock_conn = MagicMock()
mock_conn.fetchrow = AsyncMock(side_effect=Exception("DB error"))
mock_conn.__aenter__ = AsyncMock(return_value=mock_conn)
mock_conn.__aexit__ = AsyncMock()
mock_pool.acquire = MagicMock(return_value=mock_conn)
with patch("central.gui.middleware.get_pool", return_value=mock_pool):
app = FastAPI()
@app.get("/health")
async def health(request: Request):
return {"operator": getattr(request.state, "operator", "missing")}
app.add_middleware(SessionMiddleware)
client = TestClient(app, cookies={"central_session": "some-token"})
response = client.get("/health")
# Should not crash, just set operator to None
assert response.status_code == 200
assert response.json()["operator"] is None