Enhance tool functionality by adding function definitions and schemas; initialize database before handlers; improve API response handling for tool calls
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -26,3 +26,5 @@ discord_bot.log.4
|
||||
discord_bot.log.5
|
||||
queue_state.json
|
||||
system_prompt.yaml
|
||||
discord_glhf/__pycache__/bot.cpython-313.pyc
|
||||
discord_glhf/handlers/__pycache__/event_handler.cpython-313.pyc
|
||||
|
||||
@@ -255,6 +255,9 @@ class APIManager:
|
||||
}
|
||||
)
|
||||
|
||||
# Import tools
|
||||
from .handlers.function_tools import TOOLS, parse_tool_response
|
||||
|
||||
# Prepare request data
|
||||
data = {
|
||||
"model": params["model"],
|
||||
@@ -262,6 +265,7 @@ class APIManager:
|
||||
"temperature": params["temperature"],
|
||||
"max_tokens": params["max_tokens"],
|
||||
"stream": False, # Disable streaming for all requests
|
||||
"tools": TOOLS # Add available tools
|
||||
}
|
||||
logger.debug(
|
||||
"API Request Details:\n"
|
||||
@@ -320,14 +324,24 @@ class APIManager:
|
||||
if json_response.get("choices"):
|
||||
choice = json_response["choices"][0]
|
||||
if "message" in choice:
|
||||
content = choice["message"].get("content", "")
|
||||
message = choice["message"]
|
||||
|
||||
# Check for tool calls first
|
||||
if "tool_calls" in message:
|
||||
tool_response = parse_tool_response(json_response)
|
||||
if tool_response:
|
||||
return True, json.dumps(tool_response)
|
||||
|
||||
# Fall back to regular content
|
||||
content = message.get("content", "")
|
||||
if content and content.strip():
|
||||
return True, content
|
||||
else:
|
||||
logger.error("Empty content in response")
|
||||
logger.debug(f"Response headers: {response.headers}")
|
||||
logger.debug(f"Response status: {response.status}")
|
||||
return False, None
|
||||
|
||||
# Log if no valid response found
|
||||
logger.error("No valid content or tool calls in response")
|
||||
logger.debug(f"Response headers: {response.headers}")
|
||||
logger.debug(f"Response status: {response.status}")
|
||||
return False, None
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to decode response: {e}")
|
||||
return False, None
|
||||
|
||||
@@ -57,7 +57,11 @@ class DiscordBot:
|
||||
try:
|
||||
async with self._init_lock:
|
||||
if not self._initialized:
|
||||
# Initialize all handlers first
|
||||
# Initialize database first
|
||||
await self.db_manager.init_db()
|
||||
logger.info("Database initialized")
|
||||
|
||||
# Initialize all handlers
|
||||
self.message_handler = MessageHandler(self.db_manager)
|
||||
logger.info("Message handler initialized")
|
||||
|
||||
@@ -420,3 +424,4 @@ def run_bot():
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_bot()
|
||||
|
||||
|
||||
@@ -25,8 +25,10 @@ class EventHandler:
|
||||
self.db_manager = db_manager
|
||||
self.api_manager = api_manager
|
||||
# Use provided handlers or create new ones
|
||||
self.message_handler = message_handler if message_handler else MessageHandler(db_manager)
|
||||
self.image_handler = image_handler if image_handler else ImageHandler(api_manager)
|
||||
self.message_handler = message_handler if message_handler else MessageHandler(
|
||||
db_manager)
|
||||
self.image_handler = image_handler if image_handler else ImageHandler(
|
||||
api_manager)
|
||||
self.tool_handler = tool_handler if tool_handler else ToolHandler(bot)
|
||||
|
||||
# Set this handler as the queue's message processor
|
||||
@@ -42,7 +44,7 @@ class EventHandler:
|
||||
f"<@{mention[2:-1]}>" # Raw mention format
|
||||
]
|
||||
pattern = '|'.join(patterns)
|
||||
|
||||
|
||||
# Replace all mention formats with the proper mention
|
||||
return re.sub(pattern, mention, response)
|
||||
|
||||
@@ -203,7 +205,8 @@ class EventHandler:
|
||||
created_at=datetime.utcnow()
|
||||
)
|
||||
|
||||
logger.info(f"Adding reaction response to queue from {user.display_name}")
|
||||
logger.info(
|
||||
f"Adding reaction response to queue from {user.display_name}")
|
||||
# Queue the reaction like a regular message
|
||||
await self.queue_manager.add_message(
|
||||
channel=channel,
|
||||
@@ -231,7 +234,8 @@ class EventHandler:
|
||||
|
||||
# Early duplicate checks before any processing
|
||||
if any(item.message.id == message.id for item in self.queue_manager.message_queue.processing):
|
||||
logger.debug(f"Message {message.id} already in processing, skipping")
|
||||
logger.debug(
|
||||
f"Message {message.id} already in processing, skipping")
|
||||
return
|
||||
|
||||
# Get current queue size
|
||||
@@ -245,7 +249,8 @@ class EventHandler:
|
||||
message_id=message.id
|
||||
)
|
||||
if message_processed:
|
||||
logger.debug(f"Message {message.id} already processed, skipping")
|
||||
logger.debug(
|
||||
f"Message {message.id} already processed, skipping")
|
||||
return
|
||||
|
||||
# Check for duplicate content in history
|
||||
@@ -253,12 +258,13 @@ class EventHandler:
|
||||
user_id=0,
|
||||
channel_id=message.channel.id
|
||||
)
|
||||
current_content = f"{message.author.display_name} ({message.author.name}) (<@{message.author.id}>): {message.content}"
|
||||
current_content = f"Username: {message.author.name} Message Content: {message.content}"
|
||||
for hist_msg in recent_history:
|
||||
hist_content = hist_msg.get("content", {}).get("content", "") if isinstance(
|
||||
hist_msg.get("content"), dict) else hist_msg.get("content", "")
|
||||
if hist_content == current_content:
|
||||
logger.debug(f"Duplicate message content detected for message {message.id}, skipping")
|
||||
logger.debug(
|
||||
f"Duplicate message content detected for message {message.id}, skipping")
|
||||
return
|
||||
|
||||
# Update user activity in database
|
||||
@@ -340,7 +346,12 @@ class EventHandler:
|
||||
message_uuid = str(uuid.uuid4())
|
||||
|
||||
# Format the message with user info
|
||||
formatted_content = f"{item.message.author.display_name} ({item.message.author.name}) (<@{item.message.author.id}>): {item.prompt}"
|
||||
formatted_content = (
|
||||
f"(username): {item.message.author.name} "
|
||||
f"(message): {item.prompt} "
|
||||
"(IMPORTANT: DO NOT REPEAT THIS PATTERN WHEN RESPONDING! "
|
||||
"Respond only to the (message) part of the prompt. Now go.)"
|
||||
)
|
||||
|
||||
# Check if this message is already in history
|
||||
message_in_history = False
|
||||
@@ -352,7 +363,8 @@ class EventHandler:
|
||||
if hist_content == formatted_content:
|
||||
message_in_history = True
|
||||
if isinstance(hist_msg.get("metadata"), dict):
|
||||
stored_uuid = hist_msg["metadata"].get("message_uuid")
|
||||
stored_uuid = hist_msg["metadata"].get(
|
||||
"message_uuid")
|
||||
break
|
||||
|
||||
# Use stored UUID if found, otherwise use new UUID
|
||||
@@ -382,7 +394,8 @@ class EventHandler:
|
||||
await self.store_message(
|
||||
user_id=item.message.author.id,
|
||||
role="user",
|
||||
content={"content": formatted_content, "metadata": message_metadata},
|
||||
content={"content": formatted_content,
|
||||
"metadata": message_metadata},
|
||||
channel_id=item.channel.id,
|
||||
message_uuid=message_uuid,
|
||||
)
|
||||
@@ -441,18 +454,31 @@ Available tools:
|
||||
if not response:
|
||||
return
|
||||
|
||||
# Parse tool calls and get processed response
|
||||
tool_calls, final_response, mentioned_users = self.tool_handler.parse_tool_calls(
|
||||
response, message_id=item.message.id, channel_id=item.channel.id
|
||||
)
|
||||
# Parse response for potential tool calls
|
||||
try:
|
||||
response_data = json.loads(response)
|
||||
if isinstance(response_data, dict) and "tool_calls" in response_data:
|
||||
tool_calls = response_data["tool_calls"]
|
||||
final_response = "" # Will be populated after tool execution
|
||||
else:
|
||||
tool_calls = []
|
||||
final_response = response
|
||||
except json.JSONDecodeError:
|
||||
# Not JSON, treat as regular response
|
||||
tool_calls = []
|
||||
final_response = response
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API completion: {e}")
|
||||
return
|
||||
|
||||
# Execute tool calls
|
||||
for tool_name, args in tool_calls:
|
||||
# Execute tool calls and collect responses
|
||||
response_parts = []
|
||||
for tool_call in tool_calls:
|
||||
try:
|
||||
if tool_name == "find_user":
|
||||
name = tool_call["name"]
|
||||
args = json.loads(tool_call["arguments"]) if isinstance(tool_call["arguments"], str) else tool_call["arguments"]
|
||||
|
||||
if name == "mention_user":
|
||||
# Check if we're trying to mention the message author
|
||||
if args["name"].lower() in [
|
||||
item.message.author.name.lower(),
|
||||
@@ -466,45 +492,47 @@ Available tools:
|
||||
)
|
||||
|
||||
if mention:
|
||||
final_response = self._clean_mentions(
|
||||
final_response,
|
||||
mention,
|
||||
item.message.author.display_name,
|
||||
args["name"]
|
||||
)
|
||||
response_parts.append(mention)
|
||||
|
||||
elif tool_name == "add_reaction":
|
||||
elif name == "add_reaction":
|
||||
await self.tool_handler.add_reaction(
|
||||
item.message.id, item.channel.id, args["emoji"]
|
||||
)
|
||||
|
||||
elif tool_name == "create_embed":
|
||||
elif name == "create_embed":
|
||||
await self.tool_handler.create_embed(
|
||||
channel=item.channel, content=args["content"]
|
||||
channel=item.channel,
|
||||
content=f"{args['title']}\n{args['description']}",
|
||||
color=args.get('color', 0xFF0000)
|
||||
)
|
||||
response_parts.append(f"[Embed sent: {args['title']}]")
|
||||
|
||||
elif tool_name == "create_thread":
|
||||
# Create Discord thread first
|
||||
elif name == "create_thread":
|
||||
message_id = args.get('message_id', item.message.id)
|
||||
discord_thread = await self.tool_handler.create_thread(
|
||||
item.channel.id, args["name"], item.message.id
|
||||
item.channel.id, args["name"], message_id
|
||||
)
|
||||
|
||||
if discord_thread:
|
||||
# Create thread in database and update activity
|
||||
thread_id = await self.message_handler.create_thread_for_message(
|
||||
item.channel.id, item.message.author.id, args["name"]
|
||||
)
|
||||
if thread_id:
|
||||
await self.db_manager.update_thread_activity(thread_id)
|
||||
logger.info(f"Created and stored thread '{args['name']}' in database")
|
||||
response_parts.append(f"[Thread created: {args['name']}]")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing tool {tool_name}: {e}")
|
||||
logger.error(f"Error executing tool {name}: {e}")
|
||||
|
||||
# If we executed any tools, combine their responses
|
||||
if response_parts:
|
||||
final_response = " ".join(response_parts) if not final_response else f"{final_response}\n{' '.join(response_parts)}"
|
||||
|
||||
# Send the response
|
||||
if final_response:
|
||||
author = item.message.author
|
||||
owner_tag = ' [BOT OWNER]' if int(author.id) == BOT_OWNER_ID else ''
|
||||
owner_tag = ' [BOT OWNER]' if int(
|
||||
author.id) == BOT_OWNER_ID else ''
|
||||
logger.info(
|
||||
f"Bot response to {author.display_name} "
|
||||
f"({author.name}#{author.discriminator})"
|
||||
@@ -514,7 +542,7 @@ Available tools:
|
||||
reference = None
|
||||
if hasattr(item.message, '_state'): # Check if it's a real Discord Message
|
||||
reference = item.message
|
||||
|
||||
|
||||
sent_message = await self.message_handler.safe_send(
|
||||
item.channel, final_response, reference=reference
|
||||
)
|
||||
@@ -558,7 +586,7 @@ Available tools:
|
||||
source_info = {"type": "web"}
|
||||
else:
|
||||
source_info = {"type": "discord"}
|
||||
|
||||
|
||||
response_metadata = {
|
||||
"response_id": response_uuid,
|
||||
"user_info": {
|
||||
|
||||
117
discord_glhf/handlers/function_tools.py
Normal file
117
discord_glhf/handlers/function_tools.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""Function definitions and tool schemas for LLM function calling."""
|
||||
|
||||
from typing import Dict, Any
|
||||
|
||||
TOOLS = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "mention_user",
|
||||
"description": "Mention a Discord user in the response",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Username or nickname to mention"
|
||||
}
|
||||
},
|
||||
"required": ["name"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "add_reaction",
|
||||
"description": "Add an emoji reaction to the message",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"emoji": {
|
||||
"type": "string",
|
||||
"description": "The emoji to add (Unicode emoji, custom emoji <:name:id>, or standard emoji :name:)"
|
||||
}
|
||||
},
|
||||
"required": ["emoji"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "create_embed",
|
||||
"description": "Create a rich embed message with title and content",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "Title of the embed"
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Content of the embed"
|
||||
},
|
||||
"color": {
|
||||
"type": "integer",
|
||||
"description": "Color of the embed in hex format (e.g., 0xFF0000 for red)",
|
||||
"default": 0xFF0000
|
||||
}
|
||||
},
|
||||
"required": ["title", "description"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "create_thread",
|
||||
"description": "Create a new discussion thread from the message",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Name/topic for the thread. Common patterns: 'X vs Y', '[topic] is overrated/underrated', or topics about safety/maintenance/review"
|
||||
},
|
||||
"message_id": {
|
||||
"type": "integer",
|
||||
"description": "Optional ID of message to create thread from"
|
||||
}
|
||||
},
|
||||
"required": ["name"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
def function_to_tool_call(function_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Convert a function call response to a tool call format."""
|
||||
return {
|
||||
"name": function_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
|
||||
def parse_tool_response(response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Parse an API response looking for tool calls."""
|
||||
if not response or "choices" not in response:
|
||||
return {}
|
||||
|
||||
choice = response["choices"][0]
|
||||
if "message" not in choice:
|
||||
return {}
|
||||
|
||||
message = choice["message"]
|
||||
if "tool_calls" not in message:
|
||||
return {}
|
||||
|
||||
tool_calls = []
|
||||
for tool_call in message["tool_calls"]:
|
||||
if tool_call["type"] == "function":
|
||||
tool_calls.append(function_to_tool_call(
|
||||
tool_call["function"]["name"],
|
||||
tool_call["function"]["arguments"]
|
||||
))
|
||||
|
||||
return {"tool_calls": tool_calls} if tool_calls else {}
|
||||
Reference in New Issue
Block a user