diff --git a/tests/conftest.py b/tests/conftest.py index ad93825..97a4f5d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -48,3 +48,49 @@ def mock_conn(): conn.fetchval = AsyncMock() conn.execute = AsyncMock() return conn + + +# CSRF fixtures for route tests + +@pytest.fixture +def bypass_pre_auth_csrf(): + """Patch pre-auth CSRF validation to always pass. + + Use for tests of pre-auth routes: /login, /setup/operator + """ + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_csrf_token", "test_signed_token")): + yield + + +@pytest.fixture +def bypass_session_csrf(): + """Create a mock request with session CSRF properly configured. + + Use for tests of authenticated routes that check request.state.csrf_token. + Returns a configured mock_request. + """ + request = MagicMock() + request.state.csrf_token = "test_csrf_token_12345" + request.state.operator = MagicMock() + request.state.operator.id = 1 + request.state.operator.username = "testuser" + + # Mock form() to return dict with matching CSRF token + form_data = {"csrf_token": "test_csrf_token_12345"} + + async def mock_form(): + return form_data + + request.form = mock_form + request._form_data = form_data # Allow tests to modify form data + + return request + + +@pytest.fixture +def patch_route_settings(): + """Patch get_settings in routes module.""" + with patch("central.gui.routes.get_settings") as mock: + mock.return_value.csrf_secret = "test-csrf-secret-for-testing-only-32chars" + yield mock diff --git a/tests/test_adapters.py b/tests/test_adapters.py index 17352f0..fa25c8b 100644 --- a/tests/test_adapters.py +++ b/tests/test_adapters.py @@ -55,13 +55,9 @@ class TestAdaptersListAuthenticated: mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await adapters_list(mock_request, mock_csrf) + result = await adapters_list(mock_request) # Verify template was called with adapters call_args = mock_templates.TemplateResponse.call_args @@ -105,13 +101,9 @@ class TestAdaptersEditForm: mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await adapters_edit_form(mock_request, "nws", mock_csrf) + result = await adapters_edit_form(mock_request, "nws") call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -133,11 +125,8 @@ class TestAdaptersEditForm: mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await adapters_edit_form(mock_request, "nonexistent", mock_csrf) + result = await adapters_edit_form(mock_request, "nonexistent") assert result.status_code == 404 @@ -156,7 +145,9 @@ class TestAdaptersEditSubmit: # Mock form data mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "120", "contact_email": "new@example.com", "region_north": "49.0", @@ -183,12 +174,9 @@ class TestAdaptersEditSubmit: mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: - result = await adapters_edit_submit(mock_request, "nws", mock_csrf) + result = await adapters_edit_submit(mock_request, "nws") assert result.status_code == 302 assert result.headers["location"] == "/adapters" @@ -204,7 +192,9 @@ class TestAdaptersEditSubmit: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "30", "contact_email": "test@example.com", "region_north": "49.0", @@ -239,14 +229,9 @@ class TestAdaptersEditSubmit: mock_response.status_code = 200 mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await adapters_edit_submit(mock_request, "nws", mock_csrf) + result = await adapters_edit_submit(mock_request, "nws") # Should re-render form with error call_args = mock_templates.TemplateResponse.call_args @@ -263,7 +248,9 @@ class TestAdaptersEditSubmit: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "300", "api_key_alias": "nonexistent_key", "region_north": "49.5", @@ -299,14 +286,9 @@ class TestAdaptersEditSubmit: mock_response.status_code = 200 mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await adapters_edit_submit(mock_request, "firms", mock_csrf) + result = await adapters_edit_submit(mock_request, "firms") call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -322,7 +304,9 @@ class TestAdaptersEditSubmit: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "120", "feed": "invalid_feed", "region_north": "49.0", @@ -357,14 +341,9 @@ class TestAdaptersEditSubmit: mock_response.status_code = 200 mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await adapters_edit_submit(mock_request, "usgs_quake", mock_csrf) + result = await adapters_edit_submit(mock_request, "usgs_quake") call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -383,7 +362,9 @@ class TestAdaptersAudit: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "120", "contact_email": "new@example.com", "region_north": "49.0", @@ -410,9 +391,6 @@ class TestAdaptersAudit: mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - captured_audit = {} async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None): @@ -423,7 +401,7 @@ class TestAdaptersAudit: with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.write_audit", side_effect=capture_audit): - result = await adapters_edit_submit(mock_request, "nws", mock_csrf) + result = await adapters_edit_submit(mock_request, "nws") assert captured_audit["action"] == "adapter.update" assert captured_audit["target"] == "nws" @@ -449,7 +427,9 @@ class TestAdaptersJsonbRegression: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "120", "contact_email": "test@example.com", "region_north": "49.0", @@ -476,12 +456,9 @@ class TestAdaptersJsonbRegression: mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.write_audit", new_callable=AsyncMock): - await adapters_edit_submit(mock_request, "nws", mock_csrf) + await adapters_edit_submit(mock_request, "nws") # Get the settings argument passed to execute (3rd positional arg after query) call_args = mock_conn.execute.call_args @@ -502,7 +479,9 @@ class TestAdaptersJsonbRegression: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "120", "contact_email": "new@example.com", "region_north": "49.0", @@ -529,9 +508,6 @@ class TestAdaptersJsonbRegression: mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - captured_audit = {} async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None): @@ -540,7 +516,7 @@ class TestAdaptersJsonbRegression: with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.write_audit", side_effect=capture_audit): - await adapters_edit_submit(mock_request, "nws", mock_csrf) + await adapters_edit_submit(mock_request, "nws") # CRITICAL: before and after must be dicts, NOT strings assert isinstance(captured_audit["before"], dict), f"before should be dict, got {type(captured_audit['before'])}" diff --git a/tests/test_api_keys.py b/tests/test_api_keys.py index 6bd43be..674231b 100644 --- a/tests/test_api_keys.py +++ b/tests/test_api_keys.py @@ -75,13 +75,9 @@ class TestApiKeysListAuthenticated: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await api_keys_list(mock_request, mock_csrf) + result = await api_keys_list(mock_request) # Check template was called with correct context call_args = mock_templates.TemplateResponse.call_args @@ -104,7 +100,8 @@ class TestApiKeysCreate: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") - form_data = {"alias": "test1", "plaintext_key": "secret-api-key-123"} + mock_request.state.csrf_token = "test_csrf_token" + form_data = {"csrf_token": "test_csrf_token", "alias": "test1", "plaintext_key": "secret-api-key-123"} mock_request.form = AsyncMock(return_value=form_data) mock_conn = AsyncMock() @@ -119,13 +116,10 @@ class TestApiKeysCreate: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.crypto.encrypt", return_value=b"encrypted_data"): with patch("central.gui.routes.write_audit", new_callable=AsyncMock): - result = await api_keys_create(mock_request, mock_csrf) + result = await api_keys_create(mock_request) assert result.status_code == 302 assert result.headers["location"] == "/api-keys" @@ -136,7 +130,8 @@ class TestApiKeysCreate: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") - form_data = {"alias": "firms", "plaintext_key": "secret-key"} + mock_request.state.csrf_token = "test_csrf_token" + form_data = {"csrf_token": "test_csrf_token", "alias": "firms", "plaintext_key": "secret-key"} mock_request.form = AsyncMock(return_value=form_data) mock_templates = MagicMock() @@ -150,15 +145,10 @@ class TestApiKeysCreate: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.crypto.encrypt", return_value=b"encrypted"): - result = await api_keys_create(mock_request, mock_csrf) + result = await api_keys_create(mock_request) # Should re-render form with error call_args = mock_templates.TemplateResponse.call_args @@ -172,7 +162,8 @@ class TestApiKeysCreate: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") - form_data = {"alias": "", "plaintext_key": "secret-key"} + mock_request.state.csrf_token = "test_csrf_token" + form_data = {"csrf_token": "test_csrf_token", "alias": "", "plaintext_key": "secret-key"} mock_request.form = AsyncMock(return_value=form_data) mock_templates = MagicMock() @@ -183,14 +174,9 @@ class TestApiKeysCreate: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await api_keys_create(mock_request, mock_csrf) + result = await api_keys_create(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -203,7 +189,8 @@ class TestApiKeysCreate: mock_request.state.operator = MagicMock(id=1, username="admin") # Test with space - form_data = {"alias": "test key", "plaintext_key": "secret-key"} + mock_request.state.csrf_token = "test_csrf_token" + form_data = {"csrf_token": "test_csrf_token", "alias": "test key", "plaintext_key": "secret-key"} mock_request.form = AsyncMock(return_value=form_data) mock_templates = MagicMock() @@ -214,14 +201,9 @@ class TestApiKeysCreate: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await api_keys_create(mock_request, mock_csrf) + result = await api_keys_create(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -233,7 +215,8 @@ class TestApiKeysCreate: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") - form_data = {"alias": "test-key", "plaintext_key": "secret-key"} + mock_request.state.csrf_token = "test_csrf_token" + form_data = {"csrf_token": "test_csrf_token", "alias": "test-key", "plaintext_key": "secret-key"} mock_request.form = AsyncMock(return_value=form_data) mock_templates = MagicMock() @@ -244,14 +227,9 @@ class TestApiKeysCreate: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await api_keys_create(mock_request, mock_csrf) + result = await api_keys_create(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -267,7 +245,8 @@ class TestApiKeysRotate: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") - form_data = {"new_plaintext_key": "new-secret-key-456"} + mock_request.state.csrf_token = "test_csrf_token" + form_data = {"csrf_token": "test_csrf_token", "new_plaintext_key": "new-secret-key-456"} mock_request.form = AsyncMock(return_value=form_data) mock_conn = AsyncMock() @@ -290,13 +269,10 @@ class TestApiKeysRotate: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.crypto.encrypt", return_value=b"new_encrypted"): with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: - result = await api_keys_rotate(mock_request, "test1", mock_csrf) + result = await api_keys_rotate(mock_request, "test1") assert result.status_code == 302 # Check audit was called with no plaintext @@ -313,6 +289,8 @@ class TestApiKeysDelete: """POST /api-keys/{alias}/delete with references shows error.""" mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + mock_request.form = AsyncMock(return_value={"csrf_token": "test_csrf_token"}) mock_templates = MagicMock() mock_templates.TemplateResponse.return_value = MagicMock() @@ -331,14 +309,9 @@ class TestApiKeysDelete: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await api_keys_delete(mock_request, "firms", mock_csrf) + result = await api_keys_delete(mock_request, "firms") # Should re-render with error call_args = mock_templates.TemplateResponse.call_args @@ -351,6 +324,8 @@ class TestApiKeysDelete: """POST /api-keys/{alias}/delete without references deletes and redirects.""" mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + mock_request.form = AsyncMock(return_value={"csrf_token": "test_csrf_token"}) mock_conn = AsyncMock() mock_conn.fetchrow.return_value = { @@ -367,12 +342,9 @@ class TestApiKeysDelete: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: - result = await api_keys_delete(mock_request, "test1", mock_csrf) + result = await api_keys_delete(mock_request, "test1") assert result.status_code == 302 assert result.headers["location"] == "/api-keys" @@ -388,7 +360,8 @@ class TestApiKeysAuditNoPlaintext: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") - form_data = {"alias": "newkey", "plaintext_key": "super-secret-value"} + mock_request.state.csrf_token = "test_csrf_token" + form_data = {"csrf_token": "test_csrf_token", "alias": "newkey", "plaintext_key": "super-secret-value"} mock_request.form = AsyncMock(return_value=form_data) mock_conn = AsyncMock() @@ -401,13 +374,10 @@ class TestApiKeysAuditNoPlaintext: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.crypto.encrypt", return_value=b"encrypted"): with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: - await api_keys_create(mock_request, mock_csrf) + await api_keys_create(mock_request) # Check audit call arguments call_kwargs = mock_audit.call_args.kwargs diff --git a/tests/test_config_store.py b/tests/test_config_store.py index 797a221..4653e32 100644 --- a/tests/test_config_store.py +++ b/tests/test_config_store.py @@ -39,6 +39,7 @@ def setup_master_key(master_key_path: Path, monkeypatch: pytest.MonkeyPatch) -> clear_key_cache() monkeypatch.setenv("CENTRAL_DB_DSN", TEST_DB_DSN) monkeypatch.setenv("CENTRAL_MASTER_KEY_PATH", str(master_key_path)) + monkeypatch.setenv("CENTRAL_CSRF_SECRET", "test-csrf-secret-for-testing-only-32chars") @pytest_asyncio.fixture diff --git a/tests/test_region_picker.py b/tests/test_region_picker.py index f5a8816..63683ea 100644 --- a/tests/test_region_picker.py +++ b/tests/test_region_picker.py @@ -51,13 +51,9 @@ class TestRegionPickerInTemplate: mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await adapters_edit_form(mock_request, "firms", mock_csrf) + result = await adapters_edit_form(mock_request, "firms") call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -79,7 +75,9 @@ class TestRegionValidation: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "300", "api_key_alias": "firms", "region_north": "45.0", @@ -109,9 +107,6 @@ class TestRegionValidation: mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - captured_settings = {} async def capture_execute(query, *args): @@ -122,7 +117,7 @@ class TestRegionValidation: with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.write_audit", new_callable=AsyncMock): - result = await adapters_edit_submit(mock_request, "firms", mock_csrf) + result = await adapters_edit_submit(mock_request, "firms") assert result.status_code == 302 assert captured_settings["settings"]["region"]["north"] == 45.0 @@ -139,7 +134,9 @@ class TestRegionValidation: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "300", "api_key_alias": "firms", "region_north": "30.0", # Less than south! @@ -175,14 +172,9 @@ class TestRegionValidation: mock_response.status_code = 200 mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await adapters_edit_submit(mock_request, "firms", mock_csrf) + result = await adapters_edit_submit(mock_request, "firms") call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -198,7 +190,9 @@ class TestRegionValidation: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "300", "api_key_alias": "firms", "region_north": "45.0", @@ -234,14 +228,9 @@ class TestRegionValidation: mock_response.status_code = 200 mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await adapters_edit_submit(mock_request, "firms", mock_csrf) + result = await adapters_edit_submit(mock_request, "firms") call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -257,7 +246,9 @@ class TestRegionValidation: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "300", "api_key_alias": "firms", "region_north": "95.0", # > 90! @@ -293,14 +284,9 @@ class TestRegionValidation: mock_response.status_code = 200 mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await adapters_edit_submit(mock_request, "firms", mock_csrf) + result = await adapters_edit_submit(mock_request, "firms") call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -319,7 +305,9 @@ class TestRegionAuditLog: mock_request.state.operator = MagicMock(id=1, username="testop") mock_form = MagicMock() + mock_request.state.csrf_token = "test_csrf_token" mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", "cadence_s": "300", "api_key_alias": "firms", "region_north": "45.0", @@ -352,9 +340,6 @@ class TestRegionAuditLog: mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - captured_audit = {} async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None): @@ -363,7 +348,7 @@ class TestRegionAuditLog: with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.write_audit", side_effect=capture_audit): - result = await adapters_edit_submit(mock_request, "firms", mock_csrf) + result = await adapters_edit_submit(mock_request, "firms") # Before should have old region assert captured_audit["before"]["settings"]["region"]["north"] == 49.5 diff --git a/tests/test_streams.py b/tests/test_streams.py index c2346fa..528e0ed 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -52,10 +52,6 @@ class TestStreamsListAuthenticated: mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - # Mock JetStream with proper state fields mock_js = AsyncMock() mock_stream_info = MagicMock() @@ -74,7 +70,7 @@ class TestStreamsListAuthenticated: with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.nats.get_js", return_value=mock_js): - result = await streams_list(mock_request, mock_csrf) + result = await streams_list(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -117,14 +113,10 @@ class TestStreamsListNatsUnavailable: mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.nats.get_js", return_value=None): - result = await streams_list(mock_request, mock_csrf) + result = await streams_list(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -157,10 +149,6 @@ class TestStreamsListPartialFailure: mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - # Mock JetStream - CENTRAL_FIRE raises ValueError, CENTRAL_WX works mock_js = AsyncMock() test_ts = datetime(2026, 5, 17, 12, 0, 0, tzinfo=timezone.utc) @@ -184,7 +172,7 @@ class TestStreamsListPartialFailure: with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.nats.get_js", return_value=mock_js): - result = await streams_list(mock_request, mock_csrf) + result = await streams_list(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -222,10 +210,6 @@ class TestStreamsListEmptyStream: mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - # Mock JetStream with empty stream (first_seq = 0) mock_js = AsyncMock() mock_stream_info = MagicMock() @@ -239,7 +223,7 @@ class TestStreamsListEmptyStream: with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.nats.get_js", return_value=mock_js): - result = await streams_list(mock_request, mock_csrf) + result = await streams_list(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -278,10 +262,6 @@ class TestStreamsListSingleMessage: mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - # Mock JetStream with single message (first_seq == last_seq) mock_js = AsyncMock() mock_stream_info = MagicMock() @@ -299,7 +279,7 @@ class TestStreamsListSingleMessage: with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.nats.get_js", return_value=mock_js): - result = await streams_list(mock_request, mock_csrf) + result = await streams_list(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -337,10 +317,6 @@ class TestStreamsListGetMsgFailure: mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - # Mock JetStream mock_js = AsyncMock() mock_stream_info = MagicMock() @@ -365,7 +341,7 @@ class TestStreamsListGetMsgFailure: with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.nats.get_js", return_value=mock_js): - result = await streams_list(mock_request, mock_csrf) + result = await streams_list(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -394,8 +370,12 @@ class TestStreamsUpdate: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="testop") + mock_request.state.csrf_token = "test_csrf_token" mock_form = MagicMock() - mock_form.get.return_value = "1209600" # 14 days + mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", + "max_age_s": "1209600", + }.get(k, d) mock_request.form = AsyncMock(return_value=mock_form) mock_conn = AsyncMock() @@ -406,9 +386,6 @@ class TestStreamsUpdate: mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - captured_audit = {} async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None): @@ -419,7 +396,7 @@ class TestStreamsUpdate: with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.write_audit", side_effect=capture_audit): - result = await streams_update(mock_request, "CENTRAL_WX", mock_csrf) + result = await streams_update(mock_request, "CENTRAL_WX") assert result.status_code == 302 assert result.headers["location"] == "/streams" @@ -438,8 +415,12 @@ class TestStreamsUpdate: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="testop") + mock_request.state.csrf_token = "test_csrf_token" mock_form = MagicMock() - mock_form.get.return_value = "60" # 1 minute - too small + mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", + "max_age_s": "60", + }.get(k, d) mock_request.form = AsyncMock(return_value=mock_form) mock_conn = AsyncMock() @@ -458,15 +439,10 @@ class TestStreamsUpdate: mock_response.status_code = 200 mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.nats.get_js", return_value=None): - result = await streams_update(mock_request, "CENTRAL_WX", mock_csrf) + result = await streams_update(mock_request, "CENTRAL_WX") call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -480,9 +456,10 @@ class TestStreamsUpdate: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="testop") + mock_request.state.csrf_token = "test_csrf_token" mock_form = MagicMock() - mock_form.get.return_value = "999999999" # Way too large + mock_form.get.side_effect = lambda k, d="": {"csrf_token": "test_csrf_token", "max_age_s": "999999999"}.get(k, d) # Way too large mock_request.form = AsyncMock(return_value=mock_form) mock_conn = AsyncMock() @@ -501,15 +478,10 @@ class TestStreamsUpdate: mock_response.status_code = 200 mock_templates.TemplateResponse.return_value = mock_response - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.nats.get_js", return_value=None): - result = await streams_update(mock_request, "CENTRAL_WX", mock_csrf) + result = await streams_update(mock_request, "CENTRAL_WX") call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -523,8 +495,12 @@ class TestStreamsUpdate: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="testop") + mock_request.state.csrf_token = "test_csrf_token" mock_form = MagicMock() - mock_form.get.return_value = "604800" + mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", + "max_age_s": "604800", + }.get(k, d) mock_request.form = AsyncMock(return_value=mock_form) mock_conn = AsyncMock() @@ -534,11 +510,8 @@ class TestStreamsUpdate: mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await streams_update(mock_request, "nonexistent", mock_csrf) + result = await streams_update(mock_request, "nonexistent") assert result.status_code == 404 @@ -554,8 +527,12 @@ class TestStreamsAudit: mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="testop") + mock_request.state.csrf_token = "test_csrf_token" mock_form = MagicMock() - mock_form.get.return_value = "1209600" # 14 days + mock_form.get.side_effect = lambda k, d="": { + "csrf_token": "test_csrf_token", + "max_age_s": "1209600", + }.get(k, d) mock_request.form = AsyncMock(return_value=mock_form) mock_conn = AsyncMock() @@ -566,9 +543,6 @@ class TestStreamsAudit: mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - captured_audit = {} async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None): @@ -580,7 +554,7 @@ class TestStreamsAudit: with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.write_audit", side_effect=capture_audit): - await streams_update(mock_request, "CENTRAL_QUAKE", mock_csrf) + await streams_update(mock_request, "CENTRAL_QUAKE") assert captured_audit["action"] == "stream.update" assert captured_audit["operator_id"] == 1 diff --git a/tests/test_wizard.py b/tests/test_wizard.py index f2b3c21..dcaa7fe 100644 --- a/tests/test_wizard.py +++ b/tests/test_wizard.py @@ -4,7 +4,6 @@ from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest - from central.gui.routes import ( setup_operator_form, setup_operator_submit, @@ -87,18 +86,17 @@ class TestSetupOperatorForm: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await setup_operator_form(mock_request, mock_csrf) + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret" + with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_token", "signed_token")): + result = await setup_operator_form(mock_request) mock_templates.TemplateResponse.assert_called_once() call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) - assert context["csrf_token"] == "token" + assert "csrf_token" in context and context["csrf_token"] assert context["error"] is None assert context["existing_operator"] is None @@ -119,13 +117,12 @@ class TestSetupOperatorForm: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await setup_operator_form(mock_request, mock_csrf) + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret" + with patch("central.gui.routes.generate_pre_auth_csrf", return_value=("test_token", "signed_token")): + result = await setup_operator_form(mock_request) mock_templates.TemplateResponse.assert_called_once() call_args = mock_templates.TemplateResponse.call_args @@ -141,6 +138,13 @@ class TestSetupOperatorSubmit: async def test_password_mismatch_shows_error(self): """POST with password mismatch re-renders with error.""" mock_request = MagicMock() + mock_request.state.csrf_token = "test_csrf" + mock_request.form = AsyncMock(return_value={ + "csrf_token": "test_csrf", + "username": "testuser", + "password": "password1", + "confirm_password": "password2", # Mismatch + }) mock_templates = MagicMock() mock_templates.TemplateResponse.return_value = MagicMock() @@ -151,20 +155,17 @@ class TestSetupOperatorSubmit: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await setup_operator_submit( - mock_request, - username="admin", - password="password123", - confirm_password="different", - csrf_protect=mock_csrf, - ) + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret" + result = await setup_operator_submit( + mock_request, + username="testuser", + password="password1", + confirm_password="password2", + ) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -174,6 +175,13 @@ class TestSetupOperatorSubmit: async def test_valid_creates_operator_and_redirects(self): """POST with valid data creates operator and redirects to /setup/system.""" mock_request = MagicMock() + mock_request.state.csrf_token = "test_csrf" + mock_request.form = AsyncMock(return_value={ + "csrf_token": "test_csrf", + "username": "testuser", + "password": "password123", + "confirm_password": "password123", + }) mock_conn = AsyncMock() mock_conn.fetchval.return_value = 0 # No existing operators @@ -186,21 +194,20 @@ class TestSetupOperatorSubmit: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.hash_password", return_value="hashed"): - with patch("central.gui.routes.create_session", new_callable=AsyncMock) as mock_session: - mock_session.return_value = ("session_token", datetime.now()) - with patch("central.gui.routes.write_audit", new_callable=AsyncMock): - result = await setup_operator_submit( - mock_request, - username="admin", - password="password123", - confirm_password="password123", - csrf_protect=mock_csrf, - ) + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret" + with patch("central.gui.routes.hash_password", return_value="hashed"): + with patch("central.gui.routes.create_session", new_callable=AsyncMock) as mock_session: + mock_session.return_value = ("session_token", datetime.now(), "csrf_token") + with patch("central.gui.routes.write_audit", new_callable=AsyncMock): + result = await setup_operator_submit( + mock_request, + username="testuser", + password="password123", + confirm_password="password123", + ) assert result.status_code == 302 assert result.headers["location"] == "/setup/system" @@ -209,6 +216,12 @@ class TestSetupOperatorSubmit: async def test_post_when_operator_exists_shows_confirmation(self): """POST when operator exists returns 200 with confirmation, no insert.""" mock_request = MagicMock() + mock_request.form = AsyncMock(return_value={ + "csrf_token": "test_csrf", + "username": "testuser", + "password": "password123", + "confirm_password": "password123", + }) mock_templates = MagicMock() mock_response = MagicMock() @@ -223,21 +236,19 @@ class TestSetupOperatorSubmit: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - + mock_request.state.csrf_token = "test_csrf" with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: - result = await setup_operator_submit( - mock_request, - username="newadmin", - password="password123", - confirm_password="password123", - csrf_protect=mock_csrf, - ) + with patch("central.gui.routes.validate_pre_auth_csrf", return_value=True): + with patch("central.gui.routes.get_settings") as mock_settings: + mock_settings.return_value.csrf_secret = "testsecret" + with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: + result = await setup_operator_submit( + mock_request, + username="testuser", + password="password123", + confirm_password="password123", + ) # Should return 200, not 500 or redirect assert result.status_code == 200 @@ -259,10 +270,7 @@ class TestSetupSystemForm: """GET /setup/system without auth redirects to /setup/operator.""" mock_request = MagicMock() mock_request.state.operator = None - - mock_csrf = MagicMock() - - result = await setup_system_form(mock_request, mock_csrf) + result = await setup_system_form(mock_request) assert result.status_code == 302 assert result.headers["location"] == "/setup/operator" @@ -285,13 +293,9 @@ class TestSetupSystemForm: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await setup_system_form(mock_request, mock_csrf) + result = await setup_system_form(mock_request) mock_templates.TemplateResponse.assert_called_once() @@ -304,9 +308,11 @@ class TestSetupSystemSubmit: """POST without {z},{x},{y} placeholders shows error.""" mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" form_data = MagicMock() form_data.get = lambda k, default="": { + "csrf_token": "test_csrf_token", "map_tile_url": "https://example.com/tiles", "map_attribution": "Test", }.get(k, default) @@ -325,14 +331,9 @@ class TestSetupSystemSubmit: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await setup_system_submit(mock_request, mock_csrf) + result = await setup_system_submit(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -343,9 +344,11 @@ class TestSetupSystemSubmit: """POST with valid data updates system and redirects to /setup/keys.""" mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" form_data = MagicMock() form_data.get = lambda k, default="": { + "csrf_token": "test_csrf_token", "map_tile_url": "https://example.com/{z}/{x}/{y}.png", "map_attribution": "Test Attribution", }.get(k, default) @@ -362,12 +365,9 @@ class TestSetupSystemSubmit: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.write_audit", new_callable=AsyncMock): - result = await setup_system_submit(mock_request, mock_csrf) + result = await setup_system_submit(mock_request) assert result.status_code == 302 assert result.headers["location"] == "/setup/keys" @@ -381,10 +381,7 @@ class TestSetupKeysForm: """GET /setup/keys without auth redirects to /setup/operator.""" mock_request = MagicMock() mock_request.state.operator = None - - mock_csrf = MagicMock() - - result = await setup_keys_form(mock_request, mock_csrf) + result = await setup_keys_form(mock_request) assert result.status_code == 302 assert result.headers["location"] == "/setup/operator" @@ -397,16 +394,17 @@ class TestSetupKeysSubmit: """POST with action=next redirects to /setup/adapters.""" mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" form_data = MagicMock() - form_data.get = lambda k, default="": {"action": "next"}.get(k, default) + form_data.get = lambda k, default="": { + "csrf_token": "test_csrf_token", + "action": "next", + }.get(k, default) mock_request.form = AsyncMock(return_value=form_data) - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - # No need to mock get_pool since action="next" returns before it's called - result = await setup_keys_submit(mock_request, mock_csrf) + result = await setup_keys_submit(mock_request) assert result.status_code == 302 assert result.headers["location"] == "/setup/adapters" @@ -415,9 +413,11 @@ class TestSetupKeysSubmit: """POST with action=add creates key and re-renders with success.""" mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" form_data = MagicMock() form_data.get = lambda k, default="": { + "csrf_token": "test_csrf_token", "action": "add", "alias": "testkey", "plaintext_key": "secret123", @@ -441,16 +441,11 @@ class TestSetupKeysSubmit: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.crypto.encrypt", return_value=b"encrypted"): with patch("central.gui.routes.write_audit", new_callable=AsyncMock): - result = await setup_keys_submit(mock_request, mock_csrf) + result = await setup_keys_submit(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -465,10 +460,7 @@ class TestSetupAdaptersForm: """GET /setup/adapters without auth redirects to /setup/operator.""" mock_request = MagicMock() mock_request.state.operator = None - - mock_csrf = MagicMock() - - result = await setup_adapters_form(mock_request, mock_csrf) + result = await setup_adapters_form(mock_request) assert result.status_code == 302 assert result.headers["location"] == "/setup/operator" @@ -481,10 +473,7 @@ class TestSetupFinishForm: """GET /setup/finish without auth redirects to /setup/operator.""" mock_request = MagicMock() mock_request.state.operator = None - - mock_csrf = MagicMock() - - result = await setup_finish_form(mock_request, mock_csrf) + result = await setup_finish_form(mock_request) assert result.status_code == 302 assert result.headers["location"] == "/setup/operator" @@ -509,13 +498,9 @@ class TestSetupFinishForm: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.generate_csrf_tokens.return_value = ("token", "signed") - mock_csrf.set_csrf_cookie = MagicMock() - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): - result = await setup_finish_form(mock_request, mock_csrf) + result = await setup_finish_form(mock_request) call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) @@ -532,6 +517,12 @@ class TestSetupFinishSubmit: """POST /setup/finish marks setup_complete=true and redirects to /.""" mock_request = MagicMock() mock_request.state.operator = MagicMock(id=1, username="admin") + mock_request.state.csrf_token = "test_csrf_token" + + # Mock form with CSRF token + form_data = MagicMock() + form_data.get = lambda k, default="": {"csrf_token": "test_csrf_token"}.get(k, default) + mock_request.form = AsyncMock(return_value=form_data) mock_conn = AsyncMock() mock_conn.execute = AsyncMock() @@ -540,12 +531,9 @@ class TestSetupFinishSubmit: mock_pool.acquire.return_value.__aenter__.return_value = mock_conn mock_pool.acquire.return_value.__aexit__.return_value = None - mock_csrf = MagicMock() - mock_csrf.validate_csrf = AsyncMock() - with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit: - result = await setup_finish_submit(mock_request, mock_csrf) + result = await setup_finish_submit(mock_request) assert result.status_code == 302 assert result.headers["location"] == "/"