Initial support for channel-based data storage

This commit is contained in:
Charlie Laabs
2021-12-21 23:19:14 -06:00
parent a2ae99d75d
commit 3b946b72ec
6 changed files with 177 additions and 79 deletions

View File

@@ -37,11 +37,21 @@ export class AppConfig {
* The command prefix used to trigger the bot commands (when not using slash commands)
* @example !bot
* @default !mark
* @env CRON_SCHEDULE
* @env MESSAGE_COMMAND_PREFIX
*/
@IsOptional()
@IsString()
commandPrefix = process.env.COMMAND_PREFIX || '!mark';
messageCommandPrefix = process.env.MESSAGE_COMMAND_PREFIX || '!mark';
/**
* The slash command name to generate a message from the bot. (e.g. `/mark`)
* @example message
* @default mark
* @env SLASH_COMMAND_NAME
*/
@IsOptional()
@IsString()
slashCommandName = process.env.SLASH_COMMAND_NAME || 'mark';
/**
* The activity status shown under the bot's name in the user list

View File

@@ -6,10 +6,27 @@ import { packageJson } from './util';
const CHANNEL_OPTIONS_MAX = 25;
const helpSlashCommand = new SlashCommandBuilder()
export const helpCommand = new SlashCommandBuilder()
.setName('help')
.setDescription(`How to use ${packageJson().name}`);
export const inviteCommand = new SlashCommandBuilder()
.setName('invite')
.setDescription('Get the invite link for this bot.');
export const messageCommand = new SlashCommandBuilder()
.setName(config.slashCommandName)
.setDescription('Generate a message from learned past messages')
.addBooleanOption((tts) =>
tts.setName('tts').setDescription('Read the message via text-to-speech.').setRequired(false)
)
.addBooleanOption((debug) =>
debug
.setName('debug')
.setDescription('Follow up the generated message with the detailed sources that inspired it.')
.setRequired(false)
);
/**
* Helps generate a list of parameters for channel options
*/
@@ -20,15 +37,15 @@ const channelOptionsGenerator = (builder: SlashCommandChannelOption, index: numb
.setRequired(index === 0)
.addChannelType(ChannelType.GuildText as any);
const listenChannelCommand = new SlashCommandBuilder()
export const listenChannelCommand = new SlashCommandBuilder()
.setName('listen')
.setDescription('Change what channels the bot actively listens to and learns from.')
.addSubcommand((sub) => {
sub
.setName('add')
.setDescription(
`Add channels to learn from. Doesn't add the channel's past messages; re-train to do that.`
);
Array.from(Array(CHANNEL_OPTIONS_MAX).keys()).forEach((index) =>
sub.addChannelOption((opt) => channelOptionsGenerator(opt, index))
);
@@ -45,9 +62,25 @@ const listenChannelCommand = new SlashCommandBuilder()
);
return sub;
})
.setDescription(`How to use ${packageJson().name}`);
.addSubcommand((sub) =>
sub
.setName('list')
.setDescription(`List the channels the bot is currently actively listening to.`)
);
const commands = [helpSlashCommand.toJSON(), listenChannelCommand.toJSON()];
export const trainCommand = new SlashCommandBuilder()
.setName('train')
.setDescription(
'Train from past messages from the configured listened channels. This takes a while.'
);
const commands = [
helpCommand.toJSON(),
inviteCommand.toJSON(),
messageCommand.toJSON(),
listenChannelCommand.toJSON(),
trainCommand.toJSON(),
];
export async function deployCommands(clientId: string) {
const rest = new REST({ version: '9' }).setToken(config.token);

View File

@@ -16,7 +16,14 @@ import L from './logger';
import { Channel } from './entity/Channel';
import { Guild } from './entity/Guild';
import { config } from './config';
import { deployCommands } from './deploy-commands';
import {
deployCommands,
helpCommand,
inviteCommand,
listenChannelCommand,
messageCommand,
trainCommand,
} from './deploy-commands';
import { getRandomElement, getVersion, packageJson } from './util';
interface MarkovDataCustom {
@@ -71,6 +78,18 @@ async function isValidChannel(channelId: string): Promise<boolean> {
return channel.listen;
}
/**
* #v3-complete
*/
async function getValidChannels(guildId: string): Promise<Discord.TextChannel[]> {
const id = parseInt(guildId, 10);
const dbChannels = await Channel.find({ guild: Guild.create({ id }), listen: true });
const channels = (
await Promise.all(dbChannels.map(async (dbc) => client.channels.fetch(dbc.id.toString())))
).filter((c): c is Discord.TextChannel => c !== null && c instanceof Discord.TextChannel);
return channels;
}
/**
* Checks if the author of a message as moderator-like permissions.
* @param {GuildMember} member Sender of the message
@@ -106,10 +125,10 @@ type MessageCommands = 'respond' | 'train' | 'help' | 'invite' | 'debug' | 'tts'
function validateMessage(message: Discord.Message): MessageCommands {
const messageText = message.content.toLowerCase();
let command: MessageCommands = null;
const thisPrefix = messageText.substring(0, config.commandPrefix.length);
if (thisPrefix === config.commandPrefix) {
const thisPrefix = messageText.substring(0, config.messageCommandPrefix.length);
if (thisPrefix === config.messageCommandPrefix) {
const split = messageText.split(' ');
if (split[0] === config.commandPrefix && split.length === 1) {
if (split[0] === config.messageCommandPrefix && split.length === 1) {
command = 'respond';
} else if (split[1] === 'train') {
command = 'train';
@@ -140,44 +159,59 @@ function messageToData(message: Discord.Message): AddDataProps {
* Recursively gets all messages in a text channel's history.
* #v3-complete
*/
async function saveChannelMessageHistory(
channel: Discord.TextChannel,
async function saveGuildMessageHistory(
interaction: Discord.Message | Discord.CommandInteraction
): Promise<void> {
if (!isModerator(interaction.member as any)) return;
const markov = await getMarkovByGuildId(channel.guildId);
L.debug({ channelId: channel.id }, `Training from text channel`);
): Promise<string> {
if (!isModerator(interaction.member as any))
return 'You do not have the permissions for this action.';
if (!interaction.guildId) return 'This action must be performed within a server.';
const markov = await getMarkovByGuildId(interaction.guildId);
const channels = await getValidChannels(interaction.guildId);
if (!channels.length) {
L.warn({ guildId: interaction.guildId }, 'No channels to train from');
return 'No channels configured to learn from. Set some with `/listen add`.';
}
const channelIds = channels.map((c) => c.id);
L.debug({ channelIds }, `Training from text channels`);
const PAGE_SIZE = 100;
let keepGoing = true;
let oldestMessageID: string | undefined;
let messagesCount = 0;
// eslint-disable-next-line no-restricted-syntax
for (const channel of channels) {
let oldestMessageID: string | undefined;
let keepGoing = true;
L.debug({ channelId: channel.id, messagesCount }, `Training from channel`);
let channelMessagesCount = 0;
while (keepGoing) {
// eslint-disable-next-line no-await-in-loop
const messages = await channel.messages.fetch({
before: oldestMessageID,
limit: PAGE_SIZE,
});
const nonBotMessageFormatted = messages.filter((elem) => !elem.author.bot).map(messageToData);
L.debug({ oldestMessageID }, `Saving ${nonBotMessageFormatted.length} messages`);
// eslint-disable-next-line no-await-in-loop
await markov.addData(nonBotMessageFormatted);
L.trace('Finished saving messages');
channelMessagesCount += nonBotMessageFormatted.length;
const lastMessage = messages.last();
if (!lastMessage || messages.size < PAGE_SIZE) {
keepGoing = false;
} else {
oldestMessageID = lastMessage.id;
while (keepGoing) {
// eslint-disable-next-line no-await-in-loop
const messages = await channel.messages.fetch({
before: oldestMessageID,
limit: PAGE_SIZE,
});
const nonBotMessageFormatted = messages.filter((elem) => !elem.author.bot).map(messageToData);
L.trace({ oldestMessageID }, `Saving ${nonBotMessageFormatted.length} messages`);
// eslint-disable-next-line no-await-in-loop
await markov.addData(nonBotMessageFormatted);
L.trace('Finished saving messages');
messagesCount += nonBotMessageFormatted.length;
const lastMessage = messages.last();
if (!lastMessage || messages.size < PAGE_SIZE) {
keepGoing = false;
} else {
oldestMessageID = lastMessage.id;
}
}
}
L.info(
{ channelId: channel.id },
`Trained from ${channelMessagesCount} past human authored messages.`
);
await interaction.reply(`Trained from ${channelMessagesCount} past human authored messages.`);
L.info({ channelIds }, `Trained from ${messagesCount} past human authored messages.`);
return `Trained from ${messagesCount} past human authored messages.`;
}
interface GenerateResponse {
message?: Discord.MessageOptions;
debug?: Discord.MessageOptions;
}
/**
@@ -192,23 +226,15 @@ async function generateResponse(
interaction: Discord.Message | Discord.CommandInteraction,
debug = false,
tts = false
): Promise<void> {
): Promise<GenerateResponse> {
L.debug('Responding...');
if (!interaction.guildId) {
L.debug('Received an interaction without a guildId');
return;
L.warn('Received an interaction without a guildId');
return { message: { content: 'This action must be performed within a server.' } };
}
if (!interaction.channelId) {
L.debug('Received an interaction without a channelId');
return;
}
const isValid = await isValidChannel(interaction.channelId);
if (!isValid) {
L.debug(
{ channelId: interaction.channelId },
'Channel is not enabled for listening. Ignoring...'
);
return;
L.warn('Received an interaction without a channelId');
return { message: { content: 'This action must be performed within a text channel.' } };
}
const markov = await getMarkovByGuildId(interaction.guildId);
@@ -251,26 +277,19 @@ async function generateResponse(
response.string = response.string.replace(/@everyone/g, '@everyοne'); // Replace @everyone with a homoglyph 'o'
messageOpts.content = response.string;
if (interaction instanceof Discord.Message) {
await interaction.channel.send(messageOpts);
if (debug) {
await interaction.channel.send(`\`\`\`\n${JSON.stringify(response, null, 2)}\n\`\`\``);
}
} else if (interaction instanceof Discord.CommandInteraction) {
await interaction.editReply(messageOpts);
if (debug) {
await interaction.followUp(`\`\`\`\n${JSON.stringify(response, null, 2)}\n\`\`\``);
}
const responseMessages: GenerateResponse = {
message: messageOpts,
};
if (debug) {
responseMessages.debug = { content: `\`\`\`\n${JSON.stringify(response, null, 2)}\n\`\`\`` };
}
return responseMessages;
} catch (err) {
L.error(err);
if (debug) {
if (interaction instanceof Discord.Message) {
await interaction.channel.send(`\n\`\`\`\nERROR: ${err}\n\`\`\``);
} else if (interaction instanceof Discord.CommandInteraction) {
await interaction.editReply(`\n\`\`\`\nERROR: ${err}\n\`\`\``);
}
return { debug: { content: `\n\`\`\`\nERROR: ${err}\n\`\`\`` } };
}
return {};
}
}
@@ -281,25 +300,29 @@ function helpMessage(): Discord.MessageOptions {
.setThumbnail(avatarURL as string)
.setDescription('A Markov chain chatbot that speaks based on previous chat input.')
.addField(
`${config.commandPrefix}`,
`${config.messageCommandPrefix} or /${messageCommand.name}`,
'Generates a sentence to say based on the chat database. Send your ' +
'message as TTS to recieve it as TTS.'
)
.addField(
`${config.commandPrefix} train`,
`${config.messageCommandPrefix} train or /${trainCommand.name}`,
'Fetches the maximum amount of previous messages in the current ' +
'text channel, adds it to the database, and regenerates the corpus. Takes some time.'
)
.addField(
`${config.commandPrefix} invite`,
`${config.messageCommandPrefix} invite or /${inviteCommand.name}`,
"Don't invite this bot to other servers. The database is shared " +
'between all servers and text channels.'
)
.addField(
`${config.commandPrefix} debug`,
`Runs the ${config.commandPrefix} command and follows it up with debug info.`
`${config.messageCommandPrefix} debug or /${messageCommand.name} debug: True`,
`Runs the ${config.messageCommandPrefix} command and follows it up with debug info.`
)
.setFooter(`Markov Discord ${getVersion()} by ${packageJson().author}`);
.addField(
`${config.messageCommandPrefix} tts or /${messageCommand.name} tts: True`,
`Runs the ${config.messageCommandPrefix} command and reads it with text-to-speech.`
)
.setFooter(`${packageJson().name} ${getVersion()} by ${packageJson().author}`);
return {
embeds: [embed],
};
@@ -342,10 +365,13 @@ client.on('messageCreate', async (message) => {
await message.channel.send(inviteMessage());
}
if (command === 'train') {
await saveChannelMessageHistory(message.channel, message);
const response = await saveGuildMessageHistory(message);
await message.reply(response);
}
if (command === 'respond') {
await generateResponse(message);
const generatedResponse = await generateResponse(message);
if (generatedResponse.message) await message.reply(generatedResponse.message);
if (generatedResponse.debug) await message.reply(generatedResponse.debug);
}
if (command === 'tts') {
await generateResponse(message, false, true);
@@ -370,6 +396,7 @@ client.on('messageCreate', async (message) => {
* #v3-complete
*/
client.on('messageDelete', async (message) => {
if (message.author?.bot) return;
L.info(`Deleting message ${message.id}`);
if (!(message.guildId && message.content)) {
return;
@@ -382,6 +409,7 @@ client.on('messageDelete', async (message) => {
* #v3-complete
*/
client.on('messageUpdate', async (oldMessage, newMessage) => {
if (oldMessage.author?.bot) return;
L.info(`Editing message ${oldMessage.id}`);
if (!(oldMessage.guildId && oldMessage.content && newMessage.content)) {
return;
@@ -391,6 +419,33 @@ client.on('messageUpdate', async (oldMessage, newMessage) => {
await markov.addData([newMessage.content]);
});
client.on('interactionCreate', async (interaction) => {
if (!interaction.isCommand()) return;
// Unprivileged commands
if (interaction.commandName === helpCommand.name) {
await interaction.reply(helpMessage());
} else if (interaction.commandName === inviteCommand.name) {
await interaction.reply(inviteMessage());
} else if (interaction.commandName === messageCommand.name) {
await interaction.deferReply();
const tts = interaction.options.getBoolean('tts') || false;
const debug = interaction.options.getBoolean('debug') || false;
const generatedResponse = await generateResponse(interaction, debug, tts);
if (generatedResponse.message) await interaction.editReply(generatedResponse.message);
if (generatedResponse.debug) await interaction.followUp(generatedResponse.debug);
if (!Object.keys(generatedResponse).length) await interaction.deleteReply();
}
// Privileged commands
if (interaction.commandName === listenChannelCommand.name) {
await interaction.reply('Pong!');
} else if (interaction.commandName === trainCommand.name) {
await interaction.deferReply();
const responseMessage = await saveGuildMessageHistory(interaction);
await interaction.editReply(responseMessage);
}
});
/**
* Loads the config settings from disk
*/