mirror of
https://github.com/zvx-echo6/central.git
synced 2026-05-21 18:14:44 +02:00
Compare commits
14 commits
4c9ca176a9
...
4368c83613
| Author | SHA1 | Date | |
|---|---|---|---|
|
4368c83613 |
|||
|
|
5be002cb03 | ||
|
|
0f127399b3 | ||
|
|
dec8ce8545 | ||
|
dbe7f8f868 |
|||
|
|
736b637d31 | ||
|
c31de2499d |
|||
|
|
6b5f6709e4 | ||
|
a1e16547a0 |
|||
|
|
83b1e45fa8 | ||
|
71c73b4eb1 |
|||
|
|
a25b4af4e8 | ||
|
|
98e9d95810 | ||
|
|
8601a19f60 |
30 changed files with 2149 additions and 87 deletions
|
|
@ -15,6 +15,10 @@ Phase 0 — scaffold. Not yet operational.
|
|||
- One archive consumer process persisting events to TimescaleDB
|
||||
- Both processes systemd-managed
|
||||
|
||||
## Testing
|
||||
|
||||
See [docs/test-database.md](docs/test-database.md) for test database setup.
|
||||
|
||||
## License
|
||||
|
||||
MIT. See LICENSE.
|
||||
|
|
|
|||
|
|
@ -1,5 +1,11 @@
|
|||
# Migration policy
|
||||
|
||||
## Migrations are the sole source of truth
|
||||
|
||||
The `sql/migrations/` directory contains all schema definitions. There is
|
||||
no separate schema.sql file; use `pg_dump -s central` to generate a
|
||||
human-readable snapshot of the current schema when needed.
|
||||
|
||||
## Migrations must be idempotent
|
||||
|
||||
New migration files (007+) must use guards so they can be safely
|
||||
|
|
@ -20,8 +26,24 @@ Direct `psql` execution bypasses the `schema_migrations` tracker and
|
|||
was the cause of the v0.2.0 reconcile. If a migration needs to be
|
||||
applied on the live system, run:
|
||||
|
||||
sudo -u central /opt/central/.venv/bin/python -m scripts.migrate
|
||||
sudo -u central /opt/central/.venv/bin/python -m central.migrate
|
||||
|
||||
Never apply migration SQL directly via `psql`, even as a superuser,
|
||||
even "just to test." If migrate.py has a bug that's blocking you, fix
|
||||
migrate.py.
|
||||
|
||||
## Extensions are not in migrations
|
||||
|
||||
PostgreSQL extensions like PostGIS require superuser privileges to
|
||||
install. The production `central` role is intentionally not a superuser.
|
||||
Therefore, extensions live outside the migration system:
|
||||
|
||||
- **Production bootstrap:** A DBA runs `CREATE EXTENSION postgis` once
|
||||
before the first `migrate.py` run.
|
||||
- **Test database:** The `central_test` role is a superuser, allowing
|
||||
test fixtures to self-bootstrap extensions.
|
||||
|
||||
This is documented in [docs/test-database.md](test-database.md).
|
||||
|
||||
Do not add `CREATE EXTENSION` statements to migrations — they will fail
|
||||
in production where migrations run as the non-superuser `central` role.
|
||||
|
|
|
|||
83
docs/test-database.md
Normal file
83
docs/test-database.md
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
# Test Database Setup
|
||||
|
||||
Central's integration tests require a PostgreSQL database. This document
|
||||
covers one-time setup and maintenance of the test database.
|
||||
|
||||
## DSN Convention
|
||||
|
||||
Tests default to:
|
||||
|
||||
```
|
||||
postgresql://central_test:testpass@localhost/central_test
|
||||
```
|
||||
|
||||
Override via the `CENTRAL_TEST_DB_DSN` environment variable:
|
||||
|
||||
```bash
|
||||
export CENTRAL_TEST_DB_DSN="postgresql://myuser:mypass@localhost/mydb"
|
||||
```
|
||||
|
||||
## One-Time Setup
|
||||
|
||||
Run these commands once on a fresh PostgreSQL installation:
|
||||
|
||||
```bash
|
||||
# Create the test user (as postgres superuser)
|
||||
sudo -u postgres createuser -s central_test
|
||||
sudo -u postgres psql -c "ALTER USER central_test PASSWORD 'testpass'"
|
||||
|
||||
# Create the test database
|
||||
sudo -u postgres createdb -O central_test central_test
|
||||
|
||||
# Install required extensions
|
||||
sudo -u postgres psql central_test -c "CREATE EXTENSION IF NOT EXISTS postgis"
|
||||
```
|
||||
|
||||
**Note:** The `central_test` role is created as a superuser (`-s` flag).
|
||||
This allows test fixtures to self-bootstrap extensions like PostGIS via
|
||||
`CREATE EXTENSION IF NOT EXISTS`. Production uses a non-superuser role.
|
||||
|
||||
## Required Extensions
|
||||
|
||||
| Extension | Version | Purpose |
|
||||
|-----------|---------|---------|
|
||||
| postgis | 3.4+ | Geometry types for geospatial event data |
|
||||
|
||||
## Why PostGIS Is Not in Migrations
|
||||
|
||||
PostGIS requires superuser privileges to install. The production `central`
|
||||
role is intentionally not a superuser for security reasons. Therefore:
|
||||
|
||||
- **Production:** A DBA must run `CREATE EXTENSION postgis` before the
|
||||
first `migrate.py` run. This is a one-time bootstrap step.
|
||||
- **Test:** The `central_test` role is a superuser, so test fixtures can
|
||||
self-bootstrap PostGIS via `CREATE EXTENSION IF NOT EXISTS`.
|
||||
|
||||
This divergence is documented rather than "fixed" because granting
|
||||
superuser to production roles creates security risk, and the PostgreSQL
|
||||
packaging on Ubuntu does not mark PostGIS as a trusted extension.
|
||||
|
||||
## Resetting the Test Database
|
||||
|
||||
If the test database gets into a bad state:
|
||||
|
||||
```bash
|
||||
# Drop and recreate
|
||||
sudo -u postgres dropdb central_test
|
||||
sudo -u postgres createdb -O central_test central_test
|
||||
sudo -u postgres psql central_test -c "CREATE EXTENSION IF NOT EXISTS postgis"
|
||||
```
|
||||
|
||||
Test fixtures handle their own table creation and cleanup, so this is
|
||||
rarely needed.
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
cd /opt/central
|
||||
uv run pytest tests/ # all tests
|
||||
uv run pytest tests/test_config_store.py -v # specific file
|
||||
```
|
||||
|
||||
Tests that require the database will skip gracefully if the connection
|
||||
fails, though most integration tests will fail without a working DB.
|
||||
46
sql/migrations/011_events_add_adapter_column.sql
Normal file
46
sql/migrations/011_events_add_adapter_column.sql
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
-- Migration 011: Add adapter column to events, drop source column
|
||||
-- Replaces module-path-based source with stable adapter identifier
|
||||
|
||||
-- Add adapter column (idempotent)
|
||||
ALTER TABLE public.events ADD COLUMN IF NOT EXISTS adapter TEXT;
|
||||
|
||||
-- Backfill from existing source values
|
||||
UPDATE public.events
|
||||
SET adapter = REPLACE(source, 'central/adapters/', '')
|
||||
WHERE adapter IS NULL AND source IS NOT NULL;
|
||||
|
||||
-- Make NOT NULL after backfill (idempotent check)
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1 FROM information_schema.columns
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = 'events'
|
||||
AND column_name = 'adapter'
|
||||
AND is_nullable = 'YES'
|
||||
) THEN
|
||||
ALTER TABLE public.events ALTER COLUMN adapter SET NOT NULL;
|
||||
END IF;
|
||||
END $$;
|
||||
|
||||
-- Add FK constraint (idempotent check)
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1 FROM information_schema.table_constraints
|
||||
WHERE constraint_name = 'events_adapter_fkey'
|
||||
AND table_name = 'events'
|
||||
) THEN
|
||||
ALTER TABLE public.events
|
||||
ADD CONSTRAINT events_adapter_fkey
|
||||
FOREIGN KEY (adapter) REFERENCES config.adapters(name)
|
||||
ON DELETE RESTRICT;
|
||||
END IF;
|
||||
END $$;
|
||||
|
||||
-- Add index for dashboard queries (idempotent)
|
||||
CREATE INDEX IF NOT EXISTS events_adapter_received_idx
|
||||
ON public.events (adapter, received DESC);
|
||||
|
||||
-- Drop deprecated source column (idempotent)
|
||||
ALTER TABLE public.events DROP COLUMN IF EXISTS source;
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
-- Central Data Hub schema
|
||||
-- PostgreSQL 16 + TimescaleDB + PostGIS
|
||||
|
||||
CREATE TABLE IF NOT EXISTS events (
|
||||
id TEXT NOT NULL, -- CloudEvent id
|
||||
source TEXT NOT NULL, -- adapter identity
|
||||
category TEXT NOT NULL, -- "wx.alert.<type>"
|
||||
time TIMESTAMPTZ NOT NULL, -- event-time UTC
|
||||
expires TIMESTAMPTZ,
|
||||
severity SMALLINT, -- 0..4 or NULL
|
||||
geom GEOMETRY(Geometry, 4326), -- centroid or bbox as Polygon
|
||||
regions TEXT[], -- ["US-ID-Ada", ...]
|
||||
primary_region TEXT,
|
||||
payload JSONB NOT NULL, -- full Event as JSON
|
||||
received TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
PRIMARY KEY (id, time) -- composite PK for TimescaleDB
|
||||
);
|
||||
|
||||
SELECT create_hypertable('events', 'time', if_not_exists => TRUE);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS events_category_time_idx
|
||||
ON events (category, time DESC);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS events_geom_gist
|
||||
ON events USING GIST (geom);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS events_regions_gin
|
||||
ON events USING GIN (regions);
|
||||
|
||||
-- Dedup on insert via ON CONFLICT (id, time) in the archive consumer.
|
||||
|
|
@ -330,7 +330,7 @@ class FIRMSAdapter(SourceAdapter):
|
|||
|
||||
return Event(
|
||||
id=stable_id,
|
||||
source="central/adapters/firms",
|
||||
adapter="firms",
|
||||
category=f"fire.hotspot.{satellite_short}.{confidence}",
|
||||
time=time,
|
||||
expires=None,
|
||||
|
|
|
|||
|
|
@ -475,7 +475,7 @@ class NWSAdapter(SourceAdapter):
|
|||
|
||||
return Event(
|
||||
id=event_id,
|
||||
source="central/adapters/nws",
|
||||
adapter="nws",
|
||||
category=category,
|
||||
time=time,
|
||||
expires=expires,
|
||||
|
|
|
|||
|
|
@ -348,7 +348,7 @@ class USGSQuakeAdapter(SourceAdapter):
|
|||
|
||||
return Event(
|
||||
id=event_id,
|
||||
source="central/adapters/usgs_quake",
|
||||
adapter="usgs_quake",
|
||||
category=f"quake.event.{tier}",
|
||||
time=event_time,
|
||||
expires=None,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
"""Central archive consumer - JetStream to TimescaleDB."""
|
||||
"""Central archive consumer - JetStream to TimescaleDB.
|
||||
|
||||
Consumes events from multiple NATS JetStream streams and archives them
|
||||
to TimescaleDB. One durable consumer per stream for independent ack tracking.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
|
@ -12,17 +16,27 @@ import asyncpg
|
|||
import nats
|
||||
from nats.js import JetStreamContext
|
||||
from nats.js.api import ConsumerConfig, DeliverPolicy, AckPolicy
|
||||
from nats.js.errors import NotFoundError
|
||||
|
||||
from central.bootstrap_config import get_settings
|
||||
|
||||
CONSUMER_NAME = "archive"
|
||||
STREAM_NAME = "CENTRAL_WX"
|
||||
SUBJECT_FILTER = "central.wx.>"
|
||||
# Event-bearing streams to consume (skip CENTRAL_META - status messages only)
|
||||
STREAMS = [
|
||||
("CENTRAL_WX", "central.wx.>"),
|
||||
("CENTRAL_FIRE", "central.fire.>"),
|
||||
("CENTRAL_QUAKE", "central.quake.>"),
|
||||
]
|
||||
|
||||
BATCH_SIZE = 100
|
||||
FETCH_TIMEOUT = 5.0
|
||||
ACK_WAIT = 30
|
||||
|
||||
|
||||
def consumer_name_for(stream: str) -> str:
|
||||
"""Generate consumer name for a stream."""
|
||||
return f"archive-{stream.lower()}"
|
||||
|
||||
|
||||
class JsonFormatter(logging.Formatter):
|
||||
"""JSON log formatter for structured logging."""
|
||||
|
||||
|
|
@ -125,24 +139,49 @@ class ArchiveConsumer:
|
|||
self._js = None
|
||||
logger.info("Disconnected")
|
||||
|
||||
async def _ensure_consumer(self) -> None:
|
||||
"""Ensure the durable consumer exists."""
|
||||
async def _cleanup_orphaned_consumer(self) -> None:
|
||||
"""Remove orphaned 'archive' consumer from CENTRAL_WX if it exists.
|
||||
|
||||
The old single-stream code used a consumer named 'archive' on CENTRAL_WX.
|
||||
Now we use 'archive-central_wx' instead. Clean up the old one.
|
||||
"""
|
||||
if not self._js:
|
||||
return
|
||||
|
||||
try:
|
||||
await self._js.consumer_info(STREAM_NAME, CONSUMER_NAME)
|
||||
logger.info("Consumer exists", extra={"consumer": CONSUMER_NAME})
|
||||
except nats.js.errors.NotFoundError:
|
||||
await self._js.consumer_info("CENTRAL_WX", "archive")
|
||||
await self._js.delete_consumer("CENTRAL_WX", "archive")
|
||||
logger.info("Removed orphaned 'archive' consumer from CENTRAL_WX")
|
||||
except NotFoundError:
|
||||
pass # Already gone or never existed
|
||||
|
||||
async def _ensure_consumer(
|
||||
self, stream_name: str, subject_filter: str, consumer_name: str
|
||||
) -> None:
|
||||
"""Ensure the durable consumer exists for a stream."""
|
||||
if not self._js:
|
||||
return
|
||||
|
||||
try:
|
||||
await self._js.consumer_info(stream_name, consumer_name)
|
||||
logger.info(
|
||||
"Consumer exists",
|
||||
extra={"stream": stream_name, "consumer": consumer_name}
|
||||
)
|
||||
except NotFoundError:
|
||||
consumer_config = ConsumerConfig(
|
||||
durable_name=CONSUMER_NAME,
|
||||
durable_name=consumer_name,
|
||||
deliver_policy=DeliverPolicy.ALL,
|
||||
ack_policy=AckPolicy.EXPLICIT,
|
||||
ack_wait=ACK_WAIT,
|
||||
filter_subject=SUBJECT_FILTER,
|
||||
max_deliver=5,
|
||||
filter_subject=subject_filter,
|
||||
)
|
||||
await self._js.add_consumer(stream_name, consumer_config)
|
||||
logger.info(
|
||||
"Consumer created",
|
||||
extra={"stream": stream_name, "consumer": consumer_name}
|
||||
)
|
||||
await self._js.add_consumer(STREAM_NAME, consumer_config)
|
||||
logger.info("Consumer created", extra={"consumer": CONSUMER_NAME})
|
||||
|
||||
async def _process_message(self, msg: Any, conn: asyncpg.Connection) -> None:
|
||||
"""Process a single message and insert into database."""
|
||||
|
|
@ -157,7 +196,7 @@ class ArchiveConsumer:
|
|||
geo_data = event_data.get("geo")
|
||||
|
||||
event_id = envelope.get("id")
|
||||
source = event_data.get("source", "")
|
||||
adapter = event_data.get("adapter", "")
|
||||
category = event_data.get("category", "")
|
||||
time_str = event_data.get("time")
|
||||
expires_str = event_data.get("expires")
|
||||
|
|
@ -194,12 +233,12 @@ class ArchiveConsumer:
|
|||
if geom_json:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO events (id, source, category, time, expires, severity,
|
||||
INSERT INTO events (id, adapter, category, time, expires, severity,
|
||||
geom, regions, primary_region, payload)
|
||||
VALUES ($1, $2, $3, $4, $5, $6,
|
||||
ST_GeomFromGeoJSON($7), $8, $9, $10)
|
||||
ON CONFLICT (id, time) DO UPDATE SET
|
||||
source = EXCLUDED.source,
|
||||
adapter = EXCLUDED.adapter,
|
||||
category = EXCLUDED.category,
|
||||
expires = EXCLUDED.expires,
|
||||
severity = EXCLUDED.severity,
|
||||
|
|
@ -208,17 +247,17 @@ class ArchiveConsumer:
|
|||
primary_region = EXCLUDED.primary_region,
|
||||
payload = EXCLUDED.payload
|
||||
""",
|
||||
event_id, source, category, event_time, expires_time, severity,
|
||||
event_id, adapter, category, event_time, expires_time, severity,
|
||||
geom_json, regions, primary_region, json.dumps(envelope)
|
||||
)
|
||||
else:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO events (id, source, category, time, expires, severity,
|
||||
INSERT INTO events (id, adapter, category, time, expires, severity,
|
||||
geom, regions, primary_region, payload)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, NULL, $7, $8, $9)
|
||||
ON CONFLICT (id, time) DO UPDATE SET
|
||||
source = EXCLUDED.source,
|
||||
adapter = EXCLUDED.adapter,
|
||||
category = EXCLUDED.category,
|
||||
expires = EXCLUDED.expires,
|
||||
severity = EXCLUDED.severity,
|
||||
|
|
@ -227,7 +266,7 @@ class ArchiveConsumer:
|
|||
primary_region = EXCLUDED.primary_region,
|
||||
payload = EXCLUDED.payload
|
||||
""",
|
||||
event_id, source, category, event_time, expires_time, severity,
|
||||
event_id, adapter, category, event_time, expires_time, severity,
|
||||
regions, primary_region, json.dumps(envelope)
|
||||
)
|
||||
|
||||
|
|
@ -241,22 +280,24 @@ class ArchiveConsumer:
|
|||
)
|
||||
# Don't ack - let it be redelivered
|
||||
|
||||
async def _consume_loop(self) -> None:
|
||||
"""Main consume loop."""
|
||||
async def _consume_stream(
|
||||
self, stream_name: str, subject_filter: str, consumer_name: str
|
||||
) -> None:
|
||||
"""Consume loop for a single stream."""
|
||||
if not self._js or not self._pool:
|
||||
return
|
||||
|
||||
await self._ensure_consumer()
|
||||
await self._ensure_consumer(stream_name, subject_filter, consumer_name)
|
||||
|
||||
sub = await self._js.pull_subscribe(
|
||||
SUBJECT_FILTER,
|
||||
durable=CONSUMER_NAME,
|
||||
stream=STREAM_NAME,
|
||||
subject_filter,
|
||||
durable=consumer_name,
|
||||
stream=stream_name,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Subscribed to stream",
|
||||
extra={"stream": STREAM_NAME, "filter": SUBJECT_FILTER}
|
||||
extra={"stream": stream_name, "filter": subject_filter}
|
||||
)
|
||||
|
||||
while not self._shutdown_event.is_set():
|
||||
|
|
@ -277,19 +318,62 @@ class ArchiveConsumer:
|
|||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.exception("Error in consume loop", extra={"error": str(e)})
|
||||
logger.exception(
|
||||
"Error in consume loop",
|
||||
extra={"stream": stream_name, "error": str(e)}
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
logger.info("Consume loop stopped")
|
||||
logger.info("Consume loop stopped", extra={"stream": stream_name})
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the consumer."""
|
||||
await self.connect()
|
||||
await self._cleanup_orphaned_consumer()
|
||||
logger.info("Archive consumer ready")
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the consume loop until shutdown."""
|
||||
await self._consume_loop()
|
||||
"""Run consume loops for all streams until shutdown."""
|
||||
tasks = []
|
||||
for stream_name, subject_filter in STREAMS:
|
||||
consumer_name = consumer_name_for(stream_name)
|
||||
task = asyncio.create_task(
|
||||
self._consume_stream(stream_name, subject_filter, consumer_name),
|
||||
name=f"consume-{stream_name}",
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
try:
|
||||
# Wait for all tasks; if one fails, cancel the others
|
||||
done, pending = await asyncio.wait(
|
||||
tasks,
|
||||
return_when=asyncio.FIRST_EXCEPTION,
|
||||
)
|
||||
|
||||
# Check for exceptions in completed tasks
|
||||
for task in done:
|
||||
if task.exception():
|
||||
logger.error(
|
||||
"Stream consumer failed",
|
||||
extra={"task": task.get_name(), "error": str(task.exception())}
|
||||
)
|
||||
|
||||
# Cancel any remaining tasks
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Shutdown requested, cancel all tasks
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the consumer gracefully."""
|
||||
|
|
@ -308,7 +392,6 @@ async def async_main() -> None:
|
|||
"Archive starting",
|
||||
extra={
|
||||
"nats_url": settings.nats_url,
|
||||
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -80,12 +80,16 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
from central.bootstrap_config import get_settings
|
||||
from central.gui.db import close_pool, init_pool
|
||||
from central.gui.nats import close_nats, init_nats
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# Initialize database pool
|
||||
await init_pool(settings.db_dsn)
|
||||
|
||||
# Initialize NATS connection
|
||||
await init_nats(settings.nats_url)
|
||||
|
||||
# Start session cleanup task
|
||||
_shutdown_event = asyncio.Event()
|
||||
_cleanup_task = asyncio.create_task(_session_cleanup_loop())
|
||||
|
|
@ -103,6 +107,7 @@ async def lifespan(app: FastAPI):
|
|||
except asyncio.TimeoutError:
|
||||
_cleanup_task.cancel()
|
||||
|
||||
await close_nats()
|
||||
await close_pool()
|
||||
logger.info("Central GUI stopped")
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ AUTH_LOGIN_FAILED = "auth.login_failed"
|
|||
AUTH_LOGOUT = "auth.logout"
|
||||
AUTH_PASSWORD_CHANGE = "auth.password_change"
|
||||
OPERATOR_CREATE = "operator.create"
|
||||
ADAPTER_UPDATE = "adapter.update"
|
||||
|
||||
|
||||
async def write_audit(
|
||||
|
|
@ -20,18 +21,15 @@ async def write_audit(
|
|||
after: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Write an audit log entry."""
|
||||
# Serialize before/after as JSON strings if provided
|
||||
before_json = json.dumps(before) if before else None
|
||||
after_json = json.dumps(after) if after else None
|
||||
|
||||
# asyncpg handles dict -> jsonb conversion automatically
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO config.audit_log (operator_id, action, target, before, after)
|
||||
VALUES ($1, $2, $3, $4::jsonb, $5::jsonb)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
""",
|
||||
operator_id,
|
||||
action,
|
||||
target,
|
||||
before_json,
|
||||
after_json,
|
||||
before,
|
||||
after,
|
||||
)
|
||||
|
|
|
|||
46
src/central/gui/nats.py
Normal file
46
src/central/gui/nats.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
"""NATS connection for GUI."""
|
||||
|
||||
import logging
|
||||
|
||||
import nats
|
||||
from nats.js import JetStreamContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_nc: nats.NATS | None = None
|
||||
_js: JetStreamContext | None = None
|
||||
|
||||
|
||||
async def init_nats(url: str) -> JetStreamContext | None:
|
||||
"""Initialize the NATS connection and JetStream context."""
|
||||
global _nc, _js
|
||||
if _nc is None:
|
||||
try:
|
||||
_nc = await nats.connect(url)
|
||||
_js = _nc.jetstream()
|
||||
logger.info("Connected to NATS", extra={"url": url})
|
||||
except Exception as e:
|
||||
logger.warning("Failed to connect to NATS", extra={"error": str(e)})
|
||||
_nc = None
|
||||
_js = None
|
||||
return _js
|
||||
|
||||
|
||||
def get_js() -> JetStreamContext | None:
|
||||
"""Get the JetStream context. Returns None if not connected."""
|
||||
return _js
|
||||
|
||||
|
||||
async def close_nats() -> None:
|
||||
"""Close the NATS connection."""
|
||||
global _nc, _js
|
||||
if _nc is not None:
|
||||
try:
|
||||
await _nc.drain()
|
||||
await _nc.close()
|
||||
logger.info("Disconnected from NATS")
|
||||
except Exception as e:
|
||||
logger.warning("Error closing NATS", extra={"error": str(e)})
|
||||
finally:
|
||||
_nc = None
|
||||
_js = None
|
||||
|
|
@ -1,5 +1,9 @@
|
|||
"""Route handlers for Central GUI."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Form, Request
|
||||
from fastapi.responses import HTMLResponse, RedirectResponse, Response
|
||||
from fastapi_csrf_protect import CsrfProtect
|
||||
|
|
@ -12,6 +16,7 @@ from central.gui.auth import (
|
|||
verify_password,
|
||||
)
|
||||
from central.gui.audit import (
|
||||
ADAPTER_UPDATE,
|
||||
AUTH_LOGIN,
|
||||
AUTH_LOGIN_FAILED,
|
||||
AUTH_LOGOUT,
|
||||
|
|
@ -23,6 +28,24 @@ from central.gui.db import get_pool
|
|||
|
||||
router = APIRouter()
|
||||
|
||||
# Streams to display on dashboard
|
||||
DASHBOARD_STREAMS = ["CENTRAL_WX", "CENTRAL_FIRE", "CENTRAL_QUAKE", "CENTRAL_META"]
|
||||
|
||||
# Email validation regex (simple but effective)
|
||||
EMAIL_REGEX = re.compile(r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$")
|
||||
|
||||
|
||||
def _get_valid_satellites() -> list[str]:
|
||||
"""Get valid satellite identifiers from firms adapter."""
|
||||
from central.adapters.firms import SATELLITE_SHORT
|
||||
return list(SATELLITE_SHORT.keys())
|
||||
|
||||
|
||||
def _get_valid_feeds() -> set[str]:
|
||||
"""Get valid feed values from usgs_quake adapter."""
|
||||
from central.adapters.usgs_quake import VALID_FEEDS
|
||||
return VALID_FEEDS
|
||||
|
||||
|
||||
def _get_templates():
|
||||
"""Get templates instance (deferred import to avoid circular)."""
|
||||
|
|
@ -30,6 +53,15 @@ def _get_templates():
|
|||
return templates
|
||||
|
||||
|
||||
def _format_bytes(size: int) -> str:
|
||||
"""Format bytes as human-readable string."""
|
||||
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||
if size < 1024:
|
||||
return f"{size:.1f} {unit}" if unit != "B" else f"{size} {unit}"
|
||||
size /= 1024
|
||||
return f"{size:.1f} PB"
|
||||
|
||||
|
||||
def _set_session_cookie(
|
||||
response: Response,
|
||||
token: str,
|
||||
|
|
@ -76,6 +108,154 @@ async def index(request: Request, csrf_protect: CsrfProtect = Depends()) -> HTML
|
|||
return response
|
||||
|
||||
|
||||
@router.get("/dashboard/events", response_class=HTMLResponse)
|
||||
async def dashboard_events(request: Request) -> HTMLResponse:
|
||||
"""Get events by adapter for the last 24 hours."""
|
||||
templates = _get_templates()
|
||||
pool = get_pool()
|
||||
|
||||
events = []
|
||||
error = None
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT adapter, COUNT(*) as count
|
||||
FROM events
|
||||
WHERE received > NOW() - INTERVAL '24 hours'
|
||||
GROUP BY adapter
|
||||
ORDER BY count DESC
|
||||
"""
|
||||
)
|
||||
events = [{"adapter": row["adapter"], "count": row["count"]} for row in rows]
|
||||
except Exception as e:
|
||||
error = f"Database error: {str(e)}"
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request=request,
|
||||
name="_dashboard_events.html",
|
||||
context={"events": events, "error": error},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/dashboard/streams", response_class=HTMLResponse)
|
||||
async def dashboard_streams(request: Request) -> HTMLResponse:
|
||||
"""Get stream sizes from NATS JetStream."""
|
||||
from central.gui.nats import get_js
|
||||
|
||||
templates = _get_templates()
|
||||
js = get_js()
|
||||
|
||||
streams = None
|
||||
error = None
|
||||
|
||||
if js is None:
|
||||
error = "NATS unavailable"
|
||||
else:
|
||||
streams = []
|
||||
for stream_name in DASHBOARD_STREAMS:
|
||||
try:
|
||||
info = await js.stream_info(stream_name)
|
||||
streams.append({
|
||||
"name": stream_name,
|
||||
"messages": info.state.messages,
|
||||
"size": _format_bytes(info.state.bytes),
|
||||
"error": None,
|
||||
})
|
||||
except Exception:
|
||||
streams.append({
|
||||
"name": stream_name,
|
||||
"messages": 0,
|
||||
"size": "0 B",
|
||||
"error": "unavailable",
|
||||
})
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request=request,
|
||||
name="_dashboard_streams.html",
|
||||
context={"streams": streams, "error": error},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/dashboard/polls", response_class=HTMLResponse)
|
||||
async def dashboard_polls(request: Request) -> HTMLResponse:
|
||||
"""Get last poll times for each adapter."""
|
||||
from central.gui.nats import get_js
|
||||
|
||||
templates = _get_templates()
|
||||
pool = get_pool()
|
||||
js = get_js()
|
||||
|
||||
adapters = []
|
||||
error = None
|
||||
|
||||
try:
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"SELECT name FROM config.adapters ORDER BY name"
|
||||
)
|
||||
adapter_names = [row["name"] for row in rows]
|
||||
except Exception as e:
|
||||
error = f"Database error: {str(e)}"
|
||||
return templates.TemplateResponse(
|
||||
request=request,
|
||||
name="_dashboard_polls.html",
|
||||
context={"adapters": [], "error": error},
|
||||
)
|
||||
|
||||
if js is None:
|
||||
error = "NATS unavailable"
|
||||
adapters = [{"name": name, "last_poll": None, "status": None, "error": "NATS unavailable"} for name in adapter_names]
|
||||
else:
|
||||
for name in adapter_names:
|
||||
try:
|
||||
# Get last message from CENTRAL_META for this adapter
|
||||
sub = await js.pull_subscribe(
|
||||
f"central.meta.{name}.status",
|
||||
durable=f"dashboard-poll-{name}",
|
||||
stream="CENTRAL_META",
|
||||
)
|
||||
try:
|
||||
msgs = await sub.fetch(1, timeout=1.0)
|
||||
if msgs:
|
||||
data = json.loads(msgs[0].data.decode())
|
||||
last_poll = data.get("data", {}).get("time", "—")
|
||||
adapters.append({
|
||||
"name": name,
|
||||
"last_poll": last_poll,
|
||||
"status": "✓",
|
||||
"error": None,
|
||||
})
|
||||
else:
|
||||
adapters.append({
|
||||
"name": name,
|
||||
"last_poll": None,
|
||||
"status": None,
|
||||
"error": None,
|
||||
})
|
||||
except Exception:
|
||||
adapters.append({
|
||||
"name": name,
|
||||
"last_poll": None,
|
||||
"status": None,
|
||||
"error": None,
|
||||
})
|
||||
except Exception:
|
||||
adapters.append({
|
||||
"name": name,
|
||||
"last_poll": None,
|
||||
"status": None,
|
||||
"error": "unavailable",
|
||||
})
|
||||
|
||||
return templates.TemplateResponse(
|
||||
request=request,
|
||||
name="_dashboard_polls.html",
|
||||
context={"adapters": adapters, "error": error},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/setup", response_class=HTMLResponse)
|
||||
async def setup_form(
|
||||
request: Request,
|
||||
|
|
@ -370,3 +550,285 @@ async def change_password_submit(
|
|||
|
||||
# Redirect to index
|
||||
return RedirectResponse(url="/", status_code=302)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Adapters routes
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@router.get("/adapters", response_class=HTMLResponse)
|
||||
async def adapters_list(
|
||||
request: Request,
|
||||
csrf_protect: CsrfProtect = Depends(),
|
||||
) -> HTMLResponse:
|
||||
"""List all adapters."""
|
||||
templates = _get_templates()
|
||||
pool = get_pool()
|
||||
operator = request.state.operator
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
rows = await conn.fetch(
|
||||
"""
|
||||
SELECT name, enabled, cadence_s, settings, paused_at, updated_at
|
||||
FROM config.adapters
|
||||
ORDER BY name
|
||||
"""
|
||||
)
|
||||
|
||||
adapters = []
|
||||
for row in rows:
|
||||
settings = row["settings"] or {}
|
||||
adapters.append({
|
||||
"name": row["name"],
|
||||
"enabled": row["enabled"],
|
||||
"cadence_s": row["cadence_s"],
|
||||
"settings": settings,
|
||||
"paused_at": row["paused_at"],
|
||||
"updated_at": row["updated_at"],
|
||||
})
|
||||
|
||||
csrf_token, signed_token = csrf_protect.generate_csrf_tokens()
|
||||
response = templates.TemplateResponse(
|
||||
request=request,
|
||||
name="adapters_list.html",
|
||||
context={
|
||||
"operator": operator,
|
||||
"csrf_token": csrf_token,
|
||||
"adapters": adapters,
|
||||
},
|
||||
)
|
||||
csrf_protect.set_csrf_cookie(signed_token, response)
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/adapters/{name}", response_class=HTMLResponse)
|
||||
async def adapters_edit_form(
|
||||
request: Request,
|
||||
name: str,
|
||||
csrf_protect: CsrfProtect = Depends(),
|
||||
) -> Response:
|
||||
"""Render the adapter edit form."""
|
||||
templates = _get_templates()
|
||||
pool = get_pool()
|
||||
operator = request.state.operator
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT name, enabled, cadence_s, settings, paused_at, updated_at
|
||||
FROM config.adapters
|
||||
WHERE name = $1
|
||||
""",
|
||||
name,
|
||||
)
|
||||
|
||||
if row is None:
|
||||
return Response(status_code=404, content="Adapter not found")
|
||||
|
||||
# Get API keys for firms dropdown
|
||||
api_keys = await conn.fetch(
|
||||
"SELECT alias FROM config.api_keys ORDER BY alias"
|
||||
)
|
||||
|
||||
settings = row["settings"] or {}
|
||||
adapter = {
|
||||
"name": row["name"],
|
||||
"enabled": row["enabled"],
|
||||
"cadence_s": row["cadence_s"],
|
||||
"settings": settings,
|
||||
"paused_at": row["paused_at"],
|
||||
"updated_at": row["updated_at"],
|
||||
}
|
||||
|
||||
csrf_token, signed_token = csrf_protect.generate_csrf_tokens()
|
||||
response = templates.TemplateResponse(
|
||||
request=request,
|
||||
name="adapters_edit.html",
|
||||
context={
|
||||
"operator": operator,
|
||||
"csrf_token": csrf_token,
|
||||
"adapter": adapter,
|
||||
"errors": None,
|
||||
"form_data": None,
|
||||
"api_keys": [{"alias": k["alias"]} for k in api_keys],
|
||||
"valid_satellites": _get_valid_satellites(),
|
||||
"valid_feeds": sorted(_get_valid_feeds()),
|
||||
},
|
||||
)
|
||||
csrf_protect.set_csrf_cookie(signed_token, response)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/adapters/{name}")
|
||||
async def adapters_edit_submit(
|
||||
request: Request,
|
||||
name: str,
|
||||
csrf_protect: CsrfProtect = Depends(),
|
||||
) -> Response:
|
||||
"""Process the adapter edit form."""
|
||||
templates = _get_templates()
|
||||
pool = get_pool()
|
||||
operator = request.state.operator
|
||||
|
||||
# Validate CSRF
|
||||
await csrf_protect.validate_csrf(request)
|
||||
|
||||
# Parse form data
|
||||
form = await request.form()
|
||||
enabled = "enabled" in form
|
||||
cadence_s_str = form.get("cadence_s", "")
|
||||
|
||||
# Build form_data for re-render on error
|
||||
form_data: dict[str, Any] = {
|
||||
"enabled": enabled,
|
||||
"cadence_s": cadence_s_str,
|
||||
}
|
||||
|
||||
errors: dict[str, str] = {}
|
||||
|
||||
# Validate cadence_s
|
||||
try:
|
||||
cadence_s = int(cadence_s_str)
|
||||
if cadence_s < 60 or cadence_s > 3600:
|
||||
errors["cadence_s"] = "Cadence must be between 60 and 3600 seconds"
|
||||
except ValueError:
|
||||
errors["cadence_s"] = "Cadence must be a valid integer"
|
||||
cadence_s = 0
|
||||
|
||||
async with pool.acquire() as conn:
|
||||
# Get current adapter state
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT name, enabled, cadence_s, settings, paused_at, updated_at
|
||||
FROM config.adapters
|
||||
WHERE name = $1
|
||||
""",
|
||||
name,
|
||||
)
|
||||
|
||||
if row is None:
|
||||
return Response(status_code=404, content="Adapter not found")
|
||||
|
||||
current_settings = row["settings"] or {}
|
||||
new_settings = dict(current_settings)
|
||||
|
||||
# Adapter-specific validation and settings update
|
||||
if name == "nws":
|
||||
contact_email = form.get("contact_email", "").strip()
|
||||
form_data["contact_email"] = contact_email
|
||||
if not contact_email:
|
||||
errors["contact_email"] = "Contact email is required"
|
||||
elif not EMAIL_REGEX.match(contact_email):
|
||||
errors["contact_email"] = "Invalid email format"
|
||||
else:
|
||||
new_settings["contact_email"] = contact_email
|
||||
|
||||
elif name == "firms":
|
||||
api_key_alias = form.get("api_key_alias", "").strip()
|
||||
satellites = form.getlist("satellites")
|
||||
form_data["api_key_alias"] = api_key_alias
|
||||
form_data["satellites"] = satellites
|
||||
|
||||
# Validate api_key_alias if set
|
||||
if api_key_alias:
|
||||
key_exists = await conn.fetchrow(
|
||||
"SELECT 1 FROM config.api_keys WHERE alias = $1",
|
||||
api_key_alias,
|
||||
)
|
||||
if not key_exists:
|
||||
errors["api_key_alias"] = f"API key alias '{api_key_alias}' does not exist"
|
||||
else:
|
||||
new_settings["api_key_alias"] = api_key_alias
|
||||
else:
|
||||
new_settings["api_key_alias"] = None
|
||||
|
||||
# Validate satellites
|
||||
valid_sats = set(_get_valid_satellites())
|
||||
invalid_sats = [s for s in satellites if s not in valid_sats]
|
||||
if invalid_sats:
|
||||
errors["satellites"] = f"Invalid satellites: {', '.join(invalid_sats)}"
|
||||
else:
|
||||
new_settings["satellites"] = satellites
|
||||
|
||||
elif name == "usgs_quake":
|
||||
feed = form.get("feed", "").strip()
|
||||
form_data["feed"] = feed
|
||||
valid_feeds = _get_valid_feeds()
|
||||
if feed not in valid_feeds:
|
||||
errors["feed"] = f"Invalid feed. Must be one of: {', '.join(sorted(valid_feeds))}"
|
||||
else:
|
||||
new_settings["feed"] = feed
|
||||
|
||||
# If there are errors, re-render the form
|
||||
if errors:
|
||||
adapter = {
|
||||
"name": row["name"],
|
||||
"enabled": row["enabled"],
|
||||
"cadence_s": row["cadence_s"],
|
||||
"settings": current_settings,
|
||||
"paused_at": row["paused_at"],
|
||||
"updated_at": row["updated_at"],
|
||||
}
|
||||
|
||||
api_keys = await conn.fetch(
|
||||
"SELECT alias FROM config.api_keys ORDER BY alias"
|
||||
)
|
||||
|
||||
csrf_token, signed_token = csrf_protect.generate_csrf_tokens()
|
||||
response = templates.TemplateResponse(
|
||||
request=request,
|
||||
name="adapters_edit.html",
|
||||
context={
|
||||
"operator": operator,
|
||||
"csrf_token": csrf_token,
|
||||
"adapter": adapter,
|
||||
"errors": errors,
|
||||
"form_data": form_data,
|
||||
"api_keys": [{"alias": k["alias"]} for k in api_keys],
|
||||
"valid_satellites": _get_valid_satellites(),
|
||||
"valid_feeds": sorted(_get_valid_feeds()),
|
||||
},
|
||||
status_code=200,
|
||||
)
|
||||
csrf_protect.set_csrf_cookie(signed_token, response)
|
||||
return response
|
||||
|
||||
# Build before state for audit
|
||||
before = {
|
||||
"enabled": row["enabled"],
|
||||
"cadence_s": row["cadence_s"],
|
||||
"settings": current_settings,
|
||||
}
|
||||
|
||||
# Build after state for audit
|
||||
after = {
|
||||
"enabled": enabled,
|
||||
"cadence_s": cadence_s,
|
||||
"settings": new_settings,
|
||||
}
|
||||
|
||||
# Update the adapter
|
||||
await conn.execute(
|
||||
"""
|
||||
UPDATE config.adapters
|
||||
SET enabled = $1, cadence_s = $2, settings = $3, updated_at = now()
|
||||
WHERE name = $4
|
||||
""",
|
||||
enabled,
|
||||
cadence_s,
|
||||
new_settings,
|
||||
name,
|
||||
)
|
||||
|
||||
# Write audit log
|
||||
await write_audit(
|
||||
conn,
|
||||
ADAPTER_UPDATE,
|
||||
operator_id=operator.id,
|
||||
target=name,
|
||||
before=before,
|
||||
after=after,
|
||||
)
|
||||
|
||||
return RedirectResponse(url="/adapters", status_code=302)
|
||||
|
|
|
|||
22
src/central/gui/templates/_dashboard_events.html
Normal file
22
src/central/gui/templates/_dashboard_events.html
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
{% if error %}
|
||||
<p class="error">{{ error }}</p>
|
||||
{% elif events %}
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Adapter</th>
|
||||
<th>Count</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for event in events %}
|
||||
<tr>
|
||||
<td>{{ event.adapter }}</td>
|
||||
<td>{{ event.count }}</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
{% else %}
|
||||
<p>No events in the last 24 hours.</p>
|
||||
{% endif %}
|
||||
31
src/central/gui/templates/_dashboard_polls.html
Normal file
31
src/central/gui/templates/_dashboard_polls.html
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
{% if error %}
|
||||
<p class="error">{{ error }}</p>
|
||||
{% elif adapters %}
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Adapter</th>
|
||||
<th>Last Poll</th>
|
||||
<th>Status</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for adapter in adapters %}
|
||||
<tr>
|
||||
<td>{{ adapter.name }}</td>
|
||||
{% if adapter.error %}
|
||||
<td colspan="2" class="error">{{ adapter.error }}</td>
|
||||
{% elif adapter.last_poll %}
|
||||
<td>{{ adapter.last_poll }}</td>
|
||||
<td>{{ adapter.status }}</td>
|
||||
{% else %}
|
||||
<td>—</td>
|
||||
<td>—</td>
|
||||
{% endif %}
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
{% else %}
|
||||
<p>No adapter data available.</p>
|
||||
{% endif %}
|
||||
28
src/central/gui/templates/_dashboard_streams.html
Normal file
28
src/central/gui/templates/_dashboard_streams.html
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
{% if error %}
|
||||
<p class="error">{{ error }}</p>
|
||||
{% elif streams %}
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Stream</th>
|
||||
<th>Messages</th>
|
||||
<th>Size</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for stream in streams %}
|
||||
<tr>
|
||||
<td>{{ stream.name }}</td>
|
||||
{% if stream.error %}
|
||||
<td colspan="2" class="error">{{ stream.error }}</td>
|
||||
{% else %}
|
||||
<td>{{ stream.messages }}</td>
|
||||
<td>{{ stream.size }}</td>
|
||||
{% endif %}
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
{% else %}
|
||||
<p>No stream data available.</p>
|
||||
{% endif %}
|
||||
49
src/central/gui/templates/adapters_edit.html
Normal file
49
src/central/gui/templates/adapters_edit.html
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
{% extends "base.html" %}
|
||||
|
||||
{% block title %}Central — Edit {{ adapter.name }}{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<h1>Edit Adapter: {{ adapter.name }}</h1>
|
||||
|
||||
<form method="post" action="/adapters/{{ adapter.name }}">
|
||||
<input type="hidden" name="csrf_token" value="{{ csrf_token }}">
|
||||
|
||||
<fieldset>
|
||||
<legend>Universal Settings</legend>
|
||||
|
||||
<label>
|
||||
<input type="checkbox" name="enabled" {% if adapter.enabled %}checked{% endif %}>
|
||||
Enabled
|
||||
</label>
|
||||
|
||||
<label for="cadence_s">Cadence (seconds)</label>
|
||||
<input type="number" id="cadence_s" name="cadence_s" value="{{ form_data.cadence_s if form_data else adapter.cadence_s }}" min="60" max="3600" required>
|
||||
{% if errors and errors.cadence_s %}
|
||||
<small style="color: var(--pico-color-red-500);">{{ errors.cadence_s }}</small>
|
||||
{% endif %}
|
||||
</fieldset>
|
||||
|
||||
<fieldset>
|
||||
<legend>Adapter-Specific Settings</legend>
|
||||
{% include "adapters_edit_" + adapter.name + ".html" %}
|
||||
</fieldset>
|
||||
|
||||
<fieldset>
|
||||
<legend>Region (read-only)</legend>
|
||||
{% if adapter.settings.region %}
|
||||
<p>
|
||||
<strong>North:</strong> {{ adapter.settings.region.north }}<br>
|
||||
<strong>South:</strong> {{ adapter.settings.region.south }}<br>
|
||||
<strong>East:</strong> {{ adapter.settings.region.east }}<br>
|
||||
<strong>West:</strong> {{ adapter.settings.region.west }}
|
||||
</p>
|
||||
{% else %}
|
||||
<p>No region configured.</p>
|
||||
{% endif %}
|
||||
<small>Region editing comes in 1b-5.</small>
|
||||
</fieldset>
|
||||
|
||||
<button type="submit">Save Changes</button>
|
||||
<a href="/adapters" role="button" class="outline">Cancel</a>
|
||||
</form>
|
||||
{% endblock %}
|
||||
21
src/central/gui/templates/adapters_edit_firms.html
Normal file
21
src/central/gui/templates/adapters_edit_firms.html
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
<label for="api_key_alias">API Key Alias</label>
|
||||
<select id="api_key_alias" name="api_key_alias">
|
||||
<option value="" {% if not (form_data.api_key_alias if form_data else adapter.settings.api_key_alias) %}selected{% endif %}>(none)</option>
|
||||
{% for key in api_keys %}
|
||||
<option value="{{ key.alias }}" {% if (form_data.api_key_alias if form_data else adapter.settings.api_key_alias) == key.alias %}selected{% endif %}>{{ key.alias }}</option>
|
||||
{% endfor %}
|
||||
</select>
|
||||
{% if errors and errors.api_key_alias %}
|
||||
<small style="color: var(--pico-color-red-500);">{{ errors.api_key_alias }}</small>
|
||||
{% endif %}
|
||||
|
||||
<label>Satellites</label>
|
||||
{% for sat in valid_satellites %}
|
||||
<label>
|
||||
<input type="checkbox" name="satellites" value="{{ sat }}" {% if sat in (form_data.satellites if form_data else adapter.settings.satellites or []) %}checked{% endif %}>
|
||||
{{ sat }}
|
||||
</label>
|
||||
{% endfor %}
|
||||
{% if errors and errors.satellites %}
|
||||
<small style="color: var(--pico-color-red-500);">{{ errors.satellites }}</small>
|
||||
{% endif %}
|
||||
5
src/central/gui/templates/adapters_edit_nws.html
Normal file
5
src/central/gui/templates/adapters_edit_nws.html
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
<label for="contact_email">Contact Email</label>
|
||||
<input type="email" id="contact_email" name="contact_email" value="{{ form_data.contact_email if form_data else adapter.settings.contact_email }}" required>
|
||||
{% if errors and errors.contact_email %}
|
||||
<small style="color: var(--pico-color-red-500);">{{ errors.contact_email }}</small>
|
||||
{% endif %}
|
||||
9
src/central/gui/templates/adapters_edit_usgs_quake.html
Normal file
9
src/central/gui/templates/adapters_edit_usgs_quake.html
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
<label for="feed">Feed</label>
|
||||
<select id="feed" name="feed" required>
|
||||
{% for f in valid_feeds %}
|
||||
<option value="{{ f }}" {% if (form_data.feed if form_data else adapter.settings.feed) == f %}selected{% endif %}>{{ f }}</option>
|
||||
{% endfor %}
|
||||
</select>
|
||||
{% if errors and errors.feed %}
|
||||
<small style="color: var(--pico-color-red-500);">{{ errors.feed }}</small>
|
||||
{% endif %}
|
||||
29
src/central/gui/templates/adapters_list.html
Normal file
29
src/central/gui/templates/adapters_list.html
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
{% extends "base.html" %}
|
||||
|
||||
{% block title %}Central — Adapters{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<h1>Adapters</h1>
|
||||
<table>
|
||||
<thead>
|
||||
<tr>
|
||||
<th>Name</th>
|
||||
<th>Enabled</th>
|
||||
<th>Cadence</th>
|
||||
<th>Last Updated</th>
|
||||
<th></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for adapter in adapters %}
|
||||
<tr>
|
||||
<td>{{ adapter.name }}</td>
|
||||
<td>{% if adapter.enabled %}Yes{% else %}No{% endif %}</td>
|
||||
<td>{{ adapter.cadence_s }}s</td>
|
||||
<td>{{ adapter.updated_at.strftime('%Y-%m-%d %H:%M') if adapter.updated_at else '—' }}</td>
|
||||
<td><a href="/adapters/{{ adapter.name }}">Edit</a></td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
{% endblock %}
|
||||
|
|
@ -15,6 +15,8 @@
|
|||
</ul>
|
||||
<ul>
|
||||
{% if operator %}
|
||||
<li><a href="/">Dashboard</a></li>
|
||||
<li><a href="/adapters">Adapters</a></li>
|
||||
<li>{{ operator.username }}</li>
|
||||
<li><a href="/change-password">Change Password</a></li>
|
||||
<li>
|
||||
|
|
|
|||
|
|
@ -1,12 +1,27 @@
|
|||
{% extends "base.html" %}
|
||||
|
||||
{% block title %}Central — Coming Soon{% endblock %}
|
||||
{% block title %}Central — Dashboard{% endblock %}
|
||||
|
||||
{% block content %}
|
||||
<article>
|
||||
<header>
|
||||
<h1>Central</h1>
|
||||
</header>
|
||||
<p>Data hub GUI — coming soon.</p>
|
||||
</article>
|
||||
<h1>Dashboard</h1>
|
||||
<div class="grid">
|
||||
<article>
|
||||
<header>Events (24h)</header>
|
||||
<div hx-get="/dashboard/events" hx-trigger="load, every 10s" hx-swap="innerHTML">
|
||||
Loading...
|
||||
</div>
|
||||
</article>
|
||||
<article>
|
||||
<header>Stream Sizes</header>
|
||||
<div hx-get="/dashboard/streams" hx-trigger="load, every 10s" hx-swap="innerHTML">
|
||||
Loading...
|
||||
</div>
|
||||
</article>
|
||||
<article>
|
||||
<header>Last Poll Times</header>
|
||||
<div hx-get="/dashboard/polls" hx-trigger="load, every 10s" hx-swap="innerHTML">
|
||||
Loading...
|
||||
</div>
|
||||
</article>
|
||||
</div>
|
||||
{% endblock %}
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class Event(BaseModel):
|
|||
model_config = ConfigDict(extra="forbid", frozen=True)
|
||||
|
||||
id: str # unique, stable across republish
|
||||
source: str # adapter identity, e.g. "central/adapters/nws"
|
||||
adapter: str # adapter identity, e.g. "nws"
|
||||
category: str # e.g. "wx.alert.severe_thunderstorm_warning" or "fire.hotspot.viirs_snpp.high"
|
||||
time: datetime # event-time UTC, not processing-time
|
||||
expires: datetime | None = None
|
||||
|
|
|
|||
511
tests/test_adapters.py
Normal file
511
tests/test_adapters.py
Normal file
|
|
@ -0,0 +1,511 @@
|
|||
"""Tests for adapter list and edit routes."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Set required env vars before importing central modules
|
||||
os.environ.setdefault("CENTRAL_DB_DSN", "postgresql://test:test@localhost/test")
|
||||
os.environ.setdefault("CENTRAL_CSRF_SECRET", "testsecret12345678901234567890ab")
|
||||
os.environ.setdefault("CENTRAL_NATS_URL", "nats://localhost:4222")
|
||||
|
||||
|
||||
class TestAdaptersListUnauthenticated:
|
||||
"""Test adapters list without authentication."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapters_list_unauthenticated_redirects(self):
|
||||
"""GET /adapters without auth redirects to /login."""
|
||||
from central.gui.routes import adapters_list
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = None
|
||||
|
||||
# The middleware handles the redirect, so we test the route expects operator
|
||||
# In practice, middleware returns 302 before route is called
|
||||
# This test verifies the route structure expects authentication
|
||||
assert adapters_list is not None
|
||||
|
||||
|
||||
class TestAdaptersListAuthenticated:
|
||||
"""Test adapters list with authentication."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapters_list_returns_all_adapters(self):
|
||||
"""GET /adapters authenticated returns 200 with all adapters."""
|
||||
from central.gui.routes import adapters_list
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetch.return_value = [
|
||||
{"name": "firms", "enabled": True, "cadence_s": 300, "settings": {"api_key_alias": "firms"}, "paused_at": None, "updated_at": None},
|
||||
{"name": "nws", "enabled": True, "cadence_s": 60, "settings": {"contact_email": "test@test.com"}, "paused_at": None, "updated_at": None},
|
||||
{"name": "usgs_quake", "enabled": True, "cadence_s": 120, "settings": {"feed": "all_hour"}, "paused_at": None, "updated_at": None},
|
||||
]
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_templates = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_templates.TemplateResponse.return_value = mock_response
|
||||
|
||||
mock_csrf = MagicMock()
|
||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
||||
mock_csrf.set_csrf_cookie = MagicMock()
|
||||
|
||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
result = await adapters_list(mock_request, mock_csrf)
|
||||
|
||||
# Verify template was called with adapters
|
||||
call_args = mock_templates.TemplateResponse.call_args
|
||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||
assert len(context["adapters"]) == 3
|
||||
assert context["adapters"][0]["name"] == "firms"
|
||||
assert context["adapters"][1]["name"] == "nws"
|
||||
assert context["adapters"][2]["name"] == "usgs_quake"
|
||||
|
||||
|
||||
class TestAdaptersEditForm:
|
||||
"""Test adapter edit form GET."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapters_edit_nws_shows_form(self):
|
||||
"""GET /adapters/nws authenticated returns 200 with form."""
|
||||
from central.gui.routes import adapters_edit_form
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchrow.return_value = {
|
||||
"name": "nws",
|
||||
"enabled": True,
|
||||
"cadence_s": 60,
|
||||
"settings": {"contact_email": "test@example.com", "region": {"north": 49, "south": 24, "east": -66, "west": -125}},
|
||||
"paused_at": None,
|
||||
"updated_at": None,
|
||||
}
|
||||
mock_conn.fetch.return_value = [] # No API keys
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_templates = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_templates.TemplateResponse.return_value = mock_response
|
||||
|
||||
mock_csrf = MagicMock()
|
||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
||||
mock_csrf.set_csrf_cookie = MagicMock()
|
||||
|
||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
result = await adapters_edit_form(mock_request, "nws", mock_csrf)
|
||||
|
||||
call_args = mock_templates.TemplateResponse.call_args
|
||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||
assert context["adapter"]["name"] == "nws"
|
||||
assert context["adapter"]["settings"]["contact_email"] == "test@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapters_edit_nonexistent_returns_404(self):
|
||||
"""GET /adapters/nonexistent returns 404."""
|
||||
from central.gui.routes import adapters_edit_form
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchrow.return_value = None
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_csrf = MagicMock()
|
||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
||||
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
result = await adapters_edit_form(mock_request, "nonexistent", mock_csrf)
|
||||
|
||||
assert result.status_code == 404
|
||||
|
||||
|
||||
class TestAdaptersEditSubmit:
|
||||
"""Test adapter edit form POST."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapters_edit_valid_changes_updates_db(self):
|
||||
"""POST /adapters/nws with valid changes updates DB and redirects."""
|
||||
from central.gui.routes import adapters_edit_submit
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||
mock_request.cookies = {}
|
||||
|
||||
# Mock form data
|
||||
mock_form = MagicMock()
|
||||
mock_form.get.side_effect = lambda k, d="": {
|
||||
"cadence_s": "120",
|
||||
"contact_email": "new@example.com",
|
||||
}.get(k, d)
|
||||
mock_form.getlist.return_value = []
|
||||
mock_form.__contains__ = lambda self, k: k == "enabled"
|
||||
mock_request.form = AsyncMock(return_value=mock_form)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchrow.return_value = {
|
||||
"name": "nws",
|
||||
"enabled": True,
|
||||
"cadence_s": 60,
|
||||
"settings": {"contact_email": "old@example.com"},
|
||||
"paused_at": None,
|
||||
"updated_at": None,
|
||||
}
|
||||
mock_conn.execute = AsyncMock()
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_csrf = MagicMock()
|
||||
mock_csrf.validate_csrf = AsyncMock()
|
||||
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
with patch("central.gui.routes.write_audit", new_callable=AsyncMock) as mock_audit:
|
||||
result = await adapters_edit_submit(mock_request, "nws", mock_csrf)
|
||||
|
||||
assert result.status_code == 302
|
||||
assert result.headers["location"] == "/adapters"
|
||||
mock_conn.execute.assert_called()
|
||||
mock_audit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapters_edit_invalid_cadence_shows_error(self):
|
||||
"""POST /adapters/nws with cadence_s=30 shows error, no DB update."""
|
||||
from central.gui.routes import adapters_edit_submit
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||
|
||||
mock_form = MagicMock()
|
||||
mock_form.get.side_effect = lambda k, d="": {
|
||||
"cadence_s": "30",
|
||||
"contact_email": "test@example.com",
|
||||
}.get(k, d)
|
||||
mock_form.getlist.return_value = []
|
||||
mock_form.__contains__ = lambda self, k: k == "enabled"
|
||||
mock_request.form = AsyncMock(return_value=mock_form)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchrow.return_value = {
|
||||
"name": "nws",
|
||||
"enabled": True,
|
||||
"cadence_s": 60,
|
||||
"settings": {"contact_email": "test@example.com"},
|
||||
"paused_at": None,
|
||||
"updated_at": None,
|
||||
}
|
||||
mock_conn.fetch.return_value = []
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_templates = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_templates.TemplateResponse.return_value = mock_response
|
||||
|
||||
mock_csrf = MagicMock()
|
||||
mock_csrf.validate_csrf = AsyncMock()
|
||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
||||
mock_csrf.set_csrf_cookie = MagicMock()
|
||||
|
||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
result = await adapters_edit_submit(mock_request, "nws", mock_csrf)
|
||||
|
||||
# Should re-render form with error
|
||||
call_args = mock_templates.TemplateResponse.call_args
|
||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||
assert "cadence_s" in context["errors"]
|
||||
assert "60" in context["errors"]["cadence_s"] or "3600" in context["errors"]["cadence_s"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapters_edit_firms_unknown_api_key_shows_error(self):
|
||||
"""POST /adapters/firms with unknown api_key_alias shows error."""
|
||||
from central.gui.routes import adapters_edit_submit
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||
|
||||
mock_form = MagicMock()
|
||||
mock_form.get.side_effect = lambda k, d="": {
|
||||
"cadence_s": "300",
|
||||
"api_key_alias": "nonexistent_key",
|
||||
}.get(k, d)
|
||||
mock_form.getlist.return_value = ["VIIRS_SNPP_NRT"]
|
||||
mock_form.__contains__ = lambda self, k: k == "enabled"
|
||||
mock_request.form = AsyncMock(return_value=mock_form)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchrow.side_effect = [
|
||||
{ # First call: get adapter
|
||||
"name": "firms",
|
||||
"enabled": True,
|
||||
"cadence_s": 300,
|
||||
"settings": {"api_key_alias": "firms", "satellites": ["VIIRS_SNPP_NRT"]},
|
||||
"paused_at": None,
|
||||
"updated_at": None,
|
||||
},
|
||||
None, # Second call: check api_key exists - returns None
|
||||
]
|
||||
mock_conn.fetch.return_value = []
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_templates = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_templates.TemplateResponse.return_value = mock_response
|
||||
|
||||
mock_csrf = MagicMock()
|
||||
mock_csrf.validate_csrf = AsyncMock()
|
||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
||||
mock_csrf.set_csrf_cookie = MagicMock()
|
||||
|
||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
result = await adapters_edit_submit(mock_request, "firms", mock_csrf)
|
||||
|
||||
call_args = mock_templates.TemplateResponse.call_args
|
||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||
assert "api_key_alias" in context["errors"]
|
||||
assert "nonexistent_key" in context["errors"]["api_key_alias"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapters_edit_usgs_unknown_feed_shows_error(self):
|
||||
"""POST /adapters/usgs_quake with unknown feed shows error."""
|
||||
from central.gui.routes import adapters_edit_submit
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||
|
||||
mock_form = MagicMock()
|
||||
mock_form.get.side_effect = lambda k, d="": {
|
||||
"cadence_s": "120",
|
||||
"feed": "invalid_feed",
|
||||
}.get(k, d)
|
||||
mock_form.getlist.return_value = []
|
||||
mock_form.__contains__ = lambda self, k: k == "enabled"
|
||||
mock_request.form = AsyncMock(return_value=mock_form)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchrow.return_value = {
|
||||
"name": "usgs_quake",
|
||||
"enabled": True,
|
||||
"cadence_s": 120,
|
||||
"settings": {"feed": "all_hour"},
|
||||
"paused_at": None,
|
||||
"updated_at": None,
|
||||
}
|
||||
mock_conn.fetch.return_value = []
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_templates = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_templates.TemplateResponse.return_value = mock_response
|
||||
|
||||
mock_csrf = MagicMock()
|
||||
mock_csrf.validate_csrf = AsyncMock()
|
||||
mock_csrf.generate_csrf_tokens.return_value = ("token", "signed")
|
||||
mock_csrf.set_csrf_cookie = MagicMock()
|
||||
|
||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
result = await adapters_edit_submit(mock_request, "usgs_quake", mock_csrf)
|
||||
|
||||
call_args = mock_templates.TemplateResponse.call_args
|
||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||
assert "feed" in context["errors"]
|
||||
|
||||
|
||||
class TestAdaptersAudit:
|
||||
"""Test adapter audit logging."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_row_has_before_after(self):
|
||||
"""Audit row has before/after JSONB populated correctly."""
|
||||
from central.gui.routes import adapters_edit_submit
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||
|
||||
mock_form = MagicMock()
|
||||
mock_form.get.side_effect = lambda k, d="": {
|
||||
"cadence_s": "120",
|
||||
"contact_email": "new@example.com",
|
||||
}.get(k, d)
|
||||
mock_form.getlist.return_value = []
|
||||
mock_form.__contains__ = lambda self, k: k == "enabled"
|
||||
mock_request.form = AsyncMock(return_value=mock_form)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchrow.return_value = {
|
||||
"name": "nws",
|
||||
"enabled": True,
|
||||
"cadence_s": 60,
|
||||
"settings": {"contact_email": "old@example.com"},
|
||||
"paused_at": None,
|
||||
"updated_at": None,
|
||||
}
|
||||
mock_conn.execute = AsyncMock()
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_csrf = MagicMock()
|
||||
mock_csrf.validate_csrf = AsyncMock()
|
||||
|
||||
captured_audit = {}
|
||||
|
||||
async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None):
|
||||
captured_audit["action"] = action
|
||||
captured_audit["target"] = target
|
||||
captured_audit["before"] = before
|
||||
captured_audit["after"] = after
|
||||
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
with patch("central.gui.routes.write_audit", side_effect=capture_audit):
|
||||
result = await adapters_edit_submit(mock_request, "nws", mock_csrf)
|
||||
|
||||
assert captured_audit["action"] == "adapter.update"
|
||||
assert captured_audit["target"] == "nws"
|
||||
assert captured_audit["before"]["cadence_s"] == 60
|
||||
assert captured_audit["after"]["cadence_s"] == 120
|
||||
assert captured_audit["before"]["settings"]["contact_email"] == "old@example.com"
|
||||
assert captured_audit["after"]["settings"]["contact_email"] == "new@example.com"
|
||||
|
||||
|
||||
class TestAdaptersJsonbRegression:
|
||||
"""Regression tests for JSONB double-encoding bug."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_settings_passed_as_dict_not_string(self):
|
||||
"""Verify settings is passed to UPDATE as dict, not json.dumps string.
|
||||
|
||||
Regression test for double-encoding bug where json.dumps() was called
|
||||
on settings before passing to asyncpg, which already handles dict->jsonb.
|
||||
"""
|
||||
from central.gui.routes import adapters_edit_submit
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||
|
||||
mock_form = MagicMock()
|
||||
mock_form.get.side_effect = lambda k, d="": {
|
||||
"cadence_s": "120",
|
||||
"contact_email": "test@example.com",
|
||||
}.get(k, d)
|
||||
mock_form.getlist.return_value = []
|
||||
mock_form.__contains__ = lambda self, k: k == "enabled"
|
||||
mock_request.form = AsyncMock(return_value=mock_form)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchrow.return_value = {
|
||||
"name": "nws",
|
||||
"enabled": True,
|
||||
"cadence_s": 60,
|
||||
"settings": {"contact_email": "old@example.com"}, # dict, as asyncpg returns
|
||||
"paused_at": None,
|
||||
"updated_at": None,
|
||||
}
|
||||
mock_conn.execute = AsyncMock()
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_csrf = MagicMock()
|
||||
mock_csrf.validate_csrf = AsyncMock()
|
||||
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
with patch("central.gui.routes.write_audit", new_callable=AsyncMock):
|
||||
await adapters_edit_submit(mock_request, "nws", mock_csrf)
|
||||
|
||||
# Get the settings argument passed to execute (3rd positional arg after query)
|
||||
call_args = mock_conn.execute.call_args
|
||||
# args[0] is the query, args[1:] are the parameters
|
||||
settings_arg = call_args[0][3] # enabled=$1, cadence=$2, settings=$3
|
||||
|
||||
# CRITICAL: settings must be a dict, NOT a string
|
||||
# If json.dumps() was called, this would be a str like {contact_email: ...}
|
||||
assert isinstance(settings_arg, dict), f"settings should be dict, got {type(settings_arg)}: {settings_arg}"
|
||||
assert settings_arg["contact_email"] == "test@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_before_after_passed_as_dict(self):
|
||||
"""Verify audit before/after are passed as dicts, not json.dumps strings."""
|
||||
from central.gui.routes import adapters_edit_submit
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = MagicMock(id=1, username="testop")
|
||||
|
||||
mock_form = MagicMock()
|
||||
mock_form.get.side_effect = lambda k, d="": {
|
||||
"cadence_s": "120",
|
||||
"contact_email": "new@example.com",
|
||||
}.get(k, d)
|
||||
mock_form.getlist.return_value = []
|
||||
mock_form.__contains__ = lambda self, k: k == "enabled"
|
||||
mock_request.form = AsyncMock(return_value=mock_form)
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetchrow.return_value = {
|
||||
"name": "nws",
|
||||
"enabled": True,
|
||||
"cadence_s": 60,
|
||||
"settings": {"contact_email": "old@example.com"}, # dict
|
||||
"paused_at": None,
|
||||
"updated_at": None,
|
||||
}
|
||||
mock_conn.execute = AsyncMock()
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_csrf = MagicMock()
|
||||
mock_csrf.validate_csrf = AsyncMock()
|
||||
|
||||
captured_audit = {}
|
||||
|
||||
async def capture_audit(conn, action, operator_id=None, target=None, before=None, after=None):
|
||||
captured_audit["before"] = before
|
||||
captured_audit["after"] = after
|
||||
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
with patch("central.gui.routes.write_audit", side_effect=capture_audit):
|
||||
await adapters_edit_submit(mock_request, "nws", mock_csrf)
|
||||
|
||||
# CRITICAL: before and after must be dicts, NOT strings
|
||||
assert isinstance(captured_audit["before"], dict), f"before should be dict, got {type(captured_audit['before'])}"
|
||||
assert isinstance(captured_audit["after"], dict), f"after should be dict, got {type(captured_audit['after'])}"
|
||||
assert isinstance(captured_audit["before"]["settings"], dict), "before.settings should be dict"
|
||||
assert isinstance(captured_audit["after"]["settings"], dict), "after.settings should be dict"
|
||||
150
tests/test_archive_multi_stream.py
Normal file
150
tests/test_archive_multi_stream.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
"""Tests for multi-stream archive consumer."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from central.archive import (
|
||||
STREAMS,
|
||||
consumer_name_for,
|
||||
ArchiveConsumer,
|
||||
)
|
||||
|
||||
|
||||
class TestConsumerNaming:
|
||||
"""Test consumer naming convention."""
|
||||
|
||||
def test_consumer_name_for_central_wx(self):
|
||||
"""Consumer name for CENTRAL_WX is archive-central_wx."""
|
||||
assert consumer_name_for("CENTRAL_WX") == "archive-central_wx"
|
||||
|
||||
def test_consumer_name_for_central_fire(self):
|
||||
"""Consumer name for CENTRAL_FIRE is archive-central_fire."""
|
||||
assert consumer_name_for("CENTRAL_FIRE") == "archive-central_fire"
|
||||
|
||||
def test_consumer_name_for_central_quake(self):
|
||||
"""Consumer name for CENTRAL_QUAKE is archive-central_quake."""
|
||||
assert consumer_name_for("CENTRAL_QUAKE") == "archive-central_quake"
|
||||
|
||||
|
||||
class TestStreamsConfiguration:
|
||||
"""Test streams configuration."""
|
||||
|
||||
def test_streams_list_has_three_entries(self):
|
||||
"""STREAMS list has three event-bearing streams."""
|
||||
assert len(STREAMS) == 3
|
||||
|
||||
def test_streams_contains_central_wx(self):
|
||||
"""STREAMS contains CENTRAL_WX with correct filter."""
|
||||
assert ("CENTRAL_WX", "central.wx.>") in STREAMS
|
||||
|
||||
def test_streams_contains_central_fire(self):
|
||||
"""STREAMS contains CENTRAL_FIRE with correct filter."""
|
||||
assert ("CENTRAL_FIRE", "central.fire.>") in STREAMS
|
||||
|
||||
def test_streams_contains_central_quake(self):
|
||||
"""STREAMS contains CENTRAL_QUAKE with correct filter."""
|
||||
assert ("CENTRAL_QUAKE", "central.quake.>") in STREAMS
|
||||
|
||||
def test_streams_excludes_central_meta(self):
|
||||
"""STREAMS does not contain CENTRAL_META (status messages only)."""
|
||||
stream_names = [s[0] for s in STREAMS]
|
||||
assert "CENTRAL_META" not in stream_names
|
||||
|
||||
|
||||
class TestOrphanedConsumerCleanup:
|
||||
"""Test cleanup of orphaned 'archive' consumer."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_removes_orphaned_consumer_when_exists(self):
|
||||
"""Cleanup removes 'archive' consumer from CENTRAL_WX when it exists."""
|
||||
consumer = ArchiveConsumer(
|
||||
nats_url="nats://localhost:4222",
|
||||
postgres_dsn="postgresql://test:test@localhost/test",
|
||||
)
|
||||
|
||||
mock_js = AsyncMock()
|
||||
mock_js.consumer_info = AsyncMock(return_value=MagicMock())
|
||||
mock_js.delete_consumer = AsyncMock()
|
||||
consumer._js = mock_js
|
||||
|
||||
await consumer._cleanup_orphaned_consumer()
|
||||
|
||||
mock_js.consumer_info.assert_called_once_with("CENTRAL_WX", "archive")
|
||||
mock_js.delete_consumer.assert_called_once_with("CENTRAL_WX", "archive")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_handles_not_found_gracefully(self):
|
||||
"""Cleanup handles NotFoundError when 'archive' consumer doesn't exist."""
|
||||
from nats.js.errors import NotFoundError
|
||||
|
||||
consumer = ArchiveConsumer(
|
||||
nats_url="nats://localhost:4222",
|
||||
postgres_dsn="postgresql://test:test@localhost/test",
|
||||
)
|
||||
|
||||
mock_js = AsyncMock()
|
||||
mock_js.consumer_info = AsyncMock(side_effect=NotFoundError())
|
||||
mock_js.delete_consumer = AsyncMock()
|
||||
consumer._js = mock_js
|
||||
|
||||
# Should not raise
|
||||
await consumer._cleanup_orphaned_consumer()
|
||||
|
||||
mock_js.consumer_info.assert_called_once_with("CENTRAL_WX", "archive")
|
||||
mock_js.delete_consumer.assert_not_called()
|
||||
|
||||
|
||||
class TestEnsureConsumer:
|
||||
"""Test consumer creation for each stream."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_consumer_creates_when_not_exists(self):
|
||||
"""_ensure_consumer creates consumer when it doesn't exist."""
|
||||
from nats.js.errors import NotFoundError
|
||||
|
||||
consumer = ArchiveConsumer(
|
||||
nats_url="nats://localhost:4222",
|
||||
postgres_dsn="postgresql://test:test@localhost/test",
|
||||
)
|
||||
|
||||
mock_js = AsyncMock()
|
||||
mock_js.consumer_info = AsyncMock(side_effect=NotFoundError())
|
||||
mock_js.add_consumer = AsyncMock()
|
||||
consumer._js = mock_js
|
||||
|
||||
await consumer._ensure_consumer(
|
||||
"CENTRAL_FIRE", "central.fire.>", "archive-central_fire"
|
||||
)
|
||||
|
||||
mock_js.consumer_info.assert_called_once_with(
|
||||
"CENTRAL_FIRE", "archive-central_fire"
|
||||
)
|
||||
mock_js.add_consumer.assert_called_once()
|
||||
# Verify the consumer config
|
||||
call_args = mock_js.add_consumer.call_args
|
||||
assert call_args[0][0] == "CENTRAL_FIRE"
|
||||
config = call_args[0][1]
|
||||
assert config.durable_name == "archive-central_fire"
|
||||
assert config.filter_subject == "central.fire.>"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ensure_consumer_skips_when_exists(self):
|
||||
"""_ensure_consumer does nothing when consumer already exists."""
|
||||
consumer = ArchiveConsumer(
|
||||
nats_url="nats://localhost:4222",
|
||||
postgres_dsn="postgresql://test:test@localhost/test",
|
||||
)
|
||||
|
||||
mock_js = AsyncMock()
|
||||
mock_js.consumer_info = AsyncMock(return_value=MagicMock())
|
||||
mock_js.add_consumer = AsyncMock()
|
||||
consumer._js = mock_js
|
||||
|
||||
await consumer._ensure_consumer(
|
||||
"CENTRAL_QUAKE", "central.quake.>", "archive-central_quake"
|
||||
)
|
||||
|
||||
mock_js.consumer_info.assert_called_once_with(
|
||||
"CENTRAL_QUAKE", "archive-central_quake"
|
||||
)
|
||||
mock_js.add_consumer.assert_not_called()
|
||||
158
tests/test_dashboard.py
Normal file
158
tests/test_dashboard.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
"""Tests for dashboard routes."""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Set required env vars before importing central modules
|
||||
os.environ.setdefault("CENTRAL_DB_DSN", "postgresql://test:test@localhost/test")
|
||||
os.environ.setdefault("CENTRAL_CSRF_SECRET", "testsecret12345678901234567890ab")
|
||||
os.environ.setdefault("CENTRAL_NATS_URL", "nats://localhost:4222")
|
||||
|
||||
|
||||
class TestFormatBytes:
|
||||
"""Test _format_bytes helper."""
|
||||
|
||||
def test_format_bytes_bytes(self):
|
||||
"""Bytes are shown as B."""
|
||||
from central.gui.routes import _format_bytes
|
||||
assert _format_bytes(100) == "100 B"
|
||||
|
||||
def test_format_bytes_kilobytes(self):
|
||||
"""KB formatting."""
|
||||
from central.gui.routes import _format_bytes
|
||||
assert _format_bytes(1024) == "1.0 KB"
|
||||
|
||||
def test_format_bytes_megabytes(self):
|
||||
"""MB formatting."""
|
||||
from central.gui.routes import _format_bytes
|
||||
assert _format_bytes(1048576) == "1.0 MB"
|
||||
|
||||
def test_format_bytes_gigabytes(self):
|
||||
"""GB formatting."""
|
||||
from central.gui.routes import _format_bytes
|
||||
assert _format_bytes(1073741824) == "1.0 GB"
|
||||
|
||||
|
||||
class TestDashboardEventsSQL:
|
||||
"""Test events query construction."""
|
||||
|
||||
def test_events_query_has_24h_filter(self):
|
||||
"""Events query filters by received > NOW() - 24h."""
|
||||
# We can't easily test the full route without mocking,
|
||||
# but we can verify the query logic by inspecting the source
|
||||
import inspect
|
||||
from central.gui.routes import dashboard_events
|
||||
source = inspect.getsource(dashboard_events)
|
||||
assert "24 hours" in source
|
||||
assert "received > NOW()" in source
|
||||
|
||||
|
||||
class TestDashboardStreamsGracefulDegradation:
|
||||
"""Test streams endpoint graceful degradation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nats_unavailable_returns_error_message(self):
|
||||
"""When NATS is unavailable, streams returns error message not 500."""
|
||||
from central.gui.routes import dashboard_streams
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = MagicMock()
|
||||
|
||||
mock_templates = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_templates.TemplateResponse.return_value = mock_response
|
||||
|
||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||
with patch("central.gui.nats.get_js", return_value=None):
|
||||
result = await dashboard_streams(mock_request)
|
||||
|
||||
# Should call template with error context
|
||||
call_args = mock_templates.TemplateResponse.call_args
|
||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||
assert context["error"] == "NATS unavailable"
|
||||
assert context["streams"] is None
|
||||
|
||||
|
||||
class TestDashboardPollsGracefulDegradation:
|
||||
"""Test polls endpoint graceful degradation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nats_unavailable_shows_all_adapters_with_error(self):
|
||||
"""When NATS is unavailable, polls shows adapters with error message."""
|
||||
from central.gui.routes import dashboard_polls
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = MagicMock()
|
||||
|
||||
mock_conn = AsyncMock()
|
||||
mock_conn.fetch.return_value = [{"name": "nws"}, {"name": "firms"}]
|
||||
|
||||
mock_pool = MagicMock()
|
||||
mock_pool.acquire.return_value.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_pool.acquire.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_templates = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_templates.TemplateResponse.return_value = mock_response
|
||||
|
||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||
with patch("central.gui.routes.get_pool", return_value=mock_pool):
|
||||
with patch("central.gui.nats.get_js", return_value=None):
|
||||
result = await dashboard_polls(mock_request)
|
||||
|
||||
call_args = mock_templates.TemplateResponse.call_args
|
||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||
assert context["error"] == "NATS unavailable"
|
||||
assert len(context["adapters"]) == 2
|
||||
assert context["adapters"][0]["error"] == "NATS unavailable"
|
||||
|
||||
|
||||
class TestDashboardStreamsIsolation:
|
||||
"""Test stream failure isolation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_stream_failure_doesnt_crash_card(self):
|
||||
"""A single stream failure shows error for that stream only."""
|
||||
from central.gui.routes import dashboard_streams
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.operator = MagicMock()
|
||||
|
||||
async def mock_stream_info(name):
|
||||
if name == "CENTRAL_FIRE":
|
||||
raise Exception("Not found")
|
||||
state = MagicMock()
|
||||
state.messages = 100
|
||||
state.bytes = 1024
|
||||
info = MagicMock()
|
||||
info.state = state
|
||||
return info
|
||||
|
||||
mock_js = AsyncMock()
|
||||
mock_js.stream_info.side_effect = mock_stream_info
|
||||
|
||||
mock_templates = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_templates.TemplateResponse.return_value = mock_response
|
||||
|
||||
with patch("central.gui.routes._get_templates", return_value=mock_templates):
|
||||
with patch("central.gui.nats.get_js", return_value=mock_js):
|
||||
result = await dashboard_streams(mock_request)
|
||||
|
||||
call_args = mock_templates.TemplateResponse.call_args
|
||||
context = call_args.kwargs.get("context", call_args[1].get("context"))
|
||||
|
||||
streams = context["streams"]
|
||||
# Should have 4 streams
|
||||
assert len(streams) == 4
|
||||
|
||||
# CENTRAL_FIRE should have error
|
||||
fire_stream = next(s for s in streams if s["name"] == "CENTRAL_FIRE")
|
||||
assert fire_stream.get("error") == "unavailable"
|
||||
|
||||
# CENTRAL_WX should be fine
|
||||
wx_stream = next(s for s in streams if s["name"] == "CENTRAL_WX")
|
||||
assert wx_stream.get("error") is None
|
||||
assert wx_stream["messages"] == 100
|
||||
313
tests/test_events_adapter_column.py
Normal file
313
tests/test_events_adapter_column.py
Normal file
|
|
@ -0,0 +1,313 @@
|
|||
"""Tests for events.adapter column migration (011).
|
||||
|
||||
These tests exercise the actual migration SQL against a test database,
|
||||
verifying backfill logic, FK constraints, NOT NULL enforcement, and
|
||||
source column removal.
|
||||
|
||||
Requires CENTRAL_TEST_DB_DSN or uses default central_test database.
|
||||
The test database must have PostGIS installed, or the central_test role
|
||||
must be a superuser (which it is by default) to self-bootstrap PostGIS.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import asyncpg
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from central.models import Event, Geo
|
||||
|
||||
|
||||
# Test database DSN - matches test_config_store.py pattern
|
||||
TEST_DB_DSN = os.environ.get(
|
||||
"CENTRAL_TEST_DB_DSN",
|
||||
"postgresql://central_test:testpass@localhost/central_test",
|
||||
)
|
||||
|
||||
# Path to migration file
|
||||
MIGRATION_011_PATH = Path(__file__).parent.parent / "sql" / "migrations" / "011_events_add_adapter_column.sql"
|
||||
|
||||
|
||||
class TestEventModelAdapterField:
|
||||
"""Test Event model has adapter field (not source)."""
|
||||
|
||||
def test_event_has_adapter_field(self):
|
||||
"""Event model has adapter field."""
|
||||
event = Event(
|
||||
id="test-1",
|
||||
adapter="nws",
|
||||
category="wx.alert.test",
|
||||
time=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc),
|
||||
geo=Geo(),
|
||||
data={},
|
||||
)
|
||||
assert event.adapter == "nws"
|
||||
|
||||
def test_event_model_no_source_field(self):
|
||||
"""Event model does not have source field."""
|
||||
assert "source" not in Event.model_fields
|
||||
assert "adapter" in Event.model_fields
|
||||
|
||||
def test_event_adapter_accepts_all_adapter_names(self):
|
||||
"""Event adapter field accepts all known adapter names."""
|
||||
for adapter_name in ["nws", "firms", "usgs_quake"]:
|
||||
event = Event(
|
||||
id=f"test-{adapter_name}",
|
||||
adapter=adapter_name,
|
||||
category="test.category",
|
||||
time=datetime.now(timezone.utc),
|
||||
geo=Geo(),
|
||||
data={},
|
||||
)
|
||||
assert event.adapter == adapter_name
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def db_conn() -> asyncpg.Connection:
|
||||
"""Get a database connection for migration tests."""
|
||||
conn = await asyncpg.connect(TEST_DB_DSN)
|
||||
yield conn
|
||||
await conn.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def pre_migration_events_table(db_conn: asyncpg.Connection) -> None:
|
||||
"""Create events table with pre-migration schema (source column, no adapter).
|
||||
|
||||
Also ensures config.adapters exists with test adapters.
|
||||
Self-bootstraps PostGIS if not already installed (central_test is superuser).
|
||||
"""
|
||||
# Self-bootstrap PostGIS extension (central_test role is superuser)
|
||||
await db_conn.execute("CREATE EXTENSION IF NOT EXISTS postgis")
|
||||
|
||||
# Ensure config schema and adapters table exist
|
||||
await db_conn.execute("CREATE SCHEMA IF NOT EXISTS config")
|
||||
await db_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS config.adapters (
|
||||
name TEXT PRIMARY KEY,
|
||||
enabled BOOLEAN NOT NULL DEFAULT true,
|
||||
cadence_s INTEGER NOT NULL,
|
||||
settings JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||||
paused_at TIMESTAMPTZ,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
)
|
||||
""")
|
||||
|
||||
# Insert test adapters (idempotent)
|
||||
await db_conn.execute("""
|
||||
INSERT INTO config.adapters (name, cadence_s)
|
||||
VALUES ('nws', 60), ('firms', 300), ('usgs_quake', 60)
|
||||
ON CONFLICT (name) DO NOTHING
|
||||
""")
|
||||
|
||||
# Drop events table if exists (clean slate)
|
||||
await db_conn.execute("DROP TABLE IF EXISTS public.events CASCADE")
|
||||
|
||||
# Create events table with PRE-MIGRATION schema (has source, no adapter)
|
||||
# Matches production schema including geom column
|
||||
await db_conn.execute("""
|
||||
CREATE TABLE public.events (
|
||||
id TEXT NOT NULL,
|
||||
source TEXT NOT NULL,
|
||||
category TEXT NOT NULL,
|
||||
time TIMESTAMPTZ NOT NULL,
|
||||
expires TIMESTAMPTZ,
|
||||
severity SMALLINT,
|
||||
geom GEOMETRY(Geometry, 4326),
|
||||
regions TEXT[],
|
||||
primary_region TEXT,
|
||||
payload JSONB NOT NULL,
|
||||
received TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||
PRIMARY KEY (id, time)
|
||||
)
|
||||
""")
|
||||
|
||||
# Insert test rows with different source values
|
||||
# geom is NULL (production schema permits this)
|
||||
test_rows = [
|
||||
("event-nws-1", "central/adapters/nws", "wx.alert.tornado_warning"),
|
||||
("event-nws-2", "central/adapters/nws", "wx.alert.flood_warning"),
|
||||
("event-firms-1", "central/adapters/firms", "fire.hotspot"),
|
||||
("event-usgs-1", "central/adapters/usgs_quake", "seismic.earthquake"),
|
||||
]
|
||||
|
||||
for event_id, source, category in test_rows:
|
||||
await db_conn.execute("""
|
||||
INSERT INTO public.events (id, source, category, time, payload)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
""", event_id, source, category, datetime.now(timezone.utc), "{}")
|
||||
|
||||
yield
|
||||
|
||||
# Cleanup
|
||||
await db_conn.execute("DROP TABLE IF EXISTS public.events CASCADE")
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def run_migration_011(
|
||||
db_conn: asyncpg.Connection,
|
||||
pre_migration_events_table: None,
|
||||
) -> None:
|
||||
"""Run migration 011 against the pre-migration events table."""
|
||||
migration_sql = MIGRATION_011_PATH.read_text()
|
||||
await db_conn.execute(migration_sql)
|
||||
yield
|
||||
|
||||
|
||||
class TestMigration011Backfill:
|
||||
"""Test migration 011 backfill logic."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_nws_source_to_adapter(
|
||||
self, db_conn: asyncpg.Connection, run_migration_011: None
|
||||
):
|
||||
"""Backfill converts 'central/adapters/nws' to 'nws'."""
|
||||
rows = await db_conn.fetch(
|
||||
"SELECT id, adapter FROM public.events WHERE id LIKE 'event-nws-%'"
|
||||
)
|
||||
assert len(rows) == 2
|
||||
for row in rows:
|
||||
assert row["adapter"] == "nws"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_firms_source_to_adapter(
|
||||
self, db_conn: asyncpg.Connection, run_migration_011: None
|
||||
):
|
||||
"""Backfill converts 'central/adapters/firms' to 'firms'."""
|
||||
row = await db_conn.fetchrow(
|
||||
"SELECT adapter FROM public.events WHERE id = 'event-firms-1'"
|
||||
)
|
||||
assert row["adapter"] == "firms"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_usgs_quake_source_to_adapter(
|
||||
self, db_conn: asyncpg.Connection, run_migration_011: None
|
||||
):
|
||||
"""Backfill converts 'central/adapters/usgs_quake' to 'usgs_quake'."""
|
||||
row = await db_conn.fetchrow(
|
||||
"SELECT adapter FROM public.events WHERE id = 'event-usgs-1'"
|
||||
)
|
||||
assert row["adapter"] == "usgs_quake"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_all_rows_have_adapter(
|
||||
self, db_conn: asyncpg.Connection, run_migration_011: None
|
||||
):
|
||||
"""All rows have non-NULL adapter after backfill."""
|
||||
count = await db_conn.fetchval(
|
||||
"SELECT COUNT(*) FROM public.events WHERE adapter IS NULL"
|
||||
)
|
||||
assert count == 0
|
||||
|
||||
|
||||
class TestMigration011Constraints:
|
||||
"""Test migration 011 constraint enforcement."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fk_restrict_prevents_adapter_deletion(
|
||||
self, db_conn: asyncpg.Connection, run_migration_011: None
|
||||
):
|
||||
"""FK RESTRICT prevents deleting adapter with existing events."""
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await db_conn.execute(
|
||||
"DELETE FROM config.adapters WHERE name = 'nws'"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_null_rejects_null_adapter(
|
||||
self, db_conn: asyncpg.Connection, run_migration_011: None
|
||||
):
|
||||
"""NOT NULL constraint rejects NULL adapter on insert."""
|
||||
with pytest.raises(asyncpg.NotNullViolationError):
|
||||
await db_conn.execute("""
|
||||
INSERT INTO public.events (id, adapter, category, time, payload)
|
||||
VALUES ('null-test', NULL, 'test', now(), '{}')
|
||||
""")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fk_rejects_unknown_adapter(
|
||||
self, db_conn: asyncpg.Connection, run_migration_011: None
|
||||
):
|
||||
"""FK constraint rejects unknown adapter values."""
|
||||
with pytest.raises(asyncpg.ForeignKeyViolationError):
|
||||
await db_conn.execute("""
|
||||
INSERT INTO public.events (id, adapter, category, time, payload)
|
||||
VALUES ('bad-adapter', 'nonexistent_adapter', 'test', now(), '{}')
|
||||
""")
|
||||
|
||||
|
||||
class TestMigration011SchemaChanges:
|
||||
"""Test migration 011 schema changes."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_source_column_dropped(
|
||||
self, db_conn: asyncpg.Connection, run_migration_011: None
|
||||
):
|
||||
"""Source column no longer exists after migration."""
|
||||
row = await db_conn.fetchrow("""
|
||||
SELECT COUNT(*) as count
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = 'events'
|
||||
AND column_name = 'source'
|
||||
""")
|
||||
assert row["count"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapter_column_exists(
|
||||
self, db_conn: asyncpg.Connection, run_migration_011: None
|
||||
):
|
||||
"""Adapter column exists after migration."""
|
||||
row = await db_conn.fetchrow("""
|
||||
SELECT COUNT(*) as count
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = 'events'
|
||||
AND column_name = 'adapter'
|
||||
""")
|
||||
assert row["count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapter_column_is_not_null(
|
||||
self, db_conn: asyncpg.Connection, run_migration_011: None
|
||||
):
|
||||
"""Adapter column has NOT NULL constraint."""
|
||||
row = await db_conn.fetchrow("""
|
||||
SELECT is_nullable
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = 'events'
|
||||
AND column_name = 'adapter'
|
||||
""")
|
||||
assert row["is_nullable"] == "NO"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapter_fk_constraint_exists(
|
||||
self, db_conn: asyncpg.Connection, run_migration_011: None
|
||||
):
|
||||
"""FK constraint events_adapter_fkey exists."""
|
||||
row = await db_conn.fetchrow("""
|
||||
SELECT COUNT(*) as count
|
||||
FROM information_schema.table_constraints
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = 'events'
|
||||
AND constraint_name = 'events_adapter_fkey'
|
||||
AND constraint_type = 'FOREIGN KEY'
|
||||
""")
|
||||
assert row["count"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adapter_received_index_exists(
|
||||
self, db_conn: asyncpg.Connection, run_migration_011: None
|
||||
):
|
||||
"""Index events_adapter_received_idx exists."""
|
||||
row = await db_conn.fetchrow("""
|
||||
SELECT COUNT(*) as count
|
||||
FROM pg_indexes
|
||||
WHERE schemaname = 'public'
|
||||
AND tablename = 'events'
|
||||
AND indexname = 'events_adapter_received_idx'
|
||||
""")
|
||||
assert row["count"] == 1
|
||||
|
|
@ -288,7 +288,7 @@ class TestSubjectGeneration:
|
|||
def test_subject_format(self):
|
||||
event = Event(
|
||||
id="test",
|
||||
source="central/adapters/firms",
|
||||
adapter="firms",
|
||||
category="fire.hotspot.viirs_snpp.high",
|
||||
time=datetime.now(timezone.utc),
|
||||
severity=3,
|
||||
|
|
@ -302,7 +302,7 @@ class TestSubjectGeneration:
|
|||
def test_subject_nominal_confidence(self):
|
||||
event = Event(
|
||||
id="test",
|
||||
source="central/adapters/firms",
|
||||
adapter="firms",
|
||||
category="fire.hotspot.viirs_noaa20.nominal",
|
||||
time=datetime.now(timezone.utc),
|
||||
severity=2,
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ def sample_event(sample_geo: Geo) -> Event:
|
|||
"""Sample Event object for testing."""
|
||||
return Event(
|
||||
id="urn:central:nws:alert:KBOI-202401151200-SVR",
|
||||
source="central/adapters/nws",
|
||||
adapter="nws",
|
||||
category="wx.alert.severe_thunderstorm_warning",
|
||||
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
|
||||
expires=datetime(2024, 1, 15, 13, 0, 0, tzinfo=timezone.utc),
|
||||
|
|
@ -75,7 +75,7 @@ class TestSubjectForEvent:
|
|||
)
|
||||
event = Event(
|
||||
id="test-zone",
|
||||
source="test",
|
||||
adapter="nws",
|
||||
category="wx.alert.winter_storm_warning",
|
||||
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
|
||||
geo=geo,
|
||||
|
|
@ -89,7 +89,7 @@ class TestSubjectForEvent:
|
|||
geo = Geo(regions=[], primary_region=None)
|
||||
event = Event(
|
||||
id="test-unknown",
|
||||
source="test",
|
||||
adapter="nws",
|
||||
category="wx.alert.test",
|
||||
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
|
||||
geo=geo,
|
||||
|
|
@ -144,7 +144,7 @@ class TestCloudEventsWire:
|
|||
"""When severity is None, centralseverity is omitted entirely."""
|
||||
event = Event(
|
||||
id="test-no-severity",
|
||||
source="test",
|
||||
adapter="nws",
|
||||
category="wx.alert.test",
|
||||
time=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc),
|
||||
severity=None, # Explicitly None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue