Initial support for channel-based data storage

This commit is contained in:
Charlie Laabs
2021-12-21 23:19:14 -06:00
parent a2ae99d75d
commit 3b946b72ec
6 changed files with 177 additions and 79 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -37,11 +37,21 @@ export class AppConfig {
* The command prefix used to trigger the bot commands (when not using slash commands) * The command prefix used to trigger the bot commands (when not using slash commands)
* @example !bot * @example !bot
* @default !mark * @default !mark
* @env CRON_SCHEDULE * @env MESSAGE_COMMAND_PREFIX
*/ */
@IsOptional() @IsOptional()
@IsString() @IsString()
commandPrefix = process.env.COMMAND_PREFIX || '!mark'; messageCommandPrefix = process.env.MESSAGE_COMMAND_PREFIX || '!mark';
/**
* The slash command name to generate a message from the bot. (e.g. `/mark`)
* @example message
* @default mark
* @env SLASH_COMMAND_NAME
*/
@IsOptional()
@IsString()
slashCommandName = process.env.SLASH_COMMAND_NAME || 'mark';
/** /**
* The activity status shown under the bot's name in the user list * The activity status shown under the bot's name in the user list

View File

@@ -6,10 +6,27 @@ import { packageJson } from './util';
const CHANNEL_OPTIONS_MAX = 25; const CHANNEL_OPTIONS_MAX = 25;
const helpSlashCommand = new SlashCommandBuilder() export const helpCommand = new SlashCommandBuilder()
.setName('help') .setName('help')
.setDescription(`How to use ${packageJson().name}`); .setDescription(`How to use ${packageJson().name}`);
export const inviteCommand = new SlashCommandBuilder()
.setName('invite')
.setDescription('Get the invite link for this bot.');
export const messageCommand = new SlashCommandBuilder()
.setName(config.slashCommandName)
.setDescription('Generate a message from learned past messages')
.addBooleanOption((tts) =>
tts.setName('tts').setDescription('Read the message via text-to-speech.').setRequired(false)
)
.addBooleanOption((debug) =>
debug
.setName('debug')
.setDescription('Follow up the generated message with the detailed sources that inspired it.')
.setRequired(false)
);
/** /**
* Helps generate a list of parameters for channel options * Helps generate a list of parameters for channel options
*/ */
@@ -20,15 +37,15 @@ const channelOptionsGenerator = (builder: SlashCommandChannelOption, index: numb
.setRequired(index === 0) .setRequired(index === 0)
.addChannelType(ChannelType.GuildText as any); .addChannelType(ChannelType.GuildText as any);
const listenChannelCommand = new SlashCommandBuilder() export const listenChannelCommand = new SlashCommandBuilder()
.setName('listen') .setName('listen')
.setDescription('Change what channels the bot actively listens to and learns from.')
.addSubcommand((sub) => { .addSubcommand((sub) => {
sub sub
.setName('add') .setName('add')
.setDescription( .setDescription(
`Add channels to learn from. Doesn't add the channel's past messages; re-train to do that.` `Add channels to learn from. Doesn't add the channel's past messages; re-train to do that.`
); );
Array.from(Array(CHANNEL_OPTIONS_MAX).keys()).forEach((index) => Array.from(Array(CHANNEL_OPTIONS_MAX).keys()).forEach((index) =>
sub.addChannelOption((opt) => channelOptionsGenerator(opt, index)) sub.addChannelOption((opt) => channelOptionsGenerator(opt, index))
); );
@@ -45,9 +62,25 @@ const listenChannelCommand = new SlashCommandBuilder()
); );
return sub; return sub;
}) })
.setDescription(`How to use ${packageJson().name}`); .addSubcommand((sub) =>
sub
.setName('list')
.setDescription(`List the channels the bot is currently actively listening to.`)
);
const commands = [helpSlashCommand.toJSON(), listenChannelCommand.toJSON()]; export const trainCommand = new SlashCommandBuilder()
.setName('train')
.setDescription(
'Train from past messages from the configured listened channels. This takes a while.'
);
const commands = [
helpCommand.toJSON(),
inviteCommand.toJSON(),
messageCommand.toJSON(),
listenChannelCommand.toJSON(),
trainCommand.toJSON(),
];
export async function deployCommands(clientId: string) { export async function deployCommands(clientId: string) {
const rest = new REST({ version: '9' }).setToken(config.token); const rest = new REST({ version: '9' }).setToken(config.token);

View File

@@ -16,7 +16,14 @@ import L from './logger';
import { Channel } from './entity/Channel'; import { Channel } from './entity/Channel';
import { Guild } from './entity/Guild'; import { Guild } from './entity/Guild';
import { config } from './config'; import { config } from './config';
import { deployCommands } from './deploy-commands'; import {
deployCommands,
helpCommand,
inviteCommand,
listenChannelCommand,
messageCommand,
trainCommand,
} from './deploy-commands';
import { getRandomElement, getVersion, packageJson } from './util'; import { getRandomElement, getVersion, packageJson } from './util';
interface MarkovDataCustom { interface MarkovDataCustom {
@@ -71,6 +78,18 @@ async function isValidChannel(channelId: string): Promise<boolean> {
return channel.listen; return channel.listen;
} }
/**
* #v3-complete
*/
async function getValidChannels(guildId: string): Promise<Discord.TextChannel[]> {
const id = parseInt(guildId, 10);
const dbChannels = await Channel.find({ guild: Guild.create({ id }), listen: true });
const channels = (
await Promise.all(dbChannels.map(async (dbc) => client.channels.fetch(dbc.id.toString())))
).filter((c): c is Discord.TextChannel => c !== null && c instanceof Discord.TextChannel);
return channels;
}
/** /**
* Checks if the author of a message as moderator-like permissions. * Checks if the author of a message as moderator-like permissions.
* @param {GuildMember} member Sender of the message * @param {GuildMember} member Sender of the message
@@ -106,10 +125,10 @@ type MessageCommands = 'respond' | 'train' | 'help' | 'invite' | 'debug' | 'tts'
function validateMessage(message: Discord.Message): MessageCommands { function validateMessage(message: Discord.Message): MessageCommands {
const messageText = message.content.toLowerCase(); const messageText = message.content.toLowerCase();
let command: MessageCommands = null; let command: MessageCommands = null;
const thisPrefix = messageText.substring(0, config.commandPrefix.length); const thisPrefix = messageText.substring(0, config.messageCommandPrefix.length);
if (thisPrefix === config.commandPrefix) { if (thisPrefix === config.messageCommandPrefix) {
const split = messageText.split(' '); const split = messageText.split(' ');
if (split[0] === config.commandPrefix && split.length === 1) { if (split[0] === config.messageCommandPrefix && split.length === 1) {
command = 'respond'; command = 'respond';
} else if (split[1] === 'train') { } else if (split[1] === 'train') {
command = 'train'; command = 'train';
@@ -140,44 +159,59 @@ function messageToData(message: Discord.Message): AddDataProps {
* Recursively gets all messages in a text channel's history. * Recursively gets all messages in a text channel's history.
* #v3-complete * #v3-complete
*/ */
async function saveChannelMessageHistory( async function saveGuildMessageHistory(
channel: Discord.TextChannel,
interaction: Discord.Message | Discord.CommandInteraction interaction: Discord.Message | Discord.CommandInteraction
): Promise<void> { ): Promise<string> {
if (!isModerator(interaction.member as any)) return; if (!isModerator(interaction.member as any))
const markov = await getMarkovByGuildId(channel.guildId); return 'You do not have the permissions for this action.';
L.debug({ channelId: channel.id }, `Training from text channel`); if (!interaction.guildId) return 'This action must be performed within a server.';
const markov = await getMarkovByGuildId(interaction.guildId);
const channels = await getValidChannels(interaction.guildId);
if (!channels.length) {
L.warn({ guildId: interaction.guildId }, 'No channels to train from');
return 'No channels configured to learn from. Set some with `/listen add`.';
}
const channelIds = channels.map((c) => c.id);
L.debug({ channelIds }, `Training from text channels`);
const PAGE_SIZE = 100; const PAGE_SIZE = 100;
let keepGoing = true; let messagesCount = 0;
let oldestMessageID: string | undefined; // eslint-disable-next-line no-restricted-syntax
for (const channel of channels) {
let oldestMessageID: string | undefined;
let keepGoing = true;
L.debug({ channelId: channel.id, messagesCount }, `Training from channel`);
let channelMessagesCount = 0; while (keepGoing) {
// eslint-disable-next-line no-await-in-loop
while (keepGoing) { const messages = await channel.messages.fetch({
// eslint-disable-next-line no-await-in-loop before: oldestMessageID,
const messages = await channel.messages.fetch({ limit: PAGE_SIZE,
before: oldestMessageID, });
limit: PAGE_SIZE, const nonBotMessageFormatted = messages.filter((elem) => !elem.author.bot).map(messageToData);
}); L.trace({ oldestMessageID }, `Saving ${nonBotMessageFormatted.length} messages`);
const nonBotMessageFormatted = messages.filter((elem) => !elem.author.bot).map(messageToData); // eslint-disable-next-line no-await-in-loop
L.debug({ oldestMessageID }, `Saving ${nonBotMessageFormatted.length} messages`); await markov.addData(nonBotMessageFormatted);
// eslint-disable-next-line no-await-in-loop L.trace('Finished saving messages');
await markov.addData(nonBotMessageFormatted); messagesCount += nonBotMessageFormatted.length;
L.trace('Finished saving messages'); const lastMessage = messages.last();
channelMessagesCount += nonBotMessageFormatted.length; if (!lastMessage || messages.size < PAGE_SIZE) {
const lastMessage = messages.last(); keepGoing = false;
if (!lastMessage || messages.size < PAGE_SIZE) { } else {
keepGoing = false; oldestMessageID = lastMessage.id;
} else { }
oldestMessageID = lastMessage.id;
} }
} }
L.info( L.info({ channelIds }, `Trained from ${messagesCount} past human authored messages.`);
{ channelId: channel.id }, return `Trained from ${messagesCount} past human authored messages.`;
`Trained from ${channelMessagesCount} past human authored messages.` }
);
await interaction.reply(`Trained from ${channelMessagesCount} past human authored messages.`); interface GenerateResponse {
message?: Discord.MessageOptions;
debug?: Discord.MessageOptions;
} }
/** /**
@@ -192,23 +226,15 @@ async function generateResponse(
interaction: Discord.Message | Discord.CommandInteraction, interaction: Discord.Message | Discord.CommandInteraction,
debug = false, debug = false,
tts = false tts = false
): Promise<void> { ): Promise<GenerateResponse> {
L.debug('Responding...'); L.debug('Responding...');
if (!interaction.guildId) { if (!interaction.guildId) {
L.debug('Received an interaction without a guildId'); L.warn('Received an interaction without a guildId');
return; return { message: { content: 'This action must be performed within a server.' } };
} }
if (!interaction.channelId) { if (!interaction.channelId) {
L.debug('Received an interaction without a channelId'); L.warn('Received an interaction without a channelId');
return; return { message: { content: 'This action must be performed within a text channel.' } };
}
const isValid = await isValidChannel(interaction.channelId);
if (!isValid) {
L.debug(
{ channelId: interaction.channelId },
'Channel is not enabled for listening. Ignoring...'
);
return;
} }
const markov = await getMarkovByGuildId(interaction.guildId); const markov = await getMarkovByGuildId(interaction.guildId);
@@ -251,26 +277,19 @@ async function generateResponse(
response.string = response.string.replace(/@everyone/g, '@everyοne'); // Replace @everyone with a homoglyph 'o' response.string = response.string.replace(/@everyone/g, '@everyοne'); // Replace @everyone with a homoglyph 'o'
messageOpts.content = response.string; messageOpts.content = response.string;
if (interaction instanceof Discord.Message) { const responseMessages: GenerateResponse = {
await interaction.channel.send(messageOpts); message: messageOpts,
if (debug) { };
await interaction.channel.send(`\`\`\`\n${JSON.stringify(response, null, 2)}\n\`\`\``); if (debug) {
} responseMessages.debug = { content: `\`\`\`\n${JSON.stringify(response, null, 2)}\n\`\`\`` };
} else if (interaction instanceof Discord.CommandInteraction) {
await interaction.editReply(messageOpts);
if (debug) {
await interaction.followUp(`\`\`\`\n${JSON.stringify(response, null, 2)}\n\`\`\``);
}
} }
return responseMessages;
} catch (err) { } catch (err) {
L.error(err); L.error(err);
if (debug) { if (debug) {
if (interaction instanceof Discord.Message) { return { debug: { content: `\n\`\`\`\nERROR: ${err}\n\`\`\`` } };
await interaction.channel.send(`\n\`\`\`\nERROR: ${err}\n\`\`\``);
} else if (interaction instanceof Discord.CommandInteraction) {
await interaction.editReply(`\n\`\`\`\nERROR: ${err}\n\`\`\``);
}
} }
return {};
} }
} }
@@ -281,25 +300,29 @@ function helpMessage(): Discord.MessageOptions {
.setThumbnail(avatarURL as string) .setThumbnail(avatarURL as string)
.setDescription('A Markov chain chatbot that speaks based on previous chat input.') .setDescription('A Markov chain chatbot that speaks based on previous chat input.')
.addField( .addField(
`${config.commandPrefix}`, `${config.messageCommandPrefix} or /${messageCommand.name}`,
'Generates a sentence to say based on the chat database. Send your ' + 'Generates a sentence to say based on the chat database. Send your ' +
'message as TTS to recieve it as TTS.' 'message as TTS to recieve it as TTS.'
) )
.addField( .addField(
`${config.commandPrefix} train`, `${config.messageCommandPrefix} train or /${trainCommand.name}`,
'Fetches the maximum amount of previous messages in the current ' + 'Fetches the maximum amount of previous messages in the current ' +
'text channel, adds it to the database, and regenerates the corpus. Takes some time.' 'text channel, adds it to the database, and regenerates the corpus. Takes some time.'
) )
.addField( .addField(
`${config.commandPrefix} invite`, `${config.messageCommandPrefix} invite or /${inviteCommand.name}`,
"Don't invite this bot to other servers. The database is shared " + "Don't invite this bot to other servers. The database is shared " +
'between all servers and text channels.' 'between all servers and text channels.'
) )
.addField( .addField(
`${config.commandPrefix} debug`, `${config.messageCommandPrefix} debug or /${messageCommand.name} debug: True`,
`Runs the ${config.commandPrefix} command and follows it up with debug info.` `Runs the ${config.messageCommandPrefix} command and follows it up with debug info.`
) )
.setFooter(`Markov Discord ${getVersion()} by ${packageJson().author}`); .addField(
`${config.messageCommandPrefix} tts or /${messageCommand.name} tts: True`,
`Runs the ${config.messageCommandPrefix} command and reads it with text-to-speech.`
)
.setFooter(`${packageJson().name} ${getVersion()} by ${packageJson().author}`);
return { return {
embeds: [embed], embeds: [embed],
}; };
@@ -342,10 +365,13 @@ client.on('messageCreate', async (message) => {
await message.channel.send(inviteMessage()); await message.channel.send(inviteMessage());
} }
if (command === 'train') { if (command === 'train') {
await saveChannelMessageHistory(message.channel, message); const response = await saveGuildMessageHistory(message);
await message.reply(response);
} }
if (command === 'respond') { if (command === 'respond') {
await generateResponse(message); const generatedResponse = await generateResponse(message);
if (generatedResponse.message) await message.reply(generatedResponse.message);
if (generatedResponse.debug) await message.reply(generatedResponse.debug);
} }
if (command === 'tts') { if (command === 'tts') {
await generateResponse(message, false, true); await generateResponse(message, false, true);
@@ -370,6 +396,7 @@ client.on('messageCreate', async (message) => {
* #v3-complete * #v3-complete
*/ */
client.on('messageDelete', async (message) => { client.on('messageDelete', async (message) => {
if (message.author?.bot) return;
L.info(`Deleting message ${message.id}`); L.info(`Deleting message ${message.id}`);
if (!(message.guildId && message.content)) { if (!(message.guildId && message.content)) {
return; return;
@@ -382,6 +409,7 @@ client.on('messageDelete', async (message) => {
* #v3-complete * #v3-complete
*/ */
client.on('messageUpdate', async (oldMessage, newMessage) => { client.on('messageUpdate', async (oldMessage, newMessage) => {
if (oldMessage.author?.bot) return;
L.info(`Editing message ${oldMessage.id}`); L.info(`Editing message ${oldMessage.id}`);
if (!(oldMessage.guildId && oldMessage.content && newMessage.content)) { if (!(oldMessage.guildId && oldMessage.content && newMessage.content)) {
return; return;
@@ -391,6 +419,33 @@ client.on('messageUpdate', async (oldMessage, newMessage) => {
await markov.addData([newMessage.content]); await markov.addData([newMessage.content]);
}); });
client.on('interactionCreate', async (interaction) => {
if (!interaction.isCommand()) return;
// Unprivileged commands
if (interaction.commandName === helpCommand.name) {
await interaction.reply(helpMessage());
} else if (interaction.commandName === inviteCommand.name) {
await interaction.reply(inviteMessage());
} else if (interaction.commandName === messageCommand.name) {
await interaction.deferReply();
const tts = interaction.options.getBoolean('tts') || false;
const debug = interaction.options.getBoolean('debug') || false;
const generatedResponse = await generateResponse(interaction, debug, tts);
if (generatedResponse.message) await interaction.editReply(generatedResponse.message);
if (generatedResponse.debug) await interaction.followUp(generatedResponse.debug);
if (!Object.keys(generatedResponse).length) await interaction.deleteReply();
}
// Privileged commands
if (interaction.commandName === listenChannelCommand.name) {
await interaction.reply('Pong!');
} else if (interaction.commandName === trainCommand.name) {
await interaction.deferReply();
const responseMessage = await saveGuildMessageHistory(interaction);
await interaction.editReply(responseMessage);
}
});
/** /**
* Loads the config settings from disk * Loads the config settings from disk
*/ */