diff --git a/config/db/db.sqlite3 b/config/db/db.sqlite3 deleted file mode 100644 index cfd0d7d..0000000 Binary files a/config/db/db.sqlite3 and /dev/null differ diff --git a/config/db/db.sqlite3-shm b/config/db/db.sqlite3-shm deleted file mode 100644 index 308f44c..0000000 Binary files a/config/db/db.sqlite3-shm and /dev/null differ diff --git a/config/db/db.sqlite3-wal b/config/db/db.sqlite3-wal deleted file mode 100644 index 6be2482..0000000 Binary files a/config/db/db.sqlite3-wal and /dev/null differ diff --git a/src/config/classes.ts b/src/config/classes.ts index 40b6980..039ae43 100644 --- a/src/config/classes.ts +++ b/src/config/classes.ts @@ -37,11 +37,21 @@ export class AppConfig { * The command prefix used to trigger the bot commands (when not using slash commands) * @example !bot * @default !mark - * @env CRON_SCHEDULE + * @env MESSAGE_COMMAND_PREFIX */ @IsOptional() @IsString() - commandPrefix = process.env.COMMAND_PREFIX || '!mark'; + messageCommandPrefix = process.env.MESSAGE_COMMAND_PREFIX || '!mark'; + + /** + * The slash command name to generate a message from the bot. (e.g. `/mark`) + * @example message + * @default mark + * @env SLASH_COMMAND_NAME + */ + @IsOptional() + @IsString() + slashCommandName = process.env.SLASH_COMMAND_NAME || 'mark'; /** * The activity status shown under the bot's name in the user list diff --git a/src/deploy-commands.ts b/src/deploy-commands.ts index 8f5de61..0449453 100644 --- a/src/deploy-commands.ts +++ b/src/deploy-commands.ts @@ -6,10 +6,27 @@ import { packageJson } from './util'; const CHANNEL_OPTIONS_MAX = 25; -const helpSlashCommand = new SlashCommandBuilder() +export const helpCommand = new SlashCommandBuilder() .setName('help') .setDescription(`How to use ${packageJson().name}`); +export const inviteCommand = new SlashCommandBuilder() + .setName('invite') + .setDescription('Get the invite link for this bot.'); + +export const messageCommand = new SlashCommandBuilder() + .setName(config.slashCommandName) + .setDescription('Generate a message from learned past messages') + .addBooleanOption((tts) => + tts.setName('tts').setDescription('Read the message via text-to-speech.').setRequired(false) + ) + .addBooleanOption((debug) => + debug + .setName('debug') + .setDescription('Follow up the generated message with the detailed sources that inspired it.') + .setRequired(false) + ); + /** * Helps generate a list of parameters for channel options */ @@ -20,15 +37,15 @@ const channelOptionsGenerator = (builder: SlashCommandChannelOption, index: numb .setRequired(index === 0) .addChannelType(ChannelType.GuildText as any); -const listenChannelCommand = new SlashCommandBuilder() +export const listenChannelCommand = new SlashCommandBuilder() .setName('listen') + .setDescription('Change what channels the bot actively listens to and learns from.') .addSubcommand((sub) => { sub .setName('add') .setDescription( `Add channels to learn from. Doesn't add the channel's past messages; re-train to do that.` ); - Array.from(Array(CHANNEL_OPTIONS_MAX).keys()).forEach((index) => sub.addChannelOption((opt) => channelOptionsGenerator(opt, index)) ); @@ -45,9 +62,25 @@ const listenChannelCommand = new SlashCommandBuilder() ); return sub; }) - .setDescription(`How to use ${packageJson().name}`); + .addSubcommand((sub) => + sub + .setName('list') + .setDescription(`List the channels the bot is currently actively listening to.`) + ); -const commands = [helpSlashCommand.toJSON(), listenChannelCommand.toJSON()]; +export const trainCommand = new SlashCommandBuilder() + .setName('train') + .setDescription( + 'Train from past messages from the configured listened channels. This takes a while.' + ); + +const commands = [ + helpCommand.toJSON(), + inviteCommand.toJSON(), + messageCommand.toJSON(), + listenChannelCommand.toJSON(), + trainCommand.toJSON(), +]; export async function deployCommands(clientId: string) { const rest = new REST({ version: '9' }).setToken(config.token); diff --git a/src/index.ts b/src/index.ts index 485efe4..138986c 100644 --- a/src/index.ts +++ b/src/index.ts @@ -16,7 +16,14 @@ import L from './logger'; import { Channel } from './entity/Channel'; import { Guild } from './entity/Guild'; import { config } from './config'; -import { deployCommands } from './deploy-commands'; +import { + deployCommands, + helpCommand, + inviteCommand, + listenChannelCommand, + messageCommand, + trainCommand, +} from './deploy-commands'; import { getRandomElement, getVersion, packageJson } from './util'; interface MarkovDataCustom { @@ -71,6 +78,18 @@ async function isValidChannel(channelId: string): Promise { return channel.listen; } +/** + * #v3-complete + */ +async function getValidChannels(guildId: string): Promise { + const id = parseInt(guildId, 10); + const dbChannels = await Channel.find({ guild: Guild.create({ id }), listen: true }); + const channels = ( + await Promise.all(dbChannels.map(async (dbc) => client.channels.fetch(dbc.id.toString()))) + ).filter((c): c is Discord.TextChannel => c !== null && c instanceof Discord.TextChannel); + return channels; +} + /** * Checks if the author of a message as moderator-like permissions. * @param {GuildMember} member Sender of the message @@ -106,10 +125,10 @@ type MessageCommands = 'respond' | 'train' | 'help' | 'invite' | 'debug' | 'tts' function validateMessage(message: Discord.Message): MessageCommands { const messageText = message.content.toLowerCase(); let command: MessageCommands = null; - const thisPrefix = messageText.substring(0, config.commandPrefix.length); - if (thisPrefix === config.commandPrefix) { + const thisPrefix = messageText.substring(0, config.messageCommandPrefix.length); + if (thisPrefix === config.messageCommandPrefix) { const split = messageText.split(' '); - if (split[0] === config.commandPrefix && split.length === 1) { + if (split[0] === config.messageCommandPrefix && split.length === 1) { command = 'respond'; } else if (split[1] === 'train') { command = 'train'; @@ -140,44 +159,59 @@ function messageToData(message: Discord.Message): AddDataProps { * Recursively gets all messages in a text channel's history. * #v3-complete */ -async function saveChannelMessageHistory( - channel: Discord.TextChannel, +async function saveGuildMessageHistory( interaction: Discord.Message | Discord.CommandInteraction -): Promise { - if (!isModerator(interaction.member as any)) return; - const markov = await getMarkovByGuildId(channel.guildId); - L.debug({ channelId: channel.id }, `Training from text channel`); +): Promise { + if (!isModerator(interaction.member as any)) + return 'You do not have the permissions for this action.'; + if (!interaction.guildId) return 'This action must be performed within a server.'; + const markov = await getMarkovByGuildId(interaction.guildId); + const channels = await getValidChannels(interaction.guildId); + + if (!channels.length) { + L.warn({ guildId: interaction.guildId }, 'No channels to train from'); + return 'No channels configured to learn from. Set some with `/listen add`.'; + } + + const channelIds = channels.map((c) => c.id); + L.debug({ channelIds }, `Training from text channels`); + const PAGE_SIZE = 100; - let keepGoing = true; - let oldestMessageID: string | undefined; + let messagesCount = 0; + // eslint-disable-next-line no-restricted-syntax + for (const channel of channels) { + let oldestMessageID: string | undefined; + let keepGoing = true; + L.debug({ channelId: channel.id, messagesCount }, `Training from channel`); - let channelMessagesCount = 0; - - while (keepGoing) { - // eslint-disable-next-line no-await-in-loop - const messages = await channel.messages.fetch({ - before: oldestMessageID, - limit: PAGE_SIZE, - }); - const nonBotMessageFormatted = messages.filter((elem) => !elem.author.bot).map(messageToData); - L.debug({ oldestMessageID }, `Saving ${nonBotMessageFormatted.length} messages`); - // eslint-disable-next-line no-await-in-loop - await markov.addData(nonBotMessageFormatted); - L.trace('Finished saving messages'); - channelMessagesCount += nonBotMessageFormatted.length; - const lastMessage = messages.last(); - if (!lastMessage || messages.size < PAGE_SIZE) { - keepGoing = false; - } else { - oldestMessageID = lastMessage.id; + while (keepGoing) { + // eslint-disable-next-line no-await-in-loop + const messages = await channel.messages.fetch({ + before: oldestMessageID, + limit: PAGE_SIZE, + }); + const nonBotMessageFormatted = messages.filter((elem) => !elem.author.bot).map(messageToData); + L.trace({ oldestMessageID }, `Saving ${nonBotMessageFormatted.length} messages`); + // eslint-disable-next-line no-await-in-loop + await markov.addData(nonBotMessageFormatted); + L.trace('Finished saving messages'); + messagesCount += nonBotMessageFormatted.length; + const lastMessage = messages.last(); + if (!lastMessage || messages.size < PAGE_SIZE) { + keepGoing = false; + } else { + oldestMessageID = lastMessage.id; + } } } - L.info( - { channelId: channel.id }, - `Trained from ${channelMessagesCount} past human authored messages.` - ); - await interaction.reply(`Trained from ${channelMessagesCount} past human authored messages.`); + L.info({ channelIds }, `Trained from ${messagesCount} past human authored messages.`); + return `Trained from ${messagesCount} past human authored messages.`; +} + +interface GenerateResponse { + message?: Discord.MessageOptions; + debug?: Discord.MessageOptions; } /** @@ -192,23 +226,15 @@ async function generateResponse( interaction: Discord.Message | Discord.CommandInteraction, debug = false, tts = false -): Promise { +): Promise { L.debug('Responding...'); if (!interaction.guildId) { - L.debug('Received an interaction without a guildId'); - return; + L.warn('Received an interaction without a guildId'); + return { message: { content: 'This action must be performed within a server.' } }; } if (!interaction.channelId) { - L.debug('Received an interaction without a channelId'); - return; - } - const isValid = await isValidChannel(interaction.channelId); - if (!isValid) { - L.debug( - { channelId: interaction.channelId }, - 'Channel is not enabled for listening. Ignoring...' - ); - return; + L.warn('Received an interaction without a channelId'); + return { message: { content: 'This action must be performed within a text channel.' } }; } const markov = await getMarkovByGuildId(interaction.guildId); @@ -251,26 +277,19 @@ async function generateResponse( response.string = response.string.replace(/@everyone/g, '@everyοne'); // Replace @everyone with a homoglyph 'o' messageOpts.content = response.string; - if (interaction instanceof Discord.Message) { - await interaction.channel.send(messageOpts); - if (debug) { - await interaction.channel.send(`\`\`\`\n${JSON.stringify(response, null, 2)}\n\`\`\``); - } - } else if (interaction instanceof Discord.CommandInteraction) { - await interaction.editReply(messageOpts); - if (debug) { - await interaction.followUp(`\`\`\`\n${JSON.stringify(response, null, 2)}\n\`\`\``); - } + const responseMessages: GenerateResponse = { + message: messageOpts, + }; + if (debug) { + responseMessages.debug = { content: `\`\`\`\n${JSON.stringify(response, null, 2)}\n\`\`\`` }; } + return responseMessages; } catch (err) { L.error(err); if (debug) { - if (interaction instanceof Discord.Message) { - await interaction.channel.send(`\n\`\`\`\nERROR: ${err}\n\`\`\``); - } else if (interaction instanceof Discord.CommandInteraction) { - await interaction.editReply(`\n\`\`\`\nERROR: ${err}\n\`\`\``); - } + return { debug: { content: `\n\`\`\`\nERROR: ${err}\n\`\`\`` } }; } + return {}; } } @@ -281,25 +300,29 @@ function helpMessage(): Discord.MessageOptions { .setThumbnail(avatarURL as string) .setDescription('A Markov chain chatbot that speaks based on previous chat input.') .addField( - `${config.commandPrefix}`, + `${config.messageCommandPrefix} or /${messageCommand.name}`, 'Generates a sentence to say based on the chat database. Send your ' + 'message as TTS to recieve it as TTS.' ) .addField( - `${config.commandPrefix} train`, + `${config.messageCommandPrefix} train or /${trainCommand.name}`, 'Fetches the maximum amount of previous messages in the current ' + 'text channel, adds it to the database, and regenerates the corpus. Takes some time.' ) .addField( - `${config.commandPrefix} invite`, + `${config.messageCommandPrefix} invite or /${inviteCommand.name}`, "Don't invite this bot to other servers. The database is shared " + 'between all servers and text channels.' ) .addField( - `${config.commandPrefix} debug`, - `Runs the ${config.commandPrefix} command and follows it up with debug info.` + `${config.messageCommandPrefix} debug or /${messageCommand.name} debug: True`, + `Runs the ${config.messageCommandPrefix} command and follows it up with debug info.` ) - .setFooter(`Markov Discord ${getVersion()} by ${packageJson().author}`); + .addField( + `${config.messageCommandPrefix} tts or /${messageCommand.name} tts: True`, + `Runs the ${config.messageCommandPrefix} command and reads it with text-to-speech.` + ) + .setFooter(`${packageJson().name} ${getVersion()} by ${packageJson().author}`); return { embeds: [embed], }; @@ -342,10 +365,13 @@ client.on('messageCreate', async (message) => { await message.channel.send(inviteMessage()); } if (command === 'train') { - await saveChannelMessageHistory(message.channel, message); + const response = await saveGuildMessageHistory(message); + await message.reply(response); } if (command === 'respond') { - await generateResponse(message); + const generatedResponse = await generateResponse(message); + if (generatedResponse.message) await message.reply(generatedResponse.message); + if (generatedResponse.debug) await message.reply(generatedResponse.debug); } if (command === 'tts') { await generateResponse(message, false, true); @@ -370,6 +396,7 @@ client.on('messageCreate', async (message) => { * #v3-complete */ client.on('messageDelete', async (message) => { + if (message.author?.bot) return; L.info(`Deleting message ${message.id}`); if (!(message.guildId && message.content)) { return; @@ -382,6 +409,7 @@ client.on('messageDelete', async (message) => { * #v3-complete */ client.on('messageUpdate', async (oldMessage, newMessage) => { + if (oldMessage.author?.bot) return; L.info(`Editing message ${oldMessage.id}`); if (!(oldMessage.guildId && oldMessage.content && newMessage.content)) { return; @@ -391,6 +419,33 @@ client.on('messageUpdate', async (oldMessage, newMessage) => { await markov.addData([newMessage.content]); }); +client.on('interactionCreate', async (interaction) => { + if (!interaction.isCommand()) return; + + // Unprivileged commands + if (interaction.commandName === helpCommand.name) { + await interaction.reply(helpMessage()); + } else if (interaction.commandName === inviteCommand.name) { + await interaction.reply(inviteMessage()); + } else if (interaction.commandName === messageCommand.name) { + await interaction.deferReply(); + const tts = interaction.options.getBoolean('tts') || false; + const debug = interaction.options.getBoolean('debug') || false; + const generatedResponse = await generateResponse(interaction, debug, tts); + if (generatedResponse.message) await interaction.editReply(generatedResponse.message); + if (generatedResponse.debug) await interaction.followUp(generatedResponse.debug); + if (!Object.keys(generatedResponse).length) await interaction.deleteReply(); + } + // Privileged commands + if (interaction.commandName === listenChannelCommand.name) { + await interaction.reply('Pong!'); + } else if (interaction.commandName === trainCommand.name) { + await interaction.deferReply(); + const responseMessage = await saveGuildMessageHistory(interaction); + await interaction.editReply(responseMessage); + } +}); + /** * Loads the config settings from disk */