Compare commits
7 Commits
de3b09bf46
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1a4b9b184a | ||
|
|
5550df9434 | ||
|
|
5c38774e2f | ||
|
|
e8f57b976a | ||
|
|
937bc18ca7 | ||
|
|
f40d19294f | ||
|
|
898937acb4 |
14
.gitignore
vendored
14
.gitignore
vendored
@@ -51,3 +51,17 @@ discord_glhf/handlers/__pycache__/message_handler.cpython-313.pyc
|
||||
discord_glhf/handlers/__pycache__/tool_handler.cpython-313.pyc
|
||||
discord_glhf/web/__pycache__/app.cpython-313.pyc
|
||||
.env
|
||||
discord_glhf/__pycache__/__init__.cpython-313.pyc
|
||||
discord_glhf/__pycache__/config.cpython-313.pyc
|
||||
discord_glhf/__pycache__/database.cpython-313.pyc
|
||||
discord_glhf/__pycache__/main.cpython-313.pyc
|
||||
discord_glhf/__pycache__/queue_manager.cpython-313.pyc
|
||||
discord_glhf/__pycache__/queue_state.cpython-313.pyc
|
||||
discord_glhf/__pycache__/queue.cpython-313.pyc
|
||||
discord_glhf/__pycache__/training.cpython-313.pyc
|
||||
discord_glhf/handlers/event_handler.py
|
||||
discord_glhf/handlers/__pycache__/__init__.cpython-313.pyc
|
||||
discord_glhf/handlers/__pycache__/image_handler.cpython-313.pyc
|
||||
discord_glhf/web/__pycache__/__init__.cpython-313.pyc
|
||||
.DS_Store
|
||||
discord_bot.log.*
|
||||
|
||||
@@ -40,6 +40,7 @@ class DiscordBot:
|
||||
self.db_pool = DatabasePool()
|
||||
self.db_manager = DatabaseManager(self.db_pool)
|
||||
self._initialized = False
|
||||
self._running = True
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
||||
# Initialize handler references
|
||||
@@ -123,10 +124,13 @@ class DiscordBot:
|
||||
max_retries = 5
|
||||
base_delay = 1.0
|
||||
|
||||
while retry_count < max_retries:
|
||||
while retry_count < max_retries and self._running: # Check if we're still running
|
||||
try:
|
||||
await self.bot.connect()
|
||||
# Add a timeout to connection
|
||||
await asyncio.wait_for(self.bot.connect(), timeout=60.0)
|
||||
return
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Connection attempt timed out, retrying...")
|
||||
except (aiohttp.ClientError, socket.gaierror) as e:
|
||||
retry_count += 1
|
||||
if retry_count == max_retries:
|
||||
@@ -185,11 +189,15 @@ class DiscordBot:
|
||||
web_port = int(os.getenv('WEB_PORT', '8080'))
|
||||
config = Config()
|
||||
config.bind = [f"0.0.0.0:{web_port}"]
|
||||
# Allow signals to propagate to main process
|
||||
config.use_reloader = False
|
||||
config.shutdown_timeout = 3.0
|
||||
self.web_app = init_app(self.event_handler)
|
||||
|
||||
# Start web interface in background task
|
||||
# Start web interface in background task with signal handling disabled
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.create_task(serve(self.web_app, config))
|
||||
web_task = loop.create_task(serve(self.web_app, config))
|
||||
web_task.set_name('hypercorn_web') # Name task for identification during shutdown
|
||||
logger.info(f"Web interface starting at http://localhost:{web_port}")
|
||||
|
||||
# Start API manager
|
||||
@@ -251,36 +259,55 @@ class DiscordBot:
|
||||
logger.error(
|
||||
f"Error before event_handler initialization: {exc_value}")
|
||||
|
||||
self._running = True
|
||||
try:
|
||||
async with self.bot:
|
||||
await self.bot.start(token)
|
||||
while True:
|
||||
while self._running:
|
||||
try:
|
||||
await self._handle_connection(token)
|
||||
except (aiohttp.ClientError, socket.gaierror) as e:
|
||||
logger.error(f"Connection error: {e}")
|
||||
await asyncio.sleep(5) # Wait before reconnecting
|
||||
if self._running: # Only log and retry if we're still meant to be running
|
||||
logger.error(f"Connection error: {e}")
|
||||
await asyncio.sleep(5) # Wait before reconnecting
|
||||
else:
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
self._running = False
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
await asyncio.sleep(5) # Wait before reconnecting
|
||||
if self._running:
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
await asyncio.sleep(5) # Wait before reconnecting
|
||||
else:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start bot: {e}")
|
||||
raise
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the bot."""
|
||||
"""Stop the bot and set running flag to False."""
|
||||
logger.info("Initiating shutdown...")
|
||||
self._running = False
|
||||
|
||||
try:
|
||||
async with self._init_lock:
|
||||
# Stop queue processor first
|
||||
# Cancel Hypercorn web tasks first
|
||||
tasks = [t for t in asyncio.all_tasks() if t.get_name() == 'hypercorn_web']
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
logger.info("Web server tasks cancelled")
|
||||
|
||||
if self.queue_manager and self.queue_manager.is_running:
|
||||
await self.queue_manager.stop()
|
||||
logger.info("Queue processor stopped")
|
||||
|
||||
# Stop training manager
|
||||
if self.api_manager and self.api_manager.is_running:
|
||||
await self.api_manager.shutdown()
|
||||
logger.info("Stopped API health check loop")
|
||||
|
||||
if self.training_manager and self.training_manager.is_running:
|
||||
await self.training_manager.stop()
|
||||
logger.info("Training manager stopped")
|
||||
@@ -324,6 +351,7 @@ async def shutdown(
|
||||
) -> None:
|
||||
"""Handle shutdown signals."""
|
||||
logger.info(f"Received {signal_name}")
|
||||
|
||||
try:
|
||||
# Set a flag to prevent new tasks from starting
|
||||
bot.queue_manager.set_shutting_down()
|
||||
@@ -363,16 +391,21 @@ def run_bot():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Set up signal handlers
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(
|
||||
sig, lambda s=sig: asyncio.create_task(shutdown(s.name, bot, loop))
|
||||
)
|
||||
# Set up signal handlers for both SIGTERM and SIGINT
|
||||
loop.add_signal_handler(
|
||||
signal.SIGTERM,
|
||||
lambda: asyncio.create_task(shutdown('SIGTERM', bot, loop))
|
||||
)
|
||||
loop.add_signal_handler(
|
||||
signal.SIGINT,
|
||||
lambda: asyncio.create_task(shutdown('SIGINT', bot, loop))
|
||||
)
|
||||
|
||||
# Remove the line that sets the standard signal handler:
|
||||
# signal.signal(signal.SIGINT, signal.default_int_handler)
|
||||
|
||||
try:
|
||||
loop.run_until_complete(bot.start(token))
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received keyboard interrupt")
|
||||
except Exception as e:
|
||||
logger.error(f"Bot crashed: {e}")
|
||||
raise
|
||||
|
||||
@@ -198,6 +198,11 @@ BOT_OWNER_ID = int(os.getenv("BOT_OWNER_ID")) # Required
|
||||
AUTO_RESPONSE_CHANNEL_ID = int(
|
||||
os.getenv("AUTO_RESPONSE_CHANNEL_ID")) # Required
|
||||
|
||||
# Get allowed users for reset command
|
||||
RESET_ALLOWED_USERS = [int(id.strip()) for id in os.getenv("RESET_ALLOWED_USERS", "").split(",") if id.strip()]
|
||||
if not RESET_ALLOWED_USERS:
|
||||
RESET_ALLOWED_USERS = [BOT_OWNER_ID] # Default to bot owner if not configured
|
||||
|
||||
# Load system prompt
|
||||
SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT") or load_system_prompt()
|
||||
|
||||
|
||||
@@ -605,6 +605,18 @@ class DatabaseManager:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update thread activity: {e}")
|
||||
|
||||
async def clear_all_messages(self) -> None:
|
||||
"""Clear all messages from the database."""
|
||||
try:
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("DELETE FROM messages")
|
||||
await conn.commit()
|
||||
logger.info("All message history cleared")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear message history: {e}")
|
||||
raise
|
||||
|
||||
async def cleanup_old_messages(self):
|
||||
"""Clean up messages older than MESSAGE_CLEANUP_DAYS."""
|
||||
cleanup_date = datetime.now() - timedelta(days=MESSAGE_CLEANUP_DAYS)
|
||||
|
||||
@@ -9,7 +9,7 @@ from datetime import datetime
|
||||
from discord import Message, RawReactionActionEvent
|
||||
|
||||
from ..config import (
|
||||
logger, AUTO_RESPONSE_CHANNEL_ID, SYSTEM_PROMPT, BOT_OWNER_ID
|
||||
logger, AUTO_RESPONSE_CHANNEL_ID, SYSTEM_PROMPT, BOT_OWNER_ID, RESET_ALLOWED_USERS
|
||||
)
|
||||
from .message_handler import MessageHandler
|
||||
from .image_handler import ImageHandler
|
||||
@@ -41,7 +41,7 @@ class EventHandler:
|
||||
f"<@{mention[2:-1]}>" # Raw mention format
|
||||
]
|
||||
pattern = '|'.join(patterns)
|
||||
|
||||
|
||||
# Replace all mention formats with the proper mention
|
||||
return re.sub(pattern, mention, response)
|
||||
|
||||
@@ -150,6 +150,7 @@ class EventHandler:
|
||||
history = await self.db_manager.get_conversation_history(
|
||||
user_id=0,
|
||||
channel_id=payload.channel_id,
|
||||
limit=50 # Limit to recent context
|
||||
)
|
||||
logger.debug(f"Retrieved {len(history)} messages for context")
|
||||
|
||||
@@ -202,7 +203,8 @@ class EventHandler:
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
logger.info(f"Adding reaction response to queue from {user.display_name}")
|
||||
logger.info(
|
||||
f"Adding reaction response to queue from {user.display_name}")
|
||||
# Queue the reaction like a regular message
|
||||
await self.queue_manager.add_message(
|
||||
channel=channel,
|
||||
@@ -222,6 +224,33 @@ class EventHandler:
|
||||
return
|
||||
|
||||
try:
|
||||
# Handle !reset_personality command
|
||||
if message.content.strip() == "!reset_personality":
|
||||
# Check if user is allowed (in allowed list, has admin, or is bot owner)
|
||||
is_allowed = (
|
||||
message.author.id in RESET_ALLOWED_USERS or
|
||||
message.author.id == BOT_OWNER_ID or
|
||||
(hasattr(message.author, 'guild_permissions')
|
||||
and message.author.guild_permissions.administrator)
|
||||
)
|
||||
|
||||
if not is_allowed:
|
||||
logger.warning(
|
||||
f"Unauthorized reset attempt by {message.author.name} ({message.author.id})")
|
||||
return
|
||||
|
||||
try:
|
||||
# React with checkmark
|
||||
await message.add_reaction("✅")
|
||||
# Clear all messages
|
||||
await self.db_manager.clear_all_messages()
|
||||
logger.info(
|
||||
f"Personality reset by {message.author.name} ({message.author.id})")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reset personality: {e}")
|
||||
return
|
||||
|
||||
# Only respond in configured channel or its threads
|
||||
if (message.channel.id != AUTO_RESPONSE_CHANNEL_ID and
|
||||
(not hasattr(message.channel, 'parent_id') or
|
||||
@@ -230,7 +259,8 @@ class EventHandler:
|
||||
|
||||
# Early duplicate checks before any processing
|
||||
if any(item.message.id == message.id for item in self.queue_manager.message_queue.processing):
|
||||
logger.debug(f"Message {message.id} already in processing, skipping")
|
||||
logger.debug(
|
||||
f"Message {message.id} already in processing, skipping")
|
||||
return
|
||||
|
||||
# Get current queue size
|
||||
@@ -244,7 +274,8 @@ class EventHandler:
|
||||
message_id=message.id
|
||||
)
|
||||
if message_processed:
|
||||
logger.debug(f"Message {message.id} already processed, skipping")
|
||||
logger.debug(
|
||||
f"Message {message.id} already processed, skipping")
|
||||
return
|
||||
|
||||
# Check for duplicate content in history
|
||||
@@ -252,12 +283,13 @@ class EventHandler:
|
||||
user_id=0,
|
||||
channel_id=message.channel.id
|
||||
)
|
||||
current_content = f"{message.author.display_name} ({message.author.name}) (<@{message.author.id}>): {message.content}"
|
||||
current_content = f"({message.author.name}): {message.content}"
|
||||
for hist_msg in recent_history:
|
||||
hist_content = hist_msg.get("content", {}).get("content", "") if isinstance(
|
||||
hist_msg.get("content"), dict) else hist_msg.get("content", "")
|
||||
if hist_content == current_content:
|
||||
logger.debug(f"Duplicate message content detected for message {message.id}, skipping")
|
||||
logger.debug(
|
||||
f"Duplicate message content detected for message {message.id}, skipping")
|
||||
return
|
||||
|
||||
# Update user activity in database
|
||||
@@ -339,7 +371,7 @@ class EventHandler:
|
||||
message_uuid = str(uuid.uuid4())
|
||||
|
||||
# Format the message with user info
|
||||
formatted_content = f"{item.message.author.display_name} ({item.message.author.name}) (<@{item.message.author.id}>): {item.prompt}"
|
||||
formatted_content = f"({item.message.author.name}): {item.prompt}"
|
||||
|
||||
# Check if this message is already in history
|
||||
message_in_history = False
|
||||
@@ -351,7 +383,8 @@ class EventHandler:
|
||||
if hist_content == formatted_content:
|
||||
message_in_history = True
|
||||
if isinstance(hist_msg.get("metadata"), dict):
|
||||
stored_uuid = hist_msg["metadata"].get("message_uuid")
|
||||
stored_uuid = hist_msg["metadata"].get(
|
||||
"message_uuid")
|
||||
break
|
||||
|
||||
# Use stored UUID if found, otherwise use new UUID
|
||||
@@ -381,7 +414,8 @@ class EventHandler:
|
||||
await self.store_message(
|
||||
user_id=item.message.author.id,
|
||||
role="user",
|
||||
content={"content": formatted_content, "metadata": message_metadata},
|
||||
content={"content": formatted_content,
|
||||
"metadata": message_metadata},
|
||||
channel_id=item.channel.id,
|
||||
message_uuid=message_uuid,
|
||||
)
|
||||
@@ -404,7 +438,10 @@ class EventHandler:
|
||||
}
|
||||
|
||||
messages = [system_message]
|
||||
messages.extend(history)
|
||||
|
||||
# Include conversation history
|
||||
if history:
|
||||
messages.extend(history)
|
||||
|
||||
# Always add current message to the API call
|
||||
# Add timeout_env to the context
|
||||
@@ -483,7 +520,8 @@ class EventHandler:
|
||||
)
|
||||
if thread_id:
|
||||
await self.db_manager.update_thread_activity(thread_id)
|
||||
logger.info(f"Created and stored thread '{args['name']}' in database")
|
||||
logger.info(
|
||||
f"Created and stored thread '{args['name']}' in database")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool {tool_name}: {e}")
|
||||
@@ -491,7 +529,8 @@ class EventHandler:
|
||||
# Send the response
|
||||
if final_response:
|
||||
author = item.message.author
|
||||
owner_tag = ' [BOT OWNER]' if int(author.id) == BOT_OWNER_ID else ''
|
||||
owner_tag = ' [BOT OWNER]' if int(
|
||||
author.id) == BOT_OWNER_ID else ''
|
||||
logger.info(
|
||||
f"Bot response to {author.display_name} "
|
||||
f"({author.name}#{author.discriminator})"
|
||||
@@ -501,7 +540,7 @@ class EventHandler:
|
||||
reference = None
|
||||
if hasattr(item.message, '_state'): # Check if it's a real Discord Message
|
||||
reference = item.message
|
||||
|
||||
|
||||
sent_message = await self.message_handler.safe_send(
|
||||
item.channel, final_response, reference=reference
|
||||
)
|
||||
@@ -545,7 +584,7 @@ class EventHandler:
|
||||
source_info = {"type": "web"}
|
||||
else:
|
||||
source_info = {"type": "discord"}
|
||||
|
||||
|
||||
response_metadata = {
|
||||
"response_id": response_uuid,
|
||||
"user_info": {
|
||||
@@ -597,7 +636,7 @@ class EventHandler:
|
||||
},
|
||||
)
|
||||
|
||||
async def send_prompt_to_channel(self, prompt: str, channel_id: int) -> None:
|
||||
async def send_prompt_to_channel(self, prompt: str, channel_id: int, username: str = "Web User") -> None:
|
||||
"""Send a prompt to the LLM and post response in Discord channel."""
|
||||
try:
|
||||
# Get the channel
|
||||
@@ -622,8 +661,8 @@ class EventHandler:
|
||||
"id": str(self.bot.user.id),
|
||||
},
|
||||
"user_info": {
|
||||
"name": "web_user",
|
||||
"display_name": "Web User",
|
||||
"name": username.lower().replace(" ", "_"),
|
||||
"display_name": username,
|
||||
"id": "0",
|
||||
},
|
||||
"timeout_env": "GLHF_TIMEOUT"
|
||||
@@ -636,8 +675,8 @@ class EventHandler:
|
||||
id=str(uuid.uuid4()),
|
||||
author=SimpleNamespace(
|
||||
id=0,
|
||||
name="web_user",
|
||||
display_name="Web User",
|
||||
name=username.lower().replace(" ", "_"),
|
||||
display_name=username,
|
||||
bot=False,
|
||||
discriminator="0000"
|
||||
),
|
||||
|
||||
@@ -39,6 +39,10 @@ async def send_prompt():
|
||||
if not data or 'prompt' not in data:
|
||||
return jsonify({'error': 'Missing prompt'}), 400
|
||||
|
||||
username = data.get('username', 'Web User')
|
||||
if not username.strip():
|
||||
username = 'Web User'
|
||||
|
||||
try:
|
||||
channel_id = int(str(data.get('channel_id', '1198637345701285999')))
|
||||
except (ValueError, TypeError):
|
||||
@@ -48,8 +52,8 @@ async def send_prompt():
|
||||
return jsonify({'error': 'Channel ID is required'}), 400
|
||||
|
||||
try:
|
||||
# Attempt to send prompt to channel
|
||||
await event_handler.send_prompt_to_channel(data['prompt'], channel_id)
|
||||
# Attempt to send prompt to channel with username
|
||||
await event_handler.send_prompt_to_channel(data['prompt'], channel_id, username=username)
|
||||
return jsonify({
|
||||
'status': 'processing',
|
||||
'message': f'Prompt sent to channel {channel_id}'
|
||||
|
||||
@@ -25,6 +25,20 @@
|
||||
placeholder="Enter your prompt here..."
|
||||
></textarea>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label for="username" class="block text-sm font-medium text-gray-700 mb-1">
|
||||
Username
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
id="username"
|
||||
name="username"
|
||||
class="w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:ring-blue-500 focus:border-blue-500"
|
||||
placeholder="Enter your username..."
|
||||
required
|
||||
>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label for="channelId" class="block text-sm font-medium text-gray-700 mb-1">
|
||||
@@ -66,6 +80,7 @@
|
||||
e.preventDefault();
|
||||
|
||||
const prompt = document.getElementById('prompt').value;
|
||||
const username = document.getElementById('username').value;
|
||||
const channelId = document.getElementById('channelId').value;
|
||||
const resultDiv = document.getElementById('result');
|
||||
const resultMessage = document.getElementById('resultMessage');
|
||||
@@ -78,6 +93,7 @@
|
||||
},
|
||||
body: JSON.stringify({
|
||||
prompt: prompt,
|
||||
username: username,
|
||||
channel_id: channelId ? String(channelId) : "1198637345701285999"
|
||||
})
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user