mirror of
https://github.com/pacnpal/markov-discord.git
synced 2025-12-22 11:51:05 -05:00
Added several things: Parser to import JSON from DiscordChatExporter, ability to train without bot running and more.
This commit is contained in:
@@ -163,4 +163,19 @@ export class AppConfig {
|
||||
@IsOptional()
|
||||
@IsString()
|
||||
devGuildId = process.env.DEV_GUILD_ID;
|
||||
|
||||
/**
|
||||
* A list of channel IDs where the bot will respond to mentions.
|
||||
* If empty, the bot will respond to mentions in any channel.
|
||||
* @example ["734548250895319070"]
|
||||
* @default []
|
||||
* @env RESPONSE_CHANNEL_IDS (comma separated)
|
||||
*/
|
||||
@IsArray()
|
||||
@IsString({ each: true })
|
||||
@Type(() => String)
|
||||
@IsOptional()
|
||||
responseChannelIds = process.env.RESPONSE_CHANNEL_IDS
|
||||
? process.env.RESPONSE_CHANNEL_IDS.split(',').map((id) => id.trim())
|
||||
: [];
|
||||
}
|
||||
|
||||
@@ -21,9 +21,6 @@ export const inviteCommand = new SlashCommandBuilder()
|
||||
export const messageCommand = new SlashCommandBuilder()
|
||||
.setName(config.slashCommandName)
|
||||
.setDescription('Generate a message from learned past messages')
|
||||
.addBooleanOption((tts) =>
|
||||
tts.setName('tts').setDescription('Read the message via text-to-speech.').setRequired(false),
|
||||
)
|
||||
.addBooleanOption((debug) =>
|
||||
debug
|
||||
.setName('debug')
|
||||
@@ -49,6 +46,38 @@ const channelOptionsGenerator = (builder: SlashCommandChannelOption, index: numb
|
||||
.setRequired(index === 0)
|
||||
.addChannelTypes(ChannelType.GuildText);
|
||||
|
||||
export const autoRespondCommand = new SlashCommandBuilder()
|
||||
.setName('autorespond')
|
||||
.setDescription('Configure channels where the bot will automatically respond to all messages')
|
||||
.addSubcommand((sub) => {
|
||||
sub
|
||||
.setName('add')
|
||||
.setDescription('Add channels where the bot will automatically respond to all messages');
|
||||
Array.from(Array(CHANNEL_OPTIONS_MAX).keys()).forEach((index) =>
|
||||
sub.addChannelOption((opt) => channelOptionsGenerator(opt, index)),
|
||||
);
|
||||
return sub;
|
||||
})
|
||||
.addSubcommand((sub) => {
|
||||
sub
|
||||
.setName('remove')
|
||||
.setDescription('Remove channels from auto-response');
|
||||
Array.from(Array(CHANNEL_OPTIONS_MAX).keys()).forEach((index) =>
|
||||
sub.addChannelOption((opt) => channelOptionsGenerator(opt, index)),
|
||||
);
|
||||
return sub;
|
||||
})
|
||||
.addSubcommand((sub) =>
|
||||
sub
|
||||
.setName('list')
|
||||
.setDescription('List the channels where the bot auto-responds to messages'),
|
||||
)
|
||||
.addSubcommand((sub) =>
|
||||
sub
|
||||
.setName('modify')
|
||||
.setDescription('Add or remove auto-respond channels via select menu UI (first 25 text channels only)'),
|
||||
);
|
||||
|
||||
export const listenChannelCommand = new SlashCommandBuilder()
|
||||
.setName('listen')
|
||||
.setDescription('Change what channels the bot actively listens to and learns from.')
|
||||
@@ -110,7 +139,8 @@ const commands = [
|
||||
inviteCommand.toJSON(),
|
||||
messageCommand.toJSON(),
|
||||
listenChannelCommand.toJSON(),
|
||||
trainCommand.toJSON(),
|
||||
autoRespondCommand.toJSON(),
|
||||
trainCommand.toJSON()
|
||||
];
|
||||
|
||||
export async function deployCommands(clientId: string) {
|
||||
|
||||
@@ -12,6 +12,11 @@ export class Channel extends BaseEntity {
|
||||
})
|
||||
listen: boolean;
|
||||
|
||||
@Column({
|
||||
default: false,
|
||||
})
|
||||
autoRespond: boolean;
|
||||
|
||||
@ManyToOne(() => Guild, (guild) => guild.channels)
|
||||
guild: Guild;
|
||||
}
|
||||
|
||||
424
src/index.ts
424
src/index.ts
@@ -1,6 +1,8 @@
|
||||
import 'source-map-support/register';
|
||||
import { CONFIG_DIR } from './config/setup';
|
||||
import 'reflect-metadata';
|
||||
import * as Discord from 'discord.js';
|
||||
|
||||
import Markov, {
|
||||
MarkovGenerateOptions,
|
||||
MarkovConstructorOptions,
|
||||
@@ -24,6 +26,7 @@ import {
|
||||
listenChannelCommand,
|
||||
messageCommand,
|
||||
trainCommand,
|
||||
autoRespondCommand,
|
||||
} from './deploy-commands';
|
||||
import { getRandomElement, getVersion, packageJson } from './util';
|
||||
import ormconfig from './ormconfig';
|
||||
@@ -35,6 +38,7 @@ interface MarkovDataCustom {
|
||||
interface SelectMenuChannel {
|
||||
id: string;
|
||||
listen?: boolean;
|
||||
autoRespond?: boolean;
|
||||
name?: string;
|
||||
}
|
||||
|
||||
@@ -53,11 +57,19 @@ type AgnosticReplyOptions = Omit<Discord.MessageCreateOptions, 'reply' | 'sticke
|
||||
const INVALID_PERMISSIONS_MESSAGE = 'You do not have the permissions for this action.';
|
||||
const INVALID_GUILD_MESSAGE = 'This action must be performed within a server.';
|
||||
|
||||
const rest = new Discord.REST({ version: '10' }).setToken(config.token);
|
||||
const rest = new Discord.REST({
|
||||
version: '10',
|
||||
timeout: 120000, // 120 seconds
|
||||
retries: 3
|
||||
}).setToken(config.token);
|
||||
|
||||
const client = new Discord.Client<true>({
|
||||
failIfNotExists: false,
|
||||
intents: [Discord.GatewayIntentBits.GuildMessages, Discord.GatewayIntentBits.Guilds],
|
||||
intents: [
|
||||
Discord.GatewayIntentBits.GuildMessages,
|
||||
Discord.GatewayIntentBits.Guilds,
|
||||
Discord.GatewayIntentBits.GuildMembers
|
||||
],
|
||||
presence: {
|
||||
activities: [
|
||||
{
|
||||
@@ -114,6 +126,53 @@ async function isValidChannel(channel: Discord.TextBasedChannel): Promise<boolea
|
||||
return dbChannel?.listen || false;
|
||||
}
|
||||
|
||||
async function isAutoRespondChannel(channel: Discord.TextBasedChannel): Promise<boolean> {
|
||||
const channelId = getGuildChannelId(channel);
|
||||
if (!channelId) return false;
|
||||
const dbChannel = await Channel.findOneBy({ id: channelId });
|
||||
return dbChannel?.autoRespond || false;
|
||||
}
|
||||
|
||||
async function getAutoRespondChannels(guild: Discord.Guild): Promise<Discord.TextChannel[]> {
|
||||
const dbChannels = await Channel.findBy({ guild: { id: guild.id }, autoRespond: true });
|
||||
const channels = (
|
||||
await Promise.all(
|
||||
dbChannels.map(async (dbc) => {
|
||||
try {
|
||||
return guild.channels.fetch(dbc.id);
|
||||
} catch (err) {
|
||||
L.error({ erroredChannel: dbc, channelId: dbc.id }, 'Error fetching channel');
|
||||
throw err;
|
||||
}
|
||||
}),
|
||||
)
|
||||
).filter((c): c is Discord.TextChannel => c !== null && c instanceof Discord.TextChannel);
|
||||
return channels;
|
||||
}
|
||||
|
||||
async function addAutoRespondChannels(channels: Discord.TextChannel[], guildId: string): Promise<void> {
|
||||
const dbChannels = channels.map((c) => {
|
||||
return Channel.create({ id: c.id, guild: Guild.create({ id: guildId }), autoRespond: true });
|
||||
});
|
||||
await Channel.save(dbChannels);
|
||||
}
|
||||
|
||||
async function removeAutoRespondChannels(channels: Discord.TextChannel[], guildId: string): Promise<void> {
|
||||
const dbChannels = channels.map((c) => {
|
||||
return Channel.create({ id: c.id, guild: Guild.create({ id: guildId }), autoRespond: false });
|
||||
});
|
||||
await Channel.save(dbChannels);
|
||||
}
|
||||
|
||||
async function listAutoRespondChannels(interaction: Discord.CommandInteraction): Promise<string> {
|
||||
if (!interaction.guildId || !interaction.guild) return INVALID_GUILD_MESSAGE;
|
||||
const channels = await getAutoRespondChannels(interaction.guild);
|
||||
const channelText = channels.reduce((list, channel) => {
|
||||
return `${list}\n • <#${channel.id}>`;
|
||||
}, '');
|
||||
return `The bot will automatically respond to all messages in ${channels.length} channel(s).${channelText}`;
|
||||
}
|
||||
|
||||
function isHumanAuthoredMessage(message: Discord.Message | Discord.PartialMessage): boolean {
|
||||
return !(message.author?.bot || message.system);
|
||||
}
|
||||
@@ -151,7 +210,12 @@ async function getTextChannels(guild: Discord.Guild): Promise<SelectMenuChannel[
|
||||
}));
|
||||
const notFoundDbChannels: SelectMenuChannel[] = textChannels
|
||||
.filter((c) => !foundDbChannels.find((d) => d.id === c.id))
|
||||
.map((c) => ({ id: c.id, listen: false, name: textChannels.find((t) => t.id === c.id)?.name }));
|
||||
.map((c) => ({
|
||||
id: c.id,
|
||||
listen: false,
|
||||
autoRespond: false,
|
||||
name: textChannels.find((t) => t.id === c.id)?.name
|
||||
}));
|
||||
const limitedDbChannels = foundDbChannelsWithName
|
||||
.concat(notFoundDbChannels)
|
||||
.slice(0, MAX_SELECT_OPTIONS);
|
||||
@@ -223,7 +287,7 @@ function isAllowedUser(
|
||||
return true;
|
||||
}
|
||||
|
||||
type MessageCommands = 'respond' | 'train' | 'help' | 'invite' | 'debug' | 'tts' | null;
|
||||
type MessageCommands = 'respond' | 'train' | 'help' | 'invite' | 'debug' | null;
|
||||
|
||||
/**
|
||||
* Reads a new message and checks if and which command it is.
|
||||
@@ -246,8 +310,6 @@ function validateMessage(message: Discord.Message): MessageCommands {
|
||||
command = 'invite';
|
||||
} else if (split[1] === 'debug') {
|
||||
command = 'debug';
|
||||
} else if (split[1] === 'tts') {
|
||||
command = 'tts';
|
||||
}
|
||||
}
|
||||
return command;
|
||||
@@ -272,12 +334,23 @@ function messageToData(message: Discord.Message): AddDataProps {
|
||||
/**
|
||||
* Recursively gets all messages in a text channel's history.
|
||||
*/
|
||||
import { TrainingStateManager } from './training-state';
|
||||
|
||||
async function saveGuildMessageHistory(
|
||||
interaction: Discord.Message | Discord.CommandInteraction,
|
||||
clean = true,
|
||||
): Promise<string> {
|
||||
if (!isModerator(interaction.member)) return INVALID_PERMISSIONS_MESSAGE;
|
||||
if (!interaction.guildId || !interaction.guild) return INVALID_GUILD_MESSAGE;
|
||||
|
||||
const stateManager = new TrainingStateManager(interaction.guildId, CONFIG_DIR);
|
||||
|
||||
// Check if training is already in progress
|
||||
const currentState = stateManager.getState();
|
||||
if (currentState.inProgress) {
|
||||
return `Training is already in progress. Last update: ${currentState.lastUpdate}. Use /train with clean=true to restart.`;
|
||||
}
|
||||
|
||||
const markov = await getMarkovByGuildId(interaction.guildId);
|
||||
const channels = await getValidChannels(interaction.guild);
|
||||
|
||||
@@ -287,12 +360,23 @@ async function saveGuildMessageHistory(
|
||||
}
|
||||
|
||||
if (clean) {
|
||||
L.debug('Deleting old data');
|
||||
L.debug('Deleting old data and resetting state');
|
||||
await markov.delete();
|
||||
stateManager.reset();
|
||||
} else {
|
||||
L.debug('Not deleting old data during training');
|
||||
// Filter out already processed channels when not cleaning
|
||||
const unprocessedChannels = channels.filter(
|
||||
channel => !stateManager.isChannelProcessed(channel.id)
|
||||
);
|
||||
if (unprocessedChannels.length === 0) {
|
||||
return 'All channels have been processed. Use clean=true to retrain.';
|
||||
}
|
||||
channels.splice(0, channels.length, ...unprocessedChannels);
|
||||
}
|
||||
|
||||
stateManager.startTraining();
|
||||
|
||||
const channelIds = channels.map((c) => c.id);
|
||||
L.debug({ channelIds }, `Training from text channels`);
|
||||
|
||||
@@ -332,20 +416,42 @@ async function saveGuildMessageHistory(
|
||||
progressMessage = (await interaction.followUp(updateMessageData)) as Discord.Message;
|
||||
}
|
||||
|
||||
const PAGE_SIZE = 100;
|
||||
const UPDATE_RATE = 1000; // In number of messages processed
|
||||
const PAGE_SIZE = 50; // Reduced page size for better stability
|
||||
const UPDATE_RATE = 500; // More frequent updates
|
||||
const BATCH_SIZE = 100; // Number of messages to process before a small delay
|
||||
const BATCH_DELAY = 100; // Milliseconds to wait between batches
|
||||
const MAX_MEMORY_USAGE = 1024 * 1024 * 1024; // 1GB memory limit
|
||||
|
||||
let lastUpdate = 0;
|
||||
let messagesCount = 0;
|
||||
let firstMessageDate: number | undefined;
|
||||
// eslint-disable-next-line no-restricted-syntax
|
||||
for (const channel of channels) {
|
||||
let oldestMessageID: string | undefined;
|
||||
let keepGoing = true;
|
||||
L.debug({ channelId: channel.id, messagesCount }, `Training from channel`);
|
||||
const channelCreateDate = channel.createdTimestamp;
|
||||
const channelEta = makeEta({ autostart: true, min: 0, max: 1, historyTimeConstant: 30 });
|
||||
let batchCount = 0;
|
||||
|
||||
while (keepGoing) {
|
||||
// Monitor memory usage
|
||||
const getMemoryUsage = () => {
|
||||
const used = process.memoryUsage();
|
||||
return used.heapUsed;
|
||||
};
|
||||
|
||||
// Add delay between batches
|
||||
const processingDelay = () => new Promise(resolve => setTimeout(resolve, BATCH_DELAY));
|
||||
|
||||
try {
|
||||
// eslint-disable-next-line no-restricted-syntax
|
||||
for (const channel of channels) {
|
||||
try {
|
||||
// Check if we should skip this channel (already processed)
|
||||
if (stateManager.isChannelProcessed(channel.id)) {
|
||||
L.debug({ channelId: channel.id }, 'Skipping already processed channel');
|
||||
continue;
|
||||
}
|
||||
let keepGoing = true;
|
||||
let oldestMessageID = stateManager.shouldResumeFromMessage(channel.id);
|
||||
L.debug({ channelId: channel.id, messagesCount }, `Training from channel`);
|
||||
const channelCreateDate = channel.createdTimestamp;
|
||||
const channelEta = makeEta({ autostart: true, min: 0, max: 1, historyTimeConstant: 30 });
|
||||
|
||||
while (keepGoing) {
|
||||
let allBatchMessages = new Discord.Collection<string, Discord.Message<boolean>>();
|
||||
let channelBatchMessages: Discord.Collection<string, Discord.Message<boolean>>;
|
||||
try {
|
||||
@@ -407,15 +513,55 @@ async function saveGuildMessageHistory(
|
||||
|
||||
allBatchMessages = allBatchMessages.concat(channelBatchMessages);
|
||||
|
||||
// Filter and data map messages to be ready for addition to the corpus
|
||||
const humanAuthoredMessages = allBatchMessages
|
||||
.filter((m) => isHumanAuthoredMessage(m))
|
||||
.map(messageToData);
|
||||
L.trace({ oldestMessageID }, `Saving ${humanAuthoredMessages.length} messages`);
|
||||
// eslint-disable-next-line no-await-in-loop
|
||||
await markov.addData(humanAuthoredMessages);
|
||||
L.trace('Finished saving messages');
|
||||
messagesCount += humanAuthoredMessages.length;
|
||||
try {
|
||||
// Check memory usage before processing
|
||||
const memoryUsage = getMemoryUsage();
|
||||
if (memoryUsage > MAX_MEMORY_USAGE) {
|
||||
L.warn('Memory usage too high, waiting for garbage collection');
|
||||
await processingDelay();
|
||||
global.gc?.(); // Optional garbage collection if --expose-gc flag is used
|
||||
}
|
||||
|
||||
// Filter and data map messages to be ready for addition to the corpus
|
||||
const humanAuthoredMessages = allBatchMessages
|
||||
.filter((m) => isHumanAuthoredMessage(m))
|
||||
.map(messageToData);
|
||||
|
||||
// Process messages in smaller batches for stability
|
||||
for (let i = 0; i < humanAuthoredMessages.length; i += BATCH_SIZE) {
|
||||
const batch = humanAuthoredMessages.slice(i, i + BATCH_SIZE);
|
||||
L.trace({ oldestMessageID, batchSize: batch.length }, `Saving batch of messages`);
|
||||
|
||||
try {
|
||||
// eslint-disable-next-line no-await-in-loop
|
||||
await markov.addData(batch);
|
||||
batchCount++;
|
||||
messagesCount += batch.length;
|
||||
|
||||
// Update state after successful batch
|
||||
const lastMessage = allBatchMessages.last();
|
||||
if (lastMessage) {
|
||||
stateManager.updateProgress(channel.id, lastMessage.id, messagesCount);
|
||||
}
|
||||
|
||||
// Add delay between batches
|
||||
if (batchCount % 5 === 0) { // Every 5 batches
|
||||
await processingDelay();
|
||||
}
|
||||
} catch (err) {
|
||||
stateManager.recordError(err as Error, channel.id, oldestMessageID);
|
||||
L.error({ err, batchSize: batch.length }, 'Error saving batch of messages');
|
||||
// Continue with next batch instead of failing completely
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
L.trace('Finished processing message batches');
|
||||
} catch (err) {
|
||||
L.error({ err }, 'Error processing messages');
|
||||
// Wait a bit before continuing to next batch of messages
|
||||
await processingDelay();
|
||||
}
|
||||
const lastMessage = channelBatchMessages.last();
|
||||
|
||||
// Update tracking metrics
|
||||
@@ -457,12 +603,24 @@ async function saveGuildMessageHistory(
|
||||
...updateMessageData,
|
||||
embeds: [new Discord.EmbedBuilder(embedOptions)],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
L.error({ err }, 'Error processing channel');
|
||||
stateManager.recordError(err as Error);
|
||||
// Continue with next channel
|
||||
}
|
||||
}
|
||||
|
||||
L.info({ channelIds }, `Trained from ${messagesCount} past human authored messages.`);
|
||||
return `Trained from ${messagesCount} past human authored messages.`;
|
||||
L.info({ channelIds }, `Trained from ${messagesCount} past human authored messages.`);
|
||||
stateManager.finishTraining();
|
||||
return `Trained from ${messagesCount} past human authored messages.`;
|
||||
} catch (err) {
|
||||
const error = err as Error;
|
||||
L.error({ err }, 'Error during training completion');
|
||||
stateManager.recordError(error);
|
||||
return `Training encountered an error: ${error.message}. Use clean=false to resume from last checkpoint.`;
|
||||
}
|
||||
}
|
||||
|
||||
interface JSONImport {
|
||||
@@ -481,7 +639,17 @@ async function trainFromAttachmentJson(
|
||||
if (!isModerator(interaction.member)) return INVALID_PERMISSIONS_MESSAGE;
|
||||
if (!interaction.guildId || !interaction.guild) return INVALID_GUILD_MESSAGE;
|
||||
const { guildId } = interaction;
|
||||
|
||||
const stateManager = new TrainingStateManager(guildId, CONFIG_DIR);
|
||||
|
||||
// Check if training is already in progress
|
||||
const currentState = stateManager.getState();
|
||||
if (currentState.inProgress) {
|
||||
return `Training is already in progress. Last update: ${currentState.lastUpdate}. Use clean=true to restart.`;
|
||||
}
|
||||
|
||||
const markov = await getMarkovByGuildId(guildId);
|
||||
stateManager.startTraining();
|
||||
|
||||
let trainingData: AddDataProps[];
|
||||
try {
|
||||
@@ -517,14 +685,49 @@ async function trainFromAttachmentJson(
|
||||
if (clean) {
|
||||
L.debug('Deleting old data');
|
||||
await markov.delete();
|
||||
stateManager.reset();
|
||||
} else {
|
||||
L.debug('Not deleting old data during training');
|
||||
}
|
||||
|
||||
await markov.addData(trainingData);
|
||||
const BATCH_SIZE = 100;
|
||||
const BATCH_DELAY = 100;
|
||||
let processedCount = 0;
|
||||
let batchCount = 0;
|
||||
|
||||
L.info(`Trained from ${trainingData.length} past human authored messages.`);
|
||||
return `Trained from ${trainingData.length} past human authored messages.`;
|
||||
try {
|
||||
// Process messages in batches
|
||||
for (let i = 0; i < trainingData.length; i += BATCH_SIZE) {
|
||||
const batch = trainingData.slice(i, i + BATCH_SIZE);
|
||||
try {
|
||||
await markov.addData(batch);
|
||||
processedCount += batch.length;
|
||||
batchCount++;
|
||||
|
||||
// Update state after successful batch
|
||||
stateManager.updateProgress('json-import', i.toString(), processedCount);
|
||||
|
||||
// Add delay between batches
|
||||
if (batchCount % 5 === 0) {
|
||||
await new Promise(resolve => setTimeout(resolve, BATCH_DELAY));
|
||||
}
|
||||
} catch (err) {
|
||||
L.error({ err, batchIndex: i }, 'Error processing JSON batch');
|
||||
stateManager.recordError(err as Error, 'json-import', i.toString());
|
||||
// Continue with next batch instead of failing completely
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
L.info(`Successfully trained from ${processedCount} messages from JSON.`);
|
||||
stateManager.finishTraining();
|
||||
return `Successfully trained from ${processedCount} messages from JSON.`;
|
||||
} catch (err) {
|
||||
const error = err as Error;
|
||||
L.error({ err }, 'Error during JSON training completion');
|
||||
stateManager.recordError(error);
|
||||
return `Training encountered an error: ${error.message}. Use clean=false to resume from last checkpoint.`;
|
||||
}
|
||||
}
|
||||
|
||||
interface GenerateResponse {
|
||||
@@ -534,7 +737,6 @@ interface GenerateResponse {
|
||||
}
|
||||
|
||||
interface GenerateOptions {
|
||||
tts?: boolean;
|
||||
debug?: boolean;
|
||||
startSeed?: string;
|
||||
}
|
||||
@@ -551,7 +753,7 @@ async function generateResponse(
|
||||
options?: GenerateOptions,
|
||||
): Promise<GenerateResponse> {
|
||||
L.debug({ options }, 'Responding...');
|
||||
const { tts = false, debug = false, startSeed } = options || {};
|
||||
const { debug = false, startSeed } = options || {};
|
||||
if (!interaction.guildId) {
|
||||
L.warn('Received an interaction without a guildId');
|
||||
return { error: { content: INVALID_GUILD_MESSAGE } };
|
||||
@@ -568,7 +770,6 @@ async function generateResponse(
|
||||
L.info({ string: response.string }, 'Generated response text');
|
||||
L.debug({ response }, 'Generated response object');
|
||||
const messageOpts: AgnosticReplyOptions = {
|
||||
tts,
|
||||
allowedMentions: { repliedUser: false, parse: [] },
|
||||
};
|
||||
const attachmentUrls = response.refs
|
||||
@@ -652,12 +853,17 @@ function helpMessage(): AgnosticReplyOptions {
|
||||
.addFields([
|
||||
{
|
||||
name: `${config.messageCommandPrefix} or /${messageCommand.name}`,
|
||||
value: `Generates a sentence to say based on the chat database. Send your message as TTS to recieve it as TTS.`,
|
||||
value: `Generates a sentence based on the chat database.`,
|
||||
},
|
||||
|
||||
{
|
||||
name: `/${listenChannelCommand.name}`,
|
||||
value: `Add, remove, list, or modify the list of channels the bot listens to.`,
|
||||
value: `Add, remove, list, or modify the list of channels the bot listens to and learns from.`,
|
||||
},
|
||||
|
||||
{
|
||||
name: `/autorespond`,
|
||||
value: `Add, remove, list, or modify the list of channels where the bot will automatically respond to all messages.`,
|
||||
},
|
||||
|
||||
{
|
||||
@@ -674,11 +880,6 @@ function helpMessage(): AgnosticReplyOptions {
|
||||
name: `${config.messageCommandPrefix} debug or /${messageCommand.name} debug: True`,
|
||||
value: `Runs the ${config.messageCommandPrefix} command and follows it up with debug info.`,
|
||||
},
|
||||
|
||||
{
|
||||
name: `${config.messageCommandPrefix} tts or /${messageCommand.name} tts: True`,
|
||||
value: `Runs the ${config.messageCommandPrefix} command and reads it with text-to-speech.`,
|
||||
},
|
||||
])
|
||||
.setFooter({
|
||||
text: `${packageJson().name} ${getVersion()} by ${
|
||||
@@ -694,12 +895,11 @@ function generateInviteUrl(): string {
|
||||
return client.generateInvite({
|
||||
scopes: [Discord.OAuth2Scopes.Bot, Discord.OAuth2Scopes.ApplicationsCommands],
|
||||
permissions: [
|
||||
'ViewChannel',
|
||||
'SendMessages',
|
||||
'SendTTSMessages',
|
||||
'AttachFiles',
|
||||
'ReadMessageHistory',
|
||||
],
|
||||
'ViewChannel',
|
||||
'SendMessages',
|
||||
'AttachFiles',
|
||||
'ReadMessageHistory'
|
||||
],
|
||||
});
|
||||
}
|
||||
|
||||
@@ -789,11 +989,6 @@ client.on('messageCreate', async (message) => {
|
||||
const generatedResponse = await generateResponse(message);
|
||||
await handleResponseMessage(generatedResponse, message);
|
||||
}
|
||||
if (command === 'tts') {
|
||||
L.debug('Responding to legacy command tts');
|
||||
const generatedResponse = await generateResponse(message, { tts: true });
|
||||
await handleResponseMessage(generatedResponse, message);
|
||||
}
|
||||
if (command === 'debug') {
|
||||
L.debug('Responding to legacy command debug');
|
||||
const generatedResponse = await generateResponse(message, { debug: true });
|
||||
@@ -802,11 +997,23 @@ client.on('messageCreate', async (message) => {
|
||||
if (command === null) {
|
||||
if (isHumanAuthoredMessage(message)) {
|
||||
if (client.user && message.mentions.has(client.user)) {
|
||||
// Check if response channels are configured and if this channel is allowed
|
||||
if (config.responseChannelIds.length > 0 && !config.responseChannelIds.includes(message.channel.id)) {
|
||||
L.debug('Ignoring mention in non-response channel');
|
||||
return;
|
||||
}
|
||||
|
||||
L.debug('Responding to mention');
|
||||
// <@!278354154563567636> how are you doing?
|
||||
const startSeed = message.content.replace(/<@!\d+>/g, '').trim();
|
||||
const generatedResponse = await generateResponse(message, { startSeed });
|
||||
await handleResponseMessage(generatedResponse, message);
|
||||
} else if (await isAutoRespondChannel(message.channel)) {
|
||||
// Auto-respond to all messages in configured channels using message content as context
|
||||
L.debug('Auto-responding in configured channel with context');
|
||||
const startSeed = message.content.trim();
|
||||
const generatedResponse = await generateResponse(message, { startSeed });
|
||||
await handleResponseMessage(generatedResponse, message);
|
||||
}
|
||||
|
||||
if (await isValidChannel(message.channel)) {
|
||||
@@ -848,7 +1055,7 @@ client.on('threadDelete', async (thread) => {
|
||||
await markov.removeTags([thread.id]);
|
||||
});
|
||||
|
||||
// eslint-disable-next-line consistent-return
|
||||
|
||||
client.on('interactionCreate', async (interaction) => {
|
||||
if (interaction.isChatInputCommand()) {
|
||||
L.info({ command: interaction.commandName }, 'Recieved slash command');
|
||||
@@ -859,23 +1066,12 @@ client.on('interactionCreate', async (interaction) => {
|
||||
await interaction.reply(inviteMessage());
|
||||
} else if (interaction.commandName === messageCommand.name) {
|
||||
await interaction.deferReply();
|
||||
const tts = interaction.options.getBoolean('tts') || false;
|
||||
const debug = interaction.options.getBoolean('debug') || false;
|
||||
const startSeed = interaction.options.getString('seed')?.trim() || undefined;
|
||||
const generatedResponse = await generateResponse(interaction, { tts, debug, startSeed });
|
||||
const generatedResponse = await generateResponse(interaction, { debug, startSeed });
|
||||
|
||||
/**
|
||||
* TTS doesn't work when using editReply, so instead we use delete + followUp
|
||||
* However, delete + followUp is ugly and shows the bot replying to "Message could not be loaded.",
|
||||
* so we avoid it if possible
|
||||
*/
|
||||
if (generatedResponse.message) {
|
||||
if (generatedResponse.message.tts) {
|
||||
await interaction.deleteReply();
|
||||
await interaction.followUp(generatedResponse.message);
|
||||
} else {
|
||||
await interaction.editReply(generatedResponse.message);
|
||||
}
|
||||
await interaction.editReply(generatedResponse.message);
|
||||
} else {
|
||||
await interaction.deleteReply();
|
||||
}
|
||||
@@ -943,6 +1139,67 @@ client.on('interactionCreate', async (interaction) => {
|
||||
ephemeral: true,
|
||||
});
|
||||
}
|
||||
} else if (interaction.commandName === autoRespondCommand.name) {
|
||||
await interaction.deferReply();
|
||||
const subCommand = interaction.options.getSubcommand(true) as 'add' | 'remove' | 'list' | 'modify';
|
||||
|
||||
if (subCommand === 'list') {
|
||||
const reply = await listAutoRespondChannels(interaction);
|
||||
await interaction.editReply(reply);
|
||||
} else if (subCommand === 'add') {
|
||||
if (!isModerator(interaction.member)) {
|
||||
return handleUnprivileged(interaction);
|
||||
}
|
||||
if (!interaction.guildId) {
|
||||
return handleNoGuild(interaction);
|
||||
}
|
||||
const channels = getChannelsFromInteraction(interaction);
|
||||
await addAutoRespondChannels(channels, interaction.guildId);
|
||||
await interaction.editReply(
|
||||
`Added ${channels.length} text channels to auto-respond list.`
|
||||
);
|
||||
} else if (subCommand === 'remove') {
|
||||
if (!isModerator(interaction.member)) {
|
||||
return handleUnprivileged(interaction);
|
||||
}
|
||||
if (!interaction.guildId) {
|
||||
return handleNoGuild(interaction);
|
||||
}
|
||||
const channels = getChannelsFromInteraction(interaction);
|
||||
await removeAutoRespondChannels(channels, interaction.guildId);
|
||||
await interaction.editReply(
|
||||
`Removed ${channels.length} text channels from auto-respond list.`
|
||||
);
|
||||
} else if (subCommand === 'modify') {
|
||||
if (!interaction.guild) {
|
||||
return handleNoGuild(interaction);
|
||||
}
|
||||
if (!isModerator(interaction.member)) {
|
||||
await handleUnprivileged(interaction);
|
||||
}
|
||||
await interaction.deleteReply();
|
||||
const dbTextChannels = await getTextChannels(interaction.guild);
|
||||
const row = new Discord.ActionRowBuilder<Discord.StringSelectMenuBuilder>().addComponents(
|
||||
new Discord.StringSelectMenuBuilder()
|
||||
.setCustomId('autorespond-modify-select')
|
||||
.setPlaceholder('Nothing selected')
|
||||
.setMinValues(0)
|
||||
.setMaxValues(dbTextChannels.length)
|
||||
.addOptions(
|
||||
dbTextChannels.map((c) => ({
|
||||
label: `#${c.name}` || c.id,
|
||||
value: c.id,
|
||||
default: c.autoRespond || false,
|
||||
})),
|
||||
),
|
||||
);
|
||||
|
||||
await interaction.followUp({
|
||||
content: 'Select which channels you would like the bot to auto-respond in',
|
||||
components: [row],
|
||||
ephemeral: true,
|
||||
});
|
||||
}
|
||||
} else if (interaction.commandName === trainCommand.name) {
|
||||
await interaction.deferReply();
|
||||
const clean = interaction.options.getBoolean('clean') ?? true;
|
||||
@@ -990,6 +1247,37 @@ client.on('interactionCreate', async (interaction) => {
|
||||
content: 'Updated actively listened to channels list.',
|
||||
ephemeral: true,
|
||||
});
|
||||
} else if (interaction.customId === 'autorespond-modify-select') {
|
||||
await interaction.deferUpdate();
|
||||
const { guild } = interaction;
|
||||
if (!isModerator(interaction.member)) {
|
||||
return handleUnprivileged(interaction, false);
|
||||
}
|
||||
if (!guild) {
|
||||
return handleNoGuild(interaction, false);
|
||||
}
|
||||
|
||||
const allChannels =
|
||||
(interaction.component as Discord.StringSelectMenuComponent).options?.map((o) => o.value) ||
|
||||
[];
|
||||
const selectedChannelIds = interaction.values;
|
||||
|
||||
const textChannels = (
|
||||
await Promise.all(
|
||||
allChannels.map(async (c) => {
|
||||
return guild.channels.fetch(c);
|
||||
}),
|
||||
)
|
||||
).filter((c): c is Discord.TextChannel => c !== null && c instanceof Discord.TextChannel);
|
||||
const unselectedChannels = textChannels.filter((t) => !selectedChannelIds.includes(t.id));
|
||||
const selectedChannels = textChannels.filter((t) => selectedChannelIds.includes(t.id));
|
||||
await addAutoRespondChannels(selectedChannels, guild.id);
|
||||
await removeAutoRespondChannels(unselectedChannels, guild.id);
|
||||
|
||||
await interaction.followUp({
|
||||
content: 'Updated auto-respond channels list.',
|
||||
ephemeral: true,
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
391
src/train.ts
Normal file
391
src/train.ts
Normal file
@@ -0,0 +1,391 @@
|
||||
import 'source-map-support/register';
|
||||
import 'reflect-metadata';
|
||||
import Markov, { MarkovConstructorOptions, AddDataProps } from 'markov-strings-db';
|
||||
import { DataSource } from 'typeorm';
|
||||
import { promises as fs } from 'fs';
|
||||
import path from 'path';
|
||||
import { config } from './config';
|
||||
import ormconfig from './ormconfig';
|
||||
import { Guild } from './entity/Guild';
|
||||
import { Channel } from './entity/Channel';
|
||||
import L from './logger';
|
||||
import { MarkovDataCustom } from './types';
|
||||
import { TrainingStateManager } from './training-state';
|
||||
import { CONFIG_DIR } from './config/setup';
|
||||
|
||||
const markovOpts: MarkovConstructorOptions = {
|
||||
stateSize: config.stateSize,
|
||||
};
|
||||
|
||||
// Constants for batch processing
|
||||
const BATCH_SIZE = 100; // Process messages in batches
|
||||
const BATCH_DELAY = 100; // Milliseconds to wait between batches
|
||||
const MAX_MEMORY_USAGE = 1024 * 1024 * 1024; // 1GB memory limit
|
||||
|
||||
// Monitor memory usage
|
||||
const getMemoryUsage = () => {
|
||||
const used = process.memoryUsage();
|
||||
return used.heapUsed;
|
||||
};
|
||||
|
||||
// Add delay between batches
|
||||
const processingDelay = () => new Promise(resolve => setTimeout(resolve, BATCH_DELAY));
|
||||
|
||||
async function getMarkovByGuildId(guildId: string): Promise<Markov> {
|
||||
const markov = new Markov({ id: guildId, options: { ...markovOpts, id: guildId } });
|
||||
L.trace({ guildId }, 'Setting up markov instance');
|
||||
await markov.setup(); // Connect the markov instance to the DB to assign it an ID
|
||||
return markov;
|
||||
}
|
||||
|
||||
interface JSONImport {
|
||||
message: string;
|
||||
attachments?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Train from a JSON file containing messages
|
||||
*/
|
||||
|
||||
async function trainFromJson(
|
||||
guildId: string,
|
||||
jsonPath: string,
|
||||
clean = true,
|
||||
): Promise<string> {
|
||||
const markov = await getMarkovByGuildId(guildId);
|
||||
|
||||
let trainingData: AddDataProps[];
|
||||
try {
|
||||
const fileContent = await fs.readFile(jsonPath, 'utf-8');
|
||||
const importData = JSON.parse(fileContent) as JSONImport[];
|
||||
|
||||
// Filter out invalid entries first
|
||||
const validData = importData.filter((datum, index) => {
|
||||
if (!datum.message || typeof datum.message !== 'string') {
|
||||
L.debug({ index }, 'Skipping entry without valid message');
|
||||
return false;
|
||||
}
|
||||
if (datum.attachments?.some(a => typeof a !== 'string')) {
|
||||
L.debug({ index }, 'Skipping entry with invalid attachments');
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
|
||||
// Map valid entries to training data
|
||||
trainingData = validData.map(datum => {
|
||||
let custom: MarkovDataCustom | undefined;
|
||||
if (datum.attachments?.length) {
|
||||
custom = { attachments: datum.attachments };
|
||||
}
|
||||
return {
|
||||
string: datum.message,
|
||||
custom,
|
||||
tags: [guildId]
|
||||
};
|
||||
});
|
||||
} catch (err) {
|
||||
L.error(err);
|
||||
if (err instanceof SyntaxError) {
|
||||
return 'The provided JSON file has invalid formatting. See the logs for details.';
|
||||
}
|
||||
return `Error reading file: ${err instanceof Error ? err.message : 'Unknown error'}`;
|
||||
}
|
||||
|
||||
if (clean) {
|
||||
L.debug('Deleting old data');
|
||||
await markov.delete();
|
||||
} else {
|
||||
L.debug('Not deleting old data during training');
|
||||
}
|
||||
|
||||
let processedCount = 0;
|
||||
let batchCount = 0;
|
||||
const totalMessages = trainingData.length;
|
||||
|
||||
// Process messages in batches
|
||||
for (let i = 0; i < trainingData.length; i += BATCH_SIZE) {
|
||||
try {
|
||||
// Check memory usage
|
||||
const memoryUsage = getMemoryUsage();
|
||||
if (memoryUsage > MAX_MEMORY_USAGE) {
|
||||
L.warn('Memory usage too high, waiting for garbage collection');
|
||||
await processingDelay();
|
||||
global.gc?.(); // Optional garbage collection if --expose-gc flag is used
|
||||
}
|
||||
|
||||
const batch = trainingData.slice(i, i + BATCH_SIZE);
|
||||
await markov.addData(batch);
|
||||
|
||||
processedCount += batch.length;
|
||||
batchCount++;
|
||||
|
||||
// Log progress
|
||||
if (batchCount % 5 === 0) {
|
||||
const progress = (processedCount / totalMessages * 100).toFixed(2);
|
||||
L.info(`Progress: ${progress}% (${processedCount}/${totalMessages} messages)`);
|
||||
await processingDelay(); // Add delay every 5 batches
|
||||
}
|
||||
} catch (err) {
|
||||
L.error({ err, batchIndex: i }, 'Error processing batch');
|
||||
// Continue with next batch instead of failing completely
|
||||
await processingDelay(); // Wait a bit longer after an error
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
L.info(`Successfully trained from ${processedCount} messages.`);
|
||||
return `Successfully trained from ${processedCount} messages.`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Train from all JSON files in a directory
|
||||
*/
|
||||
/**
|
||||
* Train from all JSON files in a directory
|
||||
* @param guildId The Discord guild ID
|
||||
* @param dirPath Path to directory containing JSON files
|
||||
* @param clean Whether to clean existing data before training
|
||||
*/
|
||||
/**
|
||||
* Acquire a lock file for training to prevent concurrent processes
|
||||
*/
|
||||
async function acquireTrainingLock(guildId: string): Promise<boolean> {
|
||||
const lockPath = path.join(CONFIG_DIR, `${guildId}_training.lock`);
|
||||
try {
|
||||
await fs.writeFile(lockPath, process.pid.toString(), { flag: 'wx' });
|
||||
return true;
|
||||
} catch (err) {
|
||||
if ((err as NodeJS.ErrnoException).code === 'EEXIST') {
|
||||
try {
|
||||
const pid = parseInt(await fs.readFile(lockPath, 'utf-8'));
|
||||
try {
|
||||
// Check if process is still running
|
||||
process.kill(pid, 0);
|
||||
return false; // Process is still running
|
||||
} catch {
|
||||
// Process is not running, safe to remove lock
|
||||
await fs.unlink(lockPath);
|
||||
await fs.writeFile(lockPath, process.pid.toString());
|
||||
return true;
|
||||
}
|
||||
} catch {
|
||||
// Error reading/writing lock file
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Release the training lock file
|
||||
*/
|
||||
async function releaseTrainingLock(guildId: string): Promise<void> {
|
||||
const lockPath = path.join(CONFIG_DIR, `${guildId}_training.lock`);
|
||||
try {
|
||||
await fs.unlink(lockPath);
|
||||
} catch {
|
||||
// Ignore errors during cleanup
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitize and validate a directory path
|
||||
*/
|
||||
async function validateDirectoryPath(dirPath: string): Promise<string> {
|
||||
// Resolve to absolute path
|
||||
const absolutePath = path.resolve(dirPath);
|
||||
|
||||
// Prevent directory traversal
|
||||
const normalizedPath = path.normalize(absolutePath);
|
||||
if (!normalizedPath.startsWith(process.cwd())) {
|
||||
throw new Error('Directory must be within current working directory');
|
||||
}
|
||||
|
||||
// Verify directory exists and is accessible
|
||||
try {
|
||||
const stats = await fs.stat(normalizedPath);
|
||||
if (!stats.isDirectory()) {
|
||||
throw new Error('Path is not a directory');
|
||||
}
|
||||
await fs.access(normalizedPath, fs.constants.R_OK);
|
||||
return normalizedPath;
|
||||
} catch (err) {
|
||||
throw new Error(`Invalid directory path: ${err instanceof Error ? err.message : 'Unknown error'}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Train from all JSON files in a directory
|
||||
*/
|
||||
async function trainFromDirectory(
|
||||
guildId: string,
|
||||
dirPath: string,
|
||||
clean = true,
|
||||
): Promise<string> {
|
||||
L.debug({ guildId, dirPath, clean }, 'Starting directory training');
|
||||
const stateManager = new TrainingStateManager(guildId, CONFIG_DIR);
|
||||
|
||||
// Set up cleanup handler
|
||||
const cleanup = async () => {
|
||||
try {
|
||||
await releaseTrainingLock(guildId);
|
||||
stateManager.finishTraining();
|
||||
} catch (err) {
|
||||
L.error({ err }, 'Error during cleanup');
|
||||
}
|
||||
};
|
||||
|
||||
// Handle process termination
|
||||
process.once('SIGINT', cleanup);
|
||||
process.once('SIGTERM', cleanup);
|
||||
|
||||
try {
|
||||
// Try to acquire lock
|
||||
if (!await acquireTrainingLock(guildId)) {
|
||||
return 'Another training process is already running. Please wait for it to complete.';
|
||||
}
|
||||
|
||||
// Always reset state at the start of training
|
||||
stateManager.reset();
|
||||
|
||||
try {
|
||||
// Validate and normalize directory path
|
||||
const absolutePath = await validateDirectoryPath(dirPath);
|
||||
|
||||
// Get all JSON files in the directory
|
||||
L.trace({ dirPath: absolutePath }, 'Reading directory');
|
||||
const files = await fs.readdir(absolutePath);
|
||||
const jsonFiles = files.filter(file => file.toLowerCase().endsWith('.json'));
|
||||
|
||||
if (jsonFiles.length === 0) {
|
||||
L.warn({ dirPath: absolutePath }, 'No JSON files found in directory');
|
||||
return 'No JSON files found in the specified directory.';
|
||||
}
|
||||
|
||||
let totalProcessed = 0;
|
||||
let batchCount = 0;
|
||||
L.info({ fileCount: jsonFiles.length }, 'Found JSON files to process');
|
||||
|
||||
stateManager.startTraining();
|
||||
|
||||
// Process first file with clean flag, subsequent files without cleaning
|
||||
for (let i = 0; i < jsonFiles.length; i++) {
|
||||
const jsonPath = path.join(absolutePath, jsonFiles[i]);
|
||||
const fileNumber = i + 1;
|
||||
L.debug(
|
||||
{ file: jsonFiles[i], progress: `${fileNumber}/${jsonFiles.length}` },
|
||||
'Processing file'
|
||||
);
|
||||
|
||||
try {
|
||||
// Check memory usage before processing file
|
||||
const memoryUsage = getMemoryUsage();
|
||||
if (memoryUsage > MAX_MEMORY_USAGE) {
|
||||
L.warn('Memory usage too high, waiting for garbage collection');
|
||||
await processingDelay();
|
||||
global.gc?.(); // Optional garbage collection if --expose-gc flag is used
|
||||
}
|
||||
|
||||
// Check if we should skip this file (already processed)
|
||||
if (!clean && stateManager.isChannelProcessed(jsonFiles[i])) {
|
||||
L.debug({ file: jsonFiles[i] }, 'Skipping already processed file');
|
||||
continue;
|
||||
}
|
||||
|
||||
const result = await trainFromJson(
|
||||
guildId,
|
||||
jsonPath,
|
||||
i === 0 ? clean : false // Only clean on first file
|
||||
);
|
||||
|
||||
// Extract number of processed messages from result string
|
||||
const processed = parseInt(result.match(/\d+/)?.[0] || '0');
|
||||
totalProcessed += processed;
|
||||
batchCount++;
|
||||
|
||||
// Update state after each file
|
||||
stateManager.updateProgress('json-import', jsonFiles[i], totalProcessed);
|
||||
L.trace(
|
||||
{ file: jsonFiles[i], processed, totalProcessed },
|
||||
'File processing complete'
|
||||
);
|
||||
|
||||
// Add delay between files
|
||||
if (batchCount % 5 === 0) {
|
||||
await processingDelay();
|
||||
}
|
||||
|
||||
// Clear any references that might be held
|
||||
if (global.gc) {
|
||||
global.gc();
|
||||
}
|
||||
} catch (err) {
|
||||
const error = err as Error;
|
||||
L.error(
|
||||
{ error: error.message, file: jsonFiles[i], stack: error.stack },
|
||||
'Error processing JSON file'
|
||||
);
|
||||
stateManager.recordError(error, 'json-import', jsonFiles[i]);
|
||||
// Add longer delay after error
|
||||
await processingDelay();
|
||||
// Continue with next file instead of failing completely
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
const summary = { totalProcessed, fileCount: jsonFiles.length };
|
||||
L.info(summary, 'Directory training complete');
|
||||
return `Successfully trained from ${totalProcessed} messages across ${jsonFiles.length} files.`;
|
||||
} finally {
|
||||
// Clean up regardless of success/failure
|
||||
await cleanup();
|
||||
// Remove process termination handlers
|
||||
process.off('SIGINT', cleanup);
|
||||
process.off('SIGTERM', cleanup);
|
||||
}
|
||||
} catch (err) {
|
||||
const error = err as Error;
|
||||
L.error(
|
||||
{ error: error.message, stack: error.stack, dirPath },
|
||||
'Error during directory training'
|
||||
);
|
||||
stateManager.recordError(error);
|
||||
return `Training encountered an error: ${error.message}. Use clean=false to resume from last checkpoint.`;
|
||||
}
|
||||
}
|
||||
|
||||
async function main(): Promise<void> {
|
||||
const args = process.argv.slice(2);
|
||||
if (args.length < 2) {
|
||||
console.log('Usage: node train.js <guildId> <path> [--keep-existing] [--directory]');
|
||||
console.log('Options:');
|
||||
console.log(' --keep-existing Keep existing training data');
|
||||
console.log(' --directory Process all JSON files in the specified directory');
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
const guildId = args[0];
|
||||
const inputPath = args[1];
|
||||
const keepExisting = args.includes('--keep-existing');
|
||||
const isDirectory = args.includes('--directory');
|
||||
|
||||
const dataSourceOptions = Markov.extendDataSourceOptions(ormconfig);
|
||||
const dataSource = new DataSource(dataSourceOptions);
|
||||
await dataSource.initialize();
|
||||
|
||||
// Ensure guild exists in DB
|
||||
await Guild.upsert(Guild.create({ id: guildId }), ['id']);
|
||||
|
||||
const result = isDirectory
|
||||
? await trainFromDirectory(guildId, inputPath, !keepExisting)
|
||||
: await trainFromJson(guildId, inputPath, !keepExisting);
|
||||
console.log(result);
|
||||
|
||||
await dataSource.destroy();
|
||||
}
|
||||
|
||||
if (require.main === module) {
|
||||
main().catch(console.error);
|
||||
}
|
||||
113
src/training-state.ts
Normal file
113
src/training-state.ts
Normal file
@@ -0,0 +1,113 @@
|
||||
import fs from 'fs-extra';
|
||||
import path from 'path';
|
||||
import { TrainingState } from './types';
|
||||
import L from './logger';
|
||||
|
||||
export class TrainingStateManager {
|
||||
private stateFile: string;
|
||||
private state: TrainingState;
|
||||
|
||||
constructor(guildId: string, configDir: string = 'config') {
|
||||
this.stateFile = path.join(configDir, 'training-state', `${guildId}.json`);
|
||||
|
||||
// Initialize with default state
|
||||
this.state = {
|
||||
guildId,
|
||||
processedChannels: [],
|
||||
totalMessages: 0,
|
||||
lastUpdate: new Date().toISOString(),
|
||||
inProgress: false
|
||||
};
|
||||
|
||||
// Ensure directory exists
|
||||
fs.ensureDirSync(path.dirname(this.stateFile));
|
||||
|
||||
// Load existing state if available
|
||||
this.loadState();
|
||||
}
|
||||
|
||||
private loadState(): void {
|
||||
try {
|
||||
if (fs.existsSync(this.stateFile)) {
|
||||
const savedState = fs.readJsonSync(this.stateFile);
|
||||
this.state = { ...this.state, ...savedState };
|
||||
L.info({ guildId: this.state.guildId }, 'Loaded existing training state');
|
||||
}
|
||||
} catch (err) {
|
||||
L.error({ err }, 'Error loading training state');
|
||||
// Keep using default state if load fails
|
||||
}
|
||||
}
|
||||
|
||||
private saveState(): void {
|
||||
try {
|
||||
fs.writeJsonSync(this.stateFile, this.state, { spaces: 2 });
|
||||
} catch (err) {
|
||||
L.error({ err }, 'Error saving training state');
|
||||
}
|
||||
}
|
||||
|
||||
public startTraining(): void {
|
||||
this.state.inProgress = true;
|
||||
this.state.error = undefined;
|
||||
this.state.lastUpdate = new Date().toISOString();
|
||||
this.saveState();
|
||||
}
|
||||
|
||||
public finishTraining(): void {
|
||||
this.state.inProgress = false;
|
||||
this.state.lastUpdate = new Date().toISOString();
|
||||
this.saveState();
|
||||
}
|
||||
|
||||
public updateProgress(channelId: string, messageId: string, messagesProcessed: number): void {
|
||||
this.state.lastChannelId = channelId;
|
||||
this.state.lastMessageId = messageId;
|
||||
this.state.totalMessages = messagesProcessed;
|
||||
this.state.lastUpdate = new Date().toISOString();
|
||||
this.saveState();
|
||||
}
|
||||
|
||||
public markChannelComplete(channelId: string): void {
|
||||
if (!this.state.processedChannels.includes(channelId)) {
|
||||
this.state.processedChannels.push(channelId);
|
||||
this.saveState();
|
||||
}
|
||||
}
|
||||
|
||||
public recordError(error: Error, channelId?: string, messageId?: string): void {
|
||||
this.state.error = {
|
||||
message: error.message,
|
||||
channelId,
|
||||
messageId,
|
||||
timestamp: new Date().toISOString()
|
||||
};
|
||||
this.saveState();
|
||||
}
|
||||
|
||||
public isChannelProcessed(channelId: string): boolean {
|
||||
return this.state.processedChannels.includes(channelId);
|
||||
}
|
||||
|
||||
public shouldResumeFromMessage(channelId: string): string | undefined {
|
||||
if (this.state.inProgress && this.state.lastChannelId === channelId) {
|
||||
return this.state.lastMessageId;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
public getState(): TrainingState {
|
||||
return { ...this.state };
|
||||
}
|
||||
|
||||
public reset(): void {
|
||||
this.state = {
|
||||
guildId: this.state.guildId,
|
||||
processedChannels: [],
|
||||
totalMessages: 0,
|
||||
lastUpdate: new Date().toISOString(),
|
||||
inProgress: false
|
||||
};
|
||||
this.saveState();
|
||||
}
|
||||
}
|
||||
19
src/types.ts
Normal file
19
src/types.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
export interface MarkovDataCustom {
|
||||
attachments: string[];
|
||||
}
|
||||
|
||||
export interface TrainingState {
|
||||
guildId: string;
|
||||
lastMessageId?: string;
|
||||
lastChannelId?: string;
|
||||
processedChannels: string[];
|
||||
totalMessages: number;
|
||||
lastUpdate: string;
|
||||
inProgress: boolean;
|
||||
error?: {
|
||||
message: string;
|
||||
channelId?: string;
|
||||
messageId?: string;
|
||||
timestamp: string;
|
||||
};
|
||||
}
|
||||
Reference in New Issue
Block a user