Added several things: Parser to import JSON from DiscordChatExporter, ability to train without bot running and more.

This commit is contained in:
pacnpal
2024-12-27 11:17:21 -05:00
parent 44ddad6b58
commit ec0c4e6c84
14 changed files with 4549 additions and 6025 deletions

3
.gitignore vendored
View File

@@ -68,3 +68,6 @@ error.json
markovDB.json markovDB.json
/config/ /config/
/exports/
/build/
/dist/

View File

@@ -14,11 +14,67 @@ A Markov chain bot using markov-strings.
* User: `/mark` * User: `/mark`
* Bot: ![worms are not baby snakes, by the way](img/respond.png) * Bot: ![worms are not baby snakes, by the way](img/respond.png)
### Training from a file ### Training from files
Using the `json` option in the `/train` command, you can import a list of messages. You can train the bot using JSON files in two ways:
1. Using the `json` option in the `/train` command to import a single file of messages.
2. Using the command line to train from either a single file or an entire directory of JSON files.
#### Using the Discord Command
Use the `json` option in the `/train` command to import a single file of messages.
An example JSON file can be seen [here](img/example-training.json). An example JSON file can be seen [here](img/example-training.json).
#### Using the Command Line
For bulk training from multiple files, you can use the command line interface. First, build the training script:
```bash
# Build the TypeScript files
npm run build
```
Then you can use the training script:
```bash
# Train from a single JSON file
node build/train.js <guildId> <jsonPath> [--keep-existing]
# Train from all JSON files in a directory
node build/train.js <guildId> <directoryPath> --directory [--keep-existing] [--expose-gc]
```
Options:
- `--keep-existing`: Don't clear existing training data before importing
- `--directory`: Process all JSON files in the specified directory
- `--expose-gc`: Enable garbage collection for better memory management (recommended for large directories)
Each JSON file should contain an array of messages in this format:
```json
[
{
"message": "Message content",
"attachments": ["optional", "attachment", "urls"]
}
]
```
When training from a directory:
- All .json files in the directory will be processed
- Files are processed sequentially to manage memory usage
- Progress is shown for each file
- A total count of processed messages is provided at the end
Security and Performance Notes:
- The directory must be within the project's working directory for security
- The process will create lock files in the config directory to prevent concurrent training
- Memory usage is monitored and managed automatically
- For large directories, use the `--expose-gc` flag for better memory management:
```bash
node --expose-gc build/train.js <guildId> <directoryPath> --directory
```
- Training can be safely interrupted with Ctrl+C; state will be preserved
- Use `--keep-existing` to resume interrupted training
## Setup ## Setup
This bot stores your Discord server's entire message history, so a public instance to invite to your server is not available due to obvious data privacy concerns. Instead, you can host it yourself. This bot stores your Discord server's entire message history, so a public instance to invite to your server is not available due to obvious data privacy concerns. Instead, you can host it yourself.

21
example-training.json Normal file
View File

@@ -0,0 +1,21 @@
[
{
"message": "Hello world!",
"attachments": []
},
{
"message": "This is an example message with an attachment",
"attachments": ["https://example.com/image.jpg"]
},
{
"message": "Another training message",
"attachments": []
},
{
"message": "Messages can have multiple attachments",
"attachments": [
"https://example.com/image1.jpg",
"https://example.com/image2.jpg"
]
}
]

124
imports/discord-parser.py Normal file
View File

@@ -0,0 +1,124 @@
import json
import os
import re
from glob import glob
import argparse
def strip_mentions(content):
# Pattern for Discord mentions: <@!?[0-9]+> or <@&[0-9]+> or <#[0-9]+>
mention_pattern = r'<@!?\d+>|<@&\d+>|<#\d+>'
return re.sub(mention_pattern, '', content).strip()
def extract_channel_name(filename):
# Extract channel name and part number from Discord export filename format
# Pattern: channel name before [numbers] and optional [part X]
channel_match = re.search(r'- ([^[\]]+) \[\d+\]', filename)
part_match = re.search(r'\[part (\d+)\]', filename, re.IGNORECASE)
if channel_match:
channel_name = channel_match.group(1).strip().lower()
# Remove any category names (text after last hyphen)
channel_name = channel_name.split(' - ')[-1].strip()
# Add part number if it exists
if part_match:
part_num = part_match.group(1)
return f"{channel_name}_part{part_num}.json"
return f"{channel_name}.json"
return "unknown_channel.json"
def parse_discord_export(input_file, output_dir):
# Generate output filename based on channel name
output_filename = extract_channel_name(input_file)
output_file = os.path.join(output_dir, output_filename)
try:
# Read the input file
with open(input_file, 'r', encoding='utf-8') as f:
data = json.load(f)
# Extract just the messages and their attachments
output_messages = []
for msg in data['messages']:
# Skip empty messages after stripping mentions
stripped_content = strip_mentions(msg['content'])
if not stripped_content and not msg['attachments']:
continue
message_obj = {
'message': stripped_content
}
# If there are attachments, add them to the message object
if msg['attachments']:
message_obj['attachments'] = [
attachment['url'] for attachment in msg['attachments']
]
# Only add messages that have content or attachments
if message_obj['message'] or 'attachments' in message_obj:
output_messages.append(message_obj)
# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Write the output file
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(output_messages, f, indent=4, ensure_ascii=False)
print(f"Successfully parsed: {input_file} -> {output_file}")
return True
except Exception as e:
print(f"Error processing {input_file}: {str(e)}")
return False
def process_directory(input_dir, output_dir):
# Find all Discord export JSON files in the input directory
pattern = os.path.join(input_dir, "*Discord*.json")
files = glob(pattern)
if not files:
print(f"No Discord export files found in: {input_dir}")
return
success_count = 0
failure_count = 0
for file in files:
if parse_discord_export(file, output_dir):
success_count += 1
else:
failure_count += 1
print(f"\nProcessing complete:")
print(f"Successfully processed: {success_count} files")
print(f"Failed to process: {failure_count} files")
def main():
parser = argparse.ArgumentParser(description='Process Discord export JSON files.')
parser.add_argument('-i', '--input', default='.',
help='Input directory containing Discord export files (default: current directory)')
parser.add_argument('-o', '--output', default='output',
help='Output directory for processed files (default: output)')
args = parser.parse_args()
# Convert to absolute paths
input_dir = os.path.abspath(args.input)
output_dir = os.path.abspath(args.output)
# Check if input directory exists
if not os.path.exists(input_dir):
print(f"Error: Input directory does not exist: {input_dir}")
return
print(f"Processing files from: {input_dir}")
print(f"Saving output to: {output_dir}")
process_directory(input_dir, output_dir)
if __name__ == "__main__":
main()

9348
package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -4,14 +4,14 @@
"description": "A conversational Markov chain bot for Discord", "description": "A conversational Markov chain bot for Discord",
"main": "dist/index.js", "main": "dist/index.js",
"scripts": { "scripts": {
"start": "NODE_ENV=production pm2 start --no-daemon dist/index.js", "start": "./node_modules/.bin/pm2 start --no-daemon dist/index.js",
"start:ts": "ts-node src/index.ts", "start:ts": "ts-node src/index.ts",
"build": "rimraf dist && tsc", "build": "rimraf build && /opt/homebrew/bin/node ./node_modules/.bin/tsc",
"lint": "tsc --noEmit && eslint .", "lint": "./node-v20.18.1-darwin-x64/bin/node ./node_modules/.bin/tsc --noEmit && ./node-v20.18.1-darwin-x64/bin/node ./node_modules/.bin/eslint .",
"docker:build": "docker build . -t charlocharlie/markov-discord:latest --target deploy", "docker:build": "docker build . -t charlocharlie/markov-discord:latest --target deploy",
"docker:run": "docker run --rm -ti -v $(pwd)/config:/usr/app/config charlocharlie/markov-discord:latest", "docker:run": "docker run --rm -ti -v $(pwd)/config:/usr/app/config charlocharlie/markov-discord:latest",
"typeorm": "node --require ts-node/register ./node_modules/typeorm/cli.js", "typeorm": "./node-v20.18.1-darwin-x64/bin/node --require ts-node/register ./node_modules/typeorm/cli.js",
"docs": "typedoc --out docs src/config/classes.ts" "docs": "./node-v20.18.1-darwin-x64/bin/node ./node_modules/.bin/typedoc --out docs src/config/classes.ts"
}, },
"repository": "https://github.com/claabs/markov-discord.git", "repository": "https://github.com/claabs/markov-discord.git",
"keywords": [ "keywords": [
@@ -31,7 +31,7 @@
}, },
"license": "MIT", "license": "MIT",
"dependencies": { "dependencies": {
"better-sqlite3": "^8.7.0", "better-sqlite3": "^9.6.0",
"bufferutil": "^4.0.8", "bufferutil": "^4.0.8",
"class-transformer": "^0.5.1", "class-transformer": "^0.5.1",
"class-validator": "^0.14.1", "class-validator": "^0.14.1",
@@ -42,6 +42,7 @@
"json5": "^2.2.3", "json5": "^2.2.3",
"markov-strings-db": "^4.2.0", "markov-strings-db": "^4.2.0",
"node-fetch": "^2.6.7", "node-fetch": "^2.6.7",
"node-gyp": "^11.0.0",
"pino": "^7.11.0", "pino": "^7.11.0",
"pino-pretty": "^7.6.1", "pino-pretty": "^7.6.1",
"reflect-metadata": "^0.2.2", "reflect-metadata": "^0.2.2",

View File

@@ -163,4 +163,19 @@ export class AppConfig {
@IsOptional() @IsOptional()
@IsString() @IsString()
devGuildId = process.env.DEV_GUILD_ID; devGuildId = process.env.DEV_GUILD_ID;
/**
* A list of channel IDs where the bot will respond to mentions.
* If empty, the bot will respond to mentions in any channel.
* @example ["734548250895319070"]
* @default []
* @env RESPONSE_CHANNEL_IDS (comma separated)
*/
@IsArray()
@IsString({ each: true })
@Type(() => String)
@IsOptional()
responseChannelIds = process.env.RESPONSE_CHANNEL_IDS
? process.env.RESPONSE_CHANNEL_IDS.split(',').map((id) => id.trim())
: [];
} }

View File

@@ -21,9 +21,6 @@ export const inviteCommand = new SlashCommandBuilder()
export const messageCommand = new SlashCommandBuilder() export const messageCommand = new SlashCommandBuilder()
.setName(config.slashCommandName) .setName(config.slashCommandName)
.setDescription('Generate a message from learned past messages') .setDescription('Generate a message from learned past messages')
.addBooleanOption((tts) =>
tts.setName('tts').setDescription('Read the message via text-to-speech.').setRequired(false),
)
.addBooleanOption((debug) => .addBooleanOption((debug) =>
debug debug
.setName('debug') .setName('debug')
@@ -49,6 +46,38 @@ const channelOptionsGenerator = (builder: SlashCommandChannelOption, index: numb
.setRequired(index === 0) .setRequired(index === 0)
.addChannelTypes(ChannelType.GuildText); .addChannelTypes(ChannelType.GuildText);
export const autoRespondCommand = new SlashCommandBuilder()
.setName('autorespond')
.setDescription('Configure channels where the bot will automatically respond to all messages')
.addSubcommand((sub) => {
sub
.setName('add')
.setDescription('Add channels where the bot will automatically respond to all messages');
Array.from(Array(CHANNEL_OPTIONS_MAX).keys()).forEach((index) =>
sub.addChannelOption((opt) => channelOptionsGenerator(opt, index)),
);
return sub;
})
.addSubcommand((sub) => {
sub
.setName('remove')
.setDescription('Remove channels from auto-response');
Array.from(Array(CHANNEL_OPTIONS_MAX).keys()).forEach((index) =>
sub.addChannelOption((opt) => channelOptionsGenerator(opt, index)),
);
return sub;
})
.addSubcommand((sub) =>
sub
.setName('list')
.setDescription('List the channels where the bot auto-responds to messages'),
)
.addSubcommand((sub) =>
sub
.setName('modify')
.setDescription('Add or remove auto-respond channels via select menu UI (first 25 text channels only)'),
);
export const listenChannelCommand = new SlashCommandBuilder() export const listenChannelCommand = new SlashCommandBuilder()
.setName('listen') .setName('listen')
.setDescription('Change what channels the bot actively listens to and learns from.') .setDescription('Change what channels the bot actively listens to and learns from.')
@@ -110,7 +139,8 @@ const commands = [
inviteCommand.toJSON(), inviteCommand.toJSON(),
messageCommand.toJSON(), messageCommand.toJSON(),
listenChannelCommand.toJSON(), listenChannelCommand.toJSON(),
trainCommand.toJSON(), autoRespondCommand.toJSON(),
trainCommand.toJSON()
]; ];
export async function deployCommands(clientId: string) { export async function deployCommands(clientId: string) {

View File

@@ -12,6 +12,11 @@ export class Channel extends BaseEntity {
}) })
listen: boolean; listen: boolean;
@Column({
default: false,
})
autoRespond: boolean;
@ManyToOne(() => Guild, (guild) => guild.channels) @ManyToOne(() => Guild, (guild) => guild.channels)
guild: Guild; guild: Guild;
} }

View File

@@ -1,6 +1,8 @@
import 'source-map-support/register'; import 'source-map-support/register';
import { CONFIG_DIR } from './config/setup';
import 'reflect-metadata'; import 'reflect-metadata';
import * as Discord from 'discord.js'; import * as Discord from 'discord.js';
import Markov, { import Markov, {
MarkovGenerateOptions, MarkovGenerateOptions,
MarkovConstructorOptions, MarkovConstructorOptions,
@@ -24,6 +26,7 @@ import {
listenChannelCommand, listenChannelCommand,
messageCommand, messageCommand,
trainCommand, trainCommand,
autoRespondCommand,
} from './deploy-commands'; } from './deploy-commands';
import { getRandomElement, getVersion, packageJson } from './util'; import { getRandomElement, getVersion, packageJson } from './util';
import ormconfig from './ormconfig'; import ormconfig from './ormconfig';
@@ -35,6 +38,7 @@ interface MarkovDataCustom {
interface SelectMenuChannel { interface SelectMenuChannel {
id: string; id: string;
listen?: boolean; listen?: boolean;
autoRespond?: boolean;
name?: string; name?: string;
} }
@@ -53,11 +57,19 @@ type AgnosticReplyOptions = Omit<Discord.MessageCreateOptions, 'reply' | 'sticke
const INVALID_PERMISSIONS_MESSAGE = 'You do not have the permissions for this action.'; const INVALID_PERMISSIONS_MESSAGE = 'You do not have the permissions for this action.';
const INVALID_GUILD_MESSAGE = 'This action must be performed within a server.'; const INVALID_GUILD_MESSAGE = 'This action must be performed within a server.';
const rest = new Discord.REST({ version: '10' }).setToken(config.token); const rest = new Discord.REST({
version: '10',
timeout: 120000, // 120 seconds
retries: 3
}).setToken(config.token);
const client = new Discord.Client<true>({ const client = new Discord.Client<true>({
failIfNotExists: false, failIfNotExists: false,
intents: [Discord.GatewayIntentBits.GuildMessages, Discord.GatewayIntentBits.Guilds], intents: [
Discord.GatewayIntentBits.GuildMessages,
Discord.GatewayIntentBits.Guilds,
Discord.GatewayIntentBits.GuildMembers
],
presence: { presence: {
activities: [ activities: [
{ {
@@ -114,6 +126,53 @@ async function isValidChannel(channel: Discord.TextBasedChannel): Promise<boolea
return dbChannel?.listen || false; return dbChannel?.listen || false;
} }
async function isAutoRespondChannel(channel: Discord.TextBasedChannel): Promise<boolean> {
const channelId = getGuildChannelId(channel);
if (!channelId) return false;
const dbChannel = await Channel.findOneBy({ id: channelId });
return dbChannel?.autoRespond || false;
}
async function getAutoRespondChannels(guild: Discord.Guild): Promise<Discord.TextChannel[]> {
const dbChannels = await Channel.findBy({ guild: { id: guild.id }, autoRespond: true });
const channels = (
await Promise.all(
dbChannels.map(async (dbc) => {
try {
return guild.channels.fetch(dbc.id);
} catch (err) {
L.error({ erroredChannel: dbc, channelId: dbc.id }, 'Error fetching channel');
throw err;
}
}),
)
).filter((c): c is Discord.TextChannel => c !== null && c instanceof Discord.TextChannel);
return channels;
}
async function addAutoRespondChannels(channels: Discord.TextChannel[], guildId: string): Promise<void> {
const dbChannels = channels.map((c) => {
return Channel.create({ id: c.id, guild: Guild.create({ id: guildId }), autoRespond: true });
});
await Channel.save(dbChannels);
}
async function removeAutoRespondChannels(channels: Discord.TextChannel[], guildId: string): Promise<void> {
const dbChannels = channels.map((c) => {
return Channel.create({ id: c.id, guild: Guild.create({ id: guildId }), autoRespond: false });
});
await Channel.save(dbChannels);
}
async function listAutoRespondChannels(interaction: Discord.CommandInteraction): Promise<string> {
if (!interaction.guildId || !interaction.guild) return INVALID_GUILD_MESSAGE;
const channels = await getAutoRespondChannels(interaction.guild);
const channelText = channels.reduce((list, channel) => {
return `${list}\n • <#${channel.id}>`;
}, '');
return `The bot will automatically respond to all messages in ${channels.length} channel(s).${channelText}`;
}
function isHumanAuthoredMessage(message: Discord.Message | Discord.PartialMessage): boolean { function isHumanAuthoredMessage(message: Discord.Message | Discord.PartialMessage): boolean {
return !(message.author?.bot || message.system); return !(message.author?.bot || message.system);
} }
@@ -151,7 +210,12 @@ async function getTextChannels(guild: Discord.Guild): Promise<SelectMenuChannel[
})); }));
const notFoundDbChannels: SelectMenuChannel[] = textChannels const notFoundDbChannels: SelectMenuChannel[] = textChannels
.filter((c) => !foundDbChannels.find((d) => d.id === c.id)) .filter((c) => !foundDbChannels.find((d) => d.id === c.id))
.map((c) => ({ id: c.id, listen: false, name: textChannels.find((t) => t.id === c.id)?.name })); .map((c) => ({
id: c.id,
listen: false,
autoRespond: false,
name: textChannels.find((t) => t.id === c.id)?.name
}));
const limitedDbChannels = foundDbChannelsWithName const limitedDbChannels = foundDbChannelsWithName
.concat(notFoundDbChannels) .concat(notFoundDbChannels)
.slice(0, MAX_SELECT_OPTIONS); .slice(0, MAX_SELECT_OPTIONS);
@@ -223,7 +287,7 @@ function isAllowedUser(
return true; return true;
} }
type MessageCommands = 'respond' | 'train' | 'help' | 'invite' | 'debug' | 'tts' | null; type MessageCommands = 'respond' | 'train' | 'help' | 'invite' | 'debug' | null;
/** /**
* Reads a new message and checks if and which command it is. * Reads a new message and checks if and which command it is.
@@ -246,8 +310,6 @@ function validateMessage(message: Discord.Message): MessageCommands {
command = 'invite'; command = 'invite';
} else if (split[1] === 'debug') { } else if (split[1] === 'debug') {
command = 'debug'; command = 'debug';
} else if (split[1] === 'tts') {
command = 'tts';
} }
} }
return command; return command;
@@ -272,12 +334,23 @@ function messageToData(message: Discord.Message): AddDataProps {
/** /**
* Recursively gets all messages in a text channel's history. * Recursively gets all messages in a text channel's history.
*/ */
import { TrainingStateManager } from './training-state';
async function saveGuildMessageHistory( async function saveGuildMessageHistory(
interaction: Discord.Message | Discord.CommandInteraction, interaction: Discord.Message | Discord.CommandInteraction,
clean = true, clean = true,
): Promise<string> { ): Promise<string> {
if (!isModerator(interaction.member)) return INVALID_PERMISSIONS_MESSAGE; if (!isModerator(interaction.member)) return INVALID_PERMISSIONS_MESSAGE;
if (!interaction.guildId || !interaction.guild) return INVALID_GUILD_MESSAGE; if (!interaction.guildId || !interaction.guild) return INVALID_GUILD_MESSAGE;
const stateManager = new TrainingStateManager(interaction.guildId, CONFIG_DIR);
// Check if training is already in progress
const currentState = stateManager.getState();
if (currentState.inProgress) {
return `Training is already in progress. Last update: ${currentState.lastUpdate}. Use /train with clean=true to restart.`;
}
const markov = await getMarkovByGuildId(interaction.guildId); const markov = await getMarkovByGuildId(interaction.guildId);
const channels = await getValidChannels(interaction.guild); const channels = await getValidChannels(interaction.guild);
@@ -287,11 +360,22 @@ async function saveGuildMessageHistory(
} }
if (clean) { if (clean) {
L.debug('Deleting old data'); L.debug('Deleting old data and resetting state');
await markov.delete(); await markov.delete();
stateManager.reset();
} else { } else {
L.debug('Not deleting old data during training'); L.debug('Not deleting old data during training');
// Filter out already processed channels when not cleaning
const unprocessedChannels = channels.filter(
channel => !stateManager.isChannelProcessed(channel.id)
);
if (unprocessedChannels.length === 0) {
return 'All channels have been processed. Use clean=true to retrain.';
} }
channels.splice(0, channels.length, ...unprocessedChannels);
}
stateManager.startTraining();
const channelIds = channels.map((c) => c.id); const channelIds = channels.map((c) => c.id);
L.debug({ channelIds }, `Training from text channels`); L.debug({ channelIds }, `Training from text channels`);
@@ -332,15 +416,37 @@ async function saveGuildMessageHistory(
progressMessage = (await interaction.followUp(updateMessageData)) as Discord.Message; progressMessage = (await interaction.followUp(updateMessageData)) as Discord.Message;
} }
const PAGE_SIZE = 100; const PAGE_SIZE = 50; // Reduced page size for better stability
const UPDATE_RATE = 1000; // In number of messages processed const UPDATE_RATE = 500; // More frequent updates
const BATCH_SIZE = 100; // Number of messages to process before a small delay
const BATCH_DELAY = 100; // Milliseconds to wait between batches
const MAX_MEMORY_USAGE = 1024 * 1024 * 1024; // 1GB memory limit
let lastUpdate = 0; let lastUpdate = 0;
let messagesCount = 0; let messagesCount = 0;
let firstMessageDate: number | undefined; let firstMessageDate: number | undefined;
let batchCount = 0;
// Monitor memory usage
const getMemoryUsage = () => {
const used = process.memoryUsage();
return used.heapUsed;
};
// Add delay between batches
const processingDelay = () => new Promise(resolve => setTimeout(resolve, BATCH_DELAY));
try {
// eslint-disable-next-line no-restricted-syntax // eslint-disable-next-line no-restricted-syntax
for (const channel of channels) { for (const channel of channels) {
let oldestMessageID: string | undefined; try {
// Check if we should skip this channel (already processed)
if (stateManager.isChannelProcessed(channel.id)) {
L.debug({ channelId: channel.id }, 'Skipping already processed channel');
continue;
}
let keepGoing = true; let keepGoing = true;
let oldestMessageID = stateManager.shouldResumeFromMessage(channel.id);
L.debug({ channelId: channel.id, messagesCount }, `Training from channel`); L.debug({ channelId: channel.id, messagesCount }, `Training from channel`);
const channelCreateDate = channel.createdTimestamp; const channelCreateDate = channel.createdTimestamp;
const channelEta = makeEta({ autostart: true, min: 0, max: 1, historyTimeConstant: 30 }); const channelEta = makeEta({ autostart: true, min: 0, max: 1, historyTimeConstant: 30 });
@@ -407,15 +513,55 @@ async function saveGuildMessageHistory(
allBatchMessages = allBatchMessages.concat(channelBatchMessages); allBatchMessages = allBatchMessages.concat(channelBatchMessages);
try {
// Check memory usage before processing
const memoryUsage = getMemoryUsage();
if (memoryUsage > MAX_MEMORY_USAGE) {
L.warn('Memory usage too high, waiting for garbage collection');
await processingDelay();
global.gc?.(); // Optional garbage collection if --expose-gc flag is used
}
// Filter and data map messages to be ready for addition to the corpus // Filter and data map messages to be ready for addition to the corpus
const humanAuthoredMessages = allBatchMessages const humanAuthoredMessages = allBatchMessages
.filter((m) => isHumanAuthoredMessage(m)) .filter((m) => isHumanAuthoredMessage(m))
.map(messageToData); .map(messageToData);
L.trace({ oldestMessageID }, `Saving ${humanAuthoredMessages.length} messages`);
// Process messages in smaller batches for stability
for (let i = 0; i < humanAuthoredMessages.length; i += BATCH_SIZE) {
const batch = humanAuthoredMessages.slice(i, i + BATCH_SIZE);
L.trace({ oldestMessageID, batchSize: batch.length }, `Saving batch of messages`);
try {
// eslint-disable-next-line no-await-in-loop // eslint-disable-next-line no-await-in-loop
await markov.addData(humanAuthoredMessages); await markov.addData(batch);
L.trace('Finished saving messages'); batchCount++;
messagesCount += humanAuthoredMessages.length; messagesCount += batch.length;
// Update state after successful batch
const lastMessage = allBatchMessages.last();
if (lastMessage) {
stateManager.updateProgress(channel.id, lastMessage.id, messagesCount);
}
// Add delay between batches
if (batchCount % 5 === 0) { // Every 5 batches
await processingDelay();
}
} catch (err) {
stateManager.recordError(err as Error, channel.id, oldestMessageID);
L.error({ err, batchSize: batch.length }, 'Error saving batch of messages');
// Continue with next batch instead of failing completely
continue;
}
}
L.trace('Finished processing message batches');
} catch (err) {
L.error({ err }, 'Error processing messages');
// Wait a bit before continuing to next batch of messages
await processingDelay();
}
const lastMessage = channelBatchMessages.last(); const lastMessage = channelBatchMessages.last();
// Update tracking metrics // Update tracking metrics
@@ -459,10 +605,22 @@ async function saveGuildMessageHistory(
}); });
} }
} }
} } catch (err) {
L.error({ err }, 'Error processing channel');
stateManager.recordError(err as Error);
// Continue with next channel
}
}
L.info({ channelIds }, `Trained from ${messagesCount} past human authored messages.`); L.info({ channelIds }, `Trained from ${messagesCount} past human authored messages.`);
return `Trained from ${messagesCount} past human authored messages.`; stateManager.finishTraining();
return `Trained from ${messagesCount} past human authored messages.`;
} catch (err) {
const error = err as Error;
L.error({ err }, 'Error during training completion');
stateManager.recordError(error);
return `Training encountered an error: ${error.message}. Use clean=false to resume from last checkpoint.`;
}
} }
interface JSONImport { interface JSONImport {
@@ -481,7 +639,17 @@ async function trainFromAttachmentJson(
if (!isModerator(interaction.member)) return INVALID_PERMISSIONS_MESSAGE; if (!isModerator(interaction.member)) return INVALID_PERMISSIONS_MESSAGE;
if (!interaction.guildId || !interaction.guild) return INVALID_GUILD_MESSAGE; if (!interaction.guildId || !interaction.guild) return INVALID_GUILD_MESSAGE;
const { guildId } = interaction; const { guildId } = interaction;
const stateManager = new TrainingStateManager(guildId, CONFIG_DIR);
// Check if training is already in progress
const currentState = stateManager.getState();
if (currentState.inProgress) {
return `Training is already in progress. Last update: ${currentState.lastUpdate}. Use clean=true to restart.`;
}
const markov = await getMarkovByGuildId(guildId); const markov = await getMarkovByGuildId(guildId);
stateManager.startTraining();
let trainingData: AddDataProps[]; let trainingData: AddDataProps[];
try { try {
@@ -517,14 +685,49 @@ async function trainFromAttachmentJson(
if (clean) { if (clean) {
L.debug('Deleting old data'); L.debug('Deleting old data');
await markov.delete(); await markov.delete();
stateManager.reset();
} else { } else {
L.debug('Not deleting old data during training'); L.debug('Not deleting old data during training');
} }
await markov.addData(trainingData); const BATCH_SIZE = 100;
const BATCH_DELAY = 100;
let processedCount = 0;
let batchCount = 0;
L.info(`Trained from ${trainingData.length} past human authored messages.`); try {
return `Trained from ${trainingData.length} past human authored messages.`; // Process messages in batches
for (let i = 0; i < trainingData.length; i += BATCH_SIZE) {
const batch = trainingData.slice(i, i + BATCH_SIZE);
try {
await markov.addData(batch);
processedCount += batch.length;
batchCount++;
// Update state after successful batch
stateManager.updateProgress('json-import', i.toString(), processedCount);
// Add delay between batches
if (batchCount % 5 === 0) {
await new Promise(resolve => setTimeout(resolve, BATCH_DELAY));
}
} catch (err) {
L.error({ err, batchIndex: i }, 'Error processing JSON batch');
stateManager.recordError(err as Error, 'json-import', i.toString());
// Continue with next batch instead of failing completely
continue;
}
}
L.info(`Successfully trained from ${processedCount} messages from JSON.`);
stateManager.finishTraining();
return `Successfully trained from ${processedCount} messages from JSON.`;
} catch (err) {
const error = err as Error;
L.error({ err }, 'Error during JSON training completion');
stateManager.recordError(error);
return `Training encountered an error: ${error.message}. Use clean=false to resume from last checkpoint.`;
}
} }
interface GenerateResponse { interface GenerateResponse {
@@ -534,7 +737,6 @@ interface GenerateResponse {
} }
interface GenerateOptions { interface GenerateOptions {
tts?: boolean;
debug?: boolean; debug?: boolean;
startSeed?: string; startSeed?: string;
} }
@@ -551,7 +753,7 @@ async function generateResponse(
options?: GenerateOptions, options?: GenerateOptions,
): Promise<GenerateResponse> { ): Promise<GenerateResponse> {
L.debug({ options }, 'Responding...'); L.debug({ options }, 'Responding...');
const { tts = false, debug = false, startSeed } = options || {}; const { debug = false, startSeed } = options || {};
if (!interaction.guildId) { if (!interaction.guildId) {
L.warn('Received an interaction without a guildId'); L.warn('Received an interaction without a guildId');
return { error: { content: INVALID_GUILD_MESSAGE } }; return { error: { content: INVALID_GUILD_MESSAGE } };
@@ -568,7 +770,6 @@ async function generateResponse(
L.info({ string: response.string }, 'Generated response text'); L.info({ string: response.string }, 'Generated response text');
L.debug({ response }, 'Generated response object'); L.debug({ response }, 'Generated response object');
const messageOpts: AgnosticReplyOptions = { const messageOpts: AgnosticReplyOptions = {
tts,
allowedMentions: { repliedUser: false, parse: [] }, allowedMentions: { repliedUser: false, parse: [] },
}; };
const attachmentUrls = response.refs const attachmentUrls = response.refs
@@ -652,12 +853,17 @@ function helpMessage(): AgnosticReplyOptions {
.addFields([ .addFields([
{ {
name: `${config.messageCommandPrefix} or /${messageCommand.name}`, name: `${config.messageCommandPrefix} or /${messageCommand.name}`,
value: `Generates a sentence to say based on the chat database. Send your message as TTS to recieve it as TTS.`, value: `Generates a sentence based on the chat database.`,
}, },
{ {
name: `/${listenChannelCommand.name}`, name: `/${listenChannelCommand.name}`,
value: `Add, remove, list, or modify the list of channels the bot listens to.`, value: `Add, remove, list, or modify the list of channels the bot listens to and learns from.`,
},
{
name: `/autorespond`,
value: `Add, remove, list, or modify the list of channels where the bot will automatically respond to all messages.`,
}, },
{ {
@@ -674,11 +880,6 @@ function helpMessage(): AgnosticReplyOptions {
name: `${config.messageCommandPrefix} debug or /${messageCommand.name} debug: True`, name: `${config.messageCommandPrefix} debug or /${messageCommand.name} debug: True`,
value: `Runs the ${config.messageCommandPrefix} command and follows it up with debug info.`, value: `Runs the ${config.messageCommandPrefix} command and follows it up with debug info.`,
}, },
{
name: `${config.messageCommandPrefix} tts or /${messageCommand.name} tts: True`,
value: `Runs the ${config.messageCommandPrefix} command and reads it with text-to-speech.`,
},
]) ])
.setFooter({ .setFooter({
text: `${packageJson().name} ${getVersion()} by ${ text: `${packageJson().name} ${getVersion()} by ${
@@ -696,9 +897,8 @@ function generateInviteUrl(): string {
permissions: [ permissions: [
'ViewChannel', 'ViewChannel',
'SendMessages', 'SendMessages',
'SendTTSMessages',
'AttachFiles', 'AttachFiles',
'ReadMessageHistory', 'ReadMessageHistory'
], ],
}); });
} }
@@ -789,11 +989,6 @@ client.on('messageCreate', async (message) => {
const generatedResponse = await generateResponse(message); const generatedResponse = await generateResponse(message);
await handleResponseMessage(generatedResponse, message); await handleResponseMessage(generatedResponse, message);
} }
if (command === 'tts') {
L.debug('Responding to legacy command tts');
const generatedResponse = await generateResponse(message, { tts: true });
await handleResponseMessage(generatedResponse, message);
}
if (command === 'debug') { if (command === 'debug') {
L.debug('Responding to legacy command debug'); L.debug('Responding to legacy command debug');
const generatedResponse = await generateResponse(message, { debug: true }); const generatedResponse = await generateResponse(message, { debug: true });
@@ -802,11 +997,23 @@ client.on('messageCreate', async (message) => {
if (command === null) { if (command === null) {
if (isHumanAuthoredMessage(message)) { if (isHumanAuthoredMessage(message)) {
if (client.user && message.mentions.has(client.user)) { if (client.user && message.mentions.has(client.user)) {
// Check if response channels are configured and if this channel is allowed
if (config.responseChannelIds.length > 0 && !config.responseChannelIds.includes(message.channel.id)) {
L.debug('Ignoring mention in non-response channel');
return;
}
L.debug('Responding to mention'); L.debug('Responding to mention');
// <@!278354154563567636> how are you doing? // <@!278354154563567636> how are you doing?
const startSeed = message.content.replace(/<@!\d+>/g, '').trim(); const startSeed = message.content.replace(/<@!\d+>/g, '').trim();
const generatedResponse = await generateResponse(message, { startSeed }); const generatedResponse = await generateResponse(message, { startSeed });
await handleResponseMessage(generatedResponse, message); await handleResponseMessage(generatedResponse, message);
} else if (await isAutoRespondChannel(message.channel)) {
// Auto-respond to all messages in configured channels using message content as context
L.debug('Auto-responding in configured channel with context');
const startSeed = message.content.trim();
const generatedResponse = await generateResponse(message, { startSeed });
await handleResponseMessage(generatedResponse, message);
} }
if (await isValidChannel(message.channel)) { if (await isValidChannel(message.channel)) {
@@ -848,7 +1055,7 @@ client.on('threadDelete', async (thread) => {
await markov.removeTags([thread.id]); await markov.removeTags([thread.id]);
}); });
// eslint-disable-next-line consistent-return
client.on('interactionCreate', async (interaction) => { client.on('interactionCreate', async (interaction) => {
if (interaction.isChatInputCommand()) { if (interaction.isChatInputCommand()) {
L.info({ command: interaction.commandName }, 'Recieved slash command'); L.info({ command: interaction.commandName }, 'Recieved slash command');
@@ -859,23 +1066,12 @@ client.on('interactionCreate', async (interaction) => {
await interaction.reply(inviteMessage()); await interaction.reply(inviteMessage());
} else if (interaction.commandName === messageCommand.name) { } else if (interaction.commandName === messageCommand.name) {
await interaction.deferReply(); await interaction.deferReply();
const tts = interaction.options.getBoolean('tts') || false;
const debug = interaction.options.getBoolean('debug') || false; const debug = interaction.options.getBoolean('debug') || false;
const startSeed = interaction.options.getString('seed')?.trim() || undefined; const startSeed = interaction.options.getString('seed')?.trim() || undefined;
const generatedResponse = await generateResponse(interaction, { tts, debug, startSeed }); const generatedResponse = await generateResponse(interaction, { debug, startSeed });
/**
* TTS doesn't work when using editReply, so instead we use delete + followUp
* However, delete + followUp is ugly and shows the bot replying to "Message could not be loaded.",
* so we avoid it if possible
*/
if (generatedResponse.message) { if (generatedResponse.message) {
if (generatedResponse.message.tts) {
await interaction.deleteReply();
await interaction.followUp(generatedResponse.message);
} else {
await interaction.editReply(generatedResponse.message); await interaction.editReply(generatedResponse.message);
}
} else { } else {
await interaction.deleteReply(); await interaction.deleteReply();
} }
@@ -943,6 +1139,67 @@ client.on('interactionCreate', async (interaction) => {
ephemeral: true, ephemeral: true,
}); });
} }
} else if (interaction.commandName === autoRespondCommand.name) {
await interaction.deferReply();
const subCommand = interaction.options.getSubcommand(true) as 'add' | 'remove' | 'list' | 'modify';
if (subCommand === 'list') {
const reply = await listAutoRespondChannels(interaction);
await interaction.editReply(reply);
} else if (subCommand === 'add') {
if (!isModerator(interaction.member)) {
return handleUnprivileged(interaction);
}
if (!interaction.guildId) {
return handleNoGuild(interaction);
}
const channels = getChannelsFromInteraction(interaction);
await addAutoRespondChannels(channels, interaction.guildId);
await interaction.editReply(
`Added ${channels.length} text channels to auto-respond list.`
);
} else if (subCommand === 'remove') {
if (!isModerator(interaction.member)) {
return handleUnprivileged(interaction);
}
if (!interaction.guildId) {
return handleNoGuild(interaction);
}
const channels = getChannelsFromInteraction(interaction);
await removeAutoRespondChannels(channels, interaction.guildId);
await interaction.editReply(
`Removed ${channels.length} text channels from auto-respond list.`
);
} else if (subCommand === 'modify') {
if (!interaction.guild) {
return handleNoGuild(interaction);
}
if (!isModerator(interaction.member)) {
await handleUnprivileged(interaction);
}
await interaction.deleteReply();
const dbTextChannels = await getTextChannels(interaction.guild);
const row = new Discord.ActionRowBuilder<Discord.StringSelectMenuBuilder>().addComponents(
new Discord.StringSelectMenuBuilder()
.setCustomId('autorespond-modify-select')
.setPlaceholder('Nothing selected')
.setMinValues(0)
.setMaxValues(dbTextChannels.length)
.addOptions(
dbTextChannels.map((c) => ({
label: `#${c.name}` || c.id,
value: c.id,
default: c.autoRespond || false,
})),
),
);
await interaction.followUp({
content: 'Select which channels you would like the bot to auto-respond in',
components: [row],
ephemeral: true,
});
}
} else if (interaction.commandName === trainCommand.name) { } else if (interaction.commandName === trainCommand.name) {
await interaction.deferReply(); await interaction.deferReply();
const clean = interaction.options.getBoolean('clean') ?? true; const clean = interaction.options.getBoolean('clean') ?? true;
@@ -990,6 +1247,37 @@ client.on('interactionCreate', async (interaction) => {
content: 'Updated actively listened to channels list.', content: 'Updated actively listened to channels list.',
ephemeral: true, ephemeral: true,
}); });
} else if (interaction.customId === 'autorespond-modify-select') {
await interaction.deferUpdate();
const { guild } = interaction;
if (!isModerator(interaction.member)) {
return handleUnprivileged(interaction, false);
}
if (!guild) {
return handleNoGuild(interaction, false);
}
const allChannels =
(interaction.component as Discord.StringSelectMenuComponent).options?.map((o) => o.value) ||
[];
const selectedChannelIds = interaction.values;
const textChannels = (
await Promise.all(
allChannels.map(async (c) => {
return guild.channels.fetch(c);
}),
)
).filter((c): c is Discord.TextChannel => c !== null && c instanceof Discord.TextChannel);
const unselectedChannels = textChannels.filter((t) => !selectedChannelIds.includes(t.id));
const selectedChannels = textChannels.filter((t) => selectedChannelIds.includes(t.id));
await addAutoRespondChannels(selectedChannels, guild.id);
await removeAutoRespondChannels(unselectedChannels, guild.id);
await interaction.followUp({
content: 'Updated auto-respond channels list.',
ephemeral: true,
});
} }
} }
}); });

391
src/train.ts Normal file
View File

@@ -0,0 +1,391 @@
import 'source-map-support/register';
import 'reflect-metadata';
import Markov, { MarkovConstructorOptions, AddDataProps } from 'markov-strings-db';
import { DataSource } from 'typeorm';
import { promises as fs } from 'fs';
import path from 'path';
import { config } from './config';
import ormconfig from './ormconfig';
import { Guild } from './entity/Guild';
import { Channel } from './entity/Channel';
import L from './logger';
import { MarkovDataCustom } from './types';
import { TrainingStateManager } from './training-state';
import { CONFIG_DIR } from './config/setup';
const markovOpts: MarkovConstructorOptions = {
stateSize: config.stateSize,
};
// Constants for batch processing
const BATCH_SIZE = 100; // Process messages in batches
const BATCH_DELAY = 100; // Milliseconds to wait between batches
const MAX_MEMORY_USAGE = 1024 * 1024 * 1024; // 1GB memory limit
// Monitor memory usage
const getMemoryUsage = () => {
const used = process.memoryUsage();
return used.heapUsed;
};
// Add delay between batches
const processingDelay = () => new Promise(resolve => setTimeout(resolve, BATCH_DELAY));
async function getMarkovByGuildId(guildId: string): Promise<Markov> {
const markov = new Markov({ id: guildId, options: { ...markovOpts, id: guildId } });
L.trace({ guildId }, 'Setting up markov instance');
await markov.setup(); // Connect the markov instance to the DB to assign it an ID
return markov;
}
interface JSONImport {
message: string;
attachments?: string[];
}
/**
* Train from a JSON file containing messages
*/
async function trainFromJson(
guildId: string,
jsonPath: string,
clean = true,
): Promise<string> {
const markov = await getMarkovByGuildId(guildId);
let trainingData: AddDataProps[];
try {
const fileContent = await fs.readFile(jsonPath, 'utf-8');
const importData = JSON.parse(fileContent) as JSONImport[];
// Filter out invalid entries first
const validData = importData.filter((datum, index) => {
if (!datum.message || typeof datum.message !== 'string') {
L.debug({ index }, 'Skipping entry without valid message');
return false;
}
if (datum.attachments?.some(a => typeof a !== 'string')) {
L.debug({ index }, 'Skipping entry with invalid attachments');
return false;
}
return true;
});
// Map valid entries to training data
trainingData = validData.map(datum => {
let custom: MarkovDataCustom | undefined;
if (datum.attachments?.length) {
custom = { attachments: datum.attachments };
}
return {
string: datum.message,
custom,
tags: [guildId]
};
});
} catch (err) {
L.error(err);
if (err instanceof SyntaxError) {
return 'The provided JSON file has invalid formatting. See the logs for details.';
}
return `Error reading file: ${err instanceof Error ? err.message : 'Unknown error'}`;
}
if (clean) {
L.debug('Deleting old data');
await markov.delete();
} else {
L.debug('Not deleting old data during training');
}
let processedCount = 0;
let batchCount = 0;
const totalMessages = trainingData.length;
// Process messages in batches
for (let i = 0; i < trainingData.length; i += BATCH_SIZE) {
try {
// Check memory usage
const memoryUsage = getMemoryUsage();
if (memoryUsage > MAX_MEMORY_USAGE) {
L.warn('Memory usage too high, waiting for garbage collection');
await processingDelay();
global.gc?.(); // Optional garbage collection if --expose-gc flag is used
}
const batch = trainingData.slice(i, i + BATCH_SIZE);
await markov.addData(batch);
processedCount += batch.length;
batchCount++;
// Log progress
if (batchCount % 5 === 0) {
const progress = (processedCount / totalMessages * 100).toFixed(2);
L.info(`Progress: ${progress}% (${processedCount}/${totalMessages} messages)`);
await processingDelay(); // Add delay every 5 batches
}
} catch (err) {
L.error({ err, batchIndex: i }, 'Error processing batch');
// Continue with next batch instead of failing completely
await processingDelay(); // Wait a bit longer after an error
continue;
}
}
L.info(`Successfully trained from ${processedCount} messages.`);
return `Successfully trained from ${processedCount} messages.`;
}
/**
* Train from all JSON files in a directory
*/
/**
* Train from all JSON files in a directory
* @param guildId The Discord guild ID
* @param dirPath Path to directory containing JSON files
* @param clean Whether to clean existing data before training
*/
/**
* Acquire a lock file for training to prevent concurrent processes
*/
async function acquireTrainingLock(guildId: string): Promise<boolean> {
const lockPath = path.join(CONFIG_DIR, `${guildId}_training.lock`);
try {
await fs.writeFile(lockPath, process.pid.toString(), { flag: 'wx' });
return true;
} catch (err) {
if ((err as NodeJS.ErrnoException).code === 'EEXIST') {
try {
const pid = parseInt(await fs.readFile(lockPath, 'utf-8'));
try {
// Check if process is still running
process.kill(pid, 0);
return false; // Process is still running
} catch {
// Process is not running, safe to remove lock
await fs.unlink(lockPath);
await fs.writeFile(lockPath, process.pid.toString());
return true;
}
} catch {
// Error reading/writing lock file
return false;
}
}
return false;
}
}
/**
* Release the training lock file
*/
async function releaseTrainingLock(guildId: string): Promise<void> {
const lockPath = path.join(CONFIG_DIR, `${guildId}_training.lock`);
try {
await fs.unlink(lockPath);
} catch {
// Ignore errors during cleanup
}
}
/**
* Sanitize and validate a directory path
*/
async function validateDirectoryPath(dirPath: string): Promise<string> {
// Resolve to absolute path
const absolutePath = path.resolve(dirPath);
// Prevent directory traversal
const normalizedPath = path.normalize(absolutePath);
if (!normalizedPath.startsWith(process.cwd())) {
throw new Error('Directory must be within current working directory');
}
// Verify directory exists and is accessible
try {
const stats = await fs.stat(normalizedPath);
if (!stats.isDirectory()) {
throw new Error('Path is not a directory');
}
await fs.access(normalizedPath, fs.constants.R_OK);
return normalizedPath;
} catch (err) {
throw new Error(`Invalid directory path: ${err instanceof Error ? err.message : 'Unknown error'}`);
}
}
/**
* Train from all JSON files in a directory
*/
async function trainFromDirectory(
guildId: string,
dirPath: string,
clean = true,
): Promise<string> {
L.debug({ guildId, dirPath, clean }, 'Starting directory training');
const stateManager = new TrainingStateManager(guildId, CONFIG_DIR);
// Set up cleanup handler
const cleanup = async () => {
try {
await releaseTrainingLock(guildId);
stateManager.finishTraining();
} catch (err) {
L.error({ err }, 'Error during cleanup');
}
};
// Handle process termination
process.once('SIGINT', cleanup);
process.once('SIGTERM', cleanup);
try {
// Try to acquire lock
if (!await acquireTrainingLock(guildId)) {
return 'Another training process is already running. Please wait for it to complete.';
}
// Always reset state at the start of training
stateManager.reset();
try {
// Validate and normalize directory path
const absolutePath = await validateDirectoryPath(dirPath);
// Get all JSON files in the directory
L.trace({ dirPath: absolutePath }, 'Reading directory');
const files = await fs.readdir(absolutePath);
const jsonFiles = files.filter(file => file.toLowerCase().endsWith('.json'));
if (jsonFiles.length === 0) {
L.warn({ dirPath: absolutePath }, 'No JSON files found in directory');
return 'No JSON files found in the specified directory.';
}
let totalProcessed = 0;
let batchCount = 0;
L.info({ fileCount: jsonFiles.length }, 'Found JSON files to process');
stateManager.startTraining();
// Process first file with clean flag, subsequent files without cleaning
for (let i = 0; i < jsonFiles.length; i++) {
const jsonPath = path.join(absolutePath, jsonFiles[i]);
const fileNumber = i + 1;
L.debug(
{ file: jsonFiles[i], progress: `${fileNumber}/${jsonFiles.length}` },
'Processing file'
);
try {
// Check memory usage before processing file
const memoryUsage = getMemoryUsage();
if (memoryUsage > MAX_MEMORY_USAGE) {
L.warn('Memory usage too high, waiting for garbage collection');
await processingDelay();
global.gc?.(); // Optional garbage collection if --expose-gc flag is used
}
// Check if we should skip this file (already processed)
if (!clean && stateManager.isChannelProcessed(jsonFiles[i])) {
L.debug({ file: jsonFiles[i] }, 'Skipping already processed file');
continue;
}
const result = await trainFromJson(
guildId,
jsonPath,
i === 0 ? clean : false // Only clean on first file
);
// Extract number of processed messages from result string
const processed = parseInt(result.match(/\d+/)?.[0] || '0');
totalProcessed += processed;
batchCount++;
// Update state after each file
stateManager.updateProgress('json-import', jsonFiles[i], totalProcessed);
L.trace(
{ file: jsonFiles[i], processed, totalProcessed },
'File processing complete'
);
// Add delay between files
if (batchCount % 5 === 0) {
await processingDelay();
}
// Clear any references that might be held
if (global.gc) {
global.gc();
}
} catch (err) {
const error = err as Error;
L.error(
{ error: error.message, file: jsonFiles[i], stack: error.stack },
'Error processing JSON file'
);
stateManager.recordError(error, 'json-import', jsonFiles[i]);
// Add longer delay after error
await processingDelay();
// Continue with next file instead of failing completely
continue;
}
}
const summary = { totalProcessed, fileCount: jsonFiles.length };
L.info(summary, 'Directory training complete');
return `Successfully trained from ${totalProcessed} messages across ${jsonFiles.length} files.`;
} finally {
// Clean up regardless of success/failure
await cleanup();
// Remove process termination handlers
process.off('SIGINT', cleanup);
process.off('SIGTERM', cleanup);
}
} catch (err) {
const error = err as Error;
L.error(
{ error: error.message, stack: error.stack, dirPath },
'Error during directory training'
);
stateManager.recordError(error);
return `Training encountered an error: ${error.message}. Use clean=false to resume from last checkpoint.`;
}
}
async function main(): Promise<void> {
const args = process.argv.slice(2);
if (args.length < 2) {
console.log('Usage: node train.js <guildId> <path> [--keep-existing] [--directory]');
console.log('Options:');
console.log(' --keep-existing Keep existing training data');
console.log(' --directory Process all JSON files in the specified directory');
process.exit(1);
}
const guildId = args[0];
const inputPath = args[1];
const keepExisting = args.includes('--keep-existing');
const isDirectory = args.includes('--directory');
const dataSourceOptions = Markov.extendDataSourceOptions(ormconfig);
const dataSource = new DataSource(dataSourceOptions);
await dataSource.initialize();
// Ensure guild exists in DB
await Guild.upsert(Guild.create({ id: guildId }), ['id']);
const result = isDirectory
? await trainFromDirectory(guildId, inputPath, !keepExisting)
: await trainFromJson(guildId, inputPath, !keepExisting);
console.log(result);
await dataSource.destroy();
}
if (require.main === module) {
main().catch(console.error);
}

113
src/training-state.ts Normal file
View File

@@ -0,0 +1,113 @@
import fs from 'fs-extra';
import path from 'path';
import { TrainingState } from './types';
import L from './logger';
export class TrainingStateManager {
private stateFile: string;
private state: TrainingState;
constructor(guildId: string, configDir: string = 'config') {
this.stateFile = path.join(configDir, 'training-state', `${guildId}.json`);
// Initialize with default state
this.state = {
guildId,
processedChannels: [],
totalMessages: 0,
lastUpdate: new Date().toISOString(),
inProgress: false
};
// Ensure directory exists
fs.ensureDirSync(path.dirname(this.stateFile));
// Load existing state if available
this.loadState();
}
private loadState(): void {
try {
if (fs.existsSync(this.stateFile)) {
const savedState = fs.readJsonSync(this.stateFile);
this.state = { ...this.state, ...savedState };
L.info({ guildId: this.state.guildId }, 'Loaded existing training state');
}
} catch (err) {
L.error({ err }, 'Error loading training state');
// Keep using default state if load fails
}
}
private saveState(): void {
try {
fs.writeJsonSync(this.stateFile, this.state, { spaces: 2 });
} catch (err) {
L.error({ err }, 'Error saving training state');
}
}
public startTraining(): void {
this.state.inProgress = true;
this.state.error = undefined;
this.state.lastUpdate = new Date().toISOString();
this.saveState();
}
public finishTraining(): void {
this.state.inProgress = false;
this.state.lastUpdate = new Date().toISOString();
this.saveState();
}
public updateProgress(channelId: string, messageId: string, messagesProcessed: number): void {
this.state.lastChannelId = channelId;
this.state.lastMessageId = messageId;
this.state.totalMessages = messagesProcessed;
this.state.lastUpdate = new Date().toISOString();
this.saveState();
}
public markChannelComplete(channelId: string): void {
if (!this.state.processedChannels.includes(channelId)) {
this.state.processedChannels.push(channelId);
this.saveState();
}
}
public recordError(error: Error, channelId?: string, messageId?: string): void {
this.state.error = {
message: error.message,
channelId,
messageId,
timestamp: new Date().toISOString()
};
this.saveState();
}
public isChannelProcessed(channelId: string): boolean {
return this.state.processedChannels.includes(channelId);
}
public shouldResumeFromMessage(channelId: string): string | undefined {
if (this.state.inProgress && this.state.lastChannelId === channelId) {
return this.state.lastMessageId;
}
return undefined;
}
public getState(): TrainingState {
return { ...this.state };
}
public reset(): void {
this.state = {
guildId: this.state.guildId,
processedChannels: [],
totalMessages: 0,
lastUpdate: new Date().toISOString(),
inProgress: false
};
this.saveState();
}
}

19
src/types.ts Normal file
View File

@@ -0,0 +1,19 @@
export interface MarkovDataCustom {
attachments: string[];
}
export interface TrainingState {
guildId: string;
lastMessageId?: string;
lastChannelId?: string;
processedChannels: string[];
totalMessages: number;
lastUpdate: string;
inProgress: boolean;
error?: {
message: string;
channelId?: string;
messageId?: string;
timestamp: string;
};
}

View File

@@ -3,7 +3,7 @@
"compilerOptions": { "compilerOptions": {
"target": "es2021", /* Specify ECMAScript target version: 'ES3' (default), 'ES5', 'ES2015', 'ES2016', 'ES2017', 'ES2018', 'ES2019' or 'ESNEXT'. */ "target": "es2021", /* Specify ECMAScript target version: 'ES3' (default), 'ES5', 'ES2015', 'ES2016', 'ES2017', 'ES2018', 'ES2019' or 'ESNEXT'. */
"module": "commonjs", /* Specify module code generation: 'none', 'commonjs', 'amd', 'system', 'umd', 'es2015', or 'ESNext'. */ "module": "commonjs", /* Specify module code generation: 'none', 'commonjs', 'amd', 'system', 'umd', 'es2015', or 'ESNext'. */
"outDir": "./dist", /* Redirect output structure to the directory. */ "outDir": "./build", /* Redirect output structure to the directory. */
"removeComments": true, /* Do not emit comments to output. */ "removeComments": true, /* Do not emit comments to output. */
"esModuleInterop": true, "esModuleInterop": true,
"strict": true, /* Enable all strict type-checking options. */ "strict": true, /* Enable all strict type-checking options. */