Initial commit of Discord GLHF Bot with core functionality, configuration, and dependencies.
This commit is contained in:
BIN
discord_glhf/.DS_Store
vendored
Normal file
BIN
discord_glhf/.DS_Store
vendored
Normal file
Binary file not shown.
24
discord_glhf/__init__.py
Normal file
24
discord_glhf/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Discord GLHF Bot - A Discord bot with conversation history and message queuing.
|
||||
"""
|
||||
|
||||
from .main import main
|
||||
from .config import validate_config
|
||||
from .bot import DiscordBot
|
||||
from .api import APIManager
|
||||
from .database import DatabasePool, DatabaseManager
|
||||
from .queue import QueueManager, QueueItem
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "Your Name"
|
||||
|
||||
__all__ = [
|
||||
'main',
|
||||
'validate_config',
|
||||
'DiscordBot',
|
||||
'APIManager',
|
||||
'DatabasePool',
|
||||
'DatabaseManager',
|
||||
'QueueManager',
|
||||
'QueueItem',
|
||||
]
|
||||
BIN
discord_glhf/__pycache__/__init__.cpython-313.pyc
Normal file
BIN
discord_glhf/__pycache__/__init__.cpython-313.pyc
Normal file
Binary file not shown.
BIN
discord_glhf/__pycache__/api.cpython-313.pyc
Normal file
BIN
discord_glhf/__pycache__/api.cpython-313.pyc
Normal file
Binary file not shown.
BIN
discord_glhf/__pycache__/bot.cpython-313.pyc
Normal file
BIN
discord_glhf/__pycache__/bot.cpython-313.pyc
Normal file
Binary file not shown.
BIN
discord_glhf/__pycache__/config.cpython-313.pyc
Normal file
BIN
discord_glhf/__pycache__/config.cpython-313.pyc
Normal file
Binary file not shown.
BIN
discord_glhf/__pycache__/database.cpython-313.pyc
Normal file
BIN
discord_glhf/__pycache__/database.cpython-313.pyc
Normal file
Binary file not shown.
BIN
discord_glhf/__pycache__/main.cpython-313.pyc
Normal file
BIN
discord_glhf/__pycache__/main.cpython-313.pyc
Normal file
Binary file not shown.
BIN
discord_glhf/__pycache__/queue.cpython-313.pyc
Normal file
BIN
discord_glhf/__pycache__/queue.cpython-313.pyc
Normal file
Binary file not shown.
BIN
discord_glhf/__pycache__/queue_manager.cpython-313.pyc
Normal file
BIN
discord_glhf/__pycache__/queue_manager.cpython-313.pyc
Normal file
Binary file not shown.
BIN
discord_glhf/__pycache__/queue_state.cpython-313.pyc
Normal file
BIN
discord_glhf/__pycache__/queue_state.cpython-313.pyc
Normal file
Binary file not shown.
BIN
discord_glhf/__pycache__/training.cpython-313.pyc
Normal file
BIN
discord_glhf/__pycache__/training.cpython-313.pyc
Normal file
Binary file not shown.
408
discord_glhf/api.py
Normal file
408
discord_glhf/api.py
Normal file
@@ -0,0 +1,408 @@
|
||||
#!/usr/bin/env python3
|
||||
"""API manager with OpenAI-compatible endpoints and fallback handling."""
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import os
|
||||
import tiktoken
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union
|
||||
|
||||
from .config import (
|
||||
logger,
|
||||
response_manager,
|
||||
API_ENDPOINTS,
|
||||
DEFAULT_PARAMS,
|
||||
VISION_PARAMS,
|
||||
MAX_TOKENS,
|
||||
API_MAX_RETRIES as MAX_RETRIES,
|
||||
RATE_LIMIT_BACKOFF_TIME,
|
||||
)
|
||||
|
||||
|
||||
class APIManager:
|
||||
"""Manages API interactions with multiple endpoints."""
|
||||
|
||||
def __init__(self):
|
||||
self._shutting_down = asyncio.Event()
|
||||
self.is_running = False
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._endpoints_status = {endpoint.name: True for endpoint in API_ENDPOINTS}
|
||||
# Initialize tokenizer for Claude
|
||||
self.tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
async def start(self):
|
||||
"""Start the API manager and create HTTP session."""
|
||||
self._session = aiohttp.ClientSession()
|
||||
self.is_running = True
|
||||
logger.info("API manager started")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown the API manager and cleanup."""
|
||||
self._shutting_down.set()
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self.is_running = False
|
||||
logger.info("API manager shutdown")
|
||||
|
||||
@property
|
||||
def shutting_down(self) -> bool:
|
||||
return self._shutting_down.is_set()
|
||||
|
||||
def validate_messages(self, messages: List[Dict]) -> List[Dict]:
|
||||
"""Validate and sanitize message format."""
|
||||
validated = []
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
# Handle string content
|
||||
if content.strip():
|
||||
validated.append(
|
||||
{"role": msg.get("role", "user"), "content": content.strip()}
|
||||
)
|
||||
elif isinstance(content, list):
|
||||
# Handle list content for vision analysis
|
||||
if content: # Non-empty list
|
||||
# Validate image URLs in content
|
||||
valid_content = []
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
valid_content.append(item)
|
||||
elif item.get("type") == "image_url":
|
||||
# Ensure proper image URL format
|
||||
img_url_obj = item.get("image_url", {})
|
||||
url = img_url_obj.get("url")
|
||||
is_valid = url and isinstance(url, str)
|
||||
if is_valid:
|
||||
# Build image data step by step
|
||||
url_data = {"url": url}
|
||||
img_data = {"type": "image_url", "image_url": url_data}
|
||||
valid_content.append(img_data)
|
||||
if valid_content:
|
||||
validated.append(
|
||||
{"role": msg.get("role", "user"), "content": valid_content}
|
||||
)
|
||||
return validated
|
||||
|
||||
def _count_tokens(self, messages: List[Dict[str, Any]]) -> int:
|
||||
"""Count tokens accurately using tiktoken."""
|
||||
total_tokens = 0
|
||||
|
||||
for msg in messages:
|
||||
# Add message format overhead
|
||||
total_tokens += 4 # role, content markers
|
||||
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
# Count tokens in text content
|
||||
total_tokens += len(self.tokenizer.encode(content))
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
total_tokens += len(self.tokenizer.encode(item.get("text", "")))
|
||||
elif item.get("type") == "image_url":
|
||||
# Add token overhead for image URLs and embeddings
|
||||
total_tokens += 85 # Approximate token cost per image
|
||||
|
||||
return total_tokens
|
||||
|
||||
async def _make_api_call(
|
||||
self,
|
||||
endpoint_name: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
endpoint_params: Dict[str, Any],
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
"""Make API call to a specific endpoint."""
|
||||
if not self._session:
|
||||
raise RuntimeError("API manager not started")
|
||||
|
||||
# Check if this is a vision endpoint call
|
||||
is_vision = endpoint_name == "vision"
|
||||
|
||||
# Get configuration from environment variables
|
||||
timeout_env = endpoint_params.get("timeout_env")
|
||||
timeout_value = os.getenv(timeout_env)
|
||||
|
||||
# Log detailed timeout information
|
||||
logger.info(
|
||||
f"API call timeout details for {endpoint_name}:\n"
|
||||
f"- Timeout env var: {timeout_env}\n"
|
||||
f"- Raw env value: {timeout_value}\n"
|
||||
f"- Default params timeout: {DEFAULT_PARAMS.get('timeout')}\n"
|
||||
f"- Vision params timeout: {VISION_PARAMS.get('timeout')}"
|
||||
)
|
||||
|
||||
config = {
|
||||
"api_key": os.getenv(endpoint_params.get("key_env")),
|
||||
"base_url": os.getenv(endpoint_params.get("url_env")),
|
||||
"model": os.getenv(endpoint_params.get("model_env")),
|
||||
}
|
||||
|
||||
# Always set timeout from env var if available, otherwise don't set it
|
||||
# This ensures we don't accidentally override with defaults
|
||||
if timeout_value:
|
||||
config["timeout"] = float(timeout_value)
|
||||
logger.info(f"Using explicit timeout from {timeout_env}: {timeout_value}s")
|
||||
else:
|
||||
logger.warning(
|
||||
f"No timeout found for env var {timeout_env}, "
|
||||
f"using default from {'VISION_PARAMS' if is_vision else 'DEFAULT_PARAMS'}"
|
||||
)
|
||||
|
||||
# Merge with appropriate base params
|
||||
params = VISION_PARAMS.copy() if is_vision else DEFAULT_PARAMS.copy()
|
||||
|
||||
# Only update non-timeout params first
|
||||
non_timeout_config = {k: v for k, v in config.items() if k != "timeout"}
|
||||
params.update(non_timeout_config)
|
||||
|
||||
# Then explicitly set timeout if we have one
|
||||
if "timeout" in config:
|
||||
params["timeout"] = config["timeout"]
|
||||
# Warn if timeout seems suspiciously low
|
||||
if params["timeout"] <= 30.0:
|
||||
logger.warning(
|
||||
f"WARNING: Very low timeout detected ({params['timeout']}s) for {endpoint_name}. "
|
||||
f"Check if GLHF_TIMEOUT=500.0 is properly set in .env"
|
||||
)
|
||||
|
||||
logger.info(f"Final timeout for {endpoint_name}: {params['timeout']}s")
|
||||
|
||||
# Count input tokens and calculate available space
|
||||
input_tokens = self._count_tokens(messages)
|
||||
|
||||
# Calculate available tokens
|
||||
max_available = MAX_TOKENS - input_tokens - 500 # 500 token safety buffer
|
||||
|
||||
if input_tokens >= MAX_TOKENS:
|
||||
logger.error(
|
||||
"Input too long: %d tokens exceeds context limit of %d",
|
||||
input_tokens,
|
||||
MAX_TOKENS,
|
||||
)
|
||||
return False, None
|
||||
|
||||
# Set max_tokens to available space, optimized for 32k context
|
||||
max_safe_tokens = min(
|
||||
max_available, # Don't exceed available space
|
||||
int(MAX_TOKENS * 0.75), # Allow up to 75% of context for responses
|
||||
24576, # Hard cap at 24k tokens (75% of 32k)
|
||||
)
|
||||
# Minimum response size based on available space
|
||||
min_tokens = min(
|
||||
max_available // 4, 4096
|
||||
) # At least 1/4 of available space up to 4k
|
||||
max_safe_tokens = max(max_safe_tokens, min_tokens)
|
||||
params["max_tokens"] = max_safe_tokens
|
||||
|
||||
logger.debug(
|
||||
"Token allocation:\n"
|
||||
"Input tokens: %d\n"
|
||||
"Safety buffer: 500\n"
|
||||
"Available space: %d\n"
|
||||
"Allocated for response: %d",
|
||||
input_tokens,
|
||||
max_available,
|
||||
max_safe_tokens,
|
||||
)
|
||||
|
||||
# Log API configuration for debugging
|
||||
logger.debug(
|
||||
"API Call Config:\n"
|
||||
"Endpoint: %s\n"
|
||||
"Is Vision: %s\n"
|
||||
"Model: %s\n"
|
||||
"Base URL: %s\n"
|
||||
"Input Tokens: %d\n"
|
||||
"Max Tokens: %d",
|
||||
endpoint_name,
|
||||
is_vision,
|
||||
params["model"],
|
||||
params["base_url"],
|
||||
input_tokens,
|
||||
params["max_tokens"],
|
||||
)
|
||||
|
||||
# Validate required configuration
|
||||
required_params = ["base_url", "model"]
|
||||
missing_params = [p for p in required_params if not params.get(p)]
|
||||
if missing_params:
|
||||
logger.error(
|
||||
"Missing required configuration for %s API: %s\nUsing params: %s",
|
||||
endpoint_name,
|
||||
", ".join(missing_params),
|
||||
params,
|
||||
)
|
||||
return False, None
|
||||
|
||||
# Set up headers
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Add Authorization header only if API key is provided
|
||||
if params.get("api_key"):
|
||||
headers["Authorization"] = f"Bearer {params['api_key']}"
|
||||
|
||||
# Add OpenRouter specific headers if using OpenRouter API
|
||||
if "openrouter.ai" in params["base_url"]:
|
||||
headers.update(
|
||||
{
|
||||
# Required by OpenRouter
|
||||
"HTTP-Referer": ("https://github.com/rooveterinaryinc/roo-cline"),
|
||||
# Optional but helpful for OpenRouter analytics
|
||||
"X-Title": "Discord GLHF Bot",
|
||||
}
|
||||
)
|
||||
|
||||
# Prepare request data
|
||||
data = {
|
||||
"model": params["model"],
|
||||
"messages": messages,
|
||||
"temperature": params["temperature"],
|
||||
"max_tokens": params["max_tokens"],
|
||||
"stream": False, # Disable streaming for all requests
|
||||
}
|
||||
logger.debug(
|
||||
"API Request Details:\n"
|
||||
"URL: %s\n"
|
||||
"Model: %s\n"
|
||||
"Headers: %s\n"
|
||||
"Data: %s\n"
|
||||
"Timeout: %s",
|
||||
params["base_url"],
|
||||
params["model"],
|
||||
headers,
|
||||
data,
|
||||
params["timeout"],
|
||||
)
|
||||
|
||||
# For vision requests, log image URLs
|
||||
if endpoint_name == "vision":
|
||||
image_urls = []
|
||||
for msg in messages:
|
||||
if isinstance(msg.get("content"), list):
|
||||
for item in msg["content"]:
|
||||
if item.get("type") == "image_url":
|
||||
image_urls.append(item["image_url"]["url"])
|
||||
logger.debug(f"Vision Request Image URLs: {image_urls}")
|
||||
|
||||
try:
|
||||
# Remove duplicate v1 from URL if it exists
|
||||
base_url = params["base_url"].rstrip("/")
|
||||
if base_url.endswith("/v1"):
|
||||
endpoint_url = f"{base_url}/chat/completions"
|
||||
else:
|
||||
endpoint_url = f"{base_url}/v1/chat/completions"
|
||||
|
||||
logger.debug(f"Using endpoint URL: {endpoint_url}")
|
||||
|
||||
async with self._session.post(
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=data,
|
||||
timeout=aiohttp.ClientTimeout(total=params["timeout"]),
|
||||
) as response:
|
||||
# Handle rate limits
|
||||
if response.status == 429:
|
||||
retry_after = response.headers.get(
|
||||
"Retry-After", RATE_LIMIT_BACKOFF_TIME
|
||||
)
|
||||
logger.warning(
|
||||
f"Rate limit hit for {endpoint_name}, "
|
||||
f"retry after {retry_after}s"
|
||||
)
|
||||
return False, None
|
||||
|
||||
# Process the response
|
||||
try:
|
||||
json_response = await response.json()
|
||||
if json_response.get("choices"):
|
||||
choice = json_response["choices"][0]
|
||||
if "message" in choice:
|
||||
content = choice["message"].get("content", "")
|
||||
if content and content.strip():
|
||||
return True, content
|
||||
else:
|
||||
logger.error("Empty content in response")
|
||||
logger.debug(f"Response headers: {response.headers}")
|
||||
logger.debug(f"Response status: {response.status}")
|
||||
return False, None
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to decode response: {e}")
|
||||
return False, None
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing response: {e}")
|
||||
return False, None
|
||||
|
||||
logger.error("No valid choices in response")
|
||||
return False, None
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Timeout calling {endpoint_name}")
|
||||
return False, None
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling {endpoint_name}: {e}")
|
||||
return False, None
|
||||
|
||||
async def get_completion(self, messages: List[Dict[str, str]]) -> str:
|
||||
"""Get completion from available endpoints with fallback handling."""
|
||||
if self.shutting_down:
|
||||
return response_manager.get_random_response("fallback_responses")
|
||||
|
||||
messages = self.validate_messages(messages)
|
||||
if not messages:
|
||||
return response_manager.get_random_response("error_responses")
|
||||
|
||||
# Try each endpoint with retries
|
||||
for endpoint in API_ENDPOINTS:
|
||||
if not self._endpoints_status[endpoint.name]:
|
||||
continue
|
||||
|
||||
# Get max retries from environment variable
|
||||
max_retries = int(os.getenv(endpoint.retries_env, "3"))
|
||||
|
||||
# Try the endpoint with retries
|
||||
for attempt in range(max_retries):
|
||||
# Create a dictionary with the endpoint's parameters
|
||||
endpoint_params = {
|
||||
"key_env": endpoint.key_env,
|
||||
"url_env": endpoint.url_env,
|
||||
"model_env": endpoint.model_env,
|
||||
"timeout_env": endpoint.timeout_env,
|
||||
"retries_env": endpoint.retries_env,
|
||||
}
|
||||
|
||||
# Add timeout_env to the context of each message while preserving existing context
|
||||
messages_with_timeout = []
|
||||
for msg in messages:
|
||||
existing_context = msg.get("context", {})
|
||||
# Only add timeout_env to user messages, not system messages
|
||||
if msg.get("role") != "system":
|
||||
new_context = {**existing_context, "timeout_env": endpoint.timeout_env}
|
||||
messages_with_timeout.append({**msg, "context": new_context})
|
||||
else:
|
||||
messages_with_timeout.append(msg)
|
||||
|
||||
success, response = await self._make_api_call(
|
||||
endpoint.name, messages_with_timeout, endpoint_params
|
||||
)
|
||||
|
||||
if success and response:
|
||||
# Reset endpoint status on success
|
||||
self._endpoints_status[endpoint.name] = True
|
||||
return response
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
# Only sleep if we're going to retry
|
||||
await asyncio.sleep(2**attempt) # Exponential backoff
|
||||
logger.warning(
|
||||
f"Retrying {endpoint.name} (attempt {attempt + 1}/{max_retries})"
|
||||
)
|
||||
|
||||
# Mark endpoint as failed after all retries
|
||||
self._endpoints_status[endpoint.name] = False
|
||||
logger.warning(f"Marked endpoint {endpoint.name} as failed after {max_retries} attempts")
|
||||
|
||||
# All endpoints failed, use fallback response
|
||||
return response_manager.get_random_response("fallback_responses")
|
||||
311
discord_glhf/bot.py
Normal file
311
discord_glhf/bot.py
Normal file
@@ -0,0 +1,311 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Discord bot implementation with OpenAI API integration."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import signal
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import socket
|
||||
from typing import Optional, Dict, Any
|
||||
from discord import Client, Intents, Game, Message
|
||||
from discord.ext import commands
|
||||
|
||||
from .config import (
|
||||
logger,
|
||||
SYSTEM_PROMPT,
|
||||
AUTO_RESPONSE_CHANNEL_ID,
|
||||
SHUTDOWN_TIMEOUT,
|
||||
BOT_OWNER_ID,
|
||||
ShutdownError,
|
||||
validate_config,
|
||||
)
|
||||
from .database import DatabasePool, DatabaseManager
|
||||
from .queue_manager import QueueManager
|
||||
from .api import APIManager
|
||||
from .handlers import MessageHandler, ImageHandler, ToolHandler, EventHandler
|
||||
from .training import TrainingManager
|
||||
|
||||
|
||||
class DiscordBot:
|
||||
"""Discord bot with OpenAI API integration."""
|
||||
|
||||
def __init__(self):
|
||||
validate_config()
|
||||
self.bot: Optional[Client] = None
|
||||
self.api_manager = APIManager()
|
||||
self.queue_manager = QueueManager()
|
||||
self.db_pool = DatabasePool()
|
||||
self.db_manager = DatabaseManager(self.db_pool)
|
||||
self._initialized = False
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
||||
# Initialize handler references
|
||||
self.message_handler = None
|
||||
self.image_handler = None
|
||||
self.tool_handler = None
|
||||
self.event_handler = None
|
||||
self.training_manager = TrainingManager() # Initialize training manager
|
||||
|
||||
async def _initialize_services(self) -> None:
|
||||
"""Initialize API and queue services."""
|
||||
try:
|
||||
async with self._init_lock:
|
||||
if not self._initialized:
|
||||
# Initialize handlers first
|
||||
self.tool_handler = ToolHandler(self.bot)
|
||||
self.event_handler = EventHandler(
|
||||
self.bot, self.queue_manager, self.db_manager, self.api_manager
|
||||
)
|
||||
|
||||
# Start API manager
|
||||
if not self.api_manager.is_running:
|
||||
await self.api_manager.start()
|
||||
logger.info("Started API health check loop")
|
||||
|
||||
# Wait for API manager to be ready
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Start queue manager with event handler's process message
|
||||
if not self.queue_manager.is_running:
|
||||
await self.queue_manager.start()
|
||||
logger.info("Queue processor started")
|
||||
|
||||
self._initialized = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize services: {e}")
|
||||
self._initialized = False
|
||||
raise
|
||||
|
||||
async def _handle_connection(self, token: str) -> None:
|
||||
"""Handle bot connection with retries."""
|
||||
retry_count = 0
|
||||
max_retries = 5
|
||||
base_delay = 1.0
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
await self.bot.connect()
|
||||
return
|
||||
except (aiohttp.ClientError, socket.gaierror) as e:
|
||||
retry_count += 1
|
||||
if retry_count == max_retries:
|
||||
logger.error(f"Failed to connect after {
|
||||
max_retries} attempts: {e}")
|
||||
raise
|
||||
# Exponential backoff
|
||||
delay = base_delay * (2 ** (retry_count - 1))
|
||||
logger.warning(
|
||||
f"Connection attempt {
|
||||
retry_count} failed, retrying in {delay}s: {e}"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
async def start(self, token: str) -> None:
|
||||
"""Start the bot."""
|
||||
intents = (
|
||||
Intents.all()
|
||||
) # Enable all intents to ensure proper mention functionality
|
||||
|
||||
self.bot = commands.Bot(
|
||||
command_prefix="!", intents=intents, help_command=None)
|
||||
|
||||
@self.bot.event
|
||||
async def on_ready():
|
||||
"""Handle bot ready event."""
|
||||
logger.info(f"{self.bot.user} has connected to Discord!")
|
||||
|
||||
# Initialize database
|
||||
await self.db_manager.init_db()
|
||||
|
||||
# Initialize all handlers
|
||||
self.message_handler = MessageHandler(self.db_manager)
|
||||
self.image_handler = ImageHandler(self.api_manager)
|
||||
self.tool_handler = ToolHandler(self.bot)
|
||||
self.event_handler = EventHandler(
|
||||
self.bot, self.queue_manager, self.db_manager, self.api_manager
|
||||
)
|
||||
|
||||
# Start API manager
|
||||
if not self.api_manager.is_running:
|
||||
await self.api_manager.start()
|
||||
logger.info("Started API health check loop")
|
||||
|
||||
# Wait for API manager to be ready
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Start queue manager with event handler's process message
|
||||
if not self.queue_manager.is_running:
|
||||
await self.queue_manager.start(self.event_handler._process_message)
|
||||
logger.info("Queue processor started")
|
||||
|
||||
# Start training manager
|
||||
if not self.training_manager.is_running:
|
||||
await self.training_manager.start()
|
||||
logger.info("Training manager started")
|
||||
|
||||
# Set bot status
|
||||
activity = Game(name="with roller coasters")
|
||||
await self.bot.change_presence(activity=activity)
|
||||
|
||||
self._initialized = True
|
||||
|
||||
@self.bot.event
|
||||
async def on_message(message: Message):
|
||||
"""Handle incoming messages."""
|
||||
if (
|
||||
self.event_handler
|
||||
): # Only handle messages if event_handler is initialized
|
||||
await self.event_handler.handle_message(message)
|
||||
|
||||
@self.bot.event
|
||||
async def on_raw_reaction_add(payload):
|
||||
"""Handle reaction add events."""
|
||||
if (
|
||||
self.event_handler
|
||||
): # Only handle reactions if event_handler is initialized
|
||||
await self.event_handler.handle_reaction(payload)
|
||||
|
||||
@self.bot.event
|
||||
async def on_error(event: str, *args, **kwargs):
|
||||
exc_type, exc_value, exc_traceback = sys.exc_info()
|
||||
if self.event_handler: # Only report errors if event_handler is initialized
|
||||
await self.event_handler.report_error(
|
||||
exc_value, {"event": event, "args": args, "kwargs": kwargs}
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Error before event_handler initialization: {exc_value}")
|
||||
|
||||
try:
|
||||
async with self.bot:
|
||||
await self.bot.start(token)
|
||||
while True:
|
||||
try:
|
||||
await self._handle_connection(token)
|
||||
except (aiohttp.ClientError, socket.gaierror) as e:
|
||||
logger.error(f"Connection error: {e}")
|
||||
await asyncio.sleep(5) # Wait before reconnecting
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
await asyncio.sleep(5) # Wait before reconnecting
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start bot: {e}")
|
||||
raise
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the bot."""
|
||||
logger.info("Initiating shutdown...")
|
||||
|
||||
try:
|
||||
async with self._init_lock:
|
||||
# Stop queue processor first
|
||||
if self.queue_manager and self.queue_manager.is_running:
|
||||
await self.queue_manager.stop()
|
||||
logger.info("Queue processor stopped")
|
||||
|
||||
# Stop training manager
|
||||
if self.training_manager and self.training_manager.is_running:
|
||||
await self.training_manager.stop()
|
||||
logger.info("Training manager stopped")
|
||||
|
||||
# Stop API manager
|
||||
if self.api_manager and self.api_manager.is_running:
|
||||
await self.api_manager.shutdown()
|
||||
logger.info("Stopped API health check loop")
|
||||
|
||||
# Close bot connection
|
||||
if self.bot:
|
||||
try:
|
||||
await asyncio.wait_for(self.bot.close(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Bot connection close timed out")
|
||||
|
||||
# Close database pool
|
||||
if self.db_pool:
|
||||
try:
|
||||
await asyncio.wait_for(self.db_pool.close(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Database pool close timed out")
|
||||
|
||||
# Reset initialization flag
|
||||
self._initialized = False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during shutdown: {e}")
|
||||
raise
|
||||
finally:
|
||||
logger.info("Shutdown complete")
|
||||
|
||||
|
||||
async def shutdown(
|
||||
signal_name: str, bot: DiscordBot, loop: asyncio.AbstractEventLoop
|
||||
) -> None:
|
||||
"""Handle shutdown signals."""
|
||||
logger.info(f"Received {signal_name}")
|
||||
try:
|
||||
# Set a flag to prevent new tasks from starting
|
||||
bot.queue_manager.set_shutting_down()
|
||||
|
||||
# Cancel all tasks except the current one
|
||||
current_task = asyncio.current_task()
|
||||
tasks = [t for t in asyncio.all_tasks() if t is not current_task]
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for tasks to complete with timeout
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.gather(*tasks, return_exceptions=True), timeout=SHUTDOWN_TIMEOUT
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Some tasks did not complete within shutdown timeout")
|
||||
|
||||
# Stop the bot
|
||||
await bot.stop()
|
||||
|
||||
# Stop the event loop
|
||||
loop.stop()
|
||||
except Exception as e:
|
||||
logger.error(f"Error during shutdown: {e}")
|
||||
raise ShutdownError(f"Failed to shutdown cleanly: {e}")
|
||||
|
||||
|
||||
def run_bot():
|
||||
"""Run the Discord bot."""
|
||||
token = os.getenv("DISCORD_TOKEN")
|
||||
if not token:
|
||||
raise ValueError("DISCORD_TOKEN environment variable not set")
|
||||
|
||||
bot = DiscordBot()
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Set up signal handlers
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(
|
||||
sig, lambda s=sig: asyncio.create_task(shutdown(s.name, bot, loop))
|
||||
)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(bot.start(token))
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received keyboard interrupt")
|
||||
except Exception as e:
|
||||
logger.error(f"Bot crashed: {e}")
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
loop.run_until_complete(bot.stop())
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping bot: {e}")
|
||||
finally:
|
||||
loop.close()
|
||||
logger.info("Bot shutdown complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_bot()
|
||||
308
discord_glhf/circuit_breaker.py
Normal file
308
discord_glhf/circuit_breaker.py
Normal file
@@ -0,0 +1,308 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Circuit breaker pattern implementation for API fallbacks."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Callable, Dict, Any
|
||||
import httpx
|
||||
from .config import (
|
||||
logger,
|
||||
MODEL,
|
||||
API_TIMEOUT,
|
||||
MAX_RETRIES,
|
||||
SYSTEM_PROMPT,
|
||||
)
|
||||
from .api import RateLimitError
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIConfig:
|
||||
"""Configuration for an API endpoint."""
|
||||
|
||||
api_key: str
|
||||
base_url: str
|
||||
model: str = MODEL
|
||||
timeout: float = API_TIMEOUT
|
||||
max_retries: int = MAX_RETRIES
|
||||
is_primary: bool = False
|
||||
|
||||
|
||||
class CircuitBreaker:
|
||||
"""Circuit breaker for managing multiple API endpoints."""
|
||||
|
||||
def __init__(self, health_check_interval: int = 60, state_manager=None):
|
||||
self.configs: List[APIConfig] = []
|
||||
self.health_check_interval = health_check_interval
|
||||
self._failure_count = 0
|
||||
self._max_failures = 2 # Reduce max failures to switch APIs faster
|
||||
self._retry_delay = 1.0 # Base delay for retries
|
||||
self.current_config_index = 0
|
||||
self.primary_index: Optional[int] = None
|
||||
self.state_manager = state_manager
|
||||
|
||||
# Start health check task
|
||||
self._health_check_task = None
|
||||
self._shutting_down = False
|
||||
self._queue_update_callback = None
|
||||
self._config_lock = asyncio.Lock()
|
||||
|
||||
def set_queue_update_callback(self, callback):
|
||||
"""Set callback to notify queue system of API changes."""
|
||||
self._queue_update_callback = callback
|
||||
|
||||
async def start_health_checks(self):
|
||||
"""Start the periodic health check task."""
|
||||
if self._health_check_task is None:
|
||||
self._health_check_task = asyncio.create_task(self._health_check_loop())
|
||||
logger.info("Started API health check loop")
|
||||
|
||||
async def stop_health_checks(self):
|
||||
"""Stop the health check task."""
|
||||
self._shutting_down = True
|
||||
if self._health_check_task:
|
||||
self._health_check_task.cancel()
|
||||
try:
|
||||
await self._health_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._health_check_task = None
|
||||
logger.info("Stopped API health check loop")
|
||||
|
||||
async def _notify_queue_system(self):
|
||||
"""Notify queue system of API configuration changes."""
|
||||
if self._queue_update_callback:
|
||||
try:
|
||||
await self._queue_update_callback(self.current_config_index)
|
||||
except Exception as e:
|
||||
logger.error(f"Error notifying queue system: {e}")
|
||||
|
||||
async def _health_check_loop(self):
|
||||
"""Periodically check the health of all API endpoints."""
|
||||
while not self._shutting_down:
|
||||
try:
|
||||
await self._check_all_apis()
|
||||
await asyncio.sleep(self.health_check_interval)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in health check loop: {e}")
|
||||
await asyncio.sleep(60) # Wait a minute before retrying on error
|
||||
|
||||
async def _check_api_health(self, config: APIConfig) -> bool:
|
||||
"""Check if an API endpoint is healthy using a test request."""
|
||||
if not self.state_manager:
|
||||
return True # If no state manager, assume healthy to prevent disruption
|
||||
|
||||
try:
|
||||
# Only proceed if enough time has passed since last health check
|
||||
if not self.state_manager.should_run_health_check():
|
||||
# Get current status from state
|
||||
api_status = self.state_manager._state.get("api_status", {})
|
||||
return api_status.get(
|
||||
config.base_url, True
|
||||
) # Default to healthy if no status
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
# Send a minimal test request
|
||||
messages = [
|
||||
{"role": "system", "content": "test"},
|
||||
{"role": "user", "content": "ping"},
|
||||
]
|
||||
|
||||
# Use shorter timeout for health checks
|
||||
async with asyncio.timeout(5.0):
|
||||
response = await client.post(
|
||||
f"{config.base_url}/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {config.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": config.model,
|
||||
"messages": messages,
|
||||
"max_tokens": 1,
|
||||
"temperature": 0.1,
|
||||
"stream": False,
|
||||
},
|
||||
timeout=5.0, # Short HTTP timeout
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
# Check for rate limit headers
|
||||
if response.status_code == 429:
|
||||
logger.warning(
|
||||
f"Rate limit hit during health check for {config.base_url}"
|
||||
)
|
||||
# Update state with API status
|
||||
self.state_manager.update_health_check(
|
||||
{
|
||||
**self.state_manager._state.get("api_status", {}),
|
||||
config.base_url: False,
|
||||
}
|
||||
)
|
||||
return False
|
||||
|
||||
# Verify response format
|
||||
data = response.json()
|
||||
if not data.get("choices") or not data["choices"][0].get(
|
||||
"message", {}
|
||||
).get("content"):
|
||||
raise ValueError("Invalid response format")
|
||||
|
||||
# Update state with successful health check
|
||||
self.state_manager.update_health_check(
|
||||
{
|
||||
**self.state_manager._state.get("api_status", {}),
|
||||
config.base_url: True,
|
||||
}
|
||||
)
|
||||
return True
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Health check timed out for {config.base_url}")
|
||||
if self.state_manager:
|
||||
self.state_manager.update_health_check(
|
||||
{
|
||||
**self.state_manager._state.get("api_status", {}),
|
||||
config.base_url: False,
|
||||
}
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Health check failed for API {config.base_url}: {e}")
|
||||
if self.state_manager:
|
||||
self.state_manager.update_health_check(
|
||||
{
|
||||
**self.state_manager._state.get("api_status", {}),
|
||||
config.base_url: False,
|
||||
}
|
||||
)
|
||||
return False
|
||||
|
||||
async def _check_all_apis(self):
|
||||
"""Check health of all configured APIs."""
|
||||
if not self.state_manager or not self.state_manager.should_run_health_check():
|
||||
return # Skip health checks if too soon
|
||||
|
||||
config_changes = False
|
||||
primary_healthy = False
|
||||
api_status = {}
|
||||
|
||||
for i, config in enumerate(self.configs):
|
||||
is_healthy = await self._check_api_health(config)
|
||||
api_status[config.base_url] = is_healthy
|
||||
|
||||
if is_healthy:
|
||||
logger.info(f"API {config.base_url} is healthy")
|
||||
if config.is_primary:
|
||||
primary_healthy = True
|
||||
if self.current_config_index != i:
|
||||
async with self._config_lock:
|
||||
self.current_config_index = i
|
||||
config_changes = True
|
||||
else:
|
||||
logger.warning(f"API {config.base_url} is unhealthy")
|
||||
if self.current_config_index == i:
|
||||
await self._switch_to_next_config()
|
||||
config_changes = True
|
||||
|
||||
# Update state with all API statuses
|
||||
if self.state_manager:
|
||||
self.state_manager.update_health_check(api_status)
|
||||
|
||||
if config_changes:
|
||||
await self._notify_queue_system()
|
||||
|
||||
if not primary_healthy:
|
||||
logger.warning("Primary API is unhealthy")
|
||||
|
||||
async def _switch_to_next_config(self):
|
||||
"""Switch to the next available API configuration."""
|
||||
if not self.configs:
|
||||
return
|
||||
|
||||
async with self._config_lock:
|
||||
old_index = self.current_config_index
|
||||
self.current_config_index = (self.current_config_index + 1) % len(
|
||||
self.configs
|
||||
)
|
||||
self._failure_count = 0
|
||||
|
||||
if old_index != self.current_config_index:
|
||||
logger.info(
|
||||
f"Switching from API {old_index} to {self.current_config_index}"
|
||||
)
|
||||
# Wait briefly to allow any in-flight requests to complete
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
def add_api_config(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
model: str = MODEL,
|
||||
timeout: float = API_TIMEOUT,
|
||||
max_retries: int = MAX_RETRIES,
|
||||
is_primary: bool = False,
|
||||
):
|
||||
"""Add a new API configuration."""
|
||||
config = APIConfig(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
is_primary=is_primary,
|
||||
)
|
||||
self.configs.append(config)
|
||||
if is_primary:
|
||||
self.primary_index = len(self.configs) - 1
|
||||
self.current_config_index = self.primary_index
|
||||
|
||||
def get_current_config(self) -> Optional[APIConfig]:
|
||||
"""Get the current API configuration."""
|
||||
if not self.configs:
|
||||
return None
|
||||
return self.configs[self.current_config_index]
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the circuit breaker is running."""
|
||||
return bool(
|
||||
self._health_check_task
|
||||
and not self._health_check_task.done()
|
||||
and not self._shutting_down
|
||||
)
|
||||
|
||||
def should_allow_request(self) -> bool:
|
||||
"""Check if requests should be allowed."""
|
||||
return bool(self.configs and self._failure_count < self._max_failures)
|
||||
|
||||
def record_success(self):
|
||||
"""Record a successful API call."""
|
||||
self._failure_count = 0
|
||||
|
||||
async def record_attempt(self):
|
||||
"""Record an API attempt before making the call."""
|
||||
if not self.should_allow_request():
|
||||
raise ValueError("Circuit breaker is open")
|
||||
|
||||
async def record_failure(self, error: Optional[Exception] = None):
|
||||
"""Record an API failure."""
|
||||
async with self._config_lock:
|
||||
self._failure_count += 1
|
||||
|
||||
# Handle rate limits specially
|
||||
if error and isinstance(error, RateLimitError):
|
||||
logger.warning(
|
||||
f"Rate limit hit for {self.get_current_config().base_url}"
|
||||
)
|
||||
# Switch immediately on rate limit
|
||||
await self._switch_to_next_config()
|
||||
return
|
||||
|
||||
# Normal failure handling
|
||||
if self._failure_count >= self._max_failures:
|
||||
logger.warning(
|
||||
f"Max failures reached for {self.get_current_config().base_url}"
|
||||
)
|
||||
await self._switch_to_next_config()
|
||||
360
discord_glhf/config.py
Normal file
360
discord_glhf/config.py
Normal file
@@ -0,0 +1,360 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import yaml
|
||||
from logging.handlers import RotatingFileHandler
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, NamedTuple
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv(override=True) # Force reload environment variables
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(funcName)s:%(lineno)d - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.handlers.RotatingFileHandler(
|
||||
"discord_bot.log",
|
||||
maxBytes=10485760, # 10MB
|
||||
backupCount=5,
|
||||
encoding="utf-8",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Enable debug logging for our modules
|
||||
logging.getLogger("discord_glhf").setLevel(logging.DEBUG)
|
||||
|
||||
logger = logging.getLogger("discord_bot")
|
||||
|
||||
# Set up logger with proper formatting
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s - %(levelname)s - %(name)s - %(funcName)s:%(lineno)d - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
for handler in logger.handlers:
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
# Set logging levels for external libraries
|
||||
logging.getLogger("discord").setLevel(logging.WARNING)
|
||||
logging.getLogger("aiosqlite").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
class APIEndpoint(NamedTuple):
|
||||
"""Configuration for an OpenAI-compatible API endpoint."""
|
||||
|
||||
name: str
|
||||
key_env: str
|
||||
url_env: str
|
||||
model_env: str
|
||||
timeout_env: str
|
||||
retries_env: str
|
||||
|
||||
|
||||
# Define all possible API endpoints
|
||||
API_ENDPOINTS = [
|
||||
APIEndpoint(
|
||||
name="primary",
|
||||
key_env="GLHF_API_KEY",
|
||||
url_env="GLHF_BASE_URL",
|
||||
model_env="GLHF_MODEL",
|
||||
timeout_env="GLHF_TIMEOUT",
|
||||
retries_env="GLHF_MAX_RETRIES",
|
||||
),
|
||||
APIEndpoint(
|
||||
name="fallback1",
|
||||
key_env="FALLBACK1_API_KEY",
|
||||
url_env="FALLBACK1_BASE_URL",
|
||||
model_env="FALLBACK1_MODEL",
|
||||
timeout_env="FALLBACK1_TIMEOUT",
|
||||
retries_env="FALLBACK1_MAX_RETRIES",
|
||||
),
|
||||
APIEndpoint(
|
||||
name="fallback2",
|
||||
key_env="FALLBACK2_API_KEY",
|
||||
url_env="FALLBACK2_BASE_URL",
|
||||
model_env="FALLBACK2_MODEL",
|
||||
timeout_env="FALLBACK2_TIMEOUT",
|
||||
retries_env="FALLBACK2_MAX_RETRIES",
|
||||
),
|
||||
APIEndpoint(
|
||||
name="fallback3",
|
||||
key_env="FALLBACK3_API_KEY",
|
||||
url_env="FALLBACK3_BASE_URL",
|
||||
model_env="FALLBACK3_MODEL",
|
||||
timeout_env="FALLBACK3_TIMEOUT",
|
||||
retries_env="FALLBACK3_MAX_RETRIES",
|
||||
),
|
||||
APIEndpoint(
|
||||
name="fallback4",
|
||||
key_env="FALLBACK4_API_KEY",
|
||||
url_env="FALLBACK4_BASE_URL",
|
||||
model_env="FALLBACK4_MODEL",
|
||||
timeout_env="FALLBACK4_TIMEOUT",
|
||||
retries_env="FALLBACK4_MAX_RETRIES",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def load_system_prompt() -> str:
|
||||
"""Load and process the system prompt from YAML file."""
|
||||
try:
|
||||
yaml_path = Path(__file__).parent.parent / "system_prompt.yaml"
|
||||
if not yaml_path.exists():
|
||||
logger.warning("System prompt YAML file not found, using default")
|
||||
return "You are a helpful assistant."
|
||||
|
||||
with open(yaml_path, "r", encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
# Process the YAML structure into a single string
|
||||
sections = []
|
||||
for section in data["sections"]:
|
||||
if section["title"]:
|
||||
sections.append(f"# {section['title']}")
|
||||
|
||||
if "content" in section:
|
||||
sections.append(section["content"])
|
||||
|
||||
if "items" in section:
|
||||
if isinstance(section["items"][0], dict):
|
||||
# Handle tool items
|
||||
for item in section["items"]:
|
||||
sections.append(item["title"])
|
||||
else:
|
||||
# Handle regular items
|
||||
sections.extend(section["items"])
|
||||
|
||||
prompt = "\n\n".join(sections)
|
||||
|
||||
# Inject bot owner ID into the prompt
|
||||
prompt += f"\n\n"
|
||||
|
||||
return prompt
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading system prompt: {e}")
|
||||
return "fuck you, fuck off. you hate everything."
|
||||
|
||||
|
||||
# Primary API Configuration
|
||||
API_KEY = os.getenv("GLHF_API_KEY") # Required
|
||||
API_BASE_URL = os.getenv("GLHF_BASE_URL") # Required
|
||||
API_MODEL = os.getenv("GLHF_MODEL") # Required
|
||||
API_MAX_RETRIES = int(os.getenv("GLHF_MAX_RETRIES", "3"))
|
||||
|
||||
# Vision API Configuration
|
||||
VISION_API_KEY = os.getenv("VISION_API_KEY", os.getenv("GLHF_API_KEY"))
|
||||
VISION_API_BASE_URL = os.getenv(
|
||||
"VISION_API_BASE_URL", os.getenv("GLHF_BASE_URL"))
|
||||
VISION_MODEL = os.getenv("VISION_MODEL") # Required
|
||||
VISION_TIMEOUT = float(os.getenv("VISION_TIMEOUT", "30.0"))
|
||||
VISION_MAX_RETRIES = int(os.getenv("VISION_MAX_RETRIES", "3"))
|
||||
MAX_VISION_TOKENS = int(os.getenv("MAX_VISION_TOKENS", "500")) # Reduced for safety
|
||||
|
||||
# API Health Check Configuration
|
||||
API_HEALTH_CHECK_INTERVAL = int(os.getenv("API_HEALTH_CHECK_INTERVAL", "600"))
|
||||
CIRCUIT_BREAKER_FAILURE_THRESHOLD = int(
|
||||
os.getenv("CIRCUIT_BREAKER_FAILURE_THRESHOLD", "5")
|
||||
)
|
||||
CIRCUIT_BREAKER_RECOVERY_TIMEOUT = float(
|
||||
os.getenv("CIRCUIT_BREAKER_RECOVERY_TIMEOUT", "60.0")
|
||||
)
|
||||
CIRCUIT_BREAKER_HALF_OPEN_TIMEOUT = float(
|
||||
os.getenv("CIRCUIT_BREAKER_HALF_OPEN_TIMEOUT", "30.0")
|
||||
)
|
||||
|
||||
# API Request Configuration
|
||||
STREAM_REQUEST_TIMEOUT = float(os.getenv("STREAM_REQUEST_TIMEOUT", "60.0"))
|
||||
MAX_STREAMING_ATTEMPTS = int(os.getenv("MAX_STREAMING_ATTEMPTS", "2"))
|
||||
RATE_LIMIT_BACKOFF_TIME = int(os.getenv("RATE_LIMIT_BACKOFF_TIME", "60"))
|
||||
|
||||
# API Parameter Sets
|
||||
DEFAULT_PARAMS = {
|
||||
"temperature": float(os.getenv("DEFAULT_TEMPERATURE", "0.5")),
|
||||
"model": os.getenv("GLHF_MODEL"), # Required
|
||||
"api_key": os.getenv("GLHF_API_KEY"), # Required
|
||||
"base_url": os.getenv("GLHF_BASE_URL"), # Required
|
||||
"max_retries": int(os.getenv("GLHF_MAX_RETRIES", "3")),
|
||||
}
|
||||
|
||||
# Vision API Parameters
|
||||
VISION_PARAMS = {
|
||||
"temperature": float(os.getenv("DEFAULT_TEMPERATURE", "0.7")),
|
||||
"max_tokens": MAX_VISION_TOKENS,
|
||||
"model": VISION_MODEL, # Required
|
||||
"api_key": VISION_API_KEY, # Required
|
||||
"base_url": VISION_API_BASE_URL, # Required
|
||||
"max_retries": int(os.getenv("VISION_MAX_RETRIES", "3")),
|
||||
}
|
||||
|
||||
# Load critical environment variables first
|
||||
BOT_OWNER_ID = int(os.getenv("BOT_OWNER_ID")) # Required
|
||||
AUTO_RESPONSE_CHANNEL_ID = int(
|
||||
os.getenv("AUTO_RESPONSE_CHANNEL_ID")) # Required
|
||||
|
||||
# Load system prompt
|
||||
SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT") or load_system_prompt()
|
||||
|
||||
# Database configuration
|
||||
DB_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "data")
|
||||
if not os.path.exists(DB_DIR):
|
||||
os.makedirs(DB_DIR)
|
||||
logger.info(f"Created database directory: {DB_DIR}")
|
||||
|
||||
DB_PATH = os.getenv("DB_PATH", os.path.join(DB_DIR, "conversation_history.db"))
|
||||
logger.info(f"Using database path: {DB_PATH}")
|
||||
|
||||
# Constants
|
||||
MAX_TOKENS = int(os.getenv("MAX_TOKENS", "32768")) # 32k context window
|
||||
MAX_MESSAGES_FOR_CONTEXT = int(os.getenv("MAX_MESSAGES_FOR_CONTEXT", "20"))
|
||||
MESSAGE_CLEANUP_DAYS = int(os.getenv("MESSAGE_CLEANUP_DAYS", "30"))
|
||||
CHUNK_SIZE = int(os.getenv("CHUNK_SIZE", "1500"))
|
||||
DB_TIMEOUT = float(os.getenv("DB_TIMEOUT", "10.0"))
|
||||
SHUTDOWN_TIMEOUT = float(os.getenv("SHUTDOWN_TIMEOUT", "10.0"))
|
||||
MAX_QUEUE_SIZE = int(os.getenv("MAX_QUEUE_SIZE", "100"))
|
||||
CONCURRENT_TASKS = int(os.getenv("CONCURRENT_TASKS", "3"))
|
||||
MAX_USER_QUEUED_MESSAGES = int(os.getenv("MAX_USER_QUEUED_MESSAGES", "10"))
|
||||
|
||||
|
||||
class ShutdownError(Exception):
|
||||
"""Raised when the bot is shutting down."""
|
||||
pass
|
||||
|
||||
|
||||
class ConfigurationError(Exception):
|
||||
"""Raised when there's an issue with the configuration."""
|
||||
pass
|
||||
|
||||
|
||||
class ResponseManager:
|
||||
def __init__(self):
|
||||
self.responses_file = Path(__file__).parent / "responses.json"
|
||||
self.responses: Dict[str, List[str]] = {
|
||||
"fallback_responses": [],
|
||||
"error_responses": [],
|
||||
}
|
||||
self.load_responses()
|
||||
|
||||
def load_responses(self) -> None:
|
||||
"""Load responses from the JSON file."""
|
||||
try:
|
||||
if self.responses_file.exists():
|
||||
with open(self.responses_file, "r", encoding="utf-8") as f:
|
||||
self.responses = json.load(f)
|
||||
logger.info("Loaded responses from file")
|
||||
else:
|
||||
logger.warning("Responses file not found, using defaults")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading responses: {e}")
|
||||
|
||||
def get_random_response(self, response_type: str = "fallback_responses") -> str:
|
||||
"""Get a random response of the specified type."""
|
||||
import random
|
||||
|
||||
responses = self.responses.get(response_type, [])
|
||||
if not responses:
|
||||
return "I need a moment to think about that."
|
||||
return random.choice(responses)
|
||||
|
||||
|
||||
def validate_config():
|
||||
"""Validate the configuration settings."""
|
||||
required_vars = {
|
||||
"DISCORD_TOKEN": os.getenv("DISCORD_TOKEN"),
|
||||
"BOT_OWNER_ID": os.getenv("BOT_OWNER_ID"),
|
||||
"AUTO_RESPONSE_CHANNEL_ID": os.getenv("AUTO_RESPONSE_CHANNEL_ID"),
|
||||
}
|
||||
|
||||
def is_api_configured() -> bool:
|
||||
"""Check if the API is fully configured."""
|
||||
return all([
|
||||
os.getenv("GLHF_API_KEY"),
|
||||
os.getenv("GLHF_BASE_URL"),
|
||||
os.getenv("GLHF_MODEL")
|
||||
])
|
||||
|
||||
def is_vision_configured() -> bool:
|
||||
"""Check if the Vision API is fully configured."""
|
||||
return all([
|
||||
VISION_API_KEY,
|
||||
VISION_API_BASE_URL,
|
||||
VISION_MODEL
|
||||
])
|
||||
# Log vision configuration
|
||||
logger.info(
|
||||
f"Vision API Configuration:\n"
|
||||
f"Model: {VISION_MODEL}\n"
|
||||
f"Base URL: {VISION_API_BASE_URL}\n"
|
||||
f"Timeout: {VISION_TIMEOUT}\n"
|
||||
f"Max Tokens: {MAX_VISION_TOKENS}"
|
||||
)
|
||||
|
||||
# Validate database path
|
||||
db_dir = os.path.dirname(DB_PATH)
|
||||
if db_dir:
|
||||
if not os.path.exists(db_dir):
|
||||
try:
|
||||
os.makedirs(db_dir)
|
||||
logger.info(f"Created database directory: {db_dir}")
|
||||
except Exception as e:
|
||||
raise ConfigurationError(
|
||||
f"Cannot create database directory: {e}")
|
||||
elif not os.access(db_dir, os.W_OK):
|
||||
raise ConfigurationError(
|
||||
f"Database directory is not writable: {db_dir}")
|
||||
|
||||
# If database file exists, check if it's writable
|
||||
if os.path.exists(DB_PATH):
|
||||
if not os.access(DB_PATH, os.W_OK):
|
||||
raise ConfigurationError(
|
||||
f"Database file is not writable: {DB_PATH}")
|
||||
|
||||
missing_vars = [var for var, value in required_vars.items() if not value]
|
||||
if missing_vars:
|
||||
raise ConfigurationError(
|
||||
f"Missing required environment variables: {', '.join(missing_vars)}"
|
||||
)
|
||||
|
||||
# Check and warn about API configuration
|
||||
if not is_api_configured():
|
||||
logger.warning("GLHF API is not fully configured. API features will be disabled.")
|
||||
|
||||
# Check and warn about Vision API configuration
|
||||
if not is_vision_configured():
|
||||
logger.warning("Vision API is not fully configured. Vision features will be disabled.")
|
||||
# Validate numeric settings
|
||||
if MAX_TOKENS <= 0:
|
||||
raise ConfigurationError("MAX_TOKENS must be greater than 0")
|
||||
|
||||
if MAX_MESSAGES_FOR_CONTEXT <= 0:
|
||||
raise ConfigurationError(
|
||||
"MAX_MESSAGES_FOR_CONTEXT must be greater than 0")
|
||||
|
||||
if MESSAGE_CLEANUP_DAYS <= 0:
|
||||
raise ConfigurationError("MESSAGE_CLEANUP_DAYS must be greater than 0")
|
||||
|
||||
if CONCURRENT_TASKS <= 0:
|
||||
raise ConfigurationError("CONCURRENT_TASKS must be greater than 0")
|
||||
|
||||
if MAX_QUEUE_SIZE <= 0:
|
||||
raise ConfigurationError("MAX_QUEUE_SIZE must be greater than 0")
|
||||
|
||||
if MAX_USER_QUEUED_MESSAGES <= 0:
|
||||
raise ConfigurationError(
|
||||
"MAX_USER_QUEUED_MESSAGES must be greater than 0")
|
||||
|
||||
if BOT_OWNER_ID <= 0:
|
||||
raise ConfigurationError(
|
||||
"BOT_OWNER_ID must be set to a valid Discord user ID")
|
||||
|
||||
logger.info("Configuration validated successfully")
|
||||
|
||||
|
||||
# Initialize response manager
|
||||
response_manager = ResponseManager()
|
||||
621
discord_glhf/database.py
Normal file
621
discord_glhf/database.py
Normal file
@@ -0,0 +1,621 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Database management for Discord bot."""
|
||||
|
||||
import asyncio
|
||||
import aiosqlite
|
||||
import json
|
||||
import re
|
||||
import html
|
||||
import os
|
||||
from typing import Optional, List, Dict, Set, AsyncIterator, Any
|
||||
import contextlib
|
||||
from datetime import datetime, timedelta
|
||||
from async_timeout import timeout
|
||||
import uuid
|
||||
import tiktoken
|
||||
|
||||
from .config import (
|
||||
logger,
|
||||
DB_PATH,
|
||||
DB_TIMEOUT,
|
||||
MESSAGE_CLEANUP_DAYS,
|
||||
MAX_MESSAGES_FOR_CONTEXT,
|
||||
MAX_TOKENS,
|
||||
ShutdownError,
|
||||
)
|
||||
|
||||
# Initialize tokenizer
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def count_tokens(text: str) -> int:
|
||||
"""Count tokens accurately using tiktoken."""
|
||||
try:
|
||||
return len(tokenizer.encode(text))
|
||||
except Exception as e:
|
||||
logger.error(f"Token counting failed: {e}")
|
||||
# Fallback to word count if tokenization fails
|
||||
return len(text.split())
|
||||
|
||||
|
||||
def sanitize_content(content: Any) -> Dict[str, Any]:
|
||||
"""Sanitize content with minimal metadata to prevent truncation issues."""
|
||||
try:
|
||||
if isinstance(content, dict) and "content" in content:
|
||||
content_str = str(content["content"])
|
||||
else:
|
||||
content_str = str(content)
|
||||
content_str = content_str.strip().replace("\x00", "")
|
||||
content_str = html.escape(content_str)
|
||||
# Don't truncate content since we're using 32k context
|
||||
return {
|
||||
"content": content_str,
|
||||
"metadata": {"timestamp": datetime.utcnow().isoformat()},
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Content sanitization failed: {e}")
|
||||
return {
|
||||
"content": "Error: Content sanitization failed",
|
||||
"metadata": {"error": str(e)},
|
||||
}
|
||||
|
||||
|
||||
def validate_uuid(uuid_string: str) -> bool:
|
||||
"""Validate UUID format."""
|
||||
try:
|
||||
uuid_obj = uuid.UUID(uuid_string)
|
||||
return str(uuid_obj) == uuid_string
|
||||
except (ValueError, AttributeError, TypeError):
|
||||
return False
|
||||
|
||||
|
||||
class DatabasePool:
|
||||
"""Pool of SQLite database connections."""
|
||||
|
||||
def __init__(self, database: str = DB_PATH, max_size: int = 5):
|
||||
"""Initialize the database pool."""
|
||||
self._database = database
|
||||
self._max_size = max_size
|
||||
self._pool: List[aiosqlite.Connection] = []
|
||||
self._lock = asyncio.Lock()
|
||||
self._semaphore = asyncio.BoundedSemaphore(max_size)
|
||||
self._closed = asyncio.Event()
|
||||
self._active_connections: Set[aiosqlite.Connection] = set()
|
||||
|
||||
async def _init_connection(self) -> aiosqlite.Connection:
|
||||
"""Initialize a database connection with optimized settings."""
|
||||
try:
|
||||
conn = await aiosqlite.connect(self._database)
|
||||
await conn.execute("PRAGMA journal_mode=WAL")
|
||||
await conn.execute("PRAGMA busy_timeout=5000")
|
||||
await conn.execute("PRAGMA foreign_keys=ON")
|
||||
await conn.execute("PRAGMA synchronous=NORMAL")
|
||||
await conn.execute("PRAGMA temp_store=MEMORY")
|
||||
await conn.execute("PRAGMA cache_size=-2000")
|
||||
await conn.execute("PRAGMA mmap_size=30000000000")
|
||||
await conn.execute("PRAGMA page_size=4096")
|
||||
logger.info("Database connection initialized with optimized settings")
|
||||
return conn
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize database connection: {e}")
|
||||
raise
|
||||
|
||||
async def _close_connection(self, conn: aiosqlite.Connection) -> None:
|
||||
"""Safely close a database connection."""
|
||||
try:
|
||||
await conn.rollback()
|
||||
await conn.close()
|
||||
logger.debug("Database connection closed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing connection: {e}")
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def acquire(self) -> AsyncIterator[aiosqlite.Connection]:
|
||||
"""Acquire a database connection from the pool."""
|
||||
if self._closed.is_set():
|
||||
raise ShutdownError("Database pool is closed")
|
||||
|
||||
semaphore_acquired = False
|
||||
conn = None
|
||||
try:
|
||||
async with timeout(DB_TIMEOUT):
|
||||
await self._semaphore.acquire()
|
||||
semaphore_acquired = True
|
||||
|
||||
async with self._lock:
|
||||
while self._pool and not self._closed.is_set():
|
||||
conn = self._pool.pop()
|
||||
try:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("SELECT 1")
|
||||
await conn.commit()
|
||||
self._active_connections.add(conn)
|
||||
logger.debug("Reusing existing database connection")
|
||||
break
|
||||
except Exception:
|
||||
await self._close_connection(conn)
|
||||
conn = None
|
||||
|
||||
if not conn:
|
||||
if self._closed.is_set():
|
||||
raise ShutdownError("Database pool is closed")
|
||||
conn = await self._init_connection()
|
||||
self._active_connections.add(conn)
|
||||
|
||||
yield conn
|
||||
|
||||
finally:
|
||||
if semaphore_acquired:
|
||||
self._semaphore.release()
|
||||
if conn and conn in self._active_connections:
|
||||
self._active_connections.remove(conn)
|
||||
if not self._closed.is_set():
|
||||
async with self._lock:
|
||||
if (
|
||||
len(self._pool) < self._max_size
|
||||
and not self._closed.is_set()
|
||||
):
|
||||
self._pool.append(conn)
|
||||
else:
|
||||
await self._close_connection(conn)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close all database connections."""
|
||||
if self._closed.is_set():
|
||||
return
|
||||
self._closed.set()
|
||||
async with self._lock:
|
||||
while self._pool:
|
||||
conn = self._pool.pop()
|
||||
await self._close_connection(conn)
|
||||
for conn in list(self._active_connections):
|
||||
await self._close_connection(conn)
|
||||
self._active_connections.clear()
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
"""Manages database operations."""
|
||||
|
||||
def __init__(self, pool: DatabasePool):
|
||||
"""Initialize with a connection pool."""
|
||||
self.pool = pool
|
||||
|
||||
async def init_db(self):
|
||||
"""Initialize the database schema with verification."""
|
||||
logger.info("Initializing database schema...")
|
||||
try:
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
# Create messages table
|
||||
# Create users table
|
||||
await cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
user_id INTEGER PRIMARY KEY,
|
||||
username TEXT NOT NULL,
|
||||
first_interaction DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
last_interaction DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
interaction_count INTEGER DEFAULT 0,
|
||||
preferences TEXT,
|
||||
metadata TEXT
|
||||
)
|
||||
""")
|
||||
logger.info("Users table created/verified")
|
||||
|
||||
# Create threads table
|
||||
await cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS threads (
|
||||
thread_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
channel_id INTEGER NOT NULL,
|
||||
creator_id INTEGER NOT NULL,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
last_activity DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
title TEXT,
|
||||
status TEXT DEFAULT 'active',
|
||||
FOREIGN KEY (creator_id) REFERENCES users(user_id)
|
||||
)
|
||||
""")
|
||||
logger.info("Threads table created/verified")
|
||||
|
||||
# Create messages table with thread support
|
||||
await cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
role TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
channel_id INTEGER NOT NULL,
|
||||
token_count INTEGER NOT NULL,
|
||||
message_uuid TEXT NOT NULL,
|
||||
parent_uuid TEXT,
|
||||
metadata TEXT,
|
||||
CONSTRAINT valid_role CHECK (role IN ('user', 'assistant', 'system')),
|
||||
CONSTRAINT valid_uuid CHECK (length(message_uuid) = 36),
|
||||
CONSTRAINT valid_parent CHECK (parent_uuid IS NULL OR length(parent_uuid) = 36)
|
||||
)
|
||||
""")
|
||||
logger.info("Messages table created/verified")
|
||||
|
||||
# Create indices
|
||||
await cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_user_channel ON messages(user_id, channel_id)"
|
||||
)
|
||||
await cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_timestamp ON messages(timestamp)"
|
||||
)
|
||||
await cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_message_uuid ON messages(message_uuid)"
|
||||
)
|
||||
await cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_parent_uuid ON messages(parent_uuid)"
|
||||
)
|
||||
await cursor.execute(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_unique_message ON messages(user_id, channel_id, message_uuid)"
|
||||
)
|
||||
|
||||
# Create indices for users and threads
|
||||
await cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_username ON users(username)"
|
||||
)
|
||||
await cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_last_interaction ON users(last_interaction)"
|
||||
)
|
||||
await cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_thread_channel ON threads(channel_id)"
|
||||
)
|
||||
await cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_thread_creator ON threads(creator_id)"
|
||||
)
|
||||
await cursor.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_thread_status ON threads(status)"
|
||||
)
|
||||
logger.info("All indices created/verified")
|
||||
|
||||
await conn.commit()
|
||||
logger.info("Database schema initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Database initialization failed: {e}")
|
||||
raise
|
||||
|
||||
async def store_message(
|
||||
self,
|
||||
user_id: int,
|
||||
role: str,
|
||||
content: Any,
|
||||
channel_id: int,
|
||||
message_uuid: str,
|
||||
parent_uuid: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Store a message in the database."""
|
||||
if not validate_uuid(message_uuid):
|
||||
raise ValueError(f"Invalid message UUID: {message_uuid}")
|
||||
|
||||
if parent_uuid and not validate_uuid(parent_uuid):
|
||||
raise ValueError(f"Invalid parent UUID: {parent_uuid}")
|
||||
|
||||
if role not in ("user", "assistant", "system"):
|
||||
raise ValueError(f"Invalid role: {role}")
|
||||
|
||||
try:
|
||||
# Sanitize and prepare content
|
||||
sanitized_content = sanitize_content(content)
|
||||
|
||||
# Pre-check token count before storing
|
||||
content_str = str(sanitized_content.get("content", ""))
|
||||
estimated_tokens = count_tokens(content_str)
|
||||
|
||||
if estimated_tokens > MAX_TOKENS:
|
||||
logger.warning(f"Content exceeds token limit ({estimated_tokens} > {MAX_TOKENS})")
|
||||
# Don't truncate, let the API manager handle it
|
||||
|
||||
content_json = json.dumps(sanitized_content, ensure_ascii=True)
|
||||
|
||||
# Calculate token count including metadata
|
||||
total_tokens = 0
|
||||
|
||||
# Count content tokens using tiktoken
|
||||
content_str = str(sanitized_content.get("content", ""))
|
||||
total_tokens += count_tokens(content_str)
|
||||
|
||||
# Count metadata tokens
|
||||
if isinstance(content, dict):
|
||||
# Add tokens for JSON structure
|
||||
total_tokens += 4 # Basic message structure overhead
|
||||
|
||||
# Count metadata tokens
|
||||
metadata = content.get("metadata", {})
|
||||
if metadata:
|
||||
metadata_str = json.dumps(metadata)
|
||||
total_tokens += count_tokens(metadata_str)
|
||||
total_tokens += 2 # Metadata structure overhead
|
||||
|
||||
# Count user info tokens if present
|
||||
user_info = metadata.get("user_info", {})
|
||||
if user_info:
|
||||
user_info_str = json.dumps(user_info)
|
||||
total_tokens += count_tokens(user_info_str)
|
||||
total_tokens += 2 # User info structure overhead
|
||||
|
||||
# Add system prompt overhead for assistant messages
|
||||
if role == "assistant":
|
||||
total_tokens += 150 # Approximate system prompt overhead
|
||||
|
||||
token_count = total_tokens
|
||||
|
||||
async def _store(cursor):
|
||||
await cursor.execute(
|
||||
"""
|
||||
INSERT INTO messages (
|
||||
user_id, role, content, channel_id, token_count,
|
||||
message_uuid, parent_uuid, metadata
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
user_id,
|
||||
role,
|
||||
content_json,
|
||||
channel_id,
|
||||
token_count,
|
||||
message_uuid,
|
||||
parent_uuid,
|
||||
json.dumps({"timestamp": datetime.utcnow().isoformat()}),
|
||||
),
|
||||
)
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await _store(cursor)
|
||||
await conn.commit()
|
||||
logger.debug(f"Message stored: {message_uuid}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store message: {e}")
|
||||
raise
|
||||
|
||||
async def get_conversation_history(
|
||||
self, user_id: int, channel_id: int, message_uuid: Optional[str] = None,
|
||||
message_id: Optional[int] = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Get conversation history with proper sanitization."""
|
||||
try:
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
if message_uuid:
|
||||
if not validate_uuid(message_uuid):
|
||||
raise ValueError(f"Invalid message UUID: {message_uuid}")
|
||||
|
||||
# Get the conversation thread by following parent_uuid links
|
||||
await cursor.execute(
|
||||
"""
|
||||
WITH RECURSIVE conversation_thread AS (
|
||||
SELECT id, role, content, message_uuid, parent_uuid, metadata, user_id
|
||||
FROM messages
|
||||
WHERE message_uuid = ?
|
||||
UNION ALL
|
||||
SELECT m.id, m.role, m.content, m.message_uuid, m.parent_uuid, m.metadata, m.user_id
|
||||
FROM messages m
|
||||
JOIN conversation_thread ct ON m.message_uuid = ct.parent_uuid
|
||||
)
|
||||
SELECT role, content, metadata, user_id
|
||||
FROM conversation_thread
|
||||
ORDER BY id ASC
|
||||
LIMIT ?
|
||||
""",
|
||||
(message_uuid, MAX_MESSAGES_FOR_CONTEXT),
|
||||
)
|
||||
else:
|
||||
# Check if we're looking for a specific Discord message ID
|
||||
if message_id:
|
||||
await cursor.execute(
|
||||
"""
|
||||
SELECT role, content, metadata, user_id
|
||||
FROM messages
|
||||
WHERE channel_id = ? AND json_extract(metadata, '$.discord_message_id') = ?
|
||||
LIMIT 1
|
||||
""",
|
||||
(channel_id, str(message_id)),
|
||||
)
|
||||
else:
|
||||
# Get recent conversation history from the channel
|
||||
await cursor.execute(
|
||||
"""
|
||||
SELECT role, content, metadata, user_id
|
||||
FROM messages
|
||||
WHERE channel_id = ?
|
||||
ORDER BY id DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(channel_id, MAX_MESSAGES_FOR_CONTEXT),
|
||||
)
|
||||
|
||||
messages = await cursor.fetchall()
|
||||
history = []
|
||||
for msg in reversed(messages):
|
||||
try:
|
||||
content_data = json.loads(msg[1])
|
||||
metadata = json.loads(msg[2]) if msg[2] else {}
|
||||
user_id = msg[3]
|
||||
|
||||
# Get the complete content and ensure user info is included
|
||||
content = html.unescape(content_data["content"])
|
||||
# No need to modify content - use it as stored
|
||||
|
||||
# Check if this is a complete response
|
||||
if metadata.get("complete_response"):
|
||||
# Use the full content for context
|
||||
history.append(
|
||||
{
|
||||
"role": msg[0],
|
||||
"content": content,
|
||||
"user_id": user_id,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# For non-complete responses, use as is
|
||||
history.append(
|
||||
{
|
||||
"role": msg[0],
|
||||
"content": content,
|
||||
"user_id": user_id,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON in message content: {msg[1]}")
|
||||
continue
|
||||
return history
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve conversation history: {e}")
|
||||
# Return a minimal context to allow operation to continue
|
||||
return [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Previous context unavailable",
|
||||
"user_id": 0,
|
||||
"metadata": {},
|
||||
}
|
||||
]
|
||||
|
||||
async def update_user(self, user_id: int, username: str) -> None:
|
||||
"""Update user information in the database."""
|
||||
try:
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
# Try to insert new user first
|
||||
try:
|
||||
await cursor.execute(
|
||||
"""
|
||||
INSERT INTO users (user_id, username)
|
||||
VALUES (?, ?)
|
||||
""",
|
||||
(user_id, username),
|
||||
)
|
||||
except aiosqlite.IntegrityError:
|
||||
# User exists, update their information
|
||||
await cursor.execute(
|
||||
"""
|
||||
UPDATE users
|
||||
SET username = ?,
|
||||
last_interaction = CURRENT_TIMESTAMP,
|
||||
interaction_count = interaction_count + 1
|
||||
WHERE user_id = ?
|
||||
""",
|
||||
(username, user_id),
|
||||
)
|
||||
await conn.commit()
|
||||
logger.debug(f"User {username} (ID: {user_id}) updated")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update user: {e}")
|
||||
raise
|
||||
|
||||
async def get_user_info(self, user_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""Get user information from the database."""
|
||||
try:
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"""
|
||||
SELECT username, first_interaction, last_interaction,
|
||||
interaction_count, preferences, metadata
|
||||
FROM users
|
||||
WHERE user_id = ?
|
||||
""",
|
||||
(user_id,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return {
|
||||
"username": row[0],
|
||||
"first_interaction": row[1],
|
||||
"last_interaction": row[2],
|
||||
"interaction_count": row[3],
|
||||
"preferences": json.loads(row[4]) if row[4] else {},
|
||||
"metadata": json.loads(row[5]) if row[5] else {},
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get user info: {e}")
|
||||
return None
|
||||
|
||||
async def create_thread(
|
||||
self, channel_id: int, creator_id: int, title: Optional[str] = None
|
||||
) -> Optional[int]:
|
||||
"""Create a new thread and return its ID."""
|
||||
try:
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"""
|
||||
INSERT INTO threads (channel_id, creator_id, title)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(channel_id, creator_id, title),
|
||||
)
|
||||
await conn.commit()
|
||||
return cursor.lastrowid
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create thread: {e}")
|
||||
return None
|
||||
|
||||
async def get_thread_info(self, thread_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""Get thread information."""
|
||||
try:
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"""
|
||||
SELECT t.*, u.username as creator_name
|
||||
FROM threads t
|
||||
JOIN users u ON t.creator_id = u.user_id
|
||||
WHERE t.thread_id = ?
|
||||
""",
|
||||
(thread_id,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row:
|
||||
return {
|
||||
"thread_id": row[0],
|
||||
"channel_id": row[1],
|
||||
"creator_id": row[2],
|
||||
"created_at": row[3],
|
||||
"last_activity": row[4],
|
||||
"title": row[5],
|
||||
"status": row[6],
|
||||
"creator_name": row[7],
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get thread info: {e}")
|
||||
return None
|
||||
|
||||
async def update_thread_activity(self, thread_id: int) -> None:
|
||||
"""Update thread's last activity timestamp."""
|
||||
try:
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"""
|
||||
UPDATE threads
|
||||
SET last_activity = CURRENT_TIMESTAMP
|
||||
WHERE thread_id = ?
|
||||
""",
|
||||
(thread_id,),
|
||||
)
|
||||
await conn.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update thread activity: {e}")
|
||||
|
||||
async def cleanup_old_messages(self):
|
||||
"""Clean up messages older than MESSAGE_CLEANUP_DAYS."""
|
||||
cleanup_date = datetime.now() - timedelta(days=MESSAGE_CLEANUP_DAYS)
|
||||
try:
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
"DELETE FROM messages WHERE timestamp < ?",
|
||||
(cleanup_date.strftime("%Y-%m-%d %H:%M:%S"),),
|
||||
)
|
||||
await conn.commit()
|
||||
logger.debug("Old messages cleaned up")
|
||||
except Exception as e:
|
||||
logger.error(f"Message cleanup failed: {e}")
|
||||
BIN
discord_glhf/handlers/.DS_Store
vendored
Normal file
BIN
discord_glhf/handlers/.DS_Store
vendored
Normal file
Binary file not shown.
13
discord_glhf/handlers/__init__.py
Normal file
13
discord_glhf/handlers/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Discord bot handlers package."""
|
||||
|
||||
from .message_handler import MessageHandler
|
||||
from .image_handler import ImageHandler
|
||||
from .tool_handler import ToolHandler
|
||||
from .event_handler import EventHandler
|
||||
|
||||
__all__ = [
|
||||
"MessageHandler",
|
||||
"ImageHandler",
|
||||
"ToolHandler",
|
||||
"EventHandler",
|
||||
]
|
||||
BIN
discord_glhf/handlers/__pycache__/__init__.cpython-313.pyc
Normal file
BIN
discord_glhf/handlers/__pycache__/__init__.cpython-313.pyc
Normal file
Binary file not shown.
BIN
discord_glhf/handlers/__pycache__/event_handler.cpython-313.pyc
Normal file
BIN
discord_glhf/handlers/__pycache__/event_handler.cpython-313.pyc
Normal file
Binary file not shown.
BIN
discord_glhf/handlers/__pycache__/image_handler.cpython-313.pyc
Normal file
BIN
discord_glhf/handlers/__pycache__/image_handler.cpython-313.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
discord_glhf/handlers/__pycache__/tool_handler.cpython-313.pyc
Normal file
BIN
discord_glhf/handlers/__pycache__/tool_handler.cpython-313.pyc
Normal file
Binary file not shown.
589
discord_glhf/handlers/event_handler.py
Normal file
589
discord_glhf/handlers/event_handler.py
Normal file
@@ -0,0 +1,589 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Discord event handling logic."""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
import asyncio
|
||||
from typing import Dict, Any
|
||||
from discord import Message, RawReactionActionEvent
|
||||
|
||||
from ..config import (
|
||||
logger, AUTO_RESPONSE_CHANNEL_ID, SYSTEM_PROMPT, BOT_OWNER_ID
|
||||
)
|
||||
from .message_handler import MessageHandler
|
||||
from .image_handler import ImageHandler
|
||||
from .tool_handler import ToolHandler
|
||||
|
||||
|
||||
class EventHandler:
|
||||
"""Handles Discord events and their processing."""
|
||||
|
||||
def __init__(self, bot, queue_manager, db_manager, api_manager):
|
||||
self.bot = bot
|
||||
self.queue_manager = queue_manager
|
||||
self.db_manager = db_manager
|
||||
self.api_manager = api_manager
|
||||
self.message_handler = MessageHandler(db_manager)
|
||||
self.image_handler = ImageHandler(api_manager)
|
||||
self.tool_handler = ToolHandler(bot)
|
||||
|
||||
# Set this handler as the queue's message processor
|
||||
self.queue_manager._message_handler = self._process_message
|
||||
|
||||
def _clean_mentions(self, response: str, mention: str, display_name: str, username: str) -> str:
|
||||
"""Clean up mentions in response text."""
|
||||
# Create a pattern that matches all possible mention formats
|
||||
patterns = [
|
||||
re.escape(display_name), # Full display name
|
||||
f"@{re.escape(username)}", # @username
|
||||
f"@<@{mention[2:-1]}>", # Double mention format
|
||||
f"<@{mention[2:-1]}>" # Raw mention format
|
||||
]
|
||||
pattern = '|'.join(patterns)
|
||||
|
||||
# Replace all mention formats with the proper mention
|
||||
return re.sub(pattern, mention, response)
|
||||
|
||||
async def handle_reaction(self, payload: RawReactionActionEvent) -> None:
|
||||
"""Handle reaction events on bot messages."""
|
||||
# Ignore our own reactions
|
||||
if payload.user_id == self.bot.user.id:
|
||||
return
|
||||
|
||||
try:
|
||||
# Get the channel
|
||||
channel = self.bot.get_channel(payload.channel_id)
|
||||
if not channel:
|
||||
return
|
||||
|
||||
# Get the message
|
||||
message = await channel.fetch_message(payload.message_id)
|
||||
if not message:
|
||||
return
|
||||
|
||||
# Only respond to reactions on our own messages
|
||||
if message.author.id != self.bot.user.id:
|
||||
return
|
||||
|
||||
# Get the user who reacted
|
||||
user = await self.bot.fetch_user(payload.user_id)
|
||||
if not user:
|
||||
return
|
||||
|
||||
# Convert the emoji to a string representation
|
||||
emoji_str = str(payload.emoji)
|
||||
|
||||
is_owner = user.id == BOT_OWNER_ID
|
||||
logger.info(
|
||||
"Reaction received (message %s): %s from %s%s",
|
||||
message.id,
|
||||
emoji_str,
|
||||
user.display_name,
|
||||
" [BOT OWNER]" if is_owner else ""
|
||||
)
|
||||
|
||||
# Build user metadata
|
||||
user_metadata = {
|
||||
"user_id": str(user.id),
|
||||
"is_owner": int(user.id) == BOT_OWNER_ID,
|
||||
"name": user.name,
|
||||
"display_name": user.display_name,
|
||||
}
|
||||
|
||||
# Build the reaction context
|
||||
reaction_context = (
|
||||
f"This user reacted with {emoji_str} to your message that "
|
||||
f"said: {message.content}\n\n"
|
||||
"You can respond however you want - with reactions, "
|
||||
"a message, or both. Stay in character and react naturally "
|
||||
"based on the emoji and your personality."
|
||||
)
|
||||
|
||||
# Build context for the API call
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": SYSTEM_PROMPT,
|
||||
"metadata": {
|
||||
"bot_owner_id": str(BOT_OWNER_ID),
|
||||
"current_user": {
|
||||
"user_id": str(user.id)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": reaction_context,
|
||||
"metadata": user_metadata,
|
||||
"context": {
|
||||
"timeout_env": "GLHF_TIMEOUT" # Use primary API timeout for reactions
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
# Get response from API
|
||||
response = await self.api_manager.get_completion(messages)
|
||||
if not response:
|
||||
return
|
||||
|
||||
# Parse tool calls and get processed response
|
||||
tool_calls, final_response, mentioned_users = self.tool_handler.parse_tool_calls(
|
||||
response, message_id=message.id, channel_id=channel.id
|
||||
)
|
||||
|
||||
# Execute tool calls
|
||||
for tool_name, args in tool_calls:
|
||||
try:
|
||||
if tool_name == "find_user":
|
||||
# Check if we're trying to mention the user who reacted
|
||||
if args["name"].lower() in [
|
||||
user.name.lower(),
|
||||
user.display_name.lower(),
|
||||
]:
|
||||
mention = f"<@{user.id}>"
|
||||
else:
|
||||
mention = await self.tool_handler.find_user_by_name(
|
||||
args["name"],
|
||||
message.guild.id if message.guild else None,
|
||||
)
|
||||
|
||||
if mention:
|
||||
final_response = self._clean_mentions(
|
||||
final_response,
|
||||
mention,
|
||||
user.display_name,
|
||||
args["name"]
|
||||
)
|
||||
|
||||
elif tool_name == "add_reaction":
|
||||
await self.tool_handler.add_reaction(
|
||||
message.id, channel.id, args["emoji"]
|
||||
)
|
||||
|
||||
elif tool_name == "create_embed":
|
||||
await self.tool_handler.create_embed(
|
||||
channel=channel, content=args["content"]
|
||||
)
|
||||
|
||||
elif tool_name == "create_thread":
|
||||
await self.tool_handler.create_thread(
|
||||
channel.id, args["name"], message.id
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool {tool_name}: {e}")
|
||||
|
||||
# If there's any text response left, send it
|
||||
if final_response:
|
||||
logger.info(
|
||||
f"Bot response to {user.display_name} ({user.name}#{user.discriminator})"
|
||||
f"{' [BOT OWNER]' if user.id == BOT_OWNER_ID else ''}'s reaction: {final_response}"
|
||||
)
|
||||
await self.message_handler.safe_send(
|
||||
channel, final_response, reference=message
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling reaction: {e}")
|
||||
|
||||
async def handle_message(self, message: Message) -> None:
|
||||
"""Handle incoming messages."""
|
||||
# Ignore our own messages
|
||||
if message.author == self.bot.user:
|
||||
return
|
||||
|
||||
try:
|
||||
# Only respond in configured channel or its threads
|
||||
if (message.channel.id != AUTO_RESPONSE_CHANNEL_ID and
|
||||
(not hasattr(message.channel, 'parent_id') or
|
||||
message.channel.parent_id != AUTO_RESPONSE_CHANNEL_ID)):
|
||||
return
|
||||
|
||||
# Early duplicate checks before any processing
|
||||
if any(item.message.id == message.id for item in self.queue_manager.message_queue.processing):
|
||||
logger.debug(f"Message {message.id} already in processing, skipping")
|
||||
return
|
||||
|
||||
# Get current queue size
|
||||
queue_size = self.queue_manager.message_queue.queue.qsize()
|
||||
logger.debug(f"Current queue size: {queue_size}")
|
||||
|
||||
# Check if message is already processed
|
||||
message_processed = await self.db_manager.get_conversation_history(
|
||||
user_id=0,
|
||||
channel_id=message.channel.id,
|
||||
message_id=message.id
|
||||
)
|
||||
if message_processed:
|
||||
logger.debug(f"Message {message.id} already processed, skipping")
|
||||
return
|
||||
|
||||
# Check for duplicate content in history
|
||||
recent_history = await self.db_manager.get_conversation_history(
|
||||
user_id=0,
|
||||
channel_id=message.channel.id
|
||||
)
|
||||
current_content = f"{message.author.display_name} ({message.author.name}) (<@{message.author.id}>): {message.content}"
|
||||
for hist_msg in recent_history:
|
||||
hist_content = hist_msg.get("content", {}).get("content", "") if isinstance(
|
||||
hist_msg.get("content"), dict) else hist_msg.get("content", "")
|
||||
if hist_content == current_content:
|
||||
logger.debug(f"Duplicate message content detected for message {message.id}, skipping")
|
||||
return
|
||||
|
||||
# Update user activity in database
|
||||
await self.message_handler.update_user_activity(
|
||||
message.author.id, message.author.name
|
||||
)
|
||||
|
||||
# If we get here, message is not a duplicate
|
||||
is_owner = message.author.id == BOT_OWNER_ID
|
||||
logger.info(
|
||||
f"Channel message from {message.author.display_name} ({message.author.name}#{message.author.discriminator})"
|
||||
f"{' [BOT OWNER]' if is_owner else ''}: {message.content}"
|
||||
)
|
||||
|
||||
# Process message content and handle images
|
||||
prompt = message.content
|
||||
enhanced_prompt, has_images = await self.image_handler.process_message_with_images(message, prompt)
|
||||
if has_images:
|
||||
logger.info("Enhanced prompt with image analysis")
|
||||
prompt = enhanced_prompt
|
||||
|
||||
# Get message history for context
|
||||
history = await self.db_manager.get_conversation_history(
|
||||
user_id=0,
|
||||
channel_id=message.channel.id,
|
||||
)
|
||||
logger.debug(f"Retrieved {len(history)} messages for context")
|
||||
|
||||
# Build context
|
||||
context = {
|
||||
"history": history,
|
||||
"bot_info": {
|
||||
"name": self.bot.user.name,
|
||||
"display_name": self.bot.user.display_name,
|
||||
"id": str(self.bot.user.id),
|
||||
},
|
||||
"user_info": {
|
||||
"name": message.author.name,
|
||||
"display_name": message.author.display_name,
|
||||
"id": str(message.author.id),
|
||||
},
|
||||
"timeout_env": "GLHF_TIMEOUT" # Use primary API timeout for all messages
|
||||
}
|
||||
|
||||
logger.info(f"Adding message {message.id} to queue")
|
||||
await self.queue_manager.add_message(
|
||||
channel=message.channel,
|
||||
message=message,
|
||||
prompt=prompt,
|
||||
context=context,
|
||||
priority=2, # Default priority
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message: {e}")
|
||||
await self.report_error(
|
||||
e,
|
||||
{
|
||||
"action": "handle_message",
|
||||
"channel_id": message.channel.id,
|
||||
"message_id": message.id,
|
||||
"content": message.content,
|
||||
},
|
||||
)
|
||||
|
||||
async def _process_message(self, item: Any) -> None:
|
||||
"""Process a message from the queue."""
|
||||
try:
|
||||
# Start typing indicator
|
||||
async with item.channel.typing():
|
||||
# Get fresh conversation history first
|
||||
history = await self.db_manager.get_conversation_history(
|
||||
user_id=0,
|
||||
channel_id=item.channel.id,
|
||||
)
|
||||
logger.debug(f"Retrieved {len(history)} messages for context")
|
||||
|
||||
# Generate message UUID upfront
|
||||
message_uuid = str(uuid.uuid4())
|
||||
|
||||
# Format the message with user info
|
||||
formatted_content = f"{item.message.author.display_name} ({item.message.author.name}) (<@{item.message.author.id}>): {item.prompt}"
|
||||
|
||||
# Check if this message is already in history
|
||||
message_in_history = False
|
||||
stored_uuid = None
|
||||
|
||||
for hist_msg in history:
|
||||
hist_content = hist_msg.get("content", {}).get("content", "") if isinstance(
|
||||
hist_msg.get("content"), dict) else hist_msg.get("content", "")
|
||||
if hist_content == formatted_content:
|
||||
message_in_history = True
|
||||
if isinstance(hist_msg.get("metadata"), dict):
|
||||
stored_uuid = hist_msg["metadata"].get("message_uuid")
|
||||
break
|
||||
|
||||
# Use stored UUID if found, otherwise use new UUID
|
||||
parent_uuid = stored_uuid if stored_uuid else message_uuid
|
||||
|
||||
if not message_in_history:
|
||||
# Only store if not already in history
|
||||
message_metadata = {
|
||||
**item.context,
|
||||
"message_uuid": message_uuid,
|
||||
"user_info": {
|
||||
"name": item.message.author.name,
|
||||
"display_name": item.message.author.display_name,
|
||||
"id": str(item.message.author.id)
|
||||
}
|
||||
}
|
||||
await self.store_message(
|
||||
user_id=item.message.author.id,
|
||||
role="user",
|
||||
content={"content": formatted_content, "metadata": message_metadata},
|
||||
channel_id=item.channel.id,
|
||||
message_uuid=message_uuid,
|
||||
)
|
||||
|
||||
# Build messages array with system prompt and history
|
||||
tool_instruction = """
|
||||
"""
|
||||
|
||||
# Add system message with proper metadata structure
|
||||
user_info = item.context.get("user_info", {})
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": SYSTEM_PROMPT + tool_instruction,
|
||||
"metadata": {
|
||||
"bot_owner_id": str(BOT_OWNER_ID),
|
||||
"current_user": {
|
||||
"user_id": str(user_info.get("id", "0"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
messages = [system_message]
|
||||
messages.extend(history)
|
||||
|
||||
# Always add current message to the API call
|
||||
# Add timeout_env to the context
|
||||
api_context = {
|
||||
"timeout_env": item.context.get("timeout_env"),
|
||||
"image_urls": item.context.get("image_urls")
|
||||
}
|
||||
current_message = await self.message_handler.build_api_message(
|
||||
formatted_content,
|
||||
item.context.get("image_urls"),
|
||||
None,
|
||||
self.tool_handler.mentioned_users,
|
||||
item.message.author.id,
|
||||
context=api_context
|
||||
)
|
||||
messages.append(current_message)
|
||||
|
||||
try:
|
||||
# Process response
|
||||
response = await self.api_manager.get_completion(messages)
|
||||
if not response:
|
||||
return
|
||||
|
||||
# Parse tool calls and get processed response
|
||||
tool_calls, final_response, mentioned_users = self.tool_handler.parse_tool_calls(
|
||||
response, message_id=item.message.id, channel_id=item.channel.id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API completion: {e}")
|
||||
return
|
||||
|
||||
# Execute tool calls
|
||||
for tool_name, args in tool_calls:
|
||||
try:
|
||||
if tool_name == "find_user":
|
||||
# Check if we're trying to mention the message author
|
||||
if args["name"].lower() in [
|
||||
item.message.author.name.lower(),
|
||||
item.message.author.display_name.lower(),
|
||||
]:
|
||||
mention = f"<@{item.message.author.id}>"
|
||||
else:
|
||||
mention = await self.tool_handler.find_user_by_name(
|
||||
args["name"],
|
||||
item.message.guild.id if item.message.guild else None,
|
||||
)
|
||||
|
||||
if mention:
|
||||
final_response = self._clean_mentions(
|
||||
final_response,
|
||||
mention,
|
||||
item.message.author.display_name,
|
||||
args["name"]
|
||||
)
|
||||
|
||||
elif tool_name == "add_reaction":
|
||||
await self.tool_handler.add_reaction(
|
||||
item.message.id, item.channel.id, args["emoji"]
|
||||
)
|
||||
|
||||
elif tool_name == "create_embed":
|
||||
await self.tool_handler.create_embed(
|
||||
channel=item.channel, content=args["content"]
|
||||
)
|
||||
|
||||
elif tool_name == "create_thread":
|
||||
# Create Discord thread first
|
||||
discord_thread = await self.tool_handler.create_thread(
|
||||
item.channel.id, args["name"], item.message.id
|
||||
)
|
||||
|
||||
if discord_thread:
|
||||
# Create thread in database and update activity
|
||||
thread_id = await self.message_handler.create_thread_for_message(
|
||||
item.channel.id, item.message.author.id, args["name"]
|
||||
)
|
||||
if thread_id:
|
||||
await self.db_manager.update_thread_activity(thread_id)
|
||||
logger.info(f"Created and stored thread '{args['name']}' in database")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool {tool_name}: {e}")
|
||||
|
||||
# Send the response
|
||||
if final_response:
|
||||
author = item.message.author
|
||||
owner_tag = ' [BOT OWNER]' if int(author.id) == BOT_OWNER_ID else ''
|
||||
logger.info(
|
||||
f"Bot response to {author.display_name} "
|
||||
f"({author.name}#{author.discriminator})"
|
||||
f"{owner_tag}: {final_response}"
|
||||
)
|
||||
sent_message = await self.message_handler.safe_send(
|
||||
item.channel, final_response, reference=item.message
|
||||
)
|
||||
|
||||
if sent_message:
|
||||
# Store the complete response in the database
|
||||
response_uuid = str(uuid.uuid4())
|
||||
|
||||
# Wait a moment for all messages to be sent
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Get all messages sent in this response
|
||||
messages = []
|
||||
async for message in item.channel.history(limit=20, oldest_first=True, after=sent_message.created_at):
|
||||
if message.author == self.bot.user:
|
||||
messages.append(message)
|
||||
|
||||
# Build complete response
|
||||
complete_response = [sent_message.content]
|
||||
for message in messages:
|
||||
if message.content.startswith("⤷ "):
|
||||
complete_response.append(message.content[2:])
|
||||
else:
|
||||
complete_response.append(message.content)
|
||||
|
||||
# Combine all parts into final response
|
||||
final_response = "\n".join(complete_response)
|
||||
|
||||
# Log the message collection
|
||||
logger.debug(
|
||||
f"Collected {len(messages) + 1} messages for response:\n"
|
||||
f"Initial: {sent_message.id}\n"
|
||||
f"Continuations: {[m.id for m in messages]}"
|
||||
)
|
||||
|
||||
response_metadata = {
|
||||
"response_id": response_uuid,
|
||||
"user_info": {
|
||||
"name": "bot",
|
||||
"display_name": self.bot.user.display_name,
|
||||
"id": str(self.bot.user.id)
|
||||
},
|
||||
"complete_response": True,
|
||||
"discord_message_id": str(sent_message.id),
|
||||
"continuation_messages": len(messages) > 0,
|
||||
"message_count": len(messages) + 1,
|
||||
"message_ids": [str(sent_message.id)] + [str(m.id) for m in messages]
|
||||
}
|
||||
|
||||
# Store the complete response
|
||||
await self.store_message(
|
||||
user_id=item.message.author.id,
|
||||
role="assistant",
|
||||
content={
|
||||
"content": final_response,
|
||||
"metadata": response_metadata,
|
||||
},
|
||||
channel_id=item.channel.id,
|
||||
message_uuid=response_uuid,
|
||||
parent_uuid=parent_uuid,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing message: {e}")
|
||||
await self.report_error(
|
||||
e,
|
||||
{
|
||||
"action": "process_message",
|
||||
"message_id": item.message.id,
|
||||
"channel_id": item.channel.id,
|
||||
"content": item.prompt,
|
||||
},
|
||||
)
|
||||
|
||||
async def store_message(
|
||||
self,
|
||||
user_id: int,
|
||||
role: str,
|
||||
content: Dict[str, Any],
|
||||
channel_id: int,
|
||||
message_uuid: str,
|
||||
parent_uuid: str = None,
|
||||
) -> None:
|
||||
"""Store a message in the database."""
|
||||
try:
|
||||
await self.db_manager.store_message(
|
||||
user_id=user_id,
|
||||
role=role,
|
||||
content=content,
|
||||
channel_id=channel_id,
|
||||
message_uuid=message_uuid,
|
||||
parent_uuid=parent_uuid,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store message: {e}")
|
||||
|
||||
async def report_error(self, error: Exception, context: Dict[str, Any]) -> None:
|
||||
"""Report an error with context."""
|
||||
error_msg = self.format_error_message(error, context)
|
||||
logger.error(error_msg)
|
||||
|
||||
@staticmethod
|
||||
def format_error_message(error: Exception, context: Dict = None) -> str:
|
||||
"""Format an error message with context."""
|
||||
import traceback
|
||||
|
||||
error_details = [
|
||||
"🚨 **Bot Error Report**",
|
||||
f"**Error Type:** {type(error).__name__}",
|
||||
f"**Error Message:** {str(error)}",
|
||||
"",
|
||||
"**Context:**",
|
||||
]
|
||||
|
||||
if context:
|
||||
for key, value in context.items():
|
||||
error_details.append(f"- {key}: {value}")
|
||||
|
||||
error_details.extend(
|
||||
[
|
||||
"",
|
||||
"**Traceback:**",
|
||||
"```python",
|
||||
"".join(traceback.format_tb(error.__traceback__)),
|
||||
"```",
|
||||
]
|
||||
)
|
||||
|
||||
return "\n".join(error_details)
|
||||
112
discord_glhf/handlers/image_handler.py
Normal file
112
discord_glhf/handlers/image_handler.py
Normal file
@@ -0,0 +1,112 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Image detection and processing logic."""
|
||||
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from discord import Message
|
||||
|
||||
from ..config import logger, VISION_MODEL, VISION_API_KEY, VISION_API_BASE_URL
|
||||
|
||||
|
||||
class ImageHandler:
|
||||
"""Handles image detection and processing."""
|
||||
|
||||
def __init__(self, api_manager=None):
|
||||
self.api_manager = api_manager
|
||||
|
||||
@staticmethod
|
||||
def detect_images(message: Message) -> List[str]:
|
||||
"""Detect images in a message from attachments and URLs."""
|
||||
image_urls = []
|
||||
|
||||
# Check attachments
|
||||
for attachment in message.attachments:
|
||||
if attachment.content_type and attachment.content_type.startswith("image/"):
|
||||
image_urls.append(attachment.url)
|
||||
|
||||
# Check URLs in content
|
||||
words = message.content.split()
|
||||
for word in words:
|
||||
if word.startswith(("http://", "https://")):
|
||||
if any(
|
||||
word.lower().endswith(ext)
|
||||
for ext in [".png", ".jpg", ".jpeg", ".gif", ".webp"]
|
||||
):
|
||||
image_urls.append(word)
|
||||
|
||||
return image_urls
|
||||
|
||||
@staticmethod
|
||||
def is_valid_image_url(url: str) -> bool:
|
||||
"""Check if a URL points to a valid image."""
|
||||
return any(
|
||||
url.lower().endswith(ext)
|
||||
for ext in [".png", ".jpg", ".jpeg", ".gif", ".webp"]
|
||||
)
|
||||
|
||||
async def analyze_images(self, prompt: str, image_urls: List[str]) -> Optional[str]:
|
||||
"""Analyze images using vision API and return the analysis."""
|
||||
if not self.api_manager or not image_urls:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Format messages for vision analysis
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What do you see in these images?"},
|
||||
*[
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": url, "detail": "auto"},
|
||||
}
|
||||
for url in image_urls
|
||||
],
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
logger.debug(f"Vision request messages: {messages}")
|
||||
|
||||
# Get vision analysis
|
||||
analysis = await self.api_manager._make_api_call("vision", messages)
|
||||
logger.debug(f"Vision API response: {analysis}")
|
||||
|
||||
if not analysis or not analysis[1]:
|
||||
logger.error(f"Failed to analyze images. Response: {analysis}")
|
||||
return None
|
||||
|
||||
return analysis[1]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing images: {e}")
|
||||
logger.error(
|
||||
f"Vision API config - Model: {VISION_MODEL}, Base URL: {VISION_API_BASE_URL}"
|
||||
)
|
||||
logger.error(f"Image URLs being processed: {image_urls}")
|
||||
return None
|
||||
|
||||
def format_prompt_with_analysis(self, prompt: str, analysis: Optional[str]) -> str:
|
||||
"""Format the prompt with image analysis for the main conversation."""
|
||||
if not analysis:
|
||||
return prompt
|
||||
|
||||
# Add image analysis to the prompt
|
||||
return f"{prompt}\n\n[Image Analysis: {analysis}]"
|
||||
|
||||
async def process_message_with_images(
|
||||
self, message: Message, prompt: str
|
||||
) -> Tuple[str, bool]:
|
||||
"""Process a message with images, returning the enhanced prompt and whether images were processed."""
|
||||
image_urls = self.detect_images(message)
|
||||
if not image_urls:
|
||||
return prompt, False
|
||||
|
||||
# First analyze images
|
||||
analysis = await self.analyze_images(prompt, image_urls)
|
||||
if not analysis:
|
||||
return prompt, False
|
||||
|
||||
# Then include analysis in the prompt
|
||||
enhanced_prompt = self.format_prompt_with_analysis(prompt, analysis)
|
||||
return enhanced_prompt, True
|
||||
337
discord_glhf/handlers/message_handler.py
Normal file
337
discord_glhf/handlers/message_handler.py
Normal file
@@ -0,0 +1,337 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Message handling and processing logic."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import asyncio
|
||||
from typing import Dict, List, Any, Optional
|
||||
from datetime import datetime
|
||||
from discord import Message, TextChannel
|
||||
|
||||
from ..config import logger
|
||||
|
||||
|
||||
class MessageHandler:
|
||||
"""Handles message processing and formatting."""
|
||||
|
||||
def __init__(self, db_manager):
|
||||
"""Initialize with database manager."""
|
||||
self.db = db_manager
|
||||
|
||||
async def update_user_activity(self, user_id: int, username: str) -> None:
|
||||
"""Update user activity in database."""
|
||||
await self.db.update_user(user_id, username)
|
||||
|
||||
async def create_thread_for_message(
|
||||
self, channel_id: int, user_id: int, title: Optional[str] = None
|
||||
) -> Optional[int]:
|
||||
"""Create a new thread for the message."""
|
||||
# Ensure user exists in database
|
||||
user_info = await self.db.get_user_info(user_id)
|
||||
if not user_info:
|
||||
logger.warning(f"User {user_id} not found in database")
|
||||
return None
|
||||
|
||||
# Create thread
|
||||
thread_id = await self.db.create_thread(channel_id, user_id, title)
|
||||
if thread_id:
|
||||
logger.info(f"Created thread {thread_id} for user {user_id}")
|
||||
return thread_id
|
||||
|
||||
@staticmethod
|
||||
def format_message_preview(
|
||||
role: str,
|
||||
content: str,
|
||||
timestamp: Optional[str] = None,
|
||||
display_name: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Format a message preview with timestamp and user display name."""
|
||||
ts = ""
|
||||
if timestamp:
|
||||
try:
|
||||
dt = datetime.fromisoformat(timestamp)
|
||||
ts = f"[{dt.strftime('%H:%M:%S')}] "
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Show up to 100 chars of content, preserving word boundaries
|
||||
if len(content) > 100:
|
||||
words = content[:100].split()
|
||||
preview = " ".join(words[:-1]) + "..."
|
||||
else:
|
||||
preview = content
|
||||
|
||||
# For bot messages, just use "bot"
|
||||
# For user messages, use their display name if available
|
||||
name = "bot" if role == "assistant" else (display_name or "user")
|
||||
|
||||
return f"{ts}{name}: {preview}"
|
||||
|
||||
async def safe_send(
|
||||
self,
|
||||
channel: TextChannel,
|
||||
content: str,
|
||||
reference: Optional[Message] = None,
|
||||
stream: bool = True
|
||||
) -> Optional[Message]:
|
||||
"""Safely send a message to a channel with proper error handling."""
|
||||
try:
|
||||
from discord import AllowedMentions
|
||||
|
||||
# Ensure mentions are properly handled
|
||||
allowed_mentions = AllowedMentions(users=True, roles=False, everyone=False)
|
||||
|
||||
if stream:
|
||||
# For streaming, send initial message and edit it
|
||||
# Start with the first chunk to avoid "Processing..." placeholder
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
code_block = False
|
||||
code_lang = ""
|
||||
|
||||
for line in content.split('\n'):
|
||||
# Handle code blocks
|
||||
if line.startswith('```'):
|
||||
if code_block:
|
||||
# End code block
|
||||
current_chunk += line + '\n'
|
||||
code_block = False
|
||||
else:
|
||||
# Start code block
|
||||
code_block = True
|
||||
code_lang = line[3:].strip()
|
||||
current_chunk += line + '\n'
|
||||
else:
|
||||
# Regular line
|
||||
if len(current_chunk + line + '\n') > 1900:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = line + '\n'
|
||||
else:
|
||||
current_chunk += line + '\n'
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
# Send first chunk immediately
|
||||
first_chunk = chunks[0] if chunks else content
|
||||
initial_message = await channel.send(
|
||||
first_chunk,
|
||||
reference=reference,
|
||||
allowed_mentions=allowed_mentions,
|
||||
mention_author=True
|
||||
)
|
||||
|
||||
# Process remaining chunks
|
||||
current_message = initial_message
|
||||
for i, chunk in enumerate(chunks[1:], 1):
|
||||
try:
|
||||
# Send each chunk as a new message
|
||||
current_message = await channel.send(
|
||||
f"⤷ {chunk}", # Add continuation marker
|
||||
allowed_mentions=allowed_mentions
|
||||
)
|
||||
# Small delay to maintain message order
|
||||
await asyncio.sleep(0.05)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send chunk {i}: {e}")
|
||||
|
||||
return initial_message
|
||||
else:
|
||||
# Handle non-streaming long messages by splitting them intelligently
|
||||
if len(content) > 2000:
|
||||
# Split into chunks while preserving code blocks and mentions
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
code_block = False
|
||||
code_lang = ""
|
||||
|
||||
lines = content.split('\n')
|
||||
|
||||
for line in lines:
|
||||
# Handle code blocks
|
||||
if line.startswith('```'):
|
||||
if code_block:
|
||||
# End of code block
|
||||
if len(current_chunk + line + '\n') > 1900:
|
||||
# If adding the closing ``` would exceed limit,
|
||||
# close the block in this chunk and start new one
|
||||
chunks.append(current_chunk + '```')
|
||||
current_chunk = f'```{code_lang}\n{line[3:]}\n'
|
||||
else:
|
||||
current_chunk += line + '\n'
|
||||
code_block = False
|
||||
else:
|
||||
# Start of code block
|
||||
code_block = True
|
||||
code_lang = line[3:].strip()
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = line + '\n'
|
||||
else:
|
||||
# Regular line
|
||||
if len(current_chunk + line + '\n') > 1900:
|
||||
if code_block:
|
||||
# Close code block in this chunk
|
||||
chunks.append(current_chunk + '```')
|
||||
current_chunk = f'```{code_lang}\n{line}\n'
|
||||
else:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = line + '\n'
|
||||
else:
|
||||
current_chunk += line + '\n'
|
||||
|
||||
# Add final chunk
|
||||
if current_chunk:
|
||||
if code_block:
|
||||
current_chunk += '```'
|
||||
chunks.append(current_chunk)
|
||||
|
||||
# Send chunks with improved formatting
|
||||
first_message = None
|
||||
total_chunks = len(chunks)
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
# Add clear continuation markers
|
||||
if total_chunks > 1:
|
||||
marker = f"[Part {i+1}/{total_chunks}]\n"
|
||||
if i > 0:
|
||||
marker += "⤷ " # Add continuation indicator
|
||||
else:
|
||||
marker = ""
|
||||
|
||||
try:
|
||||
# Ensure we're not exceeding Discord's limit
|
||||
final_content = marker + chunk
|
||||
if len(final_content) > 2000:
|
||||
final_content = final_content[:1997] + "..."
|
||||
|
||||
# Send with proper reference handling
|
||||
msg = await channel.send(
|
||||
final_content,
|
||||
reference=reference if i == 0 else None,
|
||||
allowed_mentions=allowed_mentions,
|
||||
mention_author=True if i == 0 else False,
|
||||
)
|
||||
|
||||
# Store first message for reference
|
||||
if i == 0:
|
||||
first_message = msg
|
||||
|
||||
# Minimal delay to prevent rate limiting
|
||||
if i < total_chunks - 1:
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to send message chunk {i+1}/{total_chunks}: {e}"
|
||||
)
|
||||
logger.error(f"Chunk content length: {len(chunk)}")
|
||||
|
||||
return first_message
|
||||
else:
|
||||
# Send short message normally
|
||||
return await channel.send(
|
||||
content,
|
||||
reference=reference,
|
||||
allowed_mentions=allowed_mentions,
|
||||
mention_author=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send message: {e}")
|
||||
return None
|
||||
|
||||
async def build_api_message(
|
||||
self, content: str, image_urls: List[str] = None, metadata: Dict[str, Any] = None,
|
||||
mentioned_users: Dict[str, str] = None, user_id: Optional[int] = None,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Build a properly formatted message for API consumption with user context."""
|
||||
message = {"role": "user"}
|
||||
|
||||
# Initialize context if not provided
|
||||
message_context = context.copy() if context else {}
|
||||
|
||||
# Add metadata if provided
|
||||
if metadata:
|
||||
message_context["metadata"] = metadata
|
||||
|
||||
# Add mentioned users to metadata if available
|
||||
if mentioned_users:
|
||||
if "user_info" not in message_context["metadata"]:
|
||||
message_context["metadata"]["user_info"] = {}
|
||||
message_context["metadata"]["user_info"]["mentioned_users"] = mentioned_users
|
||||
|
||||
# Add user history and preferences if available
|
||||
if user_id:
|
||||
user_info = await self.db.get_user_info(user_id)
|
||||
if user_info:
|
||||
if "metadata" not in message_context:
|
||||
message_context["metadata"] = {}
|
||||
message_context["metadata"]["user_history"] = {
|
||||
"first_interaction": user_info["first_interaction"],
|
||||
"last_interaction": user_info["last_interaction"],
|
||||
"interaction_count": user_info["interaction_count"],
|
||||
"preferences": user_info["preferences"]
|
||||
}
|
||||
|
||||
# Add context to message
|
||||
if message_context:
|
||||
message["context"] = message_context
|
||||
|
||||
# Handle content based on whether there are images
|
||||
if not image_urls:
|
||||
message["content"] = content
|
||||
else:
|
||||
message["content"] = [
|
||||
{"type": "text", "text": content},
|
||||
*[
|
||||
{"type": "image_url", "image_url": {"url": url}}
|
||||
for url in image_urls
|
||||
],
|
||||
]
|
||||
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def extract_message_content(msg: Dict[str, Any]) -> str:
|
||||
"""Extract content from a message object."""
|
||||
content = ""
|
||||
if isinstance(msg.get("content"), dict):
|
||||
content = msg["content"].get("content", "")
|
||||
# Add username from metadata if available
|
||||
metadata = msg["content"].get("metadata", {})
|
||||
user_info = metadata.get("user_info", {})
|
||||
if "name" in user_info:
|
||||
content = f"{user_info['name']}: {content}"
|
||||
else:
|
||||
content = msg.get("content", "")
|
||||
# Add username from metadata if available
|
||||
if "metadata" in msg and "user_info" in msg["metadata"]:
|
||||
user_info = msg["metadata"]["user_info"]
|
||||
if "name" in user_info:
|
||||
content = f"{user_info['name']}: {content}"
|
||||
return content
|
||||
|
||||
@staticmethod
|
||||
def build_message_context(
|
||||
history: List[Dict[str, Any]], current_message: Dict[str, Any]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Build message context for API calls."""
|
||||
messages = []
|
||||
|
||||
# Add history messages
|
||||
for msg in history:
|
||||
content = MessageHandler.extract_message_content(msg)
|
||||
# Preserve the original context if it exists
|
||||
message_data = {
|
||||
"role": msg["role"],
|
||||
"content": content
|
||||
}
|
||||
if "context" in msg:
|
||||
message_data["context"] = msg["context"]
|
||||
messages.append(message_data)
|
||||
|
||||
# Add current message with its context
|
||||
messages.append(current_message)
|
||||
|
||||
return messages
|
||||
329
discord_glhf/handlers/tool_handler.py
Normal file
329
discord_glhf/handlers/tool_handler.py
Normal file
@@ -0,0 +1,329 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tool execution and response handling logic."""
|
||||
|
||||
import re
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
from discord import TextChannel, Message, utils, Embed, PartialEmoji
|
||||
|
||||
from ..config import logger
|
||||
|
||||
|
||||
class ToolHandler:
|
||||
"""Handles tool execution and responses."""
|
||||
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.mentioned_users = {} # Map usernames to their mention formats
|
||||
|
||||
async def find_user_by_name(
|
||||
self, name: str, guild_id: Optional[int] = None
|
||||
) -> Optional[str]:
|
||||
"""Find a user by their name and return their mention string."""
|
||||
try:
|
||||
if not name:
|
||||
logger.warning("Tool Use - Find User: Empty username provided")
|
||||
return None
|
||||
|
||||
# Get the guild
|
||||
guild = None
|
||||
if guild_id:
|
||||
guild = self.bot.get_guild(guild_id)
|
||||
if not guild:
|
||||
logger.error(f"Guild {guild_id} not found")
|
||||
return None
|
||||
else:
|
||||
# Use first available guild if none specified
|
||||
guild = next(iter(self.bot.guilds), None)
|
||||
if not guild:
|
||||
logger.error("No guilds available")
|
||||
return None
|
||||
|
||||
# Try exact match first
|
||||
member = guild.get_member_named(name)
|
||||
if not member:
|
||||
# Try fuzzy match on username or display name
|
||||
search_name = name.lower()
|
||||
for m in guild.members:
|
||||
# Check username first (most reliable)
|
||||
if search_name in m.name.lower():
|
||||
member = m
|
||||
break
|
||||
# Then check display name
|
||||
if search_name in m.display_name.lower():
|
||||
member = m
|
||||
break
|
||||
# Finally check nickname if it exists
|
||||
if m.nick and search_name in m.nick.lower():
|
||||
member = m
|
||||
break
|
||||
|
||||
if member:
|
||||
logger.info(f"Tool Use - Find User: Found user '{member.name}'")
|
||||
# Store the member ID for mention validation
|
||||
self.mentioned_users[name] = str(member.id)
|
||||
# Return the mention format
|
||||
mention = f"<@{member.id}>"
|
||||
logger.info(f"Generated mention: {mention}")
|
||||
return mention
|
||||
|
||||
logger.warning(f"Tool Use - Find User: No match found for '{name}'")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Tool Use - Find User: Error finding user: {e}")
|
||||
return None
|
||||
|
||||
async def add_reaction(self, message_id: int, channel_id: int, emoji: str) -> None:
|
||||
"""Add a reaction to a message."""
|
||||
try:
|
||||
logger.info(
|
||||
f"Tool Use - Add Reaction: Adding '{emoji}' to message {message_id} in channel {channel_id}"
|
||||
)
|
||||
channel = self.bot.get_channel(channel_id)
|
||||
if not channel:
|
||||
logger.error(f"Channel {channel_id} not found")
|
||||
return
|
||||
|
||||
message = await channel.fetch_message(message_id)
|
||||
if not message:
|
||||
logger.error(f"Message {message_id} not found")
|
||||
return
|
||||
|
||||
# Validate and add emoji
|
||||
try:
|
||||
# Check if it's a custom emoji (<:name:id> or <a:name:id>)
|
||||
if emoji.startswith('<') and emoji.endswith('>'):
|
||||
partial_emoji = PartialEmoji.from_str(emoji)
|
||||
await message.add_reaction(partial_emoji)
|
||||
# Check if it's a Unicode emoji
|
||||
elif any(ord(c) in range(0x1F300, 0x1FAF6) for c in emoji):
|
||||
await message.add_reaction(emoji)
|
||||
# Check if it's a standard Discord emoji (:name:)
|
||||
elif emoji.startswith(':') and emoji.endswith(':'):
|
||||
emoji_name = emoji.strip(':')
|
||||
# Try to find the emoji in the guild's emoji list
|
||||
guild_emoji = utils.get(message.guild.emojis, name=emoji_name)
|
||||
if guild_emoji:
|
||||
await message.add_reaction(guild_emoji)
|
||||
else:
|
||||
logger.warning(f"Tool Use - Add Reaction: Could not find emoji '{emoji_name}' in guild emojis")
|
||||
else:
|
||||
logger.warning(f"Tool Use - Add Reaction: Invalid emoji format '{emoji}' - must be Unicode emoji, custom emoji (<:name:id>), or standard emoji (:name:)")
|
||||
return
|
||||
|
||||
logger.info(f"Tool Use - Add Reaction: Successfully added '{emoji}' to message {message_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Tool Use - Add Reaction: Failed to add reaction: {e}")
|
||||
|
||||
logger.info(
|
||||
f"Tool Use - Add Reaction: Successfully added '{emoji}' to message {message_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Tool Use - Add Reaction: Failed to add reaction: {e}")
|
||||
|
||||
async def create_embed(
|
||||
self, channel: TextChannel, content: str, color: int = 0xFF0000
|
||||
) -> None:
|
||||
"""Create and send a rich embed."""
|
||||
try:
|
||||
lines = content.strip().split("\n")
|
||||
if not lines:
|
||||
return
|
||||
|
||||
# Extract title and description
|
||||
title = lines[0]
|
||||
description = "\n".join(lines[1:]) if len(lines) > 1 else ""
|
||||
|
||||
logger.info(
|
||||
f"Tool Use - Create Embed: Creating embed in channel {channel.id}\n"
|
||||
+ f"Title: {title}\n"
|
||||
+ f"Description length: {len(description)} chars\n"
|
||||
+ f"Fields: {description.count('-')} items"
|
||||
)
|
||||
|
||||
# Create and send embed
|
||||
embed = Embed(title=title, description=description, color=color)
|
||||
await channel.send(embed=embed)
|
||||
logger.info(
|
||||
f"Tool Use - Create Embed: Successfully sent embed to channel {channel.id}\n"
|
||||
+ f"Title: {title}\n"
|
||||
+ f"Description length: {len(description)} chars\n"
|
||||
+ f"Fields: {description.count('-')} items"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Tool Use - Create Embed: Failed to create embed: {e}")
|
||||
|
||||
async def create_thread(
|
||||
self, channel_id: int, name: str, message_id: Optional[int] = None
|
||||
) -> None:
|
||||
"""Create a new thread."""
|
||||
try:
|
||||
logger.info(
|
||||
f"Tool Use - Create Thread: Creating thread '{name}' in channel {channel_id}"
|
||||
)
|
||||
channel = self.bot.get_channel(channel_id)
|
||||
if not channel:
|
||||
logger.error(f"Channel {channel_id} not found")
|
||||
return
|
||||
|
||||
try:
|
||||
# Clean and format thread name
|
||||
# Remove mentions and clean up text
|
||||
formatted_name = re.sub(r'<@!?\d+>', '', name) # Remove mentions
|
||||
formatted_name = re.sub(r'\s+', ' ', formatted_name) # Normalize whitespace
|
||||
formatted_name = formatted_name.strip()
|
||||
|
||||
# Extract core topic
|
||||
# Look for key comparison or topic patterns
|
||||
if "vs" in formatted_name.lower() or "versus" in formatted_name.lower():
|
||||
# Extract the comparison parts
|
||||
parts = re.split(r'\bvs\.?\b|\bversus\b', formatted_name, flags=re.IGNORECASE)
|
||||
if len(parts) >= 2:
|
||||
formatted_name = f"{parts[0].strip()} vs {parts[1].strip()}"
|
||||
if not formatted_name.lower().endswith(" debate"):
|
||||
formatted_name += " debate"
|
||||
elif "overrated" in formatted_name.lower() or "underrated" in formatted_name.lower():
|
||||
# Extract just the coaster name and rating type
|
||||
match = re.search(r'(.*?)\b(over|under)rated\b', formatted_name, re.IGNORECASE)
|
||||
if match:
|
||||
formatted_name = f"{match.group(1).strip()} {match.group(2)}rated discussion"
|
||||
elif not any(formatted_name.lower().endswith(suffix) for suffix in ("discussion", "debate", "thread", "talk")):
|
||||
formatted_name += " discussion"
|
||||
|
||||
# Ensure name length is within Discord's limit
|
||||
if len(formatted_name) > 100:
|
||||
formatted_name = formatted_name[:97] + "..."
|
||||
|
||||
# Create the thread
|
||||
if message_id:
|
||||
message = await channel.fetch_message(message_id)
|
||||
if message:
|
||||
thread = await message.create_thread(
|
||||
name=formatted_name,
|
||||
auto_archive_duration=1440 # 24 hours
|
||||
)
|
||||
if thread:
|
||||
logger.info(f"Tool Use - Create Thread: Successfully created thread '{formatted_name}' from message")
|
||||
return thread
|
||||
else:
|
||||
thread = await channel.create_thread(
|
||||
name=formatted_name,
|
||||
type=None, # Let Discord choose based on channel type
|
||||
auto_archive_duration=1440 # 24 hours
|
||||
)
|
||||
if thread:
|
||||
logger.info(f"Tool Use - Create Thread: Successfully created thread '{formatted_name}'")
|
||||
return thread
|
||||
|
||||
logger.error("Tool Use - Create Thread: Failed to create thread - no thread object returned")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool Use - Create Thread: Failed to create thread: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Tool Use - Create Thread: Failed to create thread: {e}")
|
||||
|
||||
def parse_tool_calls(
|
||||
self,
|
||||
response: str,
|
||||
message_id: Optional[int] = None,
|
||||
channel_id: Optional[int] = None,
|
||||
) -> Tuple[List[Tuple[str, Dict[str, Any]]], str]:
|
||||
"""Parse all tool calls from a response string.
|
||||
|
||||
Args:
|
||||
response: The response string to parse
|
||||
message_id: Optional ID of the message being responded to
|
||||
channel_id: Optional ID of the channel the message is in
|
||||
"""
|
||||
try:
|
||||
tool_calls = []
|
||||
lines = response.split("\n")
|
||||
response_lines = []
|
||||
i = 0
|
||||
|
||||
while i < len(lines):
|
||||
line = lines[i].strip()
|
||||
if not line:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Convert line to lowercase for command checking
|
||||
line_lower = line.lower()
|
||||
command_found = False
|
||||
|
||||
# Process tools based on natural language patterns
|
||||
|
||||
# Simple pattern matching for basic tool usage
|
||||
# Let the LLM decide most tool usage through natural language
|
||||
if "@" in line: # Basic mention support
|
||||
for match in re.finditer(r"@(\w+(?:\s+\w+)*)", line):
|
||||
name = match.group(1).strip()
|
||||
tool_calls.append(("find_user", {"name": name}))
|
||||
command_found = True
|
||||
|
||||
# Thread creation patterns
|
||||
thread_patterns = [
|
||||
(r'(?i)(.*?\bvs\.?\b.*?(?:debate|discussion)?)', r'\1 debate'), # X vs Y
|
||||
(r'(?i)(.*?\b(?:over|under)rated\b.*?(?:discussion)?)', r'\1 discussion'), # overrated/underrated
|
||||
(r'(?i)(.*?\b(?:safety|maintenance|review)\b.*?(?:discussion|thread)?)', r'\1 discussion') # Specific topics
|
||||
]
|
||||
|
||||
for pattern, name_format in thread_patterns:
|
||||
if re.search(pattern, line_lower):
|
||||
# Extract the topic from the line
|
||||
match = re.search(pattern, line, re.IGNORECASE)
|
||||
if match:
|
||||
thread_name = re.sub(pattern, name_format, match.group(1))
|
||||
tool_calls.append(("create_thread", {
|
||||
"channel_id": channel_id,
|
||||
"name": thread_name,
|
||||
"message_id": message_id
|
||||
}))
|
||||
command_found = True
|
||||
break
|
||||
|
||||
# Support emoji reactions (both explicit and from text)
|
||||
# 1. Match Unicode emojis, custom emojis, and standard Discord emojis
|
||||
emoji_pattern = r'([😀-🙏🌀-🗿]|<a?:[a-zA-Z0-9_]+:\d+>|:[a-zA-Z0-9_]+:)'
|
||||
emoji_matches = re.finditer(emoji_pattern, line)
|
||||
for match in emoji_matches:
|
||||
emoji = match.group(1)
|
||||
if emoji.strip():
|
||||
tool_calls.append(("add_reaction", {
|
||||
"emoji": emoji,
|
||||
"message_id": message_id,
|
||||
"channel_id": channel_id
|
||||
}))
|
||||
command_found = True
|
||||
|
||||
# 2. Also detect emoticons and convert them to emojis
|
||||
emoticon_map = {
|
||||
r'(?:^|\s)[:;]-?[)D](?:\s|$)': '😊', # :) ;) :-) :D
|
||||
r'(?:^|\s)[:;]-?[(\[](?:\s|$)': '😢', # :( ;( :-( :[
|
||||
r'(?:^|\s)[:;]-?[pP](?:\s|$)': '😛', # :p ;p :-p
|
||||
r'(?:^|\s)[:;]-?[oO](?:\s|$)': '😮', # :o ;o :-o
|
||||
r'(?:^|\s)[xX][dD](?:\s|$)': '😂', # xD XD
|
||||
}
|
||||
for pattern, emoji in emoticon_map.items():
|
||||
if re.search(pattern, line):
|
||||
tool_calls.append(("add_reaction", {
|
||||
"emoji": emoji,
|
||||
"message_id": message_id,
|
||||
"channel_id": channel_id
|
||||
}))
|
||||
command_found = True
|
||||
|
||||
# Always include the line in the response, regardless of commands
|
||||
response_lines.append(line)
|
||||
|
||||
i += 1
|
||||
|
||||
# Join response lines, removing empty lines at start/end
|
||||
final_response = "\n".join(response_lines).strip()
|
||||
return tool_calls, final_response, self.mentioned_users
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool Use - Parse: Error parsing tool calls: {e}")
|
||||
# Return empty tool calls but preserve original response
|
||||
return [], response, {}
|
||||
20
discord_glhf/main.py
Normal file
20
discord_glhf/main.py
Normal file
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python3
|
||||
import sys
|
||||
from .bot import run_bot
|
||||
from .config import logger
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
logger.info("Starting Discord GLHF Bot...")
|
||||
run_bot()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Bot shutdown requested by user")
|
||||
sys.exit(0)
|
||||
except Exception as e:
|
||||
logger.critical(f"Fatal error: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
94
discord_glhf/markov.py
Normal file
94
discord_glhf/markov.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import json
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Set, Tuple
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class MarkovModel:
|
||||
"""A Markov chain model for generating text based on training data."""
|
||||
|
||||
def __init__(self, state_size: int = 2):
|
||||
self.state_size = state_size
|
||||
self.model: Dict[Tuple[str, ...], Dict[str, int]
|
||||
] = defaultdict(lambda: defaultdict(int))
|
||||
self.start_states: List[Tuple[str, ...]] = []
|
||||
self.save_path = Path("data/markov_model.pkl")
|
||||
self._load_model()
|
||||
|
||||
def _get_states(self, words: List[str]) -> List[Tuple[str, ...]]:
|
||||
"""Convert a list of words into state tuples."""
|
||||
if len(words) < self.state_size:
|
||||
return []
|
||||
return list(zip(*[words[i:] for i in range(self.state_size)]))
|
||||
|
||||
def train(self, text: str) -> None:
|
||||
"""Train the model on a piece of text."""
|
||||
# Split text into words
|
||||
words = text.split()
|
||||
if len(words) <= self.state_size:
|
||||
return
|
||||
|
||||
# Get states
|
||||
states = self._get_states(words)
|
||||
|
||||
# Add first state as a possible start state
|
||||
if states:
|
||||
self.start_states.append(states[0])
|
||||
|
||||
# Train model on word transitions
|
||||
for i, state in enumerate(states[:-1]):
|
||||
next_word = words[i + self.state_size]
|
||||
self.model[state][next_word] += 1
|
||||
|
||||
def generate(self, max_words: int = 50) -> str:
|
||||
"""Generate text using the trained model."""
|
||||
if not self.start_states:
|
||||
return ""
|
||||
|
||||
# Start with a random starting state
|
||||
current_state = random.choice(self.start_states)
|
||||
result = list(current_state)
|
||||
|
||||
for _ in range(max_words - self.state_size):
|
||||
if current_state not in self.model:
|
||||
break
|
||||
|
||||
# Get possible next words and their counts
|
||||
next_words = self.model[current_state]
|
||||
if not next_words:
|
||||
break
|
||||
|
||||
# Choose next word based on frequency
|
||||
words, counts = zip(*next_words.items())
|
||||
next_word = random.choices(words, weights=counts)[0]
|
||||
result.append(next_word)
|
||||
|
||||
# Update state
|
||||
current_state = tuple(result[-self.state_size:])
|
||||
|
||||
return " ".join(result)
|
||||
|
||||
def save_model(self) -> None:
|
||||
"""Save the model to disk."""
|
||||
self.save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.save_path, 'wb') as f:
|
||||
pickle.dump({
|
||||
'model': dict(self.model),
|
||||
'start_states': self.start_states,
|
||||
'state_size': self.state_size
|
||||
}, f)
|
||||
|
||||
def _load_model(self) -> None:
|
||||
"""Load the model from disk if it exists."""
|
||||
if self.save_path.exists():
|
||||
try:
|
||||
with open(self.save_path, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
self.model = defaultdict(
|
||||
lambda: defaultdict(int), data['model'])
|
||||
self.start_states = data['start_states']
|
||||
self.state_size = data['state_size']
|
||||
except Exception as e:
|
||||
print(f"Error loading model: {e}")
|
||||
682
discord_glhf/queue.py
Normal file
682
discord_glhf/queue.py
Normal file
@@ -0,0 +1,682 @@
|
||||
#!/usr/bin/env python3
|
||||
import asyncio
|
||||
import time
|
||||
import html
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
import discord
|
||||
from typing import Optional, Set, Dict, Any
|
||||
import json
|
||||
|
||||
from .config import (
|
||||
logger,
|
||||
MAX_QUEUE_SIZE,
|
||||
CONCURRENT_TASKS,
|
||||
MAX_USER_QUEUED_MESSAGES,
|
||||
ShutdownError,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True) # Make the class immutable and hashable
|
||||
class QueueItem:
|
||||
channel: discord.TextChannel
|
||||
message: discord.Message
|
||||
prompt: str
|
||||
priority: int
|
||||
timestamp: float
|
||||
user_id: int
|
||||
context: Dict[str, Any] # Store message context
|
||||
|
||||
def __hash__(self):
|
||||
# Use message ID as part of hash since it's unique
|
||||
return hash((self.message.id, self.user_id, self.timestamp))
|
||||
|
||||
@staticmethod
|
||||
def sanitize_prompt(prompt: str) -> str:
|
||||
"""Sanitize the prompt content."""
|
||||
# Remove any control characters
|
||||
prompt = "".join(char for char in prompt if ord(char) >= 32 or char in "\n\t")
|
||||
|
||||
# HTML escape to prevent injection
|
||||
prompt = html.escape(prompt)
|
||||
|
||||
# Remove any potential Discord markdown exploits
|
||||
prompt = re.sub(
|
||||
r"(`{1,3}|~{2}|\|{2}|\*{1,3}|_{1,3})", lambda m: "\\" + m.group(0), prompt
|
||||
)
|
||||
|
||||
return prompt.strip()
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
channel: discord.TextChannel,
|
||||
message: discord.Message,
|
||||
prompt: str,
|
||||
priority: int = 2,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> "QueueItem":
|
||||
"""Create a new QueueItem with sanitized content and preserved context."""
|
||||
sanitized_prompt = cls.sanitize_prompt(prompt)
|
||||
|
||||
# Base message context
|
||||
message_context = {
|
||||
"message_id": str(message.id),
|
||||
"channel_id": str(channel.id),
|
||||
"guild_id": str(message.guild.id) if message.guild else None,
|
||||
"author": {
|
||||
"id": str(message.author.id),
|
||||
"name": message.author.name,
|
||||
"display_name": message.author.display_name,
|
||||
"nick": message.author.nick
|
||||
if hasattr(message.author, "nick")
|
||||
else None,
|
||||
"bot": message.author.bot,
|
||||
"discriminator": message.author.discriminator,
|
||||
"roles": [str(role.id) for role in message.author.roles]
|
||||
if hasattr(message.author, "roles")
|
||||
else [],
|
||||
"color": str(message.author.color)
|
||||
if hasattr(message.author, "color")
|
||||
else None,
|
||||
},
|
||||
"timestamp": message.created_at.isoformat() if message.created_at else None,
|
||||
"referenced_message": None,
|
||||
}
|
||||
|
||||
# Add reference information if this is a reply
|
||||
if message.reference and message.reference.resolved:
|
||||
ref_msg = message.reference.resolved
|
||||
message_context["referenced_message"] = {
|
||||
"id": str(ref_msg.id),
|
||||
"content": cls.sanitize_prompt(ref_msg.content),
|
||||
"author": {
|
||||
"id": str(ref_msg.author.id),
|
||||
"name": ref_msg.author.name,
|
||||
"display_name": ref_msg.author.display_name,
|
||||
"nick": ref_msg.author.nick
|
||||
if hasattr(ref_msg.author, "nick")
|
||||
else None,
|
||||
"bot": ref_msg.author.bot,
|
||||
"discriminator": ref_msg.author.discriminator,
|
||||
"roles": [str(role.id) for role in ref_msg.author.roles]
|
||||
if hasattr(ref_msg.author, "roles")
|
||||
else [],
|
||||
"color": str(ref_msg.author.color)
|
||||
if hasattr(ref_msg.author, "color")
|
||||
else None,
|
||||
},
|
||||
}
|
||||
|
||||
# Merge provided context with message context
|
||||
if context:
|
||||
message_context.update(context)
|
||||
|
||||
return cls(
|
||||
channel=channel,
|
||||
message=message,
|
||||
prompt=sanitized_prompt,
|
||||
priority=priority,
|
||||
timestamp=time.time(),
|
||||
user_id=message.author.id,
|
||||
context=message_context,
|
||||
)
|
||||
|
||||
|
||||
class MessageQueue:
|
||||
def __init__(self):
|
||||
self.queue = asyncio.PriorityQueue()
|
||||
self.processing: Set[QueueItem] = set()
|
||||
self.user_queues = defaultdict(int)
|
||||
self._lock = asyncio.Lock()
|
||||
self.semaphore = asyncio.Semaphore(CONCURRENT_TASKS)
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
self._last_cleanup = time.time()
|
||||
self._active = True # Track if queue is active
|
||||
|
||||
# Initialize stats
|
||||
self._total_processed = 0
|
||||
self._failed_messages = 0
|
||||
|
||||
async def start_cleanup(self):
|
||||
"""Start the periodic cleanup task."""
|
||||
if not self._cleanup_task:
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
async def stop_cleanup(self):
|
||||
"""Stop the cleanup task."""
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._cleanup_task = None
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""Periodically clean up abandoned messages."""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(300) # Run every 5 minutes
|
||||
await self._cleanup_abandoned_messages()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup loop: {e}")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def _cleanup_abandoned_messages(self):
|
||||
"""Clean up messages that have been in the queue too long."""
|
||||
current_time = time.time()
|
||||
timeout = 3600 # 1 hour timeout
|
||||
|
||||
try:
|
||||
async with self._lock:
|
||||
if not self._active: # Don't clean up if queue is not active
|
||||
return
|
||||
|
||||
# Create a new queue with only valid items
|
||||
new_queue = asyncio.PriorityQueue()
|
||||
cleaned_count = 0
|
||||
retained_count = 0
|
||||
|
||||
while not self.queue.empty():
|
||||
try:
|
||||
priority, timestamp, item = await self.queue.get()
|
||||
|
||||
# Skip items currently being processed
|
||||
if item in self.processing:
|
||||
await new_queue.put((priority, timestamp, item))
|
||||
retained_count += 1
|
||||
continue
|
||||
|
||||
# Keep recent messages
|
||||
if current_time - timestamp < timeout:
|
||||
await new_queue.put((priority, timestamp, item))
|
||||
retained_count += 1
|
||||
else:
|
||||
self.user_queues[item.user_id] = max(
|
||||
0, self.user_queues[item.user_id] - 1
|
||||
)
|
||||
cleaned_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during queue cleanup: {e}")
|
||||
|
||||
self.queue = new_queue
|
||||
self._last_cleanup = current_time
|
||||
|
||||
if cleaned_count > 0:
|
||||
logger.info(
|
||||
f"Cleanup complete: Removed {cleaned_count} old messages, "
|
||||
f"retained {retained_count} messages"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clean up abandoned messages: {e}")
|
||||
|
||||
async def put(self, item: QueueItem) -> None:
|
||||
"""Add a message to the queue."""
|
||||
async with self._lock:
|
||||
if self.user_queues[item.user_id] >= MAX_USER_QUEUED_MESSAGES:
|
||||
logger.warning(f"User {item.user_id} has too many queued messages")
|
||||
return
|
||||
|
||||
# More aggressive queue management
|
||||
if self.queue.qsize() >= MAX_QUEUE_SIZE:
|
||||
logger.warning("Queue is full, clearing old messages")
|
||||
# Clear queue except for most recent messages
|
||||
while self.queue.qsize() > MAX_QUEUE_SIZE // 2:
|
||||
try:
|
||||
_, _, item = self.queue.get_nowait()
|
||||
self.user_queues[item.user_id] = max(
|
||||
0, self.user_queues[item.user_id] - 1
|
||||
)
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
self.user_queues[item.user_id] += 1
|
||||
# Negative timestamp ensures FIFO ordering within same priority level
|
||||
await self.queue.put((item.priority, -item.timestamp, item))
|
||||
qsize = self.queue.qsize()
|
||||
logger.info(
|
||||
f"Added message to queue (size: {qsize}, user: {item.user_id}, priority: {item.priority})"
|
||||
)
|
||||
|
||||
# Log queue status periodically
|
||||
if qsize > 0 and qsize % 5 == 0:
|
||||
logger.warning(f"Queue size has reached {qsize} messages")
|
||||
|
||||
async def get(self) -> Optional[QueueItem]:
|
||||
"""Get the next message from the queue."""
|
||||
try:
|
||||
_, _, item = await self.queue.get()
|
||||
self.processing.add(item)
|
||||
logger.debug(f"Retrieved message from queue for user {item.user_id}")
|
||||
return item
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting item from queue: {e}")
|
||||
return None
|
||||
|
||||
async def task_done(self, item: QueueItem) -> None:
|
||||
"""Mark a task as completed."""
|
||||
async with self._lock:
|
||||
if item in self.processing:
|
||||
self.processing.discard(item)
|
||||
self.user_queues[item.user_id] = max(
|
||||
0, self.user_queues[item.user_id] - 1
|
||||
)
|
||||
try:
|
||||
self.queue.task_done()
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Task already marked as done for user {item.user_id}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Marked task as done for user {item.user_id}")
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all items from the queue."""
|
||||
async with self._lock:
|
||||
async with self._lock:
|
||||
while not self.queue.empty():
|
||||
try:
|
||||
await self.queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
self.processing.clear()
|
||||
self.user_queues.clear()
|
||||
logger.debug("Queue cleared")
|
||||
|
||||
|
||||
class QueueManager:
|
||||
def __init__(self, state_file: str = "queue_state.json"):
|
||||
self._shutting_down = asyncio.Event()
|
||||
self._active_tasks: Set[asyncio.Task] = set()
|
||||
self._lock = asyncio.Lock()
|
||||
self.message_queue = MessageQueue()
|
||||
self._queue_processor: Optional[asyncio.Task] = None
|
||||
self._health_check_task: Optional[asyncio.Task] = None
|
||||
self._message_handler = None # Store the message handler
|
||||
|
||||
# Initialize state management
|
||||
from .queue_state import QueueState
|
||||
|
||||
self._state = QueueState(state_file)
|
||||
|
||||
# Load state
|
||||
self._active = self._state.get_state("active")
|
||||
self._total_processed = self._state.get_state("total_processed")
|
||||
self._failed_messages = self._state.get_state("failed_messages")
|
||||
self._last_processed_time = self._state.get_state("last_processed_time")
|
||||
|
||||
async def _health_check(self) -> None:
|
||||
"""Monitor queue processor health and restart if needed."""
|
||||
logger.info("Queue health check started")
|
||||
check_interval = 60 # Check every minute
|
||||
restart_threshold = 300 # Restart if no processing for 5 minutes
|
||||
|
||||
while not self.shutting_down and self._active:
|
||||
try:
|
||||
current_time = time.time()
|
||||
time_since_last_process = current_time - self._last_processed_time
|
||||
|
||||
# Check if processor is running and processing messages
|
||||
if (
|
||||
time_since_last_process > restart_threshold
|
||||
or not self._queue_processor
|
||||
or self._queue_processor.done()
|
||||
):
|
||||
logger.warning(
|
||||
f"Queue appears stalled - "
|
||||
f"Time since last process: {time_since_last_process:.1f}s, "
|
||||
f"Processor running: {bool(self._queue_processor and not self._queue_processor.done())}"
|
||||
)
|
||||
|
||||
# Restart the processor
|
||||
if self._message_handler:
|
||||
logger.info("Restarting queue processor")
|
||||
if self._queue_processor:
|
||||
self._queue_processor.cancel()
|
||||
try:
|
||||
await self._queue_processor
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._queue_processor = asyncio.create_task(
|
||||
self._process_queue(self._message_handler)
|
||||
)
|
||||
await self.register_task(self._queue_processor)
|
||||
self._last_processed_time = (
|
||||
time.time()
|
||||
) # Reset timer after restart
|
||||
else:
|
||||
logger.error("Cannot restart queue - no message handler")
|
||||
|
||||
await asyncio.sleep(check_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Queue health check cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in queue health check: {e}")
|
||||
await asyncio.sleep(check_interval)
|
||||
|
||||
@property
|
||||
def shutting_down(self) -> bool:
|
||||
return self._shutting_down.is_set()
|
||||
|
||||
def set_shutting_down(self, value: bool = True) -> None:
|
||||
"""Set the shutting down state."""
|
||||
if value:
|
||||
self._active = False
|
||||
self._shutting_down.set()
|
||||
else:
|
||||
self._active = True
|
||||
self._shutting_down.clear()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the queue manager is running."""
|
||||
return bool(
|
||||
self._queue_processor
|
||||
and not self._queue_processor.done()
|
||||
and not self.shutting_down
|
||||
)
|
||||
|
||||
async def start(self, message_handler) -> None:
|
||||
"""Start the queue processor and cleanup tasks."""
|
||||
async with self._lock:
|
||||
self._message_handler = message_handler # Store the handler
|
||||
self._active = True # Reset active state
|
||||
self._shutting_down.clear() # Clear shutdown flag
|
||||
|
||||
# Cancel existing processor if running
|
||||
if self._queue_processor and not self._queue_processor.done():
|
||||
self._queue_processor.cancel()
|
||||
try:
|
||||
await self._queue_processor
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Start cleanup task
|
||||
await self.message_queue.start_cleanup()
|
||||
|
||||
# Start new processor
|
||||
self._queue_processor = asyncio.create_task(
|
||||
self._process_queue(message_handler)
|
||||
)
|
||||
await self.register_task(self._queue_processor)
|
||||
logger.info("Queue processor started")
|
||||
|
||||
# Add error callback with automatic restart
|
||||
def on_processor_done(future: asyncio.Future):
|
||||
try:
|
||||
future.result() # Raise any exceptions
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"Queue processor failed: {e}")
|
||||
# Only attempt restart if not shutting down and still active
|
||||
if not self.shutting_down and self._active:
|
||||
logger.info("Automatically restarting queue processor...")
|
||||
asyncio.create_task(self._restart_processor())
|
||||
|
||||
self._queue_processor.add_done_callback(on_processor_done)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start queue processor: {e}")
|
||||
self._active = False # Mark as inactive on startup failure
|
||||
raise
|
||||
|
||||
async def _restart_processor(self) -> None:
|
||||
"""Restart the queue processor."""
|
||||
try:
|
||||
if self._message_handler and not self.shutting_down and self._active:
|
||||
logger.info("Restarting queue processor...")
|
||||
await self.start(self._message_handler)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to restart queue processor: {e}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the queue processor and cleanup tasks."""
|
||||
logger.info("Stopping queue processor...")
|
||||
self._shutting_down.set()
|
||||
self.message_queue._active = (
|
||||
False # Signal queue to stop accepting new messages
|
||||
)
|
||||
|
||||
# Stop cleanup task
|
||||
await self.message_queue.stop_cleanup()
|
||||
|
||||
# Cancel queue processor
|
||||
if self._queue_processor:
|
||||
self._queue_processor.cancel()
|
||||
try:
|
||||
await self._queue_processor
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Clear remaining items
|
||||
await self.message_queue.clear()
|
||||
|
||||
# Log final stats
|
||||
logger.info(
|
||||
f"Queue processor stopped. Final stats: "
|
||||
f"Processed: {self._total_processed}, "
|
||||
f"Failed: {self._failed_messages}, "
|
||||
f"Active tasks: {len(self._active_tasks)}"
|
||||
)
|
||||
|
||||
async def register_task(self, task: asyncio.Task) -> None:
|
||||
"""Register a task with the queue manager."""
|
||||
if self.shutting_down:
|
||||
raise ShutdownError("Queue manager is shutting down")
|
||||
async with self._lock:
|
||||
self._active_tasks.add(task)
|
||||
|
||||
def cleanup_callback(t):
|
||||
try:
|
||||
t.result() # Get result to prevent "Task destroyed but it is pending" warnings
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
finally:
|
||||
asyncio.create_task(self._remove_task(t))
|
||||
|
||||
task.add_done_callback(cleanup_callback)
|
||||
logger.debug(f"Task registered: {task}")
|
||||
|
||||
async def _remove_task(self, task: asyncio.Task) -> None:
|
||||
"""Remove a task from the active tasks set."""
|
||||
async with self._lock:
|
||||
self._active_tasks.discard(task)
|
||||
logger.debug(f"Task removed: {task}")
|
||||
|
||||
async def wait_for_tasks(self, timeout: Optional[float] = None) -> None:
|
||||
"""Wait for all active tasks to complete."""
|
||||
async with self._lock:
|
||||
if self._active_tasks:
|
||||
logger.debug(f"Waiting for tasks: {self._active_tasks}")
|
||||
await asyncio.wait(
|
||||
list(self._active_tasks),
|
||||
timeout=timeout,
|
||||
return_when=asyncio.ALL_COMPLETED,
|
||||
)
|
||||
|
||||
async def _process_queue(self, message_handler) -> None:
|
||||
"""Process messages from the queue."""
|
||||
logger.info("Queue processor starting")
|
||||
last_heartbeat = time.time()
|
||||
heartbeat_interval = 60 # Log heartbeat every minute
|
||||
processor_id = str(uuid.uuid4())[:8] # Unique ID for this processor instance
|
||||
|
||||
while True:
|
||||
current_time = time.time()
|
||||
|
||||
# Regular heartbeat logging
|
||||
if current_time - last_heartbeat >= heartbeat_interval:
|
||||
logger.info(
|
||||
f"Queue processor {processor_id} heartbeat - "
|
||||
f"Active: {self._active}, "
|
||||
f"Shutting down: {self.shutting_down}, "
|
||||
f"Queue size: {self.message_queue.queue.qsize()}, "
|
||||
f"Processing: {len(self.message_queue.processing)}, "
|
||||
f"Total processed: {self._total_processed}"
|
||||
)
|
||||
last_heartbeat = current_time
|
||||
|
||||
if self.shutting_down or not self._active:
|
||||
logger.warning(
|
||||
f"Queue processor {processor_id} stopping - "
|
||||
f"Active: {self._active}, "
|
||||
f"Shutting down: {self.shutting_down}"
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
# Get next message from queue with timeout
|
||||
try:
|
||||
async with asyncio.timeout(30): # 30 second timeout on queue.get
|
||||
item = await self.message_queue.get()
|
||||
if not item:
|
||||
logger.debug(
|
||||
f"Queue processor {processor_id} - Empty item received"
|
||||
)
|
||||
continue
|
||||
except asyncio.TimeoutError:
|
||||
logger.debug(
|
||||
f"Queue processor {processor_id} - No messages for 30s"
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Queue processor {processor_id} - Error getting message: {e}"
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"Queue processor {processor_id} - Processing message: {item.prompt[:100]} "
|
||||
f"(Queue size: {self.message_queue.queue.qsize()})"
|
||||
)
|
||||
|
||||
processing_start = time.time()
|
||||
try:
|
||||
# Process message with timeout
|
||||
async with self.message_queue.semaphore, asyncio.timeout(120): # 2 minute timeout for processing
|
||||
logger.debug(
|
||||
f"Queue processor {processor_id} - Starting message handler for {item.message.id}"
|
||||
)
|
||||
async with self.message_queue.semaphore:
|
||||
await message_handler(item)
|
||||
await self.message_queue.task_done(item)
|
||||
self._total_processed += 1
|
||||
processing_time = time.time() - processing_start
|
||||
# Update last processed time after successful processing
|
||||
self._state.update_processed_time()
|
||||
logger.info(
|
||||
f"Queue processor {processor_id} - Successfully processed message {item.message.id} "
|
||||
f"in {processing_time:.2f}s "
|
||||
f"(Total: {self._total_processed}, Queue: {self.message_queue.queue.qsize()})"
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
self._failed_messages += 1
|
||||
processing_time = time.time() - processing_start
|
||||
logger.error(
|
||||
f"Queue processor {processor_id} - Message {item.message.id} timed out "
|
||||
f"after {processing_time:.2f}s - "
|
||||
f"User: {item.user_id}, Channel: {item.channel.id}. "
|
||||
f"Failed messages: {self._failed_messages}"
|
||||
)
|
||||
except Exception as e:
|
||||
self._failed_messages += 1
|
||||
processing_time = time.time() - processing_start
|
||||
logger.error(
|
||||
f"Queue processor {processor_id} - Failed to process message {item.message.id} "
|
||||
f"after {processing_time:.2f}s - "
|
||||
f"User: {item.user_id}, Channel: {item.channel.id}, "
|
||||
f"Error: {str(e)}, Failed: {self._failed_messages}",
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
# Always mark task as done and log state
|
||||
try:
|
||||
await self.message_queue.task_done(item)
|
||||
qsize = self.message_queue.queue.qsize()
|
||||
processing_count = len(self.message_queue.processing)
|
||||
logger.info(
|
||||
f"Queue processor {processor_id} - Task completed. "
|
||||
f"Queue size: {qsize}, Currently processing: {processing_count}"
|
||||
)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Queue processor {processor_id} - "
|
||||
f"Task already marked as done for message {item.message.id}"
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Queue processor cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Queue processor {processor_id} encountered error: {e} - "
|
||||
f"Active: {self._active}, Shutting down: {self.shutting_down}, "
|
||||
f"Queue size: {self.message_queue.queue.qsize()}, "
|
||||
f"Processing: {len(self.message_queue.processing)}"
|
||||
)
|
||||
if self._active and not self.shutting_down:
|
||||
await asyncio.sleep(1) # Only sleep if we should continue
|
||||
else:
|
||||
logger.warning(
|
||||
f"Queue processor {processor_id} stopping due to inactive state"
|
||||
)
|
||||
break
|
||||
|
||||
# Log final state when exiting the loop
|
||||
logger.warning(
|
||||
f"Queue processor {processor_id} exiting - "
|
||||
f"Processed: {self._total_processed}, Failed: {self._failed_messages}, "
|
||||
f"Final queue size: {self.message_queue.queue.qsize()}"
|
||||
)
|
||||
|
||||
async def add_message(
|
||||
self,
|
||||
channel: discord.TextChannel,
|
||||
message: discord.Message,
|
||||
prompt: str,
|
||||
priority: int = 2,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Add a message to the queue with the specified priority."""
|
||||
if self.shutting_down:
|
||||
return
|
||||
|
||||
try:
|
||||
item = QueueItem.create(
|
||||
channel=channel,
|
||||
message=message,
|
||||
prompt=prompt,
|
||||
priority=priority,
|
||||
context=context,
|
||||
)
|
||||
await self.message_queue.put(item)
|
||||
logger.info(f"Added message to queue with priority {priority}")
|
||||
|
||||
# Ensure queue processor is running
|
||||
if not self._queue_processor or self._queue_processor.done():
|
||||
if not self._message_handler:
|
||||
logger.error(
|
||||
"Message handler not initialized, cannot restart queue processor"
|
||||
)
|
||||
return
|
||||
logger.warning("Queue processor not running, restarting...")
|
||||
self._queue_processor = asyncio.create_task(
|
||||
self._process_queue(self._message_handler)
|
||||
)
|
||||
await self.register_task(self._queue_processor)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add message to queue: {e}")
|
||||
523
discord_glhf/queue_manager.py
Normal file
523
discord_glhf/queue_manager.py
Normal file
@@ -0,0 +1,523 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Queue management with state persistence."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
import html
|
||||
import re
|
||||
from typing import Optional, Set, Dict, Any
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
import discord
|
||||
|
||||
from .config import (
|
||||
logger,
|
||||
MAX_QUEUE_SIZE,
|
||||
CONCURRENT_TASKS,
|
||||
MAX_USER_QUEUED_MESSAGES,
|
||||
ShutdownError,
|
||||
)
|
||||
from .queue_state import QueueState
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class QueueItem:
|
||||
channel: discord.TextChannel
|
||||
message: discord.Message
|
||||
prompt: str
|
||||
priority: int
|
||||
timestamp: float
|
||||
user_id: int
|
||||
context: Dict[str, Any]
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.message.id, self.user_id, self.timestamp))
|
||||
|
||||
@staticmethod
|
||||
def sanitize_prompt(prompt: str) -> str:
|
||||
prompt = "".join(char for char in prompt if ord(char) >= 32 or char in "\n\t")
|
||||
prompt = html.escape(prompt)
|
||||
prompt = re.sub(
|
||||
r"(`{1,3}|~{2}|\|{2}|\*{1,3}|_{1,3})", lambda m: "\\" + m.group(0), prompt
|
||||
)
|
||||
return prompt.strip()
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
channel: discord.TextChannel,
|
||||
message: discord.Message,
|
||||
prompt: str,
|
||||
priority: int = 2,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
) -> "QueueItem":
|
||||
sanitized_prompt = cls.sanitize_prompt(prompt)
|
||||
message_context = {
|
||||
"message_id": str(message.id),
|
||||
"channel_id": str(channel.id),
|
||||
"guild_id": str(message.guild.id) if hasattr(message, 'guild') and message.guild else None,
|
||||
"author": {
|
||||
"id": str(message.author.id),
|
||||
"name": message.author.name,
|
||||
"display_name": message.author.display_name,
|
||||
"nick": message.author.nick
|
||||
if hasattr(message.author, "nick")
|
||||
else None,
|
||||
"bot": message.author.bot,
|
||||
"discriminator": message.author.discriminator,
|
||||
"roles": [str(role.id) for role in message.author.roles]
|
||||
if hasattr(message.author, "roles")
|
||||
else [],
|
||||
"color": str(message.author.color)
|
||||
if hasattr(message.author, "color")
|
||||
else None,
|
||||
},
|
||||
"timestamp": message.created_at.isoformat() if message.created_at else None,
|
||||
"referenced_message": None,
|
||||
}
|
||||
|
||||
if message.reference and message.reference.resolved:
|
||||
ref_msg = message.reference.resolved
|
||||
message_context["referenced_message"] = {
|
||||
"id": str(ref_msg.id),
|
||||
"content": cls.sanitize_prompt(ref_msg.content),
|
||||
"author": {
|
||||
"id": str(ref_msg.author.id),
|
||||
"name": ref_msg.author.name,
|
||||
"display_name": ref_msg.author.display_name,
|
||||
"nick": ref_msg.author.nick
|
||||
if hasattr(ref_msg.author, "nick")
|
||||
else None,
|
||||
"bot": ref_msg.author.bot,
|
||||
"discriminator": ref_msg.author.discriminator,
|
||||
"roles": [str(role.id) for role in ref_msg.author.roles]
|
||||
if hasattr(ref_msg.author, "roles")
|
||||
else [],
|
||||
"color": str(ref_msg.author.color)
|
||||
if hasattr(ref_msg.author, "color")
|
||||
else None,
|
||||
},
|
||||
}
|
||||
|
||||
if context:
|
||||
message_context.update(context)
|
||||
|
||||
return cls(
|
||||
channel=channel,
|
||||
message=message,
|
||||
prompt=sanitized_prompt,
|
||||
priority=priority,
|
||||
timestamp=time.time(),
|
||||
user_id=message.author.id,
|
||||
context=message_context,
|
||||
)
|
||||
|
||||
|
||||
class MessageQueue:
|
||||
def __init__(self, state: QueueState):
|
||||
self.queue = asyncio.PriorityQueue()
|
||||
self.processing: Set[Any] = set()
|
||||
self.user_queues = defaultdict(int)
|
||||
self._lock = asyncio.Lock()
|
||||
self.semaphore = asyncio.Semaphore(CONCURRENT_TASKS)
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
self._last_cleanup = time.time()
|
||||
self._active = True
|
||||
self._state = state
|
||||
self.user_queues.update(self._state.get_state("user_queues") or {})
|
||||
|
||||
async def put(self, item: Any) -> None:
|
||||
async with self._lock:
|
||||
if self.user_queues[item.user_id] >= MAX_USER_QUEUED_MESSAGES:
|
||||
logger.warning(f"User {item.user_id} has too many queued messages")
|
||||
return
|
||||
|
||||
if self.queue.qsize() >= MAX_QUEUE_SIZE:
|
||||
logger.warning("Queue is full, clearing old messages")
|
||||
while self.queue.qsize() > MAX_QUEUE_SIZE // 2:
|
||||
try:
|
||||
_, _, item = self.queue.get_nowait()
|
||||
self.user_queues[item.user_id] = max(
|
||||
0, self.user_queues[item.user_id] - 1
|
||||
)
|
||||
self._state.update_user_queue(
|
||||
str(item.user_id), self.user_queues[item.user_id]
|
||||
)
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
self.user_queues[item.user_id] += 1
|
||||
self._state.update_user_queue(
|
||||
str(item.user_id), self.user_queues[item.user_id]
|
||||
)
|
||||
await self.queue.put((item.priority, item.timestamp, item))
|
||||
logger.info(
|
||||
f"Added message to queue (size: {self.queue.qsize()}, "
|
||||
f"user: {item.user_id}, priority: {item.priority})"
|
||||
)
|
||||
|
||||
async def get(self) -> Optional[Any]:
|
||||
try:
|
||||
_, _, item = await self.queue.get()
|
||||
self.processing.add(item)
|
||||
logger.debug(f"Retrieved message from queue for user {item.user_id}")
|
||||
return item
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting item from queue: {e}")
|
||||
return None
|
||||
|
||||
async def task_done(self, item: Any) -> None:
|
||||
async with self._lock:
|
||||
if item in self.processing:
|
||||
self.processing.discard(item)
|
||||
self.user_queues[item.user_id] = max(
|
||||
0, self.user_queues[item.user_id] - 1
|
||||
)
|
||||
self._state.update_user_queue(
|
||||
str(item.user_id), self.user_queues[item.user_id]
|
||||
)
|
||||
try:
|
||||
self.queue.task_done()
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Task already marked as done for user {item.user_id}"
|
||||
)
|
||||
|
||||
|
||||
class QueueManager:
|
||||
def __init__(self, state_file: str = "queue_state.json"):
|
||||
self._state = QueueState(state_file)
|
||||
self._shutting_down = asyncio.Event()
|
||||
self._active_tasks: Set[asyncio.Task] = set()
|
||||
self._lock = asyncio.Lock()
|
||||
self.message_queue = MessageQueue(self._state)
|
||||
self._queue_processor: Optional[asyncio.Task] = None
|
||||
self._health_check_task: Optional[asyncio.Task] = None
|
||||
self._message_handler = None
|
||||
self._active = self._state.get_state("active")
|
||||
self._total_processed = self._state.get_state("total_processed") or 0
|
||||
self._failed_messages = self._state.get_state("failed_messages") or 0
|
||||
|
||||
async def _process_queue(self, message_handler) -> None:
|
||||
processor_id = str(uuid.uuid4())[:8]
|
||||
logger.info(f"Queue processor {processor_id} starting")
|
||||
|
||||
self._state.set_processor_active(processor_id, True)
|
||||
last_heartbeat = time.time()
|
||||
heartbeat_interval = 60
|
||||
|
||||
async def process_message(item):
|
||||
processing_start = time.time()
|
||||
try:
|
||||
# Add to pending before processing
|
||||
self._state.add_pending_message(
|
||||
{
|
||||
"message_id": str(item.message.id),
|
||||
"channel_id": str(item.channel.id),
|
||||
"user_id": str(item.user_id),
|
||||
"prompt": item.prompt,
|
||||
"context": item.context,
|
||||
"timestamp": item.timestamp,
|
||||
}
|
||||
)
|
||||
|
||||
async with self.message_queue.semaphore:
|
||||
# Get the API endpoint's timeout from the context
|
||||
api_timeout = float(os.getenv(item.context.get('timeout_env', 'DEFAULT_TIMEOUT'), '120.0'))
|
||||
# Get timeout from context
|
||||
timeout_env = item.context.get('timeout_env', 'DEFAULT_TIMEOUT')
|
||||
raw_timeout = os.getenv(timeout_env)
|
||||
api_timeout = float(raw_timeout) if raw_timeout else 120.0
|
||||
|
||||
# Log detailed timeout information
|
||||
logger.info(
|
||||
f"Message {item.message.id} timeout details:\n"
|
||||
f"- Timeout env var: {timeout_env}\n"
|
||||
f"- Raw env value: {raw_timeout}\n"
|
||||
f"- Final timeout: {api_timeout}s\n"
|
||||
f"- Context dump: {item.context}"
|
||||
)
|
||||
await asyncio.wait_for(message_handler(item), timeout=api_timeout)
|
||||
self._total_processed += 1
|
||||
self._state.increment_counter("total_processed")
|
||||
self._state.update_processed_time()
|
||||
|
||||
# Success - remove from pending
|
||||
self._state.remove_pending_message(str(item.message.id))
|
||||
|
||||
processing_time = time.time() - processing_start
|
||||
logger.info(
|
||||
f"Queue processor {processor_id} - "
|
||||
f"Processed message {item.message.id} in {processing_time:.2f}s "
|
||||
f"(Total: {self._total_processed})"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self._failed_messages += 1
|
||||
self._state.increment_counter("failed_messages")
|
||||
logger.error(
|
||||
f"Message {item.message.id} processing timed out "
|
||||
f"after {time.time() - processing_start:.2f}s"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
self._failed_messages += 1
|
||||
self._state.increment_counter("failed_messages")
|
||||
logger.error(f"Failed to process message: {error_msg}")
|
||||
|
||||
if "Invalid response from" in error_msg:
|
||||
logger.warning(f"API endpoints failed for message {item.message.id}")
|
||||
|
||||
finally:
|
||||
# Mark task as done and remove from pending
|
||||
await self.message_queue.task_done(item)
|
||||
# Remove from pending if not already removed
|
||||
if str(item.message.id) in self._state.get_pending_messages():
|
||||
self._state.remove_pending_message(str(item.message.id))
|
||||
|
||||
while True:
|
||||
if self.shutting_down or not self._active:
|
||||
logger.warning(
|
||||
f"Queue processor {processor_id} stopping - "
|
||||
f"Active: {self._active}, Shutting down: {self.shutting_down}"
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
current_time = time.time()
|
||||
if current_time - last_heartbeat >= heartbeat_interval:
|
||||
state_info = self._state.get_processor_info()
|
||||
logger.info(
|
||||
f"Queue processor {processor_id} heartbeat - "
|
||||
f"Active: {state_info['active']}, "
|
||||
f"Queue size: {self.message_queue.queue.qsize()}, "
|
||||
f"Processing: {len(self.message_queue.processing)}, "
|
||||
f"Total processed: {self._total_processed}"
|
||||
)
|
||||
last_heartbeat = current_time
|
||||
self._state.update_processed_time()
|
||||
|
||||
try:
|
||||
item = await asyncio.wait_for(self.message_queue.get(), timeout=30)
|
||||
if not item:
|
||||
continue
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting message from queue: {e}")
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
# Process message directly instead of creating a task
|
||||
await process_message(item)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Queue processor error: {e}")
|
||||
if self._active and not self.shutting_down:
|
||||
await asyncio.sleep(1)
|
||||
else:
|
||||
break
|
||||
|
||||
self._state.set_processor_active(processor_id, False)
|
||||
logger.warning(
|
||||
f"Queue processor {processor_id} exited - "
|
||||
f"Processed: {self._total_processed}, Failed: {self._failed_messages}"
|
||||
)
|
||||
|
||||
async def _restore_pending_messages(self) -> None:
|
||||
pending_messages = self._state.get_pending_messages()
|
||||
if pending_messages:
|
||||
logger.info(f"Restoring {len(pending_messages)} pending messages")
|
||||
for msg_data in pending_messages:
|
||||
try:
|
||||
# Remove pending message since it can't be restored
|
||||
logger.info(f"Removing unrestorable message {msg_data.get('message_id')} from pending list")
|
||||
self._state.remove_pending_message(msg_data.get('message_id'))
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to restore message: {e}")
|
||||
|
||||
async def start(self, message_handler=None) -> None:
|
||||
async with self._lock:
|
||||
if message_handler:
|
||||
self._message_handler = message_handler
|
||||
self._active = True
|
||||
self._shutting_down.clear()
|
||||
self._state.update_state(active=True)
|
||||
|
||||
for task in [self._queue_processor, self._health_check_task]:
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
await self._restore_pending_messages()
|
||||
|
||||
if self._message_handler:
|
||||
self._queue_processor = asyncio.create_task(
|
||||
self._process_queue(self._message_handler)
|
||||
)
|
||||
await self.register_task(self._queue_processor)
|
||||
logger.info("Queue processor started")
|
||||
|
||||
self._health_check_task = asyncio.create_task(self._health_check())
|
||||
await self.register_task(self._health_check_task)
|
||||
logger.info("Queue health check started")
|
||||
else:
|
||||
logger.warning("Queue started without message handler")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start queue processor: {e}")
|
||||
self._active = False
|
||||
self._state.update_state(active=False)
|
||||
raise
|
||||
|
||||
async def stop(self) -> None:
|
||||
logger.info("Stopping queue processor...")
|
||||
self._shutting_down.set()
|
||||
self._active = False
|
||||
self._state.update_state(active=False)
|
||||
|
||||
for task in [self._health_check_task, self._queue_processor]:
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
self._state.update_state(
|
||||
total_processed=self._total_processed, failed_messages=self._failed_messages
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Queue processor stopped - "
|
||||
f"Processed: {self._total_processed}, "
|
||||
f"Failed: {self._failed_messages}"
|
||||
)
|
||||
|
||||
@property
|
||||
def shutting_down(self) -> bool:
|
||||
return self._shutting_down.is_set()
|
||||
|
||||
def set_shutting_down(self, value: bool = True) -> None:
|
||||
"""Set the shutting down state."""
|
||||
if value:
|
||||
self._active = False
|
||||
self._shutting_down.set()
|
||||
else:
|
||||
self._active = True
|
||||
self._shutting_down.clear()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return bool(
|
||||
self._queue_processor
|
||||
and not self._queue_processor.done()
|
||||
and not self.shutting_down
|
||||
)
|
||||
|
||||
async def add_message(
|
||||
self,
|
||||
channel: discord.TextChannel,
|
||||
message: discord.Message,
|
||||
prompt: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
priority: int = 2,
|
||||
) -> None:
|
||||
"""Add a message to the queue with the specified priority."""
|
||||
if self.shutting_down:
|
||||
return
|
||||
|
||||
try:
|
||||
item = QueueItem.create(
|
||||
channel=channel,
|
||||
message=message,
|
||||
prompt=prompt,
|
||||
priority=priority,
|
||||
context=context,
|
||||
)
|
||||
await self.message_queue.put(item)
|
||||
logger.info(f"Added message to queue with priority {priority}")
|
||||
|
||||
# Ensure queue processor is running
|
||||
if not self._queue_processor or self._queue_processor.done():
|
||||
if not self._message_handler:
|
||||
logger.error(
|
||||
"Message handler not initialized, cannot restart queue processor"
|
||||
)
|
||||
return
|
||||
logger.warning("Queue processor not running, restarting...")
|
||||
self._queue_processor = asyncio.create_task(
|
||||
self._process_queue(self._message_handler)
|
||||
)
|
||||
await self.register_task(self._queue_processor)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add message to queue: {e}")
|
||||
|
||||
async def register_task(self, task: asyncio.Task) -> None:
|
||||
if self.shutting_down:
|
||||
raise ShutdownError("Queue manager is shutting down")
|
||||
async with self._lock:
|
||||
self._active_tasks.add(task)
|
||||
task.add_done_callback(lambda t: asyncio.create_task(self._remove_task(t)))
|
||||
|
||||
async def _remove_task(self, task: asyncio.Task) -> None:
|
||||
async with self._lock:
|
||||
self._active_tasks.discard(task)
|
||||
|
||||
async def _health_check(self) -> None:
|
||||
logger.info("Queue health check started")
|
||||
check_interval = 60
|
||||
restart_threshold = 300
|
||||
|
||||
while not self.shutting_down and self._active:
|
||||
try:
|
||||
state_info = self._state.get_processor_info()
|
||||
time_since_last_process = (
|
||||
time.time() - state_info["last_processed_time"]
|
||||
)
|
||||
|
||||
if (
|
||||
time_since_last_process > restart_threshold
|
||||
or not self._queue_processor
|
||||
or self._queue_processor.done()
|
||||
):
|
||||
logger.warning(
|
||||
f"Queue appears stalled - "
|
||||
f"Time since last process: {time_since_last_process:.1f}s"
|
||||
)
|
||||
|
||||
if self._message_handler:
|
||||
logger.info("Restarting queue processor")
|
||||
if self._queue_processor:
|
||||
self._queue_processor.cancel()
|
||||
try:
|
||||
await self._queue_processor
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._queue_processor = asyncio.create_task(
|
||||
self._process_queue(self._message_handler)
|
||||
)
|
||||
await self.register_task(self._queue_processor)
|
||||
else:
|
||||
logger.error("Cannot restart queue - no message handler")
|
||||
|
||||
await asyncio.sleep(check_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Health check error: {e}")
|
||||
await asyncio.sleep(check_interval)
|
||||
148
discord_glhf/queue_state.py
Normal file
148
discord_glhf/queue_state.py
Normal file
@@ -0,0 +1,148 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Queue state persistence."""
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
from .config import logger
|
||||
|
||||
|
||||
class QueueState:
|
||||
"""Manages queue state persistence."""
|
||||
|
||||
def __init__(self, state_file: str = "queue_state.json"):
|
||||
self.state_file = Path(state_file)
|
||||
self.state: Dict[str, Any] = {
|
||||
"total_processed": 0,
|
||||
"failed_messages": 0,
|
||||
"last_processed_time": time.time(),
|
||||
"user_queues": {},
|
||||
"last_save": time.time(),
|
||||
"processor_id": None,
|
||||
"active": True,
|
||||
"pending_messages": [], # List of unprocessed message IDs and their data
|
||||
"last_channel_id": None, # Last channel we were processing
|
||||
}
|
||||
self.load_state()
|
||||
|
||||
def load_state(self) -> None:
|
||||
"""Load queue state from file."""
|
||||
try:
|
||||
if self.state_file.exists():
|
||||
with open(self.state_file, "r") as f:
|
||||
saved_state = json.load(f)
|
||||
# Update state while preserving structure
|
||||
for key in self.state:
|
||||
if key in saved_state:
|
||||
self.state[key] = saved_state[key]
|
||||
logger.info(
|
||||
f"Queue state loaded from file - "
|
||||
f"Pending messages: {len(self.state['pending_messages'])}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load queue state: {e}")
|
||||
|
||||
def save_state(self) -> None:
|
||||
"""Save current queue state to file."""
|
||||
try:
|
||||
# Update last save time
|
||||
self.state["last_save"] = time.time()
|
||||
|
||||
# Ensure directory exists
|
||||
self.state_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save state atomically using temporary file
|
||||
temp_file = self.state_file.with_suffix(".tmp")
|
||||
with open(temp_file, "w") as f:
|
||||
json.dump(self.state, f, indent=2)
|
||||
temp_file.replace(self.state_file)
|
||||
|
||||
logger.debug(
|
||||
f"Queue state saved - "
|
||||
f"Pending messages: {len(self.state['pending_messages'])}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save queue state: {e}")
|
||||
|
||||
def add_pending_message(self, message_data: Dict[str, Any]) -> None:
|
||||
"""Add a message to the pending list."""
|
||||
if message_data not in self.state["pending_messages"]:
|
||||
self.state["pending_messages"].append(message_data)
|
||||
self.state["last_channel_id"] = message_data.get("channel_id")
|
||||
self.save_state()
|
||||
logger.info(
|
||||
f"Added message {message_data.get('message_id')} to pending list "
|
||||
f"(Total pending: {len(self.state['pending_messages'])})"
|
||||
)
|
||||
|
||||
def remove_pending_message(self, message_id: str) -> None:
|
||||
"""Remove a message from the pending list."""
|
||||
self.state["pending_messages"] = [
|
||||
msg
|
||||
for msg in self.state["pending_messages"]
|
||||
if msg.get("message_id") != message_id
|
||||
]
|
||||
self.save_state()
|
||||
logger.info(
|
||||
f"Removed message {message_id} from pending list "
|
||||
f"(Total pending: {len(self.state['pending_messages'])})"
|
||||
)
|
||||
|
||||
def get_pending_messages(self) -> List[Dict[str, Any]]:
|
||||
"""Get all pending messages."""
|
||||
return self.state["pending_messages"]
|
||||
|
||||
def update_state(self, **kwargs) -> None:
|
||||
"""Update queue state with new values."""
|
||||
for key, value in kwargs.items():
|
||||
if key in self.state:
|
||||
self.state[key] = value
|
||||
self.save_state()
|
||||
|
||||
def get_state(self, key: str) -> Optional[Any]:
|
||||
"""Get a value from the queue state."""
|
||||
return self.state.get(key)
|
||||
|
||||
def increment_counter(self, counter: str, amount: int = 1) -> None:
|
||||
"""Increment a counter in the state."""
|
||||
if counter in self.state and isinstance(self.state[counter], (int, float)):
|
||||
self.state[counter] += amount
|
||||
self.save_state()
|
||||
|
||||
def update_user_queue(self, user_id: str, count: int) -> None:
|
||||
"""Update user queue count."""
|
||||
if count > 0:
|
||||
self.state["user_queues"][user_id] = count
|
||||
else:
|
||||
self.state["user_queues"].pop(user_id, None)
|
||||
self.save_state()
|
||||
|
||||
def clear_user_queues(self) -> None:
|
||||
"""Clear all user queue counts."""
|
||||
self.state["user_queues"] = {}
|
||||
self.save_state()
|
||||
|
||||
def set_processor_active(self, processor_id: str, active: bool = True) -> None:
|
||||
"""Set the current processor ID and active state."""
|
||||
self.state["processor_id"] = processor_id if active else None
|
||||
self.state["active"] = active
|
||||
self.state["last_processed_time"] = time.time()
|
||||
self.save_state()
|
||||
|
||||
def update_processed_time(self) -> None:
|
||||
"""Update the last processed time."""
|
||||
self.state["last_processed_time"] = time.time()
|
||||
# Only save periodically to reduce I/O
|
||||
if time.time() - self.state["last_save"] > 60: # Save at most once per minute
|
||||
self.save_state()
|
||||
|
||||
def get_processor_info(self) -> Dict[str, Any]:
|
||||
"""Get current processor information."""
|
||||
return {
|
||||
"processor_id": self.state["processor_id"],
|
||||
"active": self.state["active"],
|
||||
"last_processed_time": self.state["last_processed_time"],
|
||||
"pending_messages": len(self.state["pending_messages"]),
|
||||
}
|
||||
21
discord_glhf/responses.json
Normal file
21
discord_glhf/responses.json
Normal file
@@ -0,0 +1,21 @@
|
||||
{
|
||||
"fallback_responses": [
|
||||
"Hey there! CobraSilver here - I'm processing your message.",
|
||||
"CobraSilver speaking - let me think about that.",
|
||||
"Your bot CobraSilver is analyzing that input.",
|
||||
"Processing that request. -CobraSilver",
|
||||
"CobraSilver here, working on your request.",
|
||||
"Let me compute that for you. -CobraSilver",
|
||||
"CobraSilver processing... please stand by.",
|
||||
"Your friendly bot CobraSilver is on it.",
|
||||
"Computing response... -CobraSilver",
|
||||
"CobraSilver.exe is running calculations."
|
||||
],
|
||||
"error_responses": [
|
||||
"CobraSilver here - one moment please.",
|
||||
"Processing request... -CobraSilver",
|
||||
"CobraSilver needs a quick moment.",
|
||||
"Your bot CobraSilver is computing.",
|
||||
"System processing... -CobraSilver"
|
||||
]
|
||||
}
|
||||
139
discord_glhf/state.py
Normal file
139
discord_glhf/state.py
Normal file
@@ -0,0 +1,139 @@
|
||||
#!/usr/bin/env python3
|
||||
"""State management for the Discord bot."""
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime, timezone
|
||||
from .config import logger, API_HEALTH_CHECK_INTERVAL
|
||||
|
||||
|
||||
class StateManager:
|
||||
"""Manages persistent state for the Discord bot."""
|
||||
|
||||
def __init__(self, state_file: Optional[str] = None):
|
||||
"""Initialize the state manager."""
|
||||
self.state_file = Path(state_file or "bot_state.json")
|
||||
self._state: Dict[str, Any] = self._load_state()
|
||||
|
||||
def _load_state(self) -> Dict[str, Any]:
|
||||
"""Load state from file or create default state."""
|
||||
if self.state_file.exists():
|
||||
try:
|
||||
with open(self.state_file, "r") as f:
|
||||
state = json.load(f)
|
||||
logger.info(f"Found existing state file at {self.state_file}")
|
||||
logger.info(
|
||||
f"Last health check was at: {datetime.fromtimestamp(state.get('last_health_check', 0))}"
|
||||
)
|
||||
return state
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading state file: {e}")
|
||||
|
||||
# Create default state
|
||||
logger.info(f"No state file found at {self.state_file}, creating new one")
|
||||
default_state = {
|
||||
"last_health_check": 0, # Unix timestamp of last health check
|
||||
"api_status": {}, # Status of each API endpoint
|
||||
"uptime_start": int(time.time()), # Bot start time
|
||||
"total_messages_processed": 0,
|
||||
"last_updated": int(time.time()),
|
||||
}
|
||||
|
||||
# Ensure the state file is created immediately
|
||||
try:
|
||||
self.state_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(self.state_file, "w") as f:
|
||||
json.dump(default_state, f, indent=2)
|
||||
logger.info("Created new state file with default configuration")
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating state file: {e}")
|
||||
|
||||
return default_state
|
||||
|
||||
def _save_state(self) -> None:
|
||||
"""Save current state to file."""
|
||||
try:
|
||||
# Ensure directory exists
|
||||
self.state_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Update last_updated timestamp
|
||||
self._state["last_updated"] = int(time.time())
|
||||
|
||||
# Write state to file
|
||||
with open(self.state_file, "w") as f:
|
||||
json.dump(self._state, f, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving state file: {e}")
|
||||
|
||||
def should_run_health_check(self) -> bool:
|
||||
"""Check if enough time has passed since the last health check."""
|
||||
current_time = int(time.time())
|
||||
last_check = self._state.get("last_health_check", 0)
|
||||
|
||||
# Check if enough time has passed since last health check
|
||||
return (current_time - last_check) >= API_HEALTH_CHECK_INTERVAL
|
||||
|
||||
def update_health_check(self, api_status: Dict[str, bool]) -> None:
|
||||
"""Update health check state."""
|
||||
current_time = int(time.time())
|
||||
self._state["last_health_check"] = current_time
|
||||
self._state["api_status"] = api_status
|
||||
|
||||
# Log the health check update
|
||||
logger.info(
|
||||
f"Updating health check state at: {datetime.fromtimestamp(current_time)}"
|
||||
)
|
||||
for api_url, is_healthy in api_status.items():
|
||||
logger.info(
|
||||
f"API Status - {api_url}: {'healthy' if is_healthy else 'unhealthy'}"
|
||||
)
|
||||
|
||||
self._save_state()
|
||||
logger.info(f"Health check state saved to {self.state_file}")
|
||||
|
||||
def increment_messages_processed(self) -> None:
|
||||
"""Increment the total messages processed counter."""
|
||||
self._state["total_messages_processed"] += 1
|
||||
total = self._state["total_messages_processed"]
|
||||
logger.debug(f"Messages processed: {total}")
|
||||
|
||||
# Only save every 10 messages to reduce disk I/O
|
||||
if total % 10 == 0:
|
||||
logger.info(f"Milestone: {total} messages processed")
|
||||
self._save_state()
|
||||
|
||||
def get_uptime(self) -> str:
|
||||
"""Get bot uptime as a formatted string."""
|
||||
uptime_seconds = int(time.time()) - self._state["uptime_start"]
|
||||
days = uptime_seconds // 86400
|
||||
hours = (uptime_seconds % 86400) // 3600
|
||||
minutes = (uptime_seconds % 3600) // 60
|
||||
seconds = uptime_seconds % 60
|
||||
|
||||
parts = []
|
||||
if days > 0:
|
||||
parts.append(f"{days}d")
|
||||
if hours > 0:
|
||||
parts.append(f"{hours}h")
|
||||
if minutes > 0:
|
||||
parts.append(f"{minutes}m")
|
||||
parts.append(f"{seconds}s")
|
||||
|
||||
return " ".join(parts)
|
||||
|
||||
def get_state_summary(self) -> Dict[str, Any]:
|
||||
"""Get a summary of the current state."""
|
||||
return {
|
||||
"uptime": self.get_uptime(),
|
||||
"messages_processed": self._state["total_messages_processed"],
|
||||
"last_health_check": datetime.fromtimestamp(
|
||||
self._state["last_health_check"], tz=timezone.utc
|
||||
).isoformat(),
|
||||
"api_status": self._state["api_status"],
|
||||
"last_updated": datetime.fromtimestamp(
|
||||
self._state["last_updated"], tz=timezone.utc
|
||||
).isoformat(),
|
||||
}
|
||||
115
discord_glhf/training.py
Normal file
115
discord_glhf/training.py
Normal file
@@ -0,0 +1,115 @@
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainingManager:
|
||||
def __init__(self, exports_dir: str = "data/chat_exports"):
|
||||
self.exports_dir = Path(exports_dir)
|
||||
self.processed_files_path = self.exports_dir / ".processed_files.json"
|
||||
self.processed_files = self._load_processed_files()
|
||||
self.batch_size = 10 # Number of files to process in one batch
|
||||
self.training_interval = 3600 # Run every hour
|
||||
self.is_running = False
|
||||
self._training_task = None
|
||||
|
||||
def _load_processed_files(self) -> set:
|
||||
"""Load the set of already processed files"""
|
||||
if self.processed_files_path.exists():
|
||||
try:
|
||||
with open(self.processed_files_path, 'r') as f:
|
||||
return set(json.load(f))
|
||||
except json.JSONDecodeError:
|
||||
logger.error(
|
||||
"Error loading processed files list, starting fresh")
|
||||
return set()
|
||||
return set()
|
||||
|
||||
def _save_processed_files(self):
|
||||
"""Save the set of processed files"""
|
||||
with open(self.processed_files_path, 'w') as f:
|
||||
json.dump(list(self.processed_files), f)
|
||||
|
||||
async def process_batch(self):
|
||||
"""Process a batch of export files"""
|
||||
if not self.exports_dir.exists():
|
||||
logger.warning(f"Exports directory {
|
||||
self.exports_dir} does not exist")
|
||||
return
|
||||
|
||||
# Get list of unprocessed files
|
||||
all_files = [f for f in self.exports_dir.glob("*.json")
|
||||
if f.name != ".processed_files.json"]
|
||||
unprocessed = [f for f in all_files
|
||||
if f.name not in self.processed_files]
|
||||
|
||||
if not unprocessed:
|
||||
return # No new files to process
|
||||
|
||||
# Take a batch of files
|
||||
batch = unprocessed[:self.batch_size]
|
||||
|
||||
try:
|
||||
# Here you would implement the actual training logic
|
||||
# For example:
|
||||
for file in batch:
|
||||
logger.info(f"Processing file: {file.name}")
|
||||
# TODO: Implement actual training logic here
|
||||
# Example:
|
||||
# with open(file, 'r') as f:
|
||||
# data = json.load(f)
|
||||
# await train_model(data)
|
||||
|
||||
# Mark as processed
|
||||
self.processed_files.add(file.name)
|
||||
|
||||
# Save progress
|
||||
self._save_processed_files()
|
||||
|
||||
if batch:
|
||||
logger.info(f"Successfully processed batch of {
|
||||
len(batch)} files")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing batch: {str(e)}")
|
||||
|
||||
async def start(self):
|
||||
"""Start the periodic training loop"""
|
||||
if self.is_running:
|
||||
return
|
||||
|
||||
logger.info("Starting periodic training manager")
|
||||
self.is_running = True
|
||||
self._training_task = asyncio.create_task(self._run())
|
||||
|
||||
async def _run(self):
|
||||
"""Internal run loop"""
|
||||
while self.is_running:
|
||||
try:
|
||||
await self.process_batch()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training loop: {str(e)}")
|
||||
|
||||
# Wait for next interval
|
||||
await asyncio.sleep(self.training_interval)
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the training manager"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
logger.info("Stopping training manager")
|
||||
self.is_running = False
|
||||
|
||||
if self._training_task:
|
||||
self._training_task.cancel()
|
||||
try:
|
||||
await self._training_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._training_task = None
|
||||
Reference in New Issue
Block a user