Add train from file and non-clean training option.

Closes #31
This commit is contained in:
Charlie Laabs
2022-05-31 22:52:37 -05:00
parent c742bee965
commit 9adf741b5f
7 changed files with 220 additions and 55 deletions

View File

@@ -85,6 +85,20 @@ export const trainCommand = new SlashCommandBuilder()
.setName('train')
.setDescription(
'Train from past messages from the configured listened channels. This takes a while.'
)
.addBooleanOption((clean) =>
clean
.setName('clean')
.setDescription(
'Whether the database should be emptied before training. Default is true (recommended).'
)
.setRequired(false)
)
.addAttachmentOption((json) =>
json
.setName('json')
.setDescription('Train from a provided JSON file rather than channel history.')
.setRequired(false)
);
const commands = [

View File

@@ -13,6 +13,7 @@ import makeEta from 'simple-eta';
import formatDistanceToNow from 'date-fns/formatDistanceToNow';
import addSeconds from 'date-fns/addSeconds';
import type { APIInteractionGuildMember, APISelectMenuComponent } from 'discord-api-types/v9';
import fetch from 'node-fetch';
import L from './logger';
import { Channel } from './entity/Channel';
import { Guild } from './entity/Guild';
@@ -253,7 +254,8 @@ function messageToData(message: Discord.Message): AddDataProps {
* Recursively gets all messages in a text channel's history.
*/
async function saveGuildMessageHistory(
interaction: Discord.Message | Discord.CommandInteraction
interaction: Discord.Message | Discord.CommandInteraction,
clean = true
): Promise<string> {
if (!isModerator(interaction.member)) return INVALID_PERMISSIONS_MESSAGE;
if (!interaction.guildId || !interaction.guild) return INVALID_GUILD_MESSAGE;
@@ -265,8 +267,12 @@ async function saveGuildMessageHistory(
return 'No channels configured to learn from. Set some with `/listen add`.';
}
L.debug('Deleting old data');
await markov.delete();
if (clean) {
L.debug('Deleting old data');
await markov.delete();
} else {
L.debug('Not deleting old data during training');
}
const channelIds = channels.map((c) => c.id);
L.debug({ channelIds }, `Training from text channels`);
@@ -440,6 +446,69 @@ async function saveGuildMessageHistory(
return `Trained from ${messagesCount} past human authored messages.`;
}
interface JSONImport {
message: string;
attachments?: string[];
}
/**
* Train from an attached JSON file
*/
async function trainFromAttachmentJson(
attachment: Discord.MessageAttachment,
interaction: Discord.CommandInteraction,
clean = true
): Promise<string> {
if (!isModerator(interaction.member)) return INVALID_PERMISSIONS_MESSAGE;
if (!interaction.guildId || !interaction.guild) return INVALID_GUILD_MESSAGE;
const { guildId } = interaction;
const markov = await getMarkovByGuildId(guildId);
let trainingData: AddDataProps[];
try {
const importAttachmentUrl = attachment.attachment.toString();
const getResp = await fetch(importAttachmentUrl);
if (!getResp.ok) throw new Error(getResp.statusText);
const importData = (await getResp.json()) as JSONImport[];
trainingData = importData.map((datum, index) => {
if (!datum.message) {
throw new Error(`Entry at index ${index} must have a "message"`);
}
if (typeof datum.message !== 'string') {
throw new Error(`Entry at index ${index} must have a "message" with a type of string`);
}
if (datum.attachments?.every((a) => typeof a !== 'string')) {
throw new Error(
`Entry at index ${index} must have all "attachments" each with a type of string`
);
}
let custom: MarkovDataCustom | undefined;
if (datum.attachments?.length) custom = { attachments: datum.attachments };
return {
string: datum.message,
custom,
tags: [guildId],
};
});
} catch (err) {
L.error(err);
return 'The provided attachment file has invalid formatting. See the logs for details.';
}
if (clean) {
L.debug('Deleting old data');
await markov.delete();
} else {
L.debug('Not deleting old data during training');
}
await markov.addData(trainingData);
L.info(`Trained from ${trainingData.length} past human authored messages.`);
return `Trained from ${trainingData.length} past human authored messages.`;
}
interface GenerateResponse {
message?: AgnosticReplyOptions;
debug?: AgnosticReplyOptions;
@@ -846,10 +915,18 @@ client.on('interactionCreate', async (interaction) => {
}
} else if (interaction.commandName === trainCommand.name) {
await interaction.deferReply();
const reply = (await interaction.fetchReply()) as Discord.Message; // Must fetch the reply ASAP
const responseMessage = await saveGuildMessageHistory(interaction);
// Send a message in reply to the reply to avoid the 15 minute webhook token timeout
await reply.reply({ content: responseMessage });
const clean = interaction.options.getBoolean('clean') ?? true;
const trainingJSON = interaction.options.getAttachment('json');
if (trainingJSON) {
const responseMessage = await trainFromAttachmentJson(trainingJSON, interaction, clean);
await interaction.followUp(responseMessage);
} else {
const reply = (await interaction.fetchReply()) as Discord.Message; // Must fetch the reply ASAP
const responseMessage = await saveGuildMessageHistory(interaction, clean);
// Send a message in reply to the reply to avoid the 15 minute webhook token timeout
await reply.reply({ content: responseMessage });
}
}
} else if (interaction.isSelectMenu()) {
if (interaction.customId === 'listen-modify-select') {