fix: Change enable_sites command to use string parameter instead of *args for hybrid command compatibility

This commit is contained in:
pacnpal
2024-11-15 00:07:57 +00:00
parent fc40e994fe
commit a373c455a9

View File

@@ -154,26 +154,27 @@ class VideoArchiverCommands(commands.Cog):
await ctx.send(f"Archive message template set to:\n{template}") await ctx.send(f"Archive message template set to:\n{template}")
@videoarchiver.command(name="enablesites") @videoarchiver.command(name="enablesites")
async def enable_sites(self, ctx: commands.Context, *sites: str): async def enable_sites(self, ctx: commands.Context, *, sites: Optional[str] = None):
"""Enable specific sites (leave empty for all sites)""" """Enable specific sites (leave empty for all sites). Separate multiple sites with spaces."""
sites = [s.lower() for s in sites] if sites is None:
if not sites:
await self.config.update_setting(ctx.guild.id, "enabled_sites", []) await self.config.update_setting(ctx.guild.id, "enabled_sites", [])
await ctx.send("All sites enabled") await ctx.send("All sites enabled")
return return
site_list = [s.strip().lower() for s in sites.split()]
# Verify sites are valid # Verify sites are valid
with yt_dlp.YoutubeDL() as ydl: with yt_dlp.YoutubeDL() as ydl:
valid_sites = set(ie.IE_NAME.lower() for ie in ydl._ies) valid_sites = set(ie.IE_NAME.lower() for ie in ydl._ies)
invalid_sites = [s for s in sites if s not in valid_sites] invalid_sites = [s for s in site_list if s not in valid_sites]
if invalid_sites: if invalid_sites:
await ctx.send( await ctx.send(
f"Invalid sites: {', '.join(invalid_sites)}\nValid sites: {', '.join(valid_sites)}" f"Invalid sites: {', '.join(invalid_sites)}\nValid sites: {', '.join(valid_sites)}"
) )
return return
await self.config.update_setting(ctx.guild.id, "enabled_sites", sites) await self.config.update_setting(ctx.guild.id, "enabled_sites", site_list)
await ctx.send(f"Enabled sites: {', '.join(sites)}") await ctx.send(f"Enabled sites: {', '.join(site_list)}")
@videoarchiver.command(name="listsites") @videoarchiver.command(name="listsites")
async def list_sites(self, ctx: commands.Context): async def list_sites(self, ctx: commands.Context):