diff --git a/package-lock.json b/package-lock.json index ca41785..d13d22c 100644 --- a/package-lock.json +++ b/package-lock.json @@ -2669,7 +2669,7 @@ "node_modules/markov-strings-db": { "version": "4.0.0", "resolved": "file:../markov-strings/markov-strings-db-4.0.0.tgz", - "integrity": "sha512-CBYNkqUqj0XVohyBLz6kJL81VKzh+8xLcN6vp0ojps/AjqmycKHmj/xZWdCZjc72X7r85UaLnJ6L7QqnW+xPEw==", + "integrity": "sha512-AB1Sp0ukD+DpjeYFeiPhRgZXou6tUrmNn85dFBI2wAcCj2mzlolsTWV1zBhL0jmPtMoX7xwwf4FhDefMtY+E7A==", "license": "MIT", "dependencies": { "reflect-metadata": "^0.1.13", @@ -6409,7 +6409,7 @@ }, "markov-strings-db": { "version": "file:../markov-strings/markov-strings-db-4.0.0.tgz", - "integrity": "sha512-CBYNkqUqj0XVohyBLz6kJL81VKzh+8xLcN6vp0ojps/AjqmycKHmj/xZWdCZjc72X7r85UaLnJ6L7QqnW+xPEw==", + "integrity": "sha512-AB1Sp0ukD+DpjeYFeiPhRgZXou6tUrmNn85dFBI2wAcCj2mzlolsTWV1zBhL0jmPtMoX7xwwf4FhDefMtY+E7A==", "requires": { "reflect-metadata": "^0.1.13", "typeorm": "^0.2.41" diff --git a/src/deploy-commands.ts b/src/deploy-commands.ts index 0449453..317d990 100644 --- a/src/deploy-commands.ts +++ b/src/deploy-commands.ts @@ -4,7 +4,7 @@ import { ChannelType, Routes } from 'discord-api-types/v9'; import { config } from './config'; import { packageJson } from './util'; -const CHANNEL_OPTIONS_MAX = 25; +export const CHANNEL_OPTIONS_MAX = 25; export const helpCommand = new SlashCommandBuilder() .setName('help') diff --git a/src/entity/Channel.ts b/src/entity/Channel.ts index 080469a..caa7fbf 100644 --- a/src/entity/Channel.ts +++ b/src/entity/Channel.ts @@ -4,8 +4,8 @@ import { Guild } from './Guild'; @Entity() export class Channel extends BaseEntity { - @PrimaryColumn() - id: number; + @PrimaryColumn({ type: 'text' }) + id: string; @Column({ default: false, diff --git a/src/entity/Guild.ts b/src/entity/Guild.ts index 1bd8f41..d6f34d1 100644 --- a/src/entity/Guild.ts +++ b/src/entity/Guild.ts @@ -4,8 +4,8 @@ import { Channel } from './Channel'; @Entity() export class Guild extends BaseEntity { - @PrimaryColumn() - id: number; + @PrimaryColumn({ type: 'text' }) + id: string; @OneToMany(() => Channel, (channel) => channel.guild, { onDelete: 'CASCADE', cascade: true }) channels: Channel[]; diff --git a/src/index.ts b/src/index.ts index 138986c..1aa5fe0 100644 --- a/src/index.ts +++ b/src/index.ts @@ -17,6 +17,7 @@ import { Channel } from './entity/Channel'; import { Guild } from './entity/Guild'; import { config } from './config'; import { + CHANNEL_OPTIONS_MAX, deployCommands, helpCommand, inviteCommand, @@ -30,6 +31,9 @@ interface MarkovDataCustom { attachments: string[]; } +const INVALID_PERMISSIONS_MESSAGE = 'You do not have the permissions for this action.'; +const INVALID_GUILD_MESSAGE = 'This action must be performed within a server.'; + const client = new Discord.Client({ intents: [Discord.Intents.FLAGS.GUILD_MESSAGES, Discord.Intents.FLAGS.GUILDS], presence: { @@ -54,47 +58,46 @@ const markovGenerateOptions: MarkovGenerateOptions = { maxTries: config.maxTries, }; -/** - * #v3-complete - */ async function getMarkovByGuildId(guildId: string): Promise { - const id = parseInt(guildId, 10); - const markov = new Markov({ id, options: markovOpts }); + const markov = new Markov({ id: guildId, options: { ...markovOpts, id: guildId } }); await markov.setup(); // Connect the markov instance to the DB to assign it an ID return markov; } -/** - * #v3-complete - */ -async function isValidChannel(channelId: string): Promise { - const id = parseInt(channelId, 10); - const channel = await Channel.findOne(id); - if (!channel) { - L.warn({ channelId }, 'Channel does not exist, setting to valid'); - await Channel.create({ id }).save(); - return false; - } - return channel.listen; -} - -/** - * #v3-complete - */ -async function getValidChannels(guildId: string): Promise { - const id = parseInt(guildId, 10); - const dbChannels = await Channel.find({ guild: Guild.create({ id }), listen: true }); +async function getValidChannels(guild: Discord.Guild): Promise { + const dbChannels = await Channel.find({ guild: Guild.create({ id: guild.id }), listen: true }); const channels = ( - await Promise.all(dbChannels.map(async (dbc) => client.channels.fetch(dbc.id.toString()))) + await Promise.all( + dbChannels.map(async (dbc) => { + return guild.channels.fetch(dbc.id.toString()); + }) + ) ).filter((c): c is Discord.TextChannel => c !== null && c instanceof Discord.TextChannel); return channels; } +async function addValidChannels(channels: Discord.TextChannel[], guildId: string): Promise { + const dbChannels = channels.map((c) => { + return Channel.create({ id: c.id, guild: Guild.create({ id: guildId }), listen: true }); + }); + await Channel.save(dbChannels); +} + +async function removeValidChannels( + channels: Discord.TextChannel[], + guildId: string +): Promise { + const dbChannels = channels.map((c) => { + return Channel.create({ id: c.id, guild: Guild.create({ id: guildId }), listen: false }); + }); + await Channel.save(dbChannels); +} + /** * Checks if the author of a message as moderator-like permissions. * @param {GuildMember} member Sender of the message * @return {Boolean} True if the sender is a moderator. - * #v3-complete + * */ function isModerator(member: Discord.GuildMember | APIInteractionGuildMember | null): boolean { const MODERATOR_PERMISSIONS: Discord.PermissionResolvable[] = [ @@ -157,16 +160,14 @@ function messageToData(message: Discord.Message): AddDataProps { /** * Recursively gets all messages in a text channel's history. - * #v3-complete */ async function saveGuildMessageHistory( interaction: Discord.Message | Discord.CommandInteraction ): Promise { - if (!isModerator(interaction.member as any)) - return 'You do not have the permissions for this action.'; - if (!interaction.guildId) return 'This action must be performed within a server.'; + if (!isModerator(interaction.member as any)) return INVALID_PERMISSIONS_MESSAGE; + if (!interaction.guildId || !interaction.guild) return INVALID_GUILD_MESSAGE; const markov = await getMarkovByGuildId(interaction.guildId); - const channels = await getValidChannels(interaction.guildId); + const channels = await getValidChannels(interaction.guild); if (!channels.length) { L.warn({ guildId: interaction.guildId }, 'No channels to train from'); @@ -212,6 +213,7 @@ async function saveGuildMessageHistory( interface GenerateResponse { message?: Discord.MessageOptions; debug?: Discord.MessageOptions; + error?: Discord.MessageOptions; } /** @@ -220,7 +222,6 @@ interface GenerateResponse { * @param debug Sends debug info as a message if true. * @param tts If the message should be sent as TTS. Defaults to the TTS setting of the * invoking message. - * #v3-complete */ async function generateResponse( interaction: Discord.Message | Discord.CommandInteraction, @@ -230,7 +231,7 @@ async function generateResponse( L.debug('Responding...'); if (!interaction.guildId) { L.warn('Received an interaction without a guildId'); - return { message: { content: 'This action must be performed within a server.' } }; + return { message: { content: INVALID_GUILD_MESSAGE } }; } if (!interaction.channelId) { L.warn('Received an interaction without a channelId'); @@ -286,13 +287,31 @@ async function generateResponse( return responseMessages; } catch (err) { L.error(err); - if (debug) { - return { debug: { content: `\n\`\`\`\nERROR: ${err}\n\`\`\`` } }; - } - return {}; + return { error: { content: `\n\`\`\`\nERROR: ${err}\n\`\`\`` } }; } } +async function listValidChannels(interaction: Discord.CommandInteraction): Promise { + if (!interaction.guildId || !interaction.guild) return INVALID_GUILD_MESSAGE; + const channels = await getValidChannels(interaction.guild); + const channelText = channels.reduce((list, channel) => { + return `${list}\n • <#${channel.id}>`; + }, ''); + return `This bot is currently listening and learning from ${channels.length} channel(s).${channelText}`; +} + +function getChannelsFromInteraction( + interaction: Discord.CommandInteraction +): Discord.TextChannel[] { + const channels = Array.from(Array(CHANNEL_OPTIONS_MAX).keys()).map((index) => + interaction.options.getChannel(`channel-${index + 1}`, index === 0) + ); + const textChannels = channels.filter( + (c): c is Discord.TextChannel => c !== null && c instanceof Discord.TextChannel + ); + return textChannels; +} + function helpMessage(): Discord.MessageOptions { const avatarURL = client.user.avatarURL() || undefined; const embed = new Discord.MessageEmbed() @@ -345,12 +364,15 @@ client.on('ready', async (readyClient) => { await deployCommands(readyClient.user.id); - const guildsToSave = readyClient.guilds - .valueOf() - .map((guild) => Guild.create({ id: parseInt(guild.id, 10) })); + const guildsToSave = readyClient.guilds.valueOf().map((guild) => Guild.create({ id: guild.id })); await Guild.upsert(guildsToSave, ['id']); }); +client.on('guildCreate', async (guild) => { + L.info({ guildId: guild.id }, 'Adding new guild'); + await Guild.upsert(Guild.create({ id: guild.id }), ['id']); +}); + client.on('error', (err) => { L.error(err); }); @@ -372,6 +394,7 @@ client.on('messageCreate', async (message) => { const generatedResponse = await generateResponse(message); if (generatedResponse.message) await message.reply(generatedResponse.message); if (generatedResponse.debug) await message.reply(generatedResponse.debug); + if (generatedResponse.error) await message.reply(generatedResponse.error); } if (command === 'tts') { await generateResponse(message, false, true); @@ -380,8 +403,8 @@ client.on('messageCreate', async (message) => { await generateResponse(message, true); } if (command === null) { - L.debug('Listening...'); if (!message.author.bot) { + L.debug('Listening...'); const markov = await getMarkovByGuildId(message.channel.guildId); await markov.addData([messageToData(message)]); @@ -392,9 +415,6 @@ client.on('messageCreate', async (message) => { } }); -/** - * #v3-complete - */ client.on('messageDelete', async (message) => { if (message.author?.bot) return; L.info(`Deleting message ${message.id}`); @@ -405,9 +425,6 @@ client.on('messageDelete', async (message) => { await markov.removeData([message.content]); }); -/** - * #v3-complete - */ client.on('messageUpdate', async (oldMessage, newMessage) => { if (oldMessage.author?.bot) return; L.info(`Editing message ${oldMessage.id}`); @@ -422,7 +439,6 @@ client.on('messageUpdate', async (oldMessage, newMessage) => { client.on('interactionCreate', async (interaction) => { if (!interaction.isCommand()) return; - // Unprivileged commands if (interaction.commandName === helpCommand.name) { await interaction.reply(helpMessage()); } else if (interaction.commandName === inviteCommand.name) { @@ -433,12 +449,40 @@ client.on('interactionCreate', async (interaction) => { const debug = interaction.options.getBoolean('debug') || false; const generatedResponse = await generateResponse(interaction, debug, tts); if (generatedResponse.message) await interaction.editReply(generatedResponse.message); + else await interaction.deleteReply(); if (generatedResponse.debug) await interaction.followUp(generatedResponse.debug); - if (!Object.keys(generatedResponse).length) await interaction.deleteReply(); - } - // Privileged commands - if (interaction.commandName === listenChannelCommand.name) { - await interaction.reply('Pong!'); + if (generatedResponse.error) { + await interaction.followUp({ ...generatedResponse.error, ephemeral: true }); + } + } else if (interaction.commandName === listenChannelCommand.name) { + await interaction.deferReply(); + const subCommand = interaction.options.getSubcommand(true) as 'add' | 'remove' | 'list'; + if (subCommand === 'list') { + const reply = await listValidChannels(interaction); + await interaction.editReply(reply); + } else if (subCommand === 'add') { + if (!isModerator(interaction.member as any)) { + await interaction.deleteReply(); + await interaction.followUp({ content: INVALID_PERMISSIONS_MESSAGE, ephemeral: true }); + return; + } + const channels = getChannelsFromInteraction(interaction); + await addValidChannels(channels, interaction.guildId); + await interaction.editReply( + `Added ${channels.length} text channels to the list. Use \`/train\` to update the past known messages.` + ); + } else if (subCommand === 'remove') { + if (!isModerator(interaction.member as any)) { + await interaction.deleteReply(); + await interaction.followUp({ content: INVALID_PERMISSIONS_MESSAGE, ephemeral: true }); + return; + } + const channels = getChannelsFromInteraction(interaction); + await removeValidChannels(channels, interaction.guildId); + await interaction.editReply( + `Removed ${channels.length} text channels from the list. Use \`/train\` to remove these channels from the past known messages.` + ); + } } else if (interaction.commandName === trainCommand.name) { await interaction.deferReply(); const responseMessage = await saveGuildMessageHistory(interaction);