mirror of
https://github.com/pacnpal/Pac-cogs.git
synced 2025-12-20 10:51:05 -05:00
367 lines
13 KiB
Python
367 lines
13 KiB
Python
"""Module for managing queue state"""
|
|
|
|
import logging
|
|
import asyncio
|
|
from enum import Enum
|
|
from dataclasses import dataclass
|
|
from typing import Dict, Set, List, Optional, Any
|
|
from datetime import datetime
|
|
|
|
from models import QueueItem, QueueMetrics
|
|
|
|
logger = logging.getLogger("QueueStateManager")
|
|
|
|
class ItemState(Enum):
|
|
"""Possible states for queue items"""
|
|
PENDING = "pending"
|
|
PROCESSING = "processing"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
RETRYING = "retrying"
|
|
|
|
@dataclass
|
|
class StateTransition:
|
|
"""Records a state transition"""
|
|
item_url: str
|
|
from_state: ItemState
|
|
to_state: ItemState
|
|
timestamp: datetime
|
|
reason: Optional[str] = None
|
|
|
|
class StateSnapshot:
|
|
"""Represents a point-in-time snapshot of queue state"""
|
|
|
|
def __init__(self):
|
|
self.timestamp = datetime.utcnow()
|
|
self.queue: List[QueueItem] = []
|
|
self.processing: Dict[str, QueueItem] = {}
|
|
self.completed: Dict[str, QueueItem] = {}
|
|
self.failed: Dict[str, QueueItem] = {}
|
|
self.guild_queues: Dict[int, Set[str]] = {}
|
|
self.channel_queues: Dict[int, Set[str]] = {}
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert snapshot to dictionary"""
|
|
return {
|
|
"timestamp": self.timestamp.isoformat(),
|
|
"queue": [item.__dict__ for item in self.queue],
|
|
"processing": {url: item.__dict__ for url, item in self.processing.items()},
|
|
"completed": {url: item.__dict__ for url, item in self.completed.items()},
|
|
"failed": {url: item.__dict__ for url, item in self.failed.items()},
|
|
"guild_queues": {gid: list(urls) for gid, urls in self.guild_queues.items()},
|
|
"channel_queues": {cid: list(urls) for cid, urls in self.channel_queues.items()}
|
|
}
|
|
|
|
class StateValidator:
|
|
"""Validates queue state"""
|
|
|
|
@staticmethod
|
|
def validate_item(item: QueueItem) -> bool:
|
|
"""Validate a queue item"""
|
|
return all([
|
|
isinstance(item.url, str) and item.url,
|
|
isinstance(item.guild_id, int) and item.guild_id > 0,
|
|
isinstance(item.channel_id, int) and item.channel_id > 0,
|
|
isinstance(item.priority, int) and 0 <= item.priority <= 10,
|
|
isinstance(item.added_at, datetime),
|
|
isinstance(item.status, str)
|
|
])
|
|
|
|
@staticmethod
|
|
def validate_transition(
|
|
item: QueueItem,
|
|
from_state: ItemState,
|
|
to_state: ItemState
|
|
) -> bool:
|
|
"""Validate a state transition"""
|
|
valid_transitions = {
|
|
ItemState.PENDING: {ItemState.PROCESSING, ItemState.FAILED},
|
|
ItemState.PROCESSING: {ItemState.COMPLETED, ItemState.FAILED, ItemState.RETRYING},
|
|
ItemState.FAILED: {ItemState.RETRYING},
|
|
ItemState.RETRYING: {ItemState.PENDING},
|
|
ItemState.COMPLETED: set() # No transitions from completed
|
|
}
|
|
return to_state in valid_transitions.get(from_state, set())
|
|
|
|
class StateTracker:
|
|
"""Tracks state changes and transitions"""
|
|
|
|
def __init__(self, max_history: int = 1000):
|
|
self.max_history = max_history
|
|
self.transitions: List[StateTransition] = []
|
|
self.snapshots: List[StateSnapshot] = []
|
|
self.state_counts: Dict[ItemState, int] = {state: 0 for state in ItemState}
|
|
|
|
def record_transition(
|
|
self,
|
|
transition: StateTransition
|
|
) -> None:
|
|
"""Record a state transition"""
|
|
self.transitions.append(transition)
|
|
if len(self.transitions) > self.max_history:
|
|
self.transitions.pop(0)
|
|
|
|
self.state_counts[transition.from_state] -= 1
|
|
self.state_counts[transition.to_state] += 1
|
|
|
|
def take_snapshot(self, state_manager: 'QueueStateManager') -> None:
|
|
"""Take a snapshot of current state"""
|
|
snapshot = StateSnapshot()
|
|
snapshot.queue = state_manager._queue.copy()
|
|
snapshot.processing = state_manager._processing.copy()
|
|
snapshot.completed = state_manager._completed.copy()
|
|
snapshot.failed = state_manager._failed.copy()
|
|
snapshot.guild_queues = {
|
|
gid: urls.copy() for gid, urls in state_manager._guild_queues.items()
|
|
}
|
|
snapshot.channel_queues = {
|
|
cid: urls.copy() for cid, urls in state_manager._channel_queues.items()
|
|
}
|
|
|
|
self.snapshots.append(snapshot)
|
|
if len(self.snapshots) > self.max_history:
|
|
self.snapshots.pop(0)
|
|
|
|
def get_state_history(self) -> Dict[str, Any]:
|
|
"""Get state history statistics"""
|
|
return {
|
|
"transitions": len(self.transitions),
|
|
"snapshots": len(self.snapshots),
|
|
"state_counts": {
|
|
state.value: count
|
|
for state, count in self.state_counts.items()
|
|
},
|
|
"latest_snapshot": (
|
|
self.snapshots[-1].to_dict()
|
|
if self.snapshots
|
|
else None
|
|
)
|
|
}
|
|
|
|
class QueueStateManager:
|
|
"""Manages the state of the queue system"""
|
|
|
|
def __init__(self, max_queue_size: int = 1000):
|
|
self.max_queue_size = max_queue_size
|
|
|
|
# Queue storage
|
|
self._queue: List[QueueItem] = []
|
|
self._processing: Dict[str, QueueItem] = {}
|
|
self._completed: Dict[str, QueueItem] = {}
|
|
self._failed: Dict[str, QueueItem] = {}
|
|
|
|
# Tracking
|
|
self._guild_queues: Dict[int, Set[str]] = {}
|
|
self._channel_queues: Dict[int, Set[str]] = {}
|
|
|
|
# State management
|
|
self._lock = asyncio.Lock()
|
|
self.validator = StateValidator()
|
|
self.tracker = StateTracker()
|
|
|
|
async def add_item(self, item: QueueItem) -> bool:
|
|
"""Add an item to the queue"""
|
|
if not self.validator.validate_item(item):
|
|
logger.error(f"Invalid queue item: {item}")
|
|
return False
|
|
|
|
async with self._lock:
|
|
if len(self._queue) >= self.max_queue_size:
|
|
return False
|
|
|
|
# Record transition
|
|
self.tracker.record_transition(StateTransition(
|
|
item_url=item.url,
|
|
from_state=ItemState.PENDING,
|
|
to_state=ItemState.PENDING,
|
|
timestamp=datetime.utcnow(),
|
|
reason="Initial add"
|
|
))
|
|
|
|
# Add to main queue
|
|
self._queue.append(item)
|
|
self._queue.sort(key=lambda x: (-x.priority, x.added_at))
|
|
|
|
# Update tracking
|
|
if item.guild_id not in self._guild_queues:
|
|
self._guild_queues[item.guild_id] = set()
|
|
self._guild_queues[item.guild_id].add(item.url)
|
|
|
|
if item.channel_id not in self._channel_queues:
|
|
self._channel_queues[item.channel_id] = set()
|
|
self._channel_queues[item.channel_id].add(item.url)
|
|
|
|
# Take snapshot periodically
|
|
if len(self._queue) % 100 == 0:
|
|
self.tracker.take_snapshot(self)
|
|
|
|
return True
|
|
|
|
async def get_next_items(self, count: int = 5) -> List[QueueItem]:
|
|
"""Get the next batch of items to process"""
|
|
items = []
|
|
async with self._lock:
|
|
while len(items) < count and self._queue:
|
|
item = self._queue.pop(0)
|
|
items.append(item)
|
|
self._processing[item.url] = item
|
|
|
|
# Record transition
|
|
self.tracker.record_transition(StateTransition(
|
|
item_url=item.url,
|
|
from_state=ItemState.PENDING,
|
|
to_state=ItemState.PROCESSING,
|
|
timestamp=datetime.utcnow()
|
|
))
|
|
|
|
return items
|
|
|
|
async def mark_completed(
|
|
self,
|
|
item: QueueItem,
|
|
success: bool,
|
|
error: Optional[str] = None
|
|
) -> None:
|
|
"""Mark an item as completed or failed"""
|
|
async with self._lock:
|
|
self._processing.pop(item.url, None)
|
|
|
|
to_state = ItemState.COMPLETED if success else ItemState.FAILED
|
|
self.tracker.record_transition(StateTransition(
|
|
item_url=item.url,
|
|
from_state=ItemState.PROCESSING,
|
|
to_state=to_state,
|
|
timestamp=datetime.utcnow(),
|
|
reason=error if error else None
|
|
))
|
|
|
|
if success:
|
|
self._completed[item.url] = item
|
|
else:
|
|
self._failed[item.url] = item
|
|
|
|
async def retry_item(self, item: QueueItem) -> None:
|
|
"""Add an item back to the queue for retry"""
|
|
if not self.validator.validate_transition(
|
|
item,
|
|
ItemState.FAILED,
|
|
ItemState.RETRYING
|
|
):
|
|
logger.error(f"Invalid retry transition for item: {item}")
|
|
return
|
|
|
|
async with self._lock:
|
|
self._processing.pop(item.url, None)
|
|
item.status = ItemState.PENDING.value
|
|
item.last_retry = datetime.utcnow()
|
|
item.priority = max(0, item.priority - 1)
|
|
|
|
# Record transitions
|
|
self.tracker.record_transition(StateTransition(
|
|
item_url=item.url,
|
|
from_state=ItemState.FAILED,
|
|
to_state=ItemState.RETRYING,
|
|
timestamp=datetime.utcnow()
|
|
))
|
|
self.tracker.record_transition(StateTransition(
|
|
item_url=item.url,
|
|
from_state=ItemState.RETRYING,
|
|
to_state=ItemState.PENDING,
|
|
timestamp=datetime.utcnow()
|
|
))
|
|
|
|
self._queue.append(item)
|
|
self._queue.sort(key=lambda x: (-x.priority, x.added_at))
|
|
|
|
async def get_guild_status(self, guild_id: int) -> Dict[str, int]:
|
|
"""Get queue status for a specific guild"""
|
|
async with self._lock:
|
|
return {
|
|
"pending": len([
|
|
item for item in self._queue
|
|
if item.guild_id == guild_id
|
|
]),
|
|
"processing": len([
|
|
item for item in self._processing.values()
|
|
if item.guild_id == guild_id
|
|
]),
|
|
"completed": len([
|
|
item for item in self._completed.values()
|
|
if item.guild_id == guild_id
|
|
]),
|
|
"failed": len([
|
|
item for item in self._failed.values()
|
|
if item.guild_id == guild_id
|
|
])
|
|
}
|
|
|
|
async def clear_state(self) -> None:
|
|
"""Clear all state data"""
|
|
async with self._lock:
|
|
self._queue.clear()
|
|
self._processing.clear()
|
|
self._completed.clear()
|
|
self._failed.clear()
|
|
self._guild_queues.clear()
|
|
self._channel_queues.clear()
|
|
|
|
# Take final snapshot before clearing
|
|
self.tracker.take_snapshot(self)
|
|
|
|
async def get_state_for_persistence(self) -> Dict[str, Any]:
|
|
"""Get current state for persistence"""
|
|
async with self._lock:
|
|
# Take snapshot before persistence
|
|
self.tracker.take_snapshot(self)
|
|
|
|
return {
|
|
"queue": self._queue,
|
|
"processing": self._processing,
|
|
"completed": self._completed,
|
|
"failed": self._failed,
|
|
"history": self.tracker.get_state_history()
|
|
}
|
|
|
|
async def restore_state(self, state: Dict[str, Any]) -> None:
|
|
"""Restore state from persisted data"""
|
|
async with self._lock:
|
|
self._queue = state.get("queue", [])
|
|
self._processing = state.get("processing", {})
|
|
self._completed = state.get("completed", {})
|
|
self._failed = state.get("failed", {})
|
|
|
|
# Validate restored items
|
|
for item in self._queue:
|
|
if not self.validator.validate_item(item):
|
|
logger.warning(f"Removing invalid restored item: {item}")
|
|
self._queue.remove(item)
|
|
|
|
# Rebuild tracking
|
|
self._rebuild_tracking()
|
|
|
|
def _rebuild_tracking(self) -> None:
|
|
"""Rebuild guild and channel tracking from queue data"""
|
|
self._guild_queues.clear()
|
|
self._channel_queues.clear()
|
|
|
|
for item in self._queue:
|
|
if item.guild_id not in self._guild_queues:
|
|
self._guild_queues[item.guild_id] = set()
|
|
self._guild_queues[item.guild_id].add(item.url)
|
|
|
|
if item.channel_id not in self._channel_queues:
|
|
self._channel_queues[item.channel_id] = set()
|
|
self._channel_queues[item.channel_id].add(item.url)
|
|
|
|
def get_state_stats(self) -> Dict[str, Any]:
|
|
"""Get comprehensive state statistics"""
|
|
return {
|
|
"queue_size": len(self._queue),
|
|
"processing_count": len(self._processing),
|
|
"completed_count": len(self._completed),
|
|
"failed_count": len(self._failed),
|
|
"guild_count": len(self._guild_queues),
|
|
"channel_count": len(self._channel_queues),
|
|
"history": self.tracker.get_state_history()
|
|
}
|