Support thread messages. Resolves #23.

Improved edit and delete support (switched to tags).
Remove embed when train command is finished.
This commit is contained in:
Charlie Laabs
2022-01-09 16:25:55 -06:00
parent 6b29681123
commit 457baee96d

View File

@@ -72,9 +72,25 @@ async function getMarkovByGuildId(guildId: string): Promise<Markov> {
return markov;
}
async function isValidChannel(channelId: string): Promise<boolean> {
const channel = await Channel.findOne(channelId);
return channel?.listen || false;
/**
* Returns a thread channels parent guild channel ID, otherwise it just returns a channel ID
*/
function getGuildChannelId(channel: Discord.TextBasedChannel): string | null {
if (channel.isThread()) {
return channel.parentId;
}
return channel.id;
}
async function isValidChannel(channel: Discord.TextBasedChannel): Promise<boolean> {
const channelId = getGuildChannelId(channel);
if (!channelId) return false;
const dbChannel = await Channel.findOne(channelId);
return dbChannel?.listen || false;
}
function isHumanAuthoredMessage(message: Discord.Message | Discord.PartialMessage): boolean {
return !(message.author?.bot || message.system);
}
async function getValidChannels(guild: Discord.Guild): Promise<Discord.TextChannel[]> {
@@ -206,8 +222,10 @@ function messageToData(message: Discord.Message): AddDataProps {
let custom: MarkovDataCustom | undefined;
if (attachmentUrls.length) custom = { attachments: attachmentUrls };
const tags: string[] = [message.id];
if (message.channelId) tags.push(message.channelId);
if (message.guildId) tags.push(message.guildId);
if (message.channel.isThread()) tags.push(message.channelId); // Add thread channel ID
const channelId = getGuildChannelId(message.channel);
if (channelId) tags.push(channelId); // Add guild channel ID
if (message.guildId) tags.push(message.guildId); // Add guild ID
return {
string: message.content,
custom,
@@ -286,28 +304,79 @@ async function saveGuildMessageHistory(
const channelEta = makeEta({ autostart: true, min: 0, max: 1, historyTimeConstant: 30 });
while (keepGoing) {
let messages;
let allBatchMessages = new Discord.Collection<string, Discord.Message<boolean>>();
let channelBatchMessages: Discord.Collection<string, Discord.Message<boolean>>;
try {
// eslint-disable-next-line no-await-in-loop
messages = await channel.messages.fetch({
channelBatchMessages = await channel.messages.fetch({
before: oldestMessageID,
limit: PAGE_SIZE,
});
} catch (err) {
L.error({ before: oldestMessageID, limit: PAGE_SIZE }, 'Error retreiving messages');
L.error(err);
break;
break; // Give up on this channel
}
const nonBotMessageFormatted = messages.filter((elem) => !elem.author.bot).map(messageToData);
L.trace({ oldestMessageID }, `Saving ${nonBotMessageFormatted.length} messages`);
// Gather any thread messages if present in this message batch
const threadChannels = channelBatchMessages
.filter((m) => m.hasThread)
.map((m) => m.thread)
.filter((c): c is Discord.ThreadChannel => c !== null);
if (threadChannels.length > 0) {
L.debug(`Found ${threadChannels.length} threads. Reading into them.`);
// eslint-disable-next-line no-restricted-syntax
for (const threadChannel of threadChannels) {
let oldestThreadMessageID: string | undefined;
let keepGoingThread = true;
L.debug({ channelId: threadChannel.id }, `Training from thread`);
while (keepGoingThread) {
let threadBatchMessages: Discord.Collection<string, Discord.Message<boolean>>;
try {
// eslint-disable-next-line no-await-in-loop
threadBatchMessages = await threadChannel.messages.fetch({
before: oldestThreadMessageID,
limit: PAGE_SIZE,
});
} catch (err) {
L.error(
{ before: oldestThreadMessageID, limit: PAGE_SIZE },
'Error retreiving thread messages'
);
L.error(err);
break; // Give up on this thread
}
L.trace(
{ threadMessagesCount: threadBatchMessages.size },
`Found some thread messages`
);
const lastThreadMessage = threadBatchMessages.last();
allBatchMessages = allBatchMessages.concat(threadBatchMessages); // Add the thread messages to this message batch to be included in later processing
if (!lastThreadMessage || threadBatchMessages.size < PAGE_SIZE) {
keepGoingThread = false;
} else {
oldestThreadMessageID = lastThreadMessage.id;
}
}
}
}
allBatchMessages = allBatchMessages.concat(channelBatchMessages);
// Filter and data map messages to be ready for addition to the corpus
const humanAuthoredMessages = allBatchMessages
.filter((m) => isHumanAuthoredMessage(m))
.map(messageToData);
L.trace({ oldestMessageID }, `Saving ${humanAuthoredMessages.length} messages`);
// eslint-disable-next-line no-await-in-loop
await markov.addData(nonBotMessageFormatted);
await markov.addData(humanAuthoredMessages);
L.trace('Finished saving messages');
messagesCount += nonBotMessageFormatted.length;
const lastMessage = messages.last();
messagesCount += humanAuthoredMessages.length;
const lastMessage = channelBatchMessages.last();
// Update tracking metrics
if (!lastMessage || messages.size < PAGE_SIZE) {
if (!lastMessage || channelBatchMessages.size < PAGE_SIZE) {
keepGoing = false;
if (completedChannelsField.value === 'None') completedChannelsField.value = '';
completedChannelsField.value += `\n • <#${channel.id}>`;
@@ -315,7 +384,7 @@ async function saveGuildMessageHistory(
oldestMessageID = lastMessage.id;
}
currentChannelField.value = `<#${channel.id}>`;
if (!firstMessageDate) firstMessageDate = messages.first()?.createdTimestamp;
if (!firstMessageDate) firstMessageDate = channelBatchMessages.first()?.createdTimestamp;
const oldestMessageDate = lastMessage?.createdTimestamp;
if (firstMessageDate && oldestMessageDate) {
const channelAge = firstMessageDate - channelCreateDate;
@@ -378,10 +447,6 @@ async function generateResponse(
L.warn('Received an interaction without a guildId');
return { error: { content: INVALID_GUILD_MESSAGE } };
}
if (!interaction.channelId) {
L.warn('Received an interaction without a channelId');
return { error: { content: 'This action must be performed within a text channel.' } };
}
if (!isAllowedUser(interaction.member)) {
L.info('Member does not have permissions to generate a response');
return { error: { content: INVALID_PERMISSIONS_MESSAGE } };
@@ -573,7 +638,14 @@ client.on('warn', (m) => L.warn(m));
client.on('error', (m) => L.error(m));
client.on('messageCreate', async (message) => {
if (!(message.guild && message.channel instanceof Discord.TextChannel)) return;
if (
!(
message.guild &&
(message.channel instanceof Discord.TextChannel ||
message.channel instanceof Discord.ThreadChannel)
)
)
return;
const command = validateMessage(message);
if (command !== null) L.info({ command }, 'Recieved message command');
if (command === 'help') {
@@ -602,7 +674,7 @@ client.on('messageCreate', async (message) => {
await handleResponseMessage(generatedResponse, message);
}
if (command === null) {
if (!message.author.bot) {
if (isHumanAuthoredMessage(message)) {
if (client.user && message.mentions.has(client.user)) {
L.debug('Responding to mention');
// <@!278354154563567636> how are you doing?
@@ -611,7 +683,7 @@ client.on('messageCreate', async (message) => {
await handleResponseMessage(generatedResponse, message);
}
if (await isValidChannel(message.channelId)) {
if (await isValidChannel(message.channel)) {
L.debug('Listening');
const markov = await getMarkovByGuildId(message.channel.guildId);
await markov.addData([messageToData(message)]);
@@ -621,26 +693,35 @@ client.on('messageCreate', async (message) => {
});
client.on('messageDelete', async (message) => {
if (message.author?.bot) return;
if (!(await isValidChannel(message.channelId))) return;
if (!(message.guildId && message.content)) return;
if (!isHumanAuthoredMessage(message)) return;
if (!(await isValidChannel(message.channel))) return;
if (!message.guildId) return;
L.debug(`Deleting message ${message.id}`);
const markov = await getMarkovByGuildId(message.guildId);
await markov.removeStrings([message.content]);
await markov.removeTags([message.id]);
});
client.on('messageUpdate', async (oldMessage, newMessage) => {
if (oldMessage.author?.bot) return;
if (!(await isValidChannel(oldMessage.channelId))) return;
if (!(oldMessage.guildId && oldMessage.content && newMessage.content)) return;
if (!isHumanAuthoredMessage(oldMessage)) return;
if (!(await isValidChannel(oldMessage.channel))) return;
if (!(oldMessage.guildId && newMessage.content)) return;
L.debug(`Editing message ${oldMessage.id}`);
const markov = await getMarkovByGuildId(oldMessage.guildId);
await markov.removeStrings([oldMessage.content]);
await markov.removeTags([oldMessage.id]);
await markov.addData([newMessage.content]);
});
client.on('threadDelete', async (thread) => {
if (!(await isValidChannel(thread))) return;
if (!thread.guildId) return;
L.debug(`Deleting thread messages ${thread.id}`);
const markov = await getMarkovByGuildId(thread.guildId);
await markov.removeTags([thread.id]);
});
// eslint-disable-next-line consistent-return
client.on('interactionCreate', async (interaction) => {
if (interaction.isCommand()) {
@@ -725,7 +806,7 @@ client.on('interactionCreate', async (interaction) => {
} else if (interaction.commandName === trainCommand.name) {
await interaction.deferReply();
const responseMessage = await saveGuildMessageHistory(interaction);
await interaction.editReply({ content: responseMessage });
await interaction.editReply({ content: responseMessage, embeds: [] });
}
} else if (interaction.isSelectMenu()) {
if (interaction.customId === 'listen-modify-select') {