fix(gui): dashboard polls card + CSRF exception handler

Fix A - /dashboard/polls:
- Use get_last_msg instead of pull_subscribe (no durable consumers)
- Fix subject filter: central.meta.adapter.{name}.status
- Parse correct fields: ts and ok from status message
- Handle NotFoundError gracefully when no status exists

Fix B - CSRF exception handler:
- Add global CsrfProtectError handler in __init__.py
- Return friendly "session expired" message instead of 500
- Re-render forms with error or redirect to /login
- Update templates to display error messages

Tests:
- Add get_last_msg mocking tests for polls
- Add regression test verifying no pull_subscribe
- Add CSRF handler tests

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Ubuntu 2026-05-17 22:29:56 +00:00
commit 9396e5dbe8
7 changed files with 283 additions and 74 deletions

View file

@ -136,6 +136,50 @@ def _create_app() -> FastAPI:
# Include routes # Include routes
app.include_router(router) app.include_router(router)
# CSRF exception handler - return friendly error instead of 500
from fastapi_csrf_protect.exceptions import CsrfProtectError
from fastapi.responses import RedirectResponse
@app.exception_handler(CsrfProtectError)
async def csrf_exception_handler(request, exc: CsrfProtectError):
from fastapi_csrf_protect import CsrfProtect
csrf_protect = CsrfProtect()
csrf_token, signed_token = csrf_protect.generate_csrf_tokens()
if request.url.path == "/login":
response = templates.TemplateResponse(
request=request,
name="login.html",
context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."},
)
csrf_protect.set_csrf_cookie(signed_token, response)
return response
elif request.url.path == "/setup":
response = templates.TemplateResponse(
request=request,
name="setup.html",
context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."},
)
csrf_protect.set_csrf_cookie(signed_token, response)
return response
elif request.url.path == "/logout":
return RedirectResponse("/login", status_code=302)
elif request.url.path == "/change-password":
response = templates.TemplateResponse(
request=request,
name="change_password.html",
context={"csrf_token": csrf_token, "error": "Your session expired. Please try again."},
)
csrf_protect.set_csrf_cookie(signed_token, response)
return response
elif request.url.path.startswith("/adapters/"):
# Redirect back to adapters list
return RedirectResponse("/adapters", status_code=302)
else:
# Fallback: redirect to login
return RedirectResponse("/login", status_code=302)
return app return app

View file

@ -182,6 +182,7 @@ async def dashboard_streams(request: Request) -> HTMLResponse:
async def dashboard_polls(request: Request) -> HTMLResponse: async def dashboard_polls(request: Request) -> HTMLResponse:
"""Get last poll times for each adapter.""" """Get last poll times for each adapter."""
from central.gui.nats import get_js from central.gui.nats import get_js
from nats.js.errors import NotFoundError
templates = _get_templates() templates = _get_templates()
pool = get_pool() pool = get_pool()
@ -210,43 +211,31 @@ async def dashboard_polls(request: Request) -> HTMLResponse:
else: else:
for name in adapter_names: for name in adapter_names:
try: try:
# Get last message from CENTRAL_META for this adapter msg = await js.get_last_msg(
sub = await js.pull_subscribe( "CENTRAL_META",
f"central.meta.{name}.status", f"central.meta.adapter.{name}.status",
durable=f"dashboard-poll-{name}",
stream="CENTRAL_META",
) )
try: data = json.loads(msg.data.decode())
msgs = await sub.fetch(1, timeout=1.0)
if msgs:
data = json.loads(msgs[0].data.decode())
last_poll = data.get("data", {}).get("time", "")
adapters.append({ adapters.append({
"name": name, "name": name,
"last_poll": last_poll, "last_poll": data.get("ts"),
"status": "", "status": "" if data.get("ok") else "",
"error": None, "error": data.get("error") if not data.get("ok") else None,
}) })
else: except NotFoundError:
# No status message for this adapter yet
adapters.append({ adapters.append({
"name": name, "name": name,
"last_poll": None, "last_poll": None,
"status": None, "status": None,
"error": None, "error": None,
}) })
except Exception: except Exception as e:
adapters.append({ adapters.append({
"name": name, "name": name,
"last_poll": None, "last_poll": None,
"status": None, "status": "?",
"error": None, "error": str(e),
})
except Exception:
adapters.append({
"name": name,
"last_poll": None,
"status": None,
"error": "unavailable",
}) })
return templates.TemplateResponse( return templates.TemplateResponse(

View file

@ -8,6 +8,10 @@
<h1>Change Password</h1> <h1>Change Password</h1>
</header> </header>
{% if error %}
<p style="color: var(--pico-color-red-500);">{{ error }}</p>
{% endif %}
<form action="/change-password" method="post"> <form action="/change-password" method="post">
<input type="hidden" name="csrf_token" value="{{ csrf_token }}"> <input type="hidden" name="csrf_token" value="{{ csrf_token }}">

View file

@ -8,6 +8,10 @@
<h1>Login</h1> <h1>Login</h1>
</header> </header>
{% if error %}
<p style="color: var(--pico-color-red-500);">{{ error }}</p>
{% endif %}
<form action="/login" method="post"> <form action="/login" method="post">
<input type="hidden" name="csrf_token" value="{{ csrf_token }}"> <input type="hidden" name="csrf_token" value="{{ csrf_token }}">

View file

@ -9,6 +9,10 @@
<p>Create the initial operator account to get started.</p> <p>Create the initial operator account to get started.</p>
</header> </header>
{% if error %}
<p style="color: var(--pico-color-red-500);">{{ error }}</p>
{% endif %}
<form action="/setup" method="post"> <form action="/setup" method="post">
<input type="hidden" name="csrf_token" value="{{ csrf_token }}"> <input type="hidden" name="csrf_token" value="{{ csrf_token }}">

109
tests/test_csrf_handler.py Normal file
View file

@ -0,0 +1,109 @@
"""Tests for CSRF exception handler."""
import os
from unittest.mock import MagicMock, AsyncMock, patch
import pytest
os.environ.setdefault("CENTRAL_DB_DSN", "postgresql://test:test@localhost/test")
os.environ.setdefault("CENTRAL_CSRF_SECRET", "testsecret12345678901234567890ab")
os.environ.setdefault("CENTRAL_NATS_URL", "nats://localhost:4222")
class TestCsrfExceptionHandlerRegistered:
"""Verify CSRF exception handler is properly registered."""
def test_csrf_exception_handler_is_registered(self):
"""The app has a CsrfProtectError exception handler registered."""
from central.gui import app
from fastapi_csrf_protect.exceptions import CsrfProtectError
assert CsrfProtectError in app.exception_handlers, \
"CsrfProtectError handler should be registered"
def test_csrf_subclasses_are_caught(self):
"""MissingTokenError and TokenValidationError inherit from CsrfProtectError."""
from fastapi_csrf_protect.exceptions import (
CsrfProtectError,
MissingTokenError,
TokenValidationError,
)
assert issubclass(MissingTokenError, CsrfProtectError)
assert issubclass(TokenValidationError, CsrfProtectError)
class TestCsrfExceptionHandlerBehavior:
"""Test the CSRF exception handler behavior."""
def test_login_csrf_error_handler_checks_path(self):
"""CSRF handler checks request path for /login."""
import inspect
from central.gui import _create_app
from fastapi_csrf_protect.exceptions import CsrfProtectError
app = _create_app()
handler = app.exception_handlers.get(CsrfProtectError)
# Verify handler source contains /login path check
source = inspect.getsource(handler)
assert "/login" in source
assert "session expired" in source.lower()
@pytest.mark.asyncio
async def test_logout_csrf_error_redirects_to_login(self):
"""CSRF error on /logout should redirect to /login."""
from central.gui import _create_app
from fastapi_csrf_protect.exceptions import TokenValidationError
from fastapi.responses import RedirectResponse
app = _create_app()
from fastapi_csrf_protect.exceptions import CsrfProtectError
handler = app.exception_handlers.get(CsrfProtectError)
mock_request = MagicMock()
mock_request.url.path = "/logout"
exc = TokenValidationError("Invalid token")
result = await handler(mock_request, exc)
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
@pytest.mark.asyncio
async def test_adapters_csrf_error_redirects_to_adapters(self):
"""CSRF error on /adapters/{name} should redirect to /adapters."""
from central.gui import _create_app
from fastapi_csrf_protect.exceptions import TokenValidationError
from fastapi.responses import RedirectResponse
app = _create_app()
from fastapi_csrf_protect.exceptions import CsrfProtectError
handler = app.exception_handlers.get(CsrfProtectError)
mock_request = MagicMock()
mock_request.url.path = "/adapters/nws"
exc = TokenValidationError("Invalid token")
result = await handler(mock_request, exc)
assert isinstance(result, RedirectResponse)
assert result.status_code == 302
class TestCsrfHandlerNoTraceback:
"""Verify exception handler doesn't expose Python internals."""
def test_handler_exists_and_is_async(self):
"""The CSRF handler should be an async function."""
import inspect
from central.gui import _create_app
from fastapi_csrf_protect.exceptions import CsrfProtectError
app = _create_app()
handler = app.exception_handlers.get(CsrfProtectError)
assert handler is not None
assert inspect.iscoroutinefunction(handler)

View file

@ -1,47 +1,36 @@
"""Tests for dashboard routes.""" """Tests for dashboard routes."""
import json
import os import os
from unittest.mock import MagicMock, AsyncMock, patch from unittest.mock import MagicMock, AsyncMock, patch
import pytest import pytest
# Set required env vars before importing central modules
os.environ.setdefault("CENTRAL_DB_DSN", "postgresql://test:test@localhost/test") os.environ.setdefault("CENTRAL_DB_DSN", "postgresql://test:test@localhost/test")
os.environ.setdefault("CENTRAL_CSRF_SECRET", "testsecret12345678901234567890ab") os.environ.setdefault("CENTRAL_CSRF_SECRET", "testsecret12345678901234567890ab")
os.environ.setdefault("CENTRAL_NATS_URL", "nats://localhost:4222") os.environ.setdefault("CENTRAL_NATS_URL", "nats://localhost:4222")
class TestFormatBytes: class TestFormatBytes:
"""Test _format_bytes helper."""
def test_format_bytes_bytes(self): def test_format_bytes_bytes(self):
"""Bytes are shown as B."""
from central.gui.routes import _format_bytes from central.gui.routes import _format_bytes
assert _format_bytes(100) == "100 B" assert _format_bytes(100) == "100 B"
def test_format_bytes_kilobytes(self): def test_format_bytes_kilobytes(self):
"""KB formatting."""
from central.gui.routes import _format_bytes from central.gui.routes import _format_bytes
assert _format_bytes(1024) == "1.0 KB" assert _format_bytes(1024) == "1.0 KB"
def test_format_bytes_megabytes(self): def test_format_bytes_megabytes(self):
"""MB formatting."""
from central.gui.routes import _format_bytes from central.gui.routes import _format_bytes
assert _format_bytes(1048576) == "1.0 MB" assert _format_bytes(1048576) == "1.0 MB"
def test_format_bytes_gigabytes(self): def test_format_bytes_gigabytes(self):
"""GB formatting."""
from central.gui.routes import _format_bytes from central.gui.routes import _format_bytes
assert _format_bytes(1073741824) == "1.0 GB" assert _format_bytes(1073741824) == "1.0 GB"
class TestDashboardEventsSQL: class TestDashboardEventsSQL:
"""Test events query construction."""
def test_events_query_has_24h_filter(self): def test_events_query_has_24h_filter(self):
"""Events query filters by received > NOW() - 24h."""
# We can't easily test the full route without mocking,
# but we can verify the query logic by inspecting the source
import inspect import inspect
from central.gui.routes import dashboard_events from central.gui.routes import dashboard_events
source = inspect.getsource(dashboard_events) source = inspect.getsource(dashboard_events)
@ -50,25 +39,17 @@ class TestDashboardEventsSQL:
class TestDashboardStreamsGracefulDegradation: class TestDashboardStreamsGracefulDegradation:
"""Test streams endpoint graceful degradation."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_nats_unavailable_returns_error_message(self): async def test_nats_unavailable_returns_error_message(self):
"""When NATS is unavailable, streams returns error message not 500."""
from central.gui.routes import dashboard_streams from central.gui.routes import dashboard_streams
mock_request = MagicMock() mock_request = MagicMock()
mock_request.state.operator = MagicMock() mock_request.state.operator = MagicMock()
mock_templates = MagicMock() mock_templates = MagicMock()
mock_response = MagicMock() mock_response = MagicMock()
mock_templates.TemplateResponse.return_value = mock_response mock_templates.TemplateResponse.return_value = mock_response
with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes._get_templates", return_value=mock_templates):
with patch("central.gui.nats.get_js", return_value=None): with patch("central.gui.nats.get_js", return_value=None):
result = await dashboard_streams(mock_request) result = await dashboard_streams(mock_request)
# Should call template with error context
call_args = mock_templates.TemplateResponse.call_args call_args = mock_templates.TemplateResponse.call_args
context = call_args.kwargs.get("context", call_args[1].get("context")) context = call_args.kwargs.get("context", call_args[1].get("context"))
assert context["error"] == "NATS unavailable" assert context["error"] == "NATS unavailable"
@ -76,32 +57,23 @@ class TestDashboardStreamsGracefulDegradation:
class TestDashboardPollsGracefulDegradation: class TestDashboardPollsGracefulDegradation:
"""Test polls endpoint graceful degradation."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_nats_unavailable_shows_all_adapters_with_error(self): async def test_nats_unavailable_shows_all_adapters_with_error(self):
"""When NATS is unavailable, polls shows adapters with error message."""
from central.gui.routes import dashboard_polls from central.gui.routes import dashboard_polls
mock_request = MagicMock() mock_request = MagicMock()
mock_request.state.operator = MagicMock() mock_request.state.operator = MagicMock()
mock_conn = AsyncMock() mock_conn = AsyncMock()
mock_conn.fetch.return_value = [{"name": "nws"}, {"name": "firms"}] mock_conn.fetch.return_value = [{"name": "nws"}, {"name": "firms"}]
mock_pool = MagicMock() mock_pool = MagicMock()
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn) mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None) mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
mock_templates = MagicMock() mock_templates = MagicMock()
mock_response = MagicMock() mock_response = MagicMock()
mock_templates.TemplateResponse.return_value = mock_response mock_templates.TemplateResponse.return_value = mock_response
with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes._get_templates", return_value=mock_templates):
with patch("central.gui.routes.get_pool", return_value=mock_pool): with patch("central.gui.routes.get_pool", return_value=mock_pool):
with patch("central.gui.nats.get_js", return_value=None): with patch("central.gui.nats.get_js", return_value=None):
result = await dashboard_polls(mock_request) result = await dashboard_polls(mock_request)
call_args = mock_templates.TemplateResponse.call_args call_args = mock_templates.TemplateResponse.call_args
context = call_args.kwargs.get("context", call_args[1].get("context")) context = call_args.kwargs.get("context", call_args[1].get("context"))
assert context["error"] == "NATS unavailable" assert context["error"] == "NATS unavailable"
@ -109,14 +81,106 @@ class TestDashboardPollsGracefulDegradation:
assert context["adapters"][0]["error"] == "NATS unavailable" assert context["adapters"][0]["error"] == "NATS unavailable"
class TestDashboardStreamsIsolation: class TestDashboardPollsGetLastMsg:
"""Test stream failure isolation.""" @pytest.mark.asyncio
async def test_polls_returns_timestamp_from_status_message(self):
from central.gui.routes import dashboard_polls
mock_request = MagicMock()
mock_request.state.operator = MagicMock()
mock_conn = AsyncMock()
mock_conn.fetch.return_value = [{"name": "nws"}]
mock_pool = MagicMock()
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
mock_msg = MagicMock()
mock_msg.data = json.dumps({"ok": True, "ts": "2026-05-17T12:34:56Z"}).encode()
mock_js = AsyncMock()
mock_js.get_last_msg = AsyncMock(return_value=mock_msg)
mock_templates = MagicMock()
mock_response = MagicMock()
mock_templates.TemplateResponse.return_value = mock_response
with patch("central.gui.routes._get_templates", return_value=mock_templates):
with patch("central.gui.routes.get_pool", return_value=mock_pool):
with patch("central.gui.nats.get_js", return_value=mock_js):
result = await dashboard_polls(mock_request)
call_args = mock_templates.TemplateResponse.call_args
context = call_args.kwargs.get("context", call_args[1].get("context"))
assert len(context["adapters"]) == 1
adapter = context["adapters"][0]
assert adapter["name"] == "nws"
assert adapter["last_poll"] == "2026-05-17T12:34:56Z"
assert adapter["status"] == "\u2713"
assert adapter["error"] is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_single_stream_failure_doesnt_crash_card(self): async def test_polls_handles_not_found_error_gracefully(self):
"""A single stream failure shows error for that stream only.""" from central.gui.routes import dashboard_polls
from central.gui.routes import dashboard_streams from nats.js.errors import NotFoundError
mock_request = MagicMock()
mock_request.state.operator = MagicMock()
mock_conn = AsyncMock()
mock_conn.fetch.return_value = [{"name": "nws"}]
mock_pool = MagicMock()
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
mock_js = AsyncMock()
mock_js.get_last_msg = AsyncMock(side_effect=NotFoundError())
mock_templates = MagicMock()
mock_response = MagicMock()
mock_templates.TemplateResponse.return_value = mock_response
with patch("central.gui.routes._get_templates", return_value=mock_templates):
with patch("central.gui.routes.get_pool", return_value=mock_pool):
with patch("central.gui.nats.get_js", return_value=mock_js):
result = await dashboard_polls(mock_request)
call_args = mock_templates.TemplateResponse.call_args
context = call_args.kwargs.get("context", call_args[1].get("context"))
adapter = context["adapters"][0]
assert adapter["last_poll"] is None
assert adapter["status"] is None
assert adapter["error"] is None
@pytest.mark.asyncio
async def test_polls_shows_failure_status_when_ok_is_false(self):
from central.gui.routes import dashboard_polls
mock_request = MagicMock()
mock_request.state.operator = MagicMock()
mock_conn = AsyncMock()
mock_conn.fetch.return_value = [{"name": "nws"}]
mock_pool = MagicMock()
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
mock_msg = MagicMock()
mock_msg.data = json.dumps({"ok": False, "ts": "2026-05-17T12:34:56Z", "error": "Connection timeout"}).encode()
mock_js = AsyncMock()
mock_js.get_last_msg = AsyncMock(return_value=mock_msg)
mock_templates = MagicMock()
mock_response = MagicMock()
mock_templates.TemplateResponse.return_value = mock_response
with patch("central.gui.routes._get_templates", return_value=mock_templates):
with patch("central.gui.routes.get_pool", return_value=mock_pool):
with patch("central.gui.nats.get_js", return_value=mock_js):
result = await dashboard_polls(mock_request)
call_args = mock_templates.TemplateResponse.call_args
context = call_args.kwargs.get("context", call_args[1].get("context"))
adapter = context["adapters"][0]
assert adapter["status"] == "\u2717"
assert adapter["error"] == "Connection timeout"
class TestDashboardPollsNoSubscribe:
def test_polls_does_not_use_pull_subscribe(self):
import inspect
from central.gui.routes import dashboard_polls
source = inspect.getsource(dashboard_polls)
assert "pull_subscribe" not in source
assert "get_last_msg" in source
assert "central.meta.adapter." in source
class TestDashboardStreamsIsolation:
@pytest.mark.asyncio
async def test_single_stream_failure_doesnt_crash_card(self):
from central.gui.routes import dashboard_streams
mock_request = MagicMock() mock_request = MagicMock()
mock_request.state.operator = MagicMock() mock_request.state.operator = MagicMock()
@ -132,27 +196,18 @@ class TestDashboardStreamsIsolation:
mock_js = AsyncMock() mock_js = AsyncMock()
mock_js.stream_info.side_effect = mock_stream_info mock_js.stream_info.side_effect = mock_stream_info
mock_templates = MagicMock() mock_templates = MagicMock()
mock_response = MagicMock() mock_response = MagicMock()
mock_templates.TemplateResponse.return_value = mock_response mock_templates.TemplateResponse.return_value = mock_response
with patch("central.gui.routes._get_templates", return_value=mock_templates): with patch("central.gui.routes._get_templates", return_value=mock_templates):
with patch("central.gui.nats.get_js", return_value=mock_js): with patch("central.gui.nats.get_js", return_value=mock_js):
result = await dashboard_streams(mock_request) result = await dashboard_streams(mock_request)
call_args = mock_templates.TemplateResponse.call_args call_args = mock_templates.TemplateResponse.call_args
context = call_args.kwargs.get("context", call_args[1].get("context")) context = call_args.kwargs.get("context", call_args[1].get("context"))
streams = context["streams"] streams = context["streams"]
# Should have 4 streams
assert len(streams) == 4 assert len(streams) == 4
# CENTRAL_FIRE should have error
fire_stream = next(s for s in streams if s["name"] == "CENTRAL_FIRE") fire_stream = next(s for s in streams if s["name"] == "CENTRAL_FIRE")
assert fire_stream.get("error") == "unavailable" assert fire_stream.get("error") == "unavailable"
# CENTRAL_WX should be fine
wx_stream = next(s for s in streams if s["name"] == "CENTRAL_WX") wx_stream = next(s for s in streams if s["name"] == "CENTRAL_WX")
assert wx_stream.get("error") is None assert wx_stream.get("error") is None
assert wx_stream["messages"] == 100 assert wx_stream["messages"] == 100