mirror of
https://github.com/pacnpal/thrillwiki_django_no_react.git
synced 2025-12-30 01:27:00 -05:00
feat: Implement MFA authentication, add ride statistics model, and update various services, APIs, and tests across the application.
This commit is contained in:
@@ -13,7 +13,6 @@ from io import StringIO
|
||||
from django.contrib import admin, messages
|
||||
from django.core.serializers.json import DjangoJSONEncoder
|
||||
from django.http import HttpResponse
|
||||
from django.utils.html import format_html
|
||||
|
||||
|
||||
class QueryOptimizationMixin:
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from django.db import models
|
||||
from datetime import timedelta
|
||||
|
||||
import pghistory
|
||||
from django.contrib.contenttypes.fields import GenericForeignKey
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.utils import timezone
|
||||
from django.db import models
|
||||
from django.db.models import Count
|
||||
from datetime import timedelta
|
||||
import pghistory
|
||||
from django.utils import timezone
|
||||
|
||||
|
||||
@pghistory.track()
|
||||
|
||||
@@ -3,21 +3,27 @@ Custom exception handling for ThrillWiki API.
|
||||
Provides standardized error responses following Django styleguide patterns.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from django.http import Http404
|
||||
from django.core.exceptions import (
|
||||
PermissionDenied,
|
||||
)
|
||||
from django.core.exceptions import (
|
||||
ValidationError as DjangoValidationError,
|
||||
)
|
||||
from django.http import Http404
|
||||
from rest_framework import status
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import exception_handler
|
||||
from rest_framework.exceptions import (
|
||||
ValidationError as DRFValidationError,
|
||||
NotFound,
|
||||
)
|
||||
from rest_framework.exceptions import (
|
||||
PermissionDenied as DRFPermissionDenied,
|
||||
)
|
||||
from rest_framework.exceptions import (
|
||||
ValidationError as DRFValidationError,
|
||||
)
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import exception_handler
|
||||
|
||||
from ..exceptions import ThrillWikiException
|
||||
from ..logging import get_logger, log_exception
|
||||
@@ -26,8 +32,8 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
def custom_exception_handler(
|
||||
exc: Exception, context: Dict[str, Any]
|
||||
) -> Optional[Response]:
|
||||
exc: Exception, context: dict[str, Any]
|
||||
) -> Response | None:
|
||||
"""
|
||||
Custom exception handler for DRF that provides standardized error responses.
|
||||
|
||||
@@ -209,7 +215,7 @@ def _get_error_message(exc: Exception, response_data: Any) -> str:
|
||||
return str(exc) if str(exc) else "An error occurred"
|
||||
|
||||
|
||||
def _get_error_details(exc: Exception, response_data: Any) -> Optional[Dict[str, Any]]:
|
||||
def _get_error_details(exc: Exception, response_data: Any) -> dict[str, Any] | None:
|
||||
"""Extract detailed error information for debugging."""
|
||||
if isinstance(response_data, dict) and len(response_data) > 1:
|
||||
return response_data
|
||||
@@ -224,7 +230,7 @@ def _get_error_details(exc: Exception, response_data: Any) -> Optional[Dict[str,
|
||||
|
||||
def _format_django_validation_errors(
|
||||
exc: DjangoValidationError,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""Format Django ValidationError for API response."""
|
||||
if hasattr(exc, "error_dict"):
|
||||
# Field-specific errors
|
||||
|
||||
@@ -2,10 +2,11 @@
|
||||
Common mixins for API views following Django styleguide patterns.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
from rest_framework import status
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework import status
|
||||
|
||||
# Constants for error messages
|
||||
_MISSING_INPUT_SERIALIZER_MSG = "Subclasses must set input_serializer class attribute"
|
||||
@@ -20,17 +21,17 @@ class ApiMixin:
|
||||
|
||||
# Expose expected attributes so static type checkers know they exist on subclasses.
|
||||
# Subclasses or other bases (e.g. DRF GenericAPIView) will actually provide these.
|
||||
input_serializer: Optional[Type[Any]] = None
|
||||
output_serializer: Optional[Type[Any]] = None
|
||||
input_serializer: type[Any] | None = None
|
||||
output_serializer: type[Any] | None = None
|
||||
|
||||
def create_response(
|
||||
self,
|
||||
*,
|
||||
data: Any = None,
|
||||
message: Optional[str] = None,
|
||||
message: str | None = None,
|
||||
status_code: int = status.HTTP_200_OK,
|
||||
pagination: Optional[Dict[str, Any]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
pagination: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> Response:
|
||||
"""
|
||||
Create standardized API response.
|
||||
@@ -66,8 +67,8 @@ class ApiMixin:
|
||||
*,
|
||||
message: str,
|
||||
status_code: int = status.HTTP_400_BAD_REQUEST,
|
||||
error_code: Optional[str] = None,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
error_code: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
) -> Response:
|
||||
"""
|
||||
Create standardized error response.
|
||||
@@ -82,7 +83,7 @@ class ApiMixin:
|
||||
Standardized error Response object
|
||||
"""
|
||||
# explicitly allow any-shaped values in the error_data dict
|
||||
error_data: Dict[str, Any] = {
|
||||
error_data: dict[str, Any] = {
|
||||
"code": error_code or "GENERIC_ERROR",
|
||||
"message": message,
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ Security checks included:
|
||||
|
||||
import os
|
||||
import re
|
||||
from django.conf import settings
|
||||
from django.core.checks import Error, Warning, register, Tags
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.checks import Error, Tags, Warning, register
|
||||
|
||||
# =============================================================================
|
||||
# Secret Key Validation
|
||||
|
||||
@@ -12,11 +12,11 @@ Key Components:
|
||||
- RichChoiceSerializer: DRF serializer for API responses
|
||||
"""
|
||||
|
||||
from .base import RichChoice, ChoiceCategory, ChoiceGroup
|
||||
from .registry import ChoiceRegistry, register_choices
|
||||
from .base import ChoiceCategory, ChoiceGroup, RichChoice
|
||||
from .fields import RichChoiceField
|
||||
from .serializers import RichChoiceSerializer, RichChoiceOptionSerializer
|
||||
from .utils import validate_choice_value, get_choice_display
|
||||
from .registry import ChoiceRegistry, register_choices
|
||||
from .serializers import RichChoiceOptionSerializer, RichChoiceSerializer
|
||||
from .utils import get_choice_display, validate_choice_value
|
||||
|
||||
__all__ = [
|
||||
'RichChoice',
|
||||
|
||||
@@ -5,8 +5,8 @@ This module defines the core dataclass structures for rich choice objects.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Any, Optional
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ChoiceCategory(Enum):
|
||||
@@ -30,10 +30,10 @@ class ChoiceCategory(Enum):
|
||||
class RichChoice:
|
||||
"""
|
||||
Rich choice object with metadata support.
|
||||
|
||||
|
||||
This replaces simple tuple choices with a comprehensive object that can
|
||||
carry additional information like descriptions, colors, icons, and custom metadata.
|
||||
|
||||
|
||||
Attributes:
|
||||
value: The stored value (equivalent to first element of tuple choice)
|
||||
label: Human-readable display name (equivalent to second element of tuple choice)
|
||||
@@ -45,39 +45,39 @@ class RichChoice:
|
||||
value: str
|
||||
label: str
|
||||
description: str = ""
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
deprecated: bool = False
|
||||
category: ChoiceCategory = ChoiceCategory.OTHER
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate the choice object after initialization"""
|
||||
if not self.value:
|
||||
raise ValueError("Choice value cannot be empty")
|
||||
if not self.label:
|
||||
raise ValueError("Choice label cannot be empty")
|
||||
|
||||
|
||||
@property
|
||||
def color(self) -> Optional[str]:
|
||||
def color(self) -> str | None:
|
||||
"""Get the color from metadata if available"""
|
||||
return self.metadata.get('color')
|
||||
|
||||
|
||||
@property
|
||||
def icon(self) -> Optional[str]:
|
||||
def icon(self) -> str | None:
|
||||
"""Get the icon from metadata if available"""
|
||||
return self.metadata.get('icon')
|
||||
|
||||
|
||||
@property
|
||||
def css_class(self) -> Optional[str]:
|
||||
def css_class(self) -> str | None:
|
||||
"""Get the CSS class from metadata if available"""
|
||||
return self.metadata.get('css_class')
|
||||
|
||||
|
||||
@property
|
||||
def sort_order(self) -> int:
|
||||
"""Get the sort order from metadata, defaulting to 0"""
|
||||
return self.metadata.get('sort_order', 0)
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary representation for API serialization"""
|
||||
return {
|
||||
'value': self.value,
|
||||
@@ -91,11 +91,11 @@ class RichChoice:
|
||||
'css_class': self.css_class,
|
||||
'sort_order': self.sort_order,
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.label
|
||||
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"RichChoice(value='{self.value}', label='{self.label}')"
|
||||
|
||||
@@ -104,47 +104,47 @@ class RichChoice:
|
||||
class ChoiceGroup:
|
||||
"""
|
||||
A group of related choices with shared metadata.
|
||||
|
||||
|
||||
This allows for organizing choices into logical groups with
|
||||
common properties and behaviors.
|
||||
"""
|
||||
name: str
|
||||
choices: list[RichChoice]
|
||||
description: str = ""
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate the choice group after initialization"""
|
||||
if not self.name:
|
||||
raise ValueError("Choice group name cannot be empty")
|
||||
if not self.choices:
|
||||
raise ValueError("Choice group must contain at least one choice")
|
||||
|
||||
|
||||
# Validate that all choice values are unique within the group
|
||||
values = [choice.value for choice in self.choices]
|
||||
if len(values) != len(set(values)):
|
||||
raise ValueError("All choice values within a group must be unique")
|
||||
|
||||
def get_choice(self, value: str) -> Optional[RichChoice]:
|
||||
|
||||
def get_choice(self, value: str) -> RichChoice | None:
|
||||
"""Get a choice by its value"""
|
||||
for choice in self.choices:
|
||||
if choice.value == value:
|
||||
return choice
|
||||
return None
|
||||
|
||||
|
||||
def get_choices_by_category(self, category: ChoiceCategory) -> list[RichChoice]:
|
||||
"""Get all choices in a specific category"""
|
||||
return [choice for choice in self.choices if choice.category == category]
|
||||
|
||||
|
||||
def get_active_choices(self) -> list[RichChoice]:
|
||||
"""Get all non-deprecated choices"""
|
||||
return [choice for choice in self.choices if not choice.deprecated]
|
||||
|
||||
|
||||
def to_tuple_choices(self) -> list[tuple[str, str]]:
|
||||
"""Convert to legacy tuple choices format"""
|
||||
return [(choice.value, choice.label) for choice in self.choices]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary representation for API serialization"""
|
||||
return {
|
||||
'name': self.name,
|
||||
|
||||
@@ -5,10 +5,9 @@ This module defines all choice objects for core system functionality,
|
||||
including health checks, API statuses, and other system-level choices.
|
||||
"""
|
||||
|
||||
from .base import RichChoice, ChoiceCategory
|
||||
from .base import ChoiceCategory, RichChoice
|
||||
from .registry import register_choices
|
||||
|
||||
|
||||
# Health Check Status Choices
|
||||
HEALTH_STATUSES = [
|
||||
RichChoice(
|
||||
@@ -128,7 +127,7 @@ ENTITY_TYPES = [
|
||||
|
||||
def register_core_choices():
|
||||
"""Register all core system choices with the global registry"""
|
||||
|
||||
|
||||
register_choices(
|
||||
name="health_statuses",
|
||||
choices=HEALTH_STATUSES,
|
||||
@@ -136,7 +135,7 @@ def register_core_choices():
|
||||
description="Health check status options",
|
||||
metadata={'domain': 'core', 'type': 'health_status'}
|
||||
)
|
||||
|
||||
|
||||
register_choices(
|
||||
name="simple_health_statuses",
|
||||
choices=SIMPLE_HEALTH_STATUSES,
|
||||
@@ -144,7 +143,7 @@ def register_core_choices():
|
||||
description="Simple health check status options",
|
||||
metadata={'domain': 'core', 'type': 'simple_health_status'}
|
||||
)
|
||||
|
||||
|
||||
register_choices(
|
||||
name="entity_types",
|
||||
choices=ENTITY_TYPES,
|
||||
|
||||
@@ -4,10 +4,12 @@ Django Model Fields for Rich Choices
|
||||
This module provides Django model field implementations for rich choice objects.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
from django.db import models
|
||||
from typing import Any
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db import models
|
||||
from django.forms import ChoiceField
|
||||
|
||||
from .base import RichChoice
|
||||
from .registry import registry
|
||||
|
||||
@@ -15,11 +17,11 @@ from .registry import registry
|
||||
class RichChoiceField(models.CharField):
|
||||
"""
|
||||
Django model field for rich choice objects.
|
||||
|
||||
|
||||
This field stores the choice value as a CharField but provides
|
||||
rich choice functionality through the registry system.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
choice_group: str,
|
||||
@@ -30,7 +32,7 @@ class RichChoiceField(models.CharField):
|
||||
):
|
||||
"""
|
||||
Initialize the RichChoiceField.
|
||||
|
||||
|
||||
Args:
|
||||
choice_group: Name of the choice group in the registry
|
||||
domain: Domain namespace for the choice group
|
||||
@@ -41,66 +43,66 @@ class RichChoiceField(models.CharField):
|
||||
self.choice_group = choice_group
|
||||
self.domain = domain
|
||||
self.allow_deprecated = allow_deprecated
|
||||
|
||||
|
||||
# Set choices from registry for Django admin and forms
|
||||
if self.allow_deprecated:
|
||||
choices_list = registry.get_choices(choice_group, domain)
|
||||
else:
|
||||
choices_list = registry.get_active_choices(choice_group, domain)
|
||||
|
||||
|
||||
choices = [(choice.value, choice.label) for choice in choices_list]
|
||||
|
||||
|
||||
kwargs['choices'] = choices
|
||||
kwargs['max_length'] = max_length
|
||||
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
def validate(self, value: Any, model_instance: Any) -> None:
|
||||
"""Validate the choice value"""
|
||||
super().validate(value, model_instance)
|
||||
|
||||
|
||||
if value is None or value == '':
|
||||
return
|
||||
|
||||
|
||||
# Check if choice exists in registry
|
||||
choice = registry.get_choice(self.choice_group, value, self.domain)
|
||||
if choice is None:
|
||||
raise ValidationError(
|
||||
f"'{value}' is not a valid choice for {self.choice_group}"
|
||||
)
|
||||
|
||||
|
||||
# Check if deprecated choices are allowed
|
||||
if choice.deprecated and not self.allow_deprecated:
|
||||
raise ValidationError(
|
||||
f"'{value}' is deprecated and cannot be used for new entries"
|
||||
)
|
||||
|
||||
def get_rich_choice(self, value: str) -> Optional[RichChoice]:
|
||||
|
||||
def get_rich_choice(self, value: str) -> RichChoice | None:
|
||||
"""Get the RichChoice object for a value"""
|
||||
return registry.get_choice(self.choice_group, value, self.domain)
|
||||
|
||||
|
||||
def get_choice_display(self, value: str) -> str:
|
||||
"""Get the display label for a choice value"""
|
||||
return registry.get_choice_display(self.choice_group, value, self.domain)
|
||||
|
||||
|
||||
def contribute_to_class(self, cls: Any, name: str, private_only: bool = False, **kwargs: Any) -> None:
|
||||
"""Add helper methods to the model class (signature compatible with Django Field)"""
|
||||
super().contribute_to_class(cls, name, private_only=private_only, **kwargs)
|
||||
|
||||
|
||||
# Add get_FOO_rich_choice method
|
||||
def get_rich_choice_method(instance):
|
||||
value = getattr(instance, name)
|
||||
return self.get_rich_choice(value) if value else None
|
||||
|
||||
|
||||
setattr(cls, f'get_{name}_rich_choice', get_rich_choice_method)
|
||||
|
||||
|
||||
# Add get_FOO_display method (Django provides this, but we enhance it)
|
||||
def get_display_method(instance):
|
||||
value = getattr(instance, name)
|
||||
return self.get_choice_display(value) if value else ''
|
||||
|
||||
|
||||
setattr(cls, f'get_{name}_display', get_display_method)
|
||||
|
||||
|
||||
def deconstruct(self):
|
||||
"""Support for Django migrations"""
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
@@ -114,7 +116,7 @@ class RichChoiceFormField(ChoiceField):
|
||||
"""
|
||||
Form field for rich choices with enhanced functionality.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
choice_group: str,
|
||||
@@ -125,7 +127,7 @@ class RichChoiceFormField(ChoiceField):
|
||||
):
|
||||
"""
|
||||
Initialize the form field.
|
||||
|
||||
|
||||
Args:
|
||||
choice_group: Name of the choice group in the registry
|
||||
domain: Domain namespace for the choice group
|
||||
@@ -137,13 +139,13 @@ class RichChoiceFormField(ChoiceField):
|
||||
self.domain = domain
|
||||
self.allow_deprecated = allow_deprecated
|
||||
self.show_descriptions = show_descriptions
|
||||
|
||||
|
||||
# Get choices from registry
|
||||
if allow_deprecated:
|
||||
choices_list = registry.get_choices(choice_group, domain)
|
||||
else:
|
||||
choices_list = registry.get_active_choices(choice_group, domain)
|
||||
|
||||
|
||||
# Format choices for display
|
||||
choices = []
|
||||
for choice in choices_list:
|
||||
@@ -151,24 +153,24 @@ class RichChoiceFormField(ChoiceField):
|
||||
if show_descriptions and choice.description:
|
||||
label = f"{choice.label} - {choice.description}"
|
||||
choices.append((choice.value, label))
|
||||
|
||||
|
||||
kwargs['choices'] = choices
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
def validate(self, value: Any) -> None:
|
||||
"""Validate the choice value"""
|
||||
super().validate(value)
|
||||
|
||||
|
||||
if value is None or value == '':
|
||||
return
|
||||
|
||||
|
||||
# Check if choice exists in registry
|
||||
choice = registry.get_choice(self.choice_group, value, self.domain)
|
||||
if choice is None:
|
||||
raise ValidationError(
|
||||
f"'{value}' is not a valid choice for {self.choice_group}"
|
||||
)
|
||||
|
||||
|
||||
# Check if deprecated choices are allowed
|
||||
if choice.deprecated and not self.allow_deprecated:
|
||||
raise ValidationError(
|
||||
@@ -185,7 +187,7 @@ def create_rich_choice_field(
|
||||
) -> RichChoiceField:
|
||||
"""
|
||||
Factory function to create a RichChoiceField.
|
||||
|
||||
|
||||
This is useful for creating fields with consistent settings
|
||||
across multiple models.
|
||||
"""
|
||||
|
||||
@@ -4,55 +4,57 @@ Choice Registry
|
||||
Centralized registry for managing all choice definitions across the application.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from .base import RichChoice, ChoiceGroup
|
||||
|
||||
from .base import ChoiceGroup, RichChoice
|
||||
|
||||
|
||||
class ChoiceRegistry:
|
||||
"""
|
||||
Centralized registry for managing all choice definitions.
|
||||
|
||||
|
||||
This provides a single source of truth for all choice objects
|
||||
throughout the application, with support for namespacing by domain.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self._choices: Dict[str, ChoiceGroup] = {}
|
||||
self._domains: Dict[str, List[str]] = {}
|
||||
|
||||
self._choices: dict[str, ChoiceGroup] = {}
|
||||
self._domains: dict[str, list[str]] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
choices: List[RichChoice],
|
||||
self,
|
||||
name: str,
|
||||
choices: list[RichChoice],
|
||||
domain: str = "core",
|
||||
description: str = "",
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
) -> ChoiceGroup:
|
||||
"""
|
||||
Register a group of choices.
|
||||
|
||||
|
||||
Args:
|
||||
name: Unique name for the choice group
|
||||
choices: List of RichChoice objects
|
||||
domain: Domain namespace (e.g., 'rides', 'parks', 'accounts')
|
||||
description: Description of the choice group
|
||||
metadata: Additional metadata for the group
|
||||
|
||||
|
||||
Returns:
|
||||
The registered ChoiceGroup
|
||||
|
||||
|
||||
Raises:
|
||||
ImproperlyConfigured: If name is already registered with different choices
|
||||
"""
|
||||
full_name = f"{domain}.{name}"
|
||||
|
||||
|
||||
if full_name in self._choices:
|
||||
# Check if the existing registration is identical
|
||||
existing_group = self._choices[full_name]
|
||||
existing_values = [choice.value for choice in existing_group.choices]
|
||||
new_values = [choice.value for choice in choices]
|
||||
|
||||
|
||||
if existing_values == new_values:
|
||||
# Same choices, return existing group (allow duplicate registration)
|
||||
return existing_group
|
||||
@@ -62,69 +64,69 @@ class ChoiceRegistry:
|
||||
f"Choice group '{full_name}' is already registered with different choices. "
|
||||
f"Existing: {existing_values}, New: {new_values}"
|
||||
)
|
||||
|
||||
|
||||
choice_group = ChoiceGroup(
|
||||
name=full_name,
|
||||
choices=choices,
|
||||
description=description,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
|
||||
self._choices[full_name] = choice_group
|
||||
|
||||
|
||||
# Track domain
|
||||
if domain not in self._domains:
|
||||
self._domains[domain] = []
|
||||
self._domains[domain].append(name)
|
||||
|
||||
|
||||
return choice_group
|
||||
|
||||
def get(self, name: str, domain: str = "core") -> Optional[ChoiceGroup]:
|
||||
|
||||
def get(self, name: str, domain: str = "core") -> ChoiceGroup | None:
|
||||
"""Get a choice group by name and domain"""
|
||||
full_name = f"{domain}.{name}"
|
||||
return self._choices.get(full_name)
|
||||
|
||||
def get_choice(self, group_name: str, value: str, domain: str = "core") -> Optional[RichChoice]:
|
||||
|
||||
def get_choice(self, group_name: str, value: str, domain: str = "core") -> RichChoice | None:
|
||||
"""Get a specific choice by group name, value, and domain"""
|
||||
choice_group = self.get(group_name, domain)
|
||||
if choice_group:
|
||||
return choice_group.get_choice(value)
|
||||
return None
|
||||
|
||||
def get_choices(self, name: str, domain: str = "core") -> List[RichChoice]:
|
||||
|
||||
def get_choices(self, name: str, domain: str = "core") -> list[RichChoice]:
|
||||
"""Get all choices in a group"""
|
||||
choice_group = self.get(name, domain)
|
||||
return choice_group.choices if choice_group else []
|
||||
|
||||
def get_active_choices(self, name: str, domain: str = "core") -> List[RichChoice]:
|
||||
|
||||
def get_active_choices(self, name: str, domain: str = "core") -> list[RichChoice]:
|
||||
"""Get all non-deprecated choices in a group"""
|
||||
choice_group = self.get(name, domain)
|
||||
return choice_group.get_active_choices() if choice_group else []
|
||||
|
||||
|
||||
def get_domains(self) -> List[str]:
|
||||
|
||||
|
||||
def get_domains(self) -> list[str]:
|
||||
"""Get all registered domains"""
|
||||
return list(self._domains.keys())
|
||||
|
||||
def get_domain_choices(self, domain: str) -> Dict[str, ChoiceGroup]:
|
||||
|
||||
def get_domain_choices(self, domain: str) -> dict[str, ChoiceGroup]:
|
||||
"""Get all choice groups for a specific domain"""
|
||||
if domain not in self._domains:
|
||||
return {}
|
||||
|
||||
|
||||
return {
|
||||
name: self._choices[f"{domain}.{name}"]
|
||||
for name in self._domains[domain]
|
||||
}
|
||||
|
||||
def list_all(self) -> Dict[str, ChoiceGroup]:
|
||||
|
||||
def list_all(self) -> dict[str, ChoiceGroup]:
|
||||
"""Get all registered choice groups"""
|
||||
return self._choices.copy()
|
||||
|
||||
|
||||
def validate_choice(self, group_name: str, value: str, domain: str = "core") -> bool:
|
||||
"""Validate that a choice value exists in a group"""
|
||||
choice = self.get_choice(group_name, value, domain)
|
||||
return choice is not None and not choice.deprecated
|
||||
|
||||
|
||||
def get_choice_display(self, group_name: str, value: str, domain: str = "core") -> str:
|
||||
"""Get the display label for a choice value"""
|
||||
choice = self.get_choice(group_name, value, domain)
|
||||
@@ -132,7 +134,7 @@ class ChoiceRegistry:
|
||||
return choice.label
|
||||
else:
|
||||
raise ValueError(f"Choice value '{value}' not found in group '{group_name}' for domain '{domain}'")
|
||||
|
||||
|
||||
def clear_domain(self, domain: str) -> None:
|
||||
"""Clear all choices for a specific domain (useful for testing)"""
|
||||
if domain in self._domains:
|
||||
@@ -141,7 +143,7 @@ class ChoiceRegistry:
|
||||
if full_name in self._choices:
|
||||
del self._choices[full_name]
|
||||
del self._domains[domain]
|
||||
|
||||
|
||||
def clear_all(self) -> None:
|
||||
"""Clear all registered choices (useful for testing)"""
|
||||
self._choices.clear()
|
||||
@@ -154,33 +156,33 @@ registry = ChoiceRegistry()
|
||||
|
||||
def register_choices(
|
||||
name: str,
|
||||
choices: List[RichChoice],
|
||||
choices: list[RichChoice],
|
||||
domain: str = "core",
|
||||
description: str = "",
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
metadata: dict[str, Any] | None = None
|
||||
) -> ChoiceGroup:
|
||||
"""
|
||||
Convenience function to register choices with the global registry.
|
||||
|
||||
|
||||
Args:
|
||||
name: Unique name for the choice group
|
||||
choices: List of RichChoice objects
|
||||
domain: Domain namespace
|
||||
description: Description of the choice group
|
||||
metadata: Additional metadata for the group
|
||||
|
||||
|
||||
Returns:
|
||||
The registered ChoiceGroup
|
||||
"""
|
||||
return registry.register(name, choices, domain, description, metadata)
|
||||
|
||||
|
||||
def get_choices(name: str, domain: str = "core") -> List[RichChoice]:
|
||||
def get_choices(name: str, domain: str = "core") -> list[RichChoice]:
|
||||
"""Get choices from the global registry"""
|
||||
return registry.get_choices(name, domain)
|
||||
|
||||
|
||||
def get_choice(group_name: str, value: str, domain: str = "core") -> Optional[RichChoice]:
|
||||
def get_choice(group_name: str, value: str, domain: str = "core") -> RichChoice | None:
|
||||
"""Get a specific choice from the global registry"""
|
||||
return registry.get_choice(group_name, value, domain)
|
||||
|
||||
|
||||
@@ -5,16 +5,18 @@ This module provides Django REST Framework serializer implementations
|
||||
for rich choice objects.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
from rest_framework import serializers
|
||||
from .base import RichChoice, ChoiceGroup
|
||||
|
||||
from .base import ChoiceGroup, RichChoice
|
||||
from .registry import registry
|
||||
|
||||
|
||||
class RichChoiceSerializer(serializers.Serializer):
|
||||
"""
|
||||
Serializer for individual RichChoice objects.
|
||||
|
||||
|
||||
This provides a consistent API representation for choice objects
|
||||
with all their metadata.
|
||||
"""
|
||||
@@ -28,8 +30,8 @@ class RichChoiceSerializer(serializers.Serializer):
|
||||
icon = serializers.CharField(allow_null=True)
|
||||
css_class = serializers.CharField(allow_null=True)
|
||||
sort_order = serializers.IntegerField()
|
||||
|
||||
def to_representation(self, instance: RichChoice) -> Dict[str, Any]:
|
||||
|
||||
def to_representation(self, instance: RichChoice) -> dict[str, Any]:
|
||||
"""Convert RichChoice to dictionary representation"""
|
||||
return instance.to_dict()
|
||||
|
||||
@@ -37,7 +39,7 @@ class RichChoiceSerializer(serializers.Serializer):
|
||||
class RichChoiceOptionSerializer(serializers.Serializer):
|
||||
"""
|
||||
Serializer for choice options in filter endpoints.
|
||||
|
||||
|
||||
This replaces the legacy FilterOptionSerializer with rich choice support.
|
||||
"""
|
||||
value = serializers.CharField()
|
||||
@@ -50,8 +52,8 @@ class RichChoiceOptionSerializer(serializers.Serializer):
|
||||
icon = serializers.CharField(allow_null=True, required=False)
|
||||
css_class = serializers.CharField(allow_null=True, required=False)
|
||||
metadata = serializers.DictField(required=False)
|
||||
|
||||
def to_representation(self, instance) -> Dict[str, Any]:
|
||||
|
||||
def to_representation(self, instance) -> dict[str, Any]:
|
||||
"""Convert choice option to dictionary representation"""
|
||||
if isinstance(instance, RichChoice):
|
||||
# Convert RichChoice to option format
|
||||
@@ -88,7 +90,7 @@ class RichChoiceOptionSerializer(serializers.Serializer):
|
||||
class ChoiceGroupSerializer(serializers.Serializer):
|
||||
"""
|
||||
Serializer for ChoiceGroup objects.
|
||||
|
||||
|
||||
This provides API representation for entire choice groups
|
||||
with all their choices and metadata.
|
||||
"""
|
||||
@@ -96,8 +98,8 @@ class ChoiceGroupSerializer(serializers.Serializer):
|
||||
description = serializers.CharField()
|
||||
metadata = serializers.DictField()
|
||||
choices = RichChoiceSerializer(many=True)
|
||||
|
||||
def to_representation(self, instance: ChoiceGroup) -> Dict[str, Any]:
|
||||
|
||||
def to_representation(self, instance: ChoiceGroup) -> dict[str, Any]:
|
||||
"""Convert ChoiceGroup to dictionary representation"""
|
||||
return instance.to_dict()
|
||||
|
||||
@@ -105,11 +107,11 @@ class ChoiceGroupSerializer(serializers.Serializer):
|
||||
class RichChoiceFieldSerializer(serializers.CharField):
|
||||
"""
|
||||
Serializer field for rich choice values.
|
||||
|
||||
|
||||
This field serializes the choice value but can optionally
|
||||
include rich choice metadata in the response.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
choice_group: str,
|
||||
@@ -119,7 +121,7 @@ class RichChoiceFieldSerializer(serializers.CharField):
|
||||
):
|
||||
"""
|
||||
Initialize the serializer field.
|
||||
|
||||
|
||||
Args:
|
||||
choice_group: Name of the choice group in the registry
|
||||
domain: Domain namespace for the choice group
|
||||
@@ -130,12 +132,12 @@ class RichChoiceFieldSerializer(serializers.CharField):
|
||||
self.domain = domain
|
||||
self.include_metadata = include_metadata
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
def to_representation(self, value: str) -> Any:
|
||||
"""Convert choice value to representation"""
|
||||
if not value:
|
||||
return value
|
||||
|
||||
|
||||
if self.include_metadata:
|
||||
# Return rich choice object
|
||||
choice = registry.get_choice(self.choice_group, value, self.domain)
|
||||
@@ -158,7 +160,7 @@ class RichChoiceFieldSerializer(serializers.CharField):
|
||||
else:
|
||||
# Return just the value
|
||||
return value
|
||||
|
||||
|
||||
def to_internal_value(self, data: Any) -> str:
|
||||
"""Convert input data to choice value"""
|
||||
if isinstance(data, dict) and 'value' in data:
|
||||
@@ -175,26 +177,26 @@ def create_choice_options_serializer(
|
||||
include_counts: bool = False,
|
||||
queryset=None,
|
||||
count_field: str = 'id'
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Create choice options for filter endpoints.
|
||||
|
||||
|
||||
This function generates choice options with optional counts
|
||||
for use in filter metadata endpoints.
|
||||
|
||||
|
||||
Args:
|
||||
choice_group: Name of the choice group in the registry
|
||||
domain: Domain namespace for the choice group
|
||||
include_counts: Whether to include counts for each option
|
||||
queryset: QuerySet to count against (required if include_counts=True)
|
||||
count_field: Field to filter on for counting (default: 'id')
|
||||
|
||||
|
||||
Returns:
|
||||
List of choice option dictionaries
|
||||
"""
|
||||
choices = registry.get_active_choices(choice_group, domain)
|
||||
options = []
|
||||
|
||||
|
||||
for choice in choices:
|
||||
option_data = {
|
||||
'value': choice.value,
|
||||
@@ -207,7 +209,7 @@ def create_choice_options_serializer(
|
||||
'css_class': choice.css_class,
|
||||
'metadata': choice.metadata,
|
||||
}
|
||||
|
||||
|
||||
if include_counts and queryset is not None:
|
||||
# Count items for this choice
|
||||
try:
|
||||
@@ -218,9 +220,9 @@ def create_choice_options_serializer(
|
||||
option_data['count'] = None
|
||||
else:
|
||||
option_data['count'] = None
|
||||
|
||||
|
||||
options.append(option_data)
|
||||
|
||||
|
||||
# Sort by sort_order, then by label
|
||||
options.sort(key=lambda x: (
|
||||
(lambda c: c.sort_order if (c is not None and hasattr(c, 'sort_order')) else 0)(
|
||||
@@ -228,7 +230,7 @@ def create_choice_options_serializer(
|
||||
),
|
||||
x['label']
|
||||
))
|
||||
|
||||
|
||||
return options
|
||||
|
||||
|
||||
@@ -240,19 +242,19 @@ def serialize_choice_value(
|
||||
) -> Any:
|
||||
"""
|
||||
Serialize a single choice value.
|
||||
|
||||
|
||||
Args:
|
||||
value: The choice value to serialize
|
||||
choice_group: Name of the choice group in the registry
|
||||
domain: Domain namespace for the choice group
|
||||
include_metadata: Whether to include rich choice metadata
|
||||
|
||||
|
||||
Returns:
|
||||
Serialized choice value (string or rich object)
|
||||
"""
|
||||
if not value:
|
||||
return value
|
||||
|
||||
|
||||
if include_metadata:
|
||||
choice = registry.get_choice(choice_group, value, domain)
|
||||
if choice:
|
||||
|
||||
@@ -4,8 +4,9 @@ Utility Functions for Rich Choices
|
||||
This module provides utility functions for working with rich choice objects.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from .base import RichChoice, ChoiceCategory
|
||||
from typing import Any
|
||||
|
||||
from .base import ChoiceCategory, RichChoice
|
||||
from .registry import registry
|
||||
|
||||
|
||||
@@ -17,27 +18,24 @@ def validate_choice_value(
|
||||
) -> bool:
|
||||
"""
|
||||
Validate that a choice value is valid for a given choice group.
|
||||
|
||||
|
||||
Args:
|
||||
value: The choice value to validate
|
||||
choice_group: Name of the choice group in the registry
|
||||
domain: Domain namespace for the choice group
|
||||
allow_deprecated: Whether to allow deprecated choices
|
||||
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
if not value:
|
||||
return True # Allow empty values (handled by field's null/blank settings)
|
||||
|
||||
|
||||
choice = registry.get_choice(choice_group, value, domain)
|
||||
if choice is None:
|
||||
return False
|
||||
|
||||
if choice.deprecated and not allow_deprecated:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
return not (choice.deprecated and not allow_deprecated)
|
||||
|
||||
|
||||
def get_choice_display(
|
||||
@@ -47,21 +45,21 @@ def get_choice_display(
|
||||
) -> str:
|
||||
"""
|
||||
Get the display label for a choice value.
|
||||
|
||||
|
||||
Args:
|
||||
value: The choice value
|
||||
choice_group: Name of the choice group in the registry
|
||||
domain: Domain namespace for the choice group
|
||||
|
||||
|
||||
Returns:
|
||||
Display label for the choice
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If the choice value is not found in the registry
|
||||
"""
|
||||
if not value:
|
||||
return ""
|
||||
|
||||
|
||||
choice = registry.get_choice(choice_group, value, domain)
|
||||
if choice:
|
||||
return choice.label
|
||||
@@ -72,24 +70,24 @@ def get_choice_display(
|
||||
|
||||
|
||||
def create_status_choices(
|
||||
statuses: Dict[str, Dict[str, Any]],
|
||||
statuses: dict[str, dict[str, Any]],
|
||||
category: ChoiceCategory = ChoiceCategory.STATUS
|
||||
) -> List[RichChoice]:
|
||||
) -> list[RichChoice]:
|
||||
"""
|
||||
Create status choices with consistent color coding.
|
||||
|
||||
|
||||
Args:
|
||||
statuses: Dictionary mapping status value to config dict
|
||||
category: Choice category (defaults to STATUS)
|
||||
|
||||
|
||||
Returns:
|
||||
List of RichChoice objects for statuses
|
||||
"""
|
||||
choices = []
|
||||
|
||||
|
||||
for value, config in statuses.items():
|
||||
metadata = config.get('metadata', {})
|
||||
|
||||
|
||||
# Add default status colors if not specified
|
||||
if 'color' not in metadata:
|
||||
if 'operating' in value.lower() or 'active' in value.lower():
|
||||
@@ -102,7 +100,7 @@ def create_status_choices(
|
||||
metadata['color'] = 'blue'
|
||||
else:
|
||||
metadata['color'] = 'gray'
|
||||
|
||||
|
||||
choice = RichChoice(
|
||||
value=value,
|
||||
label=config['label'],
|
||||
@@ -112,26 +110,26 @@ def create_status_choices(
|
||||
category=category
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
|
||||
return choices
|
||||
|
||||
|
||||
def create_type_choices(
|
||||
types: Dict[str, Dict[str, Any]],
|
||||
types: dict[str, dict[str, Any]],
|
||||
category: ChoiceCategory = ChoiceCategory.TYPE
|
||||
) -> List[RichChoice]:
|
||||
) -> list[RichChoice]:
|
||||
"""
|
||||
Create type/classification choices.
|
||||
|
||||
|
||||
Args:
|
||||
types: Dictionary mapping type value to config dict
|
||||
category: Choice category (defaults to TYPE)
|
||||
|
||||
|
||||
Returns:
|
||||
List of RichChoice objects for types
|
||||
"""
|
||||
choices = []
|
||||
|
||||
|
||||
for value, config in types.items():
|
||||
choice = RichChoice(
|
||||
value=value,
|
||||
@@ -142,21 +140,21 @@ def create_type_choices(
|
||||
category=category
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
|
||||
return choices
|
||||
|
||||
|
||||
def merge_choice_metadata(
|
||||
base_metadata: Dict[str, Any],
|
||||
override_metadata: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
base_metadata: dict[str, Any],
|
||||
override_metadata: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Merge choice metadata dictionaries.
|
||||
|
||||
|
||||
Args:
|
||||
base_metadata: Base metadata dictionary
|
||||
override_metadata: Override metadata dictionary
|
||||
|
||||
|
||||
Returns:
|
||||
Merged metadata dictionary
|
||||
"""
|
||||
@@ -166,16 +164,16 @@ def merge_choice_metadata(
|
||||
|
||||
|
||||
def filter_choices_by_category(
|
||||
choices: List[RichChoice],
|
||||
choices: list[RichChoice],
|
||||
category: ChoiceCategory
|
||||
) -> List[RichChoice]:
|
||||
) -> list[RichChoice]:
|
||||
"""
|
||||
Filter choices by category.
|
||||
|
||||
|
||||
Args:
|
||||
choices: List of RichChoice objects
|
||||
category: Category to filter by
|
||||
|
||||
|
||||
Returns:
|
||||
Filtered list of choices
|
||||
"""
|
||||
@@ -183,16 +181,16 @@ def filter_choices_by_category(
|
||||
|
||||
|
||||
def sort_choices(
|
||||
choices: List[RichChoice],
|
||||
choices: list[RichChoice],
|
||||
sort_by: str = "sort_order"
|
||||
) -> List[RichChoice]:
|
||||
) -> list[RichChoice]:
|
||||
"""
|
||||
Sort choices by specified criteria.
|
||||
|
||||
|
||||
Args:
|
||||
choices: List of RichChoice objects
|
||||
sort_by: Sort criteria ("sort_order", "label", "value")
|
||||
|
||||
|
||||
Returns:
|
||||
Sorted list of choices
|
||||
"""
|
||||
@@ -209,14 +207,14 @@ def sort_choices(
|
||||
def get_choice_colors(
|
||||
choice_group: str,
|
||||
domain: str = "core"
|
||||
) -> Dict[str, str]:
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Get a mapping of choice values to their colors.
|
||||
|
||||
|
||||
Args:
|
||||
choice_group: Name of the choice group in the registry
|
||||
domain: Domain namespace for the choice group
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary mapping choice values to colors
|
||||
"""
|
||||
@@ -230,35 +228,35 @@ def get_choice_colors(
|
||||
|
||||
def validate_choice_group_data(
|
||||
name: str,
|
||||
choices: List[RichChoice],
|
||||
choices: list[RichChoice],
|
||||
domain: str = "core"
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""
|
||||
Validate choice group data and return list of errors.
|
||||
|
||||
|
||||
Args:
|
||||
name: Choice group name
|
||||
choices: List of RichChoice objects
|
||||
domain: Domain namespace
|
||||
|
||||
|
||||
Returns:
|
||||
List of validation error messages
|
||||
"""
|
||||
errors = []
|
||||
|
||||
|
||||
if not name:
|
||||
errors.append("Choice group name cannot be empty")
|
||||
|
||||
|
||||
if not choices:
|
||||
errors.append("Choice group must contain at least one choice")
|
||||
return errors
|
||||
|
||||
|
||||
# Check for duplicate values
|
||||
values = [choice.value for choice in choices]
|
||||
if len(values) != len(set(values)):
|
||||
duplicates = [v for v in values if values.count(v) > 1]
|
||||
errors.append(f"Duplicate choice values found: {', '.join(set(duplicates))}")
|
||||
|
||||
|
||||
# Validate individual choices
|
||||
for i, choice in enumerate(choices):
|
||||
try:
|
||||
@@ -273,17 +271,17 @@ def validate_choice_group_data(
|
||||
)
|
||||
except ValueError as e:
|
||||
errors.append(f"Choice {i}: {str(e)}")
|
||||
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def create_choice_from_config(config: Dict[str, Any]) -> RichChoice:
|
||||
def create_choice_from_config(config: dict[str, Any]) -> RichChoice:
|
||||
"""
|
||||
Create a RichChoice from a configuration dictionary.
|
||||
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary with choice data
|
||||
|
||||
|
||||
Returns:
|
||||
RichChoice object
|
||||
"""
|
||||
@@ -300,19 +298,19 @@ def create_choice_from_config(config: Dict[str, Any]) -> RichChoice:
|
||||
def export_choices_to_dict(
|
||||
choice_group: str,
|
||||
domain: str = "core"
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Export a choice group to a dictionary format.
|
||||
|
||||
|
||||
Args:
|
||||
choice_group: Name of the choice group in the registry
|
||||
domain: Domain namespace for the choice group
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the choice group
|
||||
"""
|
||||
group = registry.get(choice_group, domain)
|
||||
if not group:
|
||||
return {}
|
||||
|
||||
|
||||
return group.to_dict()
|
||||
|
||||
@@ -4,23 +4,26 @@ Advanced caching decorators for API views and functions.
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Optional, List, Callable, Any, Dict
|
||||
from typing import Any
|
||||
|
||||
from django.http import HttpRequest, HttpResponseBase
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.views.decorators.vary import vary_on_headers
|
||||
from django.views import View
|
||||
from django.views.decorators.vary import vary_on_headers
|
||||
from rest_framework.response import Response as DRFResponse
|
||||
|
||||
from apps.core.services.enhanced_cache_service import EnhancedCacheService
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cache_api_response(
|
||||
timeout: int = 1800,
|
||||
vary_on: Optional[List[str]] = None,
|
||||
vary_on: list[str] | None = None,
|
||||
key_prefix: str = "api",
|
||||
cache_backend: str = "api",
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
@@ -82,14 +85,14 @@ def cache_api_response(
|
||||
"cache_hit": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# If cached data is our dict format for DRF responses, reconstruct it
|
||||
if isinstance(cached_response, dict) and '__drf_data__' in cached_response:
|
||||
return DRFResponse(
|
||||
data=cached_response['__drf_data__'],
|
||||
data=cached_response['__drf_data__'],
|
||||
status=cached_response.get('status', 200)
|
||||
)
|
||||
|
||||
|
||||
return cached_response
|
||||
|
||||
# Execute view and cache result
|
||||
@@ -108,7 +111,7 @@ def cache_api_response(
|
||||
}
|
||||
else:
|
||||
cache_payload = response
|
||||
|
||||
|
||||
getattr(cache_service, cache_backend + "_cache").set(
|
||||
cache_key, cache_payload, timeout
|
||||
)
|
||||
@@ -193,7 +196,7 @@ def cache_queryset_result(
|
||||
|
||||
|
||||
def invalidate_cache_on_save(
|
||||
model_name: str, cache_patterns: Optional[List[str]] = None
|
||||
model_name: str, cache_patterns: list[str] | None = None
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""
|
||||
Decorator to invalidate cache when model instances are saved
|
||||
@@ -313,8 +316,8 @@ class CachedAPIViewMixin(View):
|
||||
|
||||
def smart_cache(
|
||||
timeout: int = 3600,
|
||||
key_func: Optional[Callable[..., str]] = None,
|
||||
invalidate_on: Optional[List[str]] = None,
|
||||
key_func: Callable[..., str] | None = None,
|
||||
invalidate_on: list[str] | None = None,
|
||||
cache_backend: str = "default",
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
||||
"""
|
||||
@@ -378,8 +381,8 @@ def smart_cache(
|
||||
|
||||
# Add cache invalidation if specified
|
||||
if invalidate_on:
|
||||
setattr(wrapper, "_cache_invalidate_on", invalidate_on)
|
||||
setattr(wrapper, "_cache_backend", cache_backend)
|
||||
wrapper._cache_invalidate_on = invalidate_on
|
||||
wrapper._cache_backend = cache_backend
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -431,7 +434,7 @@ def generate_model_cache_key(model_instance: Any, suffix: str = "") -> str:
|
||||
|
||||
|
||||
def generate_queryset_cache_key(
|
||||
queryset: Any, params: Optional[Dict[str, Any]] = None
|
||||
queryset: Any, params: dict[str, Any] | None = None
|
||||
) -> str:
|
||||
"""Generate cache key for queryset with parameters"""
|
||||
model_name = queryset.model._meta.model_name
|
||||
|
||||
@@ -3,7 +3,7 @@ Custom exception classes for ThrillWiki.
|
||||
Provides domain-specific exceptions with proper error codes and messages.
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ThrillWikiException(Exception):
|
||||
@@ -15,16 +15,16 @@ class ThrillWikiException(Exception):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: Optional[str] = None,
|
||||
error_code: Optional[str] = None,
|
||||
details: Optional[Dict[str, Any]] = None,
|
||||
message: str | None = None,
|
||||
error_code: str | None = None,
|
||||
details: dict[str, Any] | None = None,
|
||||
):
|
||||
self.message = message or self.default_message
|
||||
self.error_code = error_code or self.error_code
|
||||
self.details = details or {}
|
||||
super().__init__(self.message)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert exception to dictionary for API responses."""
|
||||
return {
|
||||
"error_code": self.error_code,
|
||||
@@ -96,7 +96,7 @@ class ParkNotFoundError(NotFoundError):
|
||||
default_message = "Park not found"
|
||||
error_code = "PARK_NOT_FOUND"
|
||||
|
||||
def __init__(self, park_slug: Optional[str] = None, **kwargs):
|
||||
def __init__(self, park_slug: str | None = None, **kwargs):
|
||||
if park_slug:
|
||||
kwargs["details"] = {"park_slug": park_slug}
|
||||
kwargs["message"] = f"Park with slug '{park_slug}' not found"
|
||||
@@ -122,7 +122,7 @@ class RideNotFoundError(NotFoundError):
|
||||
default_message = "Ride not found"
|
||||
error_code = "RIDE_NOT_FOUND"
|
||||
|
||||
def __init__(self, ride_slug: Optional[str] = None, **kwargs):
|
||||
def __init__(self, ride_slug: str | None = None, **kwargs):
|
||||
if ride_slug:
|
||||
kwargs["details"] = {"ride_slug": ride_slug}
|
||||
kwargs["message"] = f"Ride with slug '{ride_slug}' not found"
|
||||
@@ -150,8 +150,8 @@ class InvalidCoordinatesError(ValidationException):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
latitude: Optional[float] = None,
|
||||
longitude: Optional[float] = None,
|
||||
latitude: float | None = None,
|
||||
longitude: float | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if latitude is not None or longitude is not None:
|
||||
@@ -198,7 +198,7 @@ class InsufficientPermissionsError(PermissionDeniedError):
|
||||
default_message = "Insufficient permissions"
|
||||
error_code = "INSUFFICIENT_PERMISSIONS"
|
||||
|
||||
def __init__(self, required_permission: Optional[str] = None, **kwargs):
|
||||
def __init__(self, required_permission: str | None = None, **kwargs):
|
||||
if required_permission:
|
||||
kwargs["details"] = {"required_permission": required_permission}
|
||||
kwargs["message"] = f"Permission '{required_permission}' required"
|
||||
@@ -226,7 +226,7 @@ class RoadTripError(ExternalServiceError):
|
||||
default_message = "Road trip planning error"
|
||||
error_code = "ROADTRIP_ERROR"
|
||||
|
||||
def __init__(self, service_name: Optional[str] = None, **kwargs):
|
||||
def __init__(self, service_name: str | None = None, **kwargs):
|
||||
if service_name:
|
||||
kwargs["details"] = {"service": service_name}
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
"""Core forms and form components."""
|
||||
|
||||
from autocomplete import Autocomplete
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from autocomplete import Autocomplete
|
||||
|
||||
|
||||
class BaseAutocomplete(Autocomplete):
|
||||
"""Base autocomplete class for consistent autocomplete behavior across the project.
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""
|
||||
Base forms and views for HTMX integration.
|
||||
"""
|
||||
from django.views.generic.edit import FormView
|
||||
from django.http import JsonResponse
|
||||
from django.views.generic.edit import FormView
|
||||
|
||||
|
||||
class HTMXFormView(FormView):
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
Custom health checks for ThrillWiki application.
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db import connection
|
||||
from health_check.backends import BaseHealthCheckBackend
|
||||
@@ -165,9 +166,10 @@ class ApplicationHealthCheck(BaseHealthCheckBackend):
|
||||
|
||||
# Check if we can access critical models
|
||||
try:
|
||||
from parks.models import Park
|
||||
from apps.rides.models import Ride
|
||||
from django.contrib.auth import get_user_model
|
||||
from parks.models import Park
|
||||
|
||||
from apps.rides.models import Ride
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
@@ -185,9 +187,10 @@ class ApplicationHealthCheck(BaseHealthCheckBackend):
|
||||
self.add_error(f"Model access check failed: {e}")
|
||||
|
||||
# Check media and static file configuration
|
||||
from django.conf import settings
|
||||
import os
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
if not os.path.exists(settings.MEDIA_ROOT):
|
||||
self.add_error(f"Media directory does not exist: {settings.MEDIA_ROOT}")
|
||||
|
||||
@@ -208,8 +211,8 @@ class ExternalServiceHealthCheck(BaseHealthCheckBackend):
|
||||
def check_status(self):
|
||||
# Check email service if configured
|
||||
try:
|
||||
from django.core.mail import get_connection
|
||||
from django.conf import settings
|
||||
from django.core.mail import get_connection
|
||||
|
||||
if (
|
||||
hasattr(settings, "EMAIL_BACKEND")
|
||||
@@ -253,8 +256,8 @@ class ExternalServiceHealthCheck(BaseHealthCheckBackend):
|
||||
|
||||
# Check Redis connection if configured
|
||||
try:
|
||||
from django.core.cache import caches
|
||||
from django.conf import settings
|
||||
from django.core.cache import caches
|
||||
|
||||
cache_config = settings.CACHES.get("default", {})
|
||||
if "redis" in cache_config.get("BACKEND", "").lower():
|
||||
@@ -279,6 +282,7 @@ class DiskSpaceHealthCheck(BaseHealthCheckBackend):
|
||||
def check_status(self):
|
||||
try:
|
||||
import shutil
|
||||
|
||||
from django.conf import settings
|
||||
|
||||
# Check disk space for media directory
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from django.db import models
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.contrib.contenttypes.fields import GenericForeignKey
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from django.conf import settings
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
from django.contrib.contenttypes.fields import GenericForeignKey
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.db import models
|
||||
from django.db.models import QuerySet
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -12,7 +13,7 @@ if TYPE_CHECKING:
|
||||
class DiffMixin:
|
||||
"""Mixin to add diffing capabilities to models with pghistory"""
|
||||
|
||||
def get_prev_record(self) -> Optional[Any]:
|
||||
def get_prev_record(self) -> Any | None:
|
||||
"""Get the previous record for this instance"""
|
||||
try:
|
||||
# Use getattr to safely access objects manager and pghistory fields
|
||||
@@ -37,7 +38,7 @@ class DiffMixin:
|
||||
except (AttributeError, TypeError):
|
||||
return None
|
||||
|
||||
def diff_against_previous(self) -> Dict:
|
||||
def diff_against_previous(self) -> dict:
|
||||
"""Compare this record against the previous one"""
|
||||
prev_record = self.get_prev_record()
|
||||
if not prev_record:
|
||||
|
||||
@@ -24,7 +24,7 @@ import json
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from django.http import HttpResponse, JsonResponse
|
||||
from django.http import HttpResponse
|
||||
from django.template import TemplateDoesNotExist
|
||||
from django.template.loader import render_to_string
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@ Provides structured logging with proper formatting and context.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
from django.utils import timezone
|
||||
|
||||
@@ -65,7 +66,7 @@ def log_exception(
|
||||
logger: logging.Logger,
|
||||
exception: Exception,
|
||||
*,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
request=None,
|
||||
level: int = logging.ERROR,
|
||||
) -> None:
|
||||
@@ -111,7 +112,7 @@ def log_business_event(
|
||||
event_type: str,
|
||||
*,
|
||||
message: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
request=None,
|
||||
level: int = logging.INFO,
|
||||
) -> None:
|
||||
@@ -149,7 +150,7 @@ def log_performance_metric(
|
||||
operation: str,
|
||||
*,
|
||||
duration_ms: float,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
level: int = logging.INFO,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -177,8 +178,8 @@ def log_api_request(
|
||||
logger: logging.Logger,
|
||||
request,
|
||||
*,
|
||||
response_status: Optional[int] = None,
|
||||
duration_ms: Optional[float] = None,
|
||||
response_status: int | None = None,
|
||||
duration_ms: float | None = None,
|
||||
level: int = logging.INFO,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -219,7 +220,7 @@ def log_security_event(
|
||||
*,
|
||||
message: str,
|
||||
severity: str = "medium",
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
request=None,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
@@ -7,11 +7,12 @@ Run with: uv run manage.py calculate_new_content
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
from django.utils import timezone
|
||||
from typing import Any
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
from django.db.models import Q
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.parks.models import Park
|
||||
from apps.rides.models import Ride
|
||||
@@ -102,7 +103,7 @@ class Command(BaseCommand):
|
||||
logger.error(f"Error calculating new content: {e}", exc_info=True)
|
||||
raise CommandError(f"Failed to calculate new content: {e}")
|
||||
|
||||
def _get_new_parks(self, cutoff_date: datetime, limit: int) -> List[Dict[str, Any]]:
|
||||
def _get_new_parks(self, cutoff_date: datetime, limit: int) -> list[dict[str, Any]]:
|
||||
"""Get recently added parks using real data."""
|
||||
new_parks = (
|
||||
Park.objects.filter(
|
||||
@@ -117,9 +118,8 @@ class Command(BaseCommand):
|
||||
results = []
|
||||
for park in new_parks:
|
||||
date_added = park.opening_date or park.created_at
|
||||
if date_added:
|
||||
if isinstance(date_added, datetime):
|
||||
date_added = date_added.date()
|
||||
if date_added and isinstance(date_added, datetime):
|
||||
date_added = date_added.date()
|
||||
|
||||
opening_date = getattr(park, "opening_date", None)
|
||||
if opening_date and isinstance(opening_date, datetime):
|
||||
@@ -142,7 +142,7 @@ class Command(BaseCommand):
|
||||
|
||||
return results
|
||||
|
||||
def _get_new_rides(self, cutoff_date: datetime, limit: int) -> List[Dict[str, Any]]:
|
||||
def _get_new_rides(self, cutoff_date: datetime, limit: int) -> list[dict[str, Any]]:
|
||||
"""Get recently added rides using real data."""
|
||||
new_rides = (
|
||||
Ride.objects.filter(
|
||||
@@ -159,9 +159,8 @@ class Command(BaseCommand):
|
||||
date_added = getattr(ride, "opening_date", None) or getattr(
|
||||
ride, "created_at", None
|
||||
)
|
||||
if date_added:
|
||||
if isinstance(date_added, datetime):
|
||||
date_added = date_added.date()
|
||||
if date_added and isinstance(date_added, datetime):
|
||||
date_added = date_added.date()
|
||||
|
||||
opening_date = getattr(ride, "opening_date", None)
|
||||
if opening_date and isinstance(opening_date, datetime):
|
||||
@@ -186,8 +185,8 @@ class Command(BaseCommand):
|
||||
return results
|
||||
|
||||
def _format_new_content_results(
|
||||
self, new_items: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
self, new_items: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format new content results for frontend consumption."""
|
||||
formatted_results = []
|
||||
|
||||
|
||||
@@ -6,11 +6,12 @@ Run with: uv run manage.py calculate_trending
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any
|
||||
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.core.cache import cache
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
from django.utils import timezone
|
||||
from django.core.cache import cache
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
|
||||
from apps.core.analytics import PageView
|
||||
from apps.parks.models import Park
|
||||
@@ -107,7 +108,7 @@ class Command(BaseCommand):
|
||||
|
||||
def _calculate_trending_parks(
|
||||
self, current_period_hours: int, previous_period_hours: int, limit: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Calculate trending scores for parks using real data."""
|
||||
parks = Park.objects.filter(status="OPERATING").select_related(
|
||||
"location", "operator"
|
||||
@@ -151,7 +152,7 @@ class Command(BaseCommand):
|
||||
|
||||
def _calculate_trending_rides(
|
||||
self, current_period_hours: int, previous_period_hours: int, limit: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Calculate trending scores for rides using real data."""
|
||||
rides = Ride.objects.filter(status="OPERATING").select_related(
|
||||
"park", "park__location"
|
||||
@@ -339,10 +340,10 @@ class Command(BaseCommand):
|
||||
|
||||
def _format_trending_results(
|
||||
self,
|
||||
trending_items: List[Dict[str, Any]],
|
||||
trending_items: list[dict[str, Any]],
|
||||
current_period_hours: int,
|
||||
previous_period_hours: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format trending results for frontend consumption."""
|
||||
formatted_results = []
|
||||
|
||||
|
||||
@@ -15,9 +15,9 @@ import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.cache import cache, caches
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.conf import settings
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
|
||||
@@ -6,11 +6,10 @@ showing which callbacks are registered for each model and transition.
|
||||
"""
|
||||
|
||||
from django.core.management.base import BaseCommand, CommandParser
|
||||
from django.apps import apps
|
||||
|
||||
from apps.core.state_machine.callback_base import (
|
||||
callback_registry,
|
||||
CallbackStage,
|
||||
callback_registry,
|
||||
)
|
||||
from apps.core.state_machine.config import callback_config
|
||||
|
||||
|
||||
@@ -10,10 +10,10 @@ Usage:
|
||||
python manage.py optimize_static --force
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
|
||||
@@ -5,8 +5,8 @@ This command automatically sets up the development environment and starts
|
||||
the server, replacing the need for the dev_server.sh script.
|
||||
"""
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.core.management import execute_from_command_line
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
@@ -92,7 +92,7 @@ class Command(BaseCommand):
|
||||
def has_runserver_plus(self):
|
||||
"""Check if runserver_plus is available (django-extensions)."""
|
||||
try:
|
||||
import django_extensions
|
||||
import django_extensions # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
|
||||
@@ -10,9 +10,9 @@ Usage:
|
||||
python manage.py security_audit --verbose
|
||||
"""
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.core.checks import registry, Tags
|
||||
from django.conf import settings
|
||||
from django.core.checks import Tags, registry
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
@@ -88,10 +88,7 @@ class Command(BaseCommand):
|
||||
)
|
||||
else:
|
||||
for error in errors:
|
||||
if error.is_serious():
|
||||
prefix = self.style.ERROR(" ✗ ERROR")
|
||||
else:
|
||||
prefix = self.style.WARNING(" ! WARNING")
|
||||
prefix = self.style.ERROR(" ✗ ERROR") if error.is_serious() else self.style.WARNING(" ! WARNING")
|
||||
|
||||
self.log(f"{prefix}: {error.msg}", report_lines)
|
||||
if error.hint and self.verbose:
|
||||
@@ -169,10 +166,7 @@ class Command(BaseCommand):
|
||||
])
|
||||
|
||||
for name, is_secure, value in checks:
|
||||
if is_secure:
|
||||
status = self.style.SUCCESS("✓")
|
||||
else:
|
||||
status = self.style.WARNING("!")
|
||||
status = self.style.SUCCESS("✓") if is_secure else self.style.WARNING("!")
|
||||
|
||||
msg = f" {status} {name}"
|
||||
if self.verbose:
|
||||
|
||||
@@ -6,11 +6,10 @@ allowing the project to run without requiring the shell script.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.conf import settings
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
|
||||
@@ -5,14 +5,14 @@ This command allows testing callbacks for specific transitions
|
||||
without actually changing model state.
|
||||
"""
|
||||
|
||||
from django.core.management.base import BaseCommand, CommandParser, CommandError
|
||||
from django.apps import apps
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.core.management.base import BaseCommand, CommandError, CommandParser
|
||||
|
||||
from apps.core.state_machine.callback_base import (
|
||||
callback_registry,
|
||||
CallbackStage,
|
||||
TransitionContext,
|
||||
callback_registry,
|
||||
)
|
||||
from apps.core.state_machine.monitoring import callback_monitor
|
||||
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from django.core.management.base import BaseCommand
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.utils import timezone
|
||||
from apps.parks.models.parks import Park
|
||||
from apps.rides.models.rides import Ride
|
||||
from apps.parks.models.companies import Company
|
||||
|
||||
from apps.core.analytics import PageView
|
||||
from apps.core.services.trending_service import trending_service
|
||||
from datetime import datetime, timedelta
|
||||
import random
|
||||
from apps.parks.models.companies import Company
|
||||
from apps.parks.models.parks import Park
|
||||
from apps.rides.models.rides import Ride
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
@@ -205,7 +207,7 @@ class Command(BaseCommand):
|
||||
content_type = ContentType.objects.get_for_model(type(content_object))
|
||||
|
||||
# Create recent views (last 2 hours)
|
||||
for i in range(recent_views):
|
||||
for _i in range(recent_views):
|
||||
view_time = base_time - timedelta(
|
||||
minutes=random.randint(0, 120) # Last 2 hours
|
||||
)
|
||||
@@ -218,7 +220,7 @@ class Command(BaseCommand):
|
||||
)
|
||||
|
||||
# Create older views (2-24 hours ago)
|
||||
for i in range(older_views):
|
||||
for _i in range(older_views):
|
||||
view_time = base_time - timedelta(hours=random.randint(2, 24))
|
||||
PageView.objects.create(
|
||||
content_type=content_type,
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.core.cache import cache
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from apps.core.analytics import PageView
|
||||
from apps.parks.models import Park
|
||||
from apps.rides.models import Ride
|
||||
from apps.core.analytics import PageView
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
|
||||
@@ -12,14 +12,15 @@ Usage:
|
||||
|
||||
import json
|
||||
import sys
|
||||
from django.core.management.base import BaseCommand, CommandError
|
||||
|
||||
from django.core.management.base import BaseCommand
|
||||
|
||||
from config.settings.secrets import (
|
||||
check_secret_expiry,
|
||||
validate_required_secrets,
|
||||
)
|
||||
from config.settings.validation import (
|
||||
validate_all_settings,
|
||||
get_validation_report,
|
||||
)
|
||||
from config.settings.secrets import (
|
||||
validate_required_secrets,
|
||||
check_secret_expiry,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -12,12 +12,13 @@ Usage:
|
||||
python manage.py warm_cache --dry-run
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.db.models import Count, Avg
|
||||
import time
|
||||
|
||||
from apps.core.services.enhanced_cache_service import EnhancedCacheService, CacheWarmer
|
||||
from django.core.management.base import BaseCommand
|
||||
from django.db.models import Count
|
||||
|
||||
from apps.core.services.enhanced_cache_service import EnhancedCacheService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -122,7 +123,7 @@ class Command(BaseCommand):
|
||||
)
|
||||
warmed_count += 1
|
||||
if verbose:
|
||||
self.stdout.write(f" Cached park status counts")
|
||||
self.stdout.write(" Cached park status counts")
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
self.stdout.write(self.style.ERROR(f" Failed to cache park status counts: {e}"))
|
||||
@@ -191,7 +192,7 @@ class Command(BaseCommand):
|
||||
)
|
||||
warmed_count += 1
|
||||
if verbose:
|
||||
self.stdout.write(f" Cached ride category counts")
|
||||
self.stdout.write(" Cached ride category counts")
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
self.stdout.write(self.style.ERROR(f" Failed to cache ride category counts: {e}"))
|
||||
|
||||
@@ -3,13 +3,13 @@ Custom managers and QuerySets for optimized database patterns.
|
||||
Following Django styleguide best practices for database access.
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Union
|
||||
from django.db import models
|
||||
from django.db.models import Q, Count, Avg, Max
|
||||
from datetime import timedelta
|
||||
|
||||
from django.contrib.gis.geos import Point
|
||||
from django.contrib.gis.measure import Distance
|
||||
from django.db import models
|
||||
from django.db.models import Avg, Count, Max, Q
|
||||
from django.utils import timezone
|
||||
from datetime import timedelta
|
||||
|
||||
|
||||
class BaseQuerySet(models.QuerySet):
|
||||
@@ -32,7 +32,7 @@ class BaseQuerySet(models.QuerySet):
|
||||
cutoff_date = timezone.now() - timedelta(days=days)
|
||||
return self.filter(created_at__gte=cutoff_date)
|
||||
|
||||
def search(self, *, query: str, fields: Optional[List[str]] = None):
|
||||
def search(self, *, query: str, fields: list[str] | None = None):
|
||||
"""
|
||||
Full-text search across specified fields.
|
||||
|
||||
@@ -81,7 +81,7 @@ class BaseManager(models.Manager):
|
||||
def recent(self, *, days: int = 30):
|
||||
return self.get_queryset().recent(days=days)
|
||||
|
||||
def search(self, *, query: str, fields: Optional[List[str]] = None):
|
||||
def search(self, *, query: str, fields: list[str] | None = None):
|
||||
return self.get_queryset().search(query=query, fields=fields)
|
||||
|
||||
|
||||
@@ -245,7 +245,7 @@ class TimestampedManager(BaseManager):
|
||||
class StatusQuerySet(BaseQuerySet):
|
||||
"""QuerySet for models with status fields."""
|
||||
|
||||
def with_status(self, *, status: Union[str, List[str]]):
|
||||
def with_status(self, *, status: str | list[str]):
|
||||
"""Filter by status."""
|
||||
if isinstance(status, list):
|
||||
return self.filter(status__in=status)
|
||||
|
||||
@@ -5,9 +5,9 @@ This package contains middleware components for the Django application,
|
||||
including view tracking and other core functionality.
|
||||
"""
|
||||
|
||||
from .view_tracking import ViewTrackingMiddleware, get_view_stats_for_content
|
||||
from .analytics import PgHistoryContextMiddleware
|
||||
from .nextjs import APIResponseMiddleware
|
||||
from .view_tracking import ViewTrackingMiddleware, get_view_stats_for_content
|
||||
|
||||
__all__ = [
|
||||
"ViewTrackingMiddleware",
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Middleware for handling errors in HTMX requests.
|
||||
"""
|
||||
import logging
|
||||
|
||||
from django.http import HttpResponseServerError
|
||||
from django.template.loader import render_to_string
|
||||
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
Performance monitoring middleware for tracking request metrics.
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
import time
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import connection
|
||||
from django.utils.deprecation import MiddlewareMixin
|
||||
from django.conf import settings
|
||||
|
||||
performance_logger = logging.getLogger("performance")
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -162,10 +163,7 @@ class PerformanceMiddleware(MiddlewareMixin):
|
||||
def _get_client_ip(self, request):
|
||||
"""Extract client IP address from request"""
|
||||
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
|
||||
if x_forwarded_for:
|
||||
ip = x_forwarded_for.split(",")[0].strip()
|
||||
else:
|
||||
ip = request.META.get("REMOTE_ADDR", "")
|
||||
ip = x_forwarded_for.split(",")[0].strip() if x_forwarded_for else request.META.get("REMOTE_ADDR", "")
|
||||
return ip
|
||||
|
||||
def _get_log_level(self, duration, query_count, status_code):
|
||||
|
||||
@@ -14,11 +14,10 @@ Usage:
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Callable, Optional, Tuple
|
||||
from collections.abc import Callable
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.http import HttpRequest, HttpResponse, JsonResponse
|
||||
from django.conf import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -88,7 +87,7 @@ class AuthRateLimitMiddleware:
|
||||
|
||||
return response
|
||||
|
||||
def _get_rate_limits(self, path: str) -> Optional[dict]:
|
||||
def _get_rate_limits(self, path: str) -> dict | None:
|
||||
"""Get rate limits for a path, if any."""
|
||||
# Exact match
|
||||
if path in self.RATE_LIMITED_PATHS:
|
||||
@@ -125,7 +124,7 @@ class AuthRateLimitMiddleware:
|
||||
client_ip: str,
|
||||
path: str,
|
||||
limits: dict
|
||||
) -> Tuple[bool, str]:
|
||||
) -> tuple[bool, str]:
|
||||
"""
|
||||
Check if the client has exceeded rate limits.
|
||||
|
||||
|
||||
@@ -3,9 +3,10 @@ Request logging middleware for comprehensive request/response logging.
|
||||
Logs all HTTP requests with detailed data for debugging and monitoring.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import json
|
||||
|
||||
from django.utils.deprecation import MiddlewareMixin
|
||||
|
||||
logger = logging.getLogger('request_logging')
|
||||
|
||||
@@ -9,12 +9,13 @@ analytics for the trending algorithm.
|
||||
import logging
|
||||
import re
|
||||
from datetime import timedelta
|
||||
from typing import Optional, Union
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.utils import timezone
|
||||
from typing import Union
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.core.cache import cache
|
||||
from django.conf import settings
|
||||
from django.http import HttpRequest, HttpResponse
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.core.analytics import PageView
|
||||
from apps.parks.models import Park
|
||||
@@ -105,10 +106,7 @@ class ViewTrackingMiddleware:
|
||||
return True
|
||||
|
||||
# Skip AJAX requests (optional - depending on requirements)
|
||||
if request.META.get("HTTP_X_REQUESTED_WITH") == "XMLHttpRequest":
|
||||
return True
|
||||
|
||||
return False
|
||||
return request.META.get("HTTP_X_REQUESTED_WITH") == "XMLHttpRequest"
|
||||
|
||||
def _track_view_if_applicable(self, request: HttpRequest) -> None:
|
||||
"""Track view if the URL matches tracked patterns."""
|
||||
@@ -159,7 +157,7 @@ class ViewTrackingMiddleware:
|
||||
|
||||
def _get_content_object(
|
||||
self, content_type: str, slug: str
|
||||
) -> Optional[ContentObject]:
|
||||
) -> ContentObject | None:
|
||||
"""Get the content object by type and slug."""
|
||||
try:
|
||||
if content_type == "park":
|
||||
@@ -234,7 +232,7 @@ class ViewTrackingMiddleware:
|
||||
content_type = ContentType.objects.get_for_model(content_obj)
|
||||
return f"pageview_dedup:{content_type.id}:{content_obj.pk}:{client_ip}"
|
||||
|
||||
def _get_client_ip(self, request: HttpRequest) -> Optional[str]:
|
||||
def _get_client_ip(self, request: HttpRequest) -> str | None:
|
||||
"""Extract client IP address from request."""
|
||||
# Check for forwarded IP (common in production with load balancers)
|
||||
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
|
||||
@@ -270,13 +268,12 @@ class ViewTrackingMiddleware:
|
||||
|
||||
# Skip localhost and private IPs in production
|
||||
if getattr(settings, "SKIP_LOCAL_IPS", not settings.DEBUG):
|
||||
if ip.startswith(("127.", "192.168.", "10.")) or ip.startswith("172."):
|
||||
if any(
|
||||
16 <= int(ip.split(".")[1]) <= 31
|
||||
for _ in [ip]
|
||||
if ip.startswith("172.")
|
||||
):
|
||||
return False
|
||||
if (ip.startswith(("127.", "192.168.", "10.")) or ip.startswith("172.")) and any(
|
||||
16 <= int(ip.split(".")[1]) <= 31
|
||||
for _ in [ip]
|
||||
if ip.startswith("172.")
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""HTMX mixins for views. Canonical definitions for partial rendering and triggers."""
|
||||
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any
|
||||
|
||||
from django.template import TemplateDoesNotExist
|
||||
from django.template.loader import select_template
|
||||
@@ -11,7 +11,7 @@ from django.views.generic.list import MultipleObjectMixin
|
||||
class HTMXFilterableMixin(MultipleObjectMixin):
|
||||
"""Enhance list views to return partial templates for HTMX requests."""
|
||||
|
||||
filter_class: Optional[Type[Any]] = None
|
||||
filter_class: type[Any] | None = None
|
||||
htmx_partial_suffix = "_partial.html"
|
||||
|
||||
def get_queryset(self):
|
||||
@@ -47,7 +47,7 @@ class HTMXFilterableMixin(MultipleObjectMixin):
|
||||
class HTMXFormMixin(FormMixin):
|
||||
"""FormMixin that returns partials and field-level errors for HTMX requests."""
|
||||
|
||||
htmx_success_trigger: Optional[str] = None
|
||||
htmx_success_trigger: str | None = None
|
||||
|
||||
def form_invalid(self, form):
|
||||
"""Return partial with errors on invalid form submission via HTMX."""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from django.db import models
|
||||
import pghistory
|
||||
from django.contrib.contenttypes.fields import GenericForeignKey
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.db import models
|
||||
from django.utils.text import slugify
|
||||
|
||||
from apps.core.history import TrackedModel
|
||||
import pghistory
|
||||
|
||||
|
||||
@pghistory.track()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from rest_framework import permissions
|
||||
|
||||
|
||||
class IsOwnerOrReadOnly(permissions.BasePermission):
|
||||
"""
|
||||
Custom permission to only allow owners of an object to edit it.
|
||||
|
||||
@@ -3,24 +3,26 @@ Selectors for core functionality including map services and analytics.
|
||||
Following Django styleguide pattern for separating data access from business logic.
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any, List
|
||||
from django.db.models import QuerySet, Q, Count
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from django.contrib.gis.geos import Point, Polygon
|
||||
from django.contrib.gis.measure import Distance
|
||||
from django.db.models import Count, Q, QuerySet
|
||||
from django.utils import timezone
|
||||
from datetime import timedelta
|
||||
|
||||
from .analytics import PageView
|
||||
from apps.parks.models import Park
|
||||
from apps.rides.models import Ride
|
||||
|
||||
from .analytics import PageView
|
||||
|
||||
|
||||
def unified_locations_for_map(
|
||||
*,
|
||||
bounds: Optional[Polygon] = None,
|
||||
location_types: Optional[List[str]] = None,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, QuerySet]:
|
||||
bounds: Polygon | None = None,
|
||||
location_types: list[str] | None = None,
|
||||
filters: dict[str, Any] | None = None,
|
||||
) -> dict[str, QuerySet]:
|
||||
"""
|
||||
Get unified location data for map display across all location types.
|
||||
|
||||
@@ -88,9 +90,9 @@ def locations_near_point(
|
||||
*,
|
||||
point: Point,
|
||||
distance_km: float = 50,
|
||||
location_types: Optional[List[str]] = None,
|
||||
location_types: list[str] | None = None,
|
||||
limit: int = 20,
|
||||
) -> Dict[str, QuerySet]:
|
||||
) -> dict[str, QuerySet]:
|
||||
"""
|
||||
Get locations near a specific geographic point across all types.
|
||||
|
||||
@@ -149,7 +151,7 @@ def locations_near_point(
|
||||
return results
|
||||
|
||||
|
||||
def search_all_locations(*, query: str, limit: int = 20) -> Dict[str, QuerySet]:
|
||||
def search_all_locations(*, query: str, limit: int = 20) -> dict[str, QuerySet]:
|
||||
"""
|
||||
Search across all location types for a query string.
|
||||
|
||||
@@ -193,9 +195,9 @@ def search_all_locations(*, query: str, limit: int = 20) -> Dict[str, QuerySet]:
|
||||
|
||||
def page_views_for_analytics(
|
||||
*,
|
||||
start_date: Optional[timezone.datetime] = None,
|
||||
end_date: Optional[timezone.datetime] = None,
|
||||
path_pattern: Optional[str] = None,
|
||||
start_date: timezone.datetime | None = None,
|
||||
end_date: timezone.datetime | None = None,
|
||||
path_pattern: str | None = None,
|
||||
) -> QuerySet[PageView]:
|
||||
"""
|
||||
Get page views for analytics with optional filtering.
|
||||
@@ -222,7 +224,7 @@ def page_views_for_analytics(
|
||||
return queryset.order_by("-timestamp")
|
||||
|
||||
|
||||
def popular_pages_summary(*, days: int = 30) -> Dict[str, Any]:
|
||||
def popular_pages_summary(*, days: int = 30) -> dict[str, Any]:
|
||||
"""
|
||||
Get summary of most popular pages in the last N days.
|
||||
|
||||
@@ -261,7 +263,7 @@ def popular_pages_summary(*, days: int = 30) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def geographic_distribution_summary() -> Dict[str, Any]:
|
||||
def geographic_distribution_summary() -> dict[str, Any]:
|
||||
"""
|
||||
Get geographic distribution statistics for all locations.
|
||||
|
||||
@@ -290,7 +292,7 @@ def geographic_distribution_summary() -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
def system_health_metrics() -> Dict[str, Any]:
|
||||
def system_health_metrics() -> dict[str, Any]:
|
||||
"""
|
||||
Get system health and activity metrics.
|
||||
|
||||
|
||||
@@ -2,17 +2,17 @@
|
||||
Core services for ThrillWiki unified map functionality.
|
||||
"""
|
||||
|
||||
from .map_service import UnifiedMapService
|
||||
from .clustering_service import ClusteringService
|
||||
from .map_cache_service import MapCacheService
|
||||
from .data_structures import (
|
||||
UnifiedLocation,
|
||||
LocationType,
|
||||
ClusterData,
|
||||
GeoBounds,
|
||||
LocationType,
|
||||
MapFilters,
|
||||
MapResponse,
|
||||
ClusterData,
|
||||
UnifiedLocation,
|
||||
)
|
||||
from .map_cache_service import MapCacheService
|
||||
from .map_service import UnifiedMapService
|
||||
|
||||
__all__ = [
|
||||
"UnifiedMapService",
|
||||
|
||||
@@ -3,15 +3,15 @@ Clustering service for map locations to improve performance and user experience.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from .data_structures import (
|
||||
UnifiedLocation,
|
||||
ClusterData,
|
||||
GeoBounds,
|
||||
LocationType,
|
||||
UnifiedLocation,
|
||||
)
|
||||
|
||||
|
||||
@@ -70,10 +70,10 @@ class ClusteringService:
|
||||
|
||||
def cluster_locations(
|
||||
self,
|
||||
locations: List[UnifiedLocation],
|
||||
locations: list[UnifiedLocation],
|
||||
zoom_level: int,
|
||||
bounds: Optional[GeoBounds] = None,
|
||||
) -> tuple[List[UnifiedLocation], List[ClusterData]]:
|
||||
bounds: GeoBounds | None = None,
|
||||
) -> tuple[list[UnifiedLocation], list[ClusterData]]:
|
||||
"""
|
||||
Cluster locations based on zoom level and density.
|
||||
Returns (unclustered_locations, clusters).
|
||||
@@ -115,9 +115,9 @@ class ClusteringService:
|
||||
|
||||
def _project_locations(
|
||||
self,
|
||||
locations: List[UnifiedLocation],
|
||||
bounds: Optional[GeoBounds] = None,
|
||||
) -> List[ClusterPoint]:
|
||||
locations: list[UnifiedLocation],
|
||||
bounds: GeoBounds | None = None,
|
||||
) -> list[ClusterPoint]:
|
||||
"""Convert lat/lng coordinates to projected x/y for clustering calculations."""
|
||||
cluster_points = []
|
||||
|
||||
@@ -149,8 +149,8 @@ class ClusteringService:
|
||||
return cluster_points
|
||||
|
||||
def _cluster_points(
|
||||
self, points: List[ClusterPoint], radius_pixels: int, min_points: int
|
||||
) -> List[List[ClusterPoint]]:
|
||||
self, points: list[ClusterPoint], radius_pixels: int, min_points: int
|
||||
) -> list[list[ClusterPoint]]:
|
||||
"""
|
||||
Cluster points using a simple distance-based approach.
|
||||
Radius is in pixels, converted to meters based on zoom level.
|
||||
@@ -189,7 +189,7 @@ class ClusteringService:
|
||||
dy = point1.y - point2.y
|
||||
return math.sqrt(dx * dx + dy * dy)
|
||||
|
||||
def _create_cluster(self, cluster_points: List[ClusterPoint]) -> ClusterData:
|
||||
def _create_cluster(self, cluster_points: list[ClusterPoint]) -> ClusterData:
|
||||
"""Create a ClusterData object from a group of points."""
|
||||
locations = [cp.location for cp in cluster_points]
|
||||
|
||||
@@ -205,7 +205,7 @@ class ClusteringService:
|
||||
)
|
||||
|
||||
# Collect location types in cluster
|
||||
types = set(loc.type for loc in locations)
|
||||
types = {loc.type for loc in locations}
|
||||
|
||||
# Select representative location (highest weight)
|
||||
representative = self._select_representative_location(locations)
|
||||
@@ -224,8 +224,8 @@ class ClusteringService:
|
||||
)
|
||||
|
||||
def _select_representative_location(
|
||||
self, locations: List[UnifiedLocation]
|
||||
) -> Optional[UnifiedLocation]:
|
||||
self, locations: list[UnifiedLocation]
|
||||
) -> UnifiedLocation | None:
|
||||
"""Select the most representative location for a cluster."""
|
||||
if not locations:
|
||||
return None
|
||||
@@ -259,7 +259,7 @@ class ClusteringService:
|
||||
# Fall back to highest weight location
|
||||
return max(locations, key=lambda x: x.cluster_weight)
|
||||
|
||||
def get_cluster_breakdown(self, clusters: List[ClusterData]) -> Dict[str, Any]:
|
||||
def get_cluster_breakdown(self, clusters: list[ClusterData]) -> dict[str, Any]:
|
||||
"""Get statistics about clustering results."""
|
||||
if not clusters:
|
||||
return {
|
||||
@@ -293,7 +293,7 @@ class ClusteringService:
|
||||
|
||||
def expand_cluster(
|
||||
self, cluster: ClusterData, zoom_level: int
|
||||
) -> List[UnifiedLocation]:
|
||||
) -> list[UnifiedLocation]:
|
||||
"""
|
||||
Expand a cluster to show individual locations (for drill-down functionality).
|
||||
This would typically require re-querying the database with the cluster bounds.
|
||||
@@ -335,7 +335,7 @@ class SmartClusteringRules:
|
||||
|
||||
@staticmethod
|
||||
def calculate_cluster_priority(
|
||||
locations: List[UnifiedLocation],
|
||||
locations: list[UnifiedLocation],
|
||||
) -> UnifiedLocation:
|
||||
"""Select the representative location for a cluster based on priority rules."""
|
||||
# Prioritize by: 1) Parks over rides, 2) Higher weight, 3) Better
|
||||
|
||||
@@ -4,7 +4,8 @@ Data structures for the unified map service.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Set, Any
|
||||
from typing import Any
|
||||
|
||||
from django.contrib.gis.geos import Polygon
|
||||
|
||||
|
||||
@@ -60,7 +61,7 @@ class GeoBounds:
|
||||
"""Check if a point is within these bounds."""
|
||||
return self.south <= lat <= self.north and self.west <= lng <= self.east
|
||||
|
||||
def to_dict(self) -> Dict[str, float]:
|
||||
def to_dict(self) -> dict[str, float]:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
return {
|
||||
"north": self.north,
|
||||
@@ -74,18 +75,18 @@ class GeoBounds:
|
||||
class MapFilters:
|
||||
"""Filtering options for map queries."""
|
||||
|
||||
location_types: Optional[Set[LocationType]] = None
|
||||
park_status: Optional[Set[str]] = None # OPERATING, CLOSED_TEMP, etc.
|
||||
ride_types: Optional[Set[str]] = None
|
||||
company_roles: Optional[Set[str]] = None # OPERATOR, MANUFACTURER, etc.
|
||||
search_query: Optional[str] = None
|
||||
min_rating: Optional[float] = None
|
||||
location_types: set[LocationType] | None = None
|
||||
park_status: set[str] | None = None # OPERATING, CLOSED_TEMP, etc.
|
||||
ride_types: set[str] | None = None
|
||||
company_roles: set[str] | None = None # OPERATOR, MANUFACTURER, etc.
|
||||
search_query: str | None = None
|
||||
min_rating: float | None = None
|
||||
has_coordinates: bool = True
|
||||
country: Optional[str] = None
|
||||
state: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
country: str | None = None
|
||||
state: str | None = None
|
||||
city: str | None = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for caching and serialization."""
|
||||
return {
|
||||
"location_types": (
|
||||
@@ -110,10 +111,10 @@ class UnifiedLocation:
|
||||
id: str # Composite: f"{type}_{id}"
|
||||
type: LocationType
|
||||
name: str
|
||||
coordinates: List[float] # [lat, lng]
|
||||
address: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
type_data: Dict[str, Any] = field(default_factory=dict)
|
||||
coordinates: list[float] # [lat, lng]
|
||||
address: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
type_data: dict[str, Any] = field(default_factory=dict)
|
||||
cluster_weight: int = 1
|
||||
cluster_category: str = "default"
|
||||
|
||||
@@ -127,7 +128,7 @@ class UnifiedLocation:
|
||||
"""Get longitude from coordinates."""
|
||||
return self.coordinates[1]
|
||||
|
||||
def to_geojson_feature(self) -> Dict[str, Any]:
|
||||
def to_geojson_feature(self) -> dict[str, Any]:
|
||||
"""Convert to GeoJSON feature for mapping libraries."""
|
||||
return {
|
||||
"type": "Feature",
|
||||
@@ -148,7 +149,7 @@ class UnifiedLocation:
|
||||
},
|
||||
}
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for JSON responses."""
|
||||
return {
|
||||
"id": self.id,
|
||||
@@ -168,13 +169,13 @@ class ClusterData:
|
||||
"""Represents a cluster of locations for map display."""
|
||||
|
||||
id: str
|
||||
coordinates: List[float] # [lat, lng]
|
||||
coordinates: list[float] # [lat, lng]
|
||||
count: int
|
||||
types: Set[LocationType]
|
||||
types: set[LocationType]
|
||||
bounds: GeoBounds
|
||||
representative_location: Optional[UnifiedLocation] = None
|
||||
representative_location: UnifiedLocation | None = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for JSON responses."""
|
||||
return {
|
||||
"id": self.id,
|
||||
@@ -194,18 +195,18 @@ class ClusterData:
|
||||
class MapResponse:
|
||||
"""Response structure for map API calls."""
|
||||
|
||||
locations: List[UnifiedLocation] = field(default_factory=list)
|
||||
clusters: List[ClusterData] = field(default_factory=list)
|
||||
bounds: Optional[GeoBounds] = None
|
||||
locations: list[UnifiedLocation] = field(default_factory=list)
|
||||
clusters: list[ClusterData] = field(default_factory=list)
|
||||
bounds: GeoBounds | None = None
|
||||
total_count: int = 0
|
||||
filtered_count: int = 0
|
||||
zoom_level: Optional[int] = None
|
||||
zoom_level: int | None = None
|
||||
clustered: bool = False
|
||||
cache_hit: bool = False
|
||||
query_time_ms: Optional[int] = None
|
||||
filters_applied: List[str] = field(default_factory=list)
|
||||
query_time_ms: int | None = None
|
||||
filters_applied: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for JSON responses."""
|
||||
return {
|
||||
"status": "success",
|
||||
@@ -241,7 +242,7 @@ class QueryPerformanceMetrics:
|
||||
bounds_used: bool
|
||||
clustering_used: bool
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for logging."""
|
||||
return {
|
||||
"query_time_ms": self.query_time_ms,
|
||||
|
||||
@@ -2,13 +2,15 @@
|
||||
Enhanced caching service with multiple cache backends and strategies.
|
||||
"""
|
||||
|
||||
from typing import Optional, Any, Dict, Callable
|
||||
from django.core.cache import caches
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from django.core.cache import caches
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -64,7 +66,7 @@ class EnhancedCacheService:
|
||||
def cache_api_response(
|
||||
self,
|
||||
view_name: str,
|
||||
params: Dict,
|
||||
params: dict,
|
||||
response_data: Any,
|
||||
timeout: int = 1800,
|
||||
):
|
||||
@@ -73,7 +75,7 @@ class EnhancedCacheService:
|
||||
self.api_cache.set(cache_key, response_data, timeout)
|
||||
logger.debug(f"Cached API response for view '{view_name}'")
|
||||
|
||||
def get_cached_api_response(self, view_name: str, params: Dict) -> Optional[Any]:
|
||||
def get_cached_api_response(self, view_name: str, params: dict) -> Any | None:
|
||||
"""Retrieve cached API response"""
|
||||
cache_key = self._generate_api_cache_key(view_name, params)
|
||||
result = self.api_cache.get(cache_key)
|
||||
@@ -103,7 +105,7 @@ class EnhancedCacheService:
|
||||
|
||||
def get_cached_geographic_data(
|
||||
self, bounds: "GeoBounds", zoom_level: int
|
||||
) -> Optional[Any]:
|
||||
) -> Any | None:
|
||||
"""Retrieve cached geographic data"""
|
||||
cache_key = f"geo:{bounds.min_lat}:{bounds.min_lng}:{bounds.max_lat}:{
|
||||
bounds.max_lng
|
||||
@@ -129,13 +131,10 @@ class EnhancedCacheService:
|
||||
logger.error(f"Error invalidating cache pattern '{pattern}': {e}")
|
||||
|
||||
def invalidate_model_cache(
|
||||
self, model_name: str, instance_id: Optional[int] = None
|
||||
self, model_name: str, instance_id: int | None = None
|
||||
):
|
||||
"""Invalidate cache keys related to a specific model"""
|
||||
if instance_id:
|
||||
pattern = f"*{model_name}:{instance_id}*"
|
||||
else:
|
||||
pattern = f"*{model_name}*"
|
||||
pattern = f"*{model_name}:{instance_id}*" if instance_id else f"*{model_name}*"
|
||||
|
||||
self.invalidate_pattern(pattern)
|
||||
|
||||
@@ -155,7 +154,7 @@ class EnhancedCacheService:
|
||||
except Exception as e:
|
||||
logger.error(f"Error warming cache for key '{cache_key}': {e}")
|
||||
|
||||
def _generate_api_cache_key(self, view_name: str, params: Dict) -> str:
|
||||
def _generate_api_cache_key(self, view_name: str, params: dict) -> str:
|
||||
"""Generate consistent cache keys for API responses"""
|
||||
# Sort params to ensure consistent key generation
|
||||
params_str = json.dumps(params, sort_keys=True, default=str)
|
||||
@@ -275,7 +274,7 @@ class CacheMonitor:
|
||||
def __init__(self):
|
||||
self.cache_service = EnhancedCacheService()
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
def get_cache_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics if available"""
|
||||
stats = {}
|
||||
|
||||
@@ -319,7 +318,7 @@ class CacheMonitor:
|
||||
if stats:
|
||||
logger.info("Cache performance statistics", extra=stats)
|
||||
|
||||
def get_cache_statistics(self, key_prefix: str = "") -> Dict[str, Any]:
|
||||
def get_cache_statistics(self, key_prefix: str = "") -> dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics for a given key prefix.
|
||||
|
||||
|
||||
@@ -13,16 +13,15 @@ Features:
|
||||
"""
|
||||
|
||||
import re
|
||||
from difflib import SequenceMatcher
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from difflib import SequenceMatcher
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from django.db.models import Q
|
||||
|
||||
from apps.parks.models import Park
|
||||
from apps.parks.models import Company, Park
|
||||
from apps.rides.models import Ride
|
||||
from apps.parks.models import Company
|
||||
|
||||
|
||||
class EntityType(Enum):
|
||||
@@ -44,9 +43,9 @@ class FuzzyMatchResult:
|
||||
score: float # 0.0 to 1.0, higher is better match
|
||||
match_reason: str # Description of why this was matched
|
||||
confidence: str # 'high', 'medium', 'low'
|
||||
url: Optional[str] = None
|
||||
url: str | None = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for API responses."""
|
||||
return {
|
||||
"entity_type": self.entity_type.value,
|
||||
@@ -180,8 +179,8 @@ class EntityFuzzyMatcher:
|
||||
self.algorithms = FuzzyMatchingAlgorithms()
|
||||
|
||||
def find_entity(
|
||||
self, query: str, entity_types: Optional[List[EntityType]] = None, user=None
|
||||
) -> tuple[List[FuzzyMatchResult], Optional[EntitySuggestion]]:
|
||||
self, query: str, entity_types: list[EntityType] | None = None, user=None
|
||||
) -> tuple[list[FuzzyMatchResult], EntitySuggestion | None]:
|
||||
"""
|
||||
Find entities matching the query with fuzzy matching.
|
||||
|
||||
@@ -221,7 +220,7 @@ class EntityFuzzyMatcher:
|
||||
|
||||
def _get_candidates(
|
||||
self, query: str, entity_type: EntityType
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get potential matching candidates for an entity type."""
|
||||
candidates = []
|
||||
|
||||
@@ -286,8 +285,8 @@ class EntityFuzzyMatcher:
|
||||
return candidates
|
||||
|
||||
def _score_and_rank_candidates(
|
||||
self, query: str, candidates: List[Dict[str, Any]]
|
||||
) -> List[FuzzyMatchResult]:
|
||||
self, query: str, candidates: list[dict[str, Any]]
|
||||
) -> list[FuzzyMatchResult]:
|
||||
"""Score and rank all candidates using multiple algorithms."""
|
||||
scored_matches = []
|
||||
|
||||
@@ -356,7 +355,7 @@ class EntityFuzzyMatcher:
|
||||
return sorted(scored_matches, key=lambda x: x.score, reverse=True)
|
||||
|
||||
def _generate_entity_suggestion(
|
||||
self, query: str, entity_types: List[EntityType], user
|
||||
self, query: str, entity_types: list[EntityType], user
|
||||
) -> EntitySuggestion:
|
||||
"""Generate suggestion for creating new entity when no matches found."""
|
||||
|
||||
|
||||
@@ -2,36 +2,37 @@
|
||||
Location adapters for converting between domain-specific models and UnifiedLocation.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.urls import reverse
|
||||
|
||||
from .data_structures import (
|
||||
UnifiedLocation,
|
||||
LocationType,
|
||||
GeoBounds,
|
||||
MapFilters,
|
||||
)
|
||||
from apps.parks.models import ParkLocation, CompanyHeadquarters
|
||||
from apps.parks.models import CompanyHeadquarters, ParkLocation
|
||||
from apps.rides.models import RideLocation
|
||||
|
||||
from .data_structures import (
|
||||
GeoBounds,
|
||||
LocationType,
|
||||
MapFilters,
|
||||
UnifiedLocation,
|
||||
)
|
||||
|
||||
|
||||
class BaseLocationAdapter:
|
||||
"""Base adapter class for location conversions."""
|
||||
|
||||
def to_unified_location(self, location_obj) -> Optional[UnifiedLocation]:
|
||||
def to_unified_location(self, location_obj) -> UnifiedLocation | None:
|
||||
"""Convert model instance to UnifiedLocation."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_queryset(
|
||||
self,
|
||||
bounds: Optional[GeoBounds] = None,
|
||||
filters: Optional[MapFilters] = None,
|
||||
bounds: GeoBounds | None = None,
|
||||
filters: MapFilters | None = None,
|
||||
) -> QuerySet:
|
||||
"""Get optimized queryset for this location type."""
|
||||
raise NotImplementedError
|
||||
|
||||
def bulk_convert(self, queryset: QuerySet) -> List[UnifiedLocation]:
|
||||
def bulk_convert(self, queryset: QuerySet) -> list[UnifiedLocation]:
|
||||
"""Convert multiple location objects efficiently."""
|
||||
unified_locations = []
|
||||
for obj in queryset:
|
||||
@@ -46,7 +47,7 @@ class ParkLocationAdapter(BaseLocationAdapter):
|
||||
|
||||
def to_unified_location(
|
||||
self, location_obj: ParkLocation
|
||||
) -> Optional[UnifiedLocation]:
|
||||
) -> UnifiedLocation | None:
|
||||
"""Convert ParkLocation to UnifiedLocation."""
|
||||
if (
|
||||
not location_obj.point
|
||||
@@ -106,8 +107,8 @@ class ParkLocationAdapter(BaseLocationAdapter):
|
||||
|
||||
def get_queryset(
|
||||
self,
|
||||
bounds: Optional[GeoBounds] = None,
|
||||
filters: Optional[MapFilters] = None,
|
||||
bounds: GeoBounds | None = None,
|
||||
filters: MapFilters | None = None,
|
||||
) -> QuerySet:
|
||||
"""Get optimized queryset for park locations."""
|
||||
queryset = ParkLocation.objects.select_related("park", "park__operator").filter(
|
||||
@@ -177,7 +178,7 @@ class RideLocationAdapter(BaseLocationAdapter):
|
||||
|
||||
def to_unified_location(
|
||||
self, location_obj: RideLocation
|
||||
) -> Optional[UnifiedLocation]:
|
||||
) -> UnifiedLocation | None:
|
||||
"""Convert RideLocation to UnifiedLocation."""
|
||||
if (
|
||||
not location_obj.point
|
||||
@@ -235,8 +236,8 @@ class RideLocationAdapter(BaseLocationAdapter):
|
||||
|
||||
def get_queryset(
|
||||
self,
|
||||
bounds: Optional[GeoBounds] = None,
|
||||
filters: Optional[MapFilters] = None,
|
||||
bounds: GeoBounds | None = None,
|
||||
filters: MapFilters | None = None,
|
||||
) -> QuerySet:
|
||||
"""Get optimized queryset for ride locations."""
|
||||
queryset = RideLocation.objects.select_related(
|
||||
@@ -293,7 +294,7 @@ class CompanyLocationAdapter(BaseLocationAdapter):
|
||||
|
||||
def to_unified_location(
|
||||
self, location_obj: CompanyHeadquarters
|
||||
) -> Optional[UnifiedLocation]:
|
||||
) -> UnifiedLocation | None:
|
||||
"""Convert CompanyHeadquarters to UnifiedLocation."""
|
||||
# Note: CompanyHeadquarters doesn't have coordinates, so we need to geocode
|
||||
# For now, we'll skip companies without coordinates
|
||||
@@ -302,8 +303,8 @@ class CompanyLocationAdapter(BaseLocationAdapter):
|
||||
|
||||
def get_queryset(
|
||||
self,
|
||||
bounds: Optional[GeoBounds] = None,
|
||||
filters: Optional[MapFilters] = None,
|
||||
bounds: GeoBounds | None = None,
|
||||
filters: MapFilters | None = None,
|
||||
) -> QuerySet:
|
||||
"""Get optimized queryset for company locations."""
|
||||
queryset = CompanyHeadquarters.objects.select_related("company")
|
||||
@@ -346,9 +347,9 @@ class LocationAbstractionLayer:
|
||||
|
||||
def get_all_locations(
|
||||
self,
|
||||
bounds: Optional[GeoBounds] = None,
|
||||
filters: Optional[MapFilters] = None,
|
||||
) -> List[UnifiedLocation]:
|
||||
bounds: GeoBounds | None = None,
|
||||
filters: MapFilters | None = None,
|
||||
) -> list[UnifiedLocation]:
|
||||
"""Get locations from all sources within bounds."""
|
||||
all_locations = []
|
||||
|
||||
@@ -370,9 +371,9 @@ class LocationAbstractionLayer:
|
||||
def get_locations_by_type(
|
||||
self,
|
||||
location_type: LocationType,
|
||||
bounds: Optional[GeoBounds] = None,
|
||||
filters: Optional[MapFilters] = None,
|
||||
) -> List[UnifiedLocation]:
|
||||
bounds: GeoBounds | None = None,
|
||||
filters: MapFilters | None = None,
|
||||
) -> list[UnifiedLocation]:
|
||||
"""Get locations of specific type."""
|
||||
adapter = self.adapters[location_type]
|
||||
queryset = adapter.get_queryset(bounds, filters)
|
||||
@@ -380,7 +381,7 @@ class LocationAbstractionLayer:
|
||||
|
||||
def get_location_by_id(
|
||||
self, location_type: LocationType, location_id: int
|
||||
) -> Optional[UnifiedLocation]:
|
||||
) -> UnifiedLocation | None:
|
||||
"""Get single location with full details."""
|
||||
adapter = self.adapters[location_type]
|
||||
|
||||
|
||||
@@ -6,13 +6,14 @@ to provide proximity-based search, location filtering, and geographic
|
||||
search capabilities.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from django.contrib.gis.geos import Point
|
||||
from django.contrib.gis.measure import Distance
|
||||
from django.db.models import Q
|
||||
from typing import Optional, List, Dict, Any, Set
|
||||
from dataclasses import dataclass
|
||||
|
||||
from apps.parks.models import Park, Company, ParkLocation
|
||||
from apps.parks.models import Company, Park, ParkLocation
|
||||
from apps.rides.models import Ride
|
||||
|
||||
|
||||
@@ -21,22 +22,22 @@ class LocationSearchFilters:
|
||||
"""Filters for location-aware search queries."""
|
||||
|
||||
# Text search
|
||||
search_query: Optional[str] = None
|
||||
search_query: str | None = None
|
||||
|
||||
# Location-based filters
|
||||
location_point: Optional[Point] = None
|
||||
radius_km: Optional[float] = None
|
||||
location_types: Optional[Set[str]] = None # 'park', 'ride', 'company'
|
||||
location_point: Point | None = None
|
||||
radius_km: float | None = None
|
||||
location_types: set[str] | None = None # 'park', 'ride', 'company'
|
||||
|
||||
# Geographic filters
|
||||
country: Optional[str] = None
|
||||
state: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
country: str | None = None
|
||||
state: str | None = None
|
||||
city: str | None = None
|
||||
|
||||
# Content-specific filters
|
||||
park_status: Optional[List[str]] = None
|
||||
ride_types: Optional[List[str]] = None
|
||||
company_roles: Optional[List[str]] = None
|
||||
park_status: list[str] | None = None
|
||||
ride_types: list[str] | None = None
|
||||
company_roles: list[str] | None = None
|
||||
|
||||
# Result options
|
||||
include_distance: bool = True
|
||||
@@ -51,26 +52,26 @@ class LocationSearchResult:
|
||||
content_type: str # 'park', 'ride', 'company'
|
||||
object_id: int
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
description: str | None = None
|
||||
url: str | None = None
|
||||
|
||||
# Location data
|
||||
latitude: Optional[float] = None
|
||||
longitude: Optional[float] = None
|
||||
address: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
state: Optional[str] = None
|
||||
country: Optional[str] = None
|
||||
latitude: float | None = None
|
||||
longitude: float | None = None
|
||||
address: str | None = None
|
||||
city: str | None = None
|
||||
state: str | None = None
|
||||
country: str | None = None
|
||||
|
||||
# Distance data (if proximity search)
|
||||
distance_km: Optional[float] = None
|
||||
distance_km: float | None = None
|
||||
|
||||
# Additional metadata
|
||||
status: Optional[str] = None
|
||||
tags: Optional[List[str]] = None
|
||||
rating: Optional[float] = None
|
||||
status: str | None = None
|
||||
tags: list[str] | None = None
|
||||
rating: float | None = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
return {
|
||||
"content_type": self.content_type,
|
||||
@@ -96,7 +97,7 @@ class LocationSearchResult:
|
||||
class LocationSearchService:
|
||||
"""Service for performing location-aware searches across ThrillWiki content."""
|
||||
|
||||
def search(self, filters: LocationSearchFilters) -> List[LocationSearchResult]:
|
||||
def search(self, filters: LocationSearchFilters) -> list[LocationSearchResult]:
|
||||
"""
|
||||
Perform a comprehensive location-aware search.
|
||||
|
||||
@@ -129,7 +130,7 @@ class LocationSearchService:
|
||||
|
||||
def _search_parks(
|
||||
self, filters: LocationSearchFilters
|
||||
) -> List[LocationSearchResult]:
|
||||
) -> list[LocationSearchResult]:
|
||||
"""Search parks with location data."""
|
||||
queryset = Park.objects.select_related("location", "operator").all()
|
||||
|
||||
@@ -199,7 +200,7 @@ class LocationSearchService:
|
||||
|
||||
def _search_rides(
|
||||
self, filters: LocationSearchFilters
|
||||
) -> List[LocationSearchResult]:
|
||||
) -> list[LocationSearchResult]:
|
||||
"""Search rides with location data."""
|
||||
queryset = Ride.objects.select_related("park", "location").all()
|
||||
|
||||
@@ -282,7 +283,7 @@ class LocationSearchService:
|
||||
|
||||
def _search_companies(
|
||||
self, filters: LocationSearchFilters
|
||||
) -> List[LocationSearchResult]:
|
||||
) -> list[LocationSearchResult]:
|
||||
"""Search companies with headquarters location data."""
|
||||
queryset = Company.objects.select_related("headquarters").all()
|
||||
|
||||
@@ -398,7 +399,7 @@ class LocationSearchService:
|
||||
|
||||
return queryset
|
||||
|
||||
def suggest_locations(self, query: str, limit: int = 10) -> List[Dict[str, Any]]:
|
||||
def suggest_locations(self, query: str, limit: int = 10) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get location suggestions for autocomplete.
|
||||
|
||||
|
||||
@@ -5,18 +5,18 @@ Caching service for map data to improve performance and reduce database load.
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.utils import timezone
|
||||
|
||||
from .data_structures import (
|
||||
UnifiedLocation,
|
||||
ClusterData,
|
||||
GeoBounds,
|
||||
MapFilters,
|
||||
MapResponse,
|
||||
QueryPerformanceMetrics,
|
||||
UnifiedLocation,
|
||||
)
|
||||
|
||||
|
||||
@@ -52,9 +52,9 @@ class MapCacheService:
|
||||
|
||||
def get_locations_cache_key(
|
||||
self,
|
||||
bounds: Optional[GeoBounds],
|
||||
filters: Optional[MapFilters],
|
||||
zoom_level: Optional[int] = None,
|
||||
bounds: GeoBounds | None,
|
||||
filters: MapFilters | None,
|
||||
zoom_level: int | None = None,
|
||||
) -> str:
|
||||
"""Generate cache key for location queries."""
|
||||
key_parts = [self.LOCATIONS_PREFIX]
|
||||
@@ -76,8 +76,8 @@ class MapCacheService:
|
||||
|
||||
def get_clusters_cache_key(
|
||||
self,
|
||||
bounds: Optional[GeoBounds],
|
||||
filters: Optional[MapFilters],
|
||||
bounds: GeoBounds | None,
|
||||
filters: MapFilters | None,
|
||||
zoom_level: int,
|
||||
) -> str:
|
||||
"""Generate cache key for cluster queries."""
|
||||
@@ -102,8 +102,8 @@ class MapCacheService:
|
||||
def cache_locations(
|
||||
self,
|
||||
cache_key: str,
|
||||
locations: List[UnifiedLocation],
|
||||
ttl: Optional[int] = None,
|
||||
locations: list[UnifiedLocation],
|
||||
ttl: int | None = None,
|
||||
) -> None:
|
||||
"""Cache location data."""
|
||||
try:
|
||||
@@ -122,8 +122,8 @@ class MapCacheService:
|
||||
def cache_clusters(
|
||||
self,
|
||||
cache_key: str,
|
||||
clusters: List[ClusterData],
|
||||
ttl: Optional[int] = None,
|
||||
clusters: list[ClusterData],
|
||||
ttl: int | None = None,
|
||||
) -> None:
|
||||
"""Cache cluster data."""
|
||||
try:
|
||||
@@ -138,7 +138,7 @@ class MapCacheService:
|
||||
print(f"Cache write error for clusters {cache_key}: {e}")
|
||||
|
||||
def cache_map_response(
|
||||
self, cache_key: str, response: MapResponse, ttl: Optional[int] = None
|
||||
self, cache_key: str, response: MapResponse, ttl: int | None = None
|
||||
) -> None:
|
||||
"""Cache complete map response."""
|
||||
try:
|
||||
@@ -149,7 +149,7 @@ class MapCacheService:
|
||||
except Exception as e:
|
||||
print(f"Cache write error for response {cache_key}: {e}")
|
||||
|
||||
def get_cached_locations(self, cache_key: str) -> Optional[List[UnifiedLocation]]:
|
||||
def get_cached_locations(self, cache_key: str) -> list[UnifiedLocation] | None:
|
||||
"""Retrieve cached location data."""
|
||||
try:
|
||||
cache_data = cache.get(cache_key)
|
||||
@@ -172,7 +172,7 @@ class MapCacheService:
|
||||
self.cache_stats["misses"] += 1
|
||||
return None
|
||||
|
||||
def get_cached_clusters(self, cache_key: str) -> Optional[List[ClusterData]]:
|
||||
def get_cached_clusters(self, cache_key: str) -> list[ClusterData] | None:
|
||||
"""Retrieve cached cluster data."""
|
||||
try:
|
||||
cache_data = cache.get(cache_key)
|
||||
@@ -194,7 +194,7 @@ class MapCacheService:
|
||||
self.cache_stats["misses"] += 1
|
||||
return None
|
||||
|
||||
def get_cached_map_response(self, cache_key: str) -> Optional[MapResponse]:
|
||||
def get_cached_map_response(self, cache_key: str) -> MapResponse | None:
|
||||
"""Retrieve cached map response."""
|
||||
try:
|
||||
cache_data = cache.get(cache_key)
|
||||
@@ -213,7 +213,7 @@ class MapCacheService:
|
||||
return None
|
||||
|
||||
def invalidate_location_cache(
|
||||
self, location_type: str, location_id: Optional[int] = None
|
||||
self, location_type: str, location_id: int | None = None
|
||||
) -> None:
|
||||
"""Invalidate cache for specific location or all locations of a type."""
|
||||
try:
|
||||
@@ -268,7 +268,7 @@ class MapCacheService:
|
||||
except Exception as e:
|
||||
print(f"Cache clear error: {e}")
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
def get_cache_stats(self) -> dict[str, Any]:
|
||||
"""Get cache performance statistics."""
|
||||
total_requests = self.cache_stats["hits"] + self.cache_stats["misses"]
|
||||
hit_rate = (
|
||||
@@ -370,7 +370,7 @@ class MapCacheService:
|
||||
filter_str = json.dumps(filter_dict, sort_keys=True)
|
||||
return hashlib.md5(filter_str.encode()).hexdigest()[:8]
|
||||
|
||||
def _dict_to_unified_location(self, data: Dict[str, Any]) -> UnifiedLocation:
|
||||
def _dict_to_unified_location(self, data: dict[str, Any]) -> UnifiedLocation:
|
||||
"""Convert dictionary back to UnifiedLocation object."""
|
||||
from .data_structures import LocationType
|
||||
|
||||
@@ -386,7 +386,7 @@ class MapCacheService:
|
||||
cluster_category=data.get("cluster_category", "default"),
|
||||
)
|
||||
|
||||
def _dict_to_cluster_data(self, data: Dict[str, Any]) -> ClusterData:
|
||||
def _dict_to_cluster_data(self, data: dict[str, Any]) -> ClusterData:
|
||||
"""Convert dictionary back to ClusterData object."""
|
||||
from .data_structures import LocationType
|
||||
|
||||
@@ -406,7 +406,7 @@ class MapCacheService:
|
||||
representative_location=representative,
|
||||
)
|
||||
|
||||
def _dict_to_map_response(self, data: Dict[str, Any]) -> MapResponse:
|
||||
def _dict_to_map_response(self, data: dict[str, Any]) -> MapResponse:
|
||||
"""Convert dictionary back to MapResponse object."""
|
||||
locations = [
|
||||
self._dict_to_unified_location(loc) for loc in data.get("locations", [])
|
||||
|
||||
@@ -3,20 +3,21 @@ Unified Map Service - Main orchestrating service for all map functionality.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import List, Optional, Dict, Any, Set
|
||||
from typing import Any
|
||||
|
||||
from django.db import connection
|
||||
|
||||
from .clustering_service import ClusteringService
|
||||
from .data_structures import (
|
||||
UnifiedLocation,
|
||||
ClusterData,
|
||||
GeoBounds,
|
||||
LocationType,
|
||||
MapFilters,
|
||||
MapResponse,
|
||||
LocationType,
|
||||
QueryPerformanceMetrics,
|
||||
UnifiedLocation,
|
||||
)
|
||||
from .location_adapters import LocationAbstractionLayer
|
||||
from .clustering_service import ClusteringService
|
||||
from .map_cache_service import MapCacheService
|
||||
|
||||
|
||||
@@ -39,8 +40,8 @@ class UnifiedMapService:
|
||||
def get_map_data(
|
||||
self,
|
||||
*,
|
||||
bounds: Optional[GeoBounds] = None,
|
||||
filters: Optional[MapFilters] = None,
|
||||
bounds: GeoBounds | None = None,
|
||||
filters: MapFilters | None = None,
|
||||
zoom_level: int = DEFAULT_ZOOM_LEVEL,
|
||||
cluster: bool = True,
|
||||
use_cache: bool = True,
|
||||
@@ -145,7 +146,7 @@ class UnifiedMapService:
|
||||
|
||||
def get_location_details(
|
||||
self, location_type: str, location_id: int
|
||||
) -> Optional[UnifiedLocation]:
|
||||
) -> UnifiedLocation | None:
|
||||
"""
|
||||
Get detailed information for a specific location.
|
||||
|
||||
@@ -188,10 +189,10 @@ class UnifiedMapService:
|
||||
def search_locations(
|
||||
self,
|
||||
query: str,
|
||||
bounds: Optional[GeoBounds] = None,
|
||||
location_types: Optional[Set[LocationType]] = None,
|
||||
bounds: GeoBounds | None = None,
|
||||
location_types: set[LocationType] | None = None,
|
||||
limit: int = 50,
|
||||
) -> List[UnifiedLocation]:
|
||||
) -> list[UnifiedLocation]:
|
||||
"""
|
||||
Search locations with text query.
|
||||
|
||||
@@ -228,7 +229,7 @@ class UnifiedMapService:
|
||||
south: float,
|
||||
east: float,
|
||||
west: float,
|
||||
location_types: Optional[Set[LocationType]] = None,
|
||||
location_types: set[LocationType] | None = None,
|
||||
zoom_level: int = DEFAULT_ZOOM_LEVEL,
|
||||
) -> MapResponse:
|
||||
"""
|
||||
@@ -261,8 +262,8 @@ class UnifiedMapService:
|
||||
def get_clustered_locations(
|
||||
self,
|
||||
zoom_level: int,
|
||||
bounds: Optional[GeoBounds] = None,
|
||||
filters: Optional[MapFilters] = None,
|
||||
bounds: GeoBounds | None = None,
|
||||
filters: MapFilters | None = None,
|
||||
) -> MapResponse:
|
||||
"""
|
||||
Get clustered location data for map display.
|
||||
@@ -282,9 +283,9 @@ class UnifiedMapService:
|
||||
def get_locations_by_type(
|
||||
self,
|
||||
location_type: LocationType,
|
||||
bounds: Optional[GeoBounds] = None,
|
||||
limit: Optional[int] = None,
|
||||
) -> List[UnifiedLocation]:
|
||||
bounds: GeoBounds | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[UnifiedLocation]:
|
||||
"""
|
||||
Get locations of a specific type.
|
||||
|
||||
@@ -313,9 +314,9 @@ class UnifiedMapService:
|
||||
|
||||
def invalidate_cache(
|
||||
self,
|
||||
location_type: Optional[str] = None,
|
||||
location_id: Optional[int] = None,
|
||||
bounds: Optional[GeoBounds] = None,
|
||||
location_type: str | None = None,
|
||||
location_id: int | None = None,
|
||||
bounds: GeoBounds | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Invalidate cached map data.
|
||||
@@ -332,7 +333,7 @@ class UnifiedMapService:
|
||||
else:
|
||||
self.cache_service.clear_all_map_cache()
|
||||
|
||||
def get_service_stats(self) -> Dict[str, Any]:
|
||||
def get_service_stats(self) -> dict[str, Any]:
|
||||
"""Get service performance and usage statistics."""
|
||||
cache_stats = self.cache_service.get_cache_stats()
|
||||
|
||||
@@ -346,17 +347,17 @@ class UnifiedMapService:
|
||||
}
|
||||
|
||||
def _get_locations_from_db(
|
||||
self, bounds: Optional[GeoBounds], filters: Optional[MapFilters]
|
||||
) -> List[UnifiedLocation]:
|
||||
self, bounds: GeoBounds | None, filters: MapFilters | None
|
||||
) -> list[UnifiedLocation]:
|
||||
"""Get locations from database using the abstraction layer."""
|
||||
return self.location_layer.get_all_locations(bounds, filters)
|
||||
|
||||
def _apply_smart_limiting(
|
||||
self,
|
||||
locations: List[UnifiedLocation],
|
||||
bounds: Optional[GeoBounds],
|
||||
locations: list[UnifiedLocation],
|
||||
bounds: GeoBounds | None,
|
||||
zoom_level: int,
|
||||
) -> List[UnifiedLocation]:
|
||||
) -> list[UnifiedLocation]:
|
||||
"""Apply intelligent limiting based on zoom level and density."""
|
||||
if zoom_level < 6: # Very zoomed out - show only major parks
|
||||
major_parks = [
|
||||
@@ -375,10 +376,10 @@ class UnifiedMapService:
|
||||
|
||||
def _calculate_response_bounds(
|
||||
self,
|
||||
locations: List[UnifiedLocation],
|
||||
clusters: List[ClusterData],
|
||||
request_bounds: Optional[GeoBounds],
|
||||
) -> Optional[GeoBounds]:
|
||||
locations: list[UnifiedLocation],
|
||||
clusters: list[ClusterData],
|
||||
request_bounds: GeoBounds | None,
|
||||
) -> GeoBounds | None:
|
||||
"""Calculate the actual bounds of the response data."""
|
||||
if request_bounds:
|
||||
return request_bounds
|
||||
@@ -396,12 +397,12 @@ class UnifiedMapService:
|
||||
if not all_coords:
|
||||
return None
|
||||
|
||||
lats, lngs = zip(*all_coords)
|
||||
lats, lngs = zip(*all_coords, strict=False)
|
||||
return GeoBounds(
|
||||
north=max(lats), south=min(lats), east=max(lngs), west=min(lngs)
|
||||
)
|
||||
|
||||
def _get_applied_filters_list(self, filters: Optional[MapFilters]) -> List[str]:
|
||||
def _get_applied_filters_list(self, filters: MapFilters | None) -> list[str]:
|
||||
"""Get list of applied filter types for metadata."""
|
||||
if not filters:
|
||||
return []
|
||||
@@ -430,8 +431,8 @@ class UnifiedMapService:
|
||||
|
||||
def _generate_cache_key(
|
||||
self,
|
||||
bounds: Optional[GeoBounds],
|
||||
filters: Optional[MapFilters],
|
||||
bounds: GeoBounds | None,
|
||||
filters: MapFilters | None,
|
||||
zoom_level: int,
|
||||
cluster: bool,
|
||||
) -> str:
|
||||
|
||||
@@ -6,12 +6,13 @@ that can be used across all domain-specific media implementations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional, Dict
|
||||
from datetime import datetime
|
||||
from django.core.files.uploadedfile import UploadedFile
|
||||
from django.conf import settings
|
||||
from PIL import Image, ExifTags
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.files.uploadedfile import UploadedFile
|
||||
from PIL import ExifTags, Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -21,7 +22,7 @@ class MediaService:
|
||||
|
||||
@staticmethod
|
||||
def generate_upload_path(
|
||||
domain: str, identifier: str, filename: str, subdirectory: Optional[str] = None
|
||||
domain: str, identifier: str, filename: str, subdirectory: str | None = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate standardized upload path for media files.
|
||||
@@ -44,7 +45,7 @@ class MediaService:
|
||||
return f"{domain}/{identifier}/{base_filename}"
|
||||
|
||||
@staticmethod
|
||||
def extract_exif_date(image_file: UploadedFile) -> Optional[datetime]:
|
||||
def extract_exif_date(image_file: UploadedFile) -> datetime | None:
|
||||
"""
|
||||
Extract the date taken from image EXIF data.
|
||||
|
||||
@@ -60,18 +61,17 @@ class MediaService:
|
||||
if exif:
|
||||
# Find the DateTime tag ID
|
||||
for tag_id in ExifTags.TAGS:
|
||||
if ExifTags.TAGS[tag_id] == "DateTimeOriginal":
|
||||
if tag_id in exif:
|
||||
# EXIF dates are typically in format: '2024:02:15 14:30:00'
|
||||
date_str = exif[tag_id]
|
||||
return datetime.strptime(date_str, "%Y:%m:%d %H:%M:%S")
|
||||
if ExifTags.TAGS[tag_id] == "DateTimeOriginal" and tag_id in exif:
|
||||
# EXIF dates are typically in format: '2024:02:15 14:30:00'
|
||||
date_str = exif[tag_id]
|
||||
return datetime.strptime(date_str, "%Y:%m:%d %H:%M:%S")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract EXIF date: {str(e)}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def validate_image_file(image_file: UploadedFile) -> tuple[bool, Optional[str]]:
|
||||
def validate_image_file(image_file: UploadedFile) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Validate uploaded image file.
|
||||
|
||||
@@ -144,6 +144,7 @@ class MediaService:
|
||||
|
||||
# Save processed image
|
||||
from io import BytesIO
|
||||
|
||||
from django.core.files.uploadedfile import InMemoryUploadedFile
|
||||
|
||||
output = BytesIO()
|
||||
@@ -180,7 +181,7 @@ class MediaService:
|
||||
return f"Uploaded by {username} on {current_time.strftime('%B %d, %Y at %I:%M %p')}"
|
||||
|
||||
@staticmethod
|
||||
def get_storage_stats() -> Dict[str, Any]:
|
||||
def get_storage_stats() -> dict[str, Any]:
|
||||
"""
|
||||
Get media storage statistics.
|
||||
|
||||
|
||||
@@ -6,7 +6,8 @@ while maintaining compatibility with Cloudflare Images.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from django.utils.text import slugify
|
||||
|
||||
|
||||
@@ -83,7 +84,7 @@ class MediaURLService:
|
||||
return f"/parks/{park_slug}/rides/{ride_slug}/photos/{filename}"
|
||||
|
||||
@staticmethod
|
||||
def parse_photo_filename(filename: str) -> Optional[Dict[str, Any]]:
|
||||
def parse_photo_filename(filename: str) -> dict[str, Any] | None:
|
||||
"""
|
||||
Parse a friendly filename to extract photo ID and variant.
|
||||
|
||||
@@ -118,7 +119,7 @@ class MediaURLService:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_cloudflare_url_with_fallback(cloudflare_image, variant: str = "public") -> Optional[str]:
|
||||
def get_cloudflare_url_with_fallback(cloudflare_image, variant: str = "public") -> str | None:
|
||||
"""
|
||||
Get Cloudflare URL with fallback handling.
|
||||
|
||||
|
||||
@@ -2,13 +2,14 @@
|
||||
Performance monitoring utilities and context managers.
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import Optional, Dict, Any, List
|
||||
from django.db import connection
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import connection
|
||||
from django.utils import timezone
|
||||
|
||||
logger = logging.getLogger("performance")
|
||||
@@ -271,7 +272,7 @@ class DatabaseQueryAnalyzer:
|
||||
"""Analyze database query patterns and performance"""
|
||||
|
||||
@staticmethod
|
||||
def analyze_queries(queries: List[Dict]) -> Dict[str, Any]:
|
||||
def analyze_queries(queries: list[dict]) -> dict[str, Any]:
|
||||
"""Analyze a list of queries for patterns and issues"""
|
||||
if not queries:
|
||||
return {}
|
||||
@@ -332,7 +333,7 @@ class DatabaseQueryAnalyzer:
|
||||
return analysis
|
||||
|
||||
@classmethod
|
||||
def analyze_current_queries(cls) -> Dict[str, Any]:
|
||||
def analyze_current_queries(cls) -> dict[str, Any]:
|
||||
"""Analyze the current request's queries"""
|
||||
if hasattr(connection, "queries"):
|
||||
return cls.analyze_queries(connection.queries)
|
||||
@@ -340,7 +341,7 @@ class DatabaseQueryAnalyzer:
|
||||
|
||||
|
||||
# Performance monitoring decorators
|
||||
def monitor_function_performance(operation_name: Optional[str] = None):
|
||||
def monitor_function_performance(operation_name: str | None = None):
|
||||
"""Decorator to monitor function performance"""
|
||||
|
||||
def decorator(func):
|
||||
@@ -379,7 +380,7 @@ class PerformanceMetrics:
|
||||
def __init__(self):
|
||||
self.metrics = []
|
||||
|
||||
def record_metric(self, name: str, value: float, tags: Optional[Dict] = None):
|
||||
def record_metric(self, name: str, value: float, tags: dict | None = None):
|
||||
"""Record a performance metric"""
|
||||
metric = {
|
||||
"name": name,
|
||||
@@ -392,7 +393,7 @@ class PerformanceMetrics:
|
||||
# Log the metric
|
||||
logger.info(f"Performance metric: {name} = {value}", extra=metric)
|
||||
|
||||
def get_metrics(self, name: Optional[str] = None) -> List[Dict]:
|
||||
def get_metrics(self, name: str | None = None) -> list[dict]:
|
||||
"""Get recorded metrics, optionally filtered by name"""
|
||||
if name:
|
||||
return [m for m in self.metrics if m["name"] == name]
|
||||
|
||||
@@ -12,11 +12,12 @@ Results are cached in Redis for performance optimization.
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any
|
||||
from django.utils import timezone
|
||||
from typing import Any
|
||||
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.core.cache import cache
|
||||
from django.db.models import Q
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.core.analytics import PageView
|
||||
from apps.parks.models import Park
|
||||
@@ -56,7 +57,7 @@ class TrendingService:
|
||||
|
||||
def get_trending_content(
|
||||
self, content_type: str = "all", limit: int = 20, force_refresh: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get trending content using direct calculation.
|
||||
|
||||
@@ -121,7 +122,7 @@ class TrendingService:
|
||||
limit: int = 20,
|
||||
days_back: int = 30,
|
||||
force_refresh: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get recently added content using direct calculation.
|
||||
|
||||
@@ -182,7 +183,7 @@ class TrendingService:
|
||||
self.logger.error(f"Error getting new content: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def _calculate_trending_parks(self, limit: int) -> List[Dict[str, Any]]:
|
||||
def _calculate_trending_parks(self, limit: int) -> list[dict[str, Any]]:
|
||||
"""Calculate trending scores for parks."""
|
||||
parks = Park.objects.filter(status="OPERATING").select_related(
|
||||
"location", "operator", "card_image"
|
||||
@@ -253,7 +254,7 @@ class TrendingService:
|
||||
|
||||
return trending_parks
|
||||
|
||||
def _calculate_trending_rides(self, limit: int) -> List[Dict[str, Any]]:
|
||||
def _calculate_trending_rides(self, limit: int) -> list[dict[str, Any]]:
|
||||
"""Calculate trending scores for rides."""
|
||||
rides = Ride.objects.filter(status="OPERATING").select_related(
|
||||
"park", "park__location", "card_image"
|
||||
@@ -456,7 +457,7 @@ class TrendingService:
|
||||
self.logger.warning(f"Error calculating popularity score: {e}")
|
||||
return 0.0
|
||||
|
||||
def _get_new_parks(self, cutoff_date: datetime, limit: int) -> List[Dict[str, Any]]:
|
||||
def _get_new_parks(self, cutoff_date: datetime, limit: int) -> list[dict[str, Any]]:
|
||||
"""Get recently added parks."""
|
||||
new_parks = (
|
||||
Park.objects.filter(
|
||||
@@ -528,7 +529,7 @@ class TrendingService:
|
||||
|
||||
return results
|
||||
|
||||
def _get_new_rides(self, cutoff_date: datetime, limit: int) -> List[Dict[str, Any]]:
|
||||
def _get_new_rides(self, cutoff_date: datetime, limit: int) -> list[dict[str, Any]]:
|
||||
"""Get recently added rides."""
|
||||
new_rides = (
|
||||
Ride.objects.filter(
|
||||
@@ -584,8 +585,8 @@ class TrendingService:
|
||||
return results
|
||||
|
||||
def _format_trending_results(
|
||||
self, trending_items: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
self, trending_items: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format trending results for frontend consumption."""
|
||||
formatted_results = []
|
||||
|
||||
@@ -649,8 +650,8 @@ class TrendingService:
|
||||
return formatted_results
|
||||
|
||||
def _format_new_content_results(
|
||||
self, new_items: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
self, new_items: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format new content results for frontend consumption."""
|
||||
formatted_results = []
|
||||
|
||||
|
||||
@@ -1,106 +1,106 @@
|
||||
"""State machine utilities for core app."""
|
||||
from .fields import RichFSMField
|
||||
from .mixins import StateMachineMixin
|
||||
from .builder import (
|
||||
StateTransitionBuilder,
|
||||
determine_method_name_for_transition,
|
||||
)
|
||||
from .decorators import (
|
||||
generate_transition_decorator,
|
||||
TransitionMethodFactory,
|
||||
with_callbacks,
|
||||
register_method_callbacks,
|
||||
)
|
||||
from .registry import (
|
||||
TransitionRegistry,
|
||||
TransitionInfo,
|
||||
registry_instance,
|
||||
register_callback,
|
||||
register_notification_callback,
|
||||
register_cache_invalidation,
|
||||
register_related_update,
|
||||
register_transition_callbacks,
|
||||
discover_and_register_callbacks,
|
||||
)
|
||||
from .callback_base import (
|
||||
BaseTransitionCallback,
|
||||
PreTransitionCallback,
|
||||
PostTransitionCallback,
|
||||
ErrorTransitionCallback,
|
||||
TransitionContext,
|
||||
TransitionCallbackRegistry,
|
||||
callback_registry,
|
||||
CallbackStage,
|
||||
)
|
||||
from .signals import (
|
||||
pre_state_transition,
|
||||
post_state_transition,
|
||||
state_transition_failed,
|
||||
register_transition_handler,
|
||||
on_transition,
|
||||
on_pre_transition,
|
||||
on_post_transition,
|
||||
on_transition_error,
|
||||
ErrorTransitionCallback,
|
||||
PostTransitionCallback,
|
||||
PreTransitionCallback,
|
||||
TransitionCallbackRegistry,
|
||||
TransitionContext,
|
||||
callback_registry,
|
||||
)
|
||||
from .config import (
|
||||
CallbackConfig,
|
||||
callback_config,
|
||||
get_callback_config,
|
||||
)
|
||||
from .monitoring import (
|
||||
CallbackMonitor,
|
||||
callback_monitor,
|
||||
TimedCallbackExecution,
|
||||
)
|
||||
from .validators import MetadataValidator, ValidationResult
|
||||
from .guards import (
|
||||
# Role constants
|
||||
VALID_ROLES,
|
||||
MODERATOR_ROLES,
|
||||
ADMIN_ROLES,
|
||||
SUPERUSER_ROLES,
|
||||
ESCALATION_LEVEL_ROLES,
|
||||
# Guard classes
|
||||
PermissionGuard,
|
||||
OwnershipGuard,
|
||||
AssignmentGuard,
|
||||
StateGuard,
|
||||
MetadataGuard,
|
||||
CompositeGuard,
|
||||
# Guard extraction and creation
|
||||
extract_guards_from_metadata,
|
||||
create_permission_guard,
|
||||
create_ownership_guard,
|
||||
create_assignment_guard,
|
||||
create_composite_guard,
|
||||
validate_guard_metadata,
|
||||
# Registry
|
||||
GuardRegistry,
|
||||
guard_registry,
|
||||
# Role checking functions
|
||||
get_user_role,
|
||||
has_role,
|
||||
is_moderator_or_above,
|
||||
is_admin_or_above,
|
||||
is_superuser_role,
|
||||
has_permission,
|
||||
from .decorators import (
|
||||
TransitionMethodFactory,
|
||||
generate_transition_decorator,
|
||||
register_method_callbacks,
|
||||
with_callbacks,
|
||||
)
|
||||
from .exceptions import (
|
||||
ERROR_MESSAGES,
|
||||
TransitionNotAvailable,
|
||||
TransitionPermissionDenied,
|
||||
TransitionValidationError,
|
||||
TransitionNotAvailable,
|
||||
ERROR_MESSAGES,
|
||||
format_transition_error,
|
||||
get_permission_error_message,
|
||||
get_state_error_message,
|
||||
format_transition_error,
|
||||
raise_permission_denied,
|
||||
raise_validation_error,
|
||||
)
|
||||
from .fields import RichFSMField
|
||||
from .guards import (
|
||||
ADMIN_ROLES,
|
||||
ESCALATION_LEVEL_ROLES,
|
||||
MODERATOR_ROLES,
|
||||
SUPERUSER_ROLES,
|
||||
# Role constants
|
||||
VALID_ROLES,
|
||||
AssignmentGuard,
|
||||
CompositeGuard,
|
||||
# Registry
|
||||
GuardRegistry,
|
||||
MetadataGuard,
|
||||
OwnershipGuard,
|
||||
# Guard classes
|
||||
PermissionGuard,
|
||||
StateGuard,
|
||||
create_assignment_guard,
|
||||
create_composite_guard,
|
||||
create_ownership_guard,
|
||||
create_permission_guard,
|
||||
# Guard extraction and creation
|
||||
extract_guards_from_metadata,
|
||||
# Role checking functions
|
||||
get_user_role,
|
||||
guard_registry,
|
||||
has_permission,
|
||||
has_role,
|
||||
is_admin_or_above,
|
||||
is_moderator_or_above,
|
||||
is_superuser_role,
|
||||
validate_guard_metadata,
|
||||
)
|
||||
from .integration import (
|
||||
apply_state_machine,
|
||||
StateMachineModelMixin,
|
||||
apply_state_machine,
|
||||
state_machine_model,
|
||||
)
|
||||
from .mixins import StateMachineMixin
|
||||
from .monitoring import (
|
||||
CallbackMonitor,
|
||||
TimedCallbackExecution,
|
||||
callback_monitor,
|
||||
)
|
||||
from .registry import (
|
||||
TransitionInfo,
|
||||
TransitionRegistry,
|
||||
discover_and_register_callbacks,
|
||||
register_cache_invalidation,
|
||||
register_callback,
|
||||
register_notification_callback,
|
||||
register_related_update,
|
||||
register_transition_callbacks,
|
||||
registry_instance,
|
||||
)
|
||||
from .signals import (
|
||||
on_post_transition,
|
||||
on_pre_transition,
|
||||
on_transition,
|
||||
on_transition_error,
|
||||
post_state_transition,
|
||||
pre_state_transition,
|
||||
register_transition_handler,
|
||||
state_transition_failed,
|
||||
)
|
||||
from .validators import MetadataValidator, ValidationResult
|
||||
|
||||
__all__ = [
|
||||
# Fields and mixins
|
||||
|
||||
@@ -60,11 +60,12 @@ See Also:
|
||||
- apps.core.choices.registry: Central choice registry
|
||||
- apps.core.state_machine.guards: Guard extraction from metadata
|
||||
"""
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
|
||||
from apps.core.choices.registry import registry
|
||||
from apps.core.choices.base import RichChoice
|
||||
from apps.core.choices.registry import registry
|
||||
|
||||
|
||||
class StateTransitionBuilder:
|
||||
@@ -123,7 +124,7 @@ class StateTransitionBuilder:
|
||||
"""
|
||||
self.choice_group = choice_group
|
||||
self.domain = domain
|
||||
self._cache: Dict[str, Any] = {}
|
||||
self._cache: dict[str, Any] = {}
|
||||
|
||||
# Validate choice group exists
|
||||
group = registry.get(choice_group, domain)
|
||||
@@ -134,7 +135,7 @@ class StateTransitionBuilder:
|
||||
|
||||
self.choices = registry.get_choices(choice_group, domain)
|
||||
|
||||
def get_choice_metadata(self, state_value: str) -> Dict[str, Any]:
|
||||
def get_choice_metadata(self, state_value: str) -> dict[str, Any]:
|
||||
"""
|
||||
Retrieve metadata for a specific state.
|
||||
|
||||
@@ -156,7 +157,7 @@ class StateTransitionBuilder:
|
||||
self._cache[cache_key] = metadata
|
||||
return metadata
|
||||
|
||||
def extract_valid_transitions(self, state_value: str) -> List[str]:
|
||||
def extract_valid_transitions(self, state_value: str) -> list[str]:
|
||||
"""
|
||||
Get can_transition_to list from metadata.
|
||||
|
||||
@@ -184,7 +185,7 @@ class StateTransitionBuilder:
|
||||
|
||||
def extract_permission_requirements(
|
||||
self, state_value: str
|
||||
) -> Dict[str, bool]:
|
||||
) -> dict[str, bool]:
|
||||
"""
|
||||
Extract permission requirements from metadata.
|
||||
|
||||
@@ -228,7 +229,7 @@ class StateTransitionBuilder:
|
||||
metadata = self.get_choice_metadata(state_value)
|
||||
return metadata.get("is_actionable", False)
|
||||
|
||||
def build_transition_graph(self) -> Dict[str, List[str]]:
|
||||
def build_transition_graph(self) -> dict[str, list[str]]:
|
||||
"""
|
||||
Create a complete state transition graph.
|
||||
|
||||
@@ -247,7 +248,7 @@ class StateTransitionBuilder:
|
||||
self._cache[cache_key] = graph
|
||||
return graph
|
||||
|
||||
def get_all_states(self) -> List[str]:
|
||||
def get_all_states(self) -> list[str]:
|
||||
"""
|
||||
Get all state values in the choice group.
|
||||
|
||||
@@ -256,7 +257,7 @@ class StateTransitionBuilder:
|
||||
"""
|
||||
return [choice.value for choice in self.choices]
|
||||
|
||||
def get_choice(self, state_value: str) -> Optional[RichChoice]:
|
||||
def get_choice(self, state_value: str) -> RichChoice | None:
|
||||
"""
|
||||
Get the RichChoice object for a state.
|
||||
|
||||
@@ -276,7 +277,7 @@ class StateTransitionBuilder:
|
||||
def determine_method_name_for_transition(source: str, target: str) -> str:
|
||||
"""
|
||||
Determine appropriate method name for a transition.
|
||||
|
||||
|
||||
Always uses transition_to_<state> pattern to avoid conflicts with
|
||||
business logic methods (approve, reject, escalate, etc.).
|
||||
|
||||
|
||||
@@ -66,16 +66,15 @@ See Also:
|
||||
- apps.core.state_machine.callbacks.related_updates: Related model callbacks
|
||||
"""
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from django.db import models
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -167,12 +166,12 @@ class TransitionContext:
|
||||
field_name: str
|
||||
source_state: str
|
||||
target_state: str
|
||||
user: Optional[Any] = None
|
||||
user: Any | None = None
|
||||
timestamp: datetime = field(default_factory=datetime.now)
|
||||
extra_data: Dict[str, Any] = field(default_factory=dict)
|
||||
extra_data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def model_class(self) -> Type[models.Model]:
|
||||
def model_class(self) -> type[models.Model]:
|
||||
"""Get the model class of the instance."""
|
||||
return type(self.instance)
|
||||
|
||||
@@ -206,9 +205,9 @@ class BaseTransitionCallback(ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
priority: Optional[int] = None,
|
||||
continue_on_error: Optional[bool] = None,
|
||||
name: Optional[str] = None,
|
||||
priority: int | None = None,
|
||||
continue_on_error: bool | None = None,
|
||||
name: str | None = None,
|
||||
):
|
||||
if priority is not None:
|
||||
self.priority = priority
|
||||
@@ -288,7 +287,7 @@ class ErrorTransitionCallback(BaseTransitionCallback):
|
||||
# Error callbacks should always continue
|
||||
continue_on_error: bool = True
|
||||
|
||||
def execute(self, context: TransitionContext, exception: Optional[Exception] = None) -> bool:
|
||||
def execute(self, context: TransitionContext, exception: Exception | None = None) -> bool:
|
||||
"""
|
||||
Execute the error callback.
|
||||
|
||||
@@ -307,7 +306,7 @@ class CallbackRegistration:
|
||||
"""Represents a registered callback with its configuration."""
|
||||
|
||||
callback: BaseTransitionCallback
|
||||
model_class: Type[models.Model]
|
||||
model_class: type[models.Model]
|
||||
field_name: str
|
||||
source: str # Can be '*' for wildcard
|
||||
target: str # Can be '*' for wildcard
|
||||
@@ -315,7 +314,7 @@ class CallbackRegistration:
|
||||
|
||||
def matches(
|
||||
self,
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
field_name: str,
|
||||
source: str,
|
||||
target: str,
|
||||
@@ -327,9 +326,7 @@ class CallbackRegistration:
|
||||
return False
|
||||
if self.source != '*' and self.source != source:
|
||||
return False
|
||||
if self.target != '*' and self.target != target:
|
||||
return False
|
||||
return True
|
||||
return not (self.target != '*' and self.target != target)
|
||||
|
||||
|
||||
class TransitionCallbackRegistry:
|
||||
@@ -351,7 +348,7 @@ class TransitionCallbackRegistry:
|
||||
def __init__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
self._callbacks: Dict[CallbackStage, List[CallbackRegistration]] = {
|
||||
self._callbacks: dict[CallbackStage, list[CallbackRegistration]] = {
|
||||
CallbackStage.PRE: [],
|
||||
CallbackStage.POST: [],
|
||||
CallbackStage.ERROR: [],
|
||||
@@ -360,12 +357,12 @@ class TransitionCallbackRegistry:
|
||||
|
||||
def register(
|
||||
self,
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
field_name: str,
|
||||
source: str,
|
||||
target: str,
|
||||
callback: BaseTransitionCallback,
|
||||
stage: Union[CallbackStage, str] = CallbackStage.POST,
|
||||
stage: CallbackStage | str = CallbackStage.POST,
|
||||
) -> None:
|
||||
"""
|
||||
Register a callback for a specific transition.
|
||||
@@ -402,10 +399,10 @@ class TransitionCallbackRegistry:
|
||||
|
||||
def register_bulk(
|
||||
self,
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
field_name: str,
|
||||
callbacks_config: Dict[Tuple[str, str], List[BaseTransitionCallback]],
|
||||
stage: Union[CallbackStage, str] = CallbackStage.POST,
|
||||
callbacks_config: dict[tuple[str, str], list[BaseTransitionCallback]],
|
||||
stage: CallbackStage | str = CallbackStage.POST,
|
||||
) -> None:
|
||||
"""
|
||||
Register multiple callbacks for multiple transitions.
|
||||
@@ -422,12 +419,12 @@ class TransitionCallbackRegistry:
|
||||
|
||||
def get_callbacks(
|
||||
self,
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
field_name: str,
|
||||
source: str,
|
||||
target: str,
|
||||
stage: Union[CallbackStage, str] = CallbackStage.POST,
|
||||
) -> List[BaseTransitionCallback]:
|
||||
stage: CallbackStage | str = CallbackStage.POST,
|
||||
) -> list[BaseTransitionCallback]:
|
||||
"""
|
||||
Get all callbacks matching the given transition.
|
||||
|
||||
@@ -454,9 +451,9 @@ class TransitionCallbackRegistry:
|
||||
def execute_callbacks(
|
||||
self,
|
||||
context: TransitionContext,
|
||||
stage: Union[CallbackStage, str] = CallbackStage.POST,
|
||||
exception: Optional[Exception] = None,
|
||||
) -> Tuple[bool, List[Tuple[BaseTransitionCallback, Optional[Exception]]]]:
|
||||
stage: CallbackStage | str = CallbackStage.POST,
|
||||
exception: Exception | None = None,
|
||||
) -> tuple[bool, list[tuple[BaseTransitionCallback, Exception | None]]]:
|
||||
"""
|
||||
Execute all callbacks for a transition.
|
||||
|
||||
@@ -479,7 +476,7 @@ class TransitionCallbackRegistry:
|
||||
stage,
|
||||
)
|
||||
|
||||
failures: List[Tuple[BaseTransitionCallback, Optional[Exception]]] = []
|
||||
failures: list[tuple[BaseTransitionCallback, Exception | None]] = []
|
||||
overall_success = True
|
||||
|
||||
for callback in callbacks:
|
||||
@@ -530,7 +527,7 @@ class TransitionCallbackRegistry:
|
||||
|
||||
return overall_success, failures
|
||||
|
||||
def clear(self, model_class: Optional[Type[models.Model]] = None) -> None:
|
||||
def clear(self, model_class: type[models.Model] | None = None) -> None:
|
||||
"""
|
||||
Clear registered callbacks.
|
||||
|
||||
@@ -550,8 +547,8 @@ class TransitionCallbackRegistry:
|
||||
|
||||
def get_all_registrations(
|
||||
self,
|
||||
model_class: Optional[Type[models.Model]] = None,
|
||||
) -> Dict[CallbackStage, List[CallbackRegistration]]:
|
||||
model_class: type[models.Model] | None = None,
|
||||
) -> dict[CallbackStage, list[CallbackRegistration]]:
|
||||
"""
|
||||
Get all registered callbacks, optionally filtered by model class.
|
||||
|
||||
@@ -585,19 +582,19 @@ callback_registry = TransitionCallbackRegistry()
|
||||
|
||||
# Convenience functions for common operations
|
||||
def register_callback(
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
field_name: str,
|
||||
source: str,
|
||||
target: str,
|
||||
callback: BaseTransitionCallback,
|
||||
stage: Union[CallbackStage, str] = CallbackStage.POST,
|
||||
stage: CallbackStage | str = CallbackStage.POST,
|
||||
) -> None:
|
||||
"""Convenience function to register a callback."""
|
||||
callback_registry.register(model_class, field_name, source, target, callback, stage)
|
||||
|
||||
|
||||
def register_pre_callback(
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
field_name: str,
|
||||
source: str,
|
||||
target: str,
|
||||
@@ -610,7 +607,7 @@ def register_pre_callback(
|
||||
|
||||
|
||||
def register_post_callback(
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
field_name: str,
|
||||
source: str,
|
||||
target: str,
|
||||
@@ -623,7 +620,7 @@ def register_post_callback(
|
||||
|
||||
|
||||
def register_error_callback(
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
field_name: str,
|
||||
source: str,
|
||||
target: str,
|
||||
|
||||
@@ -5,29 +5,28 @@ This package provides specialized callback implementations for
|
||||
FSM state transitions.
|
||||
"""
|
||||
|
||||
from .notifications import (
|
||||
NotificationCallback,
|
||||
SubmissionApprovedNotification,
|
||||
SubmissionRejectedNotification,
|
||||
SubmissionEscalatedNotification,
|
||||
StatusChangeNotification,
|
||||
ModerationNotificationCallback,
|
||||
)
|
||||
from .cache import (
|
||||
APICacheInvalidation,
|
||||
CacheInvalidationCallback,
|
||||
ModelCacheInvalidation,
|
||||
RelatedModelCacheInvalidation,
|
||||
PatternCacheInvalidation,
|
||||
APICacheInvalidation,
|
||||
RelatedModelCacheInvalidation,
|
||||
)
|
||||
from .notifications import (
|
||||
ModerationNotificationCallback,
|
||||
NotificationCallback,
|
||||
StatusChangeNotification,
|
||||
SubmissionApprovedNotification,
|
||||
SubmissionEscalatedNotification,
|
||||
SubmissionRejectedNotification,
|
||||
)
|
||||
from .related_updates import (
|
||||
RelatedModelUpdateCallback,
|
||||
ParkCountUpdateCallback,
|
||||
SearchTextUpdateCallback,
|
||||
ComputedFieldUpdateCallback,
|
||||
ParkCountUpdateCallback,
|
||||
RelatedModelUpdateCallback,
|
||||
SearchTextUpdateCallback,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Notification callbacks
|
||||
"NotificationCallback",
|
||||
|
||||
@@ -5,15 +5,12 @@ This module provides callback implementations that invalidate cache entries
|
||||
when state transitions occur.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Set, Type
|
||||
import logging
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import models
|
||||
|
||||
from ..callback_base import PostTransitionCallback, TransitionContext
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -29,7 +26,7 @@ class CacheInvalidationCallback(PostTransitionCallback):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patterns: Optional[List[str]] = None,
|
||||
patterns: list[str] | None = None,
|
||||
include_instance_patterns: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -62,7 +59,7 @@ class CacheInvalidationCallback(PostTransitionCallback):
|
||||
logger.warning("EnhancedCacheService not available")
|
||||
return None
|
||||
|
||||
def _get_instance_patterns(self, context: TransitionContext) -> List[str]:
|
||||
def _get_instance_patterns(self, context: TransitionContext) -> list[str]:
|
||||
"""Generate cache key patterns specific to the instance."""
|
||||
patterns = []
|
||||
model_name = context.model_name.lower()
|
||||
@@ -75,7 +72,7 @@ class CacheInvalidationCallback(PostTransitionCallback):
|
||||
|
||||
return patterns
|
||||
|
||||
def _get_all_patterns(self, context: TransitionContext) -> Set[str]:
|
||||
def _get_all_patterns(self, context: TransitionContext) -> set[str]:
|
||||
"""Get all patterns to invalidate, including generated ones."""
|
||||
all_patterns = set(self.patterns)
|
||||
|
||||
@@ -130,7 +127,6 @@ class CacheInvalidationCallback(PostTransitionCallback):
|
||||
def _fallback_invalidation(self, context: TransitionContext) -> bool:
|
||||
"""Fallback cache invalidation using Django's cache framework."""
|
||||
try:
|
||||
from django.core.cache import cache
|
||||
|
||||
patterns = self._get_all_patterns(context)
|
||||
|
||||
@@ -171,7 +167,7 @@ class ModelCacheInvalidation(CacheInvalidationCallback):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _get_instance_patterns(self, context: TransitionContext) -> List[str]:
|
||||
def _get_instance_patterns(self, context: TransitionContext) -> list[str]:
|
||||
"""Get model-specific patterns."""
|
||||
base_patterns = super()._get_instance_patterns(context)
|
||||
|
||||
@@ -198,7 +194,7 @@ class RelatedModelCacheInvalidation(CacheInvalidationCallback):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
related_fields: Optional[List[str]] = None,
|
||||
related_fields: list[str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -211,7 +207,7 @@ class RelatedModelCacheInvalidation(CacheInvalidationCallback):
|
||||
super().__init__(**kwargs)
|
||||
self.related_fields = related_fields or []
|
||||
|
||||
def _get_related_patterns(self, context: TransitionContext) -> List[str]:
|
||||
def _get_related_patterns(self, context: TransitionContext) -> list[str]:
|
||||
"""Get cache patterns for related models."""
|
||||
patterns = []
|
||||
|
||||
@@ -236,7 +232,7 @@ class RelatedModelCacheInvalidation(CacheInvalidationCallback):
|
||||
|
||||
return patterns
|
||||
|
||||
def _get_all_patterns(self, context: TransitionContext) -> Set[str]:
|
||||
def _get_all_patterns(self, context: TransitionContext) -> set[str]:
|
||||
"""Get all patterns including related model patterns."""
|
||||
patterns = super()._get_all_patterns(context)
|
||||
patterns.update(self._get_related_patterns(context))
|
||||
@@ -254,7 +250,7 @@ class PatternCacheInvalidation(CacheInvalidationCallback):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patterns: List[str],
|
||||
patterns: list[str],
|
||||
include_instance_patterns: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -284,7 +280,7 @@ class APICacheInvalidation(CacheInvalidationCallback):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_prefixes: Optional[List[str]] = None,
|
||||
api_prefixes: list[str] | None = None,
|
||||
include_geo_cache: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -300,7 +296,7 @@ class APICacheInvalidation(CacheInvalidationCallback):
|
||||
self.api_prefixes = api_prefixes or ['api:*']
|
||||
self.include_geo_cache = include_geo_cache
|
||||
|
||||
def _get_all_patterns(self, context: TransitionContext) -> Set[str]:
|
||||
def _get_all_patterns(self, context: TransitionContext) -> set[str]:
|
||||
"""Get API-specific cache patterns."""
|
||||
patterns = set()
|
||||
|
||||
@@ -358,7 +354,7 @@ class RideCacheInvalidation(CacheInvalidationCallback):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _get_instance_patterns(self, context: TransitionContext) -> List[str]:
|
||||
def _get_instance_patterns(self, context: TransitionContext) -> list[str]:
|
||||
"""Include parent park cache patterns."""
|
||||
patterns = super()._get_instance_patterns(context)
|
||||
|
||||
|
||||
@@ -5,15 +5,14 @@ This module provides callback implementations that send notifications
|
||||
when state transitions occur.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import models
|
||||
|
||||
from ..callback_base import PostTransitionCallback, TransitionContext
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -31,7 +30,7 @@ class NotificationCallback(PostTransitionCallback):
|
||||
self,
|
||||
notification_type: str,
|
||||
recipient_field: str = "submitted_by",
|
||||
template_name: Optional[str] = None,
|
||||
template_name: str | None = None,
|
||||
include_transition_data: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -69,7 +68,7 @@ class NotificationCallback(PostTransitionCallback):
|
||||
|
||||
return True
|
||||
|
||||
def _get_recipient(self, instance: models.Model) -> Optional[Any]:
|
||||
def _get_recipient(self, instance: models.Model) -> Any | None:
|
||||
"""Get the notification recipient from the instance."""
|
||||
return getattr(instance, self.recipient_field, None)
|
||||
|
||||
@@ -82,7 +81,7 @@ class NotificationCallback(PostTransitionCallback):
|
||||
logger.warning("NotificationService not available")
|
||||
return None
|
||||
|
||||
def _build_extra_data(self, context: TransitionContext) -> Dict[str, Any]:
|
||||
def _build_extra_data(self, context: TransitionContext) -> dict[str, Any]:
|
||||
"""Build extra data for the notification."""
|
||||
extra_data = {}
|
||||
|
||||
@@ -401,7 +400,7 @@ class StatusChangeNotification(NotificationCallback):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
significant_states: Optional[List[str]] = None,
|
||||
significant_states: list[str] | None = None,
|
||||
notify_admins: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -429,10 +428,7 @@ class StatusChangeNotification(NotificationCallback):
|
||||
return False
|
||||
|
||||
# Only notify for significant status changes
|
||||
if context.target_state not in self.significant_states:
|
||||
return False
|
||||
|
||||
return True
|
||||
return context.target_state in self.significant_states
|
||||
|
||||
def execute(self, context: TransitionContext) -> bool:
|
||||
"""Execute the status change notification."""
|
||||
@@ -518,12 +514,12 @@ class ModerationNotificationCallback(NotificationCallback):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _get_notification_type(self, context: TransitionContext) -> Optional[str]:
|
||||
def _get_notification_type(self, context: TransitionContext) -> str | None:
|
||||
"""Get the specific notification type based on model and state."""
|
||||
key = (context.model_name, context.target_state)
|
||||
return self.NOTIFICATION_MAPPING.get(key)
|
||||
|
||||
def _get_recipient(self, instance: models.Model) -> Optional[Any]:
|
||||
def _get_recipient(self, instance: models.Model) -> Any | None:
|
||||
"""Get the appropriate recipient based on model type."""
|
||||
# Try common recipient fields
|
||||
for field in ['reporter', 'assigned_to', 'created_by', 'submitted_by']:
|
||||
|
||||
@@ -5,15 +5,14 @@ This module provides callback implementations that update related models
|
||||
when state transitions occur.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Type
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import models, transaction
|
||||
|
||||
from ..callback_base import PostTransitionCallback, TransitionContext
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -28,7 +27,7 @@ class RelatedModelUpdateCallback(PostTransitionCallback):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
update_function: Optional[Callable[[TransitionContext], bool]] = None,
|
||||
update_function: Callable[[TransitionContext], bool] | None = None,
|
||||
use_transaction: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -251,8 +250,8 @@ class ComputedFieldUpdateCallback(RelatedModelUpdateCallback):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
computed_fields: Optional[List[str]] = None,
|
||||
update_method: Optional[str] = None,
|
||||
computed_fields: list[str] | None = None,
|
||||
update_method: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -321,10 +320,7 @@ class RideStatusUpdateCallback(RelatedModelUpdateCallback):
|
||||
return False
|
||||
|
||||
# Only execute for Ride model
|
||||
if context.model_name != 'Ride':
|
||||
return False
|
||||
|
||||
return True
|
||||
return context.model_name == 'Ride'
|
||||
|
||||
def perform_update(self, context: TransitionContext) -> bool:
|
||||
"""Perform ride-specific status updates."""
|
||||
@@ -425,7 +421,7 @@ class ModerationQueueUpdateCallback(RelatedModelUpdateCallback):
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update queue items: {e}")
|
||||
|
||||
def _get_content_type_id(self, instance: models.Model) -> Optional[int]:
|
||||
def _get_content_type_id(self, instance: models.Model) -> int | None:
|
||||
"""Get content type ID for the instance."""
|
||||
try:
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
|
||||
@@ -5,14 +5,13 @@ This module provides centralized configuration for all FSM transition callbacks,
|
||||
including enable/disable settings, priorities, and environment-specific overrides.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import models
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -23,10 +22,10 @@ class TransitionCallbackConfig:
|
||||
notifications_enabled: bool = True
|
||||
cache_invalidation_enabled: bool = True
|
||||
related_updates_enabled: bool = True
|
||||
notification_template: Optional[str] = None
|
||||
cache_patterns: List[str] = field(default_factory=list)
|
||||
notification_template: str | None = None
|
||||
cache_patterns: list[str] = field(default_factory=list)
|
||||
priority: int = 100
|
||||
extra_data: Dict[str, Any] = field(default_factory=dict)
|
||||
extra_data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -35,7 +34,7 @@ class ModelCallbackConfig:
|
||||
|
||||
model_name: str
|
||||
field_name: str = 'status'
|
||||
transitions: Dict[tuple, TransitionCallbackConfig] = field(default_factory=dict)
|
||||
transitions: dict[tuple, TransitionCallbackConfig] = field(default_factory=dict)
|
||||
default_config: TransitionCallbackConfig = field(default_factory=TransitionCallbackConfig)
|
||||
|
||||
|
||||
@@ -63,20 +62,20 @@ class CallbackConfig:
|
||||
}
|
||||
|
||||
# Model-specific configurations
|
||||
MODEL_CONFIGS: Dict[str, ModelCallbackConfig] = {}
|
||||
MODEL_CONFIGS: dict[str, ModelCallbackConfig] = {}
|
||||
|
||||
def __init__(self):
|
||||
self._settings = self._load_settings()
|
||||
self._model_configs = self._build_model_configs()
|
||||
|
||||
def _load_settings(self) -> Dict[str, Any]:
|
||||
def _load_settings(self) -> dict[str, Any]:
|
||||
"""Load settings from Django configuration."""
|
||||
django_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
|
||||
merged = dict(self.DEFAULT_SETTINGS)
|
||||
merged.update(django_settings)
|
||||
return merged
|
||||
|
||||
def _build_model_configs(self) -> Dict[str, ModelCallbackConfig]:
|
||||
def _build_model_configs(self) -> dict[str, ModelCallbackConfig]:
|
||||
"""Build model-specific configurations."""
|
||||
return {
|
||||
'EditSubmission': ModelCallbackConfig(
|
||||
@@ -315,7 +314,7 @@ class CallbackConfig:
|
||||
model_name: str,
|
||||
source: str,
|
||||
target: str,
|
||||
) -> List[str]:
|
||||
) -> list[str]:
|
||||
"""Get cache invalidation patterns for a transition."""
|
||||
config = self.get_config(model_name, source, target)
|
||||
return config.cache_patterns
|
||||
@@ -325,14 +324,14 @@ class CallbackConfig:
|
||||
model_name: str,
|
||||
source: str,
|
||||
target: str,
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
"""Get notification template for a transition."""
|
||||
config = self.get_config(model_name, source, target)
|
||||
return config.notification_template
|
||||
|
||||
def register_model_config(
|
||||
self,
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
config: ModelCallbackConfig,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Transition decorator generation for django-fsm integration."""
|
||||
from typing import Any, Callable, List, Optional, Type, Union
|
||||
from functools import wraps
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
from django.db import models
|
||||
from django_fsm import transition
|
||||
@@ -14,12 +15,11 @@ from .callback_base import (
|
||||
callback_registry,
|
||||
)
|
||||
from .signals import (
|
||||
pre_state_transition,
|
||||
post_state_transition,
|
||||
pre_state_transition,
|
||||
state_transition_failed,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -193,10 +193,10 @@ def create_transition_method(
|
||||
source: str,
|
||||
target: str,
|
||||
field_name: str,
|
||||
permission_guard: Optional[Callable] = None,
|
||||
on_success: Optional[Callable] = None,
|
||||
on_error: Optional[Callable] = None,
|
||||
callbacks: Optional[List[BaseTransitionCallback]] = None,
|
||||
permission_guard: Callable | None = None,
|
||||
on_success: Callable | None = None,
|
||||
on_error: Callable | None = None,
|
||||
callbacks: list[BaseTransitionCallback] | None = None,
|
||||
enable_callbacks: bool = True,
|
||||
emit_signals: bool = True,
|
||||
) -> Callable:
|
||||
@@ -259,7 +259,7 @@ def create_transition_method(
|
||||
|
||||
|
||||
def register_method_callbacks(
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
method: Callable,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -275,14 +275,11 @@ def register_method_callbacks(
|
||||
if not metadata or not metadata.get('callbacks'):
|
||||
return
|
||||
|
||||
from .callback_base import CallbackStage, PostTransitionCallback, PreTransitionCallback
|
||||
from .callback_base import CallbackStage, PreTransitionCallback
|
||||
|
||||
for callback in metadata['callbacks']:
|
||||
# Determine stage from callback type
|
||||
if isinstance(callback, PreTransitionCallback):
|
||||
stage = CallbackStage.PRE
|
||||
else:
|
||||
stage = CallbackStage.POST
|
||||
stage = CallbackStage.PRE if isinstance(callback, PreTransitionCallback) else CallbackStage.POST
|
||||
|
||||
callback_registry.register(
|
||||
model_class=model_class,
|
||||
@@ -302,7 +299,7 @@ class TransitionMethodFactory:
|
||||
source: str,
|
||||
target: str,
|
||||
field_name: str = "status",
|
||||
permission_guard: Optional[Callable] = None,
|
||||
permission_guard: Callable | None = None,
|
||||
enable_callbacks: bool = True,
|
||||
emit_signals: bool = True,
|
||||
) -> Callable:
|
||||
@@ -353,7 +350,7 @@ class TransitionMethodFactory:
|
||||
source: str,
|
||||
target: str,
|
||||
field_name: str = "status",
|
||||
permission_guard: Optional[Callable] = None,
|
||||
permission_guard: Callable | None = None,
|
||||
enable_callbacks: bool = True,
|
||||
emit_signals: bool = True,
|
||||
) -> Callable:
|
||||
@@ -404,7 +401,7 @@ class TransitionMethodFactory:
|
||||
source: str,
|
||||
target: str,
|
||||
field_name: str = "status",
|
||||
permission_guard: Optional[Callable] = None,
|
||||
permission_guard: Callable | None = None,
|
||||
enable_callbacks: bool = True,
|
||||
emit_signals: bool = True,
|
||||
) -> Callable:
|
||||
@@ -456,8 +453,8 @@ class TransitionMethodFactory:
|
||||
source: str,
|
||||
target: str,
|
||||
field_name: str = "status",
|
||||
permission_guard: Optional[Callable] = None,
|
||||
docstring: Optional[str] = None,
|
||||
permission_guard: Callable | None = None,
|
||||
docstring: str | None = None,
|
||||
enable_callbacks: bool = True,
|
||||
emit_signals: bool = True,
|
||||
) -> Callable:
|
||||
|
||||
@@ -12,7 +12,8 @@ Example usage:
|
||||
'code': e.error_code
|
||||
}, status=403)
|
||||
"""
|
||||
from typing import Any, Optional, List, Dict
|
||||
from typing import Any
|
||||
|
||||
from django_fsm import TransitionNotAllowed
|
||||
|
||||
|
||||
@@ -42,10 +43,10 @@ class TransitionPermissionDenied(TransitionNotAllowed):
|
||||
self,
|
||||
message: str = "Permission denied for this transition",
|
||||
error_code: str = "PERMISSION_DENIED",
|
||||
user_message: Optional[str] = None,
|
||||
required_roles: Optional[List[str]] = None,
|
||||
user_role: Optional[str] = None,
|
||||
guard: Optional[Any] = None,
|
||||
user_message: str | None = None,
|
||||
required_roles: list[str] | None = None,
|
||||
user_role: str | None = None,
|
||||
guard: Any | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize permission denied exception.
|
||||
@@ -65,7 +66,7 @@ class TransitionPermissionDenied(TransitionNotAllowed):
|
||||
self.user_role = user_role
|
||||
self.guard = guard
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert exception to dictionary for API responses.
|
||||
|
||||
@@ -106,10 +107,10 @@ class TransitionValidationError(TransitionNotAllowed):
|
||||
self,
|
||||
message: str = "Transition validation failed",
|
||||
error_code: str = "VALIDATION_FAILED",
|
||||
user_message: Optional[str] = None,
|
||||
field_name: Optional[str] = None,
|
||||
current_state: Optional[str] = None,
|
||||
guard: Optional[Any] = None,
|
||||
user_message: str | None = None,
|
||||
field_name: str | None = None,
|
||||
current_state: str | None = None,
|
||||
guard: Any | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize validation error exception.
|
||||
@@ -129,7 +130,7 @@ class TransitionValidationError(TransitionNotAllowed):
|
||||
self.current_state = current_state
|
||||
self.guard = guard
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert exception to dictionary for API responses.
|
||||
|
||||
@@ -168,10 +169,10 @@ class TransitionNotAvailable(TransitionNotAllowed):
|
||||
self,
|
||||
message: str = "This transition is not available",
|
||||
error_code: str = "TRANSITION_NOT_AVAILABLE",
|
||||
user_message: Optional[str] = None,
|
||||
current_state: Optional[str] = None,
|
||||
requested_transition: Optional[str] = None,
|
||||
available_transitions: Optional[List[str]] = None,
|
||||
user_message: str | None = None,
|
||||
current_state: str | None = None,
|
||||
requested_transition: str | None = None,
|
||||
available_transitions: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize transition not available exception.
|
||||
@@ -191,7 +192,7 @@ class TransitionNotAvailable(TransitionNotAllowed):
|
||||
self.requested_transition = requested_transition
|
||||
self.available_transitions = available_transitions or []
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""
|
||||
Convert exception to dictionary for API responses.
|
||||
|
||||
@@ -267,12 +268,12 @@ def get_permission_error_message(
|
||||
# "You need moderator permissions to approve submissions..."
|
||||
"""
|
||||
from .guards import (
|
||||
PermissionGuard,
|
||||
OwnershipGuard,
|
||||
AssignmentGuard,
|
||||
MODERATOR_ROLES,
|
||||
ADMIN_ROLES,
|
||||
MODERATOR_ROLES,
|
||||
SUPERUSER_ROLES,
|
||||
AssignmentGuard,
|
||||
OwnershipGuard,
|
||||
PermissionGuard,
|
||||
)
|
||||
|
||||
if hasattr(guard, "get_error_message"):
|
||||
@@ -348,7 +349,7 @@ def get_state_error_message(
|
||||
def format_transition_error(
|
||||
exception: Exception,
|
||||
include_details: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Format a transition exception for API response.
|
||||
|
||||
@@ -412,7 +413,7 @@ def raise_permission_denied(
|
||||
user_role = get_user_role(user) if user else None
|
||||
|
||||
error_code = TransitionPermissionDenied.ERROR_CODE_PERMISSION_DENIED_ROLE
|
||||
required_roles: List[str] = []
|
||||
required_roles: list[str] = []
|
||||
|
||||
if isinstance(guard, PermissionGuard):
|
||||
required_roles = guard.get_required_roles()
|
||||
@@ -431,8 +432,8 @@ def raise_permission_denied(
|
||||
|
||||
def raise_validation_error(
|
||||
guard: Any,
|
||||
current_state: Optional[str] = None,
|
||||
field_name: Optional[str] = None,
|
||||
current_state: str | None = None,
|
||||
field_name: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Raise a TransitionValidationError exception with proper context.
|
||||
@@ -445,7 +446,7 @@ def raise_validation_error(
|
||||
Raises:
|
||||
TransitionValidationError: Always raised with proper context
|
||||
"""
|
||||
from .guards import StateGuard, MetadataGuard
|
||||
from .guards import MetadataGuard, StateGuard
|
||||
|
||||
error_code = TransitionValidationError.ERROR_CODE_VALIDATION_FAILED
|
||||
user_message = "Validation failed for this transition"
|
||||
|
||||
@@ -47,7 +47,7 @@ See Also:
|
||||
- apps.core.choices.registry: The central choice registry
|
||||
- apps.core.state_machine.mixins.StateMachineMixin: Convenience helpers
|
||||
"""
|
||||
from typing import Any, Optional
|
||||
from typing import Any
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
from django_fsm import FSMField as DjangoFSMField
|
||||
@@ -147,7 +147,7 @@ class RichFSMField(DjangoFSMField):
|
||||
f"'{value}' is deprecated and cannot be used for new entries"
|
||||
)
|
||||
|
||||
def get_rich_choice(self, value: str) -> Optional[RichChoice]:
|
||||
def get_rich_choice(self, value: str) -> RichChoice | None:
|
||||
"""Return the RichChoice object for a given state value."""
|
||||
return registry.get_choice(self.choice_group, value, self.domain)
|
||||
|
||||
|
||||
@@ -18,7 +18,8 @@ Example usage:
|
||||
OwnershipGuard()
|
||||
], operator='OR')
|
||||
"""
|
||||
from typing import Callable, Dict, List, Optional, Any, Tuple, Union
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
# Valid user roles in order of increasing privilege
|
||||
VALID_ROLES = ["USER", "MODERATOR", "ADMIN", "SUPERUSER"]
|
||||
@@ -62,9 +63,9 @@ class PermissionGuard:
|
||||
requires_moderator: bool = False,
|
||||
requires_admin: bool = False,
|
||||
requires_superuser: bool = False,
|
||||
required_roles: Optional[List[str]] = None,
|
||||
custom_check: Optional[Callable] = None,
|
||||
error_message: Optional[str] = None,
|
||||
required_roles: list[str] | None = None,
|
||||
custom_check: Callable | None = None,
|
||||
error_message: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize permission guard.
|
||||
@@ -83,10 +84,10 @@ class PermissionGuard:
|
||||
self.required_roles = required_roles
|
||||
self.custom_check = custom_check
|
||||
self._custom_error_message = error_message
|
||||
self._last_error_code: Optional[str] = None
|
||||
self._last_error_code: str | None = None
|
||||
|
||||
@property
|
||||
def error_code(self) -> Optional[str]:
|
||||
def error_code(self) -> str | None:
|
||||
"""Return the error code from the last failed check."""
|
||||
return self._last_error_code
|
||||
|
||||
@@ -126,16 +127,14 @@ class PermissionGuard:
|
||||
return False
|
||||
|
||||
# Check moderator (includes admin and superuser)
|
||||
elif self.requires_moderator:
|
||||
if not is_moderator_or_above(user):
|
||||
self._last_error_code = self.ERROR_CODE_PERMISSION_DENIED_ROLE
|
||||
return False
|
||||
elif self.requires_moderator and not is_moderator_or_above(user):
|
||||
self._last_error_code = self.ERROR_CODE_PERMISSION_DENIED_ROLE
|
||||
return False
|
||||
|
||||
# Apply custom check if provided
|
||||
if self.custom_check:
|
||||
if not self.custom_check(instance, user):
|
||||
self._last_error_code = self.ERROR_CODE_PERMISSION_DENIED_CUSTOM
|
||||
return False
|
||||
if self.custom_check and not self.custom_check(instance, user):
|
||||
self._last_error_code = self.ERROR_CODE_PERMISSION_DENIED_CUSTOM
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -162,7 +161,7 @@ class PermissionGuard:
|
||||
return "This transition requires special permissions"
|
||||
return "This transition is not allowed"
|
||||
|
||||
def get_required_roles(self) -> List[str]:
|
||||
def get_required_roles(self) -> list[str]:
|
||||
"""
|
||||
Return list of roles that would satisfy this guard.
|
||||
|
||||
@@ -207,10 +206,10 @@ class OwnershipGuard:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
owner_fields: Optional[List[str]] = None,
|
||||
owner_fields: list[str] | None = None,
|
||||
allow_moderator_override: bool = False,
|
||||
allow_admin_override: bool = False,
|
||||
error_message: Optional[str] = None,
|
||||
error_message: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize ownership guard.
|
||||
@@ -225,10 +224,10 @@ class OwnershipGuard:
|
||||
self.allow_moderator_override = allow_moderator_override
|
||||
self.allow_admin_override = allow_admin_override
|
||||
self._custom_error_message = error_message
|
||||
self._last_error_code: Optional[str] = None
|
||||
self._last_error_code: str | None = None
|
||||
|
||||
@property
|
||||
def error_code(self) -> Optional[str]:
|
||||
def error_code(self) -> str | None:
|
||||
"""Return the error code from the last failed check."""
|
||||
return self._last_error_code
|
||||
|
||||
@@ -305,10 +304,10 @@ class AssignmentGuard:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
assignment_fields: Optional[List[str]] = None,
|
||||
assignment_fields: list[str] | None = None,
|
||||
require_assignment: bool = False,
|
||||
allow_admin_override: bool = False,
|
||||
error_message: Optional[str] = None,
|
||||
error_message: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize assignment guard.
|
||||
@@ -323,10 +322,10 @@ class AssignmentGuard:
|
||||
self.require_assignment = require_assignment
|
||||
self.allow_admin_override = allow_admin_override
|
||||
self._custom_error_message = error_message
|
||||
self._last_error_code: Optional[str] = None
|
||||
self._last_error_code: str | None = None
|
||||
|
||||
@property
|
||||
def error_code(self) -> Optional[str]:
|
||||
def error_code(self) -> str | None:
|
||||
"""Return the error code from the last failed check."""
|
||||
return self._last_error_code
|
||||
|
||||
@@ -399,10 +398,10 @@ class StateGuard:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
allowed_states: Optional[List[str]] = None,
|
||||
blocked_states: Optional[List[str]] = None,
|
||||
allowed_states: list[str] | None = None,
|
||||
blocked_states: list[str] | None = None,
|
||||
state_field: str = "status",
|
||||
error_message: Optional[str] = None,
|
||||
error_message: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize state guard.
|
||||
@@ -417,11 +416,11 @@ class StateGuard:
|
||||
self.blocked_states = blocked_states or []
|
||||
self.state_field = state_field
|
||||
self._custom_error_message = error_message
|
||||
self._last_error_code: Optional[str] = None
|
||||
self._current_state: Optional[str] = None
|
||||
self._last_error_code: str | None = None
|
||||
self._current_state: str | None = None
|
||||
|
||||
@property
|
||||
def error_code(self) -> Optional[str]:
|
||||
def error_code(self) -> str | None:
|
||||
"""Return the error code from the last failed check."""
|
||||
return self._last_error_code
|
||||
|
||||
@@ -445,10 +444,9 @@ class StateGuard:
|
||||
return False
|
||||
|
||||
# Check allowed states if specified
|
||||
if self.allowed_states is not None:
|
||||
if self._current_state not in self.allowed_states:
|
||||
self._last_error_code = self.ERROR_CODE_INVALID_STATE
|
||||
return False
|
||||
if self.allowed_states is not None and self._current_state not in self.allowed_states:
|
||||
self._last_error_code = self.ERROR_CODE_INVALID_STATE
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -485,9 +483,9 @@ class MetadataGuard:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
required_fields: Optional[List[str]] = None,
|
||||
required_fields: list[str] | None = None,
|
||||
check_not_empty: bool = True,
|
||||
error_message: Optional[str] = None,
|
||||
error_message: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize metadata guard.
|
||||
@@ -500,11 +498,11 @@ class MetadataGuard:
|
||||
self.required_fields = required_fields or []
|
||||
self.check_not_empty = check_not_empty
|
||||
self._custom_error_message = error_message
|
||||
self._last_error_code: Optional[str] = None
|
||||
self._failed_field: Optional[str] = None
|
||||
self._last_error_code: str | None = None
|
||||
self._failed_field: str | None = None
|
||||
|
||||
@property
|
||||
def error_code(self) -> Optional[str]:
|
||||
def error_code(self) -> str | None:
|
||||
"""Return the error code from the last failed check."""
|
||||
return self._last_error_code
|
||||
|
||||
@@ -582,9 +580,9 @@ class CompositeGuard:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guards: List[Callable],
|
||||
guards: list[Callable],
|
||||
operator: str = "AND",
|
||||
error_message: Optional[str] = None,
|
||||
error_message: str | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize composite guard.
|
||||
@@ -597,11 +595,11 @@ class CompositeGuard:
|
||||
self.guards = guards
|
||||
self.operator = operator.upper()
|
||||
self._custom_error_message = error_message
|
||||
self._last_error_code: Optional[str] = None
|
||||
self._failed_guards: List[Callable] = []
|
||||
self._last_error_code: str | None = None
|
||||
self._failed_guards: list[Callable] = []
|
||||
|
||||
@property
|
||||
def error_code(self) -> Optional[str]:
|
||||
def error_code(self) -> str | None:
|
||||
"""Return the error code from the last failed check."""
|
||||
return self._last_error_code
|
||||
|
||||
@@ -658,7 +656,7 @@ class CompositeGuard:
|
||||
|
||||
|
||||
def create_ownership_guard(
|
||||
owner_fields: Optional[List[str]] = None,
|
||||
owner_fields: list[str] | None = None,
|
||||
allow_moderator_override: bool = False,
|
||||
allow_admin_override: bool = False,
|
||||
) -> OwnershipGuard:
|
||||
@@ -681,7 +679,7 @@ def create_ownership_guard(
|
||||
|
||||
|
||||
def create_assignment_guard(
|
||||
assignment_fields: Optional[List[str]] = None,
|
||||
assignment_fields: list[str] | None = None,
|
||||
require_assignment: bool = False,
|
||||
allow_admin_override: bool = False,
|
||||
) -> AssignmentGuard:
|
||||
@@ -704,7 +702,7 @@ def create_assignment_guard(
|
||||
|
||||
|
||||
def create_composite_guard(
|
||||
guards: List[Callable],
|
||||
guards: list[Callable],
|
||||
operator: str = "AND",
|
||||
) -> CompositeGuard:
|
||||
"""
|
||||
@@ -728,7 +726,7 @@ ESCALATION_LEVEL_ROLES = {
|
||||
}
|
||||
|
||||
|
||||
def extract_guards_from_metadata(metadata: Dict[str, Any]) -> List[Callable]:
|
||||
def extract_guards_from_metadata(metadata: dict[str, Any]) -> list[Callable]:
|
||||
"""
|
||||
Convert RichChoice metadata to guard functions.
|
||||
|
||||
@@ -823,7 +821,7 @@ def extract_guards_from_metadata(metadata: Dict[str, Any]) -> List[Callable]:
|
||||
return guards
|
||||
|
||||
|
||||
def create_permission_guard(metadata: Dict[str, Any]) -> PermissionGuard:
|
||||
def create_permission_guard(metadata: dict[str, Any]) -> PermissionGuard:
|
||||
"""
|
||||
Create a permission guard from RichChoice metadata.
|
||||
|
||||
@@ -874,7 +872,7 @@ def create_permission_guard(metadata: Dict[str, Any]) -> PermissionGuard:
|
||||
)
|
||||
|
||||
|
||||
def validate_guard_metadata(metadata: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
||||
def validate_guard_metadata(metadata: dict[str, Any]) -> tuple[bool, list[str]]:
|
||||
"""
|
||||
Validate that metadata contains valid guard configuration.
|
||||
|
||||
@@ -892,12 +890,11 @@ def validate_guard_metadata(metadata: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
||||
|
||||
# Validate escalation_level
|
||||
escalation_level = metadata.get("escalation_level")
|
||||
if escalation_level:
|
||||
if escalation_level.lower() not in ESCALATION_LEVEL_ROLES:
|
||||
errors.append(
|
||||
f"Invalid escalation_level: {escalation_level}. "
|
||||
f"Must be one of: {', '.join(ESCALATION_LEVEL_ROLES.keys())}"
|
||||
)
|
||||
if escalation_level and escalation_level.lower() not in ESCALATION_LEVEL_ROLES:
|
||||
errors.append(
|
||||
f"Invalid escalation_level: {escalation_level}. "
|
||||
f"Must be one of: {', '.join(ESCALATION_LEVEL_ROLES.keys())}"
|
||||
)
|
||||
|
||||
# Validate required_permissions is a list
|
||||
required_permissions = metadata.get("required_permissions")
|
||||
@@ -927,7 +924,7 @@ class GuardRegistry:
|
||||
"""Registry for storing and retrieving guard functions."""
|
||||
|
||||
_instance: Optional["GuardRegistry"] = None
|
||||
_guards: Dict[str, Callable]
|
||||
_guards: dict[str, Callable]
|
||||
|
||||
def __new__(cls):
|
||||
"""Implement singleton pattern."""
|
||||
@@ -946,7 +943,7 @@ class GuardRegistry:
|
||||
"""
|
||||
self._guards[name] = guard
|
||||
|
||||
def get_guard(self, name: str) -> Optional[Callable]:
|
||||
def get_guard(self, name: str) -> Callable | None:
|
||||
"""
|
||||
Retrieve a guard by name.
|
||||
|
||||
@@ -961,9 +958,9 @@ class GuardRegistry:
|
||||
def apply_guards(
|
||||
self,
|
||||
instance: Any,
|
||||
guards: List[Callable],
|
||||
user: Optional[Any] = None,
|
||||
) -> Tuple[bool, Optional[str]]:
|
||||
guards: list[Callable],
|
||||
user: Any | None = None,
|
||||
) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Apply multiple guards.
|
||||
|
||||
@@ -993,8 +990,8 @@ class GuardRegistry:
|
||||
|
||||
|
||||
def create_condition_from_metadata(
|
||||
metadata: Dict[str, Any],
|
||||
) -> Optional[Callable]:
|
||||
metadata: dict[str, Any],
|
||||
) -> Callable | None:
|
||||
"""
|
||||
Create FSM condition from metadata.
|
||||
|
||||
@@ -1011,10 +1008,7 @@ def create_condition_from_metadata(
|
||||
|
||||
def combined_condition(instance, user=None):
|
||||
"""Combined condition from all guards."""
|
||||
for guard in guards:
|
||||
if not guard(instance, user):
|
||||
return False
|
||||
return True
|
||||
return all(guard(instance, user) for guard in guards)
|
||||
|
||||
return combined_condition
|
||||
|
||||
@@ -1022,7 +1016,7 @@ def create_condition_from_metadata(
|
||||
# Helper functions for permission checks
|
||||
|
||||
|
||||
def get_user_role(user: Any) -> Optional[str]:
|
||||
def get_user_role(user: Any) -> str | None:
|
||||
"""
|
||||
Get the user's role from the role field.
|
||||
|
||||
@@ -1043,7 +1037,7 @@ def get_user_role(user: Any) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def has_role(user: Any, required_roles: List[str]) -> bool:
|
||||
def has_role(user: Any, required_roles: list[str]) -> bool:
|
||||
"""
|
||||
Check if user has one of the required roles.
|
||||
|
||||
@@ -1083,9 +1077,8 @@ def has_role(user: Any, required_roles: List[str]) -> bool:
|
||||
return True
|
||||
|
||||
# Check for staff status (treat as moderator)
|
||||
if hasattr(user, "is_staff") and user.is_staff:
|
||||
if "MODERATOR" in required_roles:
|
||||
return True
|
||||
if hasattr(user, "is_staff") and user.is_staff and "MODERATOR" in required_roles:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -1191,7 +1184,7 @@ def has_permission(user: Any, permission: str) -> bool:
|
||||
|
||||
def create_guard_from_drf_permission(
|
||||
permission_class: type,
|
||||
error_message: Optional[str] = None,
|
||||
error_message: str | None = None,
|
||||
) -> Callable:
|
||||
"""
|
||||
Create an FSM guard from a DRF permission class.
|
||||
@@ -1225,13 +1218,13 @@ def create_guard_from_drf_permission(
|
||||
class DRFPermissionGuard:
|
||||
"""Guard that wraps a DRF permission class."""
|
||||
|
||||
def __init__(self, perm_class: type, err_msg: Optional[str] = None):
|
||||
def __init__(self, perm_class: type, err_msg: str | None = None):
|
||||
self.permission_class = perm_class
|
||||
self._custom_error_message = err_msg
|
||||
self._last_error_code: Optional[str] = None
|
||||
self._last_error_code: str | None = None
|
||||
|
||||
@property
|
||||
def error_code(self) -> Optional[str]:
|
||||
def error_code(self) -> str | None:
|
||||
return self._last_error_code
|
||||
|
||||
def __call__(self, instance: Any, user: Any = None) -> bool:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Model integration utilities for applying state machines to Django models."""
|
||||
from typing import Type, Optional, Dict, Any, List, Callable
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from django.db import models
|
||||
from django_fsm import can_proceed
|
||||
@@ -8,23 +9,21 @@ from apps.core.state_machine.builder import (
|
||||
StateTransitionBuilder,
|
||||
determine_method_name_for_transition,
|
||||
)
|
||||
from apps.core.state_machine.decorators import TransitionMethodFactory
|
||||
from apps.core.state_machine.guards import (
|
||||
CompositeGuard,
|
||||
create_guard_from_drf_permission,
|
||||
extract_guards_from_metadata,
|
||||
)
|
||||
from apps.core.state_machine.registry import (
|
||||
TransitionInfo,
|
||||
registry_instance,
|
||||
)
|
||||
from apps.core.state_machine.validators import MetadataValidator
|
||||
from apps.core.state_machine.decorators import TransitionMethodFactory
|
||||
from apps.core.state_machine.guards import (
|
||||
create_permission_guard,
|
||||
extract_guards_from_metadata,
|
||||
create_condition_from_metadata,
|
||||
create_guard_from_drf_permission,
|
||||
CompositeGuard,
|
||||
)
|
||||
|
||||
|
||||
def apply_state_machine(
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
field_name: str,
|
||||
choice_group: str,
|
||||
domain: str = "core",
|
||||
@@ -48,7 +47,7 @@ def apply_state_machine(
|
||||
if not result.is_valid:
|
||||
error_messages = [str(e) for e in result.errors]
|
||||
raise ValueError(
|
||||
f"Cannot apply state machine - validation failed:\n"
|
||||
"Cannot apply state machine - validation failed:\n"
|
||||
+ "\n".join(error_messages)
|
||||
)
|
||||
|
||||
@@ -62,7 +61,7 @@ def apply_state_machine(
|
||||
|
||||
|
||||
def generate_transition_methods_for_model(
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
field_name: str,
|
||||
choice_group: str,
|
||||
domain: str = "core",
|
||||
@@ -100,7 +99,7 @@ def generate_transition_methods_for_model(
|
||||
all_guards = guards + target_guards
|
||||
|
||||
# Create combined guard if we have multiple guards
|
||||
combined_guard: Optional[Callable] = None
|
||||
combined_guard: Callable | None = None
|
||||
if len(all_guards) == 1:
|
||||
combined_guard = all_guards[0]
|
||||
elif len(all_guards) > 1:
|
||||
@@ -149,7 +148,7 @@ class StateMachineModelMixin:
|
||||
|
||||
def get_available_state_transitions(
|
||||
self, field_name: str = "status"
|
||||
) -> List[TransitionInfo]:
|
||||
) -> list[TransitionInfo]:
|
||||
"""
|
||||
Get available transitions from current state.
|
||||
|
||||
@@ -176,7 +175,7 @@ class StateMachineModelMixin:
|
||||
self,
|
||||
target_state: str,
|
||||
field_name: str = "status",
|
||||
user: Optional[Any] = None,
|
||||
user: Any | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if transition to target state is allowed.
|
||||
@@ -219,7 +218,7 @@ class StateMachineModelMixin:
|
||||
|
||||
def get_transition_method(
|
||||
self, target_state: str, field_name: str = "status"
|
||||
) -> Optional[Callable]:
|
||||
) -> Callable | None:
|
||||
"""
|
||||
Get the transition method for moving to target state.
|
||||
|
||||
@@ -252,7 +251,7 @@ class StateMachineModelMixin:
|
||||
self,
|
||||
target_state: str,
|
||||
field_name: str = "status",
|
||||
user: Optional[Any] = None,
|
||||
user: Any | None = None,
|
||||
**kwargs: Any,
|
||||
) -> bool:
|
||||
"""
|
||||
@@ -299,7 +298,7 @@ def state_machine_model(
|
||||
Decorator function
|
||||
"""
|
||||
|
||||
def decorator(model_class: Type[models.Model]) -> Type[models.Model]:
|
||||
def decorator(model_class: type[models.Model]) -> type[models.Model]:
|
||||
"""Apply state machine to model class."""
|
||||
apply_state_machine(model_class, field_name, choice_group, domain)
|
||||
return model_class
|
||||
@@ -308,7 +307,7 @@ def state_machine_model(
|
||||
|
||||
|
||||
def validate_model_state_machine(
|
||||
model_class: Type[models.Model], field_name: str
|
||||
model_class: type[models.Model], field_name: str
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure model is properly configured with state machine.
|
||||
@@ -345,7 +344,7 @@ def validate_model_state_machine(
|
||||
if not result.is_valid:
|
||||
error_messages = [str(e) for e in result.errors]
|
||||
raise ValueError(
|
||||
f"State machine validation failed:\n" + "\n".join(error_messages)
|
||||
"State machine validation failed:\n" + "\n".join(error_messages)
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@@ -38,12 +38,12 @@ See Also:
|
||||
- apps.core.state_machine.fields.RichFSMField: The FSM field implementation
|
||||
- django_fsm.can_proceed: FSM transition checking utility
|
||||
"""
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from django.db import models
|
||||
from django_fsm import can_proceed
|
||||
|
||||
|
||||
# Default transition metadata for styling
|
||||
TRANSITION_METADATA = {
|
||||
# Approval transitions
|
||||
@@ -71,7 +71,7 @@ TRANSITION_METADATA = {
|
||||
}
|
||||
|
||||
|
||||
def _get_transition_metadata(transition_name: str) -> Dict[str, Any]:
|
||||
def _get_transition_metadata(transition_name: str) -> dict[str, Any]:
|
||||
"""Get metadata for a transition by name."""
|
||||
if transition_name in TRANSITION_METADATA:
|
||||
return TRANSITION_METADATA[transition_name].copy()
|
||||
@@ -161,12 +161,12 @@ class StateMachineMixin(models.Model):
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
def get_state_value(self, field_name: Optional[str] = None) -> Any:
|
||||
def get_state_value(self, field_name: str | None = None) -> Any:
|
||||
"""Return the raw state value for the given field (default is `state`)."""
|
||||
name = field_name or self.state_field_name
|
||||
return getattr(self, name, None)
|
||||
|
||||
def get_state_display_value(self, field_name: Optional[str] = None) -> str:
|
||||
def get_state_display_value(self, field_name: str | None = None) -> str:
|
||||
"""Return the display label for the current state, if available."""
|
||||
name = field_name or self.state_field_name
|
||||
getter = getattr(self, f"get_{name}_display", None)
|
||||
@@ -175,7 +175,7 @@ class StateMachineMixin(models.Model):
|
||||
value = getattr(self, name, "")
|
||||
return value if value is not None else ""
|
||||
|
||||
def get_state_choice(self, field_name: Optional[str] = None):
|
||||
def get_state_choice(self, field_name: str | None = None):
|
||||
"""Return the RichChoice object when the field provides one."""
|
||||
name = field_name or self.state_field_name
|
||||
getter = getattr(self, f"get_{name}_rich_choice", None)
|
||||
@@ -193,7 +193,7 @@ class StateMachineMixin(models.Model):
|
||||
return can_proceed(method)
|
||||
|
||||
def get_available_transitions(
|
||||
self, field_name: Optional[str] = None
|
||||
self, field_name: str | None = None
|
||||
) -> Iterable[Any]:
|
||||
"""Return available transitions when helpers are present."""
|
||||
name = field_name or self.state_field_name
|
||||
@@ -203,12 +203,12 @@ class StateMachineMixin(models.Model):
|
||||
return helper() # type: ignore[misc]
|
||||
return []
|
||||
|
||||
def is_in_state(self, state: str, field_name: Optional[str] = None) -> bool:
|
||||
def is_in_state(self, state: str, field_name: str | None = None) -> bool:
|
||||
"""Convenience check for comparing the current state."""
|
||||
current_state = self.get_state_value(field_name)
|
||||
return current_state == state
|
||||
|
||||
def get_available_user_transitions(self, user) -> List[Dict[str, Any]]:
|
||||
def get_available_user_transitions(self, user) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get transitions available to the given user.
|
||||
|
||||
|
||||
@@ -5,20 +5,18 @@ This module provides tools for monitoring callback execution,
|
||||
tracking performance, and debugging transition issues.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from django.conf import settings
|
||||
from django.db import models
|
||||
|
||||
from .callback_base import TransitionContext
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -35,9 +33,9 @@ class CallbackExecutionRecord:
|
||||
timestamp: datetime
|
||||
duration_ms: float
|
||||
success: bool
|
||||
error_message: Optional[str] = None
|
||||
instance_id: Optional[int] = None
|
||||
user_id: Optional[int] = None
|
||||
error_message: str | None = None
|
||||
instance_id: int | None = None
|
||||
user_id: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -51,8 +49,8 @@ class CallbackStats:
|
||||
total_duration_ms: float = 0.0
|
||||
min_duration_ms: float = float('inf')
|
||||
max_duration_ms: float = 0.0
|
||||
last_execution: Optional[datetime] = None
|
||||
last_error: Optional[str] = None
|
||||
last_execution: datetime | None = None
|
||||
last_error: str | None = None
|
||||
|
||||
@property
|
||||
def avg_duration_ms(self) -> float:
|
||||
@@ -72,7 +70,7 @@ class CallbackStats:
|
||||
self,
|
||||
duration_ms: float,
|
||||
success: bool,
|
||||
error_message: Optional[str] = None,
|
||||
error_message: str | None = None,
|
||||
) -> None:
|
||||
"""Record a callback execution."""
|
||||
self.total_executions += 1
|
||||
@@ -114,10 +112,10 @@ class CallbackMonitor:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._stats: Dict[str, CallbackStats] = defaultdict(
|
||||
self._stats: dict[str, CallbackStats] = defaultdict(
|
||||
lambda: CallbackStats(callback_name="")
|
||||
)
|
||||
self._recent_executions: List[CallbackExecutionRecord] = []
|
||||
self._recent_executions: list[CallbackExecutionRecord] = []
|
||||
self._max_recent_records = 1000
|
||||
self._enabled = self._check_enabled()
|
||||
self._debug_mode = self._check_debug_mode()
|
||||
@@ -159,7 +157,7 @@ class CallbackMonitor:
|
||||
stage: str,
|
||||
duration_ms: float,
|
||||
success: bool,
|
||||
error_message: Optional[str] = None,
|
||||
error_message: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Record a callback execution.
|
||||
@@ -220,7 +218,7 @@ class CallbackMonitor:
|
||||
else:
|
||||
logger.warning(f"{log_message} - Error: {record.error_message}")
|
||||
|
||||
def get_stats(self, callback_name: Optional[str] = None) -> Dict[str, CallbackStats]:
|
||||
def get_stats(self, callback_name: str | None = None) -> dict[str, CallbackStats]:
|
||||
"""
|
||||
Get callback statistics.
|
||||
|
||||
@@ -239,10 +237,10 @@ class CallbackMonitor:
|
||||
def get_recent_executions(
|
||||
self,
|
||||
limit: int = 100,
|
||||
callback_name: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
success_only: Optional[bool] = None,
|
||||
) -> List[CallbackExecutionRecord]:
|
||||
callback_name: str | None = None,
|
||||
model_name: str | None = None,
|
||||
success_only: bool | None = None,
|
||||
) -> list[CallbackExecutionRecord]:
|
||||
"""
|
||||
Get recent execution records.
|
||||
|
||||
@@ -268,12 +266,12 @@ class CallbackMonitor:
|
||||
# Return most recent first
|
||||
return list(reversed(records[-limit:]))
|
||||
|
||||
def get_failure_summary(self) -> Dict[str, Any]:
|
||||
def get_failure_summary(self) -> dict[str, Any]:
|
||||
"""Get a summary of callback failures."""
|
||||
failures = [r for r in self._recent_executions if not r.success]
|
||||
|
||||
# Group by callback
|
||||
by_callback: Dict[str, List[CallbackExecutionRecord]] = defaultdict(list)
|
||||
by_callback: dict[str, list[CallbackExecutionRecord]] = defaultdict(list)
|
||||
for record in failures:
|
||||
by_callback[record.callback_name].append(record)
|
||||
|
||||
@@ -292,7 +290,7 @@ class CallbackMonitor:
|
||||
|
||||
return summary
|
||||
|
||||
def get_performance_report(self) -> Dict[str, Any]:
|
||||
def get_performance_report(self) -> dict[str, Any]:
|
||||
"""Get a performance report for all callbacks."""
|
||||
report = {
|
||||
'callbacks': {},
|
||||
@@ -361,7 +359,7 @@ class TimedCallbackExecution:
|
||||
self.stage = stage
|
||||
self.start_time = 0.0
|
||||
self.success = True
|
||||
self.error_message: Optional[str] = None
|
||||
self.error_message: str | None = None
|
||||
|
||||
def __enter__(self) -> 'TimedCallbackExecution':
|
||||
self.start_time = time.perf_counter()
|
||||
@@ -419,7 +417,7 @@ def get_callback_execution_order(
|
||||
model_name: str,
|
||||
source: str,
|
||||
target: str,
|
||||
) -> List[Tuple[str, str, int]]:
|
||||
) -> list[tuple[str, str, int]]:
|
||||
"""
|
||||
Get the order of callback execution for a transition.
|
||||
|
||||
@@ -431,7 +429,7 @@ def get_callback_execution_order(
|
||||
Returns:
|
||||
List of (stage, callback_name, priority) tuples in execution order.
|
||||
"""
|
||||
from .callback_base import callback_registry, CallbackStage
|
||||
from .callback_base import CallbackStage
|
||||
|
||||
order = []
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
"""TransitionRegistry - Centralized registry for managing FSM transitions."""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, List, Optional, Any, Tuple, Type
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
||||
from django.db import models
|
||||
|
||||
from apps.core.state_machine.builder import StateTransitionBuilder
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class TransitionInfo:
|
||||
method_name: str
|
||||
requires_moderator: bool = False
|
||||
requires_admin_approval: bool = False
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __hash__(self):
|
||||
"""Make TransitionInfo hashable."""
|
||||
@@ -31,7 +31,7 @@ class TransitionRegistry:
|
||||
"""Centralized registry for managing and looking up FSM transitions."""
|
||||
|
||||
_instance: Optional["TransitionRegistry"] = None
|
||||
_transitions: Dict[Tuple[str, str], Dict[Tuple[str, str], TransitionInfo]]
|
||||
_transitions: dict[tuple[str, str], dict[tuple[str, str], TransitionInfo]]
|
||||
|
||||
def __new__(cls):
|
||||
"""Implement singleton pattern."""
|
||||
@@ -40,7 +40,7 @@ class TransitionRegistry:
|
||||
cls._instance._transitions = {}
|
||||
return cls._instance
|
||||
|
||||
def _get_key(self, choice_group: str, domain: str) -> Tuple[str, str]:
|
||||
def _get_key(self, choice_group: str, domain: str) -> tuple[str, str]:
|
||||
"""Generate registry key from choice group and domain."""
|
||||
return (domain, choice_group)
|
||||
|
||||
@@ -51,7 +51,7 @@ class TransitionRegistry:
|
||||
source: str,
|
||||
target: str,
|
||||
method_name: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> TransitionInfo:
|
||||
"""
|
||||
Register a transition.
|
||||
@@ -88,7 +88,7 @@ class TransitionRegistry:
|
||||
|
||||
def get_transition(
|
||||
self, choice_group: str, domain: str, source: str, target: str
|
||||
) -> Optional[TransitionInfo]:
|
||||
) -> TransitionInfo | None:
|
||||
"""
|
||||
Retrieve transition info.
|
||||
|
||||
@@ -111,7 +111,7 @@ class TransitionRegistry:
|
||||
|
||||
def get_available_transitions(
|
||||
self, choice_group: str, domain: str, current_state: str
|
||||
) -> List[TransitionInfo]:
|
||||
) -> list[TransitionInfo]:
|
||||
"""
|
||||
Get all valid transitions from a state.
|
||||
|
||||
@@ -129,7 +129,7 @@ class TransitionRegistry:
|
||||
return []
|
||||
|
||||
available = []
|
||||
for (source, target), info in self._transitions[key].items():
|
||||
for (source, _target), info in self._transitions[key].items():
|
||||
if source == current_state:
|
||||
available.append(info)
|
||||
|
||||
@@ -137,7 +137,7 @@ class TransitionRegistry:
|
||||
|
||||
def get_transition_method_name(
|
||||
self, choice_group: str, domain: str, source: str, target: str
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
"""
|
||||
Get the method name for a transition.
|
||||
|
||||
@@ -209,8 +209,8 @@ class TransitionRegistry:
|
||||
|
||||
def clear_registry(
|
||||
self,
|
||||
choice_group: Optional[str] = None,
|
||||
domain: Optional[str] = None,
|
||||
choice_group: str | None = None,
|
||||
domain: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Clear registry entries for testing.
|
||||
@@ -246,7 +246,7 @@ class TransitionRegistry:
|
||||
return {} if format == "dict" else ""
|
||||
|
||||
if format == "dict":
|
||||
graph: Dict[str, List[str]] = {}
|
||||
graph: dict[str, list[str]] = {}
|
||||
for (source, target), info in self._transitions[key].items():
|
||||
if source not in graph:
|
||||
graph[source] = []
|
||||
@@ -272,7 +272,7 @@ class TransitionRegistry:
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {format}")
|
||||
|
||||
def get_all_registered_groups(self) -> List[Tuple[str, str]]:
|
||||
def get_all_registered_groups(self) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Get all registered choice groups.
|
||||
|
||||
@@ -289,7 +289,7 @@ registry_instance = TransitionRegistry()
|
||||
# Callback registration helpers
|
||||
|
||||
def register_callback(
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
field_name: str,
|
||||
source: str,
|
||||
target: str,
|
||||
@@ -307,7 +307,7 @@ def register_callback(
|
||||
callback: The callback instance.
|
||||
stage: When to execute ('pre', 'post', 'error').
|
||||
"""
|
||||
from .callback_base import callback_registry, CallbackStage
|
||||
from .callback_base import CallbackStage, callback_registry
|
||||
|
||||
callback_registry.register(
|
||||
model_class=model_class,
|
||||
@@ -320,7 +320,7 @@ def register_callback(
|
||||
|
||||
|
||||
def register_notification_callback(
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
field_name: str,
|
||||
source: str,
|
||||
target: str,
|
||||
@@ -348,9 +348,9 @@ def register_notification_callback(
|
||||
|
||||
|
||||
def register_cache_invalidation(
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
field_name: str,
|
||||
cache_patterns: Optional[List[str]] = None,
|
||||
cache_patterns: list[str] | None = None,
|
||||
source: str = '*',
|
||||
target: str = '*',
|
||||
) -> None:
|
||||
@@ -371,7 +371,7 @@ def register_cache_invalidation(
|
||||
|
||||
|
||||
def register_related_update(
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
field_name: str,
|
||||
update_func: Callable,
|
||||
source: str = '*',
|
||||
@@ -393,7 +393,7 @@ def register_related_update(
|
||||
register_callback(model_class, field_name, source, target, callback, 'post')
|
||||
|
||||
|
||||
def register_transition_callbacks(cls: Type[models.Model]) -> Type[models.Model]:
|
||||
def register_transition_callbacks(cls: type[models.Model]) -> type[models.Model]:
|
||||
"""
|
||||
Class decorator to auto-register callbacks from model's Meta.
|
||||
|
||||
|
||||
@@ -5,15 +5,12 @@ This module defines custom Django signals emitted during state machine
|
||||
transitions and provides utilities for connecting signal handlers.
|
||||
"""
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
from django.db import models
|
||||
from django.dispatch import Signal, receiver
|
||||
|
||||
from .callback_base import TransitionContext
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -69,11 +66,11 @@ class TransitionSignalHandler:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._handlers: Dict[str, List[Callable]] = {}
|
||||
self._handlers: dict[str, list[Callable]] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
source: str,
|
||||
target: str,
|
||||
handler: Callable,
|
||||
@@ -105,7 +102,7 @@ class TransitionSignalHandler:
|
||||
|
||||
def unregister(
|
||||
self,
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
source: str,
|
||||
target: str,
|
||||
handler: Callable,
|
||||
@@ -121,7 +118,7 @@ class TransitionSignalHandler:
|
||||
|
||||
def _make_key(
|
||||
self,
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
source: str,
|
||||
target: str,
|
||||
stage: str,
|
||||
@@ -140,7 +137,7 @@ class TransitionSignalHandler:
|
||||
def _connect_signal(
|
||||
self,
|
||||
signal: Signal,
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
source: str,
|
||||
target: str,
|
||||
handler: Callable,
|
||||
@@ -173,7 +170,7 @@ transition_signal_handler = TransitionSignalHandler()
|
||||
|
||||
|
||||
def register_transition_handler(
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
source: str,
|
||||
target: str,
|
||||
handler: Callable,
|
||||
@@ -233,7 +230,7 @@ class TransitionHandlerDecorator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
source: str = '*',
|
||||
target: str = '*',
|
||||
stage: str = 'post',
|
||||
@@ -265,7 +262,7 @@ class TransitionHandlerDecorator:
|
||||
|
||||
|
||||
def on_transition(
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
source: str = '*',
|
||||
target: str = '*',
|
||||
stage: str = 'post',
|
||||
@@ -291,7 +288,7 @@ def on_transition(
|
||||
|
||||
|
||||
def on_pre_transition(
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
source: str = '*',
|
||||
target: str = '*',
|
||||
) -> TransitionHandlerDecorator:
|
||||
@@ -300,7 +297,7 @@ def on_pre_transition(
|
||||
|
||||
|
||||
def on_post_transition(
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
source: str = '*',
|
||||
target: str = '*',
|
||||
) -> TransitionHandlerDecorator:
|
||||
@@ -309,7 +306,7 @@ def on_post_transition(
|
||||
|
||||
|
||||
def on_transition_error(
|
||||
model_class: Type[models.Model],
|
||||
model_class: type[models.Model],
|
||||
source: str = '*',
|
||||
target: str = '*',
|
||||
) -> TransitionHandlerDecorator:
|
||||
|
||||
@@ -7,43 +7,44 @@ This module provides reusable fixtures for creating test data:
|
||||
- Mock objects for testing guards and callbacks
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from typing import Optional, Any, Dict
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
class UserFactory:
|
||||
"""Factory for creating users with different roles."""
|
||||
|
||||
|
||||
_counter = 0
|
||||
|
||||
|
||||
@classmethod
|
||||
def _get_unique_id(cls) -> int:
|
||||
"""Get a unique counter for creating unique usernames."""
|
||||
cls._counter += 1
|
||||
return cls._counter
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_user(
|
||||
cls,
|
||||
role: str = 'USER',
|
||||
username: Optional[str] = None,
|
||||
email: Optional[str] = None,
|
||||
username: str | None = None,
|
||||
email: str | None = None,
|
||||
password: str = 'testpass123',
|
||||
**kwargs
|
||||
) -> User:
|
||||
"""
|
||||
Create a user with specified role.
|
||||
|
||||
|
||||
Args:
|
||||
role: User role (USER, MODERATOR, ADMIN, SUPERUSER)
|
||||
username: Username (auto-generated if not provided)
|
||||
email: Email (auto-generated if not provided)
|
||||
password: Password for the user
|
||||
**kwargs: Additional user fields
|
||||
|
||||
|
||||
Returns:
|
||||
Created User instance
|
||||
"""
|
||||
@@ -52,7 +53,7 @@ class UserFactory:
|
||||
username = f"user_{role.lower()}_{uid}"
|
||||
if email is None:
|
||||
email = f"{role.lower()}_{uid}@example.com"
|
||||
|
||||
|
||||
return User.objects.create_user(
|
||||
username=username,
|
||||
email=email,
|
||||
@@ -60,22 +61,22 @@ class UserFactory:
|
||||
role=role,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_regular_user(cls, **kwargs) -> User:
|
||||
"""Create a regular user."""
|
||||
return cls.create_user(role='USER', **kwargs)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_moderator(cls, **kwargs) -> User:
|
||||
"""Create a moderator user."""
|
||||
return cls.create_user(role='MODERATOR', **kwargs)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_admin(cls, **kwargs) -> User:
|
||||
"""Create an admin user."""
|
||||
return cls.create_user(role='ADMIN', **kwargs)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_superuser(cls, **kwargs) -> User:
|
||||
"""Create a superuser."""
|
||||
@@ -84,23 +85,23 @@ class UserFactory:
|
||||
|
||||
class CompanyFactory:
|
||||
"""Factory for creating company instances."""
|
||||
|
||||
|
||||
_counter = 0
|
||||
|
||||
|
||||
@classmethod
|
||||
def _get_unique_id(cls) -> int:
|
||||
cls._counter += 1
|
||||
return cls._counter
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_operator(cls, name: Optional[str] = None, **kwargs) -> Any:
|
||||
def create_operator(cls, name: str | None = None, **kwargs) -> Any:
|
||||
"""Create an operator company."""
|
||||
from apps.parks.models import Company
|
||||
|
||||
|
||||
uid = cls._get_unique_id()
|
||||
if name is None:
|
||||
name = f"Test Operator {uid}"
|
||||
|
||||
|
||||
defaults = {
|
||||
'name': name,
|
||||
'description': f'Test operator company {uid}',
|
||||
@@ -108,16 +109,16 @@ class CompanyFactory:
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return Company.objects.create(**defaults)
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_manufacturer(cls, name: Optional[str] = None, **kwargs) -> Any:
|
||||
def create_manufacturer(cls, name: str | None = None, **kwargs) -> Any:
|
||||
"""Create a manufacturer company."""
|
||||
from apps.rides.models import Company
|
||||
|
||||
|
||||
uid = cls._get_unique_id()
|
||||
if name is None:
|
||||
name = f"Test Manufacturer {uid}"
|
||||
|
||||
|
||||
defaults = {
|
||||
'name': name,
|
||||
'description': f'Test manufacturer company {uid}',
|
||||
@@ -129,42 +130,42 @@ class CompanyFactory:
|
||||
|
||||
class ParkFactory:
|
||||
"""Factory for creating park instances."""
|
||||
|
||||
|
||||
_counter = 0
|
||||
|
||||
|
||||
@classmethod
|
||||
def _get_unique_id(cls) -> int:
|
||||
cls._counter += 1
|
||||
return cls._counter
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_park(
|
||||
cls,
|
||||
name: Optional[str] = None,
|
||||
operator: Optional[Any] = None,
|
||||
name: str | None = None,
|
||||
operator: Any | None = None,
|
||||
status: str = 'OPERATING',
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Create a park with specified status.
|
||||
|
||||
|
||||
Args:
|
||||
name: Park name (auto-generated if not provided)
|
||||
operator: Operator company (auto-created if not provided)
|
||||
status: Park status
|
||||
**kwargs: Additional park fields
|
||||
|
||||
|
||||
Returns:
|
||||
Created Park instance
|
||||
"""
|
||||
from apps.parks.models import Park
|
||||
|
||||
|
||||
uid = cls._get_unique_id()
|
||||
if name is None:
|
||||
name = f"Test Park {uid}"
|
||||
if operator is None:
|
||||
operator = CompanyFactory.create_operator()
|
||||
|
||||
|
||||
defaults = {
|
||||
'name': name,
|
||||
'slug': f'test-park-{uid}',
|
||||
@@ -179,38 +180,38 @@ class ParkFactory:
|
||||
|
||||
class RideFactory:
|
||||
"""Factory for creating ride instances."""
|
||||
|
||||
|
||||
_counter = 0
|
||||
|
||||
|
||||
@classmethod
|
||||
def _get_unique_id(cls) -> int:
|
||||
cls._counter += 1
|
||||
return cls._counter
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_ride(
|
||||
cls,
|
||||
name: Optional[str] = None,
|
||||
park: Optional[Any] = None,
|
||||
manufacturer: Optional[Any] = None,
|
||||
name: str | None = None,
|
||||
park: Any | None = None,
|
||||
manufacturer: Any | None = None,
|
||||
status: str = 'OPERATING',
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Create a ride with specified status.
|
||||
|
||||
|
||||
Args:
|
||||
name: Ride name (auto-generated if not provided)
|
||||
park: Park for the ride (auto-created if not provided)
|
||||
manufacturer: Manufacturer company (auto-created if not provided)
|
||||
status: Ride status
|
||||
**kwargs: Additional ride fields
|
||||
|
||||
|
||||
Returns:
|
||||
Created Ride instance
|
||||
"""
|
||||
from apps.rides.models import Ride
|
||||
|
||||
|
||||
uid = cls._get_unique_id()
|
||||
if name is None:
|
||||
name = f"Test Ride {uid}"
|
||||
@@ -218,7 +219,7 @@ class RideFactory:
|
||||
park = ParkFactory.create_park()
|
||||
if manufacturer is None:
|
||||
manufacturer = CompanyFactory.create_manufacturer()
|
||||
|
||||
|
||||
defaults = {
|
||||
'name': name,
|
||||
'slug': f'test-ride-{uid}',
|
||||
@@ -233,39 +234,39 @@ class RideFactory:
|
||||
|
||||
class EditSubmissionFactory:
|
||||
"""Factory for creating edit submission instances."""
|
||||
|
||||
|
||||
_counter = 0
|
||||
|
||||
|
||||
@classmethod
|
||||
def _get_unique_id(cls) -> int:
|
||||
cls._counter += 1
|
||||
return cls._counter
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_submission(
|
||||
cls,
|
||||
user: Optional[Any] = None,
|
||||
target_object: Optional[Any] = None,
|
||||
user: Any | None = None,
|
||||
target_object: Any | None = None,
|
||||
status: str = 'PENDING',
|
||||
changes: Optional[Dict[str, Any]] = None,
|
||||
changes: dict[str, Any] | None = None,
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Create an edit submission.
|
||||
|
||||
|
||||
Args:
|
||||
user: User who submitted (auto-created if not provided)
|
||||
target_object: Object being edited (auto-created if not provided)
|
||||
status: Submission status
|
||||
changes: Changes dictionary
|
||||
**kwargs: Additional fields
|
||||
|
||||
|
||||
Returns:
|
||||
Created EditSubmission instance
|
||||
"""
|
||||
from apps.moderation.models import EditSubmission
|
||||
from apps.parks.models import Company
|
||||
|
||||
|
||||
uid = cls._get_unique_id()
|
||||
if user is None:
|
||||
user = UserFactory.create_regular_user()
|
||||
@@ -276,9 +277,9 @@ class EditSubmissionFactory:
|
||||
)
|
||||
if changes is None:
|
||||
changes = {'name': f'Updated Name {uid}'}
|
||||
|
||||
|
||||
content_type = ContentType.objects.get_for_model(target_object)
|
||||
|
||||
|
||||
defaults = {
|
||||
'user': user,
|
||||
'content_type': content_type,
|
||||
@@ -294,37 +295,37 @@ class EditSubmissionFactory:
|
||||
|
||||
class ModerationReportFactory:
|
||||
"""Factory for creating moderation report instances."""
|
||||
|
||||
|
||||
_counter = 0
|
||||
|
||||
|
||||
@classmethod
|
||||
def _get_unique_id(cls) -> int:
|
||||
cls._counter += 1
|
||||
return cls._counter
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_report(
|
||||
cls,
|
||||
reporter: Optional[Any] = None,
|
||||
target_object: Optional[Any] = None,
|
||||
reporter: Any | None = None,
|
||||
target_object: Any | None = None,
|
||||
status: str = 'PENDING',
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Create a moderation report.
|
||||
|
||||
|
||||
Args:
|
||||
reporter: User who reported (auto-created if not provided)
|
||||
target_object: Object being reported (auto-created if not provided)
|
||||
status: Report status
|
||||
**kwargs: Additional fields
|
||||
|
||||
|
||||
Returns:
|
||||
Created ModerationReport instance
|
||||
"""
|
||||
from apps.moderation.models import ModerationReport
|
||||
from apps.parks.models import Company
|
||||
|
||||
|
||||
uid = cls._get_unique_id()
|
||||
if reporter is None:
|
||||
reporter = UserFactory.create_regular_user()
|
||||
@@ -333,9 +334,9 @@ class ModerationReportFactory:
|
||||
name=f'Reported Company {uid}',
|
||||
description='Test company'
|
||||
)
|
||||
|
||||
|
||||
content_type = ContentType.objects.get_for_model(target_object)
|
||||
|
||||
|
||||
defaults = {
|
||||
'report_type': 'CONTENT',
|
||||
'status': status,
|
||||
@@ -354,7 +355,7 @@ class ModerationReportFactory:
|
||||
class MockInstance:
|
||||
"""
|
||||
Mock instance for testing guards without database.
|
||||
|
||||
|
||||
Example:
|
||||
instance = MockInstance(
|
||||
status='PENDING',
|
||||
@@ -362,11 +363,11 @@ class MockInstance:
|
||||
assigned_to=moderator
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
def __repr__(self):
|
||||
attrs = ', '.join(f'{k}={v!r}' for k, v in self.__dict__.items())
|
||||
return f'MockInstance({attrs})'
|
||||
|
||||
@@ -7,34 +7,36 @@ This module provides utility functions for testing state machine functionality:
|
||||
- Guard testing utilities
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, List, Callable
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
|
||||
|
||||
def assert_transition_allowed(
|
||||
instance: Any,
|
||||
method_name: str,
|
||||
user: Optional[Any] = None
|
||||
user: Any | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Assert that a transition is allowed.
|
||||
|
||||
|
||||
Args:
|
||||
instance: Model instance with FSM field
|
||||
method_name: Name of the transition method
|
||||
user: User attempting the transition
|
||||
|
||||
|
||||
Returns:
|
||||
True if transition is allowed
|
||||
|
||||
|
||||
Raises:
|
||||
AssertionError: If transition is not allowed
|
||||
|
||||
|
||||
Example:
|
||||
assert_transition_allowed(submission, 'transition_to_approved', moderator)
|
||||
"""
|
||||
from django_fsm import can_proceed
|
||||
|
||||
|
||||
method = getattr(instance, method_name)
|
||||
result = can_proceed(method)
|
||||
assert result, f"Transition {method_name} should be allowed but was denied"
|
||||
@@ -44,27 +46,27 @@ def assert_transition_allowed(
|
||||
def assert_transition_denied(
|
||||
instance: Any,
|
||||
method_name: str,
|
||||
user: Optional[Any] = None
|
||||
user: Any | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Assert that a transition is denied.
|
||||
|
||||
|
||||
Args:
|
||||
instance: Model instance with FSM field
|
||||
method_name: Name of the transition method
|
||||
user: User attempting the transition
|
||||
|
||||
|
||||
Returns:
|
||||
True if transition is denied
|
||||
|
||||
|
||||
Raises:
|
||||
AssertionError: If transition is allowed
|
||||
|
||||
|
||||
Example:
|
||||
assert_transition_denied(submission, 'transition_to_approved', regular_user)
|
||||
"""
|
||||
from django_fsm import can_proceed
|
||||
|
||||
|
||||
method = getattr(instance, method_name)
|
||||
result = can_proceed(method)
|
||||
assert not result, f"Transition {method_name} should be denied but was allowed"
|
||||
@@ -74,130 +76,130 @@ def assert_transition_denied(
|
||||
def assert_state_log_created(
|
||||
instance: Any,
|
||||
expected_state: str,
|
||||
user: Optional[Any] = None
|
||||
user: Any | None = None
|
||||
) -> Any:
|
||||
"""
|
||||
Assert that a StateLog entry was created for a transition.
|
||||
|
||||
|
||||
Args:
|
||||
instance: Model instance that was transitioned
|
||||
expected_state: The expected final state in the log
|
||||
user: Expected user who made the transition (optional)
|
||||
|
||||
|
||||
Returns:
|
||||
The StateLog entry
|
||||
|
||||
|
||||
Raises:
|
||||
AssertionError: If StateLog entry not found or doesn't match
|
||||
|
||||
|
||||
Example:
|
||||
log = assert_state_log_created(submission, 'APPROVED', moderator)
|
||||
"""
|
||||
from django_fsm_log.models import StateLog
|
||||
|
||||
|
||||
ct = ContentType.objects.get_for_model(instance)
|
||||
log = StateLog.objects.filter(
|
||||
content_type=ct,
|
||||
object_id=instance.id,
|
||||
state=expected_state
|
||||
).first()
|
||||
|
||||
|
||||
assert log is not None, f"StateLog for state '{expected_state}' not found"
|
||||
|
||||
|
||||
if user is not None:
|
||||
assert log.by == user, f"Expected log.by={user}, got {log.by}"
|
||||
|
||||
|
||||
return log
|
||||
|
||||
|
||||
def assert_state_log_count(instance: Any, expected_count: int) -> List[Any]:
|
||||
def assert_state_log_count(instance: Any, expected_count: int) -> list[Any]:
|
||||
"""
|
||||
Assert the number of StateLog entries for an instance.
|
||||
|
||||
|
||||
Args:
|
||||
instance: Model instance to check logs for
|
||||
expected_count: Expected number of log entries
|
||||
|
||||
|
||||
Returns:
|
||||
List of StateLog entries
|
||||
|
||||
|
||||
Raises:
|
||||
AssertionError: If count doesn't match
|
||||
|
||||
|
||||
Example:
|
||||
logs = assert_state_log_count(submission, 2)
|
||||
"""
|
||||
from django_fsm_log.models import StateLog
|
||||
|
||||
|
||||
ct = ContentType.objects.get_for_model(instance)
|
||||
logs = list(StateLog.objects.filter(
|
||||
content_type=ct,
|
||||
object_id=instance.id
|
||||
).order_by('timestamp'))
|
||||
|
||||
|
||||
actual_count = len(logs)
|
||||
assert actual_count == expected_count, \
|
||||
f"Expected {expected_count} StateLog entries, got {actual_count}"
|
||||
|
||||
|
||||
return logs
|
||||
|
||||
|
||||
def assert_state_transition_sequence(
|
||||
instance: Any,
|
||||
expected_states: List[str]
|
||||
) -> List[Any]:
|
||||
expected_states: list[str]
|
||||
) -> list[Any]:
|
||||
"""
|
||||
Assert that state transitions occurred in a specific sequence.
|
||||
|
||||
|
||||
Args:
|
||||
instance: Model instance to check
|
||||
expected_states: List of expected states in order
|
||||
|
||||
|
||||
Returns:
|
||||
List of StateLog entries
|
||||
|
||||
|
||||
Raises:
|
||||
AssertionError: If sequence doesn't match
|
||||
|
||||
|
||||
Example:
|
||||
assert_state_transition_sequence(submission, ['ESCALATED', 'APPROVED'])
|
||||
"""
|
||||
from django_fsm_log.models import StateLog
|
||||
|
||||
|
||||
ct = ContentType.objects.get_for_model(instance)
|
||||
logs = list(StateLog.objects.filter(
|
||||
content_type=ct,
|
||||
object_id=instance.id
|
||||
).order_by('timestamp'))
|
||||
|
||||
|
||||
actual_states = [log.state for log in logs]
|
||||
assert actual_states == expected_states, \
|
||||
f"Expected state sequence {expected_states}, got {actual_states}"
|
||||
|
||||
|
||||
return logs
|
||||
|
||||
|
||||
def assert_guard_passes(
|
||||
guard: Callable,
|
||||
instance: Any,
|
||||
user: Optional[Any] = None,
|
||||
user: Any | None = None,
|
||||
message: str = ""
|
||||
) -> bool:
|
||||
"""
|
||||
Assert that a guard function passes.
|
||||
|
||||
|
||||
Args:
|
||||
guard: Guard function or callable
|
||||
instance: Model instance to check
|
||||
user: User attempting the action
|
||||
message: Optional message on failure
|
||||
|
||||
|
||||
Returns:
|
||||
True if guard passes
|
||||
|
||||
|
||||
Raises:
|
||||
AssertionError: If guard fails
|
||||
|
||||
|
||||
Example:
|
||||
assert_guard_passes(permission_guard, instance, moderator)
|
||||
"""
|
||||
@@ -210,58 +212,58 @@ def assert_guard_passes(
|
||||
def assert_guard_fails(
|
||||
guard: Callable,
|
||||
instance: Any,
|
||||
user: Optional[Any] = None,
|
||||
expected_error_code: Optional[str] = None,
|
||||
user: Any | None = None,
|
||||
expected_error_code: str | None = None,
|
||||
message: str = ""
|
||||
) -> bool:
|
||||
"""
|
||||
Assert that a guard function fails.
|
||||
|
||||
|
||||
Args:
|
||||
guard: Guard function or callable
|
||||
instance: Model instance to check
|
||||
user: User attempting the action
|
||||
expected_error_code: Expected error code from guard
|
||||
message: Optional message on failure
|
||||
|
||||
|
||||
Returns:
|
||||
True if guard fails as expected
|
||||
|
||||
|
||||
Raises:
|
||||
AssertionError: If guard passes or wrong error code
|
||||
|
||||
|
||||
Example:
|
||||
assert_guard_fails(permission_guard, instance, regular_user, 'PERMISSION_DENIED')
|
||||
"""
|
||||
result = guard(instance, user)
|
||||
fail_message = message or f"Guard should fail but returned {result}"
|
||||
assert result is False, fail_message
|
||||
|
||||
|
||||
if expected_error_code and hasattr(guard, 'error_code'):
|
||||
assert guard.error_code == expected_error_code, \
|
||||
f"Expected error code {expected_error_code}, got {guard.error_code}"
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def transition_and_save(
|
||||
instance: Any,
|
||||
transition_method: str,
|
||||
user: Optional[Any] = None,
|
||||
user: Any | None = None,
|
||||
**kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Execute a transition and save the instance.
|
||||
|
||||
|
||||
Args:
|
||||
instance: Model instance with FSM field
|
||||
transition_method: Name of the transition method
|
||||
user: User performing the transition
|
||||
**kwargs: Additional arguments for the transition
|
||||
|
||||
|
||||
Returns:
|
||||
The saved instance
|
||||
|
||||
|
||||
Example:
|
||||
submission = transition_and_save(submission, 'transition_to_approved', moderator)
|
||||
"""
|
||||
@@ -272,37 +274,36 @@ def transition_and_save(
|
||||
return instance
|
||||
|
||||
|
||||
def get_available_transitions(instance: Any) -> List[str]:
|
||||
def get_available_transitions(instance: Any) -> list[str]:
|
||||
"""
|
||||
Get list of available transitions for an instance.
|
||||
|
||||
|
||||
Args:
|
||||
instance: Model instance with FSM field
|
||||
|
||||
|
||||
Returns:
|
||||
List of available transition method names
|
||||
|
||||
|
||||
Example:
|
||||
transitions = get_available_transitions(submission)
|
||||
# ['transition_to_approved', 'transition_to_rejected', 'transition_to_escalated']
|
||||
"""
|
||||
from django_fsm import get_available_FIELD_transitions
|
||||
|
||||
|
||||
# Get the state field name from the instance
|
||||
state_field = getattr(instance, 'state_field_name', 'status')
|
||||
|
||||
|
||||
# Build the function name dynamically
|
||||
func_name = f'get_available_{state_field}_transitions'
|
||||
if hasattr(instance, func_name):
|
||||
get_transitions = getattr(instance, func_name)
|
||||
return [t.name for t in get_transitions()]
|
||||
|
||||
|
||||
# Fallback: look for transition methods
|
||||
transitions = []
|
||||
for attr_name in dir(instance):
|
||||
if attr_name.startswith('transition_to_'):
|
||||
transitions.append(attr_name)
|
||||
|
||||
|
||||
return transitions
|
||||
|
||||
|
||||
@@ -310,22 +311,22 @@ def create_transition_context(
|
||||
instance: Any,
|
||||
from_state: str,
|
||||
to_state: str,
|
||||
user: Optional[Any] = None,
|
||||
user: Any | None = None,
|
||||
**extra
|
||||
) -> dict:
|
||||
"""
|
||||
Create a mock transition context dictionary.
|
||||
|
||||
|
||||
Args:
|
||||
instance: Model instance being transitioned
|
||||
from_state: Source state
|
||||
to_state: Target state
|
||||
user: User performing the transition
|
||||
**extra: Additional context data
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary matching TransitionContext structure
|
||||
|
||||
|
||||
Example:
|
||||
context = create_transition_context(submission, 'PENDING', 'APPROVED', moderator)
|
||||
"""
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
import pytest
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
|
||||
from apps.core.choices.base import RichChoice, ChoiceCategory
|
||||
from apps.core.choices.base import ChoiceCategory, RichChoice
|
||||
from apps.core.choices.registry import registry
|
||||
from apps.core.state_machine.builder import StateTransitionBuilder
|
||||
|
||||
|
||||
@@ -9,14 +9,15 @@ This module tests:
|
||||
- Callback context handling
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from django.test import TestCase
|
||||
from unittest.mock import Mock, patch, call
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
class CallbackContext:
|
||||
"""Mock context for testing callbacks."""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instance: Any = None,
|
||||
@@ -30,8 +31,8 @@ class CallbackContext:
|
||||
self.to_state = to_state
|
||||
self.user = user
|
||||
self.extra = extra
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
'instance': self.instance,
|
||||
'from_state': self.from_state,
|
||||
@@ -43,179 +44,179 @@ class CallbackContext:
|
||||
|
||||
class MockCallback:
|
||||
"""Mock callback for testing."""
|
||||
|
||||
|
||||
def __init__(self, name: str = 'callback', should_raise: bool = False):
|
||||
self.name = name
|
||||
self.calls: List[Dict] = []
|
||||
self.calls: list[dict] = []
|
||||
self.should_raise = should_raise
|
||||
|
||||
def __call__(self, context: Dict[str, Any]) -> None:
|
||||
|
||||
def __call__(self, context: dict[str, Any]) -> None:
|
||||
self.calls.append(context)
|
||||
if self.should_raise:
|
||||
raise ValueError(f"Callback {self.name} failed")
|
||||
|
||||
|
||||
@property
|
||||
def call_count(self) -> int:
|
||||
return len(self.calls)
|
||||
|
||||
|
||||
def was_called(self) -> bool:
|
||||
return len(self.calls) > 0
|
||||
|
||||
|
||||
def reset(self):
|
||||
self.calls = []
|
||||
|
||||
|
||||
class PreTransitionCallbackTests(TestCase):
|
||||
"""Tests for pre-transition callbacks."""
|
||||
|
||||
|
||||
def test_pre_callback_executes_before_state_change(self):
|
||||
"""Test that pre-transition callback executes before state changes."""
|
||||
callback = MockCallback('pre_callback')
|
||||
context = CallbackContext(from_state='PENDING', to_state='APPROVED')
|
||||
|
||||
|
||||
# Simulate pre-transition execution
|
||||
callback(context.to_dict())
|
||||
|
||||
|
||||
self.assertTrue(callback.was_called())
|
||||
self.assertEqual(callback.calls[0]['from_state'], 'PENDING')
|
||||
self.assertEqual(callback.calls[0]['to_state'], 'APPROVED')
|
||||
|
||||
|
||||
def test_pre_callback_receives_instance(self):
|
||||
"""Test that pre-callback receives the model instance."""
|
||||
mock_instance = Mock()
|
||||
mock_instance.id = 123
|
||||
mock_instance.status = 'PENDING'
|
||||
|
||||
|
||||
callback = MockCallback()
|
||||
context = CallbackContext(instance=mock_instance)
|
||||
|
||||
|
||||
callback(context.to_dict())
|
||||
|
||||
|
||||
self.assertEqual(callback.calls[0]['instance'], mock_instance)
|
||||
|
||||
|
||||
def test_pre_callback_receives_user(self):
|
||||
"""Test that pre-callback receives the user performing transition."""
|
||||
mock_user = Mock()
|
||||
mock_user.username = 'moderator'
|
||||
|
||||
|
||||
callback = MockCallback()
|
||||
context = CallbackContext(user=mock_user)
|
||||
|
||||
|
||||
callback(context.to_dict())
|
||||
|
||||
|
||||
self.assertEqual(callback.calls[0]['user'], mock_user)
|
||||
|
||||
|
||||
def test_pre_callback_can_prevent_transition(self):
|
||||
"""Test that pre-callback can prevent transition by raising exception."""
|
||||
callback = MockCallback(should_raise=True)
|
||||
context = CallbackContext()
|
||||
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
callback(context.to_dict())
|
||||
|
||||
|
||||
def test_multiple_pre_callbacks_execute_in_order(self):
|
||||
"""Test that multiple pre-callbacks execute in registration order."""
|
||||
execution_order = []
|
||||
|
||||
|
||||
def callback_1(ctx):
|
||||
execution_order.append('first')
|
||||
|
||||
|
||||
def callback_2(ctx):
|
||||
execution_order.append('second')
|
||||
|
||||
|
||||
def callback_3(ctx):
|
||||
execution_order.append('third')
|
||||
|
||||
|
||||
context = CallbackContext().to_dict()
|
||||
|
||||
|
||||
# Execute in order
|
||||
callback_1(context)
|
||||
callback_2(context)
|
||||
callback_3(context)
|
||||
|
||||
|
||||
self.assertEqual(execution_order, ['first', 'second', 'third'])
|
||||
|
||||
|
||||
class PostTransitionCallbackTests(TestCase):
|
||||
"""Tests for post-transition callbacks."""
|
||||
|
||||
|
||||
def test_post_callback_executes_after_state_change(self):
|
||||
"""Test that post-transition callback executes after state changes."""
|
||||
callback = MockCallback('post_callback')
|
||||
|
||||
|
||||
# Simulate instance after transition
|
||||
mock_instance = Mock()
|
||||
mock_instance.status = 'APPROVED' # Already changed
|
||||
|
||||
|
||||
context = CallbackContext(
|
||||
instance=mock_instance,
|
||||
from_state='PENDING',
|
||||
to_state='APPROVED'
|
||||
)
|
||||
|
||||
|
||||
callback(context.to_dict())
|
||||
|
||||
|
||||
self.assertTrue(callback.was_called())
|
||||
self.assertEqual(callback.calls[0]['instance'].status, 'APPROVED')
|
||||
|
||||
|
||||
def test_post_callback_receives_updated_instance(self):
|
||||
"""Test that post-callback receives instance with new state."""
|
||||
mock_instance = Mock()
|
||||
mock_instance.status = 'APPROVED'
|
||||
mock_instance.approved_at = '2025-01-15'
|
||||
mock_instance.handled_by_id = 456
|
||||
|
||||
|
||||
callback = MockCallback()
|
||||
context = CallbackContext(instance=mock_instance)
|
||||
|
||||
|
||||
callback(context.to_dict())
|
||||
|
||||
|
||||
instance = callback.calls[0]['instance']
|
||||
self.assertEqual(instance.status, 'APPROVED')
|
||||
self.assertEqual(instance.approved_at, '2025-01-15')
|
||||
|
||||
|
||||
def test_post_callback_failure_does_not_rollback(self):
|
||||
"""Test that post-callback failures don't rollback the transition."""
|
||||
# In a real scenario, the transition would already be committed
|
||||
callback = MockCallback(should_raise=True)
|
||||
context = CallbackContext()
|
||||
|
||||
|
||||
# Post-callback failure should not affect already-committed transition
|
||||
with self.assertRaises(ValueError):
|
||||
callback(context.to_dict())
|
||||
|
||||
|
||||
# The transition would still be committed in real usage
|
||||
self.assertTrue(callback.was_called())
|
||||
|
||||
|
||||
def test_multiple_post_callbacks_execute_in_order(self):
|
||||
"""Test that multiple post-callbacks execute in order."""
|
||||
execution_order = []
|
||||
|
||||
|
||||
def notification_callback(ctx):
|
||||
execution_order.append('notification')
|
||||
|
||||
|
||||
def cache_callback(ctx):
|
||||
execution_order.append('cache')
|
||||
|
||||
|
||||
def analytics_callback(ctx):
|
||||
execution_order.append('analytics')
|
||||
|
||||
|
||||
context = CallbackContext().to_dict()
|
||||
|
||||
|
||||
notification_callback(context)
|
||||
cache_callback(context)
|
||||
analytics_callback(context)
|
||||
|
||||
|
||||
self.assertEqual(execution_order, ['notification', 'cache', 'analytics'])
|
||||
|
||||
|
||||
class ErrorCallbackTests(TestCase):
|
||||
"""Tests for error callbacks."""
|
||||
|
||||
|
||||
def test_error_callback_receives_exception(self):
|
||||
"""Test that error callback receives exception information."""
|
||||
error_callback = MockCallback()
|
||||
|
||||
|
||||
try:
|
||||
raise ValueError("Transition failed")
|
||||
except ValueError as e:
|
||||
@@ -227,33 +228,33 @@ class ErrorCallbackTests(TestCase):
|
||||
'exception_type': type(e).__name__
|
||||
}
|
||||
error_callback(error_context)
|
||||
|
||||
|
||||
self.assertTrue(error_callback.was_called())
|
||||
self.assertIn('exception', error_callback.calls[0])
|
||||
self.assertEqual(error_callback.calls[0]['exception_type'], 'ValueError')
|
||||
|
||||
|
||||
def test_error_callback_for_cleanup(self):
|
||||
"""Test that error callbacks can perform cleanup."""
|
||||
cleanup_performed = []
|
||||
|
||||
|
||||
def cleanup_callback(ctx):
|
||||
cleanup_performed.append(True)
|
||||
# In real usage, might release locks, revert partial changes, etc.
|
||||
|
||||
|
||||
try:
|
||||
raise ValueError("Transition failed")
|
||||
except ValueError:
|
||||
cleanup_callback({'exception': 'test'})
|
||||
|
||||
|
||||
self.assertTrue(cleanup_performed)
|
||||
|
||||
|
||||
def test_error_callback_receives_context(self):
|
||||
"""Test that error callback receives full transition context."""
|
||||
mock_instance = Mock()
|
||||
mock_user = Mock()
|
||||
|
||||
|
||||
error_callback = MockCallback()
|
||||
|
||||
|
||||
error_context = {
|
||||
'instance': mock_instance,
|
||||
'from_state': 'PENDING',
|
||||
@@ -261,152 +262,152 @@ class ErrorCallbackTests(TestCase):
|
||||
'user': mock_user,
|
||||
'exception': ValueError("Test error")
|
||||
}
|
||||
|
||||
|
||||
error_callback(error_context)
|
||||
|
||||
|
||||
self.assertEqual(error_callback.calls[0]['instance'], mock_instance)
|
||||
self.assertEqual(error_callback.calls[0]['user'], mock_user)
|
||||
|
||||
|
||||
class ConditionalCallbackTests(TestCase):
|
||||
"""Tests for conditional callback execution."""
|
||||
|
||||
|
||||
def test_callback_with_state_filter(self):
|
||||
"""Test callback that only executes for specific states."""
|
||||
execution_log = []
|
||||
|
||||
|
||||
def approval_only_callback(ctx):
|
||||
if ctx.get('to_state') == 'APPROVED':
|
||||
execution_log.append('approved')
|
||||
|
||||
|
||||
# Transition to APPROVED - should execute
|
||||
approval_only_callback({'to_state': 'APPROVED'})
|
||||
self.assertEqual(len(execution_log), 1)
|
||||
|
||||
|
||||
# Transition to REJECTED - should not execute
|
||||
approval_only_callback({'to_state': 'REJECTED'})
|
||||
self.assertEqual(len(execution_log), 1) # Still 1
|
||||
|
||||
|
||||
def test_callback_with_transition_filter(self):
|
||||
"""Test callback that only executes for specific transitions."""
|
||||
execution_log = []
|
||||
|
||||
|
||||
def escalation_callback(ctx):
|
||||
if ctx.get('to_state') == 'ESCALATED':
|
||||
execution_log.append('escalated')
|
||||
|
||||
|
||||
# Escalation - should execute
|
||||
escalation_callback({'to_state': 'ESCALATED'})
|
||||
self.assertEqual(len(execution_log), 1)
|
||||
|
||||
|
||||
# Other transitions - should not execute
|
||||
escalation_callback({'to_state': 'APPROVED'})
|
||||
self.assertEqual(len(execution_log), 1)
|
||||
|
||||
|
||||
def test_callback_with_user_role_filter(self):
|
||||
"""Test callback that checks user role."""
|
||||
admin_notifications = []
|
||||
|
||||
|
||||
def admin_only_notification(ctx):
|
||||
user = ctx.get('user')
|
||||
if user and getattr(user, 'role', None) == 'ADMIN':
|
||||
admin_notifications.append(ctx)
|
||||
|
||||
|
||||
admin_user = Mock(role='ADMIN')
|
||||
moderator_user = Mock(role='MODERATOR')
|
||||
|
||||
|
||||
admin_only_notification({'user': admin_user})
|
||||
self.assertEqual(len(admin_notifications), 1)
|
||||
|
||||
|
||||
admin_only_notification({'user': moderator_user})
|
||||
self.assertEqual(len(admin_notifications), 1) # Still 1
|
||||
|
||||
|
||||
class CallbackChainTests(TestCase):
|
||||
"""Tests for callback chains and pipelines."""
|
||||
|
||||
|
||||
def test_callback_chain_continues_on_success(self):
|
||||
"""Test that callback chain continues when callbacks succeed."""
|
||||
results = []
|
||||
|
||||
|
||||
callbacks = [
|
||||
lambda ctx: results.append('a'),
|
||||
lambda ctx: results.append('b'),
|
||||
lambda ctx: results.append('c'),
|
||||
]
|
||||
|
||||
|
||||
context = {}
|
||||
for cb in callbacks:
|
||||
cb(context)
|
||||
|
||||
|
||||
self.assertEqual(results, ['a', 'b', 'c'])
|
||||
|
||||
|
||||
def test_callback_chain_stops_on_failure(self):
|
||||
"""Test that callback chain stops when a callback fails."""
|
||||
results = []
|
||||
|
||||
|
||||
def callback_a(ctx):
|
||||
results.append('a')
|
||||
|
||||
|
||||
def callback_b(ctx):
|
||||
raise ValueError("B failed")
|
||||
|
||||
|
||||
def callback_c(ctx):
|
||||
results.append('c')
|
||||
|
||||
|
||||
callbacks = [callback_a, callback_b, callback_c]
|
||||
|
||||
|
||||
context = {}
|
||||
for cb in callbacks:
|
||||
try:
|
||||
cb(context)
|
||||
except ValueError:
|
||||
break
|
||||
|
||||
|
||||
self.assertEqual(results, ['a']) # c never executed
|
||||
|
||||
|
||||
def test_callback_chain_with_continue_on_error(self):
|
||||
"""Test callback chain that continues despite errors."""
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
|
||||
def callback_a(ctx):
|
||||
results.append('a')
|
||||
|
||||
|
||||
def callback_b(ctx):
|
||||
raise ValueError("B failed")
|
||||
|
||||
|
||||
def callback_c(ctx):
|
||||
results.append('c')
|
||||
|
||||
|
||||
callbacks = [callback_a, callback_b, callback_c]
|
||||
|
||||
|
||||
context = {}
|
||||
for cb in callbacks:
|
||||
try:
|
||||
cb(context)
|
||||
except Exception as e:
|
||||
errors.append(str(e))
|
||||
|
||||
|
||||
self.assertEqual(results, ['a', 'c'])
|
||||
self.assertEqual(len(errors), 1)
|
||||
|
||||
|
||||
class CallbackContextEnrichmentTests(TestCase):
|
||||
"""Tests for callback context enrichment."""
|
||||
|
||||
|
||||
def test_context_includes_model_class(self):
|
||||
"""Test that context includes the model class."""
|
||||
mock_instance = Mock()
|
||||
mock_instance.__class__.__name__ = 'EditSubmission'
|
||||
|
||||
|
||||
context = {
|
||||
'instance': mock_instance,
|
||||
'model_class': type(mock_instance)
|
||||
}
|
||||
|
||||
|
||||
self.assertIn('model_class', context)
|
||||
|
||||
|
||||
def test_context_includes_transition_name(self):
|
||||
"""Test that context includes the transition method name."""
|
||||
context = {
|
||||
@@ -415,9 +416,9 @@ class CallbackContextEnrichmentTests(TestCase):
|
||||
'to_state': 'APPROVED',
|
||||
'transition_name': 'transition_to_approved'
|
||||
}
|
||||
|
||||
|
||||
self.assertEqual(context['transition_name'], 'transition_to_approved')
|
||||
|
||||
|
||||
def test_context_includes_timestamp(self):
|
||||
"""Test that context includes transition timestamp."""
|
||||
from django.utils import timezone
|
||||
@@ -452,9 +453,10 @@ class NotificationCallbackTests(TestCase):
|
||||
instance=None,
|
||||
):
|
||||
"""Helper to create a TransitionContext."""
|
||||
from ..callback_base import TransitionContext
|
||||
from django.utils import timezone
|
||||
|
||||
from ..callback_base import TransitionContext
|
||||
|
||||
if instance is None:
|
||||
instance = Mock()
|
||||
instance.pk = 123
|
||||
@@ -603,9 +605,10 @@ class CacheCallbackTests(TestCase):
|
||||
target_state: str = 'CLOSED_TEMP',
|
||||
):
|
||||
"""Helper to create a TransitionContext."""
|
||||
from ..callback_base import TransitionContext
|
||||
from django.utils import timezone
|
||||
|
||||
from ..callback_base import TransitionContext
|
||||
|
||||
instance = Mock()
|
||||
instance.pk = instance_id
|
||||
instance.__class__.__name__ = model_name
|
||||
@@ -703,7 +706,7 @@ class CacheCallbackTests(TestCase):
|
||||
context = self._create_transition_context()
|
||||
|
||||
# Should not raise overall
|
||||
result = callback.execute(context)
|
||||
callback.execute(context)
|
||||
# All patterns should have been attempted
|
||||
self.assertGreater(call_count, 1)
|
||||
|
||||
@@ -716,9 +719,10 @@ class ModelCacheInvalidationTests(TestCase):
|
||||
model_name: str = 'Ride',
|
||||
instance_id: int = 789,
|
||||
):
|
||||
from ..callback_base import TransitionContext
|
||||
from django.utils import timezone
|
||||
|
||||
from ..callback_base import TransitionContext
|
||||
|
||||
instance = Mock()
|
||||
instance.pk = instance_id
|
||||
instance.__class__.__name__ = model_name
|
||||
@@ -771,7 +775,7 @@ class RelatedUpdateCallbackTests(TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
from django.contrib.auth import get_user_model
|
||||
User = get_user_model()
|
||||
get_user_model()
|
||||
|
||||
self.user = Mock()
|
||||
self.user.pk = 1
|
||||
@@ -783,9 +787,10 @@ class RelatedUpdateCallbackTests(TestCase):
|
||||
instance=None,
|
||||
target_state: str = 'OPERATING',
|
||||
):
|
||||
from ..callback_base import TransitionContext
|
||||
from django.utils import timezone
|
||||
|
||||
from ..callback_base import TransitionContext
|
||||
|
||||
if instance is None:
|
||||
instance = Mock()
|
||||
instance.pk = 123
|
||||
@@ -920,9 +925,10 @@ class CallbackErrorHandlingTests(TestCase):
|
||||
"""Tests for callback error handling paths."""
|
||||
|
||||
def _create_transition_context(self):
|
||||
from ..callback_base import TransitionContext
|
||||
from django.utils import timezone
|
||||
|
||||
from ..callback_base import TransitionContext
|
||||
|
||||
instance = Mock()
|
||||
instance.pk = 1
|
||||
instance.__class__.__name__ = 'EditSubmission'
|
||||
@@ -939,9 +945,10 @@ class CallbackErrorHandlingTests(TestCase):
|
||||
@patch('apps.core.state_machine.callbacks.notifications.NotificationService')
|
||||
def test_notification_callback_logs_error_on_failure(self, mock_service_class):
|
||||
"""Test NotificationCallback logs errors when service fails."""
|
||||
from ..callbacks.notifications import NotificationCallback
|
||||
import logging
|
||||
|
||||
from ..callbacks.notifications import NotificationCallback
|
||||
|
||||
mock_service = Mock()
|
||||
mock_service.send_notification = Mock(side_effect=Exception("Network error"))
|
||||
mock_service_class.return_value = mock_service
|
||||
@@ -949,7 +956,7 @@ class CallbackErrorHandlingTests(TestCase):
|
||||
callback = NotificationCallback()
|
||||
context = self._create_transition_context()
|
||||
|
||||
with self.assertLogs(level=logging.WARNING) as log_output:
|
||||
with self.assertLogs(level=logging.WARNING):
|
||||
try:
|
||||
callback.execute(context)
|
||||
except Exception:
|
||||
@@ -979,9 +986,10 @@ class CallbackErrorHandlingTests(TestCase):
|
||||
|
||||
def test_callback_with_none_user(self):
|
||||
"""Test callbacks handle None user gracefully."""
|
||||
from django.utils import timezone
|
||||
|
||||
from ..callback_base import TransitionContext
|
||||
from ..callbacks.notifications import NotificationCallback
|
||||
from django.utils import timezone
|
||||
|
||||
instance = Mock()
|
||||
instance.pk = 1
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
"""Tests for transition decorator generation."""
|
||||
import pytest
|
||||
from unittest.mock import Mock
|
||||
|
||||
from apps.core.state_machine.decorators import (
|
||||
generate_transition_decorator,
|
||||
create_transition_method,
|
||||
TransitionMethodFactory,
|
||||
create_transition_method,
|
||||
generate_transition_decorator,
|
||||
with_transition_logging,
|
||||
)
|
||||
|
||||
|
||||
@@ -10,33 +10,30 @@ This module contains tests for:
|
||||
- CompositeGuard (combining guards with AND/OR logic)
|
||||
"""
|
||||
|
||||
from django.test import TestCase
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.auth.models import AnonymousUser
|
||||
from django.test import TestCase
|
||||
|
||||
from apps.core.state_machine.guards import (
|
||||
PermissionGuard,
|
||||
OwnershipGuard,
|
||||
ADMIN_ROLES,
|
||||
MODERATOR_ROLES,
|
||||
AssignmentGuard,
|
||||
StateGuard,
|
||||
MetadataGuard,
|
||||
CompositeGuard,
|
||||
extract_guards_from_metadata,
|
||||
create_permission_guard,
|
||||
create_ownership_guard,
|
||||
MetadataGuard,
|
||||
OwnershipGuard,
|
||||
PermissionGuard,
|
||||
StateGuard,
|
||||
create_assignment_guard,
|
||||
create_composite_guard,
|
||||
validate_guard_metadata,
|
||||
create_ownership_guard,
|
||||
create_permission_guard,
|
||||
extract_guards_from_metadata,
|
||||
get_user_role,
|
||||
has_role,
|
||||
is_moderator_or_above,
|
||||
is_admin_or_above,
|
||||
is_moderator_or_above,
|
||||
is_superuser_role,
|
||||
has_permission,
|
||||
VALID_ROLES,
|
||||
MODERATOR_ROLES,
|
||||
ADMIN_ROLES,
|
||||
SUPERUSER_ROLES,
|
||||
validate_guard_metadata,
|
||||
)
|
||||
|
||||
User = get_user_model()
|
||||
@@ -44,7 +41,7 @@ User = get_user_model()
|
||||
|
||||
class MockInstance:
|
||||
"""Mock instance for testing guards."""
|
||||
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
@@ -89,90 +86,90 @@ class PermissionGuardTests(TestCase):
|
||||
def test_no_user_fails(self):
|
||||
"""Test that guard fails when no user is provided."""
|
||||
guard = PermissionGuard(requires_moderator=True)
|
||||
|
||||
|
||||
result = guard(self.instance, user=None)
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(guard.error_code, PermissionGuard.ERROR_CODE_NO_USER)
|
||||
|
||||
def test_requires_moderator_allows_moderator(self):
|
||||
"""Test that requires_moderator allows moderator role."""
|
||||
guard = PermissionGuard(requires_moderator=True)
|
||||
|
||||
|
||||
result = guard(self.instance, user=self.moderator)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_requires_moderator_allows_admin(self):
|
||||
"""Test that requires_moderator allows admin role."""
|
||||
guard = PermissionGuard(requires_moderator=True)
|
||||
|
||||
|
||||
result = guard(self.instance, user=self.admin)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_requires_moderator_allows_superuser(self):
|
||||
"""Test that requires_moderator allows superuser role."""
|
||||
guard = PermissionGuard(requires_moderator=True)
|
||||
|
||||
|
||||
result = guard(self.instance, user=self.superuser)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_requires_moderator_denies_regular_user(self):
|
||||
"""Test that requires_moderator denies regular user."""
|
||||
guard = PermissionGuard(requires_moderator=True)
|
||||
|
||||
|
||||
result = guard(self.instance, user=self.regular_user)
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(guard.error_code, PermissionGuard.ERROR_CODE_PERMISSION_DENIED_ROLE)
|
||||
|
||||
def test_requires_admin_allows_admin(self):
|
||||
"""Test that requires_admin allows admin role."""
|
||||
guard = PermissionGuard(requires_admin=True)
|
||||
|
||||
|
||||
result = guard(self.instance, user=self.admin)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_requires_admin_allows_superuser(self):
|
||||
"""Test that requires_admin allows superuser role."""
|
||||
guard = PermissionGuard(requires_admin=True)
|
||||
|
||||
|
||||
result = guard(self.instance, user=self.superuser)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_requires_admin_denies_moderator(self):
|
||||
"""Test that requires_admin denies moderator role."""
|
||||
guard = PermissionGuard(requires_admin=True)
|
||||
|
||||
|
||||
result = guard(self.instance, user=self.moderator)
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(guard.error_code, PermissionGuard.ERROR_CODE_PERMISSION_DENIED_ROLE)
|
||||
|
||||
def test_requires_superuser_allows_superuser(self):
|
||||
"""Test that requires_superuser allows superuser role."""
|
||||
guard = PermissionGuard(requires_superuser=True)
|
||||
|
||||
|
||||
result = guard(self.instance, user=self.superuser)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_requires_superuser_denies_admin(self):
|
||||
"""Test that requires_superuser denies admin role."""
|
||||
guard = PermissionGuard(requires_superuser=True)
|
||||
|
||||
|
||||
result = guard(self.instance, user=self.admin)
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_required_roles_explicit_list(self):
|
||||
"""Test using explicit required_roles list."""
|
||||
guard = PermissionGuard(required_roles=['ADMIN', 'SUPERUSER'])
|
||||
|
||||
|
||||
self.assertTrue(guard(self.instance, user=self.admin))
|
||||
self.assertTrue(guard(self.instance, user=self.superuser))
|
||||
self.assertFalse(guard(self.instance, user=self.moderator))
|
||||
@@ -182,24 +179,24 @@ class PermissionGuardTests(TestCase):
|
||||
"""Test custom check function that passes."""
|
||||
def custom_check(instance, user):
|
||||
return hasattr(instance, 'allow_access') and instance.allow_access
|
||||
|
||||
|
||||
guard = PermissionGuard(custom_check=custom_check)
|
||||
instance = MockInstance(allow_access=True)
|
||||
|
||||
|
||||
result = guard(instance, user=self.regular_user)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_custom_check_fails(self):
|
||||
"""Test custom check function that fails."""
|
||||
def custom_check(instance, user):
|
||||
return hasattr(instance, 'allow_access') and instance.allow_access
|
||||
|
||||
|
||||
guard = PermissionGuard(custom_check=custom_check)
|
||||
instance = MockInstance(allow_access=False)
|
||||
|
||||
|
||||
result = guard(instance, user=self.regular_user)
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(guard.error_code, PermissionGuard.ERROR_CODE_PERMISSION_DENIED_CUSTOM)
|
||||
|
||||
@@ -207,25 +204,25 @@ class PermissionGuardTests(TestCase):
|
||||
"""Test custom error message."""
|
||||
custom_message = "You need special access for this"
|
||||
guard = PermissionGuard(requires_moderator=True, error_message=custom_message)
|
||||
|
||||
|
||||
guard(self.instance, user=self.regular_user)
|
||||
|
||||
|
||||
self.assertEqual(guard.get_error_message(), custom_message)
|
||||
|
||||
def test_get_required_roles_moderator(self):
|
||||
"""Test get_required_roles for moderator requirement."""
|
||||
guard = PermissionGuard(requires_moderator=True)
|
||||
|
||||
|
||||
roles = guard.get_required_roles()
|
||||
|
||||
|
||||
self.assertEqual(set(roles), set(MODERATOR_ROLES))
|
||||
|
||||
def test_get_required_roles_admin(self):
|
||||
"""Test get_required_roles for admin requirement."""
|
||||
guard = PermissionGuard(requires_admin=True)
|
||||
|
||||
|
||||
roles = guard.get_required_roles()
|
||||
|
||||
|
||||
self.assertEqual(set(roles), set(ADMIN_ROLES))
|
||||
|
||||
|
||||
@@ -268,9 +265,9 @@ class OwnershipGuardTests(TestCase):
|
||||
"""Test that guard fails when no user is provided."""
|
||||
instance = MockInstance(created_by=self.owner)
|
||||
guard = OwnershipGuard()
|
||||
|
||||
|
||||
result = guard(instance, user=None)
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(guard.error_code, OwnershipGuard.ERROR_CODE_NO_USER)
|
||||
|
||||
@@ -278,36 +275,36 @@ class OwnershipGuardTests(TestCase):
|
||||
"""Test that owner passes via created_by field."""
|
||||
instance = MockInstance(created_by=self.owner)
|
||||
guard = OwnershipGuard()
|
||||
|
||||
|
||||
result = guard(instance, user=self.owner)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_owner_passes_user_field(self):
|
||||
"""Test that owner passes via user field."""
|
||||
instance = MockInstance(user=self.owner)
|
||||
guard = OwnershipGuard()
|
||||
|
||||
|
||||
result = guard(instance, user=self.owner)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_owner_passes_submitted_by(self):
|
||||
"""Test that owner passes via submitted_by field."""
|
||||
instance = MockInstance(submitted_by=self.owner)
|
||||
guard = OwnershipGuard()
|
||||
|
||||
|
||||
result = guard(instance, user=self.owner)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_non_owner_fails(self):
|
||||
"""Test that non-owner fails."""
|
||||
instance = MockInstance(created_by=self.owner)
|
||||
guard = OwnershipGuard()
|
||||
|
||||
|
||||
result = guard(instance, user=self.other_user)
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(guard.error_code, OwnershipGuard.ERROR_CODE_NOT_OWNER)
|
||||
|
||||
@@ -315,27 +312,27 @@ class OwnershipGuardTests(TestCase):
|
||||
"""Test that moderator can bypass ownership check."""
|
||||
instance = MockInstance(created_by=self.owner)
|
||||
guard = OwnershipGuard(allow_moderator_override=True)
|
||||
|
||||
|
||||
result = guard(instance, user=self.moderator)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_admin_override(self):
|
||||
"""Test that admin can bypass ownership check."""
|
||||
instance = MockInstance(created_by=self.owner)
|
||||
guard = OwnershipGuard(allow_admin_override=True)
|
||||
|
||||
|
||||
result = guard(instance, user=self.admin)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_custom_owner_fields(self):
|
||||
"""Test custom owner field names."""
|
||||
instance = MockInstance(author=self.owner)
|
||||
guard = OwnershipGuard(owner_fields=['author'])
|
||||
|
||||
|
||||
result = guard(instance, user=self.owner)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_anonymous_user_fails(self):
|
||||
@@ -343,9 +340,9 @@ class OwnershipGuardTests(TestCase):
|
||||
instance = MockInstance(created_by=self.owner)
|
||||
guard = OwnershipGuard()
|
||||
anonymous = AnonymousUser()
|
||||
|
||||
|
||||
result = guard(instance, user=anonymous)
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
@@ -382,9 +379,9 @@ class AssignmentGuardTests(TestCase):
|
||||
"""Test that guard fails when no user is provided."""
|
||||
instance = MockInstance(assigned_to=self.assigned_user)
|
||||
guard = AssignmentGuard()
|
||||
|
||||
|
||||
result = guard(instance, user=None)
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(guard.error_code, AssignmentGuard.ERROR_CODE_NO_USER)
|
||||
|
||||
@@ -392,18 +389,18 @@ class AssignmentGuardTests(TestCase):
|
||||
"""Test that assigned user passes."""
|
||||
instance = MockInstance(assigned_to=self.assigned_user)
|
||||
guard = AssignmentGuard()
|
||||
|
||||
|
||||
result = guard(instance, user=self.assigned_user)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_unassigned_user_fails(self):
|
||||
"""Test that unassigned user fails."""
|
||||
instance = MockInstance(assigned_to=self.assigned_user)
|
||||
guard = AssignmentGuard()
|
||||
|
||||
|
||||
result = guard(instance, user=self.other_user)
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(guard.error_code, AssignmentGuard.ERROR_CODE_NOT_ASSIGNED)
|
||||
|
||||
@@ -411,18 +408,18 @@ class AssignmentGuardTests(TestCase):
|
||||
"""Test that admin can bypass assignment check."""
|
||||
instance = MockInstance(assigned_to=self.assigned_user)
|
||||
guard = AssignmentGuard(allow_admin_override=True)
|
||||
|
||||
|
||||
result = guard(instance, user=self.admin)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_require_assignment_with_no_assignment(self):
|
||||
"""Test require_assignment fails when no one is assigned."""
|
||||
instance = MockInstance(assigned_to=None)
|
||||
guard = AssignmentGuard(require_assignment=True)
|
||||
|
||||
|
||||
result = guard(instance, user=self.assigned_user)
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(guard.error_code, AssignmentGuard.ERROR_CODE_NO_ASSIGNMENT)
|
||||
|
||||
@@ -430,18 +427,18 @@ class AssignmentGuardTests(TestCase):
|
||||
"""Test custom assignment field names."""
|
||||
instance = MockInstance(reviewer=self.assigned_user)
|
||||
guard = AssignmentGuard(assignment_fields=['reviewer'])
|
||||
|
||||
|
||||
result = guard(instance, user=self.assigned_user)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_error_message_for_no_assignment(self):
|
||||
"""Test error message when no assignment exists."""
|
||||
instance = MockInstance(assigned_to=None)
|
||||
guard = AssignmentGuard(require_assignment=True)
|
||||
|
||||
|
||||
guard(instance, user=self.assigned_user)
|
||||
|
||||
|
||||
self.assertIn('assigned', guard.get_error_message().lower())
|
||||
|
||||
|
||||
@@ -466,18 +463,18 @@ class StateGuardTests(TestCase):
|
||||
"""Test that guard passes when in allowed state."""
|
||||
instance = MockInstance(status='PENDING')
|
||||
guard = StateGuard(allowed_states=['PENDING', 'UNDER_REVIEW'])
|
||||
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_allowed_states_fails(self):
|
||||
"""Test that guard fails when not in allowed state."""
|
||||
instance = MockInstance(status='COMPLETED')
|
||||
guard = StateGuard(allowed_states=['PENDING', 'UNDER_REVIEW'])
|
||||
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(guard.error_code, StateGuard.ERROR_CODE_INVALID_STATE)
|
||||
|
||||
@@ -485,18 +482,18 @@ class StateGuardTests(TestCase):
|
||||
"""Test that guard passes when not in blocked state."""
|
||||
instance = MockInstance(status='PENDING')
|
||||
guard = StateGuard(blocked_states=['COMPLETED', 'CANCELLED'])
|
||||
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_blocked_states_fails(self):
|
||||
"""Test that guard fails when in blocked state."""
|
||||
instance = MockInstance(status='COMPLETED')
|
||||
guard = StateGuard(blocked_states=['COMPLETED', 'CANCELLED'])
|
||||
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(guard.error_code, StateGuard.ERROR_CODE_BLOCKED_STATE)
|
||||
|
||||
@@ -504,18 +501,18 @@ class StateGuardTests(TestCase):
|
||||
"""Test using custom state field name."""
|
||||
instance = MockInstance(workflow_status='ACTIVE')
|
||||
guard = StateGuard(allowed_states=['ACTIVE'], state_field='workflow_status')
|
||||
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_error_message_includes_states(self):
|
||||
"""Test that error message includes allowed states."""
|
||||
instance = MockInstance(status='COMPLETED')
|
||||
guard = StateGuard(allowed_states=['PENDING', 'UNDER_REVIEW'])
|
||||
|
||||
|
||||
guard(instance, user=self.user)
|
||||
|
||||
|
||||
message = guard.get_error_message()
|
||||
self.assertIn('PENDING', message)
|
||||
self.assertIn('UNDER_REVIEW', message)
|
||||
@@ -542,18 +539,18 @@ class MetadataGuardTests(TestCase):
|
||||
"""Test that guard passes when required fields are present."""
|
||||
instance = MockInstance(resolution_notes='Fixed', assigned_to='user')
|
||||
guard = MetadataGuard(required_fields=['resolution_notes', 'assigned_to'])
|
||||
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_required_field_missing(self):
|
||||
"""Test that guard fails when required field is missing."""
|
||||
instance = MockInstance(resolution_notes='Fixed')
|
||||
guard = MetadataGuard(required_fields=['resolution_notes', 'assigned_to'])
|
||||
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(guard.error_code, MetadataGuard.ERROR_CODE_MISSING_FIELD)
|
||||
|
||||
@@ -561,9 +558,9 @@ class MetadataGuardTests(TestCase):
|
||||
"""Test that guard fails when required field is None."""
|
||||
instance = MockInstance(resolution_notes=None)
|
||||
guard = MetadataGuard(required_fields=['resolution_notes'])
|
||||
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(guard.error_code, MetadataGuard.ERROR_CODE_MISSING_FIELD)
|
||||
|
||||
@@ -571,9 +568,9 @@ class MetadataGuardTests(TestCase):
|
||||
"""Test that empty string fails when check_not_empty is True."""
|
||||
instance = MockInstance(resolution_notes=' ')
|
||||
guard = MetadataGuard(required_fields=['resolution_notes'], check_not_empty=True)
|
||||
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(guard.error_code, MetadataGuard.ERROR_CODE_EMPTY_FIELD)
|
||||
|
||||
@@ -581,9 +578,9 @@ class MetadataGuardTests(TestCase):
|
||||
"""Test that empty list fails when check_not_empty is True."""
|
||||
instance = MockInstance(tags=[])
|
||||
guard = MetadataGuard(required_fields=['tags'], check_not_empty=True)
|
||||
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(guard.error_code, MetadataGuard.ERROR_CODE_EMPTY_FIELD)
|
||||
|
||||
@@ -591,9 +588,9 @@ class MetadataGuardTests(TestCase):
|
||||
"""Test that empty dict fails when check_not_empty is True."""
|
||||
instance = MockInstance(metadata={})
|
||||
guard = MetadataGuard(required_fields=['metadata'], check_not_empty=True)
|
||||
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(guard.error_code, MetadataGuard.ERROR_CODE_EMPTY_FIELD)
|
||||
|
||||
@@ -601,9 +598,9 @@ class MetadataGuardTests(TestCase):
|
||||
"""Test that error message includes the field name."""
|
||||
instance = MockInstance(resolution_notes=None)
|
||||
guard = MetadataGuard(required_fields=['resolution_notes'])
|
||||
|
||||
|
||||
guard(instance, user=self.user)
|
||||
|
||||
|
||||
message = guard.get_error_message()
|
||||
self.assertIn('Resolution Notes', message)
|
||||
|
||||
@@ -645,9 +642,9 @@ class CompositeGuardTests(TestCase):
|
||||
OwnershipGuard()
|
||||
]
|
||||
composite = CompositeGuard(guards, operator='AND')
|
||||
|
||||
|
||||
result = composite(instance, user=self.moderator)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_and_operator_one_fails(self):
|
||||
@@ -658,9 +655,9 @@ class CompositeGuardTests(TestCase):
|
||||
OwnershipGuard() # Will fail - moderator is not owner
|
||||
]
|
||||
composite = CompositeGuard(guards, operator='AND')
|
||||
|
||||
|
||||
result = composite(instance, user=self.non_owner_moderator)
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(composite.error_code, CompositeGuard.ERROR_CODE_SOME_FAILED)
|
||||
|
||||
@@ -672,9 +669,9 @@ class CompositeGuardTests(TestCase):
|
||||
OwnershipGuard() # Will pass - user is owner
|
||||
]
|
||||
composite = CompositeGuard(guards, operator='OR')
|
||||
|
||||
|
||||
result = composite(instance, user=self.owner)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_or_operator_all_fail(self):
|
||||
@@ -685,30 +682,30 @@ class CompositeGuardTests(TestCase):
|
||||
OwnershipGuard() # Not the owner fails
|
||||
]
|
||||
composite = CompositeGuard(guards, operator='OR')
|
||||
|
||||
|
||||
result = composite(instance, user=self.owner)
|
||||
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(composite.error_code, CompositeGuard.ERROR_CODE_ALL_FAILED)
|
||||
|
||||
def test_nested_composite_guards(self):
|
||||
"""Test nested composite guards."""
|
||||
instance = MockInstance(created_by=self.moderator, status='PENDING')
|
||||
|
||||
|
||||
# Inner composite: moderator OR owner
|
||||
inner = CompositeGuard([
|
||||
PermissionGuard(requires_moderator=True),
|
||||
OwnershipGuard()
|
||||
], operator='OR')
|
||||
|
||||
|
||||
# Outer composite: (moderator OR owner) AND valid state
|
||||
outer = CompositeGuard([
|
||||
inner,
|
||||
StateGuard(allowed_states=['PENDING'])
|
||||
], operator='AND')
|
||||
|
||||
|
||||
result = outer(instance, user=self.moderator)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_error_message_from_failed_guard(self):
|
||||
@@ -717,9 +714,9 @@ class CompositeGuardTests(TestCase):
|
||||
perm_guard = PermissionGuard(requires_admin=True)
|
||||
guards = [perm_guard]
|
||||
composite = CompositeGuard(guards, operator='AND')
|
||||
|
||||
|
||||
composite(instance, user=self.owner)
|
||||
|
||||
|
||||
message = composite.get_error_message()
|
||||
self.assertIn('admin', message.lower())
|
||||
|
||||
@@ -746,42 +743,42 @@ class GuardFactoryTests(TestCase):
|
||||
metadata = {'requires_moderator': True}
|
||||
guard = create_permission_guard(metadata)
|
||||
instance = MockInstance()
|
||||
|
||||
|
||||
result = guard(instance, user=self.moderator)
|
||||
|
||||
|
||||
self.assertTrue(result)
|
||||
|
||||
def test_create_permission_guard_admin(self):
|
||||
"""Test create_permission_guard with admin requirement."""
|
||||
metadata = {'requires_admin_approval': True}
|
||||
guard = create_permission_guard(metadata)
|
||||
|
||||
|
||||
self.assertTrue(guard.requires_admin)
|
||||
|
||||
def test_create_permission_guard_escalation_level(self):
|
||||
"""Test create_permission_guard with escalation level."""
|
||||
metadata = {'escalation_level': 'admin'}
|
||||
guard = create_permission_guard(metadata)
|
||||
|
||||
|
||||
self.assertTrue(guard.requires_admin)
|
||||
|
||||
def test_create_ownership_guard(self):
|
||||
"""Test create_ownership_guard factory."""
|
||||
guard = create_ownership_guard(allow_moderator_override=True)
|
||||
|
||||
|
||||
self.assertTrue(guard.allow_moderator_override)
|
||||
|
||||
def test_create_assignment_guard(self):
|
||||
"""Test create_assignment_guard factory."""
|
||||
guard = create_assignment_guard(require_assignment=True)
|
||||
|
||||
|
||||
self.assertTrue(guard.require_assignment)
|
||||
|
||||
def test_create_composite_guard(self):
|
||||
"""Test create_composite_guard factory."""
|
||||
guards = [PermissionGuard(), OwnershipGuard()]
|
||||
composite = create_composite_guard(guards, operator='OR')
|
||||
|
||||
|
||||
self.assertEqual(composite.operator, 'OR')
|
||||
self.assertEqual(len(composite.guards), 2)
|
||||
|
||||
@@ -798,7 +795,7 @@ class MetadataExtractionTests(TestCase):
|
||||
"""Test extracting guard for moderator requirement."""
|
||||
metadata = {'requires_moderator': True}
|
||||
guards = extract_guards_from_metadata(metadata)
|
||||
|
||||
|
||||
self.assertEqual(len(guards), 1)
|
||||
self.assertIsInstance(guards[0], PermissionGuard)
|
||||
|
||||
@@ -806,7 +803,7 @@ class MetadataExtractionTests(TestCase):
|
||||
"""Test extracting guard for admin requirement."""
|
||||
metadata = {'requires_admin_approval': True}
|
||||
guards = extract_guards_from_metadata(metadata)
|
||||
|
||||
|
||||
self.assertEqual(len(guards), 1)
|
||||
self.assertTrue(guards[0].requires_admin)
|
||||
|
||||
@@ -814,7 +811,7 @@ class MetadataExtractionTests(TestCase):
|
||||
"""Test extracting assignment guard."""
|
||||
metadata = {'requires_assignment': True}
|
||||
guards = extract_guards_from_metadata(metadata)
|
||||
|
||||
|
||||
self.assertEqual(len(guards), 1)
|
||||
self.assertIsInstance(guards[0], AssignmentGuard)
|
||||
|
||||
@@ -825,21 +822,21 @@ class MetadataExtractionTests(TestCase):
|
||||
'requires_assignment': True
|
||||
}
|
||||
guards = extract_guards_from_metadata(metadata)
|
||||
|
||||
|
||||
self.assertEqual(len(guards), 2)
|
||||
|
||||
def test_extract_zero_tolerance_guard(self):
|
||||
"""Test extracting guard for zero tolerance (superuser required)."""
|
||||
metadata = {'zero_tolerance': True}
|
||||
guards = extract_guards_from_metadata(metadata)
|
||||
|
||||
|
||||
self.assertEqual(len(guards), 1)
|
||||
self.assertTrue(guards[0].requires_superuser)
|
||||
|
||||
def test_invalid_escalation_level_raises(self):
|
||||
"""Test that invalid escalation level raises ValueError."""
|
||||
metadata = {'escalation_level': 'invalid'}
|
||||
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
extract_guards_from_metadata(metadata)
|
||||
|
||||
@@ -859,36 +856,36 @@ class MetadataValidationTests(TestCase):
|
||||
'escalation_level': 'admin',
|
||||
'requires_assignment': False
|
||||
}
|
||||
|
||||
|
||||
is_valid, errors = validate_guard_metadata(metadata)
|
||||
|
||||
|
||||
self.assertTrue(is_valid)
|
||||
self.assertEqual(len(errors), 0)
|
||||
|
||||
def test_invalid_escalation_level(self):
|
||||
"""Test that invalid escalation level fails validation."""
|
||||
metadata = {'escalation_level': 'invalid_level'}
|
||||
|
||||
|
||||
is_valid, errors = validate_guard_metadata(metadata)
|
||||
|
||||
|
||||
self.assertFalse(is_valid)
|
||||
self.assertTrue(any('escalation_level' in e for e in errors))
|
||||
|
||||
def test_invalid_boolean_field(self):
|
||||
"""Test that non-boolean value for boolean field fails validation."""
|
||||
metadata = {'requires_moderator': 'yes'}
|
||||
|
||||
|
||||
is_valid, errors = validate_guard_metadata(metadata)
|
||||
|
||||
|
||||
self.assertFalse(is_valid)
|
||||
self.assertTrue(any('requires_moderator' in e for e in errors))
|
||||
|
||||
def test_required_permissions_not_list(self):
|
||||
"""Test that non-list required_permissions fails validation."""
|
||||
metadata = {'required_permissions': 'app.permission'}
|
||||
|
||||
|
||||
is_valid, errors = validate_guard_metadata(metadata)
|
||||
|
||||
|
||||
self.assertFalse(is_valid)
|
||||
self.assertTrue(any('required_permissions' in e for e in errors))
|
||||
|
||||
@@ -965,7 +962,7 @@ class RoleHelperTests(TestCase):
|
||||
def test_anonymous_user_has_no_role(self):
|
||||
"""Test that anonymous user has no role."""
|
||||
anonymous = AnonymousUser()
|
||||
|
||||
|
||||
self.assertFalse(has_role(anonymous, ['USER']))
|
||||
self.assertFalse(is_moderator_or_above(anonymous))
|
||||
self.assertFalse(is_admin_or_above(anonymous))
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
"""Integration tests for state machine model integration."""
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
|
||||
import pytest
|
||||
|
||||
from apps.core.choices.base import RichChoice
|
||||
from apps.core.choices.registry import registry
|
||||
from apps.core.state_machine.integration import (
|
||||
StateMachineModelMixin,
|
||||
apply_state_machine,
|
||||
generate_transition_methods_for_model,
|
||||
StateMachineModelMixin,
|
||||
state_machine_model,
|
||||
validate_model_state_machine,
|
||||
)
|
||||
|
||||
@@ -4,8 +4,8 @@ import pytest
|
||||
from apps.core.choices.base import RichChoice
|
||||
from apps.core.choices.registry import registry
|
||||
from apps.core.state_machine.registry import (
|
||||
TransitionRegistry,
|
||||
TransitionInfo,
|
||||
TransitionRegistry,
|
||||
registry_instance,
|
||||
)
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ from apps.core.choices.base import RichChoice
|
||||
from apps.core.choices.registry import registry
|
||||
from apps.core.state_machine.validators import (
|
||||
MetadataValidator,
|
||||
ValidationResult,
|
||||
ValidationError,
|
||||
ValidationResult,
|
||||
ValidationWarning,
|
||||
validate_on_registration,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""Metadata validators for ensuring RichChoice metadata meets FSM requirements."""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Set, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from apps.core.state_machine.builder import StateTransitionBuilder
|
||||
from apps.core.choices.registry import registry
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -12,8 +11,8 @@ class ValidationError:
|
||||
|
||||
code: str
|
||||
message: str
|
||||
state: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
state: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __str__(self):
|
||||
"""String representation of the error."""
|
||||
@@ -28,7 +27,7 @@ class ValidationWarning:
|
||||
|
||||
code: str
|
||||
message: str
|
||||
state: Optional[str] = None
|
||||
state: str | None = None
|
||||
|
||||
def __str__(self):
|
||||
"""String representation of the warning."""
|
||||
@@ -42,15 +41,15 @@ class ValidationResult:
|
||||
"""Result of metadata validation."""
|
||||
|
||||
is_valid: bool
|
||||
errors: List[ValidationError] = field(default_factory=list)
|
||||
warnings: List[ValidationWarning] = field(default_factory=list)
|
||||
errors: list[ValidationError] = field(default_factory=list)
|
||||
warnings: list[ValidationWarning] = field(default_factory=list)
|
||||
|
||||
def add_error(self, code: str, message: str, state: Optional[str] = None):
|
||||
def add_error(self, code: str, message: str, state: str | None = None):
|
||||
"""Add a validation error."""
|
||||
self.errors.append(ValidationError(code, message, state))
|
||||
self.is_valid = False
|
||||
|
||||
def add_warning(self, code: str, message: str, state: Optional[str] = None):
|
||||
def add_warning(self, code: str, message: str, state: str | None = None):
|
||||
"""Add a validation warning."""
|
||||
self.warnings.append(ValidationWarning(code, message, state))
|
||||
|
||||
@@ -91,7 +90,7 @@ class MetadataValidator:
|
||||
|
||||
return result
|
||||
|
||||
def validate_transitions(self) -> List[ValidationError]:
|
||||
def validate_transitions(self) -> list[ValidationError]:
|
||||
"""
|
||||
Check all can_transition_to references exist.
|
||||
|
||||
@@ -148,7 +147,7 @@ class MetadataValidator:
|
||||
|
||||
return errors
|
||||
|
||||
def validate_terminal_states(self) -> List[ValidationError]:
|
||||
def validate_terminal_states(self) -> list[ValidationError]:
|
||||
"""
|
||||
Ensure terminal states have no outgoing transitions.
|
||||
|
||||
@@ -175,7 +174,7 @@ class MetadataValidator:
|
||||
|
||||
return errors
|
||||
|
||||
def validate_permission_consistency(self) -> List[ValidationError]:
|
||||
def validate_permission_consistency(self) -> list[ValidationError]:
|
||||
"""
|
||||
Check permission requirements are consistent.
|
||||
|
||||
@@ -206,7 +205,7 @@ class MetadataValidator:
|
||||
|
||||
return errors
|
||||
|
||||
def validate_no_cycles(self) -> List[ValidationError]:
|
||||
def validate_no_cycles(self) -> list[ValidationError]:
|
||||
"""
|
||||
Detect invalid state cycles (excluding self-loops).
|
||||
|
||||
@@ -224,10 +223,10 @@ class MetadataValidator:
|
||||
pass
|
||||
|
||||
# Detect cycles using DFS
|
||||
visited: Set[str] = set()
|
||||
rec_stack: Set[str] = set()
|
||||
visited: set[str] = set()
|
||||
rec_stack: set[str] = set()
|
||||
|
||||
def has_cycle(node: str, path: List[str]) -> Optional[List[str]]:
|
||||
def has_cycle(node: str, path: list[str]) -> list[str] | None:
|
||||
visited.add(node)
|
||||
rec_stack.add(node)
|
||||
path.append(node)
|
||||
@@ -262,7 +261,7 @@ class MetadataValidator:
|
||||
|
||||
return errors
|
||||
|
||||
def validate_reachability(self) -> List[ValidationError]:
|
||||
def validate_reachability(self) -> list[ValidationError]:
|
||||
"""
|
||||
Ensure all states are reachable from initial states.
|
||||
|
||||
@@ -274,7 +273,7 @@ class MetadataValidator:
|
||||
all_states = set(self.builder.get_all_states())
|
||||
|
||||
# Find states with no incoming transitions (potential initial states)
|
||||
incoming: Dict[str, List[str]] = {state: [] for state in all_states}
|
||||
incoming: dict[str, list[str]] = {state: [] for state in all_states}
|
||||
for source, targets in graph.items():
|
||||
for target in targets:
|
||||
incoming[target].append(source)
|
||||
@@ -293,7 +292,7 @@ class MetadataValidator:
|
||||
return errors
|
||||
|
||||
# BFS from initial states to find reachable states
|
||||
reachable: Set[str] = set(initial_states)
|
||||
reachable: set[str] = set(initial_states)
|
||||
queue = list(initial_states)
|
||||
|
||||
while queue:
|
||||
|
||||
@@ -7,12 +7,13 @@ All tasks run asynchronously to avoid blocking the main application.
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from django.utils import timezone
|
||||
from django.core.cache import cache
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.core.cache import cache
|
||||
from django.db.models import Q
|
||||
from django.utils import timezone
|
||||
|
||||
from apps.core.analytics import PageView
|
||||
from apps.parks.models import Park
|
||||
@@ -24,7 +25,7 @@ logger = logging.getLogger(__name__)
|
||||
@shared_task(bind=True, max_retries=3, default_retry_delay=60)
|
||||
def calculate_trending_content(
|
||||
self, content_type: str = "all", limit: int = 50
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Calculate trending content using real analytics data.
|
||||
|
||||
@@ -100,7 +101,7 @@ def calculate_trending_content(
|
||||
@shared_task(bind=True, max_retries=3, default_retry_delay=30)
|
||||
def calculate_new_content(
|
||||
self, content_type: str = "all", days_back: int = 30, limit: int = 50
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Calculate new content based on opening dates and creation dates.
|
||||
|
||||
@@ -157,7 +158,7 @@ def calculate_new_content(
|
||||
|
||||
|
||||
@shared_task(bind=True)
|
||||
def warm_trending_cache(self) -> Dict[str, Any]:
|
||||
def warm_trending_cache(self) -> dict[str, Any]:
|
||||
"""
|
||||
Warm the trending cache by pre-calculating common queries.
|
||||
|
||||
@@ -208,7 +209,7 @@ def warm_trending_cache(self) -> Dict[str, Any]:
|
||||
|
||||
def _calculate_trending_parks(
|
||||
current_period_hours: int, previous_period_hours: int, limit: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Calculate trending scores for parks using real data."""
|
||||
parks = Park.objects.filter(status="OPERATING").select_related(
|
||||
"location", "operator"
|
||||
@@ -247,7 +248,7 @@ def _calculate_trending_parks(
|
||||
|
||||
def _calculate_trending_rides(
|
||||
current_period_hours: int, previous_period_hours: int, limit: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Calculate trending scores for rides using real data."""
|
||||
rides = Ride.objects.filter(status="OPERATING").select_related(
|
||||
"park", "park__location"
|
||||
@@ -453,7 +454,7 @@ def _calculate_popularity_score(
|
||||
return 0.0
|
||||
|
||||
|
||||
def _get_new_parks(cutoff_date: datetime, limit: int) -> List[Dict[str, Any]]:
|
||||
def _get_new_parks(cutoff_date: datetime, limit: int) -> list[dict[str, Any]]:
|
||||
"""Get recently added parks using real data."""
|
||||
new_parks = (
|
||||
Park.objects.filter(
|
||||
@@ -467,9 +468,8 @@ def _get_new_parks(cutoff_date: datetime, limit: int) -> List[Dict[str, Any]]:
|
||||
results = []
|
||||
for park in new_parks:
|
||||
date_added = park.opening_date or park.created_at
|
||||
if date_added:
|
||||
if isinstance(date_added, datetime):
|
||||
date_added = date_added.date()
|
||||
if date_added and isinstance(date_added, datetime):
|
||||
date_added = date_added.date()
|
||||
|
||||
opening_date = getattr(park, "opening_date", None)
|
||||
if opening_date and isinstance(opening_date, datetime):
|
||||
@@ -492,7 +492,7 @@ def _get_new_parks(cutoff_date: datetime, limit: int) -> List[Dict[str, Any]]:
|
||||
return results
|
||||
|
||||
|
||||
def _get_new_rides(cutoff_date: datetime, limit: int) -> List[Dict[str, Any]]:
|
||||
def _get_new_rides(cutoff_date: datetime, limit: int) -> list[dict[str, Any]]:
|
||||
"""Get recently added rides using real data."""
|
||||
new_rides = (
|
||||
Ride.objects.filter(
|
||||
@@ -508,9 +508,8 @@ def _get_new_rides(cutoff_date: datetime, limit: int) -> List[Dict[str, Any]]:
|
||||
date_added = getattr(ride, "opening_date", None) or getattr(
|
||||
ride, "created_at", None
|
||||
)
|
||||
if date_added:
|
||||
if isinstance(date_added, datetime):
|
||||
date_added = date_added.date()
|
||||
if date_added and isinstance(date_added, datetime):
|
||||
date_added = date_added.date()
|
||||
|
||||
opening_date = getattr(ride, "opening_date", None)
|
||||
if opening_date and isinstance(opening_date, datetime):
|
||||
@@ -534,10 +533,10 @@ def _get_new_rides(cutoff_date: datetime, limit: int) -> List[Dict[str, Any]]:
|
||||
|
||||
|
||||
def _format_trending_results(
|
||||
trending_items: List[Dict[str, Any]],
|
||||
trending_items: list[dict[str, Any]],
|
||||
current_period_hours: int,
|
||||
previous_period_hours: int,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format trending results for frontend consumption."""
|
||||
formatted_results = []
|
||||
|
||||
@@ -581,8 +580,8 @@ def _format_trending_results(
|
||||
|
||||
|
||||
def _format_new_content_results(
|
||||
new_items: List[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
new_items: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format new content results for frontend consumption."""
|
||||
formatted_results = []
|
||||
|
||||
|
||||
@@ -14,10 +14,10 @@ Usage:
|
||||
"""
|
||||
|
||||
from datetime import timedelta
|
||||
|
||||
from django import template
|
||||
from django.template.defaultfilters import stringfilter
|
||||
from django.utils import timezone
|
||||
from django.utils.html import format_html
|
||||
|
||||
register = template.Library()
|
||||
|
||||
|
||||
@@ -23,13 +23,13 @@ Usage:
|
||||
{# Render a transition button #}
|
||||
{% transition_button submission 'approve' request.user %}
|
||||
"""
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any
|
||||
|
||||
from django import template
|
||||
from django.urls import reverse, NoReverseMatch
|
||||
from django.urls import NoReverseMatch, reverse
|
||||
from django_fsm import can_proceed
|
||||
|
||||
from apps.core.views.views import get_transition_metadata, TRANSITION_METADATA
|
||||
from apps.core.views.views import get_transition_metadata
|
||||
|
||||
register = template.Library()
|
||||
|
||||
@@ -40,7 +40,7 @@ register = template.Library()
|
||||
|
||||
|
||||
@register.filter
|
||||
def get_state_value(obj) -> Optional[str]:
|
||||
def get_state_value(obj) -> str | None:
|
||||
"""
|
||||
Get the current state value of an FSM-enabled object.
|
||||
|
||||
@@ -171,7 +171,7 @@ def default_target_id(obj) -> str:
|
||||
|
||||
|
||||
@register.simple_tag
|
||||
def get_available_transitions(obj, user) -> List[Dict[str, Any]]:
|
||||
def get_available_transitions(obj, user) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get all available transitions for an object that the user can execute.
|
||||
|
||||
|
||||
@@ -28,19 +28,26 @@ Usage:
|
||||
{% icon "check" class="w-4 h-4" %}
|
||||
"""
|
||||
|
||||
import json
|
||||
from django import template
|
||||
from django.utils.safestring import mark_safe
|
||||
|
||||
from apps.core.utils.html_sanitizer import (
|
||||
sanitize_html,
|
||||
sanitize_minimal as _sanitize_minimal,
|
||||
sanitize_svg,
|
||||
strip_html as _strip_html,
|
||||
sanitize_for_json,
|
||||
escape_js_string as _escape_js_string,
|
||||
sanitize_url as _sanitize_url,
|
||||
)
|
||||
from apps.core.utils.html_sanitizer import (
|
||||
sanitize_attribute_value,
|
||||
sanitize_for_json,
|
||||
sanitize_html,
|
||||
sanitize_svg,
|
||||
)
|
||||
from apps.core.utils.html_sanitizer import (
|
||||
sanitize_minimal as _sanitize_minimal,
|
||||
)
|
||||
from apps.core.utils.html_sanitizer import (
|
||||
sanitize_url as _sanitize_url,
|
||||
)
|
||||
from apps.core.utils.html_sanitizer import (
|
||||
strip_html as _strip_html,
|
||||
)
|
||||
|
||||
register = template.Library()
|
||||
|
||||
@@ -5,7 +5,6 @@ These tests verify the functionality of the base admin classes and mixins
|
||||
that provide standardized behavior across all admin interfaces.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from django.contrib.admin.sites import AdminSite
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.test import RequestFactory, TestCase
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
|
||||
import pghistory
|
||||
import pytest
|
||||
from django.contrib.auth import get_user_model
|
||||
from apps.parks.models import Park, Company
|
||||
import pghistory
|
||||
|
||||
from apps.parks.models import Company, Park
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
@@ -16,7 +17,7 @@ class TestTrackedModel:
|
||||
"""Test that creating a model instance creates a history event."""
|
||||
user = User.objects.create_user(username="testuser", password="password")
|
||||
company = Company.objects.create(name="Test Operator", roles=["OPERATOR"])
|
||||
|
||||
|
||||
with pghistory.context(user=user.id):
|
||||
park = Park.objects.create(
|
||||
name="History Test Park",
|
||||
@@ -24,13 +25,13 @@ class TestTrackedModel:
|
||||
operating_season="Summer",
|
||||
operator=company
|
||||
)
|
||||
|
||||
|
||||
# Verify history using the helper method from TrackedModel
|
||||
events = park.get_history()
|
||||
assert events.count() == 1
|
||||
event = events.first()
|
||||
assert event.pgh_obj_id == park.pk
|
||||
|
||||
|
||||
# Verify context was captured
|
||||
# The middleware isn't running here, so we used pghistory.context explicitly
|
||||
# But pghistory.context stores data in pgh_context field if configured?
|
||||
@@ -40,15 +41,15 @@ class TestTrackedModel:
|
||||
def test_update_tracking(self):
|
||||
company = Company.objects.create(name="Test Operator 2", roles=["OPERATOR"])
|
||||
park = Park.objects.create(name="Original", operator=company)
|
||||
|
||||
|
||||
# Initial create event
|
||||
assert park.get_history().count() == 1
|
||||
|
||||
|
||||
# Update
|
||||
park.name = "Updated"
|
||||
park.save()
|
||||
|
||||
|
||||
assert park.get_history().count() == 2
|
||||
latest = park.get_history().first() # Ordered by -pgh_created_at
|
||||
assert latest.name == "Updated"
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
Core app URL configuration.
|
||||
"""
|
||||
|
||||
from django.urls import path, include
|
||||
from django.urls import include, path
|
||||
|
||||
from ..views.entity_search import (
|
||||
EntityFuzzySearchView,
|
||||
EntityNotFoundView,
|
||||
|
||||
@@ -3,13 +3,14 @@ URL patterns for the unified map service API.
|
||||
"""
|
||||
|
||||
from django.urls import path
|
||||
|
||||
from ..views.map_views import (
|
||||
MapLocationsView,
|
||||
MapLocationDetailView,
|
||||
MapSearchView,
|
||||
MapBoundsView,
|
||||
MapStatsView,
|
||||
MapCacheView,
|
||||
MapLocationDetailView,
|
||||
MapLocationsView,
|
||||
MapSearchView,
|
||||
MapStatsView,
|
||||
)
|
||||
|
||||
app_name = "map_api"
|
||||
|
||||
@@ -4,15 +4,16 @@ Includes both HTML views and HTMX endpoints.
|
||||
"""
|
||||
|
||||
from django.urls import path
|
||||
|
||||
from ..views.maps import (
|
||||
UniversalMapView,
|
||||
ParkMapView,
|
||||
NearbyLocationsView,
|
||||
LocationDetailModalView,
|
||||
LocationFilterView,
|
||||
LocationListView,
|
||||
LocationSearchView,
|
||||
MapBoundsUpdateView,
|
||||
LocationDetailModalView,
|
||||
LocationListView,
|
||||
NearbyLocationsView,
|
||||
ParkMapView,
|
||||
UniversalMapView,
|
||||
)
|
||||
|
||||
app_name = "maps"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from django.urls import path
|
||||
|
||||
from apps.core.views.search import (
|
||||
AdaptiveSearchView,
|
||||
FilterFormView,
|
||||
|
||||
@@ -29,7 +29,8 @@ Usage Examples:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import contextlib
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urljoin
|
||||
|
||||
@@ -351,19 +352,15 @@ def get_model_breadcrumb(
|
||||
parent_list_url = f"{parent_model_name}s:list"
|
||||
parent_list_label = f"{parent.__class__.__name__}s"
|
||||
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
builder.add_from_url(parent_list_url, parent_list_label)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
builder.add_model(parent)
|
||||
|
||||
# Add list page breadcrumb
|
||||
if list_url_name and list_label:
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
builder.add_from_url(list_url_name, list_label)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Add current model instance
|
||||
builder.add_model_current(instance)
|
||||
|
||||
@@ -1,53 +1,54 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_direct_upload_url(user_id=None):
|
||||
"""
|
||||
Generates a direct upload URL for Cloudflare Images.
|
||||
|
||||
|
||||
Args:
|
||||
user_id (str, optional): The user ID to associate with the upload.
|
||||
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing 'id' and 'uploadURL'.
|
||||
|
||||
|
||||
Raises:
|
||||
ImproperlyConfigured: If Cloudflare settings are missing.
|
||||
requests.RequestException: If the Cloudflare API request fails.
|
||||
"""
|
||||
account_id = getattr(settings, 'CLOUDFLARE_IMAGES_ACCOUNT_ID', None)
|
||||
api_token = getattr(settings, 'CLOUDFLARE_IMAGES_API_TOKEN', None)
|
||||
|
||||
|
||||
if not account_id or not api_token:
|
||||
raise ImproperlyConfigured(
|
||||
"CLOUDFLARE_IMAGES_ACCOUNT_ID and CLOUDFLARE_IMAGES_API_TOKEN must be set."
|
||||
)
|
||||
|
||||
url = f"https://api.cloudflare.com/client/v4/accounts/{account_id}/images/v2/direct_upload"
|
||||
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_token}",
|
||||
}
|
||||
|
||||
|
||||
data = {
|
||||
"requireSignedURLs": "false",
|
||||
}
|
||||
|
||||
|
||||
if user_id:
|
||||
data["metadata"] = f'{{"user_id": "{user_id}"}}'
|
||||
|
||||
response = requests.post(url, headers=headers, data=data)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
result = response.json()
|
||||
|
||||
|
||||
if not result.get("success"):
|
||||
error_msg = result.get("errors", [{"message": "Unknown error"}])[0].get("message")
|
||||
logger.error(f"Cloudflare Direct Upload Error: {error_msg}")
|
||||
raise requests.RequestException(f"Cloudflare Error: {error_msg}")
|
||||
|
||||
|
||||
return result.get("result", {})
|
||||
|
||||
@@ -6,7 +6,7 @@ ensuring consistent logging, user messages, and API responses.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from django.contrib import messages
|
||||
from django.http import HttpRequest
|
||||
@@ -26,7 +26,7 @@ class ErrorHandler:
|
||||
request: HttpRequest,
|
||||
error: Exception,
|
||||
user_message: str = "An error occurred",
|
||||
log_message: Optional[str] = None,
|
||||
log_message: str | None = None,
|
||||
level: str = "error",
|
||||
) -> None:
|
||||
"""
|
||||
@@ -68,7 +68,7 @@ class ErrorHandler:
|
||||
def handle_api_error(
|
||||
error: Exception,
|
||||
user_message: str = "An error occurred",
|
||||
log_message: Optional[str] = None,
|
||||
log_message: str | None = None,
|
||||
status_code: int = status.HTTP_400_BAD_REQUEST,
|
||||
) -> Response:
|
||||
"""
|
||||
@@ -99,7 +99,7 @@ class ErrorHandler:
|
||||
logger.error(log_msg, exc_info=True)
|
||||
|
||||
# Build error response
|
||||
error_data: Dict[str, Any] = {
|
||||
error_data: dict[str, Any] = {
|
||||
"error": user_message,
|
||||
"detail": str(error),
|
||||
}
|
||||
@@ -150,7 +150,7 @@ class ErrorHandler:
|
||||
Returns:
|
||||
DRF Response with success data in standard format
|
||||
"""
|
||||
response_data: Dict[str, Any] = {
|
||||
response_data: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"message": message,
|
||||
}
|
||||
|
||||
@@ -30,7 +30,6 @@ import os
|
||||
import re
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Optional, Set, Tuple
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.core.files.uploadedfile import UploadedFile
|
||||
@@ -72,7 +71,7 @@ IMAGE_SIGNATURES = {
|
||||
}
|
||||
|
||||
# All allowed MIME types
|
||||
ALLOWED_IMAGE_MIME_TYPES: Set[str] = frozenset({
|
||||
ALLOWED_IMAGE_MIME_TYPES: set[str] = frozenset({
|
||||
'image/jpeg',
|
||||
'image/png',
|
||||
'image/gif',
|
||||
@@ -80,7 +79,7 @@ ALLOWED_IMAGE_MIME_TYPES: Set[str] = frozenset({
|
||||
})
|
||||
|
||||
# Allowed file extensions
|
||||
ALLOWED_IMAGE_EXTENSIONS: Set[str] = frozenset({
|
||||
ALLOWED_IMAGE_EXTENSIONS: set[str] = frozenset({
|
||||
'.jpg', '.jpeg', '.png', '.gif', '.webp',
|
||||
})
|
||||
|
||||
@@ -98,8 +97,8 @@ MIN_FILE_SIZE = 100 # 100 bytes
|
||||
def validate_image_upload(
|
||||
file: UploadedFile,
|
||||
max_size: int = MAX_FILE_SIZE,
|
||||
allowed_types: Optional[Set[str]] = None,
|
||||
allowed_extensions: Optional[Set[str]] = None,
|
||||
allowed_types: set[str] | None = None,
|
||||
allowed_extensions: set[str] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Validate an uploaded image file for security.
|
||||
@@ -191,15 +190,14 @@ def _validate_magic_number(file: UploadedFile) -> bool:
|
||||
|
||||
# Check against known signatures
|
||||
for format_name, signatures in IMAGE_SIGNATURES.items():
|
||||
for magic, offset, description in signatures:
|
||||
if len(header) >= offset + len(magic):
|
||||
if header[offset:offset + len(magic)] == magic:
|
||||
# Special handling for WebP (must also have WEBP marker)
|
||||
if format_name == 'webp':
|
||||
if len(header) >= 12 and header[8:12] == b'WEBP':
|
||||
return True
|
||||
else:
|
||||
for magic, offset, _description in signatures:
|
||||
if len(header) >= offset + len(magic) and header[offset:offset + len(magic)] == magic:
|
||||
# Special handling for WebP (must also have WEBP marker)
|
||||
if format_name == 'webp':
|
||||
if len(header) >= 12 and header[8:12] == b'WEBP':
|
||||
return True
|
||||
else:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -340,7 +338,7 @@ UPLOAD_RATE_LIMITS = {
|
||||
}
|
||||
|
||||
|
||||
def check_upload_rate_limit(user_id: int, cache_backend=None) -> Tuple[bool, str]:
|
||||
def check_upload_rate_limit(user_id: int, cache_backend=None) -> tuple[bool, str]:
|
||||
"""
|
||||
Check if user has exceeded upload rate limits.
|
||||
|
||||
@@ -414,7 +412,7 @@ def increment_upload_count(user_id: int, cache_backend=None) -> None:
|
||||
# Antivirus Integration Point
|
||||
# =============================================================================
|
||||
|
||||
def scan_file_for_malware(file: UploadedFile) -> Tuple[bool, str]:
|
||||
def scan_file_for_malware(file: UploadedFile) -> tuple[bool, str]:
|
||||
"""
|
||||
Placeholder for antivirus/malware scanning integration.
|
||||
|
||||
|
||||
@@ -16,8 +16,6 @@ Usage Examples:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def success_created(
|
||||
model_name: str,
|
||||
|
||||
@@ -2,14 +2,15 @@
|
||||
Database query optimization utilities and helpers.
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Dict, Any, List, Type
|
||||
from django.db import connection, models
|
||||
from django.db.models import QuerySet, Prefetch, Count, Avg, Max
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.cache import cache
|
||||
from django.db import connection, models
|
||||
from django.db.models import Avg, Count, Max, Prefetch, QuerySet
|
||||
|
||||
logger = logging.getLogger("query_optimization")
|
||||
|
||||
@@ -136,7 +137,7 @@ class QueryOptimizer:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_bulk_queryset(model: Type[models.Model], ids: List[int]) -> QuerySet:
|
||||
def create_bulk_queryset(model: type[models.Model], ids: list[int]) -> QuerySet:
|
||||
"""
|
||||
Create an optimized queryset for bulk operations
|
||||
"""
|
||||
@@ -186,7 +187,7 @@ class QueryCache:
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def invalidate_model_cache(model_name: str, instance_id: Optional[int] = None):
|
||||
def invalidate_model_cache(model_name: str, instance_id: int | None = None):
|
||||
"""
|
||||
Invalidate cache keys related to a specific model
|
||||
|
||||
@@ -195,10 +196,7 @@ class QueryCache:
|
||||
instance_id: Specific instance ID, if applicable
|
||||
"""
|
||||
# Pattern-based cache invalidation (works with Redis)
|
||||
if instance_id:
|
||||
pattern = f"*{model_name}_{instance_id}*"
|
||||
else:
|
||||
pattern = f"*{model_name}*"
|
||||
pattern = f"*{model_name}_{instance_id}*" if instance_id else f"*{model_name}*"
|
||||
|
||||
try:
|
||||
# For Redis cache backends that support pattern deletion
|
||||
@@ -219,7 +217,7 @@ class IndexAnalyzer:
|
||||
"""Analyze and suggest database indexes"""
|
||||
|
||||
@staticmethod
|
||||
def analyze_slow_queries(min_time: float = 0.1) -> List[Dict[str, Any]]:
|
||||
def analyze_slow_queries(min_time: float = 0.1) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Analyze slow queries from the current request
|
||||
|
||||
@@ -244,7 +242,7 @@ class IndexAnalyzer:
|
||||
return slow_queries
|
||||
|
||||
@staticmethod
|
||||
def _analyze_query_sql(sql: str) -> Dict[str, Any]:
|
||||
def _analyze_query_sql(sql: str) -> dict[str, Any]:
|
||||
"""
|
||||
Analyze SQL to suggest potential optimizations
|
||||
"""
|
||||
@@ -285,7 +283,7 @@ class IndexAnalyzer:
|
||||
return analysis
|
||||
|
||||
@staticmethod
|
||||
def suggest_model_indexes(model: Type[models.Model]) -> List[str]:
|
||||
def suggest_model_indexes(model: type[models.Model]) -> list[str]:
|
||||
"""
|
||||
Suggest database indexes for a Django model based on its fields
|
||||
"""
|
||||
@@ -343,7 +341,7 @@ def log_query_performance():
|
||||
|
||||
|
||||
def optimize_queryset_for_serialization(
|
||||
queryset: QuerySet, fields: List[str]
|
||||
queryset: QuerySet, fields: list[str]
|
||||
) -> QuerySet:
|
||||
"""
|
||||
Optimize a queryset for API serialization by only selecting needed fields
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user