This commit is contained in:
pacnpal
2024-11-17 01:03:45 +00:00
parent d2576df988
commit 2557978cf3
6 changed files with 143 additions and 451 deletions

View File

@@ -8,7 +8,7 @@ from typing import Optional, Dict, Any, List, Tuple, Set, TypedDict, ClassVar, C
from datetime import datetime
import discord
from ..utils.progress_tracker import ProgressTracker
from ..utils import progress_tracker
from ..database.video_archive_db import VideoArchiveDB
from ..utils.download_manager import DownloadManager
from ..utils.message_manager import MessageManager
@@ -57,7 +57,6 @@ class QueueHandler:
self._unloading = False
self._active_downloads: Dict[str, asyncio.Task] = {}
self._active_downloads_lock = asyncio.Lock()
self.progress_tracker = ProgressTracker()
self._stats: QueueStats = {
"active_downloads": 0,
"processing_items": 0,
@@ -71,13 +70,13 @@ class QueueHandler:
async def process_video(self, item: QueueItem) -> Tuple[bool, Optional[str]]:
"""
Process a video from the queue.
Args:
item: Queue item to process
Returns:
Tuple of (success, error_message)
Raises:
QueueHandlerError: If there's an error during processing
"""
@@ -105,12 +104,16 @@ class QueueHandler:
message_manager = components.get("message_manager")
if not downloader or not message_manager:
raise QueueHandlerError(f"Missing required components for guild {item.guild_id}")
raise QueueHandlerError(
f"Missing required components for guild {item.guild_id}"
)
# Get original message and update reactions
original_message = await self._get_original_message(item)
if original_message:
await self._update_message_reactions(original_message, QueueItemStatus.PROCESSING)
await self._update_message_reactions(
original_message, QueueItemStatus.PROCESSING
)
# Download and archive video
file_path = await self._process_video_file(
@@ -121,7 +124,9 @@ class QueueHandler:
self._update_stats(True, start_time)
item.finish_processing(True)
if original_message:
await self._update_message_reactions(original_message, QueueItemStatus.COMPLETED)
await self._update_message_reactions(
original_message, QueueItemStatus.COMPLETED
)
return True, None
except QueueHandlerError as e:
@@ -143,7 +148,9 @@ class QueueHandler:
if self.db.is_url_archived(item.url):
logger.info(f"Video already archived: {item.url}")
if original_message := await self._get_original_message(item):
await self._update_message_reactions(original_message, QueueItemStatus.COMPLETED)
await self._update_message_reactions(
original_message, QueueItemStatus.COMPLETED
)
archived_info = self.db.get_archived_video(item.url)
if archived_info:
await original_message.reply(
@@ -153,10 +160,7 @@ class QueueHandler:
return True
return False
async def _get_components(
self,
guild_id: int
) -> Dict[str, Any]:
async def _get_components(self, guild_id: int) -> Dict[str, Any]:
"""Get required components for processing"""
if guild_id not in self.components:
raise QueueHandlerError(f"No components found for guild {guild_id}")
@@ -167,7 +171,7 @@ class QueueHandler:
downloader: DownloadManager,
message_manager: MessageManager,
item: QueueItem,
original_message: Optional[discord.Message]
original_message: Optional[discord.Message],
) -> Optional[str]:
"""Download and process video file"""
# Create progress callback
@@ -182,11 +186,7 @@ class QueueHandler:
# Archive video
success, error = await self._archive_video(
item.guild_id,
original_message,
message_manager,
item.url,
file_path
item.guild_id, original_message, message_manager, item.url, file_path
)
if not success:
raise QueueHandlerError(f"Failed to archive video: {error}")
@@ -194,16 +194,15 @@ class QueueHandler:
return file_path
def _handle_processing_error(
self,
item: QueueItem,
message: Optional[discord.Message],
error: str
self, item: QueueItem, message: Optional[discord.Message], error: str
) -> None:
"""Handle processing error"""
self._update_stats(False, datetime.utcnow())
item.finish_processing(False, error)
if message:
asyncio.create_task(self._update_message_reactions(message, QueueItemStatus.FAILED))
asyncio.create_task(
self._update_message_reactions(message, QueueItemStatus.FAILED)
)
def _update_stats(self, success: bool, start_time: datetime) -> None:
"""Update queue statistics"""
@@ -213,19 +212,19 @@ class QueueHandler:
self._stats["completed_items"] += 1
else:
self._stats["failed_items"] += 1
# Update average processing time
total_items = self._stats["completed_items"] + self._stats["failed_items"]
if total_items > 0:
current_total = self._stats["average_processing_time"] * (total_items - 1)
self._stats["average_processing_time"] = (current_total + processing_time) / total_items
self._stats["average_processing_time"] = (
current_total + processing_time
) / total_items
self._stats["last_processed"] = datetime.utcnow().isoformat()
async def _update_message_reactions(
self,
message: discord.Message,
status: QueueItemStatus
self, message: discord.Message, status: QueueItemStatus
) -> None:
"""Update message reactions based on status"""
try:
@@ -234,7 +233,7 @@ class QueueHandler:
REACTIONS["queued"],
REACTIONS["processing"],
REACTIONS["success"],
REACTIONS["error"]
REACTIONS["error"],
]:
try:
await message.remove_reaction(reaction, self.bot.user)
@@ -265,21 +264,21 @@ class QueueHandler:
original_message: Optional[discord.Message],
message_manager: MessageManager,
url: str,
file_path: str
file_path: str,
) -> Tuple[bool, Optional[str]]:
"""
Archive downloaded video.
Args:
guild_id: Discord guild ID
original_message: Original message containing the video
message_manager: Message manager instance
url: Video URL
file_path: Path to downloaded video file
Returns:
Tuple of (success, error_message)
Raises:
QueueHandlerError: If archiving fails
"""
@@ -308,19 +307,14 @@ class QueueHandler:
raise QueueHandlerError("Processed file not found")
archive_message = await archive_channel.send(
content=message,
file=discord.File(file_path)
content=message, file=discord.File(file_path)
)
# Store in database if available
if self.db and archive_message.attachments:
discord_url = archive_message.attachments[0].url
self.db.add_archived_video(
url,
discord_url,
archive_message.id,
archive_channel.id,
guild_id
url, discord_url, archive_message.id, archive_channel.id, guild_id
)
logger.info(f"Added video to archive database: {url} -> {discord_url}")
@@ -333,16 +327,13 @@ class QueueHandler:
logger.error(f"Failed to archive video: {str(e)}")
raise QueueHandlerError(f"Failed to archive video: {str(e)}")
async def _get_original_message(
self,
item: QueueItem
) -> Optional[discord.Message]:
async def _get_original_message(self, item: QueueItem) -> Optional[discord.Message]:
"""
Retrieve the original message.
Args:
item: Queue item containing message details
Returns:
Original Discord message or None if not found
"""
@@ -358,57 +349,61 @@ class QueueHandler:
return None
def _create_progress_callback(
self,
message: Optional[discord.Message],
url: str
self, message: Optional[discord.Message], url: str
) -> Callable[[float], None]:
"""
Create progress callback function for download tracking.
Args:
message: Discord message to update with progress
url: URL being downloaded
Returns:
Callback function for progress updates
"""
def progress_callback(progress: float) -> None:
if message:
try:
loop = asyncio.get_running_loop()
if not loop.is_running():
logger.warning("Event loop is not running, skipping progress update")
logger.warning(
"Event loop is not running, skipping progress update"
)
return
# Update progress tracking
self.progress_tracker.update_download_progress(url, {
'percent': progress,
'last_update': datetime.utcnow().isoformat()
})
progress_tracker.update_download_progress(
url,
{
"percent": progress,
"last_update": datetime.utcnow().isoformat(),
},
)
# Create task to update reaction
asyncio.run_coroutine_threadsafe(
self._update_download_progress_reaction(message, progress),
loop
self._update_download_progress_reaction(message, progress), loop
)
except Exception as e:
logger.error(f"Error in progress callback: {e}")
return progress_callback
async def _download_video(
self,
downloader: DownloadManager,
url: str,
progress_callback: Callable[[float], None]
progress_callback: Callable[[float], None],
) -> Tuple[bool, Optional[str], Optional[str]]:
"""
Download video with progress tracking.
Args:
downloader: Download manager instance
url: URL to download
progress_callback: Callback for progress updates
Returns:
Tuple of (success, file_path, error_message)
"""
@@ -422,13 +417,12 @@ class QueueHandler:
try:
success, file_path, error = await asyncio.wait_for(
download_task,
timeout=self.DOWNLOAD_TIMEOUT
download_task, timeout=self.DOWNLOAD_TIMEOUT
)
if success:
self.progress_tracker.complete_download(url)
progress_tracker.complete_download(url)
else:
self.progress_tracker.increment_download_retries(url)
progress_tracker.increment_download_retries(url)
return success, file_path, error
except asyncio.TimeoutError:
@@ -448,7 +442,7 @@ class QueueHandler:
async def cleanup(self) -> None:
"""
Clean up resources and stop processing.
Raises:
QueueHandlerError: If cleanup fails
"""
@@ -466,7 +460,9 @@ class QueueHandler:
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"Error cancelling download task for {url}: {e}")
logger.error(
f"Error cancelling download task for {url}: {e}"
)
self._active_downloads.clear()
self._stats["active_downloads"] = 0
@@ -492,12 +488,12 @@ class QueueHandler:
logger.info("QueueHandler force cleanup completed")
except Exception as e:
logger.error(f"Error during QueueHandler force cleanup: {str(e)}", exc_info=True)
logger.error(
f"Error during QueueHandler force cleanup: {str(e)}", exc_info=True
)
async def _update_download_progress_reaction(
self,
message: discord.Message,
progress: float
self, message: discord.Message, progress: float
) -> None:
"""Update download progress reaction on message"""
if not message:
@@ -535,7 +531,7 @@ class QueueHandler:
def is_healthy(self) -> bool:
"""
Check if handler is healthy.
Returns:
True if handler is healthy, False otherwise
"""
@@ -543,9 +539,13 @@ class QueueHandler:
# Check if any downloads are stuck
current_time = datetime.utcnow()
for url, task in self._active_downloads.items():
if not task.done() and task.get_coro().cr_frame.f_locals.get('start_time'):
start_time = task.get_coro().cr_frame.f_locals['start_time']
if (current_time - start_time).total_seconds() > self.DOWNLOAD_TIMEOUT:
if not task.done() and task.get_coro().cr_frame.f_locals.get(
"start_time"
):
start_time = task.get_coro().cr_frame.f_locals["start_time"]
if (
current_time - start_time
).total_seconds() > self.DOWNLOAD_TIMEOUT:
self._stats["is_healthy"] = False
return False
@@ -566,7 +566,7 @@ class QueueHandler:
def get_stats(self) -> QueueStats:
"""
Get queue handler statistics.
Returns:
Dictionary containing queue statistics
"""