diff --git a/src/index.ts b/src/index.ts index 3adf757..6595540 100644 --- a/src/index.ts +++ b/src/index.ts @@ -72,9 +72,25 @@ async function getMarkovByGuildId(guildId: string): Promise { return markov; } -async function isValidChannel(channelId: string): Promise { - const channel = await Channel.findOne(channelId); - return channel?.listen || false; +/** + * Returns a thread channels parent guild channel ID, otherwise it just returns a channel ID + */ +function getGuildChannelId(channel: Discord.TextBasedChannel): string | null { + if (channel.isThread()) { + return channel.parentId; + } + return channel.id; +} + +async function isValidChannel(channel: Discord.TextBasedChannel): Promise { + const channelId = getGuildChannelId(channel); + if (!channelId) return false; + const dbChannel = await Channel.findOne(channelId); + return dbChannel?.listen || false; +} + +function isHumanAuthoredMessage(message: Discord.Message | Discord.PartialMessage): boolean { + return !(message.author?.bot || message.system); } async function getValidChannels(guild: Discord.Guild): Promise { @@ -206,8 +222,10 @@ function messageToData(message: Discord.Message): AddDataProps { let custom: MarkovDataCustom | undefined; if (attachmentUrls.length) custom = { attachments: attachmentUrls }; const tags: string[] = [message.id]; - if (message.channelId) tags.push(message.channelId); - if (message.guildId) tags.push(message.guildId); + if (message.channel.isThread()) tags.push(message.channelId); // Add thread channel ID + const channelId = getGuildChannelId(message.channel); + if (channelId) tags.push(channelId); // Add guild channel ID + if (message.guildId) tags.push(message.guildId); // Add guild ID return { string: message.content, custom, @@ -286,28 +304,79 @@ async function saveGuildMessageHistory( const channelEta = makeEta({ autostart: true, min: 0, max: 1, historyTimeConstant: 30 }); while (keepGoing) { - let messages; + let allBatchMessages = new Discord.Collection>(); + let channelBatchMessages: Discord.Collection>; try { // eslint-disable-next-line no-await-in-loop - messages = await channel.messages.fetch({ + channelBatchMessages = await channel.messages.fetch({ before: oldestMessageID, limit: PAGE_SIZE, }); } catch (err) { L.error({ before: oldestMessageID, limit: PAGE_SIZE }, 'Error retreiving messages'); L.error(err); - break; + break; // Give up on this channel } - const nonBotMessageFormatted = messages.filter((elem) => !elem.author.bot).map(messageToData); - L.trace({ oldestMessageID }, `Saving ${nonBotMessageFormatted.length} messages`); + + // Gather any thread messages if present in this message batch + const threadChannels = channelBatchMessages + .filter((m) => m.hasThread) + .map((m) => m.thread) + .filter((c): c is Discord.ThreadChannel => c !== null); + if (threadChannels.length > 0) { + L.debug(`Found ${threadChannels.length} threads. Reading into them.`); + // eslint-disable-next-line no-restricted-syntax + for (const threadChannel of threadChannels) { + let oldestThreadMessageID: string | undefined; + let keepGoingThread = true; + L.debug({ channelId: threadChannel.id }, `Training from thread`); + + while (keepGoingThread) { + let threadBatchMessages: Discord.Collection>; + try { + // eslint-disable-next-line no-await-in-loop + threadBatchMessages = await threadChannel.messages.fetch({ + before: oldestThreadMessageID, + limit: PAGE_SIZE, + }); + } catch (err) { + L.error( + { before: oldestThreadMessageID, limit: PAGE_SIZE }, + 'Error retreiving thread messages' + ); + L.error(err); + break; // Give up on this thread + } + L.trace( + { threadMessagesCount: threadBatchMessages.size }, + `Found some thread messages` + ); + const lastThreadMessage = threadBatchMessages.last(); + allBatchMessages = allBatchMessages.concat(threadBatchMessages); // Add the thread messages to this message batch to be included in later processing + if (!lastThreadMessage || threadBatchMessages.size < PAGE_SIZE) { + keepGoingThread = false; + } else { + oldestThreadMessageID = lastThreadMessage.id; + } + } + } + } + + allBatchMessages = allBatchMessages.concat(channelBatchMessages); + + // Filter and data map messages to be ready for addition to the corpus + const humanAuthoredMessages = allBatchMessages + .filter((m) => isHumanAuthoredMessage(m)) + .map(messageToData); + L.trace({ oldestMessageID }, `Saving ${humanAuthoredMessages.length} messages`); // eslint-disable-next-line no-await-in-loop - await markov.addData(nonBotMessageFormatted); + await markov.addData(humanAuthoredMessages); L.trace('Finished saving messages'); - messagesCount += nonBotMessageFormatted.length; - const lastMessage = messages.last(); + messagesCount += humanAuthoredMessages.length; + const lastMessage = channelBatchMessages.last(); // Update tracking metrics - if (!lastMessage || messages.size < PAGE_SIZE) { + if (!lastMessage || channelBatchMessages.size < PAGE_SIZE) { keepGoing = false; if (completedChannelsField.value === 'None') completedChannelsField.value = ''; completedChannelsField.value += `\n • <#${channel.id}>`; @@ -315,7 +384,7 @@ async function saveGuildMessageHistory( oldestMessageID = lastMessage.id; } currentChannelField.value = `<#${channel.id}>`; - if (!firstMessageDate) firstMessageDate = messages.first()?.createdTimestamp; + if (!firstMessageDate) firstMessageDate = channelBatchMessages.first()?.createdTimestamp; const oldestMessageDate = lastMessage?.createdTimestamp; if (firstMessageDate && oldestMessageDate) { const channelAge = firstMessageDate - channelCreateDate; @@ -378,10 +447,6 @@ async function generateResponse( L.warn('Received an interaction without a guildId'); return { error: { content: INVALID_GUILD_MESSAGE } }; } - if (!interaction.channelId) { - L.warn('Received an interaction without a channelId'); - return { error: { content: 'This action must be performed within a text channel.' } }; - } if (!isAllowedUser(interaction.member)) { L.info('Member does not have permissions to generate a response'); return { error: { content: INVALID_PERMISSIONS_MESSAGE } }; @@ -573,7 +638,14 @@ client.on('warn', (m) => L.warn(m)); client.on('error', (m) => L.error(m)); client.on('messageCreate', async (message) => { - if (!(message.guild && message.channel instanceof Discord.TextChannel)) return; + if ( + !( + message.guild && + (message.channel instanceof Discord.TextChannel || + message.channel instanceof Discord.ThreadChannel) + ) + ) + return; const command = validateMessage(message); if (command !== null) L.info({ command }, 'Recieved message command'); if (command === 'help') { @@ -602,7 +674,7 @@ client.on('messageCreate', async (message) => { await handleResponseMessage(generatedResponse, message); } if (command === null) { - if (!message.author.bot) { + if (isHumanAuthoredMessage(message)) { if (client.user && message.mentions.has(client.user)) { L.debug('Responding to mention'); // <@!278354154563567636> how are you doing? @@ -611,7 +683,7 @@ client.on('messageCreate', async (message) => { await handleResponseMessage(generatedResponse, message); } - if (await isValidChannel(message.channelId)) { + if (await isValidChannel(message.channel)) { L.debug('Listening'); const markov = await getMarkovByGuildId(message.channel.guildId); await markov.addData([messageToData(message)]); @@ -621,26 +693,35 @@ client.on('messageCreate', async (message) => { }); client.on('messageDelete', async (message) => { - if (message.author?.bot) return; - if (!(await isValidChannel(message.channelId))) return; - if (!(message.guildId && message.content)) return; + if (!isHumanAuthoredMessage(message)) return; + if (!(await isValidChannel(message.channel))) return; + if (!message.guildId) return; L.debug(`Deleting message ${message.id}`); const markov = await getMarkovByGuildId(message.guildId); - await markov.removeStrings([message.content]); + await markov.removeTags([message.id]); }); client.on('messageUpdate', async (oldMessage, newMessage) => { - if (oldMessage.author?.bot) return; - if (!(await isValidChannel(oldMessage.channelId))) return; - if (!(oldMessage.guildId && oldMessage.content && newMessage.content)) return; + if (!isHumanAuthoredMessage(oldMessage)) return; + if (!(await isValidChannel(oldMessage.channel))) return; + if (!(oldMessage.guildId && newMessage.content)) return; L.debug(`Editing message ${oldMessage.id}`); const markov = await getMarkovByGuildId(oldMessage.guildId); - await markov.removeStrings([oldMessage.content]); + await markov.removeTags([oldMessage.id]); await markov.addData([newMessage.content]); }); +client.on('threadDelete', async (thread) => { + if (!(await isValidChannel(thread))) return; + if (!thread.guildId) return; + + L.debug(`Deleting thread messages ${thread.id}`); + const markov = await getMarkovByGuildId(thread.guildId); + await markov.removeTags([thread.id]); +}); + // eslint-disable-next-line consistent-return client.on('interactionCreate', async (interaction) => { if (interaction.isCommand()) { @@ -725,7 +806,7 @@ client.on('interactionCreate', async (interaction) => { } else if (interaction.commandName === trainCommand.name) { await interaction.deferReply(); const responseMessage = await saveGuildMessageHistory(interaction); - await interaction.editReply({ content: responseMessage }); + await interaction.editReply({ content: responseMessage, embeds: [] }); } } else if (interaction.isSelectMenu()) { if (interaction.customId === 'listen-modify-select') {