Added several things: Parser to import JSON from DiscordChatExporter, ability to train without bot running and more.

This commit is contained in:
pacnpal
2024-12-27 11:17:21 -05:00
parent 44ddad6b58
commit ec0c4e6c84
14 changed files with 4549 additions and 6025 deletions

3
.gitignore vendored
View File

@@ -68,3 +68,6 @@ error.json
markovDB.json
/config/
/exports/
/build/
/dist/

View File

@@ -14,11 +14,67 @@ A Markov chain bot using markov-strings.
* User: `/mark`
* Bot: ![worms are not baby snakes, by the way](img/respond.png)
### Training from a file
### Training from files
Using the `json` option in the `/train` command, you can import a list of messages.
You can train the bot using JSON files in two ways:
1. Using the `json` option in the `/train` command to import a single file of messages.
2. Using the command line to train from either a single file or an entire directory of JSON files.
#### Using the Discord Command
Use the `json` option in the `/train` command to import a single file of messages.
An example JSON file can be seen [here](img/example-training.json).
#### Using the Command Line
For bulk training from multiple files, you can use the command line interface. First, build the training script:
```bash
# Build the TypeScript files
npm run build
```
Then you can use the training script:
```bash
# Train from a single JSON file
node build/train.js <guildId> <jsonPath> [--keep-existing]
# Train from all JSON files in a directory
node build/train.js <guildId> <directoryPath> --directory [--keep-existing] [--expose-gc]
```
Options:
- `--keep-existing`: Don't clear existing training data before importing
- `--directory`: Process all JSON files in the specified directory
- `--expose-gc`: Enable garbage collection for better memory management (recommended for large directories)
Each JSON file should contain an array of messages in this format:
```json
[
{
"message": "Message content",
"attachments": ["optional", "attachment", "urls"]
}
]
```
When training from a directory:
- All .json files in the directory will be processed
- Files are processed sequentially to manage memory usage
- Progress is shown for each file
- A total count of processed messages is provided at the end
Security and Performance Notes:
- The directory must be within the project's working directory for security
- The process will create lock files in the config directory to prevent concurrent training
- Memory usage is monitored and managed automatically
- For large directories, use the `--expose-gc` flag for better memory management:
```bash
node --expose-gc build/train.js <guildId> <directoryPath> --directory
```
- Training can be safely interrupted with Ctrl+C; state will be preserved
- Use `--keep-existing` to resume interrupted training
## Setup
This bot stores your Discord server's entire message history, so a public instance to invite to your server is not available due to obvious data privacy concerns. Instead, you can host it yourself.

21
example-training.json Normal file
View File

@@ -0,0 +1,21 @@
[
{
"message": "Hello world!",
"attachments": []
},
{
"message": "This is an example message with an attachment",
"attachments": ["https://example.com/image.jpg"]
},
{
"message": "Another training message",
"attachments": []
},
{
"message": "Messages can have multiple attachments",
"attachments": [
"https://example.com/image1.jpg",
"https://example.com/image2.jpg"
]
}
]

124
imports/discord-parser.py Normal file
View File

@@ -0,0 +1,124 @@
import json
import os
import re
from glob import glob
import argparse
def strip_mentions(content):
# Pattern for Discord mentions: <@!?[0-9]+> or <@&[0-9]+> or <#[0-9]+>
mention_pattern = r'<@!?\d+>|<@&\d+>|<#\d+>'
return re.sub(mention_pattern, '', content).strip()
def extract_channel_name(filename):
# Extract channel name and part number from Discord export filename format
# Pattern: channel name before [numbers] and optional [part X]
channel_match = re.search(r'- ([^[\]]+) \[\d+\]', filename)
part_match = re.search(r'\[part (\d+)\]', filename, re.IGNORECASE)
if channel_match:
channel_name = channel_match.group(1).strip().lower()
# Remove any category names (text after last hyphen)
channel_name = channel_name.split(' - ')[-1].strip()
# Add part number if it exists
if part_match:
part_num = part_match.group(1)
return f"{channel_name}_part{part_num}.json"
return f"{channel_name}.json"
return "unknown_channel.json"
def parse_discord_export(input_file, output_dir):
# Generate output filename based on channel name
output_filename = extract_channel_name(input_file)
output_file = os.path.join(output_dir, output_filename)
try:
# Read the input file
with open(input_file, 'r', encoding='utf-8') as f:
data = json.load(f)
# Extract just the messages and their attachments
output_messages = []
for msg in data['messages']:
# Skip empty messages after stripping mentions
stripped_content = strip_mentions(msg['content'])
if not stripped_content and not msg['attachments']:
continue
message_obj = {
'message': stripped_content
}
# If there are attachments, add them to the message object
if msg['attachments']:
message_obj['attachments'] = [
attachment['url'] for attachment in msg['attachments']
]
# Only add messages that have content or attachments
if message_obj['message'] or 'attachments' in message_obj:
output_messages.append(message_obj)
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Write the output file
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(output_messages, f, indent=4, ensure_ascii=False)
print(f"Successfully parsed: {input_file} -> {output_file}")
return True
except Exception as e:
print(f"Error processing {input_file}: {str(e)}")
return False
def process_directory(input_dir, output_dir):
# Find all Discord export JSON files in the input directory
pattern = os.path.join(input_dir, "*Discord*.json")
files = glob(pattern)
if not files:
print(f"No Discord export files found in: {input_dir}")
return
success_count = 0
failure_count = 0
for file in files:
if parse_discord_export(file, output_dir):
success_count += 1
else:
failure_count += 1
print(f"\nProcessing complete:")
print(f"Successfully processed: {success_count} files")
print(f"Failed to process: {failure_count} files")
def main():
parser = argparse.ArgumentParser(description='Process Discord export JSON files.')
parser.add_argument('-i', '--input', default='.',
help='Input directory containing Discord export files (default: current directory)')
parser.add_argument('-o', '--output', default='output',
help='Output directory for processed files (default: output)')
args = parser.parse_args()
# Convert to absolute paths
input_dir = os.path.abspath(args.input)
output_dir = os.path.abspath(args.output)
# Check if input directory exists
if not os.path.exists(input_dir):
print(f"Error: Input directory does not exist: {input_dir}")
return
print(f"Processing files from: {input_dir}")
print(f"Saving output to: {output_dir}")
process_directory(input_dir, output_dir)
if __name__ == "__main__":
main()

9348
package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -4,14 +4,14 @@
"description": "A conversational Markov chain bot for Discord",
"main": "dist/index.js",
"scripts": {
"start": "NODE_ENV=production pm2 start --no-daemon dist/index.js",
"start": "./node_modules/.bin/pm2 start --no-daemon dist/index.js",
"start:ts": "ts-node src/index.ts",
"build": "rimraf dist && tsc",
"lint": "tsc --noEmit && eslint .",
"build": "rimraf build && /opt/homebrew/bin/node ./node_modules/.bin/tsc",
"lint": "./node-v20.18.1-darwin-x64/bin/node ./node_modules/.bin/tsc --noEmit && ./node-v20.18.1-darwin-x64/bin/node ./node_modules/.bin/eslint .",
"docker:build": "docker build . -t charlocharlie/markov-discord:latest --target deploy",
"docker:run": "docker run --rm -ti -v $(pwd)/config:/usr/app/config charlocharlie/markov-discord:latest",
"typeorm": "node --require ts-node/register ./node_modules/typeorm/cli.js",
"docs": "typedoc --out docs src/config/classes.ts"
"typeorm": "./node-v20.18.1-darwin-x64/bin/node --require ts-node/register ./node_modules/typeorm/cli.js",
"docs": "./node-v20.18.1-darwin-x64/bin/node ./node_modules/.bin/typedoc --out docs src/config/classes.ts"
},
"repository": "https://github.com/claabs/markov-discord.git",
"keywords": [
@@ -31,7 +31,7 @@
},
"license": "MIT",
"dependencies": {
"better-sqlite3": "^8.7.0",
"better-sqlite3": "^9.6.0",
"bufferutil": "^4.0.8",
"class-transformer": "^0.5.1",
"class-validator": "^0.14.1",
@@ -42,6 +42,7 @@
"json5": "^2.2.3",
"markov-strings-db": "^4.2.0",
"node-fetch": "^2.6.7",
"node-gyp": "^11.0.0",
"pino": "^7.11.0",
"pino-pretty": "^7.6.1",
"reflect-metadata": "^0.2.2",

View File

@@ -163,4 +163,19 @@ export class AppConfig {
@IsOptional()
@IsString()
devGuildId = process.env.DEV_GUILD_ID;
/**
* A list of channel IDs where the bot will respond to mentions.
* If empty, the bot will respond to mentions in any channel.
* @example ["734548250895319070"]
* @default []
* @env RESPONSE_CHANNEL_IDS (comma separated)
*/
@IsArray()
@IsString({ each: true })
@Type(() => String)
@IsOptional()
responseChannelIds = process.env.RESPONSE_CHANNEL_IDS
? process.env.RESPONSE_CHANNEL_IDS.split(',').map((id) => id.trim())
: [];
}

View File

@@ -21,9 +21,6 @@ export const inviteCommand = new SlashCommandBuilder()
export const messageCommand = new SlashCommandBuilder()
.setName(config.slashCommandName)
.setDescription('Generate a message from learned past messages')
.addBooleanOption((tts) =>
tts.setName('tts').setDescription('Read the message via text-to-speech.').setRequired(false),
)
.addBooleanOption((debug) =>
debug
.setName('debug')
@@ -49,6 +46,38 @@ const channelOptionsGenerator = (builder: SlashCommandChannelOption, index: numb
.setRequired(index === 0)
.addChannelTypes(ChannelType.GuildText);
export const autoRespondCommand = new SlashCommandBuilder()
.setName('autorespond')
.setDescription('Configure channels where the bot will automatically respond to all messages')
.addSubcommand((sub) => {
sub
.setName('add')
.setDescription('Add channels where the bot will automatically respond to all messages');
Array.from(Array(CHANNEL_OPTIONS_MAX).keys()).forEach((index) =>
sub.addChannelOption((opt) => channelOptionsGenerator(opt, index)),
);
return sub;
})
.addSubcommand((sub) => {
sub
.setName('remove')
.setDescription('Remove channels from auto-response');
Array.from(Array(CHANNEL_OPTIONS_MAX).keys()).forEach((index) =>
sub.addChannelOption((opt) => channelOptionsGenerator(opt, index)),
);
return sub;
})
.addSubcommand((sub) =>
sub
.setName('list')
.setDescription('List the channels where the bot auto-responds to messages'),
)
.addSubcommand((sub) =>
sub
.setName('modify')
.setDescription('Add or remove auto-respond channels via select menu UI (first 25 text channels only)'),
);
export const listenChannelCommand = new SlashCommandBuilder()
.setName('listen')
.setDescription('Change what channels the bot actively listens to and learns from.')
@@ -110,7 +139,8 @@ const commands = [
inviteCommand.toJSON(),
messageCommand.toJSON(),
listenChannelCommand.toJSON(),
trainCommand.toJSON(),
autoRespondCommand.toJSON(),
trainCommand.toJSON()
];
export async function deployCommands(clientId: string) {

View File

@@ -12,6 +12,11 @@ export class Channel extends BaseEntity {
})
listen: boolean;
@Column({
default: false,
})
autoRespond: boolean;
@ManyToOne(() => Guild, (guild) => guild.channels)
guild: Guild;
}

View File

@@ -1,6 +1,8 @@
import 'source-map-support/register';
import { CONFIG_DIR } from './config/setup';
import 'reflect-metadata';
import * as Discord from 'discord.js';
import Markov, {
MarkovGenerateOptions,
MarkovConstructorOptions,
@@ -24,6 +26,7 @@ import {
listenChannelCommand,
messageCommand,
trainCommand,
autoRespondCommand,
} from './deploy-commands';
import { getRandomElement, getVersion, packageJson } from './util';
import ormconfig from './ormconfig';
@@ -35,6 +38,7 @@ interface MarkovDataCustom {
interface SelectMenuChannel {
id: string;
listen?: boolean;
autoRespond?: boolean;
name?: string;
}
@@ -53,11 +57,19 @@ type AgnosticReplyOptions = Omit<Discord.MessageCreateOptions, 'reply' | 'sticke
const INVALID_PERMISSIONS_MESSAGE = 'You do not have the permissions for this action.';
const INVALID_GUILD_MESSAGE = 'This action must be performed within a server.';
const rest = new Discord.REST({ version: '10' }).setToken(config.token);
const rest = new Discord.REST({
version: '10',
timeout: 120000, // 120 seconds
retries: 3
}).setToken(config.token);
const client = new Discord.Client<true>({
failIfNotExists: false,
intents: [Discord.GatewayIntentBits.GuildMessages, Discord.GatewayIntentBits.Guilds],
intents: [
Discord.GatewayIntentBits.GuildMessages,
Discord.GatewayIntentBits.Guilds,
Discord.GatewayIntentBits.GuildMembers
],
presence: {
activities: [
{
@@ -114,6 +126,53 @@ async function isValidChannel(channel: Discord.TextBasedChannel): Promise<boolea
return dbChannel?.listen || false;
}
async function isAutoRespondChannel(channel: Discord.TextBasedChannel): Promise<boolean> {
const channelId = getGuildChannelId(channel);
if (!channelId) return false;
const dbChannel = await Channel.findOneBy({ id: channelId });
return dbChannel?.autoRespond || false;
}
async function getAutoRespondChannels(guild: Discord.Guild): Promise<Discord.TextChannel[]> {
const dbChannels = await Channel.findBy({ guild: { id: guild.id }, autoRespond: true });
const channels = (
await Promise.all(
dbChannels.map(async (dbc) => {
try {
return guild.channels.fetch(dbc.id);
} catch (err) {
L.error({ erroredChannel: dbc, channelId: dbc.id }, 'Error fetching channel');
throw err;
}
}),
)
).filter((c): c is Discord.TextChannel => c !== null && c instanceof Discord.TextChannel);
return channels;
}
async function addAutoRespondChannels(channels: Discord.TextChannel[], guildId: string): Promise<void> {
const dbChannels = channels.map((c) => {
return Channel.create({ id: c.id, guild: Guild.create({ id: guildId }), autoRespond: true });
});
await Channel.save(dbChannels);
}
async function removeAutoRespondChannels(channels: Discord.TextChannel[], guildId: string): Promise<void> {
const dbChannels = channels.map((c) => {
return Channel.create({ id: c.id, guild: Guild.create({ id: guildId }), autoRespond: false });
});
await Channel.save(dbChannels);
}
async function listAutoRespondChannels(interaction: Discord.CommandInteraction): Promise<string> {
if (!interaction.guildId || !interaction.guild) return INVALID_GUILD_MESSAGE;
const channels = await getAutoRespondChannels(interaction.guild);
const channelText = channels.reduce((list, channel) => {
return `${list}\n • <#${channel.id}>`;
}, '');
return `The bot will automatically respond to all messages in ${channels.length} channel(s).${channelText}`;
}
function isHumanAuthoredMessage(message: Discord.Message | Discord.PartialMessage): boolean {
return !(message.author?.bot || message.system);
}
@@ -151,7 +210,12 @@ async function getTextChannels(guild: Discord.Guild): Promise<SelectMenuChannel[
}));
const notFoundDbChannels: SelectMenuChannel[] = textChannels
.filter((c) => !foundDbChannels.find((d) => d.id === c.id))
.map((c) => ({ id: c.id, listen: false, name: textChannels.find((t) => t.id === c.id)?.name }));
.map((c) => ({
id: c.id,
listen: false,
autoRespond: false,
name: textChannels.find((t) => t.id === c.id)?.name
}));
const limitedDbChannels = foundDbChannelsWithName
.concat(notFoundDbChannels)
.slice(0, MAX_SELECT_OPTIONS);
@@ -223,7 +287,7 @@ function isAllowedUser(
return true;
}
type MessageCommands = 'respond' | 'train' | 'help' | 'invite' | 'debug' | 'tts' | null;
type MessageCommands = 'respond' | 'train' | 'help' | 'invite' | 'debug' | null;
/**
* Reads a new message and checks if and which command it is.
@@ -246,8 +310,6 @@ function validateMessage(message: Discord.Message): MessageCommands {
command = 'invite';
} else if (split[1] === 'debug') {
command = 'debug';
} else if (split[1] === 'tts') {
command = 'tts';
}
}
return command;
@@ -272,12 +334,23 @@ function messageToData(message: Discord.Message): AddDataProps {
/**
* Recursively gets all messages in a text channel's history.
*/
import { TrainingStateManager } from './training-state';
async function saveGuildMessageHistory(
interaction: Discord.Message | Discord.CommandInteraction,
clean = true,
): Promise<string> {
if (!isModerator(interaction.member)) return INVALID_PERMISSIONS_MESSAGE;
if (!interaction.guildId || !interaction.guild) return INVALID_GUILD_MESSAGE;
const stateManager = new TrainingStateManager(interaction.guildId, CONFIG_DIR);
// Check if training is already in progress
const currentState = stateManager.getState();
if (currentState.inProgress) {
return `Training is already in progress. Last update: ${currentState.lastUpdate}. Use /train with clean=true to restart.`;
}
const markov = await getMarkovByGuildId(interaction.guildId);
const channels = await getValidChannels(interaction.guild);
@@ -287,11 +360,22 @@ async function saveGuildMessageHistory(
}
if (clean) {
L.debug('Deleting old data');
L.debug('Deleting old data and resetting state');
await markov.delete();
stateManager.reset();
} else {
L.debug('Not deleting old data during training');
// Filter out already processed channels when not cleaning
const unprocessedChannels = channels.filter(
channel => !stateManager.isChannelProcessed(channel.id)
);
if (unprocessedChannels.length === 0) {
return 'All channels have been processed. Use clean=true to retrain.';
}
channels.splice(0, channels.length, ...unprocessedChannels);
}
stateManager.startTraining();
const channelIds = channels.map((c) => c.id);
L.debug({ channelIds }, `Training from text channels`);
@@ -332,15 +416,37 @@ async function saveGuildMessageHistory(
progressMessage = (await interaction.followUp(updateMessageData)) as Discord.Message;
}
const PAGE_SIZE = 100;
const UPDATE_RATE = 1000; // In number of messages processed
const PAGE_SIZE = 50; // Reduced page size for better stability
const UPDATE_RATE = 500; // More frequent updates
const BATCH_SIZE = 100; // Number of messages to process before a small delay
const BATCH_DELAY = 100; // Milliseconds to wait between batches
const MAX_MEMORY_USAGE = 1024 * 1024 * 1024; // 1GB memory limit
let lastUpdate = 0;
let messagesCount = 0;
let firstMessageDate: number | undefined;
let batchCount = 0;
// Monitor memory usage
const getMemoryUsage = () => {
const used = process.memoryUsage();
return used.heapUsed;
};
// Add delay between batches
const processingDelay = () => new Promise(resolve => setTimeout(resolve, BATCH_DELAY));
try {
// eslint-disable-next-line no-restricted-syntax
for (const channel of channels) {
let oldestMessageID: string | undefined;
try {
// Check if we should skip this channel (already processed)
if (stateManager.isChannelProcessed(channel.id)) {
L.debug({ channelId: channel.id }, 'Skipping already processed channel');
continue;
}
let keepGoing = true;
let oldestMessageID = stateManager.shouldResumeFromMessage(channel.id);
L.debug({ channelId: channel.id, messagesCount }, `Training from channel`);
const channelCreateDate = channel.createdTimestamp;
const channelEta = makeEta({ autostart: true, min: 0, max: 1, historyTimeConstant: 30 });
@@ -407,15 +513,55 @@ async function saveGuildMessageHistory(
allBatchMessages = allBatchMessages.concat(channelBatchMessages);
try {
// Check memory usage before processing
const memoryUsage = getMemoryUsage();
if (memoryUsage > MAX_MEMORY_USAGE) {
L.warn('Memory usage too high, waiting for garbage collection');
await processingDelay();
global.gc?.(); // Optional garbage collection if --expose-gc flag is used
}
// Filter and data map messages to be ready for addition to the corpus
const humanAuthoredMessages = allBatchMessages
.filter((m) => isHumanAuthoredMessage(m))
.map(messageToData);
L.trace({ oldestMessageID }, `Saving ${humanAuthoredMessages.length} messages`);
// Process messages in smaller batches for stability
for (let i = 0; i < humanAuthoredMessages.length; i += BATCH_SIZE) {
const batch = humanAuthoredMessages.slice(i, i + BATCH_SIZE);
L.trace({ oldestMessageID, batchSize: batch.length }, `Saving batch of messages`);
try {
// eslint-disable-next-line no-await-in-loop
await markov.addData(humanAuthoredMessages);
L.trace('Finished saving messages');
messagesCount += humanAuthoredMessages.length;
await markov.addData(batch);
batchCount++;
messagesCount += batch.length;
// Update state after successful batch
const lastMessage = allBatchMessages.last();
if (lastMessage) {
stateManager.updateProgress(channel.id, lastMessage.id, messagesCount);
}
// Add delay between batches
if (batchCount % 5 === 0) { // Every 5 batches
await processingDelay();
}
} catch (err) {
stateManager.recordError(err as Error, channel.id, oldestMessageID);
L.error({ err, batchSize: batch.length }, 'Error saving batch of messages');
// Continue with next batch instead of failing completely
continue;
}
}
L.trace('Finished processing message batches');
} catch (err) {
L.error({ err }, 'Error processing messages');
// Wait a bit before continuing to next batch of messages
await processingDelay();
}
const lastMessage = channelBatchMessages.last();
// Update tracking metrics
@@ -459,10 +605,22 @@ async function saveGuildMessageHistory(
});
}
}
}
} catch (err) {
L.error({ err }, 'Error processing channel');
stateManager.recordError(err as Error);
// Continue with next channel
}
}
L.info({ channelIds }, `Trained from ${messagesCount} past human authored messages.`);
return `Trained from ${messagesCount} past human authored messages.`;
L.info({ channelIds }, `Trained from ${messagesCount} past human authored messages.`);
stateManager.finishTraining();
return `Trained from ${messagesCount} past human authored messages.`;
} catch (err) {
const error = err as Error;
L.error({ err }, 'Error during training completion');
stateManager.recordError(error);
return `Training encountered an error: ${error.message}. Use clean=false to resume from last checkpoint.`;
}
}
interface JSONImport {
@@ -481,7 +639,17 @@ async function trainFromAttachmentJson(
if (!isModerator(interaction.member)) return INVALID_PERMISSIONS_MESSAGE;
if (!interaction.guildId || !interaction.guild) return INVALID_GUILD_MESSAGE;
const { guildId } = interaction;
const stateManager = new TrainingStateManager(guildId, CONFIG_DIR);
// Check if training is already in progress
const currentState = stateManager.getState();
if (currentState.inProgress) {
return `Training is already in progress. Last update: ${currentState.lastUpdate}. Use clean=true to restart.`;
}
const markov = await getMarkovByGuildId(guildId);
stateManager.startTraining();
let trainingData: AddDataProps[];
try {
@@ -517,14 +685,49 @@ async function trainFromAttachmentJson(
if (clean) {
L.debug('Deleting old data');
await markov.delete();
stateManager.reset();
} else {
L.debug('Not deleting old data during training');
}
await markov.addData(trainingData);
const BATCH_SIZE = 100;
const BATCH_DELAY = 100;
let processedCount = 0;
let batchCount = 0;
L.info(`Trained from ${trainingData.length} past human authored messages.`);
return `Trained from ${trainingData.length} past human authored messages.`;
try {
// Process messages in batches
for (let i = 0; i < trainingData.length; i += BATCH_SIZE) {
const batch = trainingData.slice(i, i + BATCH_SIZE);
try {
await markov.addData(batch);
processedCount += batch.length;
batchCount++;
// Update state after successful batch
stateManager.updateProgress('json-import', i.toString(), processedCount);
// Add delay between batches
if (batchCount % 5 === 0) {
await new Promise(resolve => setTimeout(resolve, BATCH_DELAY));
}
} catch (err) {
L.error({ err, batchIndex: i }, 'Error processing JSON batch');
stateManager.recordError(err as Error, 'json-import', i.toString());
// Continue with next batch instead of failing completely
continue;
}
}
L.info(`Successfully trained from ${processedCount} messages from JSON.`);
stateManager.finishTraining();
return `Successfully trained from ${processedCount} messages from JSON.`;
} catch (err) {
const error = err as Error;
L.error({ err }, 'Error during JSON training completion');
stateManager.recordError(error);
return `Training encountered an error: ${error.message}. Use clean=false to resume from last checkpoint.`;
}
}
interface GenerateResponse {
@@ -534,7 +737,6 @@ interface GenerateResponse {
}
interface GenerateOptions {
tts?: boolean;
debug?: boolean;
startSeed?: string;
}
@@ -551,7 +753,7 @@ async function generateResponse(
options?: GenerateOptions,
): Promise<GenerateResponse> {
L.debug({ options }, 'Responding...');
const { tts = false, debug = false, startSeed } = options || {};
const { debug = false, startSeed } = options || {};
if (!interaction.guildId) {
L.warn('Received an interaction without a guildId');
return { error: { content: INVALID_GUILD_MESSAGE } };
@@ -568,7 +770,6 @@ async function generateResponse(
L.info({ string: response.string }, 'Generated response text');
L.debug({ response }, 'Generated response object');
const messageOpts: AgnosticReplyOptions = {
tts,
allowedMentions: { repliedUser: false, parse: [] },
};
const attachmentUrls = response.refs
@@ -652,12 +853,17 @@ function helpMessage(): AgnosticReplyOptions {
.addFields([
{
name: `${config.messageCommandPrefix} or /${messageCommand.name}`,
value: `Generates a sentence to say based on the chat database. Send your message as TTS to recieve it as TTS.`,
value: `Generates a sentence based on the chat database.`,
},
{
name: `/${listenChannelCommand.name}`,
value: `Add, remove, list, or modify the list of channels the bot listens to.`,
value: `Add, remove, list, or modify the list of channels the bot listens to and learns from.`,
},
{
name: `/autorespond`,
value: `Add, remove, list, or modify the list of channels where the bot will automatically respond to all messages.`,
},
{
@@ -674,11 +880,6 @@ function helpMessage(): AgnosticReplyOptions {
name: `${config.messageCommandPrefix} debug or /${messageCommand.name} debug: True`,
value: `Runs the ${config.messageCommandPrefix} command and follows it up with debug info.`,
},
{
name: `${config.messageCommandPrefix} tts or /${messageCommand.name} tts: True`,
value: `Runs the ${config.messageCommandPrefix} command and reads it with text-to-speech.`,
},
])
.setFooter({
text: `${packageJson().name} ${getVersion()} by ${
@@ -696,9 +897,8 @@ function generateInviteUrl(): string {
permissions: [
'ViewChannel',
'SendMessages',
'SendTTSMessages',
'AttachFiles',
'ReadMessageHistory',
'ReadMessageHistory'
],
});
}
@@ -789,11 +989,6 @@ client.on('messageCreate', async (message) => {
const generatedResponse = await generateResponse(message);
await handleResponseMessage(generatedResponse, message);
}
if (command === 'tts') {
L.debug('Responding to legacy command tts');
const generatedResponse = await generateResponse(message, { tts: true });
await handleResponseMessage(generatedResponse, message);
}
if (command === 'debug') {
L.debug('Responding to legacy command debug');
const generatedResponse = await generateResponse(message, { debug: true });
@@ -802,11 +997,23 @@ client.on('messageCreate', async (message) => {
if (command === null) {
if (isHumanAuthoredMessage(message)) {
if (client.user && message.mentions.has(client.user)) {
// Check if response channels are configured and if this channel is allowed
if (config.responseChannelIds.length > 0 && !config.responseChannelIds.includes(message.channel.id)) {
L.debug('Ignoring mention in non-response channel');
return;
}
L.debug('Responding to mention');
// <@!278354154563567636> how are you doing?
const startSeed = message.content.replace(/<@!\d+>/g, '').trim();
const generatedResponse = await generateResponse(message, { startSeed });
await handleResponseMessage(generatedResponse, message);
} else if (await isAutoRespondChannel(message.channel)) {
// Auto-respond to all messages in configured channels using message content as context
L.debug('Auto-responding in configured channel with context');
const startSeed = message.content.trim();
const generatedResponse = await generateResponse(message, { startSeed });
await handleResponseMessage(generatedResponse, message);
}
if (await isValidChannel(message.channel)) {
@@ -848,7 +1055,7 @@ client.on('threadDelete', async (thread) => {
await markov.removeTags([thread.id]);
});
// eslint-disable-next-line consistent-return
client.on('interactionCreate', async (interaction) => {
if (interaction.isChatInputCommand()) {
L.info({ command: interaction.commandName }, 'Recieved slash command');
@@ -859,23 +1066,12 @@ client.on('interactionCreate', async (interaction) => {
await interaction.reply(inviteMessage());
} else if (interaction.commandName === messageCommand.name) {
await interaction.deferReply();
const tts = interaction.options.getBoolean('tts') || false;
const debug = interaction.options.getBoolean('debug') || false;
const startSeed = interaction.options.getString('seed')?.trim() || undefined;
const generatedResponse = await generateResponse(interaction, { tts, debug, startSeed });
const generatedResponse = await generateResponse(interaction, { debug, startSeed });
/**
* TTS doesn't work when using editReply, so instead we use delete + followUp
* However, delete + followUp is ugly and shows the bot replying to "Message could not be loaded.",
* so we avoid it if possible
*/
if (generatedResponse.message) {
if (generatedResponse.message.tts) {
await interaction.deleteReply();
await interaction.followUp(generatedResponse.message);
} else {
await interaction.editReply(generatedResponse.message);
}
} else {
await interaction.deleteReply();
}
@@ -943,6 +1139,67 @@ client.on('interactionCreate', async (interaction) => {
ephemeral: true,
});
}
} else if (interaction.commandName === autoRespondCommand.name) {
await interaction.deferReply();
const subCommand = interaction.options.getSubcommand(true) as 'add' | 'remove' | 'list' | 'modify';
if (subCommand === 'list') {
const reply = await listAutoRespondChannels(interaction);
await interaction.editReply(reply);
} else if (subCommand === 'add') {
if (!isModerator(interaction.member)) {
return handleUnprivileged(interaction);
}
if (!interaction.guildId) {
return handleNoGuild(interaction);
}
const channels = getChannelsFromInteraction(interaction);
await addAutoRespondChannels(channels, interaction.guildId);
await interaction.editReply(
`Added ${channels.length} text channels to auto-respond list.`
);
} else if (subCommand === 'remove') {
if (!isModerator(interaction.member)) {
return handleUnprivileged(interaction);
}
if (!interaction.guildId) {
return handleNoGuild(interaction);
}
const channels = getChannelsFromInteraction(interaction);
await removeAutoRespondChannels(channels, interaction.guildId);
await interaction.editReply(
`Removed ${channels.length} text channels from auto-respond list.`
);
} else if (subCommand === 'modify') {
if (!interaction.guild) {
return handleNoGuild(interaction);
}
if (!isModerator(interaction.member)) {
await handleUnprivileged(interaction);
}
await interaction.deleteReply();
const dbTextChannels = await getTextChannels(interaction.guild);
const row = new Discord.ActionRowBuilder<Discord.StringSelectMenuBuilder>().addComponents(
new Discord.StringSelectMenuBuilder()
.setCustomId('autorespond-modify-select')
.setPlaceholder('Nothing selected')
.setMinValues(0)
.setMaxValues(dbTextChannels.length)
.addOptions(
dbTextChannels.map((c) => ({
label: `#${c.name}` || c.id,
value: c.id,
default: c.autoRespond || false,
})),
),
);
await interaction.followUp({
content: 'Select which channels you would like the bot to auto-respond in',
components: [row],
ephemeral: true,
});
}
} else if (interaction.commandName === trainCommand.name) {
await interaction.deferReply();
const clean = interaction.options.getBoolean('clean') ?? true;
@@ -990,6 +1247,37 @@ client.on('interactionCreate', async (interaction) => {
content: 'Updated actively listened to channels list.',
ephemeral: true,
});
} else if (interaction.customId === 'autorespond-modify-select') {
await interaction.deferUpdate();
const { guild } = interaction;
if (!isModerator(interaction.member)) {
return handleUnprivileged(interaction, false);
}
if (!guild) {
return handleNoGuild(interaction, false);
}
const allChannels =
(interaction.component as Discord.StringSelectMenuComponent).options?.map((o) => o.value) ||
[];
const selectedChannelIds = interaction.values;
const textChannels = (
await Promise.all(
allChannels.map(async (c) => {
return guild.channels.fetch(c);
}),
)
).filter((c): c is Discord.TextChannel => c !== null && c instanceof Discord.TextChannel);
const unselectedChannels = textChannels.filter((t) => !selectedChannelIds.includes(t.id));
const selectedChannels = textChannels.filter((t) => selectedChannelIds.includes(t.id));
await addAutoRespondChannels(selectedChannels, guild.id);
await removeAutoRespondChannels(unselectedChannels, guild.id);
await interaction.followUp({
content: 'Updated auto-respond channels list.',
ephemeral: true,
});
}
}
});

391
src/train.ts Normal file
View File

@@ -0,0 +1,391 @@
import 'source-map-support/register';
import 'reflect-metadata';
import Markov, { MarkovConstructorOptions, AddDataProps } from 'markov-strings-db';
import { DataSource } from 'typeorm';
import { promises as fs } from 'fs';
import path from 'path';
import { config } from './config';
import ormconfig from './ormconfig';
import { Guild } from './entity/Guild';
import { Channel } from './entity/Channel';
import L from './logger';
import { MarkovDataCustom } from './types';
import { TrainingStateManager } from './training-state';
import { CONFIG_DIR } from './config/setup';
const markovOpts: MarkovConstructorOptions = {
stateSize: config.stateSize,
};
// Constants for batch processing
const BATCH_SIZE = 100; // Process messages in batches
const BATCH_DELAY = 100; // Milliseconds to wait between batches
const MAX_MEMORY_USAGE = 1024 * 1024 * 1024; // 1GB memory limit
// Monitor memory usage
const getMemoryUsage = () => {
const used = process.memoryUsage();
return used.heapUsed;
};
// Add delay between batches
const processingDelay = () => new Promise(resolve => setTimeout(resolve, BATCH_DELAY));
async function getMarkovByGuildId(guildId: string): Promise<Markov> {
const markov = new Markov({ id: guildId, options: { ...markovOpts, id: guildId } });
L.trace({ guildId }, 'Setting up markov instance');
await markov.setup(); // Connect the markov instance to the DB to assign it an ID
return markov;
}
interface JSONImport {
message: string;
attachments?: string[];
}
/**
* Train from a JSON file containing messages
*/
async function trainFromJson(
guildId: string,
jsonPath: string,
clean = true,
): Promise<string> {
const markov = await getMarkovByGuildId(guildId);
let trainingData: AddDataProps[];
try {
const fileContent = await fs.readFile(jsonPath, 'utf-8');
const importData = JSON.parse(fileContent) as JSONImport[];
// Filter out invalid entries first
const validData = importData.filter((datum, index) => {
if (!datum.message || typeof datum.message !== 'string') {
L.debug({ index }, 'Skipping entry without valid message');
return false;
}
if (datum.attachments?.some(a => typeof a !== 'string')) {
L.debug({ index }, 'Skipping entry with invalid attachments');
return false;
}
return true;
});
// Map valid entries to training data
trainingData = validData.map(datum => {
let custom: MarkovDataCustom | undefined;
if (datum.attachments?.length) {
custom = { attachments: datum.attachments };
}
return {
string: datum.message,
custom,
tags: [guildId]
};
});
} catch (err) {
L.error(err);
if (err instanceof SyntaxError) {
return 'The provided JSON file has invalid formatting. See the logs for details.';
}
return `Error reading file: ${err instanceof Error ? err.message : 'Unknown error'}`;
}
if (clean) {
L.debug('Deleting old data');
await markov.delete();
} else {
L.debug('Not deleting old data during training');
}
let processedCount = 0;
let batchCount = 0;
const totalMessages = trainingData.length;
// Process messages in batches
for (let i = 0; i < trainingData.length; i += BATCH_SIZE) {
try {
// Check memory usage
const memoryUsage = getMemoryUsage();
if (memoryUsage > MAX_MEMORY_USAGE) {
L.warn('Memory usage too high, waiting for garbage collection');
await processingDelay();
global.gc?.(); // Optional garbage collection if --expose-gc flag is used
}
const batch = trainingData.slice(i, i + BATCH_SIZE);
await markov.addData(batch);
processedCount += batch.length;
batchCount++;
// Log progress
if (batchCount % 5 === 0) {
const progress = (processedCount / totalMessages * 100).toFixed(2);
L.info(`Progress: ${progress}% (${processedCount}/${totalMessages} messages)`);
await processingDelay(); // Add delay every 5 batches
}
} catch (err) {
L.error({ err, batchIndex: i }, 'Error processing batch');
// Continue with next batch instead of failing completely
await processingDelay(); // Wait a bit longer after an error
continue;
}
}
L.info(`Successfully trained from ${processedCount} messages.`);
return `Successfully trained from ${processedCount} messages.`;
}
/**
* Train from all JSON files in a directory
*/
/**
* Train from all JSON files in a directory
* @param guildId The Discord guild ID
* @param dirPath Path to directory containing JSON files
* @param clean Whether to clean existing data before training
*/
/**
* Acquire a lock file for training to prevent concurrent processes
*/
async function acquireTrainingLock(guildId: string): Promise<boolean> {
const lockPath = path.join(CONFIG_DIR, `${guildId}_training.lock`);
try {
await fs.writeFile(lockPath, process.pid.toString(), { flag: 'wx' });
return true;
} catch (err) {
if ((err as NodeJS.ErrnoException).code === 'EEXIST') {
try {
const pid = parseInt(await fs.readFile(lockPath, 'utf-8'));
try {
// Check if process is still running
process.kill(pid, 0);
return false; // Process is still running
} catch {
// Process is not running, safe to remove lock
await fs.unlink(lockPath);
await fs.writeFile(lockPath, process.pid.toString());
return true;
}
} catch {
// Error reading/writing lock file
return false;
}
}
return false;
}
}
/**
* Release the training lock file
*/
async function releaseTrainingLock(guildId: string): Promise<void> {
const lockPath = path.join(CONFIG_DIR, `${guildId}_training.lock`);
try {
await fs.unlink(lockPath);
} catch {
// Ignore errors during cleanup
}
}
/**
* Sanitize and validate a directory path
*/
async function validateDirectoryPath(dirPath: string): Promise<string> {
// Resolve to absolute path
const absolutePath = path.resolve(dirPath);
// Prevent directory traversal
const normalizedPath = path.normalize(absolutePath);
if (!normalizedPath.startsWith(process.cwd())) {
throw new Error('Directory must be within current working directory');
}
// Verify directory exists and is accessible
try {
const stats = await fs.stat(normalizedPath);
if (!stats.isDirectory()) {
throw new Error('Path is not a directory');
}
await fs.access(normalizedPath, fs.constants.R_OK);
return normalizedPath;
} catch (err) {
throw new Error(`Invalid directory path: ${err instanceof Error ? err.message : 'Unknown error'}`);
}
}
/**
* Train from all JSON files in a directory
*/
async function trainFromDirectory(
guildId: string,
dirPath: string,
clean = true,
): Promise<string> {
L.debug({ guildId, dirPath, clean }, 'Starting directory training');
const stateManager = new TrainingStateManager(guildId, CONFIG_DIR);
// Set up cleanup handler
const cleanup = async () => {
try {
await releaseTrainingLock(guildId);
stateManager.finishTraining();
} catch (err) {
L.error({ err }, 'Error during cleanup');
}
};
// Handle process termination
process.once('SIGINT', cleanup);
process.once('SIGTERM', cleanup);
try {
// Try to acquire lock
if (!await acquireTrainingLock(guildId)) {
return 'Another training process is already running. Please wait for it to complete.';
}
// Always reset state at the start of training
stateManager.reset();
try {
// Validate and normalize directory path
const absolutePath = await validateDirectoryPath(dirPath);
// Get all JSON files in the directory
L.trace({ dirPath: absolutePath }, 'Reading directory');
const files = await fs.readdir(absolutePath);
const jsonFiles = files.filter(file => file.toLowerCase().endsWith('.json'));
if (jsonFiles.length === 0) {
L.warn({ dirPath: absolutePath }, 'No JSON files found in directory');
return 'No JSON files found in the specified directory.';
}
let totalProcessed = 0;
let batchCount = 0;
L.info({ fileCount: jsonFiles.length }, 'Found JSON files to process');
stateManager.startTraining();
// Process first file with clean flag, subsequent files without cleaning
for (let i = 0; i < jsonFiles.length; i++) {
const jsonPath = path.join(absolutePath, jsonFiles[i]);
const fileNumber = i + 1;
L.debug(
{ file: jsonFiles[i], progress: `${fileNumber}/${jsonFiles.length}` },
'Processing file'
);
try {
// Check memory usage before processing file
const memoryUsage = getMemoryUsage();
if (memoryUsage > MAX_MEMORY_USAGE) {
L.warn('Memory usage too high, waiting for garbage collection');
await processingDelay();
global.gc?.(); // Optional garbage collection if --expose-gc flag is used
}
// Check if we should skip this file (already processed)
if (!clean && stateManager.isChannelProcessed(jsonFiles[i])) {
L.debug({ file: jsonFiles[i] }, 'Skipping already processed file');
continue;
}
const result = await trainFromJson(
guildId,
jsonPath,
i === 0 ? clean : false // Only clean on first file
);
// Extract number of processed messages from result string
const processed = parseInt(result.match(/\d+/)?.[0] || '0');
totalProcessed += processed;
batchCount++;
// Update state after each file
stateManager.updateProgress('json-import', jsonFiles[i], totalProcessed);
L.trace(
{ file: jsonFiles[i], processed, totalProcessed },
'File processing complete'
);
// Add delay between files
if (batchCount % 5 === 0) {
await processingDelay();
}
// Clear any references that might be held
if (global.gc) {
global.gc();
}
} catch (err) {
const error = err as Error;
L.error(
{ error: error.message, file: jsonFiles[i], stack: error.stack },
'Error processing JSON file'
);
stateManager.recordError(error, 'json-import', jsonFiles[i]);
// Add longer delay after error
await processingDelay();
// Continue with next file instead of failing completely
continue;
}
}
const summary = { totalProcessed, fileCount: jsonFiles.length };
L.info(summary, 'Directory training complete');
return `Successfully trained from ${totalProcessed} messages across ${jsonFiles.length} files.`;
} finally {
// Clean up regardless of success/failure
await cleanup();
// Remove process termination handlers
process.off('SIGINT', cleanup);
process.off('SIGTERM', cleanup);
}
} catch (err) {
const error = err as Error;
L.error(
{ error: error.message, stack: error.stack, dirPath },
'Error during directory training'
);
stateManager.recordError(error);
return `Training encountered an error: ${error.message}. Use clean=false to resume from last checkpoint.`;
}
}
async function main(): Promise<void> {
const args = process.argv.slice(2);
if (args.length < 2) {
console.log('Usage: node train.js <guildId> <path> [--keep-existing] [--directory]');
console.log('Options:');
console.log(' --keep-existing Keep existing training data');
console.log(' --directory Process all JSON files in the specified directory');
process.exit(1);
}
const guildId = args[0];
const inputPath = args[1];
const keepExisting = args.includes('--keep-existing');
const isDirectory = args.includes('--directory');
const dataSourceOptions = Markov.extendDataSourceOptions(ormconfig);
const dataSource = new DataSource(dataSourceOptions);
await dataSource.initialize();
// Ensure guild exists in DB
await Guild.upsert(Guild.create({ id: guildId }), ['id']);
const result = isDirectory
? await trainFromDirectory(guildId, inputPath, !keepExisting)
: await trainFromJson(guildId, inputPath, !keepExisting);
console.log(result);
await dataSource.destroy();
}
if (require.main === module) {
main().catch(console.error);
}

113
src/training-state.ts Normal file
View File

@@ -0,0 +1,113 @@
import fs from 'fs-extra';
import path from 'path';
import { TrainingState } from './types';
import L from './logger';
export class TrainingStateManager {
private stateFile: string;
private state: TrainingState;
constructor(guildId: string, configDir: string = 'config') {
this.stateFile = path.join(configDir, 'training-state', `${guildId}.json`);
// Initialize with default state
this.state = {
guildId,
processedChannels: [],
totalMessages: 0,
lastUpdate: new Date().toISOString(),
inProgress: false
};
// Ensure directory exists
fs.ensureDirSync(path.dirname(this.stateFile));
// Load existing state if available
this.loadState();
}
private loadState(): void {
try {
if (fs.existsSync(this.stateFile)) {
const savedState = fs.readJsonSync(this.stateFile);
this.state = { ...this.state, ...savedState };
L.info({ guildId: this.state.guildId }, 'Loaded existing training state');
}
} catch (err) {
L.error({ err }, 'Error loading training state');
// Keep using default state if load fails
}
}
private saveState(): void {
try {
fs.writeJsonSync(this.stateFile, this.state, { spaces: 2 });
} catch (err) {
L.error({ err }, 'Error saving training state');
}
}
public startTraining(): void {
this.state.inProgress = true;
this.state.error = undefined;
this.state.lastUpdate = new Date().toISOString();
this.saveState();
}
public finishTraining(): void {
this.state.inProgress = false;
this.state.lastUpdate = new Date().toISOString();
this.saveState();
}
public updateProgress(channelId: string, messageId: string, messagesProcessed: number): void {
this.state.lastChannelId = channelId;
this.state.lastMessageId = messageId;
this.state.totalMessages = messagesProcessed;
this.state.lastUpdate = new Date().toISOString();
this.saveState();
}
public markChannelComplete(channelId: string): void {
if (!this.state.processedChannels.includes(channelId)) {
this.state.processedChannels.push(channelId);
this.saveState();
}
}
public recordError(error: Error, channelId?: string, messageId?: string): void {
this.state.error = {
message: error.message,
channelId,
messageId,
timestamp: new Date().toISOString()
};
this.saveState();
}
public isChannelProcessed(channelId: string): boolean {
return this.state.processedChannels.includes(channelId);
}
public shouldResumeFromMessage(channelId: string): string | undefined {
if (this.state.inProgress && this.state.lastChannelId === channelId) {
return this.state.lastMessageId;
}
return undefined;
}
public getState(): TrainingState {
return { ...this.state };
}
public reset(): void {
this.state = {
guildId: this.state.guildId,
processedChannels: [],
totalMessages: 0,
lastUpdate: new Date().toISOString(),
inProgress: false
};
this.saveState();
}
}

19
src/types.ts Normal file
View File

@@ -0,0 +1,19 @@
export interface MarkovDataCustom {
attachments: string[];
}
export interface TrainingState {
guildId: string;
lastMessageId?: string;
lastChannelId?: string;
processedChannels: string[];
totalMessages: number;
lastUpdate: string;
inProgress: boolean;
error?: {
message: string;
channelId?: string;
messageId?: string;
timestamp: string;
};
}

View File

@@ -3,7 +3,7 @@
"compilerOptions": {
"target": "es2021", /* Specify ECMAScript target version: 'ES3' (default), 'ES5', 'ES2015', 'ES2016', 'ES2017', 'ES2018', 'ES2019' or 'ESNEXT'. */
"module": "commonjs", /* Specify module code generation: 'none', 'commonjs', 'amd', 'system', 'umd', 'es2015', or 'ESNext'. */
"outDir": "./dist", /* Redirect output structure to the directory. */
"outDir": "./build", /* Redirect output structure to the directory. */
"removeComments": true, /* Do not emit comments to output. */
"esModuleInterop": true,
"strict": true, /* Enable all strict type-checking options. */