mirror of
https://github.com/zvx-echo6/meshai.git
synced 2026-05-21 23:24:44 +02:00
Remove dead modules: safety, rate_limiter, personality, webhook, web_status, announcements, log_setup
These modules were wired up but never actually functional in the running bot. Strips all imports and usage from main.py and router.py. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
9a628724ce
commit
10bc94b273
9 changed files with 22 additions and 1234 deletions
|
|
@ -1,109 +0,0 @@
|
|||
"""Periodic announcements/broadcasts for MeshAI."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from .config import AnnouncementsConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnnouncementScheduler:
|
||||
"""Scheduler for periodic announcements."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AnnouncementsConfig,
|
||||
send_callback: Callable[[str, int], Awaitable[None]],
|
||||
):
|
||||
"""Initialize the announcement scheduler.
|
||||
|
||||
Args:
|
||||
config: Announcements configuration
|
||||
send_callback: Async callback to send messages: (text, channel) -> None
|
||||
"""
|
||||
self.config = config
|
||||
self._send_callback = send_callback
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._message_index = 0
|
||||
self._running = False
|
||||
|
||||
async def start(self):
|
||||
"""Start the announcement scheduler."""
|
||||
if not self.config.enabled or not self.config.messages:
|
||||
logger.debug("Announcements disabled or no messages configured")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._run_loop())
|
||||
logger.info(
|
||||
f"Announcement scheduler started (every {self.config.interval_hours}h)"
|
||||
)
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the announcement scheduler."""
|
||||
self._running = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._task = None
|
||||
logger.info("Announcement scheduler stopped")
|
||||
|
||||
async def _run_loop(self):
|
||||
"""Main loop for sending periodic announcements."""
|
||||
# Wait a bit before first announcement
|
||||
await asyncio.sleep(60) # 1 minute initial delay
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# Get next message
|
||||
message = self._get_next_message()
|
||||
if message:
|
||||
logger.info(f"Sending announcement to channel {self.config.channel}")
|
||||
await self._send_callback(message, self.config.channel)
|
||||
|
||||
# Wait for next interval
|
||||
await asyncio.sleep(self.config.interval_hours * 3600)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in announcement loop: {e}")
|
||||
await asyncio.sleep(300) # Wait 5 min on error
|
||||
|
||||
def _get_next_message(self) -> Optional[str]:
|
||||
"""Get the next announcement message."""
|
||||
if not self.config.messages:
|
||||
return None
|
||||
|
||||
if self.config.random_order:
|
||||
return random.choice(self.config.messages)
|
||||
else:
|
||||
message = self.config.messages[self._message_index]
|
||||
self._message_index = (self._message_index + 1) % len(self.config.messages)
|
||||
return message
|
||||
|
||||
async def send_now(self, message: Optional[str] = None) -> bool:
|
||||
"""Send an announcement immediately.
|
||||
|
||||
Args:
|
||||
message: Optional specific message, or use next in rotation
|
||||
|
||||
Returns:
|
||||
True if sent successfully
|
||||
"""
|
||||
msg = message or self._get_next_message()
|
||||
if not msg:
|
||||
return False
|
||||
|
||||
try:
|
||||
await self._send_callback(msg, self.config.channel)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send announcement: {e}")
|
||||
return False
|
||||
|
|
@ -1,122 +0,0 @@
|
|||
"""Enhanced logging setup for MeshAI."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from .config import LoggingConfig
|
||||
|
||||
# Custom log levels for message tracking
|
||||
MESSAGE_IN = 25 # Between INFO (20) and WARNING (30)
|
||||
MESSAGE_OUT = 26
|
||||
API_CALL = 15 # Between DEBUG (10) and INFO (20)
|
||||
|
||||
logging.addLevelName(MESSAGE_IN, "MSG_IN")
|
||||
logging.addLevelName(MESSAGE_OUT, "MSG_OUT")
|
||||
logging.addLevelName(API_CALL, "API")
|
||||
|
||||
|
||||
class MeshAILogger(logging.Logger):
|
||||
"""Custom logger with message tracking methods."""
|
||||
|
||||
def message_in(self, sender: str, text: str, channel: int = 0):
|
||||
"""Log an incoming message."""
|
||||
if self.isEnabledFor(MESSAGE_IN):
|
||||
self._log(MESSAGE_IN, f"[CH{channel}] {sender}: {text}", ())
|
||||
|
||||
def message_out(self, recipient: str, text: str, channel: int = 0):
|
||||
"""Log an outgoing message."""
|
||||
if self.isEnabledFor(MESSAGE_OUT):
|
||||
self._log(MESSAGE_OUT, f"[CH{channel}] -> {recipient}: {text}", ())
|
||||
|
||||
def api_call(self, backend: str, model: str, tokens: Optional[int] = None):
|
||||
"""Log an API call."""
|
||||
if self.isEnabledFor(API_CALL):
|
||||
msg = f"API call to {backend}/{model}"
|
||||
if tokens:
|
||||
msg += f" ({tokens} tokens)"
|
||||
self._log(API_CALL, msg, ())
|
||||
|
||||
|
||||
# Set the custom logger class
|
||||
logging.setLoggerClass(MeshAILogger)
|
||||
|
||||
|
||||
def setup_logging(config: LoggingConfig, verbose: bool = False) -> logging.Logger:
|
||||
"""Configure logging based on config.
|
||||
|
||||
Args:
|
||||
config: Logging configuration
|
||||
verbose: Override to enable DEBUG level
|
||||
|
||||
Returns:
|
||||
The configured root logger
|
||||
"""
|
||||
# Determine log level
|
||||
if verbose:
|
||||
level = logging.DEBUG
|
||||
else:
|
||||
level_name = config.level.upper()
|
||||
level = getattr(logging, level_name, logging.INFO)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
# Get root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(level)
|
||||
|
||||
# Clear existing handlers
|
||||
root_logger.handlers.clear()
|
||||
|
||||
# Console handler (always)
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(level)
|
||||
console_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# File handler (if configured)
|
||||
if config.file:
|
||||
log_path = Path(config.file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_handler = RotatingFileHandler(
|
||||
log_path,
|
||||
maxBytes=config.max_size_mb * 1024 * 1024,
|
||||
backupCount=config.backup_count,
|
||||
)
|
||||
file_handler.setLevel(level)
|
||||
file_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
# Configure message logging levels based on config
|
||||
meshai_logger = logging.getLogger("meshai")
|
||||
|
||||
if not config.log_messages:
|
||||
# Disable message logging
|
||||
meshai_logger.addFilter(lambda r: r.levelno not in (MESSAGE_IN, MESSAGE_OUT))
|
||||
|
||||
if not config.log_api_calls:
|
||||
# Disable API call logging (it's DEBUG level anyway)
|
||||
meshai_logger.addFilter(lambda r: r.levelno != API_CALL)
|
||||
|
||||
return root_logger
|
||||
|
||||
|
||||
def get_logger(name: str = "meshai") -> MeshAILogger:
|
||||
"""Get a MeshAI logger instance.
|
||||
|
||||
Args:
|
||||
name: Logger name (will be prefixed with 'meshai.')
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
if not name.startswith("meshai"):
|
||||
name = f"meshai.{name}"
|
||||
return logging.getLogger(name)
|
||||
144
meshai/main.py
144
meshai/main.py
|
|
@ -11,8 +11,7 @@ from pathlib import Path
|
|||
from typing import Optional
|
||||
|
||||
from . import __version__
|
||||
from .announcements import AnnouncementScheduler
|
||||
from .backends import AnthropicBackend, FallbackBackend, GoogleBackend, LLMBackend, OpenAIBackend
|
||||
from .backends import AnthropicBackend, GoogleBackend, LLMBackend, OpenAIBackend
|
||||
from .cli import run_configurator
|
||||
from .commands import CommandDispatcher
|
||||
from .commands.dispatcher import create_dispatcher
|
||||
|
|
@ -21,13 +20,8 @@ from .config import Config, load_config
|
|||
from .connector import MeshConnector, MeshMessage
|
||||
from .history import ConversationHistory
|
||||
from .memory import ConversationSummary
|
||||
from .personality import PersonalityManager
|
||||
from .rate_limiter import RateLimiter
|
||||
from .responder import Responder
|
||||
from .router import MessageRouter, RouteType
|
||||
from .safety import SafetyFilter, UserFilter
|
||||
from .web_status import WebStatusServer, get_status_data
|
||||
from .webhook import WebhookClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -43,13 +37,6 @@ class MeshAI:
|
|||
self.llm: Optional[LLMBackend] = None
|
||||
self.router: Optional[MessageRouter] = None
|
||||
self.responder: Optional[Responder] = None
|
||||
self.personality: Optional[PersonalityManager] = None
|
||||
self.safety_filter: Optional[SafetyFilter] = None
|
||||
self.user_filter: Optional[UserFilter] = None
|
||||
self.rate_limiter: Optional[RateLimiter] = None
|
||||
self.webhook: Optional[WebhookClient] = None
|
||||
self.web_status: Optional[WebStatusServer] = None
|
||||
self.announcements: Optional[AnnouncementScheduler] = None
|
||||
self._running = False
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._last_cleanup: float = 0.0
|
||||
|
|
@ -70,14 +57,6 @@ class MeshAI:
|
|||
self._loop = asyncio.get_event_loop()
|
||||
self._last_cleanup = time.time()
|
||||
|
||||
# Start async services
|
||||
await self.webhook.start()
|
||||
await self.webhook.on_startup()
|
||||
await self.announcements.start()
|
||||
|
||||
# Start sync services
|
||||
self.web_status.start()
|
||||
|
||||
# Write PID file
|
||||
self._write_pid()
|
||||
|
||||
|
|
@ -97,16 +76,6 @@ class MeshAI:
|
|||
logger.info("Stopping MeshAI...")
|
||||
self._running = False
|
||||
|
||||
if self.webhook:
|
||||
await self.webhook.on_shutdown()
|
||||
await self.webhook.stop()
|
||||
|
||||
if self.announcements:
|
||||
await self.announcements.stop()
|
||||
|
||||
if self.web_status:
|
||||
self.web_status.stop()
|
||||
|
||||
if self.connector:
|
||||
self.connector.disconnect()
|
||||
|
||||
|
|
@ -125,30 +94,8 @@ class MeshAI:
|
|||
self.history = ConversationHistory(self.config.history)
|
||||
await self.history.initialize()
|
||||
|
||||
# Command dispatcher (2h: pass config)
|
||||
self.dispatcher = create_dispatcher(
|
||||
prefix=self.config.commands.prefix,
|
||||
disabled_commands=self.config.commands.disabled_commands,
|
||||
custom_commands=self.config.commands.custom_commands,
|
||||
)
|
||||
|
||||
# Safety and user filters (2a)
|
||||
self.user_filter = UserFilter(
|
||||
blocklist=self.config.users.blocklist,
|
||||
allowlist=self.config.users.allowlist,
|
||||
allowlist_only=self.config.users.allowlist_only,
|
||||
admin_nodes=self.config.users.admin_nodes,
|
||||
)
|
||||
self.safety_filter = SafetyFilter(self.config.safety)
|
||||
|
||||
# Rate limiter (2b)
|
||||
self.rate_limiter = RateLimiter(
|
||||
self.config.rate_limits,
|
||||
vip_nodes=self.config.users.vip_nodes,
|
||||
)
|
||||
|
||||
# Personality manager (2c)
|
||||
self.personality = PersonalityManager(self.config.personality)
|
||||
# Command dispatcher
|
||||
self.dispatcher = create_dispatcher()
|
||||
|
||||
# LLM backend
|
||||
api_key = self.config.resolve_api_key()
|
||||
|
|
@ -160,100 +107,52 @@ class MeshAI:
|
|||
window_size = mem_cfg.window_size if mem_cfg.enabled else 0
|
||||
summarize_threshold = mem_cfg.summarize_threshold
|
||||
|
||||
# Create primary backend
|
||||
# Create backend
|
||||
backend = self.config.llm.backend.lower()
|
||||
if backend == "openai":
|
||||
primary = OpenAIBackend(
|
||||
self.llm = OpenAIBackend(
|
||||
self.config.llm, api_key, window_size, summarize_threshold
|
||||
)
|
||||
elif backend == "anthropic":
|
||||
primary = AnthropicBackend(
|
||||
self.llm = AnthropicBackend(
|
||||
self.config.llm, api_key, window_size, summarize_threshold
|
||||
)
|
||||
elif backend == "google":
|
||||
primary = GoogleBackend(
|
||||
self.llm = GoogleBackend(
|
||||
self.config.llm, api_key, window_size, summarize_threshold
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Unknown backend '{backend}', defaulting to OpenAI")
|
||||
primary = OpenAIBackend(
|
||||
self.llm = OpenAIBackend(
|
||||
self.config.llm, api_key, window_size, summarize_threshold
|
||||
)
|
||||
|
||||
# Wrap in FallbackBackend if fallback is configured (2g)
|
||||
if self.config.llm.fallback:
|
||||
self.llm = FallbackBackend(
|
||||
self.config.llm, api_key, window_size, summarize_threshold
|
||||
)
|
||||
else:
|
||||
self.llm = primary
|
||||
|
||||
# Load persisted summaries into memory cache
|
||||
await self._load_summaries()
|
||||
|
||||
# Meshtastic connector
|
||||
self.connector = MeshConnector(self.config.connection)
|
||||
|
||||
# Message router (pass personality manager)
|
||||
# Message router
|
||||
self.router = MessageRouter(
|
||||
self.config, self.connector, self.history, self.dispatcher, self.llm,
|
||||
personality=self.personality,
|
||||
)
|
||||
|
||||
# Responder
|
||||
self.responder = Responder(self.config.response, self.connector)
|
||||
|
||||
# Webhook client (2d)
|
||||
self.webhook = WebhookClient(self.config.integrations.webhook)
|
||||
|
||||
# Web status server (2e)
|
||||
self.web_status = WebStatusServer(self.config.web_status)
|
||||
|
||||
# Announcement scheduler (2f)
|
||||
async def _send_announcement(text: str, channel: int) -> None:
|
||||
self.connector.send_message(text=text, channel=channel)
|
||||
|
||||
self.announcements = AnnouncementScheduler(
|
||||
self.config.announcements,
|
||||
send_callback=_send_announcement,
|
||||
)
|
||||
|
||||
async def _on_message(self, message: MeshMessage) -> None:
|
||||
"""Handle incoming message."""
|
||||
try:
|
||||
# Check user filter (2a)
|
||||
allowed, reason = self.user_filter.is_allowed(message.sender_id)
|
||||
if not allowed:
|
||||
logger.debug(f"Blocked message from {message.sender_id}: {reason}")
|
||||
return
|
||||
|
||||
# Check if we should respond
|
||||
if not self.router.should_respond(message):
|
||||
return
|
||||
|
||||
# Check rate limiter (2b)
|
||||
allowed, reason = self.rate_limiter.is_allowed(message.sender_id)
|
||||
if not allowed:
|
||||
logger.debug(f"Rate limited {message.sender_id}: {reason}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Processing message from {message.sender_name} ({message.sender_id}): "
|
||||
f"{message.text[:50]}..."
|
||||
)
|
||||
|
||||
# Record in web status (2e)
|
||||
get_status_data().record_message(message.sender_id, message.sender_name)
|
||||
|
||||
# Send webhook event (2d)
|
||||
await self.webhook.on_message_received(
|
||||
sender_id=message.sender_id,
|
||||
sender_name=message.sender_name,
|
||||
text=message.text,
|
||||
channel=message.channel,
|
||||
is_dm=message.is_dm,
|
||||
)
|
||||
|
||||
# Route the message
|
||||
result = await self.router.route(message)
|
||||
|
||||
|
|
@ -271,10 +170,6 @@ class MeshAI:
|
|||
if not response:
|
||||
return
|
||||
|
||||
# Apply safety filter to LLM responses (2a)
|
||||
if result.route_type == RouteType.LLM:
|
||||
response = self.safety_filter.filter_response(response)
|
||||
|
||||
# Send response
|
||||
if message.is_dm:
|
||||
await self.responder.send_response(
|
||||
|
|
@ -292,21 +187,8 @@ class MeshAI:
|
|||
channel=message.channel,
|
||||
)
|
||||
|
||||
# Record response in rate limiter and status (2b, 2e)
|
||||
self.rate_limiter.record_message(message.sender_id)
|
||||
get_status_data().record_response()
|
||||
|
||||
# Send webhook event (2d)
|
||||
await self.webhook.on_response_sent(
|
||||
recipient_id=message.sender_id if message.is_dm else None,
|
||||
text=response,
|
||||
channel=message.channel,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message: {e}", exc_info=True)
|
||||
get_status_data().record_error(str(e))
|
||||
await self.webhook.on_error(str(e))
|
||||
|
||||
async def _load_summaries(self) -> None:
|
||||
"""Load persisted summaries from database into memory cache."""
|
||||
|
|
@ -417,14 +299,6 @@ def main() -> None:
|
|||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Handle SIGHUP for config reload
|
||||
def reload_handler(sig, frame):
|
||||
logger.info("Received SIGHUP - reloading config")
|
||||
# For now, just log - full reload would require more work
|
||||
# Could reload config and reinitialize components
|
||||
|
||||
signal.signal(signal.SIGHUP, reload_handler)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(bot.start())
|
||||
except KeyboardInterrupt:
|
||||
|
|
|
|||
|
|
@ -1,119 +0,0 @@
|
|||
"""Personality and prompt template handling for MeshAI."""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from .config import PersonalityConfig
|
||||
|
||||
|
||||
class PersonalityManager:
|
||||
"""Manages personality switching and prompt templating."""
|
||||
|
||||
def __init__(self, config: PersonalityConfig):
|
||||
self.config = config
|
||||
self._current_persona: Optional[str] = None
|
||||
self._persona_prompts: dict[str, str] = {}
|
||||
|
||||
# Parse personas from config
|
||||
for name, persona_data in config.personas.items():
|
||||
if isinstance(persona_data, dict):
|
||||
self._persona_prompts[name] = persona_data.get("prompt", "")
|
||||
else:
|
||||
self._persona_prompts[name] = str(persona_data)
|
||||
|
||||
def get_system_prompt(
|
||||
self,
|
||||
sender_name: str = "",
|
||||
channel: int = 0,
|
||||
extra_context: Optional[dict] = None,
|
||||
) -> str:
|
||||
"""Get the current system prompt with context injection.
|
||||
|
||||
Args:
|
||||
sender_name: Name of the message sender
|
||||
channel: Channel number
|
||||
extra_context: Additional context variables
|
||||
|
||||
Returns:
|
||||
Formatted system prompt
|
||||
"""
|
||||
# Start with base prompt or persona prompt
|
||||
if self._current_persona and self._current_persona in self._persona_prompts:
|
||||
base_prompt = self._persona_prompts[self._current_persona]
|
||||
else:
|
||||
base_prompt = self.config.system_prompt
|
||||
|
||||
# Apply context injection if configured
|
||||
if self.config.context_injection:
|
||||
context_vars = {
|
||||
"time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
"sender_name": sender_name,
|
||||
"channel": str(channel),
|
||||
"persona": self._current_persona or "default",
|
||||
}
|
||||
if extra_context:
|
||||
context_vars.update(extra_context)
|
||||
|
||||
try:
|
||||
context = self.config.context_injection.format(**context_vars)
|
||||
base_prompt = f"{base_prompt}\n\n{context}"
|
||||
except KeyError as e:
|
||||
# Ignore missing context variables
|
||||
pass
|
||||
|
||||
return base_prompt
|
||||
|
||||
def check_persona_trigger(self, text: str) -> Optional[str]:
|
||||
"""Check if text contains a persona switch trigger.
|
||||
|
||||
Args:
|
||||
text: Message text to check
|
||||
|
||||
Returns:
|
||||
Persona name if triggered, None otherwise
|
||||
"""
|
||||
text_lower = text.lower().strip()
|
||||
|
||||
for name, persona_data in self.config.personas.items():
|
||||
trigger = None
|
||||
if isinstance(persona_data, dict):
|
||||
trigger = persona_data.get("trigger", f"!{name}")
|
||||
else:
|
||||
trigger = f"!{name}"
|
||||
|
||||
if trigger and text_lower.startswith(trigger.lower()):
|
||||
return name
|
||||
|
||||
return None
|
||||
|
||||
def switch_persona(self, persona_name: Optional[str]) -> bool:
|
||||
"""Switch to a different persona.
|
||||
|
||||
Args:
|
||||
persona_name: Name of persona to switch to, or None for default
|
||||
|
||||
Returns:
|
||||
True if switch was successful
|
||||
"""
|
||||
if persona_name is None:
|
||||
self._current_persona = None
|
||||
return True
|
||||
|
||||
if persona_name in self._persona_prompts:
|
||||
self._current_persona = persona_name
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_current_persona(self) -> Optional[str]:
|
||||
"""Get the name of the current persona."""
|
||||
return self._current_persona
|
||||
|
||||
def list_personas(self) -> list[str]:
|
||||
"""List available persona names."""
|
||||
return list(self._persona_prompts.keys())
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset to default persona."""
|
||||
self._current_persona = None
|
||||
|
|
@ -1,115 +0,0 @@
|
|||
"""Rate limiting for MeshAI."""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from .config import RateLimitsConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserRateState:
|
||||
"""Rate limit state for a single user."""
|
||||
|
||||
message_times: list[float] = field(default_factory=list)
|
||||
last_response_time: float = 0.0
|
||||
burst_count: int = 0
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate limiter for message processing."""
|
||||
|
||||
def __init__(self, config: RateLimitsConfig, vip_nodes: Optional[list[str]] = None):
|
||||
self.config = config
|
||||
self.vip_nodes = set(vip_nodes or [])
|
||||
self._user_states: dict[str, UserRateState] = defaultdict(UserRateState)
|
||||
self._global_times: list[float] = []
|
||||
|
||||
def is_allowed(self, user_id: str) -> tuple[bool, Optional[str]]:
|
||||
"""Check if a message from user is allowed.
|
||||
|
||||
Args:
|
||||
user_id: The user's node ID
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, reason). If not allowed, reason explains why.
|
||||
"""
|
||||
# VIP users bypass rate limits
|
||||
if user_id in self.vip_nodes:
|
||||
return True, None
|
||||
|
||||
now = time.time()
|
||||
state = self._user_states[user_id]
|
||||
|
||||
# Clean old timestamps (older than 1 minute)
|
||||
cutoff = now - 60.0
|
||||
state.message_times = [t for t in state.message_times if t > cutoff]
|
||||
self._global_times = [t for t in self._global_times if t > cutoff]
|
||||
|
||||
# Check cooldown (minimum time between responses to same user)
|
||||
if state.last_response_time > 0:
|
||||
elapsed = now - state.last_response_time
|
||||
if elapsed < self.config.cooldown_seconds:
|
||||
remaining = self.config.cooldown_seconds - elapsed
|
||||
return False, f"Cooldown: wait {remaining:.1f}s"
|
||||
|
||||
# Check per-user rate limit
|
||||
if len(state.message_times) >= self.config.messages_per_minute:
|
||||
# Check burst allowance
|
||||
if state.burst_count >= self.config.burst_allowance:
|
||||
return False, "Rate limit exceeded (per-user)"
|
||||
state.burst_count += 1
|
||||
else:
|
||||
state.burst_count = 0
|
||||
|
||||
# Check global rate limit
|
||||
if len(self._global_times) >= self.config.global_messages_per_minute:
|
||||
return False, "Rate limit exceeded (global)"
|
||||
|
||||
return True, None
|
||||
|
||||
def record_message(self, user_id: str) -> None:
|
||||
"""Record that a message was processed for a user."""
|
||||
now = time.time()
|
||||
state = self._user_states[user_id]
|
||||
state.message_times.append(now)
|
||||
state.last_response_time = now
|
||||
self._global_times.append(now)
|
||||
|
||||
def get_user_stats(self, user_id: str) -> dict:
|
||||
"""Get rate limit stats for a user."""
|
||||
now = time.time()
|
||||
state = self._user_states[user_id]
|
||||
|
||||
cutoff = now - 60.0
|
||||
recent_count = len([t for t in state.message_times if t > cutoff])
|
||||
|
||||
return {
|
||||
"messages_last_minute": recent_count,
|
||||
"limit": self.config.messages_per_minute,
|
||||
"remaining": max(0, self.config.messages_per_minute - recent_count),
|
||||
"is_vip": user_id in self.vip_nodes,
|
||||
}
|
||||
|
||||
def get_global_stats(self) -> dict:
|
||||
"""Get global rate limit stats."""
|
||||
now = time.time()
|
||||
cutoff = now - 60.0
|
||||
recent_count = len([t for t in self._global_times if t > cutoff])
|
||||
|
||||
return {
|
||||
"messages_last_minute": recent_count,
|
||||
"limit": self.config.global_messages_per_minute,
|
||||
"remaining": max(0, self.config.global_messages_per_minute - recent_count),
|
||||
}
|
||||
|
||||
def reset_user(self, user_id: str) -> None:
|
||||
"""Reset rate limit state for a user."""
|
||||
if user_id in self._user_states:
|
||||
del self._user_states[user_id]
|
||||
|
||||
def reset_all(self) -> None:
|
||||
"""Reset all rate limit state."""
|
||||
self._user_states.clear()
|
||||
self._global_times.clear()
|
||||
|
|
@ -11,7 +11,6 @@ from .commands import CommandContext, CommandDispatcher
|
|||
from .config import Config
|
||||
from .connector import MeshConnector, MeshMessage
|
||||
from .history import ConversationHistory
|
||||
from .personality import PersonalityManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -62,14 +61,12 @@ class MessageRouter:
|
|||
history: ConversationHistory,
|
||||
dispatcher: CommandDispatcher,
|
||||
llm_backend: LLMBackend,
|
||||
personality: Optional[PersonalityManager] = None,
|
||||
):
|
||||
self.config = config
|
||||
self.connector = connector
|
||||
self.history = history
|
||||
self.dispatcher = dispatcher
|
||||
self.llm = llm_backend
|
||||
self.personality = personality
|
||||
|
||||
# Compile mention pattern
|
||||
bot_name = re.escape(config.bot.name)
|
||||
|
|
@ -160,15 +157,9 @@ class MessageRouter:
|
|||
# Get conversation history
|
||||
history = await self.history.get_history_for_llm(message.sender_id)
|
||||
|
||||
# Get system prompt from personality manager or config
|
||||
# Get system prompt from config
|
||||
system_prompt = ""
|
||||
if getattr(self.config.llm, 'use_system_prompt', True):
|
||||
if self.personality:
|
||||
system_prompt = self.personality.get_system_prompt(
|
||||
sender_name=message.sender_name,
|
||||
channel=message.channel,
|
||||
)
|
||||
else:
|
||||
system_prompt = self.config.llm.system_prompt
|
||||
|
||||
try:
|
||||
|
|
@ -217,14 +208,12 @@ class MessageRouter:
|
|||
cleaned = " ".join(cleaned.split())
|
||||
cleaned = cleaned.strip()
|
||||
|
||||
# Check for prompt injection if guard is enabled
|
||||
if self.config.safety.prompt_injection_guard:
|
||||
# Check for prompt injection
|
||||
for pattern in _INJECTION_PATTERNS:
|
||||
if pattern.search(cleaned):
|
||||
logger.warning(
|
||||
f"Possible prompt injection detected: {cleaned[:80]}..."
|
||||
)
|
||||
# Truncate to just the part before the injection pattern
|
||||
match = pattern.search(cleaned)
|
||||
cleaned = cleaned[:match.start()].strip()
|
||||
if not cleaned:
|
||||
|
|
|
|||
142
meshai/safety.py
142
meshai/safety.py
|
|
@ -1,142 +0,0 @@
|
|||
"""Response filtering and safety for MeshAI."""
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from .config import SafetyConfig
|
||||
|
||||
# Basic profanity list (expand as needed)
|
||||
PROFANITY_PATTERNS = [
|
||||
r"\bf+u+c+k+\w*\b",
|
||||
r"\bs+h+i+t+\w*\b",
|
||||
r"\ba+s+s+h+o+l+e+\w*\b",
|
||||
r"\bb+i+t+c+h+\w*\b",
|
||||
r"\bc+u+n+t+\w*\b",
|
||||
r"\bd+a+m+n+\w*\b",
|
||||
]
|
||||
|
||||
|
||||
class SafetyFilter:
|
||||
"""Filter for response safety and content moderation."""
|
||||
|
||||
def __init__(self, config: SafetyConfig):
|
||||
self.config = config
|
||||
self._profanity_regex = None
|
||||
if config.filter_profanity:
|
||||
self._profanity_regex = re.compile(
|
||||
"|".join(PROFANITY_PATTERNS), re.IGNORECASE
|
||||
)
|
||||
|
||||
def filter_response(self, text: str) -> str:
|
||||
"""Filter a response for safety.
|
||||
|
||||
Args:
|
||||
text: The response text to filter
|
||||
|
||||
Returns:
|
||||
Filtered text
|
||||
"""
|
||||
# Truncate to max length
|
||||
if len(text) > self.config.max_response_length:
|
||||
text = text[: self.config.max_response_length - 3] + "..."
|
||||
|
||||
# Filter profanity
|
||||
if self._profanity_regex:
|
||||
text = self._profanity_regex.sub("***", text)
|
||||
|
||||
# Filter blocked phrases
|
||||
for phrase in self.config.blocked_phrases:
|
||||
text = text.replace(phrase, "[filtered]")
|
||||
text = text.replace(phrase.lower(), "[filtered]")
|
||||
text = text.replace(phrase.upper(), "[filtered]")
|
||||
|
||||
return text
|
||||
|
||||
def should_respond(
|
||||
self,
|
||||
text: str,
|
||||
sender_id: str,
|
||||
own_id: str,
|
||||
is_mentioned: bool,
|
||||
is_dm: bool,
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
"""Check if we should respond to this message.
|
||||
|
||||
Args:
|
||||
text: Message text
|
||||
sender_id: Sender's node ID
|
||||
own_id: Our own node ID
|
||||
is_mentioned: Whether our name is mentioned
|
||||
is_dm: Whether this is a direct message
|
||||
|
||||
Returns:
|
||||
Tuple of (should_respond, reason). Reason is None if we should respond.
|
||||
"""
|
||||
# Never respond to self
|
||||
if self.config.ignore_self and sender_id == own_id:
|
||||
return False, "Self message"
|
||||
|
||||
# Check for emergency keywords (always respond)
|
||||
text_lower = text.lower()
|
||||
for keyword in self.config.emergency_keywords:
|
||||
if keyword.lower() in text_lower:
|
||||
return True, None
|
||||
|
||||
# Check mention requirement
|
||||
if self.config.require_mention and not is_mentioned and not is_dm:
|
||||
return False, "Not mentioned"
|
||||
|
||||
return True, None
|
||||
|
||||
def contains_emergency(self, text: str) -> bool:
|
||||
"""Check if text contains emergency keywords."""
|
||||
text_lower = text.lower()
|
||||
return any(kw.lower() in text_lower for kw in self.config.emergency_keywords)
|
||||
|
||||
|
||||
class UserFilter:
|
||||
"""Filter for user access control."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
blocklist: list[str],
|
||||
allowlist: list[str],
|
||||
allowlist_only: bool,
|
||||
admin_nodes: list[str],
|
||||
):
|
||||
self.blocklist = set(blocklist)
|
||||
self.allowlist = set(allowlist)
|
||||
self.allowlist_only = allowlist_only
|
||||
self.admin_nodes = set(admin_nodes)
|
||||
|
||||
def is_allowed(self, user_id: str) -> tuple[bool, Optional[str]]:
|
||||
"""Check if user is allowed to interact.
|
||||
|
||||
Args:
|
||||
user_id: The user's node ID
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, reason)
|
||||
"""
|
||||
# Check blocklist first
|
||||
if user_id in self.blocklist:
|
||||
return False, "User is blocked"
|
||||
|
||||
# If allowlist_only mode, check allowlist
|
||||
if self.allowlist_only:
|
||||
if user_id not in self.allowlist:
|
||||
return False, "User not in allowlist"
|
||||
|
||||
return True, None
|
||||
|
||||
def is_admin(self, user_id: str) -> bool:
|
||||
"""Check if user is an admin."""
|
||||
return user_id in self.admin_nodes
|
||||
|
||||
def add_to_blocklist(self, user_id: str) -> None:
|
||||
"""Add a user to the blocklist."""
|
||||
self.blocklist.add(user_id)
|
||||
|
||||
def remove_from_blocklist(self, user_id: str) -> None:
|
||||
"""Remove a user from the blocklist."""
|
||||
self.blocklist.discard(user_id)
|
||||
|
|
@ -1,292 +0,0 @@
|
|||
"""Simple web status page for MeshAI."""
|
||||
|
||||
import asyncio
|
||||
import html as html_module
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from threading import Thread
|
||||
from typing import Callable, Optional
|
||||
|
||||
from .config import WebStatusConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StatusData:
|
||||
"""Container for status information."""
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self.start_time = time.time()
|
||||
self.message_count = 0
|
||||
self.response_count = 0
|
||||
self.error_count = 0
|
||||
self.connected_nodes: set[str] = set()
|
||||
self.recent_activity: list[dict] = []
|
||||
self.last_message_time: Optional[float] = None
|
||||
self.using_fallback = False
|
||||
|
||||
def record_message(self, sender_id: str, sender_name: str):
|
||||
"""Record an incoming message."""
|
||||
with self._lock:
|
||||
self.message_count += 1
|
||||
self.last_message_time = time.time()
|
||||
self.connected_nodes.add(sender_id)
|
||||
|
||||
self.recent_activity.append({
|
||||
"type": "message",
|
||||
"time": datetime.now().isoformat(),
|
||||
"sender": sender_name,
|
||||
})
|
||||
# Keep only last 20 activities
|
||||
self.recent_activity = self.recent_activity[-20:]
|
||||
|
||||
def record_response(self):
|
||||
"""Record an outgoing response."""
|
||||
with self._lock:
|
||||
self.response_count += 1
|
||||
|
||||
def record_error(self, error: str):
|
||||
"""Record an error."""
|
||||
with self._lock:
|
||||
self.error_count += 1
|
||||
self.recent_activity.append({
|
||||
"type": "error",
|
||||
"time": datetime.now().isoformat(),
|
||||
"error": error[:100],
|
||||
})
|
||||
self.recent_activity = self.recent_activity[-20:]
|
||||
|
||||
def get_uptime(self) -> str:
|
||||
"""Get formatted uptime string."""
|
||||
elapsed = int(time.time() - self.start_time)
|
||||
days, remainder = divmod(elapsed, 86400)
|
||||
hours, remainder = divmod(remainder, 3600)
|
||||
minutes, seconds = divmod(remainder, 60)
|
||||
|
||||
parts = []
|
||||
if days:
|
||||
parts.append(f"{days}d")
|
||||
if hours:
|
||||
parts.append(f"{hours}h")
|
||||
if minutes:
|
||||
parts.append(f"{minutes}m")
|
||||
parts.append(f"{seconds}s")
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
def to_dict(self, include_activity: bool = False) -> dict:
|
||||
"""Convert to dictionary for JSON response."""
|
||||
with self._lock:
|
||||
data = {
|
||||
"status": "online",
|
||||
"uptime": self.get_uptime(),
|
||||
"uptime_seconds": int(time.time() - self.start_time),
|
||||
"messages_received": self.message_count,
|
||||
"responses_sent": self.response_count,
|
||||
"errors": self.error_count,
|
||||
"connected_nodes": len(self.connected_nodes),
|
||||
"using_fallback": self.using_fallback,
|
||||
}
|
||||
|
||||
if self.last_message_time:
|
||||
data["last_message_ago"] = int(time.time() - self.last_message_time)
|
||||
|
||||
if include_activity:
|
||||
data["recent_activity"] = list(self.recent_activity)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# Global status data instance
|
||||
_status_data = StatusData()
|
||||
|
||||
|
||||
def get_status_data() -> StatusData:
|
||||
"""Get the global status data instance."""
|
||||
return _status_data
|
||||
|
||||
|
||||
class StatusRequestHandler(BaseHTTPRequestHandler):
|
||||
"""HTTP request handler for status page."""
|
||||
|
||||
config: WebStatusConfig = None
|
||||
|
||||
def log_message(self, format, *args):
|
||||
"""Suppress default logging."""
|
||||
pass
|
||||
|
||||
def do_GET(self):
|
||||
"""Handle GET requests."""
|
||||
if self.path == "/" or self.path == "/status":
|
||||
self._serve_status_page()
|
||||
elif self.path == "/api/status":
|
||||
self._serve_json_status()
|
||||
elif self.path == "/health":
|
||||
self._serve_health()
|
||||
else:
|
||||
self.send_error(404)
|
||||
|
||||
def _check_auth(self) -> bool:
|
||||
"""Check authentication if required."""
|
||||
if not self.config or not self.config.require_auth:
|
||||
return True
|
||||
|
||||
auth_header = self.headers.get("Authorization", "")
|
||||
if auth_header.startswith("Basic "):
|
||||
import base64
|
||||
try:
|
||||
decoded = base64.b64decode(auth_header[6:]).decode()
|
||||
_, password = decoded.split(":", 1)
|
||||
return password == self.config.auth_password
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
def _serve_status_page(self):
|
||||
"""Serve HTML status page."""
|
||||
if not self._check_auth():
|
||||
self.send_response(401)
|
||||
self.send_header("WWW-Authenticate", 'Basic realm="MeshAI Status"')
|
||||
self.end_headers()
|
||||
return
|
||||
|
||||
status = _status_data.to_dict(
|
||||
include_activity=self.config.show_recent_activity if self.config else False
|
||||
)
|
||||
|
||||
esc = html_module.escape
|
||||
|
||||
# Build optional stat rows
|
||||
rows = ""
|
||||
if self.config and self.config.show_uptime:
|
||||
rows += (
|
||||
'<div class="stat"><span class="stat-label">Uptime</span>'
|
||||
f'<span class="stat-value">{esc(str(status["uptime"]))}</span></div>'
|
||||
)
|
||||
if self.config and self.config.show_message_count:
|
||||
rows += (
|
||||
'<div class="stat"><span class="stat-label">Messages</span>'
|
||||
f'<span class="stat-value">{esc(str(status["messages_received"]))}</span></div>'
|
||||
'<div class="stat"><span class="stat-label">Responses</span>'
|
||||
f'<span class="stat-value">{esc(str(status["responses_sent"]))}</span></div>'
|
||||
)
|
||||
if self.config and self.config.show_connected_nodes:
|
||||
rows += (
|
||||
'<div class="stat"><span class="stat-label">Connected Nodes</span>'
|
||||
f'<span class="stat-value">{esc(str(status["connected_nodes"]))}</span></div>'
|
||||
)
|
||||
|
||||
status_class = "status-fallback" if status.get("using_fallback") else "status-online"
|
||||
status_text = "ONLINE (Fallback)" if status.get("using_fallback") else "ONLINE"
|
||||
|
||||
html = f"""<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>MeshAI Status</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<style>
|
||||
body {{
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, monospace;
|
||||
background: #0d1117;
|
||||
color: #c9d1d9;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
}}
|
||||
.container {{ max-width: 600px; margin: 0 auto; }}
|
||||
h1 {{ color: #58a6ff; border-bottom: 1px solid #30363d; padding-bottom: 10px; }}
|
||||
.stat {{
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
padding: 10px 0;
|
||||
border-bottom: 1px solid #21262d;
|
||||
}}
|
||||
.stat-label {{ color: #8b949e; }}
|
||||
.stat-value {{ color: #58a6ff; font-weight: bold; }}
|
||||
.status-online {{ color: #3fb950; }}
|
||||
.status-fallback {{ color: #d29922; }}
|
||||
.footer {{ margin-top: 20px; color: #484f58; font-size: 12px; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>MeshAI Status</h1>
|
||||
<div class="stat">
|
||||
<span class="stat-label">Status</span>
|
||||
<span class="stat-value {esc(status_class)}">{esc(status_text)}</span>
|
||||
</div>
|
||||
{rows}
|
||||
<div class="stat">
|
||||
<span class="stat-label">Errors</span>
|
||||
<span class="stat-value">{esc(str(status["errors"]))}</span>
|
||||
</div>
|
||||
<div class="footer">Auto-refresh in 30s</div>
|
||||
</div>
|
||||
<script>setTimeout(() => location.reload(), 30000);</script>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(html.encode())
|
||||
|
||||
def _serve_json_status(self):
|
||||
"""Serve JSON status."""
|
||||
if not self._check_auth():
|
||||
self.send_response(401)
|
||||
self.end_headers()
|
||||
return
|
||||
|
||||
status = _status_data.to_dict(
|
||||
include_activity=self.config.show_recent_activity if self.config else False
|
||||
)
|
||||
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps(status, indent=2).encode())
|
||||
|
||||
def _serve_health(self):
|
||||
"""Serve simple health check."""
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/plain")
|
||||
self.end_headers()
|
||||
self.wfile.write(b"OK")
|
||||
|
||||
|
||||
class WebStatusServer:
|
||||
"""Web status server manager."""
|
||||
|
||||
def __init__(self, config: WebStatusConfig):
|
||||
self.config = config
|
||||
self._server: Optional[HTTPServer] = None
|
||||
self._thread: Optional[Thread] = None
|
||||
|
||||
def start(self):
|
||||
"""Start the web status server."""
|
||||
if not self.config.enabled:
|
||||
return
|
||||
|
||||
StatusRequestHandler.config = self.config
|
||||
|
||||
try:
|
||||
self._server = HTTPServer(("0.0.0.0", self.config.port), StatusRequestHandler)
|
||||
self._thread = Thread(target=self._server.serve_forever, daemon=True)
|
||||
self._thread.start()
|
||||
logger.info(f"Web status server started on port {self.config.port}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start web status server: {e}")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the web status server."""
|
||||
if self._server:
|
||||
self._server.shutdown()
|
||||
self._server = None
|
||||
self._thread = None
|
||||
logger.info("Web status server stopped")
|
||||
|
|
@ -1,176 +0,0 @@
|
|||
"""Webhook integration for MeshAI."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import WebhookConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebhookClient:
|
||||
"""Client for sending webhook notifications."""
|
||||
|
||||
def __init__(self, config: WebhookConfig):
|
||||
self.config = config
|
||||
self._client: Optional[httpx.AsyncClient] = None
|
||||
self._queue: asyncio.Queue = asyncio.Queue()
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
|
||||
async def start(self):
|
||||
"""Start the webhook client."""
|
||||
if not self.config.enabled or not self.config.url:
|
||||
logger.debug("Webhooks disabled or no URL configured")
|
||||
return
|
||||
|
||||
self._client = httpx.AsyncClient(timeout=10.0)
|
||||
self._task = asyncio.create_task(self._process_queue())
|
||||
logger.info(f"Webhook client started: {self.config.url}")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the webhook client."""
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._task = None
|
||||
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
logger.info("Webhook client stopped")
|
||||
|
||||
def _should_send(self, event_type: str) -> bool:
|
||||
"""Check if this event type should be sent."""
|
||||
if not self.config.enabled:
|
||||
return False
|
||||
return event_type in self.config.events
|
||||
|
||||
async def send_event(
|
||||
self,
|
||||
event_type: str,
|
||||
data: dict[str, Any],
|
||||
immediate: bool = False,
|
||||
):
|
||||
"""Send a webhook event.
|
||||
|
||||
Args:
|
||||
event_type: Type of event (message_received, response_sent, error)
|
||||
data: Event data
|
||||
immediate: If True, send immediately instead of queuing
|
||||
"""
|
||||
if not self._should_send(event_type):
|
||||
return
|
||||
|
||||
payload = {
|
||||
"event": event_type,
|
||||
"timestamp": datetime.utcnow().isoformat() + "Z",
|
||||
"data": data,
|
||||
}
|
||||
|
||||
if immediate:
|
||||
await self._send_payload(payload)
|
||||
else:
|
||||
await self._queue.put(payload)
|
||||
|
||||
async def _process_queue(self):
|
||||
"""Process queued webhook payloads."""
|
||||
while True:
|
||||
try:
|
||||
payload = await self._queue.get()
|
||||
await self._send_payload(payload)
|
||||
self._queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing webhook queue: {e}")
|
||||
|
||||
async def _send_payload(self, payload: dict):
|
||||
"""Send a webhook payload."""
|
||||
if not self._client:
|
||||
return
|
||||
|
||||
try:
|
||||
response = await self._client.post(
|
||||
self.config.url,
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
if response.status_code >= 400:
|
||||
logger.warning(
|
||||
f"Webhook returned {response.status_code}: {response.text[:100]}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send webhook: {e}")
|
||||
|
||||
# Convenience methods for common events
|
||||
|
||||
async def on_message_received(
|
||||
self,
|
||||
sender_id: str,
|
||||
sender_name: str,
|
||||
text: str,
|
||||
channel: int,
|
||||
is_dm: bool,
|
||||
):
|
||||
"""Send message_received event."""
|
||||
await self.send_event(
|
||||
"message_received",
|
||||
{
|
||||
"sender_id": sender_id,
|
||||
"sender_name": sender_name,
|
||||
"text": text,
|
||||
"channel": channel,
|
||||
"is_dm": is_dm,
|
||||
},
|
||||
)
|
||||
|
||||
async def on_response_sent(
|
||||
self,
|
||||
recipient_id: Optional[str],
|
||||
text: str,
|
||||
channel: int,
|
||||
):
|
||||
"""Send response_sent event."""
|
||||
await self.send_event(
|
||||
"response_sent",
|
||||
{
|
||||
"recipient_id": recipient_id,
|
||||
"text": text,
|
||||
"channel": channel,
|
||||
},
|
||||
)
|
||||
|
||||
async def on_error(self, error: str, context: Optional[dict] = None):
|
||||
"""Send error event."""
|
||||
await self.send_event(
|
||||
"error",
|
||||
{
|
||||
"error": error,
|
||||
"context": context or {},
|
||||
},
|
||||
)
|
||||
|
||||
async def on_startup(self):
|
||||
"""Send startup event."""
|
||||
await self.send_event(
|
||||
"startup",
|
||||
{"message": "MeshAI started"},
|
||||
immediate=True,
|
||||
)
|
||||
|
||||
async def on_shutdown(self):
|
||||
"""Send shutdown event."""
|
||||
await self.send_event(
|
||||
"shutdown",
|
||||
{"message": "MeshAI stopping"},
|
||||
immediate=True,
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue