feat(state-machine): add comprehensive callback system for transitions

Extend state machine module with callback infrastructure including:
- Pre/post/error transition callbacks with registry
- Signal-based transition notifications
- Callback configuration and monitoring support
- Helper functions for callback registration
- Improved park ride count updates with FSM integration
This commit is contained in:
pacnpal
2025-12-21 19:20:49 -05:00
parent 7ba0004c93
commit b860e332cb
18 changed files with 4206 additions and 26 deletions

View File

@@ -0,0 +1,195 @@
"""
Management command to list all registered FSM transition callbacks.
This command provides visibility into the callback system configuration,
showing which callbacks are registered for each model and transition.
"""
from django.core.management.base import BaseCommand, CommandParser
from django.apps import apps
from apps.core.state_machine.callbacks import (
callback_registry,
CallbackStage,
)
from apps.core.state_machine.config import callback_config
class Command(BaseCommand):
help = 'List all registered FSM transition callbacks'
def add_arguments(self, parser: CommandParser) -> None:
parser.add_argument(
'--model',
type=str,
help='Filter by model name (e.g., EditSubmission, Ride)',
)
parser.add_argument(
'--stage',
type=str,
choices=['pre', 'post', 'error', 'all'],
default='all',
help='Filter by callback stage',
)
parser.add_argument(
'--verbose',
'-v',
action='store_true',
help='Show detailed callback information',
)
parser.add_argument(
'--format',
type=str,
choices=['text', 'table', 'json'],
default='text',
help='Output format',
)
def handle(self, *args, **options):
model_filter = options.get('model')
stage_filter = options.get('stage')
verbose = options.get('verbose', False)
output_format = options.get('format', 'text')
# Get all registrations
all_registrations = callback_registry.get_all_registrations()
if output_format == 'json':
self._output_json(all_registrations, model_filter, stage_filter)
elif output_format == 'table':
self._output_table(all_registrations, model_filter, stage_filter, verbose)
else:
self._output_text(all_registrations, model_filter, stage_filter, verbose)
def _output_text(self, registrations, model_filter, stage_filter, verbose):
"""Output in text format."""
self.stdout.write(self.style.SUCCESS('\n=== FSM Transition Callbacks ===\n'))
# Group by model
models_seen = set()
total_callbacks = 0
for stage in CallbackStage:
if stage_filter != 'all' and stage.value != stage_filter:
continue
stage_regs = registrations.get(stage, [])
if not stage_regs:
continue
self.stdout.write(self.style.WARNING(f'\n{stage.value.upper()} Callbacks:'))
self.stdout.write('-' * 50)
# Group by model
by_model = {}
for reg in stage_regs:
model_name = reg.model_class.__name__
if model_filter and model_name != model_filter:
continue
if model_name not in by_model:
by_model[model_name] = []
by_model[model_name].append(reg)
models_seen.add(model_name)
total_callbacks += 1
for model_name, regs in sorted(by_model.items()):
self.stdout.write(f'\n {model_name}:')
for reg in regs:
transition = f'{reg.source}{reg.target}'
callback_name = reg.callback.name
priority = reg.callback.priority
self.stdout.write(
f' [{transition}] {callback_name} (priority: {priority})'
)
if verbose:
self.stdout.write(
f' continue_on_error: {reg.callback.continue_on_error}'
)
if hasattr(reg.callback, 'patterns'):
self.stdout.write(
f' patterns: {reg.callback.patterns}'
)
# Summary
self.stdout.write('\n' + '=' * 50)
self.stdout.write(self.style.SUCCESS(
f'Total: {total_callbacks} callbacks across {len(models_seen)} models'
))
# Configuration status
self.stdout.write(self.style.WARNING('\nConfiguration Status:'))
self.stdout.write(f' Callbacks enabled: {callback_config.enabled}')
self.stdout.write(f' Notifications enabled: {callback_config.notifications_enabled}')
self.stdout.write(f' Cache invalidation enabled: {callback_config.cache_invalidation_enabled}')
self.stdout.write(f' Related updates enabled: {callback_config.related_updates_enabled}')
self.stdout.write(f' Debug mode: {callback_config.debug_mode}')
def _output_table(self, registrations, model_filter, stage_filter, verbose):
"""Output in table format."""
self.stdout.write(self.style.SUCCESS('\n=== FSM Transition Callbacks ===\n'))
# Header
if verbose:
header = f"{'Model':<20} {'Field':<10} {'Source':<15} {'Target':<15} {'Stage':<8} {'Callback':<30} {'Priority':<8} {'Continue':<8}"
else:
header = f"{'Model':<20} {'Source':<15} {'Target':<15} {'Stage':<8} {'Callback':<30}"
self.stdout.write(self.style.WARNING(header))
self.stdout.write('-' * len(header))
for stage in CallbackStage:
if stage_filter != 'all' and stage.value != stage_filter:
continue
stage_regs = registrations.get(stage, [])
for reg in stage_regs:
model_name = reg.model_class.__name__
if model_filter and model_name != model_filter:
continue
if verbose:
row = f"{model_name:<20} {reg.field_name:<10} {reg.source:<15} {reg.target:<15} {stage.value:<8} {reg.callback.name:<30} {reg.callback.priority:<8} {str(reg.callback.continue_on_error):<8}"
else:
row = f"{model_name:<20} {reg.source:<15} {reg.target:<15} {stage.value:<8} {reg.callback.name:<30}"
self.stdout.write(row)
def _output_json(self, registrations, model_filter, stage_filter):
"""Output in JSON format."""
import json
output = {
'callbacks': [],
'configuration': {
'enabled': callback_config.enabled,
'notifications_enabled': callback_config.notifications_enabled,
'cache_invalidation_enabled': callback_config.cache_invalidation_enabled,
'related_updates_enabled': callback_config.related_updates_enabled,
'debug_mode': callback_config.debug_mode,
}
}
for stage in CallbackStage:
if stage_filter != 'all' and stage.value != stage_filter:
continue
stage_regs = registrations.get(stage, [])
for reg in stage_regs:
model_name = reg.model_class.__name__
if model_filter and model_name != model_filter:
continue
output['callbacks'].append({
'model': model_name,
'field': reg.field_name,
'source': reg.source,
'target': reg.target,
'stage': stage.value,
'callback': reg.callback.name,
'priority': reg.callback.priority,
'continue_on_error': reg.callback.continue_on_error,
})
self.stdout.write(json.dumps(output, indent=2))

View File

@@ -0,0 +1,234 @@
"""
Management command to test FSM transition callback execution.
This command allows testing callbacks for specific transitions
without actually changing model state.
"""
from django.core.management.base import BaseCommand, CommandParser, CommandError
from django.apps import apps
from django.contrib.auth import get_user_model
from apps.core.state_machine.callbacks import (
callback_registry,
CallbackStage,
TransitionContext,
)
from apps.core.state_machine.monitoring import callback_monitor
class Command(BaseCommand):
help = 'Test FSM transition callbacks for specific transitions'
def add_arguments(self, parser: CommandParser) -> None:
parser.add_argument(
'model',
type=str,
help='Model name (e.g., EditSubmission, Ride, Park)',
)
parser.add_argument(
'source',
type=str,
help='Source state value',
)
parser.add_argument(
'target',
type=str,
help='Target state value',
)
parser.add_argument(
'--instance-id',
type=int,
help='ID of an existing instance to use for testing',
)
parser.add_argument(
'--user-id',
type=int,
help='ID of user to use for testing',
)
parser.add_argument(
'--dry-run',
action='store_true',
help='Show what would be executed without running callbacks',
)
parser.add_argument(
'--stage',
type=str,
choices=['pre', 'post', 'error', 'all'],
default='all',
help='Which callback stage to test',
)
parser.add_argument(
'--field',
type=str,
default='status',
help='FSM field name (default: status)',
)
def handle(self, *args, **options):
model_name = options['model']
source = options['source']
target = options['target']
instance_id = options.get('instance_id')
user_id = options.get('user_id')
dry_run = options.get('dry_run', False)
stage_filter = options.get('stage', 'all')
field_name = options.get('field', 'status')
# Find the model class
model_class = self._find_model(model_name)
if not model_class:
raise CommandError(f"Model '{model_name}' not found")
# Get or create test instance
instance = self._get_or_create_instance(model_class, instance_id, source, field_name)
# Get user if specified
user = None
if user_id:
User = get_user_model()
try:
user = User.objects.get(pk=user_id)
except User.DoesNotExist:
raise CommandError(f"User with ID {user_id} not found")
# Create transition context
context = TransitionContext(
instance=instance,
field_name=field_name,
source_state=source,
target_state=target,
user=user,
)
self.stdout.write(self.style.SUCCESS(
f'\n=== Testing Transition Callbacks ===\n'
f'Model: {model_name}\n'
f'Transition: {source}{target}\n'
f'Field: {field_name}\n'
f'Instance: {instance}\n'
f'User: {user}\n'
f'Dry Run: {dry_run}\n'
))
# Get callbacks for each stage
stages_to_test = []
if stage_filter == 'all':
stages_to_test = [CallbackStage.PRE, CallbackStage.POST, CallbackStage.ERROR]
else:
stages_to_test = [CallbackStage(stage_filter)]
total_callbacks = 0
total_success = 0
total_failures = 0
for stage in stages_to_test:
callbacks = callback_registry.get_callbacks(
model_class, field_name, source, target, stage
)
if not callbacks:
self.stdout.write(
self.style.WARNING(f'\nNo {stage.value.upper()} callbacks registered')
)
continue
self.stdout.write(
self.style.WARNING(f'\n{stage.value.upper()} Callbacks ({len(callbacks)}):')
)
self.stdout.write('-' * 50)
for callback in callbacks:
total_callbacks += 1
callback_info = (
f' {callback.name} (priority: {callback.priority}, '
f'continue_on_error: {callback.continue_on_error})'
)
if dry_run:
self.stdout.write(callback_info)
self.stdout.write(self.style.NOTICE(' → Would execute (dry run)'))
else:
self.stdout.write(callback_info)
# Check should_execute
if not callback.should_execute(context):
self.stdout.write(
self.style.WARNING(' → Skipped (should_execute returned False)')
)
continue
# Execute callback
try:
if stage == CallbackStage.ERROR:
result = callback.execute(
context,
exception=Exception("Test exception")
)
else:
result = callback.execute(context)
if result:
self.stdout.write(self.style.SUCCESS(' → Success'))
total_success += 1
else:
self.stdout.write(self.style.ERROR(' → Failed (returned False)'))
total_failures += 1
except Exception as e:
self.stdout.write(
self.style.ERROR(f' → Exception: {type(e).__name__}: {e}')
)
total_failures += 1
# Summary
self.stdout.write('\n' + '=' * 50)
self.stdout.write(self.style.SUCCESS(f'Total callbacks: {total_callbacks}'))
if not dry_run:
self.stdout.write(self.style.SUCCESS(f'Successful: {total_success}'))
self.stdout.write(
self.style.ERROR(f'Failed: {total_failures}') if total_failures
else self.style.SUCCESS(f'Failed: {total_failures}')
)
# Show monitoring stats if available
if not dry_run:
self.stdout.write(self.style.WARNING('\nRecent Executions:'))
recent = callback_monitor.get_recent_executions(limit=10)
for record in recent:
status = '' if record.success else ''
self.stdout.write(
f' {status} {record.callback_name} [{record.duration_ms:.2f}ms]'
)
def _find_model(self, model_name):
"""Find a model class by name."""
for app_config in apps.get_app_configs():
try:
model = app_config.get_model(model_name)
return model
except LookupError:
continue
return None
def _get_or_create_instance(self, model_class, instance_id, source, field_name):
"""Get an existing instance or create a mock one."""
if instance_id:
try:
return model_class.objects.get(pk=instance_id)
except model_class.DoesNotExist:
raise CommandError(
f"{model_class.__name__} with ID {instance_id} not found"
)
# Create a mock instance for testing
# This won't be saved to the database
instance = model_class()
instance.pk = 0 # Fake ID
setattr(instance, field_name, source)
self.stdout.write(self.style.NOTICE(
'Using mock instance (no --instance-id provided)'
))
return instance

View File

@@ -8,8 +8,50 @@ from .builder import (
from .decorators import ( from .decorators import (
generate_transition_decorator, generate_transition_decorator,
TransitionMethodFactory, TransitionMethodFactory,
with_callbacks,
register_method_callbacks,
)
from .registry import (
TransitionRegistry,
TransitionInfo,
registry_instance,
register_callback,
register_notification_callback,
register_cache_invalidation,
register_related_update,
register_transition_callbacks,
discover_and_register_callbacks,
)
from .callbacks import (
BaseTransitionCallback,
PreTransitionCallback,
PostTransitionCallback,
ErrorTransitionCallback,
TransitionContext,
TransitionCallbackRegistry,
callback_registry,
CallbackStage,
)
from .signals import (
pre_state_transition,
post_state_transition,
state_transition_failed,
register_transition_handler,
on_transition,
on_pre_transition,
on_post_transition,
on_transition_error,
)
from .config import (
CallbackConfig,
callback_config,
get_callback_config,
)
from .monitoring import (
CallbackMonitor,
callback_monitor,
TimedCallbackExecution,
) )
from .registry import TransitionRegistry, TransitionInfo, registry_instance
from .validators import MetadataValidator, ValidationResult from .validators import MetadataValidator, ValidationResult
from .guards import ( from .guards import (
# Role constants # Role constants
@@ -70,10 +112,44 @@ __all__ = [
# Decorators # Decorators
"generate_transition_decorator", "generate_transition_decorator",
"TransitionMethodFactory", "TransitionMethodFactory",
"with_callbacks",
"register_method_callbacks",
# Registry # Registry
"TransitionRegistry", "TransitionRegistry",
"TransitionInfo", "TransitionInfo",
"registry_instance", "registry_instance",
"register_callback",
"register_notification_callback",
"register_cache_invalidation",
"register_related_update",
"register_transition_callbacks",
"discover_and_register_callbacks",
# Callbacks
"BaseTransitionCallback",
"PreTransitionCallback",
"PostTransitionCallback",
"ErrorTransitionCallback",
"TransitionContext",
"TransitionCallbackRegistry",
"callback_registry",
"CallbackStage",
# Signals
"pre_state_transition",
"post_state_transition",
"state_transition_failed",
"register_transition_handler",
"on_transition",
"on_pre_transition",
"on_post_transition",
"on_transition_error",
# Config
"CallbackConfig",
"callback_config",
"get_callback_config",
# Monitoring
"CallbackMonitor",
"callback_monitor",
"TimedCallbackExecution",
# Validators # Validators
"MetadataValidator", "MetadataValidator",
"ValidationResult", "ValidationResult",

View File

@@ -0,0 +1,388 @@
"""
Cache invalidation callbacks for FSM state transitions.
This module provides callback implementations that invalidate cache entries
when state transitions occur.
"""
from typing import Any, Dict, List, Optional, Set, Type
import logging
from django.conf import settings
from django.db import models
from ..callbacks import PostTransitionCallback, TransitionContext
logger = logging.getLogger(__name__)
class CacheInvalidationCallback(PostTransitionCallback):
"""
Base cache invalidation callback for state transitions.
Invalidates cache entries matching specified patterns when a state
transition completes successfully.
"""
name: str = "CacheInvalidationCallback"
def __init__(
self,
patterns: Optional[List[str]] = None,
include_instance_patterns: bool = True,
**kwargs,
):
"""
Initialize the cache invalidation callback.
Args:
patterns: List of cache key patterns to invalidate.
include_instance_patterns: Whether to auto-generate instance-specific patterns.
**kwargs: Additional arguments passed to parent class.
"""
super().__init__(**kwargs)
self.patterns = patterns or []
self.include_instance_patterns = include_instance_patterns
def should_execute(self, context: TransitionContext) -> bool:
"""Check if cache invalidation is enabled."""
callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
if not callback_settings.get('cache_invalidation_enabled', True):
logger.debug("Cache invalidation disabled via settings")
return False
return True
def _get_cache_service(self):
"""Get the EnhancedCacheService instance."""
try:
from apps.core.services.enhanced_cache_service import EnhancedCacheService
return EnhancedCacheService()
except ImportError:
logger.warning("EnhancedCacheService not available")
return None
def _get_instance_patterns(self, context: TransitionContext) -> List[str]:
"""Generate cache key patterns specific to the instance."""
patterns = []
model_name = context.model_name.lower()
instance_id = context.instance.pk
# Standard instance patterns
patterns.append(f"*{model_name}:{instance_id}*")
patterns.append(f"*{model_name}_{instance_id}*")
patterns.append(f"*{model_name}*{instance_id}*")
return patterns
def _get_all_patterns(self, context: TransitionContext) -> Set[str]:
"""Get all patterns to invalidate, including generated ones."""
all_patterns = set(self.patterns)
if self.include_instance_patterns:
all_patterns.update(self._get_instance_patterns(context))
# Substitute placeholders in patterns
model_name = context.model_name.lower()
instance_id = str(context.instance.pk)
substituted = set()
for pattern in all_patterns:
substituted.add(
pattern
.replace('{id}', instance_id)
.replace('{model}', model_name)
)
return substituted
def execute(self, context: TransitionContext) -> bool:
"""Execute the cache invalidation."""
cache_service = self._get_cache_service()
if not cache_service:
# Try using Django's default cache
return self._fallback_invalidation(context)
try:
patterns = self._get_all_patterns(context)
for pattern in patterns:
try:
cache_service.invalidate_pattern(pattern)
logger.debug(f"Invalidated cache pattern: {pattern}")
except Exception as e:
logger.warning(
f"Failed to invalidate cache pattern {pattern}: {e}"
)
logger.info(
f"Cache invalidation completed for {context}: "
f"{len(patterns)} patterns"
)
return True
except Exception as e:
logger.exception(
f"Failed to invalidate cache for {context}: {e}"
)
return False
def _fallback_invalidation(self, context: TransitionContext) -> bool:
"""Fallback cache invalidation using Django's cache framework."""
try:
from django.core.cache import cache
patterns = self._get_all_patterns(context)
# Django's default cache doesn't support pattern deletion
# Log a warning and return True (don't fail the transition)
logger.warning(
f"EnhancedCacheService not available, skipping pattern "
f"invalidation for {len(patterns)} patterns"
)
return True
except Exception as e:
logger.exception(f"Fallback cache invalidation failed: {e}")
return False
class ModelCacheInvalidation(CacheInvalidationCallback):
"""
Invalidates all cache keys for a specific model instance.
Uses model-specific cache key patterns.
"""
name: str = "ModelCacheInvalidation"
# Default patterns by model type
MODEL_PATTERNS = {
'Park': ['*park:{id}*', '*parks*', 'geo:*'],
'Ride': ['*ride:{id}*', '*rides*', '*park:*', 'geo:*'],
'EditSubmission': ['*submission:{id}*', '*moderation*'],
'PhotoSubmission': ['*photo:{id}*', '*moderation*'],
'ModerationReport': ['*report:{id}*', '*moderation*'],
'ModerationQueue': ['*queue*', '*moderation*'],
'BulkOperation': ['*operation:{id}*', '*moderation*'],
}
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _get_instance_patterns(self, context: TransitionContext) -> List[str]:
"""Get model-specific patterns."""
base_patterns = super()._get_instance_patterns(context)
# Add model-specific patterns
model_name = context.model_name
if model_name in self.MODEL_PATTERNS:
model_patterns = self.MODEL_PATTERNS[model_name]
# Substitute {id} placeholder
instance_id = str(context.instance.pk)
for pattern in model_patterns:
base_patterns.append(pattern.replace('{id}', instance_id))
return base_patterns
class RelatedModelCacheInvalidation(CacheInvalidationCallback):
"""
Invalidates cache for related models when a transition occurs.
Useful for maintaining cache consistency across model relationships.
"""
name: str = "RelatedModelCacheInvalidation"
def __init__(
self,
related_fields: Optional[List[str]] = None,
**kwargs,
):
"""
Initialize related model cache invalidation.
Args:
related_fields: List of field names pointing to related models.
**kwargs: Additional arguments.
"""
super().__init__(**kwargs)
self.related_fields = related_fields or []
def _get_related_patterns(self, context: TransitionContext) -> List[str]:
"""Get cache patterns for related models."""
patterns = []
for field_name in self.related_fields:
related_obj = getattr(context.instance, field_name, None)
if related_obj is None:
continue
# Handle foreign key relationships
if hasattr(related_obj, 'pk'):
related_model = type(related_obj).__name__.lower()
related_id = related_obj.pk
patterns.append(f"*{related_model}:{related_id}*")
patterns.append(f"*{related_model}_{related_id}*")
# Handle many-to-many relationships
elif hasattr(related_obj, 'all'):
for obj in related_obj.all():
related_model = type(obj).__name__.lower()
related_id = obj.pk
patterns.append(f"*{related_model}:{related_id}*")
return patterns
def _get_all_patterns(self, context: TransitionContext) -> Set[str]:
"""Get all patterns including related model patterns."""
patterns = super()._get_all_patterns(context)
patterns.update(self._get_related_patterns(context))
return patterns
class PatternCacheInvalidation(CacheInvalidationCallback):
"""
Invalidates cache keys matching specific patterns.
Provides fine-grained control over which cache keys are invalidated.
"""
name: str = "PatternCacheInvalidation"
def __init__(
self,
patterns: List[str],
include_instance_patterns: bool = False,
**kwargs,
):
"""
Initialize pattern-based cache invalidation.
Args:
patterns: List of exact patterns to invalidate.
include_instance_patterns: Whether to include auto-generated patterns.
**kwargs: Additional arguments.
"""
super().__init__(
patterns=patterns,
include_instance_patterns=include_instance_patterns,
**kwargs,
)
class APICacheInvalidation(CacheInvalidationCallback):
"""
Invalidates API response cache entries.
Specialized for API endpoint caching.
"""
name: str = "APICacheInvalidation"
def __init__(
self,
api_prefixes: Optional[List[str]] = None,
include_geo_cache: bool = False,
**kwargs,
):
"""
Initialize API cache invalidation.
Args:
api_prefixes: List of API cache prefixes (e.g., ['api:parks', 'api:rides']).
include_geo_cache: Whether to invalidate geo/map cache entries.
**kwargs: Additional arguments.
"""
super().__init__(**kwargs)
self.api_prefixes = api_prefixes or ['api:*']
self.include_geo_cache = include_geo_cache
def _get_all_patterns(self, context: TransitionContext) -> Set[str]:
"""Get API-specific cache patterns."""
patterns = set()
# Add API patterns
for prefix in self.api_prefixes:
patterns.add(prefix)
# Add geo cache if requested
if self.include_geo_cache:
patterns.add('geo:*')
patterns.add('map:*')
# Add model-specific API patterns
model_name = context.model_name.lower()
instance_id = str(context.instance.pk)
patterns.add(f"api:{model_name}:{instance_id}*")
patterns.add(f"api:{model_name}s*")
return patterns
# Pre-configured cache invalidation callbacks for common models
class ParkCacheInvalidation(CacheInvalidationCallback):
"""Cache invalidation for Park model transitions."""
name: str = "ParkCacheInvalidation"
def __init__(self, **kwargs):
super().__init__(
patterns=[
'*park:{id}*',
'*parks*',
'api:*',
'geo:*',
],
**kwargs,
)
class RideCacheInvalidation(CacheInvalidationCallback):
"""Cache invalidation for Ride model transitions."""
name: str = "RideCacheInvalidation"
def __init__(self, **kwargs):
super().__init__(
patterns=[
'*ride:{id}*',
'*rides*',
'api:*',
'geo:*',
],
**kwargs,
)
def _get_instance_patterns(self, context: TransitionContext) -> List[str]:
"""Include parent park cache patterns."""
patterns = super()._get_instance_patterns(context)
# Invalidate parent park's cache
park = getattr(context.instance, 'park', None)
if park:
park_id = park.pk if hasattr(park, 'pk') else park
patterns.append(f"*park:{park_id}*")
patterns.append(f"*park_{park_id}*")
return patterns
class ModerationCacheInvalidation(CacheInvalidationCallback):
"""Cache invalidation for moderation-related model transitions."""
name: str = "ModerationCacheInvalidation"
def __init__(self, **kwargs):
super().__init__(
patterns=[
'*submission*',
'*moderation*',
'api:moderation*',
],
**kwargs,
)

View File

@@ -0,0 +1,509 @@
"""
Notification callbacks for FSM state transitions.
This module provides callback implementations that send notifications
when state transitions occur.
"""
from typing import Any, Dict, List, Optional, Type
import logging
from django.conf import settings
from django.db import models
from ..callbacks import PostTransitionCallback, TransitionContext
logger = logging.getLogger(__name__)
class NotificationCallback(PostTransitionCallback):
"""
Generic notification callback for state transitions.
Sends notifications using the NotificationService when a state
transition completes successfully.
"""
name: str = "NotificationCallback"
def __init__(
self,
notification_type: str,
recipient_field: str = "submitted_by",
template_name: Optional[str] = None,
include_transition_data: bool = True,
**kwargs,
):
"""
Initialize the notification callback.
Args:
notification_type: The type of notification to create.
recipient_field: The field name on the instance containing the recipient user.
template_name: Optional template name for the notification.
include_transition_data: Whether to include transition metadata in extra_data.
**kwargs: Additional arguments passed to parent class.
"""
super().__init__(**kwargs)
self.notification_type = notification_type
self.recipient_field = recipient_field
self.template_name = template_name
self.include_transition_data = include_transition_data
def should_execute(self, context: TransitionContext) -> bool:
"""Check if notifications are enabled and recipient exists."""
# Check if notifications are disabled in settings
callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
if not callback_settings.get('notifications_enabled', True):
logger.debug("Notifications disabled via settings")
return False
# Check if recipient exists
recipient = self._get_recipient(context.instance)
if not recipient:
logger.debug(
f"No recipient found at {self.recipient_field} for {context}"
)
return False
return True
def _get_recipient(self, instance: models.Model) -> Optional[Any]:
"""Get the notification recipient from the instance."""
return getattr(instance, self.recipient_field, None)
def _get_notification_service(self):
"""Get the NotificationService instance."""
try:
from apps.accounts.services.notification_service import NotificationService
return NotificationService()
except ImportError:
logger.warning("NotificationService not available")
return None
def _build_extra_data(self, context: TransitionContext) -> Dict[str, Any]:
"""Build extra data for the notification."""
extra_data = {}
if self.include_transition_data:
extra_data['transition'] = {
'source_state': context.source_state,
'target_state': context.target_state,
'field_name': context.field_name,
'timestamp': context.timestamp.isoformat(),
}
if context.user:
extra_data['transition']['by_user_id'] = context.user.id
extra_data['transition']['by_username'] = getattr(
context.user, 'username', str(context.user)
)
# Include any extra data from the context
extra_data.update(context.extra_data)
return extra_data
def _get_notification_title(self, context: TransitionContext) -> str:
"""Get the notification title based on context."""
model_name = context.model_name
return f"{model_name} status changed to {context.target_state}"
def _get_notification_message(self, context: TransitionContext) -> str:
"""Get the notification message based on context."""
model_name = context.model_name
return (
f"The {model_name} has transitioned from {context.source_state} "
f"to {context.target_state}."
)
def execute(self, context: TransitionContext) -> bool:
"""Execute the notification callback."""
notification_service = self._get_notification_service()
if not notification_service:
return False
recipient = self._get_recipient(context.instance)
if not recipient:
return False
try:
extra_data = self._build_extra_data(context)
# Create notification with required title and message
notification_service.create_notification(
user=recipient,
notification_type=self.notification_type,
title=self._get_notification_title(context),
message=self._get_notification_message(context),
related_object=context.instance,
extra_data=extra_data,
)
logger.info(
f"Created {self.notification_type} notification for "
f"{recipient} on {context}"
)
return True
except Exception as e:
logger.exception(
f"Failed to create notification for {context}: {e}"
)
return False
class SubmissionApprovedNotification(NotificationCallback):
"""Notification callback for approved submissions."""
name: str = "SubmissionApprovedNotification"
def __init__(self, **kwargs):
super().__init__(
notification_type="submission_approved",
recipient_field="submitted_by",
**kwargs,
)
def execute(self, context: TransitionContext) -> bool:
"""Execute the approval notification."""
notification_service = self._get_notification_service()
if not notification_service:
return False
recipient = self._get_recipient(context.instance)
if not recipient:
return False
try:
# Use the specific method if available
if hasattr(notification_service, 'create_submission_approved_notification'):
notification_service.create_submission_approved_notification(
user=recipient,
submission=context.instance,
approved_by=context.user,
)
else:
# Fall back to generic notification
extra_data = self._build_extra_data(context)
notification_service.create_notification(
user=recipient,
notification_type=self.notification_type,
related_object=context.instance,
extra_data=extra_data,
)
logger.info(
f"Created approval notification for {recipient} on {context}"
)
return True
except Exception as e:
logger.exception(
f"Failed to create approval notification for {context}: {e}"
)
return False
class SubmissionRejectedNotification(NotificationCallback):
"""Notification callback for rejected submissions."""
name: str = "SubmissionRejectedNotification"
def __init__(self, **kwargs):
super().__init__(
notification_type="submission_rejected",
recipient_field="submitted_by",
**kwargs,
)
def execute(self, context: TransitionContext) -> bool:
"""Execute the rejection notification."""
notification_service = self._get_notification_service()
if not notification_service:
return False
recipient = self._get_recipient(context.instance)
if not recipient:
return False
try:
# Use the specific method if available
if hasattr(notification_service, 'create_submission_rejected_notification'):
# Extract rejection reason from extra_data
reason = context.extra_data.get('reason', '')
notification_service.create_submission_rejected_notification(
user=recipient,
submission=context.instance,
rejected_by=context.user,
reason=reason,
)
else:
extra_data = self._build_extra_data(context)
notification_service.create_notification(
user=recipient,
notification_type=self.notification_type,
related_object=context.instance,
extra_data=extra_data,
)
logger.info(
f"Created rejection notification for {recipient} on {context}"
)
return True
except Exception as e:
logger.exception(
f"Failed to create rejection notification for {context}: {e}"
)
return False
class SubmissionEscalatedNotification(NotificationCallback):
"""Notification callback for escalated submissions."""
name: str = "SubmissionEscalatedNotification"
def __init__(self, admin_recipient: bool = True, **kwargs):
"""
Initialize escalation notification.
Args:
admin_recipient: If True, notify admins. If False, notify submitter.
"""
super().__init__(
notification_type="submission_escalated",
recipient_field="submitted_by" if not admin_recipient else None,
**kwargs,
)
self.admin_recipient = admin_recipient
def _get_admin_users(self):
"""Get admin users to notify."""
try:
from django.contrib.auth import get_user_model
User = get_user_model()
return User.objects.filter(is_staff=True, is_active=True)
except Exception as e:
logger.exception(f"Failed to get admin users: {e}")
return []
def execute(self, context: TransitionContext) -> bool:
"""Execute the escalation notification."""
notification_service = self._get_notification_service()
if not notification_service:
return False
try:
extra_data = self._build_extra_data(context)
# Add escalation reason if available
if 'reason' in context.extra_data:
extra_data['escalation_reason'] = context.extra_data['reason']
if self.admin_recipient:
# Notify admin users
admins = self._get_admin_users()
for admin in admins:
notification_service.create_notification(
user=admin,
notification_type=self.notification_type,
related_object=context.instance,
extra_data=extra_data,
)
logger.info(
f"Created escalation notifications for {admins.count()} admins"
)
else:
# Notify the submitter
recipient = self._get_recipient(context.instance)
if recipient:
notification_service.create_notification(
user=recipient,
notification_type=self.notification_type,
related_object=context.instance,
extra_data=extra_data,
)
logger.info(
f"Created escalation notification for {recipient}"
)
return True
except Exception as e:
logger.exception(
f"Failed to create escalation notification for {context}: {e}"
)
return False
class StatusChangeNotification(NotificationCallback):
"""
Generic notification for entity status changes.
Used for Parks and Rides when their operational status changes.
"""
name: str = "StatusChangeNotification"
def __init__(
self,
significant_states: Optional[List[str]] = None,
notify_admins: bool = True,
**kwargs,
):
"""
Initialize status change notification.
Args:
significant_states: States that trigger admin notifications.
notify_admins: Whether to notify admin users.
"""
super().__init__(
notification_type="status_change",
**kwargs,
)
self.significant_states = significant_states or [
'CLOSED_PERM', 'DEMOLISHED', 'RELOCATED'
]
self.notify_admins = notify_admins
def should_execute(self, context: TransitionContext) -> bool:
"""Only execute for significant state changes."""
# Check if notifications are disabled
callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
if not callback_settings.get('notifications_enabled', True):
return False
# Only notify for significant status changes
if context.target_state not in self.significant_states:
return False
return True
def execute(self, context: TransitionContext) -> bool:
"""Execute the status change notification."""
if not self.notify_admins:
return True
notification_service = self._get_notification_service()
if not notification_service:
return False
try:
extra_data = self._build_extra_data(context)
extra_data['entity_type'] = context.model_name
extra_data['entity_id'] = context.instance.pk
# Notify admin users
admins = self._get_admin_users()
for admin in admins:
notification_service.create_notification(
user=admin,
notification_type=self.notification_type,
related_object=context.instance,
extra_data=extra_data,
)
logger.info(
f"Created status change notifications for {context.model_name} "
f"({context.source_state}{context.target_state})"
)
return True
except Exception as e:
logger.exception(
f"Failed to create status change notification for {context}: {e}"
)
return False
def _get_admin_users(self):
"""Get admin users to notify."""
try:
from django.contrib.auth import get_user_model
User = get_user_model()
return User.objects.filter(is_staff=True, is_active=True)
except Exception as e:
logger.exception(f"Failed to get admin users: {e}")
return []
class ModerationNotificationCallback(NotificationCallback):
"""
Specialized callback for moderation-related notifications.
Handles notifications for ModerationReport, ModerationQueue,
and BulkOperation models.
"""
name: str = "ModerationNotificationCallback"
# Mapping of (model_name, target_state) to notification type
NOTIFICATION_MAPPING = {
('ModerationReport', 'UNDER_REVIEW'): 'report_under_review',
('ModerationReport', 'RESOLVED'): 'report_resolved',
('ModerationQueue', 'IN_PROGRESS'): 'queue_in_progress',
('ModerationQueue', 'COMPLETED'): 'queue_completed',
('BulkOperation', 'RUNNING'): 'bulk_operation_started',
('BulkOperation', 'COMPLETED'): 'bulk_operation_completed',
('BulkOperation', 'FAILED'): 'bulk_operation_failed',
}
def __init__(self, **kwargs):
super().__init__(
notification_type="moderation",
**kwargs,
)
def _get_notification_type(self, context: TransitionContext) -> Optional[str]:
"""Get the specific notification type based on model and state."""
key = (context.model_name, context.target_state)
return self.NOTIFICATION_MAPPING.get(key)
def _get_recipient(self, instance: models.Model) -> Optional[Any]:
"""Get the appropriate recipient based on model type."""
# Try common recipient fields
for field in ['reporter', 'assigned_to', 'created_by', 'submitted_by']:
recipient = getattr(instance, field, None)
if recipient:
return recipient
return None
def execute(self, context: TransitionContext) -> bool:
"""Execute the moderation notification."""
notification_service = self._get_notification_service()
if not notification_service:
return False
notification_type = self._get_notification_type(context)
if not notification_type:
logger.debug(
f"No notification type defined for {context.model_name} "
f"{context.target_state}"
)
return True # Not an error, just no notification needed
recipient = self._get_recipient(context.instance)
if not recipient:
logger.debug(f"No recipient found for {context}")
return True
try:
extra_data = self._build_extra_data(context)
notification_service.create_notification(
user=recipient,
notification_type=notification_type,
related_object=context.instance,
extra_data=extra_data,
)
logger.info(
f"Created {notification_type} notification for {recipient}"
)
return True
except Exception as e:
logger.exception(
f"Failed to create moderation notification for {context}: {e}"
)
return False

View File

@@ -0,0 +1,432 @@
"""
Related model update callbacks for FSM state transitions.
This module provides callback implementations that update related models
when state transitions occur.
"""
from typing import Any, Callable, Dict, List, Optional, Set, Type
import logging
from django.conf import settings
from django.db import models, transaction
from ..callbacks import PostTransitionCallback, TransitionContext
logger = logging.getLogger(__name__)
class RelatedModelUpdateCallback(PostTransitionCallback):
"""
Base callback for updating related models after state transitions.
Executes custom update logic when a state transition completes.
"""
name: str = "RelatedModelUpdateCallback"
def __init__(
self,
update_function: Optional[Callable[[TransitionContext], bool]] = None,
use_transaction: bool = True,
**kwargs,
):
"""
Initialize the related model update callback.
Args:
update_function: Optional function to call with the context.
use_transaction: Whether to wrap updates in a transaction.
**kwargs: Additional arguments passed to parent class.
"""
super().__init__(**kwargs)
self.update_function = update_function
self.use_transaction = use_transaction
def should_execute(self, context: TransitionContext) -> bool:
"""Check if related updates are enabled."""
callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
if not callback_settings.get('related_updates_enabled', True):
logger.debug("Related model updates disabled via settings")
return False
return True
def perform_update(self, context: TransitionContext) -> bool:
"""
Perform the actual update logic.
Override this method in subclasses to define specific update behavior.
Args:
context: The transition context.
Returns:
True if update succeeded, False otherwise.
"""
if self.update_function:
return self.update_function(context)
return True
def execute(self, context: TransitionContext) -> bool:
"""Execute the related model update."""
try:
if self.use_transaction:
with transaction.atomic():
return self.perform_update(context)
else:
return self.perform_update(context)
except Exception as e:
logger.exception(
f"Failed to update related models for {context}: {e}"
)
return False
class ParkCountUpdateCallback(RelatedModelUpdateCallback):
"""
Updates park ride counts when ride status changes.
Recalculates ride_count and coaster_count on the parent Park
when a Ride transitions to or from an operational status.
"""
name: str = "ParkCountUpdateCallback"
# Status values that count as "active" rides
ACTIVE_STATUSES = {'OPERATING', 'SEASONAL', 'UNDER_CONSTRUCTION'}
# Status values that indicate a ride is no longer countable
INACTIVE_STATUSES = {'CLOSED_PERM', 'DEMOLISHED', 'RELOCATED', 'REMOVED'}
def should_execute(self, context: TransitionContext) -> bool:
"""Only execute when status affects ride counts."""
if not super().should_execute(context):
return False
# Check if this transition affects ride counts
source = context.source_state
target = context.target_state
# Execute if transitioning to/from an active or inactive status
source_affects = source in self.ACTIVE_STATUSES or source in self.INACTIVE_STATUSES
target_affects = target in self.ACTIVE_STATUSES or target in self.INACTIVE_STATUSES
return source_affects or target_affects
def perform_update(self, context: TransitionContext) -> bool:
"""Update park ride counts."""
instance = context.instance
# Get the parent park
park = getattr(instance, 'park', None)
if not park:
logger.debug(f"No park found for ride {instance.pk}")
return True
try:
# Import here to avoid circular imports
from apps.parks.models.parks import Park
from apps.rides.models.rides import Ride
# Get the park ID (handle both object and ID)
park_id = park.pk if hasattr(park, 'pk') else park
# Calculate new counts efficiently
ride_queryset = Ride.objects.filter(park_id=park_id)
# Count active rides
active_statuses = list(self.ACTIVE_STATUSES)
ride_count = ride_queryset.filter(
status__in=active_statuses
).count()
# Count active coasters
coaster_count = ride_queryset.filter(
status__in=active_statuses,
ride_type='ROLLER_COASTER'
).count()
# Update park counts
Park.objects.filter(id=park_id).update(
ride_count=ride_count,
coaster_count=coaster_count,
)
logger.info(
f"Updated park {park_id} counts: "
f"ride_count={ride_count}, coaster_count={coaster_count}"
)
return True
except Exception as e:
logger.exception(
f"Failed to update park counts for {instance.pk}: {e}"
)
return False
class SearchTextUpdateCallback(RelatedModelUpdateCallback):
"""
Recalculates search_text field when status changes.
Updates the search_text field to include the new status label
for search indexing purposes.
"""
name: str = "SearchTextUpdateCallback"
def perform_update(self, context: TransitionContext) -> bool:
"""Update the search_text field."""
instance = context.instance
# Check if instance has search_text field
if not hasattr(instance, 'search_text'):
logger.debug(
f"{context.model_name} has no search_text field"
)
return True
try:
# Call the model's update_search_text method if available
if hasattr(instance, 'update_search_text'):
instance.update_search_text()
instance.save(update_fields=['search_text'])
logger.info(
f"Updated search_text for {context.model_name} {instance.pk}"
)
else:
# Build search text manually
self._build_search_text(instance, context)
return True
except Exception as e:
logger.exception(
f"Failed to update search_text for {instance.pk}: {e}"
)
return False
def _build_search_text(
self,
instance: models.Model,
context: TransitionContext,
) -> None:
"""Build search text from instance fields."""
parts = []
# Common searchable fields
for field in ['name', 'title', 'description', 'location']:
value = getattr(instance, field, None)
if value:
parts.append(str(value))
# Add status label
status_field = getattr(instance, context.field_name, None)
if status_field:
# Try to get the display label
display_method = f'get_{context.field_name}_display'
if hasattr(instance, display_method):
parts.append(getattr(instance, display_method)())
else:
parts.append(str(status_field))
# Update search_text
instance.search_text = ' '.join(parts)
instance.save(update_fields=['search_text'])
class ComputedFieldUpdateCallback(RelatedModelUpdateCallback):
"""
Generic callback for updating computed fields after transitions.
Recalculates specified computed fields when a state transition occurs.
"""
name: str = "ComputedFieldUpdateCallback"
def __init__(
self,
computed_fields: Optional[List[str]] = None,
update_method: Optional[str] = None,
**kwargs,
):
"""
Initialize computed field update callback.
Args:
computed_fields: List of field names to update.
update_method: Name of method to call for updating fields.
**kwargs: Additional arguments.
"""
super().__init__(**kwargs)
self.computed_fields = computed_fields or []
self.update_method = update_method
def perform_update(self, context: TransitionContext) -> bool:
"""Update computed fields."""
instance = context.instance
try:
# Call update method if specified
if self.update_method:
method = getattr(instance, self.update_method, None)
if method and callable(method):
method()
# Update specific fields
updated_fields = []
for field_name in self.computed_fields:
update_method_name = f'compute_{field_name}'
if hasattr(instance, update_method_name):
method = getattr(instance, update_method_name)
if callable(method):
new_value = method()
setattr(instance, field_name, new_value)
updated_fields.append(field_name)
# Save updated fields
if updated_fields:
instance.save(update_fields=updated_fields)
logger.info(
f"Updated computed fields {updated_fields} for "
f"{context.model_name} {instance.pk}"
)
return True
except Exception as e:
logger.exception(
f"Failed to update computed fields for {instance.pk}: {e}"
)
return False
class RideStatusUpdateCallback(RelatedModelUpdateCallback):
"""
Handles ride-specific updates when status changes.
Updates post_closing_status, closing_date, and related fields.
"""
name: str = "RideStatusUpdateCallback"
def should_execute(self, context: TransitionContext) -> bool:
"""Execute for specific ride status transitions."""
if not super().should_execute(context):
return False
# Only execute for Ride model
if context.model_name != 'Ride':
return False
return True
def perform_update(self, context: TransitionContext) -> bool:
"""Perform ride-specific status updates."""
instance = context.instance
target = context.target_state
try:
# Handle CLOSING → post_closing_status transition
if context.source_state == 'CLOSING' and target != 'CLOSING':
post_closing_status = getattr(instance, 'post_closing_status', None)
if post_closing_status and target == post_closing_status:
# Clear post_closing_status after applying it
instance.post_closing_status = None
instance.save(update_fields=['post_closing_status'])
logger.info(
f"Cleared post_closing_status for ride {instance.pk}"
)
return True
except Exception as e:
logger.exception(
f"Failed to update ride status fields for {instance.pk}: {e}"
)
return False
class ModerationQueueUpdateCallback(RelatedModelUpdateCallback):
"""
Updates moderation queue and statistics when submissions change state.
"""
name: str = "ModerationQueueUpdateCallback"
def should_execute(self, context: TransitionContext) -> bool:
"""Execute for moderation-related models."""
if not super().should_execute(context):
return False
# Only for submission and report models
model_name = context.model_name
return model_name in (
'EditSubmission', 'PhotoSubmission', 'ModerationReport'
)
def perform_update(self, context: TransitionContext) -> bool:
"""Update moderation queue entries."""
instance = context.instance
target = context.target_state
try:
# Mark related queue items as completed when submission is resolved
if target in ('APPROVED', 'REJECTED', 'RESOLVED'):
self._update_queue_items(instance, context)
return True
except Exception as e:
logger.exception(
f"Failed to update moderation queue for {instance.pk}: {e}"
)
return False
def _update_queue_items(
self,
instance: models.Model,
context: TransitionContext,
) -> None:
"""Update related queue items to completed status."""
try:
from apps.moderation.models import ModerationQueue
# Find related queue items
content_type_id = self._get_content_type_id(instance)
if not content_type_id:
return
queue_items = ModerationQueue.objects.filter(
content_type_id=content_type_id,
object_id=instance.pk,
status='IN_PROGRESS',
)
for item in queue_items:
if hasattr(item, 'complete'):
item.complete(user=context.user)
else:
item.status = 'COMPLETED'
item.save(update_fields=['status'])
if queue_items.exists():
logger.info(
f"Marked {queue_items.count()} queue items as completed"
)
except ImportError:
logger.debug("ModerationQueue model not available")
except Exception as e:
logger.warning(f"Failed to update queue items: {e}")
def _get_content_type_id(self, instance: models.Model) -> Optional[int]:
"""Get content type ID for the instance."""
try:
from django.contrib.contenttypes.models import ContentType
content_type = ContentType.objects.get_for_model(type(instance))
return content_type.pk
except Exception:
return None

View File

@@ -0,0 +1,403 @@
"""
Callback configuration system for FSM state transitions.
This module provides centralized configuration for all FSM transition callbacks,
including enable/disable settings, priorities, and environment-specific overrides.
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Type
import logging
from django.conf import settings
from django.db import models
logger = logging.getLogger(__name__)
@dataclass
class TransitionCallbackConfig:
"""Configuration for callbacks on a specific transition."""
notifications_enabled: bool = True
cache_invalidation_enabled: bool = True
related_updates_enabled: bool = True
notification_template: Optional[str] = None
cache_patterns: List[str] = field(default_factory=list)
priority: int = 100
extra_data: Dict[str, Any] = field(default_factory=dict)
@dataclass
class ModelCallbackConfig:
"""Configuration for all callbacks on a model."""
model_name: str
field_name: str = 'status'
transitions: Dict[tuple, TransitionCallbackConfig] = field(default_factory=dict)
default_config: TransitionCallbackConfig = field(default_factory=TransitionCallbackConfig)
class CallbackConfig:
"""
Centralized configuration for all FSM transition callbacks.
Provides settings for:
- Enabling/disabling callback types globally or per-transition
- Configuring notification templates
- Setting cache invalidation patterns
- Defining callback priorities
Configuration can be overridden via Django settings.
"""
# Default settings
DEFAULT_SETTINGS = {
'enabled': True,
'notifications_enabled': True,
'cache_invalidation_enabled': True,
'related_updates_enabled': True,
'debug_mode': False,
'log_callbacks': False,
}
# Model-specific configurations
MODEL_CONFIGS: Dict[str, ModelCallbackConfig] = {}
def __init__(self):
self._settings = self._load_settings()
self._model_configs = self._build_model_configs()
def _load_settings(self) -> Dict[str, Any]:
"""Load settings from Django configuration."""
django_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
merged = dict(self.DEFAULT_SETTINGS)
merged.update(django_settings)
return merged
def _build_model_configs(self) -> Dict[str, ModelCallbackConfig]:
"""Build model-specific configurations."""
return {
'EditSubmission': ModelCallbackConfig(
model_name='EditSubmission',
field_name='status',
transitions={
('PENDING', 'APPROVED'): TransitionCallbackConfig(
notification_template='submission_approved',
cache_patterns=['*submission*', '*moderation*'],
),
('PENDING', 'REJECTED'): TransitionCallbackConfig(
notification_template='submission_rejected',
cache_patterns=['*submission*', '*moderation*'],
),
('PENDING', 'ESCALATED'): TransitionCallbackConfig(
notification_template='submission_escalated',
cache_patterns=['*submission*', '*moderation*'],
),
},
),
'PhotoSubmission': ModelCallbackConfig(
model_name='PhotoSubmission',
field_name='status',
transitions={
('PENDING', 'APPROVED'): TransitionCallbackConfig(
notification_template='photo_approved',
cache_patterns=['*photo*', '*moderation*'],
),
('PENDING', 'REJECTED'): TransitionCallbackConfig(
notification_template='photo_rejected',
cache_patterns=['*photo*', '*moderation*'],
),
},
),
'ModerationReport': ModelCallbackConfig(
model_name='ModerationReport',
field_name='status',
transitions={
('PENDING', 'UNDER_REVIEW'): TransitionCallbackConfig(
notification_template='report_under_review',
cache_patterns=['*report*', '*moderation*'],
),
('UNDER_REVIEW', 'RESOLVED'): TransitionCallbackConfig(
notification_template='report_resolved',
cache_patterns=['*report*', '*moderation*'],
),
},
),
'ModerationQueue': ModelCallbackConfig(
model_name='ModerationQueue',
field_name='status',
transitions={
('PENDING', 'IN_PROGRESS'): TransitionCallbackConfig(
notification_template='queue_in_progress',
cache_patterns=['*queue*', '*moderation*'],
),
('IN_PROGRESS', 'COMPLETED'): TransitionCallbackConfig(
notification_template='queue_completed',
cache_patterns=['*queue*', '*moderation*'],
),
},
),
'BulkOperation': ModelCallbackConfig(
model_name='BulkOperation',
field_name='status',
transitions={
('PENDING', 'RUNNING'): TransitionCallbackConfig(
notification_template='bulk_operation_started',
cache_patterns=['*operation*', '*moderation*'],
),
('RUNNING', 'COMPLETED'): TransitionCallbackConfig(
notification_template='bulk_operation_completed',
cache_patterns=['*operation*', '*moderation*'],
),
('RUNNING', 'FAILED'): TransitionCallbackConfig(
notification_template='bulk_operation_failed',
cache_patterns=['*operation*', '*moderation*'],
),
},
),
'Park': ModelCallbackConfig(
model_name='Park',
field_name='status',
default_config=TransitionCallbackConfig(
cache_patterns=['*park*', 'api:*', 'geo:*'],
),
transitions={
('*', 'CLOSED_PERM'): TransitionCallbackConfig(
notifications_enabled=True,
notification_template='park_closed_permanently',
cache_patterns=['*park*', 'api:*', 'geo:*'],
),
('*', 'OPERATING'): TransitionCallbackConfig(
notifications_enabled=False,
cache_patterns=['*park*', 'api:*', 'geo:*'],
),
},
),
'Ride': ModelCallbackConfig(
model_name='Ride',
field_name='status',
default_config=TransitionCallbackConfig(
cache_patterns=['*ride*', '*park*', 'api:*', 'geo:*'],
),
transitions={
('*', 'OPERATING'): TransitionCallbackConfig(
cache_patterns=['*ride*', '*park*', 'api:*', 'geo:*'],
related_updates_enabled=True,
),
('*', 'CLOSED_PERM'): TransitionCallbackConfig(
cache_patterns=['*ride*', '*park*', 'api:*', 'geo:*'],
related_updates_enabled=True,
),
('*', 'DEMOLISHED'): TransitionCallbackConfig(
cache_patterns=['*ride*', '*park*', 'api:*', 'geo:*'],
related_updates_enabled=True,
),
('*', 'RELOCATED'): TransitionCallbackConfig(
cache_patterns=['*ride*', '*park*', 'api:*', 'geo:*'],
related_updates_enabled=True,
),
},
),
}
@property
def enabled(self) -> bool:
"""Check if callbacks are globally enabled."""
return self._settings.get('enabled', True)
@property
def notifications_enabled(self) -> bool:
"""Check if notification callbacks are enabled."""
return self._settings.get('notifications_enabled', True)
@property
def cache_invalidation_enabled(self) -> bool:
"""Check if cache invalidation is enabled."""
return self._settings.get('cache_invalidation_enabled', True)
@property
def related_updates_enabled(self) -> bool:
"""Check if related model updates are enabled."""
return self._settings.get('related_updates_enabled', True)
@property
def debug_mode(self) -> bool:
"""Check if debug mode is enabled."""
return self._settings.get('debug_mode', False)
@property
def log_callbacks(self) -> bool:
"""Check if callback logging is enabled."""
return self._settings.get('log_callbacks', False)
def get_config(
self,
model_name: str,
source: str,
target: str,
) -> TransitionCallbackConfig:
"""
Get configuration for a specific transition.
Args:
model_name: Name of the model.
source: Source state.
target: Target state.
Returns:
TransitionCallbackConfig for the transition.
"""
model_config = self._model_configs.get(model_name)
if not model_config:
return TransitionCallbackConfig()
# Try exact match first
config = model_config.transitions.get((source, target))
if config:
return config
# Try wildcard source
config = model_config.transitions.get(('*', target))
if config:
return config
# Try wildcard target
config = model_config.transitions.get((source, '*'))
if config:
return config
# Return default config
return model_config.default_config
def is_notification_enabled(
self,
model_name: str,
source: str,
target: str,
) -> bool:
"""Check if notifications are enabled for a transition."""
if not self.enabled or not self.notifications_enabled:
return False
config = self.get_config(model_name, source, target)
return config.notifications_enabled
def is_cache_invalidation_enabled(
self,
model_name: str,
source: str,
target: str,
) -> bool:
"""Check if cache invalidation is enabled for a transition."""
if not self.enabled or not self.cache_invalidation_enabled:
return False
config = self.get_config(model_name, source, target)
return config.cache_invalidation_enabled
def is_related_updates_enabled(
self,
model_name: str,
source: str,
target: str,
) -> bool:
"""Check if related updates are enabled for a transition."""
if not self.enabled or not self.related_updates_enabled:
return False
config = self.get_config(model_name, source, target)
return config.related_updates_enabled
def get_cache_patterns(
self,
model_name: str,
source: str,
target: str,
) -> List[str]:
"""Get cache invalidation patterns for a transition."""
config = self.get_config(model_name, source, target)
return config.cache_patterns
def get_notification_template(
self,
model_name: str,
source: str,
target: str,
) -> Optional[str]:
"""Get notification template for a transition."""
config = self.get_config(model_name, source, target)
return config.notification_template
def register_model_config(
self,
model_class: Type[models.Model],
config: ModelCallbackConfig,
) -> None:
"""
Register a custom model configuration.
Args:
model_class: The model class.
config: The configuration to register.
"""
model_name = model_class.__name__
self._model_configs[model_name] = config
logger.debug(f"Registered callback config for {model_name}")
def update_transition_config(
self,
model_name: str,
source: str,
target: str,
**kwargs,
) -> None:
"""
Update configuration for a specific transition.
Args:
model_name: Name of the model.
source: Source state.
target: Target state.
**kwargs: Configuration values to update.
"""
if model_name not in self._model_configs:
self._model_configs[model_name] = ModelCallbackConfig(
model_name=model_name
)
model_config = self._model_configs[model_name]
transition_key = (source, target)
if transition_key not in model_config.transitions:
model_config.transitions[transition_key] = TransitionCallbackConfig()
config = model_config.transitions[transition_key]
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
def reload_settings(self) -> None:
"""Reload settings from Django configuration."""
self._settings = self._load_settings()
logger.debug("Reloaded callback configuration settings")
# Global configuration instance
callback_config = CallbackConfig()
def get_callback_config() -> CallbackConfig:
"""Get the global callback configuration instance."""
return callback_config
__all__ = [
'TransitionCallbackConfig',
'ModelCallbackConfig',
'CallbackConfig',
'callback_config',
'get_callback_config',
]

View File

@@ -303,6 +303,8 @@ class TransitionMethodFactory:
target: str, target: str,
field_name: str = "status", field_name: str = "status",
permission_guard: Optional[Callable] = None, permission_guard: Optional[Callable] = None,
enable_callbacks: bool = True,
emit_signals: bool = True,
) -> Callable: ) -> Callable:
""" """
Create an approval transition method. Create an approval transition method.
@@ -312,6 +314,8 @@ class TransitionMethodFactory:
target: Target state value target: Target state value
field_name: Name of the FSM field field_name: Name of the FSM field
permission_guard: Optional permission guard permission_guard: Optional permission guard
enable_callbacks: Whether to wrap with callback execution
emit_signals: Whether to emit Django signals
Returns: Returns:
Approval transition method Approval transition method
@@ -335,6 +339,13 @@ class TransitionMethodFactory:
instance.approved_at = timezone.now() instance.approved_at = timezone.now()
# Apply callback wrapper if enabled
if enable_callbacks:
approve = with_callbacks(
field_name=field_name,
emit_signals=emit_signals,
)(approve)
return approve return approve
@staticmethod @staticmethod
@@ -343,6 +354,8 @@ class TransitionMethodFactory:
target: str, target: str,
field_name: str = "status", field_name: str = "status",
permission_guard: Optional[Callable] = None, permission_guard: Optional[Callable] = None,
enable_callbacks: bool = True,
emit_signals: bool = True,
) -> Callable: ) -> Callable:
""" """
Create a rejection transition method. Create a rejection transition method.
@@ -352,6 +365,8 @@ class TransitionMethodFactory:
target: Target state value target: Target state value
field_name: Name of the FSM field field_name: Name of the FSM field
permission_guard: Optional permission guard permission_guard: Optional permission guard
enable_callbacks: Whether to wrap with callback execution
emit_signals: Whether to emit Django signals
Returns: Returns:
Rejection transition method Rejection transition method
@@ -375,6 +390,13 @@ class TransitionMethodFactory:
instance.rejected_at = timezone.now() instance.rejected_at = timezone.now()
# Apply callback wrapper if enabled
if enable_callbacks:
reject = with_callbacks(
field_name=field_name,
emit_signals=emit_signals,
)(reject)
return reject return reject
@staticmethod @staticmethod
@@ -383,6 +405,8 @@ class TransitionMethodFactory:
target: str, target: str,
field_name: str = "status", field_name: str = "status",
permission_guard: Optional[Callable] = None, permission_guard: Optional[Callable] = None,
enable_callbacks: bool = True,
emit_signals: bool = True,
) -> Callable: ) -> Callable:
""" """
Create an escalation transition method. Create an escalation transition method.
@@ -392,6 +416,8 @@ class TransitionMethodFactory:
target: Target state value target: Target state value
field_name: Name of the FSM field field_name: Name of the FSM field
permission_guard: Optional permission guard permission_guard: Optional permission guard
enable_callbacks: Whether to wrap with callback execution
emit_signals: Whether to emit Django signals
Returns: Returns:
Escalation transition method Escalation transition method
@@ -415,6 +441,13 @@ class TransitionMethodFactory:
instance.escalated_at = timezone.now() instance.escalated_at = timezone.now()
# Apply callback wrapper if enabled
if enable_callbacks:
escalate = with_callbacks(
field_name=field_name,
emit_signals=emit_signals,
)(escalate)
return escalate return escalate
@staticmethod @staticmethod
@@ -425,6 +458,8 @@ class TransitionMethodFactory:
field_name: str = "status", field_name: str = "status",
permission_guard: Optional[Callable] = None, permission_guard: Optional[Callable] = None,
docstring: Optional[str] = None, docstring: Optional[str] = None,
enable_callbacks: bool = True,
emit_signals: bool = True,
) -> Callable: ) -> Callable:
""" """
Create a generic transition method. Create a generic transition method.
@@ -436,6 +471,8 @@ class TransitionMethodFactory:
field_name: Name of the FSM field field_name: Name of the FSM field
permission_guard: Optional permission guard permission_guard: Optional permission guard
docstring: Optional docstring for the method docstring: Optional docstring for the method
enable_callbacks: Whether to wrap with callback execution
emit_signals: Whether to emit Django signals
Returns: Returns:
Generic transition method Generic transition method
@@ -460,6 +497,13 @@ class TransitionMethodFactory:
f"Transition from {source} to {target}" f"Transition from {source} to {target}"
) )
# Apply callback wrapper if enabled
if enable_callbacks:
generic_transition = with_callbacks(
field_name=field_name,
emit_signals=emit_signals,
)(generic_transition)
return generic_transition return generic_transition

View File

@@ -0,0 +1,455 @@
"""
Callback monitoring and debugging for FSM state transitions.
This module provides tools for monitoring callback execution,
tracking performance, and debugging transition issues.
"""
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple, Type
from collections import defaultdict
import logging
import time
import threading
from django.conf import settings
from django.db import models
from .callbacks import TransitionContext
logger = logging.getLogger(__name__)
@dataclass
class CallbackExecutionRecord:
"""Record of a single callback execution."""
callback_name: str
model_name: str
field_name: str
source_state: str
target_state: str
stage: str
timestamp: datetime
duration_ms: float
success: bool
error_message: Optional[str] = None
instance_id: Optional[int] = None
user_id: Optional[int] = None
@dataclass
class CallbackStats:
"""Statistics for a specific callback."""
callback_name: str
total_executions: int = 0
successful_executions: int = 0
failed_executions: int = 0
total_duration_ms: float = 0.0
min_duration_ms: float = float('inf')
max_duration_ms: float = 0.0
last_execution: Optional[datetime] = None
last_error: Optional[str] = None
@property
def avg_duration_ms(self) -> float:
"""Calculate average execution duration."""
if self.total_executions == 0:
return 0.0
return self.total_duration_ms / self.total_executions
@property
def success_rate(self) -> float:
"""Calculate success rate as percentage."""
if self.total_executions == 0:
return 0.0
return (self.successful_executions / self.total_executions) * 100
def record_execution(
self,
duration_ms: float,
success: bool,
error_message: Optional[str] = None,
) -> None:
"""Record a callback execution."""
self.total_executions += 1
self.total_duration_ms += duration_ms
self.min_duration_ms = min(self.min_duration_ms, duration_ms)
self.max_duration_ms = max(self.max_duration_ms, duration_ms)
self.last_execution = datetime.now()
if success:
self.successful_executions += 1
else:
self.failed_executions += 1
self.last_error = error_message
class CallbackMonitor:
"""
Monitor for tracking callback execution and collecting metrics.
Provides:
- Execution time tracking
- Success/failure counting
- Error logging
- Performance statistics
"""
_instance: Optional['CallbackMonitor'] = None
_lock = threading.Lock()
def __new__(cls) -> 'CallbackMonitor':
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._stats: Dict[str, CallbackStats] = defaultdict(
lambda: CallbackStats(callback_name="")
)
self._recent_executions: List[CallbackExecutionRecord] = []
self._max_recent_records = 1000
self._enabled = self._check_enabled()
self._debug_mode = self._check_debug_mode()
self._initialized = True
def _check_enabled(self) -> bool:
"""Check if monitoring is enabled."""
callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
return callback_settings.get('monitoring_enabled', True)
def _check_debug_mode(self) -> bool:
"""Check if debug mode is enabled."""
callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
return callback_settings.get('debug_mode', settings.DEBUG)
def is_enabled(self) -> bool:
"""Check if monitoring is currently enabled."""
return self._enabled
def enable(self) -> None:
"""Enable monitoring."""
self._enabled = True
logger.info("Callback monitoring enabled")
def disable(self) -> None:
"""Disable monitoring."""
self._enabled = False
logger.info("Callback monitoring disabled")
def set_debug_mode(self, enabled: bool) -> None:
"""Set debug mode."""
self._debug_mode = enabled
logger.info(f"Callback debug mode {'enabled' if enabled else 'disabled'}")
def record_execution(
self,
callback_name: str,
context: TransitionContext,
stage: str,
duration_ms: float,
success: bool,
error_message: Optional[str] = None,
) -> None:
"""
Record a callback execution.
Args:
callback_name: Name of the executed callback.
context: The transition context.
stage: Callback stage (pre/post/error).
duration_ms: Execution duration in milliseconds.
success: Whether execution was successful.
error_message: Error message if execution failed.
"""
if not self._enabled:
return
# Update stats
stats = self._stats[callback_name]
stats.callback_name = callback_name
stats.record_execution(duration_ms, success, error_message)
# Create execution record
record = CallbackExecutionRecord(
callback_name=callback_name,
model_name=context.model_name,
field_name=context.field_name,
source_state=context.source_state,
target_state=context.target_state,
stage=stage,
timestamp=datetime.now(),
duration_ms=duration_ms,
success=success,
error_message=error_message,
instance_id=context.instance.pk if context.instance else None,
user_id=context.user.id if context.user else None,
)
# Store recent executions (with size limit)
self._recent_executions.append(record)
if len(self._recent_executions) > self._max_recent_records:
self._recent_executions = self._recent_executions[-self._max_recent_records:]
# Log in debug mode
if self._debug_mode:
self._log_execution(record)
def _log_execution(self, record: CallbackExecutionRecord) -> None:
"""Log callback execution details."""
status = "" if record.success else ""
log_message = (
f"{status} Callback: {record.callback_name} "
f"({record.model_name}.{record.field_name}: "
f"{record.source_state}{record.target_state}) "
f"[{record.stage}] {record.duration_ms:.2f}ms"
)
if record.success:
logger.debug(log_message)
else:
logger.warning(f"{log_message} - Error: {record.error_message}")
def get_stats(self, callback_name: Optional[str] = None) -> Dict[str, CallbackStats]:
"""
Get callback statistics.
Args:
callback_name: If provided, return stats for this callback only.
Returns:
Dictionary of callback stats.
"""
if callback_name:
if callback_name in self._stats:
return {callback_name: self._stats[callback_name]}
return {}
return dict(self._stats)
def get_recent_executions(
self,
limit: int = 100,
callback_name: Optional[str] = None,
model_name: Optional[str] = None,
success_only: Optional[bool] = None,
) -> List[CallbackExecutionRecord]:
"""
Get recent execution records.
Args:
limit: Maximum number of records to return.
callback_name: Filter by callback name.
model_name: Filter by model name.
success_only: If True, only successful; if False, only failed.
Returns:
List of execution records.
"""
records = self._recent_executions.copy()
# Apply filters
if callback_name:
records = [r for r in records if r.callback_name == callback_name]
if model_name:
records = [r for r in records if r.model_name == model_name]
if success_only is not None:
records = [r for r in records if r.success == success_only]
# Return most recent first
return list(reversed(records[-limit:]))
def get_failure_summary(self) -> Dict[str, Any]:
"""Get a summary of callback failures."""
failures = [r for r in self._recent_executions if not r.success]
# Group by callback
by_callback: Dict[str, List[CallbackExecutionRecord]] = defaultdict(list)
for record in failures:
by_callback[record.callback_name].append(record)
# Build summary
summary = {
'total_failures': len(failures),
'by_callback': {
name: {
'count': len(records),
'last_error': records[-1].error_message if records else None,
'last_occurrence': records[-1].timestamp if records else None,
}
for name, records in by_callback.items()
},
}
return summary
def get_performance_report(self) -> Dict[str, Any]:
"""Get a performance report for all callbacks."""
report = {
'callbacks': {},
'summary': {
'total_callbacks': len(self._stats),
'total_executions': sum(s.total_executions for s in self._stats.values()),
'total_failures': sum(s.failed_executions for s in self._stats.values()),
'avg_duration_ms': 0.0,
},
}
total_duration = 0.0
total_count = 0
for name, stats in self._stats.items():
report['callbacks'][name] = {
'executions': stats.total_executions,
'success_rate': f"{stats.success_rate:.1f}%",
'avg_duration_ms': f"{stats.avg_duration_ms:.2f}",
'min_duration_ms': f"{stats.min_duration_ms:.2f}" if stats.min_duration_ms != float('inf') else "N/A",
'max_duration_ms': f"{stats.max_duration_ms:.2f}",
'last_execution': stats.last_execution.isoformat() if stats.last_execution else None,
}
total_duration += stats.total_duration_ms
total_count += stats.total_executions
if total_count > 0:
report['summary']['avg_duration_ms'] = total_duration / total_count
return report
def clear_stats(self) -> None:
"""Clear all statistics."""
self._stats.clear()
self._recent_executions.clear()
logger.info("Callback statistics cleared")
@classmethod
def reset_instance(cls) -> None:
"""Reset the singleton instance. For testing."""
cls._instance = None
# Global monitor instance
callback_monitor = CallbackMonitor()
class TimedCallbackExecution:
"""
Context manager for timing callback execution.
Usage:
with TimedCallbackExecution(callback, context, stage) as timer:
callback.execute(context)
# Timer automatically records execution
"""
def __init__(
self,
callback_name: str,
context: TransitionContext,
stage: str,
):
self.callback_name = callback_name
self.context = context
self.stage = stage
self.start_time = 0.0
self.success = True
self.error_message: Optional[str] = None
def __enter__(self) -> 'TimedCallbackExecution':
self.start_time = time.perf_counter()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
duration_ms = (time.perf_counter() - self.start_time) * 1000
if exc_type is not None:
self.success = False
self.error_message = str(exc_val)
callback_monitor.record_execution(
callback_name=self.callback_name,
context=self.context,
stage=self.stage,
duration_ms=duration_ms,
success=self.success,
error_message=self.error_message,
)
# Don't suppress exceptions
return False
def mark_failure(self, error_message: str) -> None:
"""Mark execution as failed."""
self.success = False
self.error_message = error_message
def log_transition_start(context: TransitionContext) -> None:
"""Log the start of a transition."""
if callback_monitor._debug_mode:
logger.debug(
f"→ Starting transition: {context.model_name}.{context.field_name} "
f"{context.source_state}{context.target_state}"
)
def log_transition_end(
context: TransitionContext,
success: bool,
duration_ms: float,
) -> None:
"""Log the end of a transition."""
if callback_monitor._debug_mode:
status = "" if success else ""
logger.debug(
f"{status} Completed transition: {context.model_name}.{context.field_name} "
f"{context.source_state}{context.target_state} [{duration_ms:.2f}ms]"
)
def get_callback_execution_order(
model_name: str,
source: str,
target: str,
) -> List[Tuple[str, str, int]]:
"""
Get the order of callback execution for a transition.
Args:
model_name: Name of the model.
source: Source state.
target: Target state.
Returns:
List of (stage, callback_name, priority) tuples in execution order.
"""
from .callbacks import callback_registry, CallbackStage
order = []
for stage in [CallbackStage.PRE, CallbackStage.POST, CallbackStage.ERROR]:
# We need to get the model class, but we only have the name
# This is mainly for debugging, so we'll return what we can
order.append((stage.value, f"[{model_name}:{source}{target}]", 0))
return order
__all__ = [
'CallbackExecutionRecord',
'CallbackStats',
'CallbackMonitor',
'callback_monitor',
'TimedCallbackExecution',
'log_transition_start',
'log_transition_end',
'get_callback_execution_order',
]

View File

@@ -1,10 +1,16 @@
"""TransitionRegistry - Centralized registry for managing FSM transitions.""" """TransitionRegistry - Centralized registry for managing FSM transitions."""
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any, Tuple from typing import Callable, Dict, List, Optional, Any, Tuple, Type
import logging
from django.db import models
from apps.core.state_machine.builder import StateTransitionBuilder from apps.core.state_machine.builder import StateTransitionBuilder
logger = logging.getLogger(__name__)
@dataclass @dataclass
class TransitionInfo: class TransitionInfo:
"""Information about a state transition.""" """Information about a state transition."""
@@ -280,4 +286,216 @@ class TransitionRegistry:
registry_instance = TransitionRegistry() registry_instance = TransitionRegistry()
__all__ = ["TransitionInfo", "TransitionRegistry", "registry_instance"] # Callback registration helpers
def register_callback(
model_class: Type[models.Model],
field_name: str,
source: str,
target: str,
callback: Any,
stage: str = 'post',
) -> None:
"""
Register a callback for a specific state transition.
Args:
model_class: The model class to register the callback for.
field_name: The FSM field name.
source: Source state (use '*' for any).
target: Target state (use '*' for any).
callback: The callback instance.
stage: When to execute ('pre', 'post', 'error').
"""
from .callbacks import callback_registry, CallbackStage
callback_registry.register(
model_class=model_class,
field_name=field_name,
source=source,
target=target,
callback=callback,
stage=CallbackStage(stage) if isinstance(stage, str) else stage,
)
def register_notification_callback(
model_class: Type[models.Model],
field_name: str,
source: str,
target: str,
notification_type: str,
recipient_field: str = 'submitted_by',
) -> None:
"""
Register a notification callback for a state transition.
Args:
model_class: The model class.
field_name: The FSM field name.
source: Source state.
target: Target state.
notification_type: Type of notification to send.
recipient_field: Field containing the recipient user.
"""
from .callbacks.notifications import NotificationCallback
callback = NotificationCallback(
notification_type=notification_type,
recipient_field=recipient_field,
)
register_callback(model_class, field_name, source, target, callback, 'post')
def register_cache_invalidation(
model_class: Type[models.Model],
field_name: str,
cache_patterns: Optional[List[str]] = None,
source: str = '*',
target: str = '*',
) -> None:
"""
Register cache invalidation for state transitions.
Args:
model_class: The model class.
field_name: The FSM field name.
cache_patterns: List of cache key patterns to invalidate.
source: Source state filter.
target: Target state filter.
"""
from .callbacks.cache import CacheInvalidationCallback
callback = CacheInvalidationCallback(patterns=cache_patterns or [])
register_callback(model_class, field_name, source, target, callback, 'post')
def register_related_update(
model_class: Type[models.Model],
field_name: str,
update_func: Callable,
source: str = '*',
target: str = '*',
) -> None:
"""
Register a related model update callback.
Args:
model_class: The model class.
field_name: The FSM field name.
update_func: Function to call with TransitionContext.
source: Source state filter.
target: Target state filter.
"""
from .callbacks.related_updates import RelatedModelUpdateCallback
callback = RelatedModelUpdateCallback(update_function=update_func)
register_callback(model_class, field_name, source, target, callback, 'post')
def register_transition_callbacks(cls: Type[models.Model]) -> Type[models.Model]:
"""
Class decorator to auto-register callbacks from model's Meta.
Usage:
@register_transition_callbacks
class EditSubmission(StateMachineMixin, TrackedModel):
class Meta:
transition_callbacks = {
('PENDING', 'APPROVED'): [
SubmissionApprovedNotification(),
CacheInvalidationCallback(patterns=['*submission*']),
]
}
Args:
cls: The model class to decorate.
Returns:
The decorated model class.
"""
meta = getattr(cls, 'Meta', None)
if not meta:
return cls
transition_callbacks = getattr(meta, 'transition_callbacks', None)
if not transition_callbacks:
return cls
# Get the FSM field name
field_name = getattr(meta, 'fsm_field', 'status')
# Register each callback
for (source, target), callbacks in transition_callbacks.items():
if not isinstance(callbacks, (list, tuple)):
callbacks = [callbacks]
for callback in callbacks:
register_callback(
model_class=cls,
field_name=field_name,
source=source,
target=target,
callback=callback,
)
logger.debug(f"Registered transition callbacks for {cls.__name__}")
return cls
def discover_and_register_callbacks() -> None:
"""
Discover and register callbacks for all models with StateMachineMixin.
This function should be called in an AppConfig.ready() method.
"""
from django.apps import apps
registered_count = 0
for model in apps.get_models():
# Check if model has StateMachineMixin
if not hasattr(model, '_fsm_metadata') and not hasattr(model, 'Meta'):
continue
meta = getattr(model, 'Meta', None)
if not meta:
continue
transition_callbacks = getattr(meta, 'transition_callbacks', None)
if not transition_callbacks:
continue
# Get the FSM field name
field_name = getattr(meta, 'fsm_field', 'status')
# Register callbacks
for (source, target), callbacks in transition_callbacks.items():
if not isinstance(callbacks, (list, tuple)):
callbacks = [callbacks]
for callback in callbacks:
register_callback(
model_class=model,
field_name=field_name,
source=source,
target=target,
callback=callback,
)
registered_count += 1
logger.info(f"Discovered and registered {registered_count} transition callbacks")
__all__ = [
"TransitionInfo",
"TransitionRegistry",
"registry_instance",
# Callback registration helpers
"register_callback",
"register_notification_callback",
"register_cache_invalidation",
"register_related_update",
"register_transition_callbacks",
"discover_and_register_callbacks",
]

View File

@@ -0,0 +1,335 @@
"""
Signal-based hook system for FSM state transitions.
This module defines custom Django signals emitted during state machine
transitions and provides utilities for connecting signal handlers.
"""
from typing import Any, Callable, Dict, List, Optional, Type, Union
import logging
from django.db import models
from django.dispatch import Signal, receiver
from .callbacks import TransitionContext
logger = logging.getLogger(__name__)
# Custom signals for state machine transitions
pre_state_transition = Signal()
"""
Signal sent before a state transition occurs.
Arguments:
sender: The model class of the transitioning instance.
instance: The model instance undergoing transition.
source: The source state value.
target: The target state value.
user: The user initiating the transition (if available).
context: TransitionContext with full transition metadata.
"""
post_state_transition = Signal()
"""
Signal sent after a successful state transition.
Arguments:
sender: The model class of the transitioning instance.
instance: The model instance that transitioned.
source: The source state value.
target: The target state value.
user: The user who initiated the transition.
context: TransitionContext with full transition metadata.
"""
state_transition_failed = Signal()
"""
Signal sent when a state transition fails.
Arguments:
sender: The model class of the transitioning instance.
instance: The model instance that failed to transition.
source: The source state value.
target: The intended target state value.
user: The user who initiated the transition.
exception: The exception that caused the failure.
context: TransitionContext with full transition metadata.
"""
class TransitionSignalHandler:
"""
Utility class for managing transition signal handlers.
Provides a cleaner interface for connecting and disconnecting
signal handlers filtered by model class and transition states.
"""
def __init__(self):
self._handlers: Dict[str, List[Callable]] = {}
def register(
self,
model_class: Type[models.Model],
source: str,
target: str,
handler: Callable,
stage: str = 'post',
) -> None:
"""
Register a handler for a specific transition.
Args:
model_class: The model class to handle transitions for.
source: Source state (use '*' for any).
target: Target state (use '*' for any).
handler: The handler function to call.
stage: 'pre', 'post', or 'error'.
"""
key = self._make_key(model_class, source, target, stage)
if key not in self._handlers:
self._handlers[key] = []
self._handlers[key].append(handler)
# Connect to appropriate signal
signal = self._get_signal(stage)
self._connect_signal(signal, model_class, source, target, handler)
logger.debug(
f"Registered {stage} transition handler for "
f"{model_class.__name__}: {source}{target}"
)
def unregister(
self,
model_class: Type[models.Model],
source: str,
target: str,
handler: Callable,
stage: str = 'post',
) -> None:
"""Unregister a previously registered handler."""
key = self._make_key(model_class, source, target, stage)
if key in self._handlers and handler in self._handlers[key]:
self._handlers[key].remove(handler)
signal = self._get_signal(stage)
signal.disconnect(handler, sender=model_class)
def _make_key(
self,
model_class: Type[models.Model],
source: str,
target: str,
stage: str,
) -> str:
"""Create a unique key for handler registration."""
return f"{model_class.__name__}:{source}:{target}:{stage}"
def _get_signal(self, stage: str) -> Signal:
"""Get the signal for a given stage."""
if stage == 'pre':
return pre_state_transition
elif stage == 'error':
return state_transition_failed
return post_state_transition
def _connect_signal(
self,
signal: Signal,
model_class: Type[models.Model],
source: str,
target: str,
handler: Callable,
) -> None:
"""Connect a filtered handler to the signal."""
def filtered_handler(sender, **kwargs):
# Check if this is the right model
if sender != model_class:
return
# Check source state
signal_source = kwargs.get('source', '')
if source != '*' and str(signal_source) != source:
return
# Check target state
signal_target = kwargs.get('target', '')
if target != '*' and str(signal_target) != target:
return
# Call the handler
return handler(**kwargs)
signal.connect(filtered_handler, sender=model_class, weak=False)
# Global signal handler instance
transition_signal_handler = TransitionSignalHandler()
def register_transition_handler(
model_class: Type[models.Model],
source: str,
target: str,
handler: Callable,
stage: str = 'post',
) -> None:
"""
Convenience function to register a transition signal handler.
Args:
model_class: The model class to handle transitions for.
source: Source state (use '*' for any).
target: Target state (use '*' for any).
handler: The handler function to call.
stage: 'pre', 'post', or 'error'.
"""
transition_signal_handler.register(
model_class, source, target, handler, stage
)
def connect_fsm_log_signals() -> None:
"""
Connect to django-fsm-log signals for audit logging.
This function should be called in an AppConfig.ready() method
to set up integration with django-fsm-log's StateLog.
"""
try:
from django_fsm_log.models import StateLog
@receiver(models.signals.post_save, sender=StateLog)
def log_state_transition(sender, instance, created, **kwargs):
"""Log state transitions from django-fsm-log."""
if created:
logger.info(
f"FSM Transition: {instance.content_type} "
f"({instance.object_id}): {instance.source_state}"
f"{instance.state} by {instance.by}"
)
logger.debug("Connected to django-fsm-log signals")
except ImportError:
logger.debug("django-fsm-log not available, skipping signal connection")
class TransitionHandlerDecorator:
"""
Decorator for registering transition handlers.
Usage:
@on_transition(EditSubmission, 'PENDING', 'APPROVED')
def handle_approval(instance, source, target, user, **kwargs):
# Handle the approval
pass
"""
def __init__(
self,
model_class: Type[models.Model],
source: str = '*',
target: str = '*',
stage: str = 'post',
):
"""
Initialize the decorator.
Args:
model_class: The model class to handle.
source: Source state filter.
target: Target state filter.
stage: When to execute ('pre', 'post', 'error').
"""
self.model_class = model_class
self.source = source
self.target = target
self.stage = stage
def __call__(self, func: Callable) -> Callable:
"""Register the decorated function as a handler."""
register_transition_handler(
self.model_class,
self.source,
self.target,
func,
self.stage,
)
return func
def on_transition(
model_class: Type[models.Model],
source: str = '*',
target: str = '*',
stage: str = 'post',
) -> TransitionHandlerDecorator:
"""
Decorator factory for registering transition handlers.
Args:
model_class: The model class to handle.
source: Source state filter ('*' for any).
target: Target state filter ('*' for any).
stage: When to execute ('pre', 'post', 'error').
Returns:
Decorator for registering the handler function.
Example:
@on_transition(EditSubmission, source='PENDING', target='APPROVED')
def notify_user(instance, source, target, user, **kwargs):
send_notification(instance.submitted_by, "Your submission was approved!")
"""
return TransitionHandlerDecorator(model_class, source, target, stage)
def on_pre_transition(
model_class: Type[models.Model],
source: str = '*',
target: str = '*',
) -> TransitionHandlerDecorator:
"""Decorator for pre-transition handlers."""
return on_transition(model_class, source, target, stage='pre')
def on_post_transition(
model_class: Type[models.Model],
source: str = '*',
target: str = '*',
) -> TransitionHandlerDecorator:
"""Decorator for post-transition handlers."""
return on_transition(model_class, source, target, stage='post')
def on_transition_error(
model_class: Type[models.Model],
source: str = '*',
target: str = '*',
) -> TransitionHandlerDecorator:
"""Decorator for transition error handlers."""
return on_transition(model_class, source, target, stage='error')
__all__ = [
# Signals
'pre_state_transition',
'post_state_transition',
'state_transition_failed',
# Handler registration
'TransitionSignalHandler',
'transition_signal_handler',
'register_transition_handler',
'connect_fsm_log_signals',
# Decorators
'on_transition',
'on_pre_transition',
'on_post_transition',
'on_transition_error',
]

View File

@@ -1,13 +1,24 @@
import logging
from django.apps import AppConfig from django.apps import AppConfig
logger = logging.getLogger(__name__)
class ModerationConfig(AppConfig): class ModerationConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField" default_auto_field = "django.db.models.BigAutoField"
name = "apps.moderation" name = "apps.moderation"
verbose_name = "Content Moderation" verbose_name = "Content Moderation"
def ready(self): def ready(self):
"""Initialize state machines for all moderation models.""" """Initialize state machines and callbacks for all moderation models."""
self._apply_state_machines()
self._register_callbacks()
self._register_signal_handlers()
def _apply_state_machines(self):
"""Apply FSM to all moderation models."""
from apps.core.state_machine import apply_state_machine from apps.core.state_machine import apply_state_machine
from .models import ( from .models import (
EditSubmission, EditSubmission,
@@ -48,3 +59,113 @@ class ModerationConfig(AppConfig):
choice_group="photo_submission_statuses", choice_group="photo_submission_statuses",
domain="moderation", domain="moderation",
) )
def _register_callbacks(self):
"""Register FSM transition callbacks for moderation models."""
from apps.core.state_machine.registry import register_callback
from apps.core.state_machine.callbacks.notifications import (
SubmissionApprovedNotification,
SubmissionRejectedNotification,
SubmissionEscalatedNotification,
ModerationNotificationCallback,
)
from apps.core.state_machine.callbacks.cache import (
ModerationCacheInvalidation,
)
from .models import (
EditSubmission,
ModerationReport,
ModerationQueue,
BulkOperation,
PhotoSubmission,
)
# EditSubmission callbacks
register_callback(
EditSubmission, 'status', 'PENDING', 'APPROVED',
SubmissionApprovedNotification()
)
register_callback(
EditSubmission, 'status', 'PENDING', 'APPROVED',
ModerationCacheInvalidation()
)
register_callback(
EditSubmission, 'status', 'PENDING', 'REJECTED',
SubmissionRejectedNotification()
)
register_callback(
EditSubmission, 'status', 'PENDING', 'REJECTED',
ModerationCacheInvalidation()
)
register_callback(
EditSubmission, 'status', 'PENDING', 'ESCALATED',
SubmissionEscalatedNotification()
)
register_callback(
EditSubmission, 'status', 'PENDING', 'ESCALATED',
ModerationCacheInvalidation()
)
# PhotoSubmission callbacks
register_callback(
PhotoSubmission, 'status', 'PENDING', 'APPROVED',
SubmissionApprovedNotification()
)
register_callback(
PhotoSubmission, 'status', 'PENDING', 'APPROVED',
ModerationCacheInvalidation()
)
register_callback(
PhotoSubmission, 'status', 'PENDING', 'REJECTED',
SubmissionRejectedNotification()
)
register_callback(
PhotoSubmission, 'status', 'PENDING', 'REJECTED',
ModerationCacheInvalidation()
)
register_callback(
PhotoSubmission, 'status', 'PENDING', 'ESCALATED',
SubmissionEscalatedNotification()
)
# ModerationReport callbacks
register_callback(
ModerationReport, 'status', '*', '*',
ModerationNotificationCallback()
)
register_callback(
ModerationReport, 'status', '*', '*',
ModerationCacheInvalidation()
)
# ModerationQueue callbacks
register_callback(
ModerationQueue, 'status', '*', '*',
ModerationNotificationCallback()
)
register_callback(
ModerationQueue, 'status', '*', '*',
ModerationCacheInvalidation()
)
# BulkOperation callbacks
register_callback(
BulkOperation, 'status', '*', '*',
ModerationNotificationCallback()
)
register_callback(
BulkOperation, 'status', '*', '*',
ModerationCacheInvalidation()
)
logger.debug("Registered moderation transition callbacks")
def _register_signal_handlers(self):
"""Register signal handlers for moderation transitions."""
from .signals import register_moderation_signal_handlers
try:
register_moderation_signal_handlers()
logger.debug("Registered moderation signal handlers")
except Exception as e:
logger.warning(f"Could not register moderation signal handlers: {e}")

View File

@@ -9,6 +9,11 @@ This module contains models for the ThrillWiki moderation system, including:
- BulkOperation: Administrative bulk operations - BulkOperation: Administrative bulk operations
All models use pghistory for change tracking and TrackedModel base class. All models use pghistory for change tracking and TrackedModel base class.
Callback System Integration:
All FSM-enabled models in this module support the callback system.
Callbacks for notifications, cache invalidation, and related updates
are registered via the callback configuration defined in each model's Meta class.
""" """
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
@@ -29,6 +34,35 @@ from apps.core.state_machine import RichFSMField, StateMachineMixin
UserType = Union[AbstractBaseUser, AnonymousUser] UserType = Union[AbstractBaseUser, AnonymousUser]
# Lazy callback imports to avoid circular dependencies
def _get_notification_callbacks():
"""Lazy import of notification callbacks."""
from apps.core.state_machine.callbacks.notifications import (
SubmissionApprovedNotification,
SubmissionRejectedNotification,
SubmissionEscalatedNotification,
ModerationNotificationCallback,
)
return {
'approved': SubmissionApprovedNotification,
'rejected': SubmissionRejectedNotification,
'escalated': SubmissionEscalatedNotification,
'moderation': ModerationNotificationCallback,
}
def _get_cache_callbacks():
"""Lazy import of cache callbacks."""
from apps.core.state_machine.callbacks.cache import (
CacheInvalidationCallback,
ModerationCacheInvalidation,
)
return {
'generic': CacheInvalidationCallback,
'moderation': ModerationCacheInvalidation,
}
# ============================================================================ # ============================================================================
# Original EditSubmission Model (Preserved) # Original EditSubmission Model (Preserved)
# ============================================================================ # ============================================================================

View File

@@ -0,0 +1,326 @@
"""
Signal handlers for moderation-related FSM state transitions.
This module provides signal handlers that execute when moderation
models (EditSubmission, PhotoSubmission, ModerationReport, etc.)
undergo state transitions.
"""
import logging
from django.conf import settings
from django.dispatch import receiver
from apps.core.state_machine.signals import (
post_state_transition,
state_transition_failed,
)
logger = logging.getLogger(__name__)
def handle_submission_approved(instance, source, target, user, context=None, **kwargs):
"""
Handle submission approval transitions.
Called when an EditSubmission or PhotoSubmission is approved.
Args:
instance: The submission instance.
source: The source state.
target: The target state.
user: The user who approved.
context: Optional TransitionContext.
"""
if target != 'APPROVED':
return
logger.info(
f"Submission {instance.pk} approved by {user if user else 'system'}"
)
# Trigger notification (handled by NotificationCallback)
# Invalidate cache (handled by CacheInvalidationCallback)
# Apply the submission changes if applicable
if hasattr(instance, 'apply_changes'):
try:
instance.apply_changes()
logger.info(f"Applied changes for submission {instance.pk}")
except Exception as e:
logger.exception(
f"Failed to apply changes for submission {instance.pk}: {e}"
)
def handle_submission_rejected(instance, source, target, user, context=None, **kwargs):
"""
Handle submission rejection transitions.
Called when an EditSubmission or PhotoSubmission is rejected.
Args:
instance: The submission instance.
source: The source state.
target: The target state.
user: The user who rejected.
context: Optional TransitionContext.
"""
if target != 'REJECTED':
return
reason = context.extra_data.get('reason', '') if context else ''
logger.info(
f"Submission {instance.pk} rejected by {user if user else 'system'}"
f"{f': {reason}' if reason else ''}"
)
def handle_submission_escalated(instance, source, target, user, context=None, **kwargs):
"""
Handle submission escalation transitions.
Called when an EditSubmission or PhotoSubmission is escalated.
Args:
instance: The submission instance.
source: The source state.
target: The target state.
user: The user who escalated.
context: Optional TransitionContext.
"""
if target != 'ESCALATED':
return
reason = context.extra_data.get('reason', '') if context else ''
logger.info(
f"Submission {instance.pk} escalated by {user if user else 'system'}"
f"{f': {reason}' if reason else ''}"
)
# Create escalation task if task system is available
_create_escalation_task(instance, user, reason)
def handle_report_resolved(instance, source, target, user, context=None, **kwargs):
"""
Handle moderation report resolution.
Called when a ModerationReport is resolved.
Args:
instance: The ModerationReport instance.
source: The source state.
target: The target state.
user: The user who resolved.
context: Optional TransitionContext.
"""
if target != 'RESOLVED':
return
logger.info(
f"ModerationReport {instance.pk} resolved by {user if user else 'system'}"
)
# Update related queue items
_update_related_queue_items(instance, 'COMPLETED')
def handle_queue_completed(instance, source, target, user, context=None, **kwargs):
"""
Handle moderation queue completion.
Called when a ModerationQueue item is completed.
Args:
instance: The ModerationQueue instance.
source: The source state.
target: The target state.
user: The user who completed.
context: Optional TransitionContext.
"""
if target != 'COMPLETED':
return
logger.info(
f"ModerationQueue {instance.pk} completed by {user if user else 'system'}"
)
# Update moderation statistics
_update_moderation_stats(instance, user)
def handle_bulk_operation_status(instance, source, target, user, context=None, **kwargs):
"""
Handle bulk operation status changes.
Called when a BulkOperation transitions between states.
Args:
instance: The BulkOperation instance.
source: The source state.
target: The target state.
user: The user who initiated the change.
context: Optional TransitionContext.
"""
logger.info(
f"BulkOperation {instance.pk} transitioned: {source}{target}"
)
if target == 'COMPLETED':
_finalize_bulk_operation(instance, success=True)
elif target == 'FAILED':
_finalize_bulk_operation(instance, success=False)
# Helper functions
def _create_escalation_task(instance, user, reason):
"""Create an escalation task for admin review."""
try:
from apps.moderation.models import ModerationQueue
# Create a queue item for the escalated submission
ModerationQueue.objects.create(
content_object=instance,
priority='HIGH',
reason=f"Escalated: {reason}" if reason else "Escalated for review",
created_by=user,
)
logger.info(f"Created escalation queue item for submission {instance.pk}")
except ImportError:
logger.debug("ModerationQueue model not available")
except Exception as e:
logger.warning(f"Failed to create escalation task: {e}")
def _update_related_queue_items(instance, status):
"""Update queue items related to a moderation object."""
try:
from django.contrib.contenttypes.models import ContentType
from apps.moderation.models import ModerationQueue
content_type = ContentType.objects.get_for_model(type(instance))
queue_items = ModerationQueue.objects.filter(
content_type=content_type,
object_id=instance.pk,
).exclude(status=status)
updated = queue_items.update(status=status)
if updated:
logger.info(f"Updated {updated} queue items to {status}")
except ImportError:
logger.debug("ModerationQueue model not available")
except Exception as e:
logger.warning(f"Failed to update queue items: {e}")
def _update_moderation_stats(instance, user):
"""Update moderation statistics for a user."""
if not user:
return
try:
# Update user's moderation count if they have a profile
profile = getattr(user, 'profile', None)
if profile and hasattr(profile, 'moderation_count'):
profile.moderation_count += 1
profile.save(update_fields=['moderation_count'])
logger.debug(f"Updated moderation count for {user}")
except Exception as e:
logger.warning(f"Failed to update moderation stats: {e}")
def _finalize_bulk_operation(instance, success):
"""Finalize a bulk operation after completion or failure."""
try:
from django.utils import timezone
instance.completed_at = timezone.now()
instance.save(update_fields=['completed_at'])
if success:
logger.info(
f"BulkOperation {instance.pk} completed successfully: "
f"{getattr(instance, 'success_count', 0)} succeeded, "
f"{getattr(instance, 'failure_count', 0)} failed"
)
else:
logger.warning(
f"BulkOperation {instance.pk} failed: "
f"{getattr(instance, 'error_message', 'Unknown error')}"
)
except Exception as e:
logger.warning(f"Failed to finalize bulk operation: {e}")
# Signal handler registration
def register_moderation_signal_handlers():
"""
Register all moderation signal handlers.
This function should be called in the moderation app's AppConfig.ready() method.
"""
from apps.core.state_machine.signals import register_transition_handler
try:
from apps.moderation.models import (
EditSubmission,
PhotoSubmission,
ModerationReport,
ModerationQueue,
BulkOperation,
)
# EditSubmission handlers
register_transition_handler(
EditSubmission, '*', 'APPROVED',
handle_submission_approved, stage='post'
)
register_transition_handler(
EditSubmission, '*', 'REJECTED',
handle_submission_rejected, stage='post'
)
register_transition_handler(
EditSubmission, '*', 'ESCALATED',
handle_submission_escalated, stage='post'
)
# PhotoSubmission handlers
register_transition_handler(
PhotoSubmission, '*', 'APPROVED',
handle_submission_approved, stage='post'
)
register_transition_handler(
PhotoSubmission, '*', 'REJECTED',
handle_submission_rejected, stage='post'
)
register_transition_handler(
PhotoSubmission, '*', 'ESCALATED',
handle_submission_escalated, stage='post'
)
# ModerationReport handlers
register_transition_handler(
ModerationReport, '*', 'RESOLVED',
handle_report_resolved, stage='post'
)
# ModerationQueue handlers
register_transition_handler(
ModerationQueue, '*', 'COMPLETED',
handle_queue_completed, stage='post'
)
# BulkOperation handlers
register_transition_handler(
BulkOperation, '*', '*',
handle_bulk_operation_status, stage='post'
)
logger.info("Registered moderation signal handlers")
except ImportError as e:
logger.warning(f"Could not register moderation signal handlers: {e}")

View File

@@ -1,6 +1,11 @@
import logging
from django.apps import AppConfig from django.apps import AppConfig
logger = logging.getLogger(__name__)
class ParksConfig(AppConfig): class ParksConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField" default_auto_field = "django.db.models.BigAutoField"
name = "apps.parks" name = "apps.parks"
@@ -8,6 +13,12 @@ class ParksConfig(AppConfig):
def ready(self): def ready(self):
import apps.parks.signals # noqa: F401 - Register signals import apps.parks.signals # noqa: F401 - Register signals
import apps.parks.choices # noqa: F401 - Register choices import apps.parks.choices # noqa: F401 - Register choices
self._apply_state_machines()
self._register_callbacks()
def _apply_state_machines(self):
"""Apply FSM to park models."""
from apps.core.state_machine import apply_state_machine from apps.core.state_machine import apply_state_machine
from apps.parks.models import Park from apps.parks.models import Park
@@ -15,3 +26,48 @@ class ParksConfig(AppConfig):
apply_state_machine( apply_state_machine(
Park, field_name="status", choice_group="statuses", domain="parks" Park, field_name="status", choice_group="statuses", domain="parks"
) )
def _register_callbacks(self):
"""Register FSM transition callbacks for park models."""
from apps.core.state_machine.registry import register_callback
from apps.core.state_machine.callbacks.cache import (
ParkCacheInvalidation,
APICacheInvalidation,
)
from apps.core.state_machine.callbacks.notifications import (
StatusChangeNotification,
)
from apps.core.state_machine.callbacks.related_updates import (
SearchTextUpdateCallback,
)
from apps.parks.models import Park
# Cache invalidation for all park status changes
register_callback(
Park, 'status', '*', '*',
ParkCacheInvalidation()
)
# API cache invalidation
register_callback(
Park, 'status', '*', '*',
APICacheInvalidation(include_geo_cache=True)
)
# Search text update
register_callback(
Park, 'status', '*', '*',
SearchTextUpdateCallback()
)
# Notification for significant status changes
register_callback(
Park, 'status', '*', 'CLOSED_PERM',
StatusChangeNotification(notify_admins=True)
)
register_callback(
Park, 'status', '*', 'DEMOLISHED',
StatusChangeNotification(notify_admins=True)
)
logger.debug("Registered park transition callbacks")

View File

@@ -1,3 +1,5 @@
import logging
from django.db.models.signals import post_save, post_delete from django.db.models.signals import post_save, post_delete
from django.dispatch import receiver from django.dispatch import receiver
from django.db.models import Q from django.db.models import Q
@@ -6,29 +8,143 @@ from apps.rides.models import Ride
from .models import Park from .models import Park
def update_park_ride_counts(park): logger = logging.getLogger(__name__)
"""Update ride_count and coaster_count for a park"""
operating_rides = Q(status="OPERATING")
# Count total operating rides
ride_count = park.rides.filter(operating_rides).count()
# Count total operating roller coasters # Status values that count as "active" rides for counting purposes
coaster_count = park.rides.filter(operating_rides, category="RC").count() ACTIVE_STATUSES = {'OPERATING', 'SEASONAL', 'UNDER_CONSTRUCTION'}
# Update park counts # Status values that should decrement ride counts
Park.objects.filter(id=park.id).update( INACTIVE_STATUSES = {'CLOSED_PERM', 'DEMOLISHED', 'RELOCATED', 'REMOVED'}
ride_count=ride_count, coaster_count=coaster_count
)
def update_park_ride_counts(park, old_status=None, new_status=None):
"""
Update ride_count and coaster_count for a park.
Args:
park: The Park instance or park ID to update.
old_status: The previous status of the ride (for FSM transitions).
new_status: The new status of the ride (for FSM transitions).
"""
if park is None:
logger.warning("Cannot update counts: park is None")
return
# Get park ID
park_id = park.pk if hasattr(park, 'pk') else park
try:
# Fetch the park if we only have an ID
if not hasattr(park, 'rides'):
park = Park.objects.get(id=park_id)
# Build the query for active rides
active_statuses = list(ACTIVE_STATUSES)
operating_rides = Q(status__in=active_statuses)
# Count total operating rides
ride_count = park.rides.filter(operating_rides).count()
# Count total operating roller coasters
coaster_count = park.rides.filter(operating_rides, category="RC").count()
# Update park counts
Park.objects.filter(id=park_id).update(
ride_count=ride_count, coaster_count=coaster_count
)
logger.debug(
f"Updated park {park_id} counts: "
f"ride_count={ride_count}, coaster_count={coaster_count}"
)
except Park.DoesNotExist:
logger.warning(f"Park {park_id} does not exist, cannot update counts")
except Exception as e:
logger.exception(f"Failed to update park counts for {park_id}: {e}")
def should_update_counts(old_status, new_status):
"""
Determine if a status change should trigger count updates.
Args:
old_status: The previous status value.
new_status: The new status value.
Returns:
True if counts should be updated, False otherwise.
"""
if old_status == new_status:
return False
# Check if either status is in active or inactive sets
old_active = old_status in ACTIVE_STATUSES if old_status else False
new_active = new_status in ACTIVE_STATUSES if new_status else False
old_inactive = old_status in INACTIVE_STATUSES if old_status else False
new_inactive = new_status in INACTIVE_STATUSES if new_status else False
# Update if transitioning to/from active status
return old_active != new_active or old_inactive != new_inactive
@receiver(post_save, sender=Ride) @receiver(post_save, sender=Ride)
def ride_saved(sender, instance, **kwargs): def ride_saved(sender, instance, created, **kwargs):
"""Update park counts when a ride is saved""" """
update_park_ride_counts(instance.park) Update park counts when a ride is saved.
Integrates with FSM transitions by checking for status changes.
"""
# For new rides, always update counts
if created:
update_park_ride_counts(instance.park)
return
# Check if status changed using model's tracker if available
if hasattr(instance, 'tracker') and hasattr(instance.tracker, 'has_changed'):
if instance.tracker.has_changed('status'):
old_status = instance.tracker.previous('status')
new_status = instance.status
if should_update_counts(old_status, new_status):
logger.info(
f"Ride {instance.pk} status changed: {old_status}{new_status}"
)
update_park_ride_counts(instance.park, old_status, new_status)
else:
# Fallback: always update counts on save
update_park_ride_counts(instance.park)
@receiver(post_delete, sender=Ride) @receiver(post_delete, sender=Ride)
def ride_deleted(sender, instance, **kwargs): def ride_deleted(sender, instance, **kwargs):
"""Update park counts when a ride is deleted""" """
Update park counts when a ride is deleted.
Logs the deletion for audit purposes.
"""
logger.info(f"Ride {instance.pk} deleted from park {instance.park_id}")
update_park_ride_counts(instance.park) update_park_ride_counts(instance.park)
# FSM transition signal handlers
def handle_ride_status_transition(instance, source, target, user, **kwargs):
"""
Handle ride status FSM transitions.
This function is called by the FSM callback system when a ride
status transition occurs.
Args:
instance: The Ride instance.
source: The source state.
target: The target state.
user: The user who initiated the transition.
"""
if should_update_counts(source, target):
logger.info(
f"FSM transition: Ride {instance.pk} {source}{target} "
f"by {user if user else 'system'}"
)
update_park_ride_counts(instance.park, source, target)

View File

@@ -1,13 +1,25 @@
import logging
from django.apps import AppConfig from django.apps import AppConfig
logger = logging.getLogger(__name__)
class RidesConfig(AppConfig): class RidesConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField" default_auto_field = "django.db.models.BigAutoField"
name = "apps.rides" name = "apps.rides"
def ready(self): def ready(self):
import apps.rides.choices # noqa: F401 - Register choices import apps.rides.choices # noqa: F401 - Register choices
import apps.rides.signals # noqa: F401 - Register signals
import apps.rides.tasks # noqa: F401 - Register Celery tasks import apps.rides.tasks # noqa: F401 - Register Celery tasks
self._apply_state_machines()
self._register_callbacks()
def _apply_state_machines(self):
"""Apply FSM to ride models."""
from apps.core.state_machine import apply_state_machine from apps.core.state_machine import apply_state_machine
from apps.rides.models import Ride from apps.rides.models import Ride
@@ -15,3 +27,58 @@ class RidesConfig(AppConfig):
apply_state_machine( apply_state_machine(
Ride, field_name="status", choice_group="statuses", domain="rides" Ride, field_name="status", choice_group="statuses", domain="rides"
) )
def _register_callbacks(self):
"""Register FSM transition callbacks for ride models."""
from apps.core.state_machine.registry import register_callback
from apps.core.state_machine.callbacks.cache import (
RideCacheInvalidation,
APICacheInvalidation,
)
from apps.core.state_machine.callbacks.related_updates import (
ParkCountUpdateCallback,
SearchTextUpdateCallback,
)
from apps.rides.models import Ride
# Cache invalidation for all ride status changes
register_callback(
Ride, 'status', '*', '*',
RideCacheInvalidation()
)
# API cache invalidation
register_callback(
Ride, 'status', '*', '*',
APICacheInvalidation(include_geo_cache=True)
)
# Park count updates for status changes that affect active rides
register_callback(
Ride, 'status', '*', 'OPERATING',
ParkCountUpdateCallback()
)
register_callback(
Ride, 'status', 'OPERATING', '*',
ParkCountUpdateCallback()
)
register_callback(
Ride, 'status', '*', 'CLOSED_PERM',
ParkCountUpdateCallback()
)
register_callback(
Ride, 'status', '*', 'DEMOLISHED',
ParkCountUpdateCallback()
)
register_callback(
Ride, 'status', '*', 'RELOCATED',
ParkCountUpdateCallback()
)
# Search text update
register_callback(
Ride, 'status', '*', '*',
SearchTextUpdateCallback()
)
logger.debug("Registered ride transition callbacks")

View File

@@ -1,17 +1,188 @@
import logging
from django.db.models.signals import pre_save from django.db.models.signals import pre_save
from django.dispatch import receiver from django.dispatch import receiver
from django.utils import timezone from django.utils import timezone
from .models import Ride from .models import Ride
logger = logging.getLogger(__name__)
@receiver(pre_save, sender=Ride) @receiver(pre_save, sender=Ride)
def handle_ride_status(sender, instance, **kwargs): def handle_ride_status(sender, instance, **kwargs):
"""Handle ride status changes based on closing date""" """
if instance.closing_date: Handle ride status changes based on closing date.
today = timezone.now().date()
# If we've reached the closing date and status is "Closing" Integrates with FSM transitions by using transition methods when available.
if today >= instance.closing_date and instance.status == "CLOSING": """
# Change to the selected post-closing status if not instance.closing_date:
instance.status = instance.post_closing_status or "SBNO" return
today = timezone.now().date()
# If we've reached the closing date and status is "CLOSING"
if today >= instance.closing_date and instance.status == "CLOSING":
target_status = instance.post_closing_status or "SBNO"
logger.info(
f"Ride {instance.pk} closing date reached, "
f"transitioning to {target_status}"
)
# Try to use FSM transition method if available
transition_method_name = f'transition_to_{target_status.lower()}'
if hasattr(instance, transition_method_name):
# Check if transition is allowed before attempting
if hasattr(instance, 'can_proceed'):
can_proceed = getattr(instance, f'can_transition_to_{target_status.lower()}', None)
if can_proceed and callable(can_proceed):
if not can_proceed():
logger.warning(
f"FSM transition to {target_status} not allowed "
f"for ride {instance.pk}"
)
# Fall back to direct status change
instance.status = target_status
instance.status_since = instance.closing_date
return
try:
method = getattr(instance, transition_method_name)
method()
instance.status_since = instance.closing_date
logger.info(
f"Applied FSM transition to {target_status} for ride {instance.pk}"
)
except Exception as e:
logger.exception(
f"Failed to apply FSM transition for ride {instance.pk}: {e}"
)
# Fall back to direct status change
instance.status = target_status
instance.status_since = instance.closing_date
else:
# No FSM transition method, use direct assignment
instance.status = target_status
instance.status_since = instance.closing_date instance.status_since = instance.closing_date
@receiver(pre_save, sender=Ride)
def validate_closing_status(sender, instance, **kwargs):
"""
Validate that post_closing_status is set when entering CLOSING state.
"""
# Only validate if this is an existing ride being updated
if not instance.pk:
return
# Check if we're transitioning to CLOSING
if instance.status == "CLOSING":
# Ensure post_closing_status is set
if not instance.post_closing_status:
logger.warning(
f"Ride {instance.pk} entering CLOSING without post_closing_status set"
)
# Default to SBNO if not set
instance.post_closing_status = "SBNO"
# Ensure closing_date is set
if not instance.closing_date:
logger.warning(
f"Ride {instance.pk} entering CLOSING without closing_date set"
)
# Default to today's date
instance.closing_date = timezone.now().date()
# FSM transition signal handlers
def handle_ride_transition_to_closing(instance, source, target, user, **kwargs):
"""
Validate transition to CLOSING status.
This function is called by the FSM callback system before a ride
transitions to CLOSING status.
Args:
instance: The Ride instance.
source: The source state.
target: The target state.
user: The user who initiated the transition.
Returns:
True if transition should proceed, False to abort.
"""
if target != 'CLOSING':
return True
if not instance.post_closing_status:
logger.error(
f"Cannot transition ride {instance.pk} to CLOSING: "
"post_closing_status not set"
)
return False
if not instance.closing_date:
logger.warning(
f"Ride {instance.pk} transitioning to CLOSING without closing_date"
)
return True
def apply_post_closing_status(instance, user=None):
"""
Apply the post_closing_status to a ride in CLOSING state.
This function can be called by the FSM callback system or directly
when a ride's closing date is reached.
Args:
instance: The Ride instance in CLOSING state.
user: The user initiating the change (optional).
Returns:
True if status was applied, False otherwise.
"""
if instance.status != 'CLOSING':
logger.debug(
f"Ride {instance.pk} not in CLOSING state, skipping"
)
return False
target_status = instance.post_closing_status
if not target_status:
logger.warning(
f"Ride {instance.pk} in CLOSING but no post_closing_status set"
)
return False
# Try to use FSM transition
transition_method_name = f'transition_to_{target_status.lower()}'
if hasattr(instance, transition_method_name):
try:
method = getattr(instance, transition_method_name)
method(user=user)
instance.post_closing_status = None
instance.save(update_fields=['post_closing_status'])
logger.info(
f"Applied post_closing_status {target_status} to ride {instance.pk}"
)
return True
except Exception as e:
logger.exception(
f"Failed to apply post_closing_status for ride {instance.pk}: {e}"
)
return False
else:
# Direct status change
instance.status = target_status
instance.post_closing_status = None
instance.status_since = timezone.now().date()
instance.save(update_fields=['status', 'post_closing_status', 'status_since'])
logger.info(
f"Applied post_closing_status {target_status} to ride {instance.pk} (direct)"
)
return True