From 9396e5dbe82f6338d990e221ec64aa5494c9a4d9 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 17 May 2026 22:29:56 +0000 Subject: [PATCH] fix(gui): dashboard polls card + CSRF exception handler Fix A - /dashboard/polls: - Use get_last_msg instead of pull_subscribe (no durable consumers) - Fix subject filter: central.meta.adapter.{name}.status - Parse correct fields: ts and ok from status message - Handle NotFoundError gracefully when no status exists Fix B - CSRF exception handler: - Add global CsrfProtectError handler in __init__.py - Return friendly "session expired" message instead of 500 - Re-render forms with error or redirect to /login - Update templates to display error messages Tests: - Add get_last_msg mocking tests for polls - Add regression test verifying no pull_subscribe - Add CSRF handler tests Co-Authored-By: Claude Opus 4.5 --- src/central/gui/__init__.py | 44 ++++++ src/central/gui/routes.py | 53 +++---- .../gui/templates/change_password.html | 4 + src/central/gui/templates/login.html | 4 + src/central/gui/templates/setup.html | 4 + tests/test_csrf_handler.py | 109 ++++++++++++++ tests/test_dashboard.py | 141 ++++++++++++------ 7 files changed, 284 insertions(+), 75 deletions(-) create mode 100644 tests/test_csrf_handler.py diff --git a/src/central/gui/__init__.py b/src/central/gui/__init__.py index 87f1269..20d79aa 100644 --- a/src/central/gui/__init__.py +++ b/src/central/gui/__init__.py @@ -136,6 +136,50 @@ def _create_app() -> FastAPI: # Include routes app.include_router(router) + # CSRF exception handler - return friendly error instead of 500 + from fastapi_csrf_protect.exceptions import CsrfProtectError + from fastapi.responses import RedirectResponse + + @app.exception_handler(CsrfProtectError) + async def csrf_exception_handler(request, exc: CsrfProtectError): + from fastapi_csrf_protect import CsrfProtect + + csrf_protect = CsrfProtect() + csrf_token, signed_token = csrf_protect.generate_csrf_tokens() + + if request.url.path == "/login": + response = templates.TemplateResponse( + request=request, + name="login.html", + context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."}, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + elif request.url.path == "/setup": + response = templates.TemplateResponse( + request=request, + name="setup.html", + context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."}, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + elif request.url.path == "/logout": + return RedirectResponse("/login", status_code=302) + elif request.url.path == "/change-password": + response = templates.TemplateResponse( + request=request, + name="change_password.html", + context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."}, + ) + csrf_protect.set_csrf_cookie(signed_token, response) + return response + elif request.url.path.startswith("/adapters/"): + # Redirect back to adapters list + return RedirectResponse("/adapters", status_code=302) + else: + # Fallback: redirect to login + return RedirectResponse("/login", status_code=302) + return app diff --git a/src/central/gui/routes.py b/src/central/gui/routes.py index 755cb80..73543c1 100644 --- a/src/central/gui/routes.py +++ b/src/central/gui/routes.py @@ -182,6 +182,7 @@ async def dashboard_streams(request: Request) -> HTMLResponse: async def dashboard_polls(request: Request) -> HTMLResponse: """Get last poll times for each adapter.""" from central.gui.nats import get_js + from nats.js.errors import NotFoundError templates = _get_templates() pool = get_pool() @@ -210,43 +211,31 @@ async def dashboard_polls(request: Request) -> HTMLResponse: else: for name in adapter_names: try: - # Get last message from CENTRAL_META for this adapter - sub = await js.pull_subscribe( - f"central.meta.{name}.status", - durable=f"dashboard-poll-{name}", - stream="CENTRAL_META", + msg = await js.get_last_msg( + "CENTRAL_META", + f"central.meta.adapter.{name}.status", ) - try: - msgs = await sub.fetch(1, timeout=1.0) - if msgs: - data = json.loads(msgs[0].data.decode()) - last_poll = data.get("data", {}).get("time", "—") - adapters.append({ - "name": name, - "last_poll": last_poll, - "status": "✓", - "error": None, - }) - else: - adapters.append({ - "name": name, - "last_poll": None, - "status": None, - "error": None, - }) - except Exception: - adapters.append({ - "name": name, - "last_poll": None, - "status": None, - "error": None, - }) - except Exception: + data = json.loads(msg.data.decode()) + adapters.append({ + "name": name, + "last_poll": data.get("ts"), + "status": "✓" if data.get("ok") else "✗", + "error": data.get("error") if not data.get("ok") else None, + }) + except NotFoundError: + # No status message for this adapter yet adapters.append({ "name": name, "last_poll": None, "status": None, - "error": "unavailable", + "error": None, + }) + except Exception as e: + adapters.append({ + "name": name, + "last_poll": None, + "status": "?", + "error": str(e), }) return templates.TemplateResponse( diff --git a/src/central/gui/templates/change_password.html b/src/central/gui/templates/change_password.html index c353c60..ad91796 100644 --- a/src/central/gui/templates/change_password.html +++ b/src/central/gui/templates/change_password.html @@ -8,6 +8,10 @@

Change Password

+ {% if error %} +

{{ error }}

+ {% endif %} +
diff --git a/src/central/gui/templates/login.html b/src/central/gui/templates/login.html index 3510bf8..60fc942 100644 --- a/src/central/gui/templates/login.html +++ b/src/central/gui/templates/login.html @@ -8,6 +8,10 @@

Login

+ {% if error %} +

{{ error }}

+ {% endif %} + diff --git a/src/central/gui/templates/setup.html b/src/central/gui/templates/setup.html index 7b72249..d290d1d 100644 --- a/src/central/gui/templates/setup.html +++ b/src/central/gui/templates/setup.html @@ -9,6 +9,10 @@

Create the initial operator account to get started.

+ {% if error %} +

{{ error }}

+ {% endif %} + diff --git a/tests/test_csrf_handler.py b/tests/test_csrf_handler.py new file mode 100644 index 0000000..58456e3 --- /dev/null +++ b/tests/test_csrf_handler.py @@ -0,0 +1,109 @@ +"""Tests for CSRF exception handler.""" + +import os +from unittest.mock import MagicMock, AsyncMock, patch + +import pytest + +os.environ.setdefault("CENTRAL_DB_DSN", "postgresql://test:test@localhost/test") +os.environ.setdefault("CENTRAL_CSRF_SECRET", "testsecret12345678901234567890ab") +os.environ.setdefault("CENTRAL_NATS_URL", "nats://localhost:4222") + + +class TestCsrfExceptionHandlerRegistered: + """Verify CSRF exception handler is properly registered.""" + + def test_csrf_exception_handler_is_registered(self): + """The app has a CsrfProtectError exception handler registered.""" + from central.gui import app + from fastapi_csrf_protect.exceptions import CsrfProtectError + + assert CsrfProtectError in app.exception_handlers, \ + "CsrfProtectError handler should be registered" + + def test_csrf_subclasses_are_caught(self): + """MissingTokenError and TokenValidationError inherit from CsrfProtectError.""" + from fastapi_csrf_protect.exceptions import ( + CsrfProtectError, + MissingTokenError, + TokenValidationError, + ) + + assert issubclass(MissingTokenError, CsrfProtectError) + assert issubclass(TokenValidationError, CsrfProtectError) + + +class TestCsrfExceptionHandlerBehavior: + """Test the CSRF exception handler behavior.""" + + def test_login_csrf_error_handler_checks_path(self): + """CSRF handler checks request path for /login.""" + import inspect + from central.gui import _create_app + from fastapi_csrf_protect.exceptions import CsrfProtectError + + app = _create_app() + handler = app.exception_handlers.get(CsrfProtectError) + + # Verify handler source contains /login path check + source = inspect.getsource(handler) + assert "/login" in source + assert "session expired" in source.lower() + + @pytest.mark.asyncio + async def test_logout_csrf_error_redirects_to_login(self): + """CSRF error on /logout should redirect to /login.""" + from central.gui import _create_app + from fastapi_csrf_protect.exceptions import TokenValidationError + from fastapi.responses import RedirectResponse + + app = _create_app() + from fastapi_csrf_protect.exceptions import CsrfProtectError + handler = app.exception_handlers.get(CsrfProtectError) + + mock_request = MagicMock() + mock_request.url.path = "/logout" + + exc = TokenValidationError("Invalid token") + + result = await handler(mock_request, exc) + + assert isinstance(result, RedirectResponse) + assert result.status_code == 302 + + @pytest.mark.asyncio + async def test_adapters_csrf_error_redirects_to_adapters(self): + """CSRF error on /adapters/{name} should redirect to /adapters.""" + from central.gui import _create_app + from fastapi_csrf_protect.exceptions import TokenValidationError + from fastapi.responses import RedirectResponse + + app = _create_app() + from fastapi_csrf_protect.exceptions import CsrfProtectError + handler = app.exception_handlers.get(CsrfProtectError) + + mock_request = MagicMock() + mock_request.url.path = "/adapters/nws" + + exc = TokenValidationError("Invalid token") + + result = await handler(mock_request, exc) + + assert isinstance(result, RedirectResponse) + assert result.status_code == 302 + + +class TestCsrfHandlerNoTraceback: + """Verify exception handler doesn't expose Python internals.""" + + def test_handler_exists_and_is_async(self): + """The CSRF handler should be an async function.""" + import inspect + from central.gui import _create_app + from fastapi_csrf_protect.exceptions import CsrfProtectError + + app = _create_app() + handler = app.exception_handlers.get(CsrfProtectError) + + assert handler is not None + assert inspect.iscoroutinefunction(handler) diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index 6572903..f03f9ac 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -1,47 +1,36 @@ """Tests for dashboard routes.""" +import json import os from unittest.mock import MagicMock, AsyncMock, patch import pytest -# Set required env vars before importing central modules os.environ.setdefault("CENTRAL_DB_DSN", "postgresql://test:test@localhost/test") os.environ.setdefault("CENTRAL_CSRF_SECRET", "testsecret12345678901234567890ab") os.environ.setdefault("CENTRAL_NATS_URL", "nats://localhost:4222") class TestFormatBytes: - """Test _format_bytes helper.""" - def test_format_bytes_bytes(self): - """Bytes are shown as B.""" from central.gui.routes import _format_bytes assert _format_bytes(100) == "100 B" def test_format_bytes_kilobytes(self): - """KB formatting.""" from central.gui.routes import _format_bytes assert _format_bytes(1024) == "1.0 KB" def test_format_bytes_megabytes(self): - """MB formatting.""" from central.gui.routes import _format_bytes assert _format_bytes(1048576) == "1.0 MB" def test_format_bytes_gigabytes(self): - """GB formatting.""" from central.gui.routes import _format_bytes assert _format_bytes(1073741824) == "1.0 GB" class TestDashboardEventsSQL: - """Test events query construction.""" - def test_events_query_has_24h_filter(self): - """Events query filters by received > NOW() - 24h.""" - # We can't easily test the full route without mocking, - # but we can verify the query logic by inspecting the source import inspect from central.gui.routes import dashboard_events source = inspect.getsource(dashboard_events) @@ -50,25 +39,17 @@ class TestDashboardEventsSQL: class TestDashboardStreamsGracefulDegradation: - """Test streams endpoint graceful degradation.""" - @pytest.mark.asyncio async def test_nats_unavailable_returns_error_message(self): - """When NATS is unavailable, streams returns error message not 500.""" from central.gui.routes import dashboard_streams - mock_request = MagicMock() mock_request.state.operator = MagicMock() - mock_templates = MagicMock() mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.nats.get_js", return_value=None): result = await dashboard_streams(mock_request) - - # Should call template with error context call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) assert context["error"] == "NATS unavailable" @@ -76,32 +57,23 @@ class TestDashboardStreamsGracefulDegradation: class TestDashboardPollsGracefulDegradation: - """Test polls endpoint graceful degradation.""" - @pytest.mark.asyncio async def test_nats_unavailable_shows_all_adapters_with_error(self): - """When NATS is unavailable, polls shows adapters with error message.""" from central.gui.routes import dashboard_polls - mock_request = MagicMock() mock_request.state.operator = MagicMock() - mock_conn = AsyncMock() mock_conn.fetch.return_value = [{"name": "nws"}, {"name": "firms"}] - mock_pool = MagicMock() mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) - mock_templates = MagicMock() mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.nats.get_js", return_value=None): result = await dashboard_polls(mock_request) - call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) assert context["error"] == "NATS unavailable" @@ -109,14 +81,106 @@ class TestDashboardPollsGracefulDegradation: assert context["adapters"][0]["error"] == "NATS unavailable" -class TestDashboardStreamsIsolation: - """Test stream failure isolation.""" +class TestDashboardPollsGetLastMsg: + @pytest.mark.asyncio + async def test_polls_returns_timestamp_from_status_message(self): + from central.gui.routes import dashboard_polls + mock_request = MagicMock() + mock_request.state.operator = MagicMock() + mock_conn = AsyncMock() + mock_conn.fetch.return_value = [{"name": "nws"}] + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) + mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) + mock_msg = MagicMock() + mock_msg.data = json.dumps({"ok": True, "ts": "2026-05-17T12:34:56Z"}).encode() + mock_js = AsyncMock() + mock_js.get_last_msg = AsyncMock(return_value=mock_msg) + mock_templates = MagicMock() + mock_response = MagicMock() + mock_templates.TemplateResponse.return_value = mock_response + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.nats.get_js", return_value=mock_js): + result = await dashboard_polls(mock_request) + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + assert len(context["adapters"]) == 1 + adapter = context["adapters"][0] + assert adapter["name"] == "nws" + assert adapter["last_poll"] == "2026-05-17T12:34:56Z" + assert adapter["status"] == "\u2713" + assert adapter["error"] is None @pytest.mark.asyncio - async def test_single_stream_failure_doesnt_crash_card(self): - """A single stream failure shows error for that stream only.""" - from central.gui.routes import dashboard_streams + async def test_polls_handles_not_found_error_gracefully(self): + from central.gui.routes import dashboard_polls + from nats.js.errors import NotFoundError + mock_request = MagicMock() + mock_request.state.operator = MagicMock() + mock_conn = AsyncMock() + mock_conn.fetch.return_value = [{"name": "nws"}] + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) + mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) + mock_js = AsyncMock() + mock_js.get_last_msg = AsyncMock(side_effect=NotFoundError()) + mock_templates = MagicMock() + mock_response = MagicMock() + mock_templates.TemplateResponse.return_value = mock_response + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.nats.get_js", return_value=mock_js): + result = await dashboard_polls(mock_request) + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + adapter = context["adapters"][0] + assert adapter["last_poll"] is None + assert adapter["status"] is None + assert adapter["error"] is None + @pytest.mark.asyncio + async def test_polls_shows_failure_status_when_ok_is_false(self): + from central.gui.routes import dashboard_polls + mock_request = MagicMock() + mock_request.state.operator = MagicMock() + mock_conn = AsyncMock() + mock_conn.fetch.return_value = [{"name": "nws"}] + mock_pool = MagicMock() + mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) + mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) + mock_msg = MagicMock() + mock_msg.data = json.dumps({"ok": False, "ts": "2026-05-17T12:34:56Z", "error": "Connection timeout"}).encode() + mock_js = AsyncMock() + mock_js.get_last_msg = AsyncMock(return_value=mock_msg) + mock_templates = MagicMock() + mock_response = MagicMock() + mock_templates.TemplateResponse.return_value = mock_response + with patch("central.gui.routes._get_templates", return_value=mock_templates): + with patch("central.gui.routes.get_pool", return_value=mock_pool): + with patch("central.gui.nats.get_js", return_value=mock_js): + result = await dashboard_polls(mock_request) + call_args = mock_templates.TemplateResponse.call_args + context = call_args.kwargs.get("context", call_args[1].get("context")) + adapter = context["adapters"][0] + assert adapter["status"] == "\u2717" + assert adapter["error"] == "Connection timeout" + + +class TestDashboardPollsNoSubscribe: + def test_polls_does_not_use_pull_subscribe(self): + import inspect + from central.gui.routes import dashboard_polls + source = inspect.getsource(dashboard_polls) + assert "pull_subscribe" not in source + assert "get_last_msg" in source + assert "central.meta.adapter." in source + + +class TestDashboardStreamsIsolation: + @pytest.mark.asyncio + async def test_single_stream_failure_doesnt_crash_card(self): + from central.gui.routes import dashboard_streams mock_request = MagicMock() mock_request.state.operator = MagicMock() @@ -132,27 +196,18 @@ class TestDashboardStreamsIsolation: mock_js = AsyncMock() mock_js.stream_info.side_effect = mock_stream_info - mock_templates = MagicMock() mock_response = MagicMock() mock_templates.TemplateResponse.return_value = mock_response - with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.nats.get_js", return_value=mock_js): result = await dashboard_streams(mock_request) - call_args = mock_templates.TemplateResponse.call_args context = call_args.kwargs.get("context", call_args[1].get("context")) - streams = context["streams"] - # Should have 4 streams assert len(streams) == 4 - - # CENTRAL_FIRE should have error fire_stream = next(s for s in streams if s["name"] == "CENTRAL_FIRE") assert fire_stream.get("error") == "unavailable" - - # CENTRAL_WX should be fine wx_stream = next(s for s in streams if s["name"] == "CENTRAL_WX") assert wx_stream.get("error") is None assert wx_stream["messages"] == 100