Unify memory implementations: delete AnthropicMemory and GoogleMemory

RollingSummaryMemory now accepts a generic async summarize callable
instead of hardcoding AsyncOpenAI. Each backend provides its own
_summarize_messages() method that uses the appropriate API client.

- Removed AnthropicMemory class from anthropic_backend.py
- Removed GoogleMemory class from google_backend.py
- Changed RollingSummaryMemory.__init__ signature to accept
  summarize_fn: Callable[[list[dict]], Awaitable[str]]
- All three backends (OpenAI, Anthropic, Google) now instantiate
  the same RollingSummaryMemory class with backend-specific callables
- Extracted shared summarize prompt to module-level _SUMMARIZE_PROMPT

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Ubuntu 2026-02-23 20:14:50 +00:00
commit 2945031071
4 changed files with 101 additions and 211 deletions

View file

@ -1,68 +1,17 @@
"""Anthropic (Claude) LLM backend with rolling summary memory.""" """Anthropic (Claude) LLM backend with rolling summary memory."""
import logging import logging
import time
from typing import Optional from typing import Optional
from anthropic import AsyncAnthropic from anthropic import AsyncAnthropic
from ..config import LLMConfig from ..config import LLMConfig
from ..memory import ConversationSummary from ..memory import RollingSummaryMemory
from .base import LLMBackend from .base import LLMBackend
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_SUMMARIZE_PROMPT = """Summarize this conversation in 2-3 concise sentences. Focus on:
class AnthropicMemory:
"""Rolling summary memory for Anthropic backend."""
def __init__(self, client: AsyncAnthropic, model: str, window_size: int = 4, summarize_threshold: int = 8):
self._client = client
self._model = model
self._window_size = window_size
self._summarize_threshold = summarize_threshold
self._summaries: dict[str, ConversationSummary] = {}
async def get_context_messages(
self, user_id: str, full_history: list[dict]
) -> tuple[Optional[str], list[dict]]:
"""Get optimized context: summary + recent messages."""
if len(full_history) <= self._window_size * 2:
return None, full_history
split_point = -(self._window_size * 2)
old_messages = full_history[:split_point]
recent_messages = full_history[split_point:]
summary = await self._get_or_create_summary(user_id, old_messages)
return summary.summary, recent_messages
async def _get_or_create_summary(self, user_id: str, messages: list[dict]) -> ConversationSummary:
"""Get cached summary or create new one."""
if user_id in self._summaries:
cached = self._summaries[user_id]
if abs(cached.message_count - len(messages)) < self._summarize_threshold:
return cached
logger.debug(f"Generating summary for {user_id} ({len(messages)} messages)")
summary_text = await self._summarize(messages)
summary = ConversationSummary(
summary=summary_text,
last_updated=time.time(),
message_count=len(messages),
)
self._summaries[user_id] = summary
return summary
async def _summarize(self, messages: list[dict]) -> str:
"""Generate summary using Anthropic."""
if not messages:
return "No previous conversation."
conversation = "\n".join([f"{msg['role'].upper()}: {msg['content']}" for msg in messages])
prompt = f"""Summarize this conversation in 2-3 concise sentences. Focus on:
- Main topics discussed - Main topics discussed
- Important context or user preferences - Important context or user preferences
- Key information to remember - Key information to remember
@ -72,30 +21,6 @@ Conversation:
Summary (2-3 sentences):""" Summary (2-3 sentences):"""
try:
response = await self._client.messages.create(
model=self._model,
max_tokens=150,
messages=[{"role": "user", "content": prompt}],
)
content = response.content[0].text if response.content else ""
return content.strip() if content else f"Previous conversation: {len(messages)} messages."
except Exception as e:
logger.warning(f"Failed to generate summary: {e}")
return f"Previous conversation: {len(messages)} messages about various topics."
def load_summary(self, user_id: str, summary: ConversationSummary) -> None:
"""Load summary from database into cache."""
self._summaries[user_id] = summary
def clear_summary(self, user_id: str) -> None:
"""Clear cached summary for user."""
self._summaries.pop(user_id, None)
def get_cached_summary(self, user_id: str) -> Optional[ConversationSummary]:
"""Get cached summary for user."""
return self._summaries.get(user_id)
class AnthropicBackend(LLMBackend): class AnthropicBackend(LLMBackend):
"""Anthropic Claude backend with rolling summary memory.""" """Anthropic Claude backend with rolling summary memory."""
@ -117,13 +42,36 @@ class AnthropicBackend(LLMBackend):
""" """
self.config = config self.config = config
self._client = AsyncAnthropic(api_key=api_key) self._client = AsyncAnthropic(api_key=api_key)
self._memory = AnthropicMemory(
client=self._client, # Initialize rolling summary memory with Anthropic summarize function
model=config.model, self._memory = RollingSummaryMemory(
summarize_fn=self._summarize_messages,
window_size=window_size, window_size=window_size,
summarize_threshold=summarize_threshold, summarize_threshold=summarize_threshold,
) )
async def _summarize_messages(self, messages: list[dict]) -> str:
"""Summarize messages using Anthropic API."""
if not messages:
return "No previous conversation."
conversation = "\n".join(
[f"{msg['role'].upper()}: {msg['content']}" for msg in messages]
)
prompt = _SUMMARIZE_PROMPT.format(conversation=conversation)
try:
response = await self._client.messages.create(
model=self.config.model,
max_tokens=150,
messages=[{"role": "user", "content": prompt}],
)
content = response.content[0].text if response.content else ""
return content.strip() if content else f"Previous conversation: {len(messages)} messages."
except Exception as e:
logger.warning(f"Failed to generate summary: {e}")
return f"Previous conversation: {len(messages)} messages about various topics."
async def generate( async def generate(
self, self,
messages: list[dict], messages: list[dict],
@ -181,7 +129,7 @@ class AnthropicBackend(LLMBackend):
logger.error(f"Anthropic API error: {e}") logger.error(f"Anthropic API error: {e}")
raise raise
def get_memory(self) -> AnthropicMemory: def get_memory(self) -> RollingSummaryMemory:
"""Get the memory manager instance.""" """Get the memory manager instance."""
return self._memory return self._memory

View file

@ -1,67 +1,17 @@
"""Google Gemini LLM backend with rolling summary memory.""" """Google Gemini LLM backend with rolling summary memory."""
import logging import logging
import time
from typing import Optional from typing import Optional
import google.generativeai as genai import google.generativeai as genai
from ..config import LLMConfig from ..config import LLMConfig
from ..memory import ConversationSummary from ..memory import RollingSummaryMemory
from .base import LLMBackend from .base import LLMBackend
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_SUMMARIZE_PROMPT = """Summarize this conversation in 2-3 concise sentences. Focus on:
class GoogleMemory:
"""Rolling summary memory for Google backend."""
def __init__(self, model: genai.GenerativeModel, window_size: int = 4, summarize_threshold: int = 8):
self._model = model
self._window_size = window_size
self._summarize_threshold = summarize_threshold
self._summaries: dict[str, ConversationSummary] = {}
async def get_context_messages(
self, user_id: str, full_history: list[dict]
) -> tuple[Optional[str], list[dict]]:
"""Get optimized context: summary + recent messages."""
if len(full_history) <= self._window_size * 2:
return None, full_history
split_point = -(self._window_size * 2)
old_messages = full_history[:split_point]
recent_messages = full_history[split_point:]
summary = await self._get_or_create_summary(user_id, old_messages)
return summary.summary, recent_messages
async def _get_or_create_summary(self, user_id: str, messages: list[dict]) -> ConversationSummary:
"""Get cached summary or create new one."""
if user_id in self._summaries:
cached = self._summaries[user_id]
if abs(cached.message_count - len(messages)) < self._summarize_threshold:
return cached
logger.debug(f"Generating summary for {user_id} ({len(messages)} messages)")
summary_text = await self._summarize(messages)
summary = ConversationSummary(
summary=summary_text,
last_updated=time.time(),
message_count=len(messages),
)
self._summaries[user_id] = summary
return summary
async def _summarize(self, messages: list[dict]) -> str:
"""Generate summary using Google Gemini."""
if not messages:
return "No previous conversation."
conversation = "\n".join([f"{msg['role'].upper()}: {msg['content']}" for msg in messages])
prompt = f"""Summarize this conversation in 2-3 concise sentences. Focus on:
- Main topics discussed - Main topics discussed
- Important context or user preferences - Important context or user preferences
- Key information to remember - Key information to remember
@ -71,31 +21,6 @@ Conversation:
Summary (2-3 sentences):""" Summary (2-3 sentences):"""
try:
response = await self._model.generate_content_async(
prompt,
generation_config=genai.types.GenerationConfig(
max_output_tokens=150,
temperature=0.3,
),
)
return response.text.strip() if response.text else f"Previous conversation: {len(messages)} messages."
except Exception as e:
logger.warning(f"Failed to generate summary: {e}")
return f"Previous conversation: {len(messages)} messages about various topics."
def load_summary(self, user_id: str, summary: ConversationSummary) -> None:
"""Load summary from database into cache."""
self._summaries[user_id] = summary
def clear_summary(self, user_id: str) -> None:
"""Clear cached summary for user."""
self._summaries.pop(user_id, None)
def get_cached_summary(self, user_id: str) -> Optional[ConversationSummary]:
"""Get cached summary for user."""
return self._summaries.get(user_id)
class GoogleBackend(LLMBackend): class GoogleBackend(LLMBackend):
"""Google Gemini backend with rolling summary memory.""" """Google Gemini backend with rolling summary memory."""
@ -118,12 +43,37 @@ class GoogleBackend(LLMBackend):
self.config = config self.config = config
genai.configure(api_key=api_key) genai.configure(api_key=api_key)
self._model = genai.GenerativeModel(config.model) self._model = genai.GenerativeModel(config.model)
self._memory = GoogleMemory(
model=self._model, # Initialize rolling summary memory with Gemini summarize function
self._memory = RollingSummaryMemory(
summarize_fn=self._summarize_messages,
window_size=window_size, window_size=window_size,
summarize_threshold=summarize_threshold, summarize_threshold=summarize_threshold,
) )
async def _summarize_messages(self, messages: list[dict]) -> str:
"""Summarize messages using Google Gemini API."""
if not messages:
return "No previous conversation."
conversation = "\n".join(
[f"{msg['role'].upper()}: {msg['content']}" for msg in messages]
)
prompt = _SUMMARIZE_PROMPT.format(conversation=conversation)
try:
response = await self._model.generate_content_async(
prompt,
generation_config=genai.types.GenerationConfig(
max_output_tokens=150,
temperature=0.3,
),
)
return response.text.strip() if response.text else f"Previous conversation: {len(messages)} messages."
except Exception as e:
logger.warning(f"Failed to generate summary: {e}")
return f"Previous conversation: {len(messages)} messages about various topics."
async def generate( async def generate(
self, self,
messages: list[dict], messages: list[dict],
@ -196,7 +146,7 @@ class GoogleBackend(LLMBackend):
logger.error(f"Google API error: {e}") logger.error(f"Google API error: {e}")
raise raise
def get_memory(self) -> GoogleMemory: def get_memory(self) -> RollingSummaryMemory:
"""Get the memory manager instance.""" """Get the memory manager instance."""
return self._memory return self._memory

View file

@ -6,11 +6,21 @@ from typing import Optional
from openai import AsyncOpenAI from openai import AsyncOpenAI
from ..config import LLMConfig from ..config import LLMConfig
from ..memory import ConversationSummary, RollingSummaryMemory from ..memory import RollingSummaryMemory
from .base import LLMBackend from .base import LLMBackend
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_SUMMARIZE_PROMPT = """Summarize this conversation in 2-3 concise sentences. Focus on:
- Main topics discussed
- Important context or user preferences
- Key information to remember
Conversation:
{conversation}
Summary (2-3 sentences):"""
class OpenAIBackend(LLMBackend): class OpenAIBackend(LLMBackend):
"""OpenAI-compatible backend (works with OpenAI, LiteLLM, local models).""" """OpenAI-compatible backend (works with OpenAI, LiteLLM, local models)."""
@ -36,14 +46,36 @@ class OpenAIBackend(LLMBackend):
base_url=config.base_url, base_url=config.base_url,
) )
# Initialize rolling summary memory for context optimization # Initialize rolling summary memory with OpenAI summarize function
self._memory = RollingSummaryMemory( self._memory = RollingSummaryMemory(
client=self._client, summarize_fn=self._summarize_messages,
model=config.model,
window_size=window_size, window_size=window_size,
summarize_threshold=summarize_threshold, summarize_threshold=summarize_threshold,
) )
async def _summarize_messages(self, messages: list[dict]) -> str:
"""Summarize messages using OpenAI API."""
if not messages:
return "No previous conversation."
conversation = "\n".join(
[f"{msg['role'].upper()}: {msg['content']}" for msg in messages]
)
prompt = _SUMMARIZE_PROMPT.format(conversation=conversation)
try:
response = await self._client.chat.completions.create(
model=self.config.model,
messages=[{"role": "user", "content": prompt}],
max_tokens=150,
temperature=0.3,
)
content = response.choices[0].message.content
return content.strip() if content else f"Previous conversation: {len(messages)} messages."
except Exception as e:
logger.warning(f"Failed to generate summary: {e}")
return f"Previous conversation: {len(messages)} messages about various topics."
async def generate( async def generate(
self, self,
messages: list[dict], messages: list[dict],

View file

@ -3,9 +3,7 @@
import logging import logging
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Awaitable, Callable, Optional
from openai import AsyncOpenAI
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,21 +37,19 @@ class RollingSummaryMemory:
def __init__( def __init__(
self, self,
client: AsyncOpenAI, summarize_fn: Callable[[list[dict]], Awaitable[str]],
model: str,
window_size: int = 4, window_size: int = 4,
summarize_threshold: int = 8, summarize_threshold: int = 8,
): ):
"""Initialize rolling summary memory. """Initialize rolling summary memory.
Args: Args:
client: AsyncOpenAI client for generating summaries summarize_fn: Async callable that takes a list of messages and returns
model: Model name to use for summarization a summary string. Each backend provides its own implementation.
window_size: Number of recent message pairs to keep in full window_size: Number of recent message pairs to keep in full
summarize_threshold: Messages to accumulate before re-summarizing summarize_threshold: Messages to accumulate before re-summarizing
""" """
self._client = client self._summarize_fn = summarize_fn
self._model = model
self._window_size = window_size self._window_size = window_size
self._summarize_threshold = summarize_threshold self._summarize_threshold = summarize_threshold
@ -105,7 +101,7 @@ class RollingSummaryMemory:
# Generate new summary # Generate new summary
logger.debug(f"Generating summary for {user_id} ({len(messages)} messages)") logger.debug(f"Generating summary for {user_id} ({len(messages)} messages)")
summary_text = await self._summarize(messages) summary_text = await self._summarize_fn(messages)
summary = ConversationSummary( summary = ConversationSummary(
summary=summary_text, summary=summary_text,
@ -116,42 +112,6 @@ class RollingSummaryMemory:
self._summaries[user_id] = summary self._summaries[user_id] = summary
return summary return summary
async def _summarize(self, messages: list[dict]) -> str:
"""Generate summary using LLM."""
if not messages:
return "No previous conversation."
# Format conversation
conversation = "\n".join(
[f"{msg['role'].upper()}: {msg['content']}" for msg in messages]
)
prompt = f"""Summarize this conversation in 2-3 concise sentences. Focus on:
- Main topics discussed
- Important context or user preferences
- Key information to remember
Conversation:
{conversation}
Summary (2-3 sentences):"""
try:
response = await self._client.chat.completions.create(
model=self._model,
messages=[{"role": "user", "content": prompt}],
max_tokens=150,
temperature=0.3,
)
content = response.choices[0].message.content
return content.strip() if content else f"Previous conversation: {len(messages)} messages."
except Exception as e:
logger.warning(f"Failed to generate summary: {e}")
# Fallback - provide basic context
return f"Previous conversation: {len(messages)} messages about various topics."
def load_summary(self, user_id: str, summary: ConversationSummary) -> None: def load_summary(self, user_id: str, summary: ConversationSummary) -> None:
"""Load summary from database into cache.""" """Load summary from database into cache."""
self._summaries[user_id] = summary self._summaries[user_id] = summary