Support thread messages. Resolves #23.

Improved edit and delete support (switched to tags).
Remove embed when train command is finished.
This commit is contained in:
Charlie Laabs
2022-01-09 16:25:55 -06:00
parent 6b29681123
commit 457baee96d

View File

@@ -72,9 +72,25 @@ async function getMarkovByGuildId(guildId: string): Promise<Markov> {
return markov; return markov;
} }
async function isValidChannel(channelId: string): Promise<boolean> { /**
const channel = await Channel.findOne(channelId); * Returns a thread channels parent guild channel ID, otherwise it just returns a channel ID
return channel?.listen || false; */
function getGuildChannelId(channel: Discord.TextBasedChannel): string | null {
if (channel.isThread()) {
return channel.parentId;
}
return channel.id;
}
async function isValidChannel(channel: Discord.TextBasedChannel): Promise<boolean> {
const channelId = getGuildChannelId(channel);
if (!channelId) return false;
const dbChannel = await Channel.findOne(channelId);
return dbChannel?.listen || false;
}
function isHumanAuthoredMessage(message: Discord.Message | Discord.PartialMessage): boolean {
return !(message.author?.bot || message.system);
} }
async function getValidChannels(guild: Discord.Guild): Promise<Discord.TextChannel[]> { async function getValidChannels(guild: Discord.Guild): Promise<Discord.TextChannel[]> {
@@ -206,8 +222,10 @@ function messageToData(message: Discord.Message): AddDataProps {
let custom: MarkovDataCustom | undefined; let custom: MarkovDataCustom | undefined;
if (attachmentUrls.length) custom = { attachments: attachmentUrls }; if (attachmentUrls.length) custom = { attachments: attachmentUrls };
const tags: string[] = [message.id]; const tags: string[] = [message.id];
if (message.channelId) tags.push(message.channelId); if (message.channel.isThread()) tags.push(message.channelId); // Add thread channel ID
if (message.guildId) tags.push(message.guildId); const channelId = getGuildChannelId(message.channel);
if (channelId) tags.push(channelId); // Add guild channel ID
if (message.guildId) tags.push(message.guildId); // Add guild ID
return { return {
string: message.content, string: message.content,
custom, custom,
@@ -286,28 +304,79 @@ async function saveGuildMessageHistory(
const channelEta = makeEta({ autostart: true, min: 0, max: 1, historyTimeConstant: 30 }); const channelEta = makeEta({ autostart: true, min: 0, max: 1, historyTimeConstant: 30 });
while (keepGoing) { while (keepGoing) {
let messages; let allBatchMessages = new Discord.Collection<string, Discord.Message<boolean>>();
let channelBatchMessages: Discord.Collection<string, Discord.Message<boolean>>;
try { try {
// eslint-disable-next-line no-await-in-loop // eslint-disable-next-line no-await-in-loop
messages = await channel.messages.fetch({ channelBatchMessages = await channel.messages.fetch({
before: oldestMessageID, before: oldestMessageID,
limit: PAGE_SIZE, limit: PAGE_SIZE,
}); });
} catch (err) { } catch (err) {
L.error({ before: oldestMessageID, limit: PAGE_SIZE }, 'Error retreiving messages'); L.error({ before: oldestMessageID, limit: PAGE_SIZE }, 'Error retreiving messages');
L.error(err); L.error(err);
break; break; // Give up on this channel
} }
const nonBotMessageFormatted = messages.filter((elem) => !elem.author.bot).map(messageToData);
L.trace({ oldestMessageID }, `Saving ${nonBotMessageFormatted.length} messages`); // Gather any thread messages if present in this message batch
const threadChannels = channelBatchMessages
.filter((m) => m.hasThread)
.map((m) => m.thread)
.filter((c): c is Discord.ThreadChannel => c !== null);
if (threadChannels.length > 0) {
L.debug(`Found ${threadChannels.length} threads. Reading into them.`);
// eslint-disable-next-line no-restricted-syntax
for (const threadChannel of threadChannels) {
let oldestThreadMessageID: string | undefined;
let keepGoingThread = true;
L.debug({ channelId: threadChannel.id }, `Training from thread`);
while (keepGoingThread) {
let threadBatchMessages: Discord.Collection<string, Discord.Message<boolean>>;
try {
// eslint-disable-next-line no-await-in-loop // eslint-disable-next-line no-await-in-loop
await markov.addData(nonBotMessageFormatted); threadBatchMessages = await threadChannel.messages.fetch({
before: oldestThreadMessageID,
limit: PAGE_SIZE,
});
} catch (err) {
L.error(
{ before: oldestThreadMessageID, limit: PAGE_SIZE },
'Error retreiving thread messages'
);
L.error(err);
break; // Give up on this thread
}
L.trace(
{ threadMessagesCount: threadBatchMessages.size },
`Found some thread messages`
);
const lastThreadMessage = threadBatchMessages.last();
allBatchMessages = allBatchMessages.concat(threadBatchMessages); // Add the thread messages to this message batch to be included in later processing
if (!lastThreadMessage || threadBatchMessages.size < PAGE_SIZE) {
keepGoingThread = false;
} else {
oldestThreadMessageID = lastThreadMessage.id;
}
}
}
}
allBatchMessages = allBatchMessages.concat(channelBatchMessages);
// Filter and data map messages to be ready for addition to the corpus
const humanAuthoredMessages = allBatchMessages
.filter((m) => isHumanAuthoredMessage(m))
.map(messageToData);
L.trace({ oldestMessageID }, `Saving ${humanAuthoredMessages.length} messages`);
// eslint-disable-next-line no-await-in-loop
await markov.addData(humanAuthoredMessages);
L.trace('Finished saving messages'); L.trace('Finished saving messages');
messagesCount += nonBotMessageFormatted.length; messagesCount += humanAuthoredMessages.length;
const lastMessage = messages.last(); const lastMessage = channelBatchMessages.last();
// Update tracking metrics // Update tracking metrics
if (!lastMessage || messages.size < PAGE_SIZE) { if (!lastMessage || channelBatchMessages.size < PAGE_SIZE) {
keepGoing = false; keepGoing = false;
if (completedChannelsField.value === 'None') completedChannelsField.value = ''; if (completedChannelsField.value === 'None') completedChannelsField.value = '';
completedChannelsField.value += `\n • <#${channel.id}>`; completedChannelsField.value += `\n • <#${channel.id}>`;
@@ -315,7 +384,7 @@ async function saveGuildMessageHistory(
oldestMessageID = lastMessage.id; oldestMessageID = lastMessage.id;
} }
currentChannelField.value = `<#${channel.id}>`; currentChannelField.value = `<#${channel.id}>`;
if (!firstMessageDate) firstMessageDate = messages.first()?.createdTimestamp; if (!firstMessageDate) firstMessageDate = channelBatchMessages.first()?.createdTimestamp;
const oldestMessageDate = lastMessage?.createdTimestamp; const oldestMessageDate = lastMessage?.createdTimestamp;
if (firstMessageDate && oldestMessageDate) { if (firstMessageDate && oldestMessageDate) {
const channelAge = firstMessageDate - channelCreateDate; const channelAge = firstMessageDate - channelCreateDate;
@@ -378,10 +447,6 @@ async function generateResponse(
L.warn('Received an interaction without a guildId'); L.warn('Received an interaction without a guildId');
return { error: { content: INVALID_GUILD_MESSAGE } }; return { error: { content: INVALID_GUILD_MESSAGE } };
} }
if (!interaction.channelId) {
L.warn('Received an interaction without a channelId');
return { error: { content: 'This action must be performed within a text channel.' } };
}
if (!isAllowedUser(interaction.member)) { if (!isAllowedUser(interaction.member)) {
L.info('Member does not have permissions to generate a response'); L.info('Member does not have permissions to generate a response');
return { error: { content: INVALID_PERMISSIONS_MESSAGE } }; return { error: { content: INVALID_PERMISSIONS_MESSAGE } };
@@ -573,7 +638,14 @@ client.on('warn', (m) => L.warn(m));
client.on('error', (m) => L.error(m)); client.on('error', (m) => L.error(m));
client.on('messageCreate', async (message) => { client.on('messageCreate', async (message) => {
if (!(message.guild && message.channel instanceof Discord.TextChannel)) return; if (
!(
message.guild &&
(message.channel instanceof Discord.TextChannel ||
message.channel instanceof Discord.ThreadChannel)
)
)
return;
const command = validateMessage(message); const command = validateMessage(message);
if (command !== null) L.info({ command }, 'Recieved message command'); if (command !== null) L.info({ command }, 'Recieved message command');
if (command === 'help') { if (command === 'help') {
@@ -602,7 +674,7 @@ client.on('messageCreate', async (message) => {
await handleResponseMessage(generatedResponse, message); await handleResponseMessage(generatedResponse, message);
} }
if (command === null) { if (command === null) {
if (!message.author.bot) { if (isHumanAuthoredMessage(message)) {
if (client.user && message.mentions.has(client.user)) { if (client.user && message.mentions.has(client.user)) {
L.debug('Responding to mention'); L.debug('Responding to mention');
// <@!278354154563567636> how are you doing? // <@!278354154563567636> how are you doing?
@@ -611,7 +683,7 @@ client.on('messageCreate', async (message) => {
await handleResponseMessage(generatedResponse, message); await handleResponseMessage(generatedResponse, message);
} }
if (await isValidChannel(message.channelId)) { if (await isValidChannel(message.channel)) {
L.debug('Listening'); L.debug('Listening');
const markov = await getMarkovByGuildId(message.channel.guildId); const markov = await getMarkovByGuildId(message.channel.guildId);
await markov.addData([messageToData(message)]); await markov.addData([messageToData(message)]);
@@ -621,26 +693,35 @@ client.on('messageCreate', async (message) => {
}); });
client.on('messageDelete', async (message) => { client.on('messageDelete', async (message) => {
if (message.author?.bot) return; if (!isHumanAuthoredMessage(message)) return;
if (!(await isValidChannel(message.channelId))) return; if (!(await isValidChannel(message.channel))) return;
if (!(message.guildId && message.content)) return; if (!message.guildId) return;
L.debug(`Deleting message ${message.id}`); L.debug(`Deleting message ${message.id}`);
const markov = await getMarkovByGuildId(message.guildId); const markov = await getMarkovByGuildId(message.guildId);
await markov.removeStrings([message.content]); await markov.removeTags([message.id]);
}); });
client.on('messageUpdate', async (oldMessage, newMessage) => { client.on('messageUpdate', async (oldMessage, newMessage) => {
if (oldMessage.author?.bot) return; if (!isHumanAuthoredMessage(oldMessage)) return;
if (!(await isValidChannel(oldMessage.channelId))) return; if (!(await isValidChannel(oldMessage.channel))) return;
if (!(oldMessage.guildId && oldMessage.content && newMessage.content)) return; if (!(oldMessage.guildId && newMessage.content)) return;
L.debug(`Editing message ${oldMessage.id}`); L.debug(`Editing message ${oldMessage.id}`);
const markov = await getMarkovByGuildId(oldMessage.guildId); const markov = await getMarkovByGuildId(oldMessage.guildId);
await markov.removeStrings([oldMessage.content]); await markov.removeTags([oldMessage.id]);
await markov.addData([newMessage.content]); await markov.addData([newMessage.content]);
}); });
client.on('threadDelete', async (thread) => {
if (!(await isValidChannel(thread))) return;
if (!thread.guildId) return;
L.debug(`Deleting thread messages ${thread.id}`);
const markov = await getMarkovByGuildId(thread.guildId);
await markov.removeTags([thread.id]);
});
// eslint-disable-next-line consistent-return // eslint-disable-next-line consistent-return
client.on('interactionCreate', async (interaction) => { client.on('interactionCreate', async (interaction) => {
if (interaction.isCommand()) { if (interaction.isCommand()) {
@@ -725,7 +806,7 @@ client.on('interactionCreate', async (interaction) => {
} else if (interaction.commandName === trainCommand.name) { } else if (interaction.commandName === trainCommand.name) {
await interaction.deferReply(); await interaction.deferReply();
const responseMessage = await saveGuildMessageHistory(interaction); const responseMessage = await saveGuildMessageHistory(interaction);
await interaction.editReply({ content: responseMessage }); await interaction.editReply({ content: responseMessage, embeds: [] });
} }
} else if (interaction.isSelectMenu()) { } else if (interaction.isSelectMenu()) {
if (interaction.customId === 'listen-modify-select') { if (interaction.customId === 'listen-modify-select') {