mirror of
https://github.com/pacnpal/markov-discord.git
synced 2025-12-19 18:51:05 -05:00
Added several things: Parser to import JSON from DiscordChatExporter, ability to train without bot running and more.
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -68,3 +68,6 @@ error.json
|
||||
markovDB.json
|
||||
|
||||
/config/
|
||||
/exports/
|
||||
/build/
|
||||
/dist/
|
||||
60
README.md
60
README.md
@@ -14,11 +14,67 @@ A Markov chain bot using markov-strings.
|
||||
* User: `/mark`
|
||||
* Bot: 
|
||||
|
||||
### Training from a file
|
||||
### Training from files
|
||||
|
||||
Using the `json` option in the `/train` command, you can import a list of messages.
|
||||
You can train the bot using JSON files in two ways:
|
||||
|
||||
1. Using the `json` option in the `/train` command to import a single file of messages.
|
||||
2. Using the command line to train from either a single file or an entire directory of JSON files.
|
||||
|
||||
#### Using the Discord Command
|
||||
Use the `json` option in the `/train` command to import a single file of messages.
|
||||
An example JSON file can be seen [here](img/example-training.json).
|
||||
|
||||
#### Using the Command Line
|
||||
For bulk training from multiple files, you can use the command line interface. First, build the training script:
|
||||
|
||||
```bash
|
||||
# Build the TypeScript files
|
||||
npm run build
|
||||
```
|
||||
|
||||
Then you can use the training script:
|
||||
|
||||
```bash
|
||||
# Train from a single JSON file
|
||||
node build/train.js <guildId> <jsonPath> [--keep-existing]
|
||||
|
||||
# Train from all JSON files in a directory
|
||||
node build/train.js <guildId> <directoryPath> --directory [--keep-existing] [--expose-gc]
|
||||
```
|
||||
|
||||
Options:
|
||||
- `--keep-existing`: Don't clear existing training data before importing
|
||||
- `--directory`: Process all JSON files in the specified directory
|
||||
- `--expose-gc`: Enable garbage collection for better memory management (recommended for large directories)
|
||||
|
||||
Each JSON file should contain an array of messages in this format:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"message": "Message content",
|
||||
"attachments": ["optional", "attachment", "urls"]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
When training from a directory:
|
||||
- All .json files in the directory will be processed
|
||||
- Files are processed sequentially to manage memory usage
|
||||
- Progress is shown for each file
|
||||
- A total count of processed messages is provided at the end
|
||||
|
||||
Security and Performance Notes:
|
||||
- The directory must be within the project's working directory for security
|
||||
- The process will create lock files in the config directory to prevent concurrent training
|
||||
- Memory usage is monitored and managed automatically
|
||||
- For large directories, use the `--expose-gc` flag for better memory management:
|
||||
```bash
|
||||
node --expose-gc build/train.js <guildId> <directoryPath> --directory
|
||||
```
|
||||
- Training can be safely interrupted with Ctrl+C; state will be preserved
|
||||
- Use `--keep-existing` to resume interrupted training
|
||||
|
||||
## Setup
|
||||
|
||||
This bot stores your Discord server's entire message history, so a public instance to invite to your server is not available due to obvious data privacy concerns. Instead, you can host it yourself.
|
||||
|
||||
21
example-training.json
Normal file
21
example-training.json
Normal file
@@ -0,0 +1,21 @@
|
||||
[
|
||||
{
|
||||
"message": "Hello world!",
|
||||
"attachments": []
|
||||
},
|
||||
{
|
||||
"message": "This is an example message with an attachment",
|
||||
"attachments": ["https://example.com/image.jpg"]
|
||||
},
|
||||
{
|
||||
"message": "Another training message",
|
||||
"attachments": []
|
||||
},
|
||||
{
|
||||
"message": "Messages can have multiple attachments",
|
||||
"attachments": [
|
||||
"https://example.com/image1.jpg",
|
||||
"https://example.com/image2.jpg"
|
||||
]
|
||||
}
|
||||
]
|
||||
124
imports/discord-parser.py
Normal file
124
imports/discord-parser.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from glob import glob
|
||||
import argparse
|
||||
|
||||
def strip_mentions(content):
|
||||
# Pattern for Discord mentions: <@!?[0-9]+> or <@&[0-9]+> or <#[0-9]+>
|
||||
mention_pattern = r'<@!?\d+>|<@&\d+>|<#\d+>'
|
||||
return re.sub(mention_pattern, '', content).strip()
|
||||
|
||||
def extract_channel_name(filename):
|
||||
# Extract channel name and part number from Discord export filename format
|
||||
# Pattern: channel name before [numbers] and optional [part X]
|
||||
channel_match = re.search(r'- ([^[\]]+) \[\d+\]', filename)
|
||||
part_match = re.search(r'\[part (\d+)\]', filename, re.IGNORECASE)
|
||||
|
||||
if channel_match:
|
||||
channel_name = channel_match.group(1).strip().lower()
|
||||
# Remove any category names (text after last hyphen)
|
||||
channel_name = channel_name.split(' - ')[-1].strip()
|
||||
|
||||
# Add part number if it exists
|
||||
if part_match:
|
||||
part_num = part_match.group(1)
|
||||
return f"{channel_name}_part{part_num}.json"
|
||||
return f"{channel_name}.json"
|
||||
|
||||
return "unknown_channel.json"
|
||||
|
||||
def parse_discord_export(input_file, output_dir):
|
||||
# Generate output filename based on channel name
|
||||
output_filename = extract_channel_name(input_file)
|
||||
output_file = os.path.join(output_dir, output_filename)
|
||||
|
||||
try:
|
||||
# Read the input file
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Extract just the messages and their attachments
|
||||
output_messages = []
|
||||
|
||||
for msg in data['messages']:
|
||||
# Skip empty messages after stripping mentions
|
||||
stripped_content = strip_mentions(msg['content'])
|
||||
if not stripped_content and not msg['attachments']:
|
||||
continue
|
||||
|
||||
message_obj = {
|
||||
'message': stripped_content
|
||||
}
|
||||
|
||||
# If there are attachments, add them to the message object
|
||||
if msg['attachments']:
|
||||
message_obj['attachments'] = [
|
||||
attachment['url'] for attachment in msg['attachments']
|
||||
]
|
||||
|
||||
# Only add messages that have content or attachments
|
||||
if message_obj['message'] or 'attachments' in message_obj:
|
||||
output_messages.append(message_obj)
|
||||
|
||||
# Create output directory if it doesn't exist
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Write the output file
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(output_messages, f, indent=4, ensure_ascii=False)
|
||||
|
||||
print(f"Successfully parsed: {input_file} -> {output_file}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing {input_file}: {str(e)}")
|
||||
return False
|
||||
|
||||
def process_directory(input_dir, output_dir):
|
||||
# Find all Discord export JSON files in the input directory
|
||||
pattern = os.path.join(input_dir, "*Discord*.json")
|
||||
files = glob(pattern)
|
||||
|
||||
if not files:
|
||||
print(f"No Discord export files found in: {input_dir}")
|
||||
return
|
||||
|
||||
success_count = 0
|
||||
failure_count = 0
|
||||
|
||||
for file in files:
|
||||
if parse_discord_export(file, output_dir):
|
||||
success_count += 1
|
||||
else:
|
||||
failure_count += 1
|
||||
|
||||
print(f"\nProcessing complete:")
|
||||
print(f"Successfully processed: {success_count} files")
|
||||
print(f"Failed to process: {failure_count} files")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Process Discord export JSON files.')
|
||||
parser.add_argument('-i', '--input', default='.',
|
||||
help='Input directory containing Discord export files (default: current directory)')
|
||||
parser.add_argument('-o', '--output', default='output',
|
||||
help='Output directory for processed files (default: output)')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert to absolute paths
|
||||
input_dir = os.path.abspath(args.input)
|
||||
output_dir = os.path.abspath(args.output)
|
||||
|
||||
# Check if input directory exists
|
||||
if not os.path.exists(input_dir):
|
||||
print(f"Error: Input directory does not exist: {input_dir}")
|
||||
return
|
||||
|
||||
print(f"Processing files from: {input_dir}")
|
||||
print(f"Saving output to: {output_dir}")
|
||||
|
||||
process_directory(input_dir, output_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
9346
package-lock.json
generated
9346
package-lock.json
generated
File diff suppressed because it is too large
Load Diff
13
package.json
13
package.json
@@ -4,14 +4,14 @@
|
||||
"description": "A conversational Markov chain bot for Discord",
|
||||
"main": "dist/index.js",
|
||||
"scripts": {
|
||||
"start": "NODE_ENV=production pm2 start --no-daemon dist/index.js",
|
||||
"start": "./node_modules/.bin/pm2 start --no-daemon dist/index.js",
|
||||
"start:ts": "ts-node src/index.ts",
|
||||
"build": "rimraf dist && tsc",
|
||||
"lint": "tsc --noEmit && eslint .",
|
||||
"build": "rimraf build && /opt/homebrew/bin/node ./node_modules/.bin/tsc",
|
||||
"lint": "./node-v20.18.1-darwin-x64/bin/node ./node_modules/.bin/tsc --noEmit && ./node-v20.18.1-darwin-x64/bin/node ./node_modules/.bin/eslint .",
|
||||
"docker:build": "docker build . -t charlocharlie/markov-discord:latest --target deploy",
|
||||
"docker:run": "docker run --rm -ti -v $(pwd)/config:/usr/app/config charlocharlie/markov-discord:latest",
|
||||
"typeorm": "node --require ts-node/register ./node_modules/typeorm/cli.js",
|
||||
"docs": "typedoc --out docs src/config/classes.ts"
|
||||
"typeorm": "./node-v20.18.1-darwin-x64/bin/node --require ts-node/register ./node_modules/typeorm/cli.js",
|
||||
"docs": "./node-v20.18.1-darwin-x64/bin/node ./node_modules/.bin/typedoc --out docs src/config/classes.ts"
|
||||
},
|
||||
"repository": "https://github.com/claabs/markov-discord.git",
|
||||
"keywords": [
|
||||
@@ -31,7 +31,7 @@
|
||||
},
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"better-sqlite3": "^8.7.0",
|
||||
"better-sqlite3": "^9.6.0",
|
||||
"bufferutil": "^4.0.8",
|
||||
"class-transformer": "^0.5.1",
|
||||
"class-validator": "^0.14.1",
|
||||
@@ -42,6 +42,7 @@
|
||||
"json5": "^2.2.3",
|
||||
"markov-strings-db": "^4.2.0",
|
||||
"node-fetch": "^2.6.7",
|
||||
"node-gyp": "^11.0.0",
|
||||
"pino": "^7.11.0",
|
||||
"pino-pretty": "^7.6.1",
|
||||
"reflect-metadata": "^0.2.2",
|
||||
|
||||
@@ -163,4 +163,19 @@ export class AppConfig {
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
devGuildId = process.env.DEV_GUILD_ID;
|
||||
|
||||
/**
|
||||
* A list of channel IDs where the bot will respond to mentions.
|
||||
* If empty, the bot will respond to mentions in any channel.
|
||||
* @example ["734548250895319070"]
|
||||
* @default []
|
||||
* @env RESPONSE_CHANNEL_IDS (comma separated)
|
||||
*/
|
||||
@IsArray()
|
||||
@IsString({ each: true })
|
||||
@Type(() => String)
|
||||
@IsOptional()
|
||||
responseChannelIds = process.env.RESPONSE_CHANNEL_IDS
|
||||
? process.env.RESPONSE_CHANNEL_IDS.split(',').map((id) => id.trim())
|
||||
: [];
|
||||
}
|
||||
|
||||
@@ -21,9 +21,6 @@ export const inviteCommand = new SlashCommandBuilder()
|
||||
export const messageCommand = new SlashCommandBuilder()
|
||||
.setName(config.slashCommandName)
|
||||
.setDescription('Generate a message from learned past messages')
|
||||
.addBooleanOption((tts) =>
|
||||
tts.setName('tts').setDescription('Read the message via text-to-speech.').setRequired(false),
|
||||
)
|
||||
.addBooleanOption((debug) =>
|
||||
debug
|
||||
.setName('debug')
|
||||
@@ -49,6 +46,38 @@ const channelOptionsGenerator = (builder: SlashCommandChannelOption, index: numb
|
||||
.setRequired(index === 0)
|
||||
.addChannelTypes(ChannelType.GuildText);
|
||||
|
||||
export const autoRespondCommand = new SlashCommandBuilder()
|
||||
.setName('autorespond')
|
||||
.setDescription('Configure channels where the bot will automatically respond to all messages')
|
||||
.addSubcommand((sub) => {
|
||||
sub
|
||||
.setName('add')
|
||||
.setDescription('Add channels where the bot will automatically respond to all messages');
|
||||
Array.from(Array(CHANNEL_OPTIONS_MAX).keys()).forEach((index) =>
|
||||
sub.addChannelOption((opt) => channelOptionsGenerator(opt, index)),
|
||||
);
|
||||
return sub;
|
||||
})
|
||||
.addSubcommand((sub) => {
|
||||
sub
|
||||
.setName('remove')
|
||||
.setDescription('Remove channels from auto-response');
|
||||
Array.from(Array(CHANNEL_OPTIONS_MAX).keys()).forEach((index) =>
|
||||
sub.addChannelOption((opt) => channelOptionsGenerator(opt, index)),
|
||||
);
|
||||
return sub;
|
||||
})
|
||||
.addSubcommand((sub) =>
|
||||
sub
|
||||
.setName('list')
|
||||
.setDescription('List the channels where the bot auto-responds to messages'),
|
||||
)
|
||||
.addSubcommand((sub) =>
|
||||
sub
|
||||
.setName('modify')
|
||||
.setDescription('Add or remove auto-respond channels via select menu UI (first 25 text channels only)'),
|
||||
);
|
||||
|
||||
export const listenChannelCommand = new SlashCommandBuilder()
|
||||
.setName('listen')
|
||||
.setDescription('Change what channels the bot actively listens to and learns from.')
|
||||
@@ -110,7 +139,8 @@ const commands = [
|
||||
inviteCommand.toJSON(),
|
||||
messageCommand.toJSON(),
|
||||
listenChannelCommand.toJSON(),
|
||||
trainCommand.toJSON(),
|
||||
autoRespondCommand.toJSON(),
|
||||
trainCommand.toJSON()
|
||||
];
|
||||
|
||||
export async function deployCommands(clientId: string) {
|
||||
|
||||
@@ -12,6 +12,11 @@ export class Channel extends BaseEntity {
|
||||
})
|
||||
listen: boolean;
|
||||
|
||||
@Column({
|
||||
default: false,
|
||||
})
|
||||
autoRespond: boolean;
|
||||
|
||||
@ManyToOne(() => Guild, (guild) => guild.channels)
|
||||
guild: Guild;
|
||||
}
|
||||
|
||||
424
src/index.ts
424
src/index.ts
@@ -1,6 +1,8 @@
|
||||
import 'source-map-support/register';
|
||||
import { CONFIG_DIR } from './config/setup';
|
||||
import 'reflect-metadata';
|
||||
import * as Discord from 'discord.js';
|
||||
|
||||
import Markov, {
|
||||
MarkovGenerateOptions,
|
||||
MarkovConstructorOptions,
|
||||
@@ -24,6 +26,7 @@ import {
|
||||
listenChannelCommand,
|
||||
messageCommand,
|
||||
trainCommand,
|
||||
autoRespondCommand,
|
||||
} from './deploy-commands';
|
||||
import { getRandomElement, getVersion, packageJson } from './util';
|
||||
import ormconfig from './ormconfig';
|
||||
@@ -35,6 +38,7 @@ interface MarkovDataCustom {
|
||||
interface SelectMenuChannel {
|
||||
id: string;
|
||||
listen?: boolean;
|
||||
autoRespond?: boolean;
|
||||
name?: string;
|
||||
}
|
||||
|
||||
@@ -53,11 +57,19 @@ type AgnosticReplyOptions = Omit<Discord.MessageCreateOptions, 'reply' | 'sticke
|
||||
const INVALID_PERMISSIONS_MESSAGE = 'You do not have the permissions for this action.';
|
||||
const INVALID_GUILD_MESSAGE = 'This action must be performed within a server.';
|
||||
|
||||
const rest = new Discord.REST({ version: '10' }).setToken(config.token);
|
||||
const rest = new Discord.REST({
|
||||
version: '10',
|
||||
timeout: 120000, // 120 seconds
|
||||
retries: 3
|
||||
}).setToken(config.token);
|
||||
|
||||
const client = new Discord.Client<true>({
|
||||
failIfNotExists: false,
|
||||
intents: [Discord.GatewayIntentBits.GuildMessages, Discord.GatewayIntentBits.Guilds],
|
||||
intents: [
|
||||
Discord.GatewayIntentBits.GuildMessages,
|
||||
Discord.GatewayIntentBits.Guilds,
|
||||
Discord.GatewayIntentBits.GuildMembers
|
||||
],
|
||||
presence: {
|
||||
activities: [
|
||||
{
|
||||
@@ -114,6 +126,53 @@ async function isValidChannel(channel: Discord.TextBasedChannel): Promise<boolea
|
||||
return dbChannel?.listen || false;
|
||||
}
|
||||
|
||||
async function isAutoRespondChannel(channel: Discord.TextBasedChannel): Promise<boolean> {
|
||||
const channelId = getGuildChannelId(channel);
|
||||
if (!channelId) return false;
|
||||
const dbChannel = await Channel.findOneBy({ id: channelId });
|
||||
return dbChannel?.autoRespond || false;
|
||||
}
|
||||
|
||||
async function getAutoRespondChannels(guild: Discord.Guild): Promise<Discord.TextChannel[]> {
|
||||
const dbChannels = await Channel.findBy({ guild: { id: guild.id }, autoRespond: true });
|
||||
const channels = (
|
||||
await Promise.all(
|
||||
dbChannels.map(async (dbc) => {
|
||||
try {
|
||||
return guild.channels.fetch(dbc.id);
|
||||
} catch (err) {
|
||||
L.error({ erroredChannel: dbc, channelId: dbc.id }, 'Error fetching channel');
|
||||
throw err;
|
||||
}
|
||||
}),
|
||||
)
|
||||
).filter((c): c is Discord.TextChannel => c !== null && c instanceof Discord.TextChannel);
|
||||
return channels;
|
||||
}
|
||||
|
||||
async function addAutoRespondChannels(channels: Discord.TextChannel[], guildId: string): Promise<void> {
|
||||
const dbChannels = channels.map((c) => {
|
||||
return Channel.create({ id: c.id, guild: Guild.create({ id: guildId }), autoRespond: true });
|
||||
});
|
||||
await Channel.save(dbChannels);
|
||||
}
|
||||
|
||||
async function removeAutoRespondChannels(channels: Discord.TextChannel[], guildId: string): Promise<void> {
|
||||
const dbChannels = channels.map((c) => {
|
||||
return Channel.create({ id: c.id, guild: Guild.create({ id: guildId }), autoRespond: false });
|
||||
});
|
||||
await Channel.save(dbChannels);
|
||||
}
|
||||
|
||||
async function listAutoRespondChannels(interaction: Discord.CommandInteraction): Promise<string> {
|
||||
if (!interaction.guildId || !interaction.guild) return INVALID_GUILD_MESSAGE;
|
||||
const channels = await getAutoRespondChannels(interaction.guild);
|
||||
const channelText = channels.reduce((list, channel) => {
|
||||
return `${list}\n • <#${channel.id}>`;
|
||||
}, '');
|
||||
return `The bot will automatically respond to all messages in ${channels.length} channel(s).${channelText}`;
|
||||
}
|
||||
|
||||
function isHumanAuthoredMessage(message: Discord.Message | Discord.PartialMessage): boolean {
|
||||
return !(message.author?.bot || message.system);
|
||||
}
|
||||
@@ -151,7 +210,12 @@ async function getTextChannels(guild: Discord.Guild): Promise<SelectMenuChannel[
|
||||
}));
|
||||
const notFoundDbChannels: SelectMenuChannel[] = textChannels
|
||||
.filter((c) => !foundDbChannels.find((d) => d.id === c.id))
|
||||
.map((c) => ({ id: c.id, listen: false, name: textChannels.find((t) => t.id === c.id)?.name }));
|
||||
.map((c) => ({
|
||||
id: c.id,
|
||||
listen: false,
|
||||
autoRespond: false,
|
||||
name: textChannels.find((t) => t.id === c.id)?.name
|
||||
}));
|
||||
const limitedDbChannels = foundDbChannelsWithName
|
||||
.concat(notFoundDbChannels)
|
||||
.slice(0, MAX_SELECT_OPTIONS);
|
||||
@@ -223,7 +287,7 @@ function isAllowedUser(
|
||||
return true;
|
||||
}
|
||||
|
||||
type MessageCommands = 'respond' | 'train' | 'help' | 'invite' | 'debug' | 'tts' | null;
|
||||
type MessageCommands = 'respond' | 'train' | 'help' | 'invite' | 'debug' | null;
|
||||
|
||||
/**
|
||||
* Reads a new message and checks if and which command it is.
|
||||
@@ -246,8 +310,6 @@ function validateMessage(message: Discord.Message): MessageCommands {
|
||||
command = 'invite';
|
||||
} else if (split[1] === 'debug') {
|
||||
command = 'debug';
|
||||
} else if (split[1] === 'tts') {
|
||||
command = 'tts';
|
||||
}
|
||||
}
|
||||
return command;
|
||||
@@ -272,12 +334,23 @@ function messageToData(message: Discord.Message): AddDataProps {
|
||||
/**
|
||||
* Recursively gets all messages in a text channel's history.
|
||||
*/
|
||||
import { TrainingStateManager } from './training-state';
|
||||
|
||||
async function saveGuildMessageHistory(
|
||||
interaction: Discord.Message | Discord.CommandInteraction,
|
||||
clean = true,
|
||||
): Promise<string> {
|
||||
if (!isModerator(interaction.member)) return INVALID_PERMISSIONS_MESSAGE;
|
||||
if (!interaction.guildId || !interaction.guild) return INVALID_GUILD_MESSAGE;
|
||||
|
||||
const stateManager = new TrainingStateManager(interaction.guildId, CONFIG_DIR);
|
||||
|
||||
// Check if training is already in progress
|
||||
const currentState = stateManager.getState();
|
||||
if (currentState.inProgress) {
|
||||
return `Training is already in progress. Last update: ${currentState.lastUpdate}. Use /train with clean=true to restart.`;
|
||||
}
|
||||
|
||||
const markov = await getMarkovByGuildId(interaction.guildId);
|
||||
const channels = await getValidChannels(interaction.guild);
|
||||
|
||||
@@ -287,12 +360,23 @@ async function saveGuildMessageHistory(
|
||||
}
|
||||
|
||||
if (clean) {
|
||||
L.debug('Deleting old data');
|
||||
L.debug('Deleting old data and resetting state');
|
||||
await markov.delete();
|
||||
stateManager.reset();
|
||||
} else {
|
||||
L.debug('Not deleting old data during training');
|
||||
// Filter out already processed channels when not cleaning
|
||||
const unprocessedChannels = channels.filter(
|
||||
channel => !stateManager.isChannelProcessed(channel.id)
|
||||
);
|
||||
if (unprocessedChannels.length === 0) {
|
||||
return 'All channels have been processed. Use clean=true to retrain.';
|
||||
}
|
||||
channels.splice(0, channels.length, ...unprocessedChannels);
|
||||
}
|
||||
|
||||
stateManager.startTraining();
|
||||
|
||||
const channelIds = channels.map((c) => c.id);
|
||||
L.debug({ channelIds }, `Training from text channels`);
|
||||
|
||||
@@ -332,20 +416,42 @@ async function saveGuildMessageHistory(
|
||||
progressMessage = (await interaction.followUp(updateMessageData)) as Discord.Message;
|
||||
}
|
||||
|
||||
const PAGE_SIZE = 100;
|
||||
const UPDATE_RATE = 1000; // In number of messages processed
|
||||
const PAGE_SIZE = 50; // Reduced page size for better stability
|
||||
const UPDATE_RATE = 500; // More frequent updates
|
||||
const BATCH_SIZE = 100; // Number of messages to process before a small delay
|
||||
const BATCH_DELAY = 100; // Milliseconds to wait between batches
|
||||
const MAX_MEMORY_USAGE = 1024 * 1024 * 1024; // 1GB memory limit
|
||||
|
||||
let lastUpdate = 0;
|
||||
let messagesCount = 0;
|
||||
let firstMessageDate: number | undefined;
|
||||
// eslint-disable-next-line no-restricted-syntax
|
||||
for (const channel of channels) {
|
||||
let oldestMessageID: string | undefined;
|
||||
let keepGoing = true;
|
||||
L.debug({ channelId: channel.id, messagesCount }, `Training from channel`);
|
||||
const channelCreateDate = channel.createdTimestamp;
|
||||
const channelEta = makeEta({ autostart: true, min: 0, max: 1, historyTimeConstant: 30 });
|
||||
let batchCount = 0;
|
||||
|
||||
while (keepGoing) {
|
||||
// Monitor memory usage
|
||||
const getMemoryUsage = () => {
|
||||
const used = process.memoryUsage();
|
||||
return used.heapUsed;
|
||||
};
|
||||
|
||||
// Add delay between batches
|
||||
const processingDelay = () => new Promise(resolve => setTimeout(resolve, BATCH_DELAY));
|
||||
|
||||
try {
|
||||
// eslint-disable-next-line no-restricted-syntax
|
||||
for (const channel of channels) {
|
||||
try {
|
||||
// Check if we should skip this channel (already processed)
|
||||
if (stateManager.isChannelProcessed(channel.id)) {
|
||||
L.debug({ channelId: channel.id }, 'Skipping already processed channel');
|
||||
continue;
|
||||
}
|
||||
let keepGoing = true;
|
||||
let oldestMessageID = stateManager.shouldResumeFromMessage(channel.id);
|
||||
L.debug({ channelId: channel.id, messagesCount }, `Training from channel`);
|
||||
const channelCreateDate = channel.createdTimestamp;
|
||||
const channelEta = makeEta({ autostart: true, min: 0, max: 1, historyTimeConstant: 30 });
|
||||
|
||||
while (keepGoing) {
|
||||
let allBatchMessages = new Discord.Collection<string, Discord.Message<boolean>>();
|
||||
let channelBatchMessages: Discord.Collection<string, Discord.Message<boolean>>;
|
||||
try {
|
||||
@@ -407,15 +513,55 @@ async function saveGuildMessageHistory(
|
||||
|
||||
allBatchMessages = allBatchMessages.concat(channelBatchMessages);
|
||||
|
||||
// Filter and data map messages to be ready for addition to the corpus
|
||||
const humanAuthoredMessages = allBatchMessages
|
||||
.filter((m) => isHumanAuthoredMessage(m))
|
||||
.map(messageToData);
|
||||
L.trace({ oldestMessageID }, `Saving ${humanAuthoredMessages.length} messages`);
|
||||
// eslint-disable-next-line no-await-in-loop
|
||||
await markov.addData(humanAuthoredMessages);
|
||||
L.trace('Finished saving messages');
|
||||
messagesCount += humanAuthoredMessages.length;
|
||||
try {
|
||||
// Check memory usage before processing
|
||||
const memoryUsage = getMemoryUsage();
|
||||
if (memoryUsage > MAX_MEMORY_USAGE) {
|
||||
L.warn('Memory usage too high, waiting for garbage collection');
|
||||
await processingDelay();
|
||||
global.gc?.(); // Optional garbage collection if --expose-gc flag is used
|
||||
}
|
||||
|
||||
// Filter and data map messages to be ready for addition to the corpus
|
||||
const humanAuthoredMessages = allBatchMessages
|
||||
.filter((m) => isHumanAuthoredMessage(m))
|
||||
.map(messageToData);
|
||||
|
||||
// Process messages in smaller batches for stability
|
||||
for (let i = 0; i < humanAuthoredMessages.length; i += BATCH_SIZE) {
|
||||
const batch = humanAuthoredMessages.slice(i, i + BATCH_SIZE);
|
||||
L.trace({ oldestMessageID, batchSize: batch.length }, `Saving batch of messages`);
|
||||
|
||||
try {
|
||||
// eslint-disable-next-line no-await-in-loop
|
||||
await markov.addData(batch);
|
||||
batchCount++;
|
||||
messagesCount += batch.length;
|
||||
|
||||
// Update state after successful batch
|
||||
const lastMessage = allBatchMessages.last();
|
||||
if (lastMessage) {
|
||||
stateManager.updateProgress(channel.id, lastMessage.id, messagesCount);
|
||||
}
|
||||
|
||||
// Add delay between batches
|
||||
if (batchCount % 5 === 0) { // Every 5 batches
|
||||
await processingDelay();
|
||||
}
|
||||
} catch (err) {
|
||||
stateManager.recordError(err as Error, channel.id, oldestMessageID);
|
||||
L.error({ err, batchSize: batch.length }, 'Error saving batch of messages');
|
||||
// Continue with next batch instead of failing completely
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
L.trace('Finished processing message batches');
|
||||
} catch (err) {
|
||||
L.error({ err }, 'Error processing messages');
|
||||
// Wait a bit before continuing to next batch of messages
|
||||
await processingDelay();
|
||||
}
|
||||
const lastMessage = channelBatchMessages.last();
|
||||
|
||||
// Update tracking metrics
|
||||
@@ -457,12 +603,24 @@ async function saveGuildMessageHistory(
|
||||
...updateMessageData,
|
||||
embeds: [new Discord.EmbedBuilder(embedOptions)],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
L.error({ err }, 'Error processing channel');
|
||||
stateManager.recordError(err as Error);
|
||||
// Continue with next channel
|
||||
}
|
||||
}
|
||||
|
||||
L.info({ channelIds }, `Trained from ${messagesCount} past human authored messages.`);
|
||||
return `Trained from ${messagesCount} past human authored messages.`;
|
||||
L.info({ channelIds }, `Trained from ${messagesCount} past human authored messages.`);
|
||||
stateManager.finishTraining();
|
||||
return `Trained from ${messagesCount} past human authored messages.`;
|
||||
} catch (err) {
|
||||
const error = err as Error;
|
||||
L.error({ err }, 'Error during training completion');
|
||||
stateManager.recordError(error);
|
||||
return `Training encountered an error: ${error.message}. Use clean=false to resume from last checkpoint.`;
|
||||
}
|
||||
}
|
||||
|
||||
interface JSONImport {
|
||||
@@ -481,7 +639,17 @@ async function trainFromAttachmentJson(
|
||||
if (!isModerator(interaction.member)) return INVALID_PERMISSIONS_MESSAGE;
|
||||
if (!interaction.guildId || !interaction.guild) return INVALID_GUILD_MESSAGE;
|
||||
const { guildId } = interaction;
|
||||
|
||||
const stateManager = new TrainingStateManager(guildId, CONFIG_DIR);
|
||||
|
||||
// Check if training is already in progress
|
||||
const currentState = stateManager.getState();
|
||||
if (currentState.inProgress) {
|
||||
return `Training is already in progress. Last update: ${currentState.lastUpdate}. Use clean=true to restart.`;
|
||||
}
|
||||
|
||||
const markov = await getMarkovByGuildId(guildId);
|
||||
stateManager.startTraining();
|
||||
|
||||
let trainingData: AddDataProps[];
|
||||
try {
|
||||
@@ -517,14 +685,49 @@ async function trainFromAttachmentJson(
|
||||
if (clean) {
|
||||
L.debug('Deleting old data');
|
||||
await markov.delete();
|
||||
stateManager.reset();
|
||||
} else {
|
||||
L.debug('Not deleting old data during training');
|
||||
}
|
||||
|
||||
await markov.addData(trainingData);
|
||||
const BATCH_SIZE = 100;
|
||||
const BATCH_DELAY = 100;
|
||||
let processedCount = 0;
|
||||
let batchCount = 0;
|
||||
|
||||
L.info(`Trained from ${trainingData.length} past human authored messages.`);
|
||||
return `Trained from ${trainingData.length} past human authored messages.`;
|
||||
try {
|
||||
// Process messages in batches
|
||||
for (let i = 0; i < trainingData.length; i += BATCH_SIZE) {
|
||||
const batch = trainingData.slice(i, i + BATCH_SIZE);
|
||||
try {
|
||||
await markov.addData(batch);
|
||||
processedCount += batch.length;
|
||||
batchCount++;
|
||||
|
||||
// Update state after successful batch
|
||||
stateManager.updateProgress('json-import', i.toString(), processedCount);
|
||||
|
||||
// Add delay between batches
|
||||
if (batchCount % 5 === 0) {
|
||||
await new Promise(resolve => setTimeout(resolve, BATCH_DELAY));
|
||||
}
|
||||
} catch (err) {
|
||||
L.error({ err, batchIndex: i }, 'Error processing JSON batch');
|
||||
stateManager.recordError(err as Error, 'json-import', i.toString());
|
||||
// Continue with next batch instead of failing completely
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
L.info(`Successfully trained from ${processedCount} messages from JSON.`);
|
||||
stateManager.finishTraining();
|
||||
return `Successfully trained from ${processedCount} messages from JSON.`;
|
||||
} catch (err) {
|
||||
const error = err as Error;
|
||||
L.error({ err }, 'Error during JSON training completion');
|
||||
stateManager.recordError(error);
|
||||
return `Training encountered an error: ${error.message}. Use clean=false to resume from last checkpoint.`;
|
||||
}
|
||||
}
|
||||
|
||||
interface GenerateResponse {
|
||||
@@ -534,7 +737,6 @@ interface GenerateResponse {
|
||||
}
|
||||
|
||||
interface GenerateOptions {
|
||||
tts?: boolean;
|
||||
debug?: boolean;
|
||||
startSeed?: string;
|
||||
}
|
||||
@@ -551,7 +753,7 @@ async function generateResponse(
|
||||
options?: GenerateOptions,
|
||||
): Promise<GenerateResponse> {
|
||||
L.debug({ options }, 'Responding...');
|
||||
const { tts = false, debug = false, startSeed } = options || {};
|
||||
const { debug = false, startSeed } = options || {};
|
||||
if (!interaction.guildId) {
|
||||
L.warn('Received an interaction without a guildId');
|
||||
return { error: { content: INVALID_GUILD_MESSAGE } };
|
||||
@@ -568,7 +770,6 @@ async function generateResponse(
|
||||
L.info({ string: response.string }, 'Generated response text');
|
||||
L.debug({ response }, 'Generated response object');
|
||||
const messageOpts: AgnosticReplyOptions = {
|
||||
tts,
|
||||
allowedMentions: { repliedUser: false, parse: [] },
|
||||
};
|
||||
const attachmentUrls = response.refs
|
||||
@@ -652,12 +853,17 @@ function helpMessage(): AgnosticReplyOptions {
|
||||
.addFields([
|
||||
{
|
||||
name: `${config.messageCommandPrefix} or /${messageCommand.name}`,
|
||||
value: `Generates a sentence to say based on the chat database. Send your message as TTS to recieve it as TTS.`,
|
||||
value: `Generates a sentence based on the chat database.`,
|
||||
},
|
||||
|
||||
{
|
||||
name: `/${listenChannelCommand.name}`,
|
||||
value: `Add, remove, list, or modify the list of channels the bot listens to.`,
|
||||
value: `Add, remove, list, or modify the list of channels the bot listens to and learns from.`,
|
||||
},
|
||||
|
||||
{
|
||||
name: `/autorespond`,
|
||||
value: `Add, remove, list, or modify the list of channels where the bot will automatically respond to all messages.`,
|
||||
},
|
||||
|
||||
{
|
||||
@@ -674,11 +880,6 @@ function helpMessage(): AgnosticReplyOptions {
|
||||
name: `${config.messageCommandPrefix} debug or /${messageCommand.name} debug: True`,
|
||||
value: `Runs the ${config.messageCommandPrefix} command and follows it up with debug info.`,
|
||||
},
|
||||
|
||||
{
|
||||
name: `${config.messageCommandPrefix} tts or /${messageCommand.name} tts: True`,
|
||||
value: `Runs the ${config.messageCommandPrefix} command and reads it with text-to-speech.`,
|
||||
},
|
||||
])
|
||||
.setFooter({
|
||||
text: `${packageJson().name} ${getVersion()} by ${
|
||||
@@ -694,12 +895,11 @@ function generateInviteUrl(): string {
|
||||
return client.generateInvite({
|
||||
scopes: [Discord.OAuth2Scopes.Bot, Discord.OAuth2Scopes.ApplicationsCommands],
|
||||
permissions: [
|
||||
'ViewChannel',
|
||||
'SendMessages',
|
||||
'SendTTSMessages',
|
||||
'AttachFiles',
|
||||
'ReadMessageHistory',
|
||||
],
|
||||
'ViewChannel',
|
||||
'SendMessages',
|
||||
'AttachFiles',
|
||||
'ReadMessageHistory'
|
||||
],
|
||||
});
|
||||
}
|
||||
|
||||
@@ -789,11 +989,6 @@ client.on('messageCreate', async (message) => {
|
||||
const generatedResponse = await generateResponse(message);
|
||||
await handleResponseMessage(generatedResponse, message);
|
||||
}
|
||||
if (command === 'tts') {
|
||||
L.debug('Responding to legacy command tts');
|
||||
const generatedResponse = await generateResponse(message, { tts: true });
|
||||
await handleResponseMessage(generatedResponse, message);
|
||||
}
|
||||
if (command === 'debug') {
|
||||
L.debug('Responding to legacy command debug');
|
||||
const generatedResponse = await generateResponse(message, { debug: true });
|
||||
@@ -802,11 +997,23 @@ client.on('messageCreate', async (message) => {
|
||||
if (command === null) {
|
||||
if (isHumanAuthoredMessage(message)) {
|
||||
if (client.user && message.mentions.has(client.user)) {
|
||||
// Check if response channels are configured and if this channel is allowed
|
||||
if (config.responseChannelIds.length > 0 && !config.responseChannelIds.includes(message.channel.id)) {
|
||||
L.debug('Ignoring mention in non-response channel');
|
||||
return;
|
||||
}
|
||||
|
||||
L.debug('Responding to mention');
|
||||
// <@!278354154563567636> how are you doing?
|
||||
const startSeed = message.content.replace(/<@!\d+>/g, '').trim();
|
||||
const generatedResponse = await generateResponse(message, { startSeed });
|
||||
await handleResponseMessage(generatedResponse, message);
|
||||
} else if (await isAutoRespondChannel(message.channel)) {
|
||||
// Auto-respond to all messages in configured channels using message content as context
|
||||
L.debug('Auto-responding in configured channel with context');
|
||||
const startSeed = message.content.trim();
|
||||
const generatedResponse = await generateResponse(message, { startSeed });
|
||||
await handleResponseMessage(generatedResponse, message);
|
||||
}
|
||||
|
||||
if (await isValidChannel(message.channel)) {
|
||||
@@ -848,7 +1055,7 @@ client.on('threadDelete', async (thread) => {
|
||||
await markov.removeTags([thread.id]);
|
||||
});
|
||||
|
||||
// eslint-disable-next-line consistent-return
|
||||
|
||||
client.on('interactionCreate', async (interaction) => {
|
||||
if (interaction.isChatInputCommand()) {
|
||||
L.info({ command: interaction.commandName }, 'Recieved slash command');
|
||||
@@ -859,23 +1066,12 @@ client.on('interactionCreate', async (interaction) => {
|
||||
await interaction.reply(inviteMessage());
|
||||
} else if (interaction.commandName === messageCommand.name) {
|
||||
await interaction.deferReply();
|
||||
const tts = interaction.options.getBoolean('tts') || false;
|
||||
const debug = interaction.options.getBoolean('debug') || false;
|
||||
const startSeed = interaction.options.getString('seed')?.trim() || undefined;
|
||||
const generatedResponse = await generateResponse(interaction, { tts, debug, startSeed });
|
||||
const generatedResponse = await generateResponse(interaction, { debug, startSeed });
|
||||
|
||||
/**
|
||||
* TTS doesn't work when using editReply, so instead we use delete + followUp
|
||||
* However, delete + followUp is ugly and shows the bot replying to "Message could not be loaded.",
|
||||
* so we avoid it if possible
|
||||
*/
|
||||
if (generatedResponse.message) {
|
||||
if (generatedResponse.message.tts) {
|
||||
await interaction.deleteReply();
|
||||
await interaction.followUp(generatedResponse.message);
|
||||
} else {
|
||||
await interaction.editReply(generatedResponse.message);
|
||||
}
|
||||
await interaction.editReply(generatedResponse.message);
|
||||
} else {
|
||||
await interaction.deleteReply();
|
||||
}
|
||||
@@ -943,6 +1139,67 @@ client.on('interactionCreate', async (interaction) => {
|
||||
ephemeral: true,
|
||||
});
|
||||
}
|
||||
} else if (interaction.commandName === autoRespondCommand.name) {
|
||||
await interaction.deferReply();
|
||||
const subCommand = interaction.options.getSubcommand(true) as 'add' | 'remove' | 'list' | 'modify';
|
||||
|
||||
if (subCommand === 'list') {
|
||||
const reply = await listAutoRespondChannels(interaction);
|
||||
await interaction.editReply(reply);
|
||||
} else if (subCommand === 'add') {
|
||||
if (!isModerator(interaction.member)) {
|
||||
return handleUnprivileged(interaction);
|
||||
}
|
||||
if (!interaction.guildId) {
|
||||
return handleNoGuild(interaction);
|
||||
}
|
||||
const channels = getChannelsFromInteraction(interaction);
|
||||
await addAutoRespondChannels(channels, interaction.guildId);
|
||||
await interaction.editReply(
|
||||
`Added ${channels.length} text channels to auto-respond list.`
|
||||
);
|
||||
} else if (subCommand === 'remove') {
|
||||
if (!isModerator(interaction.member)) {
|
||||
return handleUnprivileged(interaction);
|
||||
}
|
||||
if (!interaction.guildId) {
|
||||
return handleNoGuild(interaction);
|
||||
}
|
||||
const channels = getChannelsFromInteraction(interaction);
|
||||
await removeAutoRespondChannels(channels, interaction.guildId);
|
||||
await interaction.editReply(
|
||||
`Removed ${channels.length} text channels from auto-respond list.`
|
||||
);
|
||||
} else if (subCommand === 'modify') {
|
||||
if (!interaction.guild) {
|
||||
return handleNoGuild(interaction);
|
||||
}
|
||||
if (!isModerator(interaction.member)) {
|
||||
await handleUnprivileged(interaction);
|
||||
}
|
||||
await interaction.deleteReply();
|
||||
const dbTextChannels = await getTextChannels(interaction.guild);
|
||||
const row = new Discord.ActionRowBuilder<Discord.StringSelectMenuBuilder>().addComponents(
|
||||
new Discord.StringSelectMenuBuilder()
|
||||
.setCustomId('autorespond-modify-select')
|
||||
.setPlaceholder('Nothing selected')
|
||||
.setMinValues(0)
|
||||
.setMaxValues(dbTextChannels.length)
|
||||
.addOptions(
|
||||
dbTextChannels.map((c) => ({
|
||||
label: `#${c.name}` || c.id,
|
||||
value: c.id,
|
||||
default: c.autoRespond || false,
|
||||
})),
|
||||
),
|
||||
);
|
||||
|
||||
await interaction.followUp({
|
||||
content: 'Select which channels you would like the bot to auto-respond in',
|
||||
components: [row],
|
||||
ephemeral: true,
|
||||
});
|
||||
}
|
||||
} else if (interaction.commandName === trainCommand.name) {
|
||||
await interaction.deferReply();
|
||||
const clean = interaction.options.getBoolean('clean') ?? true;
|
||||
@@ -990,6 +1247,37 @@ client.on('interactionCreate', async (interaction) => {
|
||||
content: 'Updated actively listened to channels list.',
|
||||
ephemeral: true,
|
||||
});
|
||||
} else if (interaction.customId === 'autorespond-modify-select') {
|
||||
await interaction.deferUpdate();
|
||||
const { guild } = interaction;
|
||||
if (!isModerator(interaction.member)) {
|
||||
return handleUnprivileged(interaction, false);
|
||||
}
|
||||
if (!guild) {
|
||||
return handleNoGuild(interaction, false);
|
||||
}
|
||||
|
||||
const allChannels =
|
||||
(interaction.component as Discord.StringSelectMenuComponent).options?.map((o) => o.value) ||
|
||||
[];
|
||||
const selectedChannelIds = interaction.values;
|
||||
|
||||
const textChannels = (
|
||||
await Promise.all(
|
||||
allChannels.map(async (c) => {
|
||||
return guild.channels.fetch(c);
|
||||
}),
|
||||
)
|
||||
).filter((c): c is Discord.TextChannel => c !== null && c instanceof Discord.TextChannel);
|
||||
const unselectedChannels = textChannels.filter((t) => !selectedChannelIds.includes(t.id));
|
||||
const selectedChannels = textChannels.filter((t) => selectedChannelIds.includes(t.id));
|
||||
await addAutoRespondChannels(selectedChannels, guild.id);
|
||||
await removeAutoRespondChannels(unselectedChannels, guild.id);
|
||||
|
||||
await interaction.followUp({
|
||||
content: 'Updated auto-respond channels list.',
|
||||
ephemeral: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
391
src/train.ts
Normal file
391
src/train.ts
Normal file
@@ -0,0 +1,391 @@
|
||||
import 'source-map-support/register';
|
||||
import 'reflect-metadata';
|
||||
import Markov, { MarkovConstructorOptions, AddDataProps } from 'markov-strings-db';
|
||||
import { DataSource } from 'typeorm';
|
||||
import { promises as fs } from 'fs';
|
||||
import path from 'path';
|
||||
import { config } from './config';
|
||||
import ormconfig from './ormconfig';
|
||||
import { Guild } from './entity/Guild';
|
||||
import { Channel } from './entity/Channel';
|
||||
import L from './logger';
|
||||
import { MarkovDataCustom } from './types';
|
||||
import { TrainingStateManager } from './training-state';
|
||||
import { CONFIG_DIR } from './config/setup';
|
||||
|
||||
const markovOpts: MarkovConstructorOptions = {
|
||||
stateSize: config.stateSize,
|
||||
};
|
||||
|
||||
// Constants for batch processing
|
||||
const BATCH_SIZE = 100; // Process messages in batches
|
||||
const BATCH_DELAY = 100; // Milliseconds to wait between batches
|
||||
const MAX_MEMORY_USAGE = 1024 * 1024 * 1024; // 1GB memory limit
|
||||
|
||||
// Monitor memory usage
|
||||
const getMemoryUsage = () => {
|
||||
const used = process.memoryUsage();
|
||||
return used.heapUsed;
|
||||
};
|
||||
|
||||
// Add delay between batches
|
||||
const processingDelay = () => new Promise(resolve => setTimeout(resolve, BATCH_DELAY));
|
||||
|
||||
async function getMarkovByGuildId(guildId: string): Promise<Markov> {
|
||||
const markov = new Markov({ id: guildId, options: { ...markovOpts, id: guildId } });
|
||||
L.trace({ guildId }, 'Setting up markov instance');
|
||||
await markov.setup(); // Connect the markov instance to the DB to assign it an ID
|
||||
return markov;
|
||||
}
|
||||
|
||||
interface JSONImport {
|
||||
message: string;
|
||||
attachments?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Train from a JSON file containing messages
|
||||
*/
|
||||
|
||||
async function trainFromJson(
|
||||
guildId: string,
|
||||
jsonPath: string,
|
||||
clean = true,
|
||||
): Promise<string> {
|
||||
const markov = await getMarkovByGuildId(guildId);
|
||||
|
||||
let trainingData: AddDataProps[];
|
||||
try {
|
||||
const fileContent = await fs.readFile(jsonPath, 'utf-8');
|
||||
const importData = JSON.parse(fileContent) as JSONImport[];
|
||||
|
||||
// Filter out invalid entries first
|
||||
const validData = importData.filter((datum, index) => {
|
||||
if (!datum.message || typeof datum.message !== 'string') {
|
||||
L.debug({ index }, 'Skipping entry without valid message');
|
||||
return false;
|
||||
}
|
||||
if (datum.attachments?.some(a => typeof a !== 'string')) {
|
||||
L.debug({ index }, 'Skipping entry with invalid attachments');
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
// Map valid entries to training data
|
||||
trainingData = validData.map(datum => {
|
||||
let custom: MarkovDataCustom | undefined;
|
||||
if (datum.attachments?.length) {
|
||||
custom = { attachments: datum.attachments };
|
||||
}
|
||||
return {
|
||||
string: datum.message,
|
||||
custom,
|
||||
tags: [guildId]
|
||||
};
|
||||
});
|
||||
} catch (err) {
|
||||
L.error(err);
|
||||
if (err instanceof SyntaxError) {
|
||||
return 'The provided JSON file has invalid formatting. See the logs for details.';
|
||||
}
|
||||
return `Error reading file: ${err instanceof Error ? err.message : 'Unknown error'}`;
|
||||
}
|
||||
|
||||
if (clean) {
|
||||
L.debug('Deleting old data');
|
||||
await markov.delete();
|
||||
} else {
|
||||
L.debug('Not deleting old data during training');
|
||||
}
|
||||
|
||||
let processedCount = 0;
|
||||
let batchCount = 0;
|
||||
const totalMessages = trainingData.length;
|
||||
|
||||
// Process messages in batches
|
||||
for (let i = 0; i < trainingData.length; i += BATCH_SIZE) {
|
||||
try {
|
||||
// Check memory usage
|
||||
const memoryUsage = getMemoryUsage();
|
||||
if (memoryUsage > MAX_MEMORY_USAGE) {
|
||||
L.warn('Memory usage too high, waiting for garbage collection');
|
||||
await processingDelay();
|
||||
global.gc?.(); // Optional garbage collection if --expose-gc flag is used
|
||||
}
|
||||
|
||||
const batch = trainingData.slice(i, i + BATCH_SIZE);
|
||||
await markov.addData(batch);
|
||||
|
||||
processedCount += batch.length;
|
||||
batchCount++;
|
||||
|
||||
// Log progress
|
||||
if (batchCount % 5 === 0) {
|
||||
const progress = (processedCount / totalMessages * 100).toFixed(2);
|
||||
L.info(`Progress: ${progress}% (${processedCount}/${totalMessages} messages)`);
|
||||
await processingDelay(); // Add delay every 5 batches
|
||||
}
|
||||
} catch (err) {
|
||||
L.error({ err, batchIndex: i }, 'Error processing batch');
|
||||
// Continue with next batch instead of failing completely
|
||||
await processingDelay(); // Wait a bit longer after an error
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
L.info(`Successfully trained from ${processedCount} messages.`);
|
||||
return `Successfully trained from ${processedCount} messages.`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Train from all JSON files in a directory
|
||||
*/
|
||||
/**
|
||||
* Train from all JSON files in a directory
|
||||
* @param guildId The Discord guild ID
|
||||
* @param dirPath Path to directory containing JSON files
|
||||
* @param clean Whether to clean existing data before training
|
||||
*/
|
||||
/**
|
||||
* Acquire a lock file for training to prevent concurrent processes
|
||||
*/
|
||||
async function acquireTrainingLock(guildId: string): Promise<boolean> {
|
||||
const lockPath = path.join(CONFIG_DIR, `${guildId}_training.lock`);
|
||||
try {
|
||||
await fs.writeFile(lockPath, process.pid.toString(), { flag: 'wx' });
|
||||
return true;
|
||||
} catch (err) {
|
||||
if ((err as NodeJS.ErrnoException).code === 'EEXIST') {
|
||||
try {
|
||||
const pid = parseInt(await fs.readFile(lockPath, 'utf-8'));
|
||||
try {
|
||||
// Check if process is still running
|
||||
process.kill(pid, 0);
|
||||
return false; // Process is still running
|
||||
} catch {
|
||||
// Process is not running, safe to remove lock
|
||||
await fs.unlink(lockPath);
|
||||
await fs.writeFile(lockPath, process.pid.toString());
|
||||
return true;
|
||||
}
|
||||
} catch {
|
||||
// Error reading/writing lock file
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Release the training lock file
|
||||
*/
|
||||
async function releaseTrainingLock(guildId: string): Promise<void> {
|
||||
const lockPath = path.join(CONFIG_DIR, `${guildId}_training.lock`);
|
||||
try {
|
||||
await fs.unlink(lockPath);
|
||||
} catch {
|
||||
// Ignore errors during cleanup
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitize and validate a directory path
|
||||
*/
|
||||
async function validateDirectoryPath(dirPath: string): Promise<string> {
|
||||
// Resolve to absolute path
|
||||
const absolutePath = path.resolve(dirPath);
|
||||
|
||||
// Prevent directory traversal
|
||||
const normalizedPath = path.normalize(absolutePath);
|
||||
if (!normalizedPath.startsWith(process.cwd())) {
|
||||
throw new Error('Directory must be within current working directory');
|
||||
}
|
||||
|
||||
// Verify directory exists and is accessible
|
||||
try {
|
||||
const stats = await fs.stat(normalizedPath);
|
||||
if (!stats.isDirectory()) {
|
||||
throw new Error('Path is not a directory');
|
||||
}
|
||||
await fs.access(normalizedPath, fs.constants.R_OK);
|
||||
return normalizedPath;
|
||||
} catch (err) {
|
||||
throw new Error(`Invalid directory path: ${err instanceof Error ? err.message : 'Unknown error'}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Train from all JSON files in a directory
|
||||
*/
|
||||
async function trainFromDirectory(
|
||||
guildId: string,
|
||||
dirPath: string,
|
||||
clean = true,
|
||||
): Promise<string> {
|
||||
L.debug({ guildId, dirPath, clean }, 'Starting directory training');
|
||||
const stateManager = new TrainingStateManager(guildId, CONFIG_DIR);
|
||||
|
||||
// Set up cleanup handler
|
||||
const cleanup = async () => {
|
||||
try {
|
||||
await releaseTrainingLock(guildId);
|
||||
stateManager.finishTraining();
|
||||
} catch (err) {
|
||||
L.error({ err }, 'Error during cleanup');
|
||||
}
|
||||
};
|
||||
|
||||
// Handle process termination
|
||||
process.once('SIGINT', cleanup);
|
||||
process.once('SIGTERM', cleanup);
|
||||
|
||||
try {
|
||||
// Try to acquire lock
|
||||
if (!await acquireTrainingLock(guildId)) {
|
||||
return 'Another training process is already running. Please wait for it to complete.';
|
||||
}
|
||||
|
||||
// Always reset state at the start of training
|
||||
stateManager.reset();
|
||||
|
||||
try {
|
||||
// Validate and normalize directory path
|
||||
const absolutePath = await validateDirectoryPath(dirPath);
|
||||
|
||||
// Get all JSON files in the directory
|
||||
L.trace({ dirPath: absolutePath }, 'Reading directory');
|
||||
const files = await fs.readdir(absolutePath);
|
||||
const jsonFiles = files.filter(file => file.toLowerCase().endsWith('.json'));
|
||||
|
||||
if (jsonFiles.length === 0) {
|
||||
L.warn({ dirPath: absolutePath }, 'No JSON files found in directory');
|
||||
return 'No JSON files found in the specified directory.';
|
||||
}
|
||||
|
||||
let totalProcessed = 0;
|
||||
let batchCount = 0;
|
||||
L.info({ fileCount: jsonFiles.length }, 'Found JSON files to process');
|
||||
|
||||
stateManager.startTraining();
|
||||
|
||||
// Process first file with clean flag, subsequent files without cleaning
|
||||
for (let i = 0; i < jsonFiles.length; i++) {
|
||||
const jsonPath = path.join(absolutePath, jsonFiles[i]);
|
||||
const fileNumber = i + 1;
|
||||
L.debug(
|
||||
{ file: jsonFiles[i], progress: `${fileNumber}/${jsonFiles.length}` },
|
||||
'Processing file'
|
||||
);
|
||||
|
||||
try {
|
||||
// Check memory usage before processing file
|
||||
const memoryUsage = getMemoryUsage();
|
||||
if (memoryUsage > MAX_MEMORY_USAGE) {
|
||||
L.warn('Memory usage too high, waiting for garbage collection');
|
||||
await processingDelay();
|
||||
global.gc?.(); // Optional garbage collection if --expose-gc flag is used
|
||||
}
|
||||
|
||||
// Check if we should skip this file (already processed)
|
||||
if (!clean && stateManager.isChannelProcessed(jsonFiles[i])) {
|
||||
L.debug({ file: jsonFiles[i] }, 'Skipping already processed file');
|
||||
continue;
|
||||
}
|
||||
|
||||
const result = await trainFromJson(
|
||||
guildId,
|
||||
jsonPath,
|
||||
i === 0 ? clean : false // Only clean on first file
|
||||
);
|
||||
|
||||
// Extract number of processed messages from result string
|
||||
const processed = parseInt(result.match(/\d+/)?.[0] || '0');
|
||||
totalProcessed += processed;
|
||||
batchCount++;
|
||||
|
||||
// Update state after each file
|
||||
stateManager.updateProgress('json-import', jsonFiles[i], totalProcessed);
|
||||
L.trace(
|
||||
{ file: jsonFiles[i], processed, totalProcessed },
|
||||
'File processing complete'
|
||||
);
|
||||
|
||||
// Add delay between files
|
||||
if (batchCount % 5 === 0) {
|
||||
await processingDelay();
|
||||
}
|
||||
|
||||
// Clear any references that might be held
|
||||
if (global.gc) {
|
||||
global.gc();
|
||||
}
|
||||
} catch (err) {
|
||||
const error = err as Error;
|
||||
L.error(
|
||||
{ error: error.message, file: jsonFiles[i], stack: error.stack },
|
||||
'Error processing JSON file'
|
||||
);
|
||||
stateManager.recordError(error, 'json-import', jsonFiles[i]);
|
||||
// Add longer delay after error
|
||||
await processingDelay();
|
||||
// Continue with next file instead of failing completely
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const summary = { totalProcessed, fileCount: jsonFiles.length };
|
||||
L.info(summary, 'Directory training complete');
|
||||
return `Successfully trained from ${totalProcessed} messages across ${jsonFiles.length} files.`;
|
||||
} finally {
|
||||
// Clean up regardless of success/failure
|
||||
await cleanup();
|
||||
// Remove process termination handlers
|
||||
process.off('SIGINT', cleanup);
|
||||
process.off('SIGTERM', cleanup);
|
||||
}
|
||||
} catch (err) {
|
||||
const error = err as Error;
|
||||
L.error(
|
||||
{ error: error.message, stack: error.stack, dirPath },
|
||||
'Error during directory training'
|
||||
);
|
||||
stateManager.recordError(error);
|
||||
return `Training encountered an error: ${error.message}. Use clean=false to resume from last checkpoint.`;
|
||||
}
|
||||
}
|
||||
|
||||
async function main(): Promise<void> {
|
||||
const args = process.argv.slice(2);
|
||||
if (args.length < 2) {
|
||||
console.log('Usage: node train.js <guildId> <path> [--keep-existing] [--directory]');
|
||||
console.log('Options:');
|
||||
console.log(' --keep-existing Keep existing training data');
|
||||
console.log(' --directory Process all JSON files in the specified directory');
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
const guildId = args[0];
|
||||
const inputPath = args[1];
|
||||
const keepExisting = args.includes('--keep-existing');
|
||||
const isDirectory = args.includes('--directory');
|
||||
|
||||
const dataSourceOptions = Markov.extendDataSourceOptions(ormconfig);
|
||||
const dataSource = new DataSource(dataSourceOptions);
|
||||
await dataSource.initialize();
|
||||
|
||||
// Ensure guild exists in DB
|
||||
await Guild.upsert(Guild.create({ id: guildId }), ['id']);
|
||||
|
||||
const result = isDirectory
|
||||
? await trainFromDirectory(guildId, inputPath, !keepExisting)
|
||||
: await trainFromJson(guildId, inputPath, !keepExisting);
|
||||
console.log(result);
|
||||
|
||||
await dataSource.destroy();
|
||||
}
|
||||
|
||||
if (require.main === module) {
|
||||
main().catch(console.error);
|
||||
}
|
||||
113
src/training-state.ts
Normal file
113
src/training-state.ts
Normal file
@@ -0,0 +1,113 @@
|
||||
import fs from 'fs-extra';
|
||||
import path from 'path';
|
||||
import { TrainingState } from './types';
|
||||
import L from './logger';
|
||||
|
||||
export class TrainingStateManager {
|
||||
private stateFile: string;
|
||||
private state: TrainingState;
|
||||
|
||||
constructor(guildId: string, configDir: string = 'config') {
|
||||
this.stateFile = path.join(configDir, 'training-state', `${guildId}.json`);
|
||||
|
||||
// Initialize with default state
|
||||
this.state = {
|
||||
guildId,
|
||||
processedChannels: [],
|
||||
totalMessages: 0,
|
||||
lastUpdate: new Date().toISOString(),
|
||||
inProgress: false
|
||||
};
|
||||
|
||||
// Ensure directory exists
|
||||
fs.ensureDirSync(path.dirname(this.stateFile));
|
||||
|
||||
// Load existing state if available
|
||||
this.loadState();
|
||||
}
|
||||
|
||||
private loadState(): void {
|
||||
try {
|
||||
if (fs.existsSync(this.stateFile)) {
|
||||
const savedState = fs.readJsonSync(this.stateFile);
|
||||
this.state = { ...this.state, ...savedState };
|
||||
L.info({ guildId: this.state.guildId }, 'Loaded existing training state');
|
||||
}
|
||||
} catch (err) {
|
||||
L.error({ err }, 'Error loading training state');
|
||||
// Keep using default state if load fails
|
||||
}
|
||||
}
|
||||
|
||||
private saveState(): void {
|
||||
try {
|
||||
fs.writeJsonSync(this.stateFile, this.state, { spaces: 2 });
|
||||
} catch (err) {
|
||||
L.error({ err }, 'Error saving training state');
|
||||
}
|
||||
}
|
||||
|
||||
public startTraining(): void {
|
||||
this.state.inProgress = true;
|
||||
this.state.error = undefined;
|
||||
this.state.lastUpdate = new Date().toISOString();
|
||||
this.saveState();
|
||||
}
|
||||
|
||||
public finishTraining(): void {
|
||||
this.state.inProgress = false;
|
||||
this.state.lastUpdate = new Date().toISOString();
|
||||
this.saveState();
|
||||
}
|
||||
|
||||
public updateProgress(channelId: string, messageId: string, messagesProcessed: number): void {
|
||||
this.state.lastChannelId = channelId;
|
||||
this.state.lastMessageId = messageId;
|
||||
this.state.totalMessages = messagesProcessed;
|
||||
this.state.lastUpdate = new Date().toISOString();
|
||||
this.saveState();
|
||||
}
|
||||
|
||||
public markChannelComplete(channelId: string): void {
|
||||
if (!this.state.processedChannels.includes(channelId)) {
|
||||
this.state.processedChannels.push(channelId);
|
||||
this.saveState();
|
||||
}
|
||||
}
|
||||
|
||||
public recordError(error: Error, channelId?: string, messageId?: string): void {
|
||||
this.state.error = {
|
||||
message: error.message,
|
||||
channelId,
|
||||
messageId,
|
||||
timestamp: new Date().toISOString()
|
||||
};
|
||||
this.saveState();
|
||||
}
|
||||
|
||||
public isChannelProcessed(channelId: string): boolean {
|
||||
return this.state.processedChannels.includes(channelId);
|
||||
}
|
||||
|
||||
public shouldResumeFromMessage(channelId: string): string | undefined {
|
||||
if (this.state.inProgress && this.state.lastChannelId === channelId) {
|
||||
return this.state.lastMessageId;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
public getState(): TrainingState {
|
||||
return { ...this.state };
|
||||
}
|
||||
|
||||
public reset(): void {
|
||||
this.state = {
|
||||
guildId: this.state.guildId,
|
||||
processedChannels: [],
|
||||
totalMessages: 0,
|
||||
lastUpdate: new Date().toISOString(),
|
||||
inProgress: false
|
||||
};
|
||||
this.saveState();
|
||||
}
|
||||
}
|
||||
19
src/types.ts
Normal file
19
src/types.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
export interface MarkovDataCustom {
|
||||
attachments: string[];
|
||||
}
|
||||
|
||||
export interface TrainingState {
|
||||
guildId: string;
|
||||
lastMessageId?: string;
|
||||
lastChannelId?: string;
|
||||
processedChannels: string[];
|
||||
totalMessages: number;
|
||||
lastUpdate: string;
|
||||
inProgress: boolean;
|
||||
error?: {
|
||||
message: string;
|
||||
channelId?: string;
|
||||
messageId?: string;
|
||||
timestamp: string;
|
||||
};
|
||||
}
|
||||
@@ -3,7 +3,7 @@
|
||||
"compilerOptions": {
|
||||
"target": "es2021", /* Specify ECMAScript target version: 'ES3' (default), 'ES5', 'ES2015', 'ES2016', 'ES2017', 'ES2018', 'ES2019' or 'ESNEXT'. */
|
||||
"module": "commonjs", /* Specify module code generation: 'none', 'commonjs', 'amd', 'system', 'umd', 'es2015', or 'ESNext'. */
|
||||
"outDir": "./dist", /* Redirect output structure to the directory. */
|
||||
"outDir": "./build", /* Redirect output structure to the directory. */
|
||||
"removeComments": true, /* Do not emit comments to output. */
|
||||
"esModuleInterop": true,
|
||||
"strict": true, /* Enable all strict type-checking options. */
|
||||
|
||||
Reference in New Issue
Block a user