Files
Pac-cogs/videoarchiver/queue/state_manager.py

367 lines
13 KiB
Python

"""Module for managing queue state"""
import logging
import asyncio
from enum import Enum
from dataclasses import dataclass
from typing import Dict, Set, List, Optional, Any
from datetime import datetime
from .models import QueueItem, QueueMetrics
logger = logging.getLogger("QueueStateManager")
class ItemState(Enum):
"""Possible states for queue items"""
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
RETRYING = "retrying"
@dataclass
class StateTransition:
"""Records a state transition"""
item_url: str
from_state: ItemState
to_state: ItemState
timestamp: datetime
reason: Optional[str] = None
class StateSnapshot:
"""Represents a point-in-time snapshot of queue state"""
def __init__(self):
self.timestamp = datetime.utcnow()
self.queue: List[QueueItem] = []
self.processing: Dict[str, QueueItem] = {}
self.completed: Dict[str, QueueItem] = {}
self.failed: Dict[str, QueueItem] = {}
self.guild_queues: Dict[int, Set[str]] = {}
self.channel_queues: Dict[int, Set[str]] = {}
def to_dict(self) -> Dict[str, Any]:
"""Convert snapshot to dictionary"""
return {
"timestamp": self.timestamp.isoformat(),
"queue": [item.__dict__ for item in self.queue],
"processing": {url: item.__dict__ for url, item in self.processing.items()},
"completed": {url: item.__dict__ for url, item in self.completed.items()},
"failed": {url: item.__dict__ for url, item in self.failed.items()},
"guild_queues": {gid: list(urls) for gid, urls in self.guild_queues.items()},
"channel_queues": {cid: list(urls) for cid, urls in self.channel_queues.items()}
}
class StateValidator:
"""Validates queue state"""
@staticmethod
def validate_item(item: QueueItem) -> bool:
"""Validate a queue item"""
return all([
isinstance(item.url, str) and item.url,
isinstance(item.guild_id, int) and item.guild_id > 0,
isinstance(item.channel_id, int) and item.channel_id > 0,
isinstance(item.priority, int) and 0 <= item.priority <= 10,
isinstance(item.added_at, datetime),
isinstance(item.status, str)
])
@staticmethod
def validate_transition(
item: QueueItem,
from_state: ItemState,
to_state: ItemState
) -> bool:
"""Validate a state transition"""
valid_transitions = {
ItemState.PENDING: {ItemState.PROCESSING, ItemState.FAILED},
ItemState.PROCESSING: {ItemState.COMPLETED, ItemState.FAILED, ItemState.RETRYING},
ItemState.FAILED: {ItemState.RETRYING},
ItemState.RETRYING: {ItemState.PENDING},
ItemState.COMPLETED: set() # No transitions from completed
}
return to_state in valid_transitions.get(from_state, set())
class StateTracker:
"""Tracks state changes and transitions"""
def __init__(self, max_history: int = 1000):
self.max_history = max_history
self.transitions: List[StateTransition] = []
self.snapshots: List[StateSnapshot] = []
self.state_counts: Dict[ItemState, int] = {state: 0 for state in ItemState}
def record_transition(
self,
transition: StateTransition
) -> None:
"""Record a state transition"""
self.transitions.append(transition)
if len(self.transitions) > self.max_history:
self.transitions.pop(0)
self.state_counts[transition.from_state] -= 1
self.state_counts[transition.to_state] += 1
def take_snapshot(self, state_manager: 'QueueStateManager') -> None:
"""Take a snapshot of current state"""
snapshot = StateSnapshot()
snapshot.queue = state_manager._queue.copy()
snapshot.processing = state_manager._processing.copy()
snapshot.completed = state_manager._completed.copy()
snapshot.failed = state_manager._failed.copy()
snapshot.guild_queues = {
gid: urls.copy() for gid, urls in state_manager._guild_queues.items()
}
snapshot.channel_queues = {
cid: urls.copy() for cid, urls in state_manager._channel_queues.items()
}
self.snapshots.append(snapshot)
if len(self.snapshots) > self.max_history:
self.snapshots.pop(0)
def get_state_history(self) -> Dict[str, Any]:
"""Get state history statistics"""
return {
"transitions": len(self.transitions),
"snapshots": len(self.snapshots),
"state_counts": {
state.value: count
for state, count in self.state_counts.items()
},
"latest_snapshot": (
self.snapshots[-1].to_dict()
if self.snapshots
else None
)
}
class QueueStateManager:
"""Manages the state of the queue system"""
def __init__(self, max_queue_size: int = 1000):
self.max_queue_size = max_queue_size
# Queue storage
self._queue: List[QueueItem] = []
self._processing: Dict[str, QueueItem] = {}
self._completed: Dict[str, QueueItem] = {}
self._failed: Dict[str, QueueItem] = {}
# Tracking
self._guild_queues: Dict[int, Set[str]] = {}
self._channel_queues: Dict[int, Set[str]] = {}
# State management
self._lock = asyncio.Lock()
self.validator = StateValidator()
self.tracker = StateTracker()
async def add_item(self, item: QueueItem) -> bool:
"""Add an item to the queue"""
if not self.validator.validate_item(item):
logger.error(f"Invalid queue item: {item}")
return False
async with self._lock:
if len(self._queue) >= self.max_queue_size:
return False
# Record transition
self.tracker.record_transition(StateTransition(
item_url=item.url,
from_state=ItemState.PENDING,
to_state=ItemState.PENDING,
timestamp=datetime.utcnow(),
reason="Initial add"
))
# Add to main queue
self._queue.append(item)
self._queue.sort(key=lambda x: (-x.priority, x.added_at))
# Update tracking
if item.guild_id not in self._guild_queues:
self._guild_queues[item.guild_id] = set()
self._guild_queues[item.guild_id].add(item.url)
if item.channel_id not in self._channel_queues:
self._channel_queues[item.channel_id] = set()
self._channel_queues[item.channel_id].add(item.url)
# Take snapshot periodically
if len(self._queue) % 100 == 0:
self.tracker.take_snapshot(self)
return True
async def get_next_items(self, count: int = 5) -> List[QueueItem]:
"""Get the next batch of items to process"""
items = []
async with self._lock:
while len(items) < count and self._queue:
item = self._queue.pop(0)
items.append(item)
self._processing[item.url] = item
# Record transition
self.tracker.record_transition(StateTransition(
item_url=item.url,
from_state=ItemState.PENDING,
to_state=ItemState.PROCESSING,
timestamp=datetime.utcnow()
))
return items
async def mark_completed(
self,
item: QueueItem,
success: bool,
error: Optional[str] = None
) -> None:
"""Mark an item as completed or failed"""
async with self._lock:
self._processing.pop(item.url, None)
to_state = ItemState.COMPLETED if success else ItemState.FAILED
self.tracker.record_transition(StateTransition(
item_url=item.url,
from_state=ItemState.PROCESSING,
to_state=to_state,
timestamp=datetime.utcnow(),
reason=error if error else None
))
if success:
self._completed[item.url] = item
else:
self._failed[item.url] = item
async def retry_item(self, item: QueueItem) -> None:
"""Add an item back to the queue for retry"""
if not self.validator.validate_transition(
item,
ItemState.FAILED,
ItemState.RETRYING
):
logger.error(f"Invalid retry transition for item: {item}")
return
async with self._lock:
self._processing.pop(item.url, None)
item.status = ItemState.PENDING.value
item.last_retry = datetime.utcnow()
item.priority = max(0, item.priority - 1)
# Record transitions
self.tracker.record_transition(StateTransition(
item_url=item.url,
from_state=ItemState.FAILED,
to_state=ItemState.RETRYING,
timestamp=datetime.utcnow()
))
self.tracker.record_transition(StateTransition(
item_url=item.url,
from_state=ItemState.RETRYING,
to_state=ItemState.PENDING,
timestamp=datetime.utcnow()
))
self._queue.append(item)
self._queue.sort(key=lambda x: (-x.priority, x.added_at))
async def get_guild_status(self, guild_id: int) -> Dict[str, int]:
"""Get queue status for a specific guild"""
async with self._lock:
return {
"pending": len([
item for item in self._queue
if item.guild_id == guild_id
]),
"processing": len([
item for item in self._processing.values()
if item.guild_id == guild_id
]),
"completed": len([
item for item in self._completed.values()
if item.guild_id == guild_id
]),
"failed": len([
item for item in self._failed.values()
if item.guild_id == guild_id
])
}
async def clear_state(self) -> None:
"""Clear all state data"""
async with self._lock:
self._queue.clear()
self._processing.clear()
self._completed.clear()
self._failed.clear()
self._guild_queues.clear()
self._channel_queues.clear()
# Take final snapshot before clearing
self.tracker.take_snapshot(self)
async def get_state_for_persistence(self) -> Dict[str, Any]:
"""Get current state for persistence"""
async with self._lock:
# Take snapshot before persistence
self.tracker.take_snapshot(self)
return {
"queue": self._queue,
"processing": self._processing,
"completed": self._completed,
"failed": self._failed,
"history": self.tracker.get_state_history()
}
async def restore_state(self, state: Dict[str, Any]) -> None:
"""Restore state from persisted data"""
async with self._lock:
self._queue = state.get("queue", [])
self._processing = state.get("processing", {})
self._completed = state.get("completed", {})
self._failed = state.get("failed", {})
# Validate restored items
for item in self._queue:
if not self.validator.validate_item(item):
logger.warning(f"Removing invalid restored item: {item}")
self._queue.remove(item)
# Rebuild tracking
self._rebuild_tracking()
def _rebuild_tracking(self) -> None:
"""Rebuild guild and channel tracking from queue data"""
self._guild_queues.clear()
self._channel_queues.clear()
for item in self._queue:
if item.guild_id not in self._guild_queues:
self._guild_queues[item.guild_id] = set()
self._guild_queues[item.guild_id].add(item.url)
if item.channel_id not in self._channel_queues:
self._channel_queues[item.channel_id] = set()
self._channel_queues[item.channel_id].add(item.url)
def get_state_stats(self) -> Dict[str, Any]:
"""Get comprehensive state statistics"""
return {
"queue_size": len(self._queue),
"processing_count": len(self._processing),
"completed_count": len(self._completed),
"failed_count": len(self._failed),
"guild_count": len(self._guild_queues),
"channel_count": len(self._channel_queues),
"history": self.tracker.get_state_history()
}