Enhance tool functionality by adding function definitions and schemas; initialize database before handlers; improve API response handling for tool calls

This commit is contained in:
pacnpal
2025-02-24 14:47:43 -05:00
parent 015c40441b
commit 18c99a0434
5 changed files with 209 additions and 43 deletions

2
.gitignore vendored
View File

@@ -26,3 +26,5 @@ discord_bot.log.4
discord_bot.log.5
queue_state.json
system_prompt.yaml
discord_glhf/__pycache__/bot.cpython-313.pyc
discord_glhf/handlers/__pycache__/event_handler.cpython-313.pyc

View File

@@ -255,6 +255,9 @@ class APIManager:
}
)
# Import tools
from .handlers.function_tools import TOOLS, parse_tool_response
# Prepare request data
data = {
"model": params["model"],
@@ -262,6 +265,7 @@ class APIManager:
"temperature": params["temperature"],
"max_tokens": params["max_tokens"],
"stream": False, # Disable streaming for all requests
"tools": TOOLS # Add available tools
}
logger.debug(
"API Request Details:\n"
@@ -320,14 +324,24 @@ class APIManager:
if json_response.get("choices"):
choice = json_response["choices"][0]
if "message" in choice:
content = choice["message"].get("content", "")
message = choice["message"]
# Check for tool calls first
if "tool_calls" in message:
tool_response = parse_tool_response(json_response)
if tool_response:
return True, json.dumps(tool_response)
# Fall back to regular content
content = message.get("content", "")
if content and content.strip():
return True, content
else:
logger.error("Empty content in response")
logger.debug(f"Response headers: {response.headers}")
logger.debug(f"Response status: {response.status}")
return False, None
# Log if no valid response found
logger.error("No valid content or tool calls in response")
logger.debug(f"Response headers: {response.headers}")
logger.debug(f"Response status: {response.status}")
return False, None
except json.JSONDecodeError as e:
logger.error(f"Failed to decode response: {e}")
return False, None

View File

@@ -57,7 +57,11 @@ class DiscordBot:
try:
async with self._init_lock:
if not self._initialized:
# Initialize all handlers first
# Initialize database first
await self.db_manager.init_db()
logger.info("Database initialized")
# Initialize all handlers
self.message_handler = MessageHandler(self.db_manager)
logger.info("Message handler initialized")
@@ -420,3 +424,4 @@ def run_bot():
if __name__ == "__main__":
run_bot()

View File

@@ -25,8 +25,10 @@ class EventHandler:
self.db_manager = db_manager
self.api_manager = api_manager
# Use provided handlers or create new ones
self.message_handler = message_handler if message_handler else MessageHandler(db_manager)
self.image_handler = image_handler if image_handler else ImageHandler(api_manager)
self.message_handler = message_handler if message_handler else MessageHandler(
db_manager)
self.image_handler = image_handler if image_handler else ImageHandler(
api_manager)
self.tool_handler = tool_handler if tool_handler else ToolHandler(bot)
# Set this handler as the queue's message processor
@@ -42,7 +44,7 @@ class EventHandler:
f"<@{mention[2:-1]}>" # Raw mention format
]
pattern = '|'.join(patterns)
# Replace all mention formats with the proper mention
return re.sub(pattern, mention, response)
@@ -203,7 +205,8 @@ class EventHandler:
created_at=datetime.utcnow()
)
logger.info(f"Adding reaction response to queue from {user.display_name}")
logger.info(
f"Adding reaction response to queue from {user.display_name}")
# Queue the reaction like a regular message
await self.queue_manager.add_message(
channel=channel,
@@ -231,7 +234,8 @@ class EventHandler:
# Early duplicate checks before any processing
if any(item.message.id == message.id for item in self.queue_manager.message_queue.processing):
logger.debug(f"Message {message.id} already in processing, skipping")
logger.debug(
f"Message {message.id} already in processing, skipping")
return
# Get current queue size
@@ -245,7 +249,8 @@ class EventHandler:
message_id=message.id
)
if message_processed:
logger.debug(f"Message {message.id} already processed, skipping")
logger.debug(
f"Message {message.id} already processed, skipping")
return
# Check for duplicate content in history
@@ -253,12 +258,13 @@ class EventHandler:
user_id=0,
channel_id=message.channel.id
)
current_content = f"{message.author.display_name} ({message.author.name}) (<@{message.author.id}>): {message.content}"
current_content = f"Username: {message.author.name} Message Content: {message.content}"
for hist_msg in recent_history:
hist_content = hist_msg.get("content", {}).get("content", "") if isinstance(
hist_msg.get("content"), dict) else hist_msg.get("content", "")
if hist_content == current_content:
logger.debug(f"Duplicate message content detected for message {message.id}, skipping")
logger.debug(
f"Duplicate message content detected for message {message.id}, skipping")
return
# Update user activity in database
@@ -340,7 +346,12 @@ class EventHandler:
message_uuid = str(uuid.uuid4())
# Format the message with user info
formatted_content = f"{item.message.author.display_name} ({item.message.author.name}) (<@{item.message.author.id}>): {item.prompt}"
formatted_content = (
f"(username): {item.message.author.name} "
f"(message): {item.prompt} "
"(IMPORTANT: DO NOT REPEAT THIS PATTERN WHEN RESPONDING! "
"Respond only to the (message) part of the prompt. Now go.)"
)
# Check if this message is already in history
message_in_history = False
@@ -352,7 +363,8 @@ class EventHandler:
if hist_content == formatted_content:
message_in_history = True
if isinstance(hist_msg.get("metadata"), dict):
stored_uuid = hist_msg["metadata"].get("message_uuid")
stored_uuid = hist_msg["metadata"].get(
"message_uuid")
break
# Use stored UUID if found, otherwise use new UUID
@@ -382,7 +394,8 @@ class EventHandler:
await self.store_message(
user_id=item.message.author.id,
role="user",
content={"content": formatted_content, "metadata": message_metadata},
content={"content": formatted_content,
"metadata": message_metadata},
channel_id=item.channel.id,
message_uuid=message_uuid,
)
@@ -441,18 +454,31 @@ Available tools:
if not response:
return
# Parse tool calls and get processed response
tool_calls, final_response, mentioned_users = self.tool_handler.parse_tool_calls(
response, message_id=item.message.id, channel_id=item.channel.id
)
# Parse response for potential tool calls
try:
response_data = json.loads(response)
if isinstance(response_data, dict) and "tool_calls" in response_data:
tool_calls = response_data["tool_calls"]
final_response = "" # Will be populated after tool execution
else:
tool_calls = []
final_response = response
except json.JSONDecodeError:
# Not JSON, treat as regular response
tool_calls = []
final_response = response
except Exception as e:
logger.error(f"Error getting API completion: {e}")
return
# Execute tool calls
for tool_name, args in tool_calls:
# Execute tool calls and collect responses
response_parts = []
for tool_call in tool_calls:
try:
if tool_name == "find_user":
name = tool_call["name"]
args = json.loads(tool_call["arguments"]) if isinstance(tool_call["arguments"], str) else tool_call["arguments"]
if name == "mention_user":
# Check if we're trying to mention the message author
if args["name"].lower() in [
item.message.author.name.lower(),
@@ -466,45 +492,47 @@ Available tools:
)
if mention:
final_response = self._clean_mentions(
final_response,
mention,
item.message.author.display_name,
args["name"]
)
response_parts.append(mention)
elif tool_name == "add_reaction":
elif name == "add_reaction":
await self.tool_handler.add_reaction(
item.message.id, item.channel.id, args["emoji"]
)
elif tool_name == "create_embed":
elif name == "create_embed":
await self.tool_handler.create_embed(
channel=item.channel, content=args["content"]
channel=item.channel,
content=f"{args['title']}\n{args['description']}",
color=args.get('color', 0xFF0000)
)
response_parts.append(f"[Embed sent: {args['title']}]")
elif tool_name == "create_thread":
# Create Discord thread first
elif name == "create_thread":
message_id = args.get('message_id', item.message.id)
discord_thread = await self.tool_handler.create_thread(
item.channel.id, args["name"], item.message.id
item.channel.id, args["name"], message_id
)
if discord_thread:
# Create thread in database and update activity
thread_id = await self.message_handler.create_thread_for_message(
item.channel.id, item.message.author.id, args["name"]
)
if thread_id:
await self.db_manager.update_thread_activity(thread_id)
logger.info(f"Created and stored thread '{args['name']}' in database")
response_parts.append(f"[Thread created: {args['name']}]")
except Exception as e:
logger.error(f"Error executing tool {tool_name}: {e}")
logger.error(f"Error executing tool {name}: {e}")
# If we executed any tools, combine their responses
if response_parts:
final_response = " ".join(response_parts) if not final_response else f"{final_response}\n{' '.join(response_parts)}"
# Send the response
if final_response:
author = item.message.author
owner_tag = ' [BOT OWNER]' if int(author.id) == BOT_OWNER_ID else ''
owner_tag = ' [BOT OWNER]' if int(
author.id) == BOT_OWNER_ID else ''
logger.info(
f"Bot response to {author.display_name} "
f"({author.name}#{author.discriminator})"
@@ -514,7 +542,7 @@ Available tools:
reference = None
if hasattr(item.message, '_state'): # Check if it's a real Discord Message
reference = item.message
sent_message = await self.message_handler.safe_send(
item.channel, final_response, reference=reference
)
@@ -558,7 +586,7 @@ Available tools:
source_info = {"type": "web"}
else:
source_info = {"type": "discord"}
response_metadata = {
"response_id": response_uuid,
"user_info": {

View File

@@ -0,0 +1,117 @@
"""Function definitions and tool schemas for LLM function calling."""
from typing import Dict, Any
TOOLS = [
{
"type": "function",
"function": {
"name": "mention_user",
"description": "Mention a Discord user in the response",
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Username or nickname to mention"
}
},
"required": ["name"]
}
}
},
{
"type": "function",
"function": {
"name": "add_reaction",
"description": "Add an emoji reaction to the message",
"parameters": {
"type": "object",
"properties": {
"emoji": {
"type": "string",
"description": "The emoji to add (Unicode emoji, custom emoji <:name:id>, or standard emoji :name:)"
}
},
"required": ["emoji"]
}
}
},
{
"type": "function",
"function": {
"name": "create_embed",
"description": "Create a rich embed message with title and content",
"parameters": {
"type": "object",
"properties": {
"title": {
"type": "string",
"description": "Title of the embed"
},
"description": {
"type": "string",
"description": "Content of the embed"
},
"color": {
"type": "integer",
"description": "Color of the embed in hex format (e.g., 0xFF0000 for red)",
"default": 0xFF0000
}
},
"required": ["title", "description"]
}
}
},
{
"type": "function",
"function": {
"name": "create_thread",
"description": "Create a new discussion thread from the message",
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Name/topic for the thread. Common patterns: 'X vs Y', '[topic] is overrated/underrated', or topics about safety/maintenance/review"
},
"message_id": {
"type": "integer",
"description": "Optional ID of message to create thread from"
}
},
"required": ["name"]
}
}
}
]
def function_to_tool_call(function_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""Convert a function call response to a tool call format."""
return {
"name": function_name,
"arguments": arguments
}
def parse_tool_response(response: Dict[str, Any]) -> Dict[str, Any]:
"""Parse an API response looking for tool calls."""
if not response or "choices" not in response:
return {}
choice = response["choices"][0]
if "message" not in choice:
return {}
message = choice["message"]
if "tool_calls" not in message:
return {}
tool_calls = []
for tool_call in message["tool_calls"]:
if tool_call["type"] == "function":
tool_calls.append(function_to_tool_call(
tool_call["function"]["name"],
tool_call["function"]["arguments"]
))
return {"tool_calls": tool_calls} if tool_calls else {}