feat: Implement MFA authentication, add ride statistics model, and update various services, APIs, and tests across the application.

This commit is contained in:
pacnpal
2025-12-28 17:32:53 -05:00
parent aa56c46c27
commit c95f99ca10
452 changed files with 7948 additions and 6073 deletions

View File

@@ -13,7 +13,6 @@ from io import StringIO
from django.contrib import admin, messages
from django.core.serializers.json import DjangoJSONEncoder
from django.http import HttpResponse
from django.utils.html import format_html
class QueryOptimizationMixin:

View File

@@ -1,10 +1,11 @@
from django.db import models
from datetime import timedelta
import pghistory
from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.utils import timezone
from django.db import models
from django.db.models import Count
from datetime import timedelta
import pghistory
from django.utils import timezone
@pghistory.track()

View File

@@ -3,21 +3,27 @@ Custom exception handling for ThrillWiki API.
Provides standardized error responses following Django styleguide patterns.
"""
from typing import Any, Dict, Optional
from typing import Any
from django.http import Http404
from django.core.exceptions import (
PermissionDenied,
)
from django.core.exceptions import (
ValidationError as DjangoValidationError,
)
from django.http import Http404
from rest_framework import status
from rest_framework.response import Response
from rest_framework.views import exception_handler
from rest_framework.exceptions import (
ValidationError as DRFValidationError,
NotFound,
)
from rest_framework.exceptions import (
PermissionDenied as DRFPermissionDenied,
)
from rest_framework.exceptions import (
ValidationError as DRFValidationError,
)
from rest_framework.response import Response
from rest_framework.views import exception_handler
from ..exceptions import ThrillWikiException
from ..logging import get_logger, log_exception
@@ -26,8 +32,8 @@ logger = get_logger(__name__)
def custom_exception_handler(
exc: Exception, context: Dict[str, Any]
) -> Optional[Response]:
exc: Exception, context: dict[str, Any]
) -> Response | None:
"""
Custom exception handler for DRF that provides standardized error responses.
@@ -209,7 +215,7 @@ def _get_error_message(exc: Exception, response_data: Any) -> str:
return str(exc) if str(exc) else "An error occurred"
def _get_error_details(exc: Exception, response_data: Any) -> Optional[Dict[str, Any]]:
def _get_error_details(exc: Exception, response_data: Any) -> dict[str, Any] | None:
"""Extract detailed error information for debugging."""
if isinstance(response_data, dict) and len(response_data) > 1:
return response_data
@@ -224,7 +230,7 @@ def _get_error_details(exc: Exception, response_data: Any) -> Optional[Dict[str,
def _format_django_validation_errors(
exc: DjangoValidationError,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Format Django ValidationError for API response."""
if hasattr(exc, "error_dict"):
# Field-specific errors

View File

@@ -2,10 +2,11 @@
Common mixins for API views following Django styleguide patterns.
"""
from typing import Dict, Any, Optional, Type
from typing import Any
from rest_framework import status
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework import status
# Constants for error messages
_MISSING_INPUT_SERIALIZER_MSG = "Subclasses must set input_serializer class attribute"
@@ -20,17 +21,17 @@ class ApiMixin:
# Expose expected attributes so static type checkers know they exist on subclasses.
# Subclasses or other bases (e.g. DRF GenericAPIView) will actually provide these.
input_serializer: Optional[Type[Any]] = None
output_serializer: Optional[Type[Any]] = None
input_serializer: type[Any] | None = None
output_serializer: type[Any] | None = None
def create_response(
self,
*,
data: Any = None,
message: Optional[str] = None,
message: str | None = None,
status_code: int = status.HTTP_200_OK,
pagination: Optional[Dict[str, Any]] = None,
metadata: Optional[Dict[str, Any]] = None,
pagination: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> Response:
"""
Create standardized API response.
@@ -66,8 +67,8 @@ class ApiMixin:
*,
message: str,
status_code: int = status.HTTP_400_BAD_REQUEST,
error_code: Optional[str] = None,
details: Optional[Dict[str, Any]] = None,
error_code: str | None = None,
details: dict[str, Any] | None = None,
) -> Response:
"""
Create standardized error response.
@@ -82,7 +83,7 @@ class ApiMixin:
Standardized error Response object
"""
# explicitly allow any-shaped values in the error_data dict
error_data: Dict[str, Any] = {
error_data: dict[str, Any] = {
"code": error_code or "GENERIC_ERROR",
"message": message,
}

View File

@@ -20,9 +20,9 @@ Security checks included:
import os
import re
from django.conf import settings
from django.core.checks import Error, Warning, register, Tags
from django.conf import settings
from django.core.checks import Error, Tags, Warning, register
# =============================================================================
# Secret Key Validation

View File

@@ -12,11 +12,11 @@ Key Components:
- RichChoiceSerializer: DRF serializer for API responses
"""
from .base import RichChoice, ChoiceCategory, ChoiceGroup
from .registry import ChoiceRegistry, register_choices
from .base import ChoiceCategory, ChoiceGroup, RichChoice
from .fields import RichChoiceField
from .serializers import RichChoiceSerializer, RichChoiceOptionSerializer
from .utils import validate_choice_value, get_choice_display
from .registry import ChoiceRegistry, register_choices
from .serializers import RichChoiceOptionSerializer, RichChoiceSerializer
from .utils import get_choice_display, validate_choice_value
__all__ = [
'RichChoice',

View File

@@ -5,8 +5,8 @@ This module defines the core dataclass structures for rich choice objects.
"""
from dataclasses import dataclass, field
from typing import Dict, Any, Optional
from enum import Enum
from typing import Any
class ChoiceCategory(Enum):
@@ -30,10 +30,10 @@ class ChoiceCategory(Enum):
class RichChoice:
"""
Rich choice object with metadata support.
This replaces simple tuple choices with a comprehensive object that can
carry additional information like descriptions, colors, icons, and custom metadata.
Attributes:
value: The stored value (equivalent to first element of tuple choice)
label: Human-readable display name (equivalent to second element of tuple choice)
@@ -45,39 +45,39 @@ class RichChoice:
value: str
label: str
description: str = ""
metadata: Dict[str, Any] = field(default_factory=dict)
metadata: dict[str, Any] = field(default_factory=dict)
deprecated: bool = False
category: ChoiceCategory = ChoiceCategory.OTHER
def __post_init__(self):
"""Validate the choice object after initialization"""
if not self.value:
raise ValueError("Choice value cannot be empty")
if not self.label:
raise ValueError("Choice label cannot be empty")
@property
def color(self) -> Optional[str]:
def color(self) -> str | None:
"""Get the color from metadata if available"""
return self.metadata.get('color')
@property
def icon(self) -> Optional[str]:
def icon(self) -> str | None:
"""Get the icon from metadata if available"""
return self.metadata.get('icon')
@property
def css_class(self) -> Optional[str]:
def css_class(self) -> str | None:
"""Get the CSS class from metadata if available"""
return self.metadata.get('css_class')
@property
def sort_order(self) -> int:
"""Get the sort order from metadata, defaulting to 0"""
return self.metadata.get('sort_order', 0)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary representation for API serialization"""
return {
'value': self.value,
@@ -91,11 +91,11 @@ class RichChoice:
'css_class': self.css_class,
'sort_order': self.sort_order,
}
def __str__(self) -> str:
return self.label
def __repr__(self) -> str:
return f"RichChoice(value='{self.value}', label='{self.label}')"
@@ -104,47 +104,47 @@ class RichChoice:
class ChoiceGroup:
"""
A group of related choices with shared metadata.
This allows for organizing choices into logical groups with
common properties and behaviors.
"""
name: str
choices: list[RichChoice]
description: str = ""
metadata: Dict[str, Any] = field(default_factory=dict)
metadata: dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
"""Validate the choice group after initialization"""
if not self.name:
raise ValueError("Choice group name cannot be empty")
if not self.choices:
raise ValueError("Choice group must contain at least one choice")
# Validate that all choice values are unique within the group
values = [choice.value for choice in self.choices]
if len(values) != len(set(values)):
raise ValueError("All choice values within a group must be unique")
def get_choice(self, value: str) -> Optional[RichChoice]:
def get_choice(self, value: str) -> RichChoice | None:
"""Get a choice by its value"""
for choice in self.choices:
if choice.value == value:
return choice
return None
def get_choices_by_category(self, category: ChoiceCategory) -> list[RichChoice]:
"""Get all choices in a specific category"""
return [choice for choice in self.choices if choice.category == category]
def get_active_choices(self) -> list[RichChoice]:
"""Get all non-deprecated choices"""
return [choice for choice in self.choices if not choice.deprecated]
def to_tuple_choices(self) -> list[tuple[str, str]]:
"""Convert to legacy tuple choices format"""
return [(choice.value, choice.label) for choice in self.choices]
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary representation for API serialization"""
return {
'name': self.name,

View File

@@ -5,10 +5,9 @@ This module defines all choice objects for core system functionality,
including health checks, API statuses, and other system-level choices.
"""
from .base import RichChoice, ChoiceCategory
from .base import ChoiceCategory, RichChoice
from .registry import register_choices
# Health Check Status Choices
HEALTH_STATUSES = [
RichChoice(
@@ -128,7 +127,7 @@ ENTITY_TYPES = [
def register_core_choices():
"""Register all core system choices with the global registry"""
register_choices(
name="health_statuses",
choices=HEALTH_STATUSES,
@@ -136,7 +135,7 @@ def register_core_choices():
description="Health check status options",
metadata={'domain': 'core', 'type': 'health_status'}
)
register_choices(
name="simple_health_statuses",
choices=SIMPLE_HEALTH_STATUSES,
@@ -144,7 +143,7 @@ def register_core_choices():
description="Simple health check status options",
metadata={'domain': 'core', 'type': 'simple_health_status'}
)
register_choices(
name="entity_types",
choices=ENTITY_TYPES,

View File

@@ -4,10 +4,12 @@ Django Model Fields for Rich Choices
This module provides Django model field implementations for rich choice objects.
"""
from typing import Any, Optional
from django.db import models
from typing import Any
from django.core.exceptions import ValidationError
from django.db import models
from django.forms import ChoiceField
from .base import RichChoice
from .registry import registry
@@ -15,11 +17,11 @@ from .registry import registry
class RichChoiceField(models.CharField):
"""
Django model field for rich choice objects.
This field stores the choice value as a CharField but provides
rich choice functionality through the registry system.
"""
def __init__(
self,
choice_group: str,
@@ -30,7 +32,7 @@ class RichChoiceField(models.CharField):
):
"""
Initialize the RichChoiceField.
Args:
choice_group: Name of the choice group in the registry
domain: Domain namespace for the choice group
@@ -41,66 +43,66 @@ class RichChoiceField(models.CharField):
self.choice_group = choice_group
self.domain = domain
self.allow_deprecated = allow_deprecated
# Set choices from registry for Django admin and forms
if self.allow_deprecated:
choices_list = registry.get_choices(choice_group, domain)
else:
choices_list = registry.get_active_choices(choice_group, domain)
choices = [(choice.value, choice.label) for choice in choices_list]
kwargs['choices'] = choices
kwargs['max_length'] = max_length
super().__init__(**kwargs)
def validate(self, value: Any, model_instance: Any) -> None:
"""Validate the choice value"""
super().validate(value, model_instance)
if value is None or value == '':
return
# Check if choice exists in registry
choice = registry.get_choice(self.choice_group, value, self.domain)
if choice is None:
raise ValidationError(
f"'{value}' is not a valid choice for {self.choice_group}"
)
# Check if deprecated choices are allowed
if choice.deprecated and not self.allow_deprecated:
raise ValidationError(
f"'{value}' is deprecated and cannot be used for new entries"
)
def get_rich_choice(self, value: str) -> Optional[RichChoice]:
def get_rich_choice(self, value: str) -> RichChoice | None:
"""Get the RichChoice object for a value"""
return registry.get_choice(self.choice_group, value, self.domain)
def get_choice_display(self, value: str) -> str:
"""Get the display label for a choice value"""
return registry.get_choice_display(self.choice_group, value, self.domain)
def contribute_to_class(self, cls: Any, name: str, private_only: bool = False, **kwargs: Any) -> None:
"""Add helper methods to the model class (signature compatible with Django Field)"""
super().contribute_to_class(cls, name, private_only=private_only, **kwargs)
# Add get_FOO_rich_choice method
def get_rich_choice_method(instance):
value = getattr(instance, name)
return self.get_rich_choice(value) if value else None
setattr(cls, f'get_{name}_rich_choice', get_rich_choice_method)
# Add get_FOO_display method (Django provides this, but we enhance it)
def get_display_method(instance):
value = getattr(instance, name)
return self.get_choice_display(value) if value else ''
setattr(cls, f'get_{name}_display', get_display_method)
def deconstruct(self):
"""Support for Django migrations"""
name, path, args, kwargs = super().deconstruct()
@@ -114,7 +116,7 @@ class RichChoiceFormField(ChoiceField):
"""
Form field for rich choices with enhanced functionality.
"""
def __init__(
self,
choice_group: str,
@@ -125,7 +127,7 @@ class RichChoiceFormField(ChoiceField):
):
"""
Initialize the form field.
Args:
choice_group: Name of the choice group in the registry
domain: Domain namespace for the choice group
@@ -137,13 +139,13 @@ class RichChoiceFormField(ChoiceField):
self.domain = domain
self.allow_deprecated = allow_deprecated
self.show_descriptions = show_descriptions
# Get choices from registry
if allow_deprecated:
choices_list = registry.get_choices(choice_group, domain)
else:
choices_list = registry.get_active_choices(choice_group, domain)
# Format choices for display
choices = []
for choice in choices_list:
@@ -151,24 +153,24 @@ class RichChoiceFormField(ChoiceField):
if show_descriptions and choice.description:
label = f"{choice.label} - {choice.description}"
choices.append((choice.value, label))
kwargs['choices'] = choices
super().__init__(**kwargs)
def validate(self, value: Any) -> None:
"""Validate the choice value"""
super().validate(value)
if value is None or value == '':
return
# Check if choice exists in registry
choice = registry.get_choice(self.choice_group, value, self.domain)
if choice is None:
raise ValidationError(
f"'{value}' is not a valid choice for {self.choice_group}"
)
# Check if deprecated choices are allowed
if choice.deprecated and not self.allow_deprecated:
raise ValidationError(
@@ -185,7 +187,7 @@ def create_rich_choice_field(
) -> RichChoiceField:
"""
Factory function to create a RichChoiceField.
This is useful for creating fields with consistent settings
across multiple models.
"""

View File

@@ -4,55 +4,57 @@ Choice Registry
Centralized registry for managing all choice definitions across the application.
"""
from typing import Dict, List, Optional, Any
from typing import Any
from django.core.exceptions import ImproperlyConfigured
from .base import RichChoice, ChoiceGroup
from .base import ChoiceGroup, RichChoice
class ChoiceRegistry:
"""
Centralized registry for managing all choice definitions.
This provides a single source of truth for all choice objects
throughout the application, with support for namespacing by domain.
"""
def __init__(self):
self._choices: Dict[str, ChoiceGroup] = {}
self._domains: Dict[str, List[str]] = {}
self._choices: dict[str, ChoiceGroup] = {}
self._domains: dict[str, list[str]] = {}
def register(
self,
name: str,
choices: List[RichChoice],
self,
name: str,
choices: list[RichChoice],
domain: str = "core",
description: str = "",
metadata: Optional[Dict[str, Any]] = None
metadata: dict[str, Any] | None = None
) -> ChoiceGroup:
"""
Register a group of choices.
Args:
name: Unique name for the choice group
choices: List of RichChoice objects
domain: Domain namespace (e.g., 'rides', 'parks', 'accounts')
description: Description of the choice group
metadata: Additional metadata for the group
Returns:
The registered ChoiceGroup
Raises:
ImproperlyConfigured: If name is already registered with different choices
"""
full_name = f"{domain}.{name}"
if full_name in self._choices:
# Check if the existing registration is identical
existing_group = self._choices[full_name]
existing_values = [choice.value for choice in existing_group.choices]
new_values = [choice.value for choice in choices]
if existing_values == new_values:
# Same choices, return existing group (allow duplicate registration)
return existing_group
@@ -62,69 +64,69 @@ class ChoiceRegistry:
f"Choice group '{full_name}' is already registered with different choices. "
f"Existing: {existing_values}, New: {new_values}"
)
choice_group = ChoiceGroup(
name=full_name,
choices=choices,
description=description,
metadata=metadata or {}
)
self._choices[full_name] = choice_group
# Track domain
if domain not in self._domains:
self._domains[domain] = []
self._domains[domain].append(name)
return choice_group
def get(self, name: str, domain: str = "core") -> Optional[ChoiceGroup]:
def get(self, name: str, domain: str = "core") -> ChoiceGroup | None:
"""Get a choice group by name and domain"""
full_name = f"{domain}.{name}"
return self._choices.get(full_name)
def get_choice(self, group_name: str, value: str, domain: str = "core") -> Optional[RichChoice]:
def get_choice(self, group_name: str, value: str, domain: str = "core") -> RichChoice | None:
"""Get a specific choice by group name, value, and domain"""
choice_group = self.get(group_name, domain)
if choice_group:
return choice_group.get_choice(value)
return None
def get_choices(self, name: str, domain: str = "core") -> List[RichChoice]:
def get_choices(self, name: str, domain: str = "core") -> list[RichChoice]:
"""Get all choices in a group"""
choice_group = self.get(name, domain)
return choice_group.choices if choice_group else []
def get_active_choices(self, name: str, domain: str = "core") -> List[RichChoice]:
def get_active_choices(self, name: str, domain: str = "core") -> list[RichChoice]:
"""Get all non-deprecated choices in a group"""
choice_group = self.get(name, domain)
return choice_group.get_active_choices() if choice_group else []
def get_domains(self) -> List[str]:
def get_domains(self) -> list[str]:
"""Get all registered domains"""
return list(self._domains.keys())
def get_domain_choices(self, domain: str) -> Dict[str, ChoiceGroup]:
def get_domain_choices(self, domain: str) -> dict[str, ChoiceGroup]:
"""Get all choice groups for a specific domain"""
if domain not in self._domains:
return {}
return {
name: self._choices[f"{domain}.{name}"]
for name in self._domains[domain]
}
def list_all(self) -> Dict[str, ChoiceGroup]:
def list_all(self) -> dict[str, ChoiceGroup]:
"""Get all registered choice groups"""
return self._choices.copy()
def validate_choice(self, group_name: str, value: str, domain: str = "core") -> bool:
"""Validate that a choice value exists in a group"""
choice = self.get_choice(group_name, value, domain)
return choice is not None and not choice.deprecated
def get_choice_display(self, group_name: str, value: str, domain: str = "core") -> str:
"""Get the display label for a choice value"""
choice = self.get_choice(group_name, value, domain)
@@ -132,7 +134,7 @@ class ChoiceRegistry:
return choice.label
else:
raise ValueError(f"Choice value '{value}' not found in group '{group_name}' for domain '{domain}'")
def clear_domain(self, domain: str) -> None:
"""Clear all choices for a specific domain (useful for testing)"""
if domain in self._domains:
@@ -141,7 +143,7 @@ class ChoiceRegistry:
if full_name in self._choices:
del self._choices[full_name]
del self._domains[domain]
def clear_all(self) -> None:
"""Clear all registered choices (useful for testing)"""
self._choices.clear()
@@ -154,33 +156,33 @@ registry = ChoiceRegistry()
def register_choices(
name: str,
choices: List[RichChoice],
choices: list[RichChoice],
domain: str = "core",
description: str = "",
metadata: Optional[Dict[str, Any]] = None
metadata: dict[str, Any] | None = None
) -> ChoiceGroup:
"""
Convenience function to register choices with the global registry.
Args:
name: Unique name for the choice group
choices: List of RichChoice objects
domain: Domain namespace
description: Description of the choice group
metadata: Additional metadata for the group
Returns:
The registered ChoiceGroup
"""
return registry.register(name, choices, domain, description, metadata)
def get_choices(name: str, domain: str = "core") -> List[RichChoice]:
def get_choices(name: str, domain: str = "core") -> list[RichChoice]:
"""Get choices from the global registry"""
return registry.get_choices(name, domain)
def get_choice(group_name: str, value: str, domain: str = "core") -> Optional[RichChoice]:
def get_choice(group_name: str, value: str, domain: str = "core") -> RichChoice | None:
"""Get a specific choice from the global registry"""
return registry.get_choice(group_name, value, domain)

View File

@@ -5,16 +5,18 @@ This module provides Django REST Framework serializer implementations
for rich choice objects.
"""
from typing import Any, Dict, List
from typing import Any
from rest_framework import serializers
from .base import RichChoice, ChoiceGroup
from .base import ChoiceGroup, RichChoice
from .registry import registry
class RichChoiceSerializer(serializers.Serializer):
"""
Serializer for individual RichChoice objects.
This provides a consistent API representation for choice objects
with all their metadata.
"""
@@ -28,8 +30,8 @@ class RichChoiceSerializer(serializers.Serializer):
icon = serializers.CharField(allow_null=True)
css_class = serializers.CharField(allow_null=True)
sort_order = serializers.IntegerField()
def to_representation(self, instance: RichChoice) -> Dict[str, Any]:
def to_representation(self, instance: RichChoice) -> dict[str, Any]:
"""Convert RichChoice to dictionary representation"""
return instance.to_dict()
@@ -37,7 +39,7 @@ class RichChoiceSerializer(serializers.Serializer):
class RichChoiceOptionSerializer(serializers.Serializer):
"""
Serializer for choice options in filter endpoints.
This replaces the legacy FilterOptionSerializer with rich choice support.
"""
value = serializers.CharField()
@@ -50,8 +52,8 @@ class RichChoiceOptionSerializer(serializers.Serializer):
icon = serializers.CharField(allow_null=True, required=False)
css_class = serializers.CharField(allow_null=True, required=False)
metadata = serializers.DictField(required=False)
def to_representation(self, instance) -> Dict[str, Any]:
def to_representation(self, instance) -> dict[str, Any]:
"""Convert choice option to dictionary representation"""
if isinstance(instance, RichChoice):
# Convert RichChoice to option format
@@ -88,7 +90,7 @@ class RichChoiceOptionSerializer(serializers.Serializer):
class ChoiceGroupSerializer(serializers.Serializer):
"""
Serializer for ChoiceGroup objects.
This provides API representation for entire choice groups
with all their choices and metadata.
"""
@@ -96,8 +98,8 @@ class ChoiceGroupSerializer(serializers.Serializer):
description = serializers.CharField()
metadata = serializers.DictField()
choices = RichChoiceSerializer(many=True)
def to_representation(self, instance: ChoiceGroup) -> Dict[str, Any]:
def to_representation(self, instance: ChoiceGroup) -> dict[str, Any]:
"""Convert ChoiceGroup to dictionary representation"""
return instance.to_dict()
@@ -105,11 +107,11 @@ class ChoiceGroupSerializer(serializers.Serializer):
class RichChoiceFieldSerializer(serializers.CharField):
"""
Serializer field for rich choice values.
This field serializes the choice value but can optionally
include rich choice metadata in the response.
"""
def __init__(
self,
choice_group: str,
@@ -119,7 +121,7 @@ class RichChoiceFieldSerializer(serializers.CharField):
):
"""
Initialize the serializer field.
Args:
choice_group: Name of the choice group in the registry
domain: Domain namespace for the choice group
@@ -130,12 +132,12 @@ class RichChoiceFieldSerializer(serializers.CharField):
self.domain = domain
self.include_metadata = include_metadata
super().__init__(**kwargs)
def to_representation(self, value: str) -> Any:
"""Convert choice value to representation"""
if not value:
return value
if self.include_metadata:
# Return rich choice object
choice = registry.get_choice(self.choice_group, value, self.domain)
@@ -158,7 +160,7 @@ class RichChoiceFieldSerializer(serializers.CharField):
else:
# Return just the value
return value
def to_internal_value(self, data: Any) -> str:
"""Convert input data to choice value"""
if isinstance(data, dict) and 'value' in data:
@@ -175,26 +177,26 @@ def create_choice_options_serializer(
include_counts: bool = False,
queryset=None,
count_field: str = 'id'
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
Create choice options for filter endpoints.
This function generates choice options with optional counts
for use in filter metadata endpoints.
Args:
choice_group: Name of the choice group in the registry
domain: Domain namespace for the choice group
include_counts: Whether to include counts for each option
queryset: QuerySet to count against (required if include_counts=True)
count_field: Field to filter on for counting (default: 'id')
Returns:
List of choice option dictionaries
"""
choices = registry.get_active_choices(choice_group, domain)
options = []
for choice in choices:
option_data = {
'value': choice.value,
@@ -207,7 +209,7 @@ def create_choice_options_serializer(
'css_class': choice.css_class,
'metadata': choice.metadata,
}
if include_counts and queryset is not None:
# Count items for this choice
try:
@@ -218,9 +220,9 @@ def create_choice_options_serializer(
option_data['count'] = None
else:
option_data['count'] = None
options.append(option_data)
# Sort by sort_order, then by label
options.sort(key=lambda x: (
(lambda c: c.sort_order if (c is not None and hasattr(c, 'sort_order')) else 0)(
@@ -228,7 +230,7 @@ def create_choice_options_serializer(
),
x['label']
))
return options
@@ -240,19 +242,19 @@ def serialize_choice_value(
) -> Any:
"""
Serialize a single choice value.
Args:
value: The choice value to serialize
choice_group: Name of the choice group in the registry
domain: Domain namespace for the choice group
include_metadata: Whether to include rich choice metadata
Returns:
Serialized choice value (string or rich object)
"""
if not value:
return value
if include_metadata:
choice = registry.get_choice(choice_group, value, domain)
if choice:

View File

@@ -4,8 +4,9 @@ Utility Functions for Rich Choices
This module provides utility functions for working with rich choice objects.
"""
from typing import Any, Dict, List, Optional, Tuple
from .base import RichChoice, ChoiceCategory
from typing import Any
from .base import ChoiceCategory, RichChoice
from .registry import registry
@@ -17,27 +18,24 @@ def validate_choice_value(
) -> bool:
"""
Validate that a choice value is valid for a given choice group.
Args:
value: The choice value to validate
choice_group: Name of the choice group in the registry
domain: Domain namespace for the choice group
allow_deprecated: Whether to allow deprecated choices
Returns:
True if valid, False otherwise
"""
if not value:
return True # Allow empty values (handled by field's null/blank settings)
choice = registry.get_choice(choice_group, value, domain)
if choice is None:
return False
if choice.deprecated and not allow_deprecated:
return False
return True
return not (choice.deprecated and not allow_deprecated)
def get_choice_display(
@@ -47,21 +45,21 @@ def get_choice_display(
) -> str:
"""
Get the display label for a choice value.
Args:
value: The choice value
choice_group: Name of the choice group in the registry
domain: Domain namespace for the choice group
Returns:
Display label for the choice
Raises:
ValueError: If the choice value is not found in the registry
"""
if not value:
return ""
choice = registry.get_choice(choice_group, value, domain)
if choice:
return choice.label
@@ -72,24 +70,24 @@ def get_choice_display(
def create_status_choices(
statuses: Dict[str, Dict[str, Any]],
statuses: dict[str, dict[str, Any]],
category: ChoiceCategory = ChoiceCategory.STATUS
) -> List[RichChoice]:
) -> list[RichChoice]:
"""
Create status choices with consistent color coding.
Args:
statuses: Dictionary mapping status value to config dict
category: Choice category (defaults to STATUS)
Returns:
List of RichChoice objects for statuses
"""
choices = []
for value, config in statuses.items():
metadata = config.get('metadata', {})
# Add default status colors if not specified
if 'color' not in metadata:
if 'operating' in value.lower() or 'active' in value.lower():
@@ -102,7 +100,7 @@ def create_status_choices(
metadata['color'] = 'blue'
else:
metadata['color'] = 'gray'
choice = RichChoice(
value=value,
label=config['label'],
@@ -112,26 +110,26 @@ def create_status_choices(
category=category
)
choices.append(choice)
return choices
def create_type_choices(
types: Dict[str, Dict[str, Any]],
types: dict[str, dict[str, Any]],
category: ChoiceCategory = ChoiceCategory.TYPE
) -> List[RichChoice]:
) -> list[RichChoice]:
"""
Create type/classification choices.
Args:
types: Dictionary mapping type value to config dict
category: Choice category (defaults to TYPE)
Returns:
List of RichChoice objects for types
"""
choices = []
for value, config in types.items():
choice = RichChoice(
value=value,
@@ -142,21 +140,21 @@ def create_type_choices(
category=category
)
choices.append(choice)
return choices
def merge_choice_metadata(
base_metadata: Dict[str, Any],
override_metadata: Dict[str, Any]
) -> Dict[str, Any]:
base_metadata: dict[str, Any],
override_metadata: dict[str, Any]
) -> dict[str, Any]:
"""
Merge choice metadata dictionaries.
Args:
base_metadata: Base metadata dictionary
override_metadata: Override metadata dictionary
Returns:
Merged metadata dictionary
"""
@@ -166,16 +164,16 @@ def merge_choice_metadata(
def filter_choices_by_category(
choices: List[RichChoice],
choices: list[RichChoice],
category: ChoiceCategory
) -> List[RichChoice]:
) -> list[RichChoice]:
"""
Filter choices by category.
Args:
choices: List of RichChoice objects
category: Category to filter by
Returns:
Filtered list of choices
"""
@@ -183,16 +181,16 @@ def filter_choices_by_category(
def sort_choices(
choices: List[RichChoice],
choices: list[RichChoice],
sort_by: str = "sort_order"
) -> List[RichChoice]:
) -> list[RichChoice]:
"""
Sort choices by specified criteria.
Args:
choices: List of RichChoice objects
sort_by: Sort criteria ("sort_order", "label", "value")
Returns:
Sorted list of choices
"""
@@ -209,14 +207,14 @@ def sort_choices(
def get_choice_colors(
choice_group: str,
domain: str = "core"
) -> Dict[str, str]:
) -> dict[str, str]:
"""
Get a mapping of choice values to their colors.
Args:
choice_group: Name of the choice group in the registry
domain: Domain namespace for the choice group
Returns:
Dictionary mapping choice values to colors
"""
@@ -230,35 +228,35 @@ def get_choice_colors(
def validate_choice_group_data(
name: str,
choices: List[RichChoice],
choices: list[RichChoice],
domain: str = "core"
) -> List[str]:
) -> list[str]:
"""
Validate choice group data and return list of errors.
Args:
name: Choice group name
choices: List of RichChoice objects
domain: Domain namespace
Returns:
List of validation error messages
"""
errors = []
if not name:
errors.append("Choice group name cannot be empty")
if not choices:
errors.append("Choice group must contain at least one choice")
return errors
# Check for duplicate values
values = [choice.value for choice in choices]
if len(values) != len(set(values)):
duplicates = [v for v in values if values.count(v) > 1]
errors.append(f"Duplicate choice values found: {', '.join(set(duplicates))}")
# Validate individual choices
for i, choice in enumerate(choices):
try:
@@ -273,17 +271,17 @@ def validate_choice_group_data(
)
except ValueError as e:
errors.append(f"Choice {i}: {str(e)}")
return errors
def create_choice_from_config(config: Dict[str, Any]) -> RichChoice:
def create_choice_from_config(config: dict[str, Any]) -> RichChoice:
"""
Create a RichChoice from a configuration dictionary.
Args:
config: Configuration dictionary with choice data
Returns:
RichChoice object
"""
@@ -300,19 +298,19 @@ def create_choice_from_config(config: Dict[str, Any]) -> RichChoice:
def export_choices_to_dict(
choice_group: str,
domain: str = "core"
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Export a choice group to a dictionary format.
Args:
choice_group: Name of the choice group in the registry
domain: Domain namespace for the choice group
Returns:
Dictionary representation of the choice group
"""
group = registry.get(choice_group, domain)
if not group:
return {}
return group.to_dict()

View File

@@ -4,23 +4,26 @@ Advanced caching decorators for API views and functions.
import hashlib
import json
import logging
import time
from collections.abc import Callable
from functools import wraps
from typing import Optional, List, Callable, Any, Dict
from typing import Any
from django.http import HttpRequest, HttpResponseBase
from django.utils.decorators import method_decorator
from django.views.decorators.vary import vary_on_headers
from django.views import View
from django.views.decorators.vary import vary_on_headers
from rest_framework.response import Response as DRFResponse
from apps.core.services.enhanced_cache_service import EnhancedCacheService
import logging
logger = logging.getLogger(__name__)
def cache_api_response(
timeout: int = 1800,
vary_on: Optional[List[str]] = None,
vary_on: list[str] | None = None,
key_prefix: str = "api",
cache_backend: str = "api",
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
@@ -82,14 +85,14 @@ def cache_api_response(
"cache_hit": True,
},
)
# If cached data is our dict format for DRF responses, reconstruct it
if isinstance(cached_response, dict) and '__drf_data__' in cached_response:
return DRFResponse(
data=cached_response['__drf_data__'],
data=cached_response['__drf_data__'],
status=cached_response.get('status', 200)
)
return cached_response
# Execute view and cache result
@@ -108,7 +111,7 @@ def cache_api_response(
}
else:
cache_payload = response
getattr(cache_service, cache_backend + "_cache").set(
cache_key, cache_payload, timeout
)
@@ -193,7 +196,7 @@ def cache_queryset_result(
def invalidate_cache_on_save(
model_name: str, cache_patterns: Optional[List[str]] = None
model_name: str, cache_patterns: list[str] | None = None
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""
Decorator to invalidate cache when model instances are saved
@@ -313,8 +316,8 @@ class CachedAPIViewMixin(View):
def smart_cache(
timeout: int = 3600,
key_func: Optional[Callable[..., str]] = None,
invalidate_on: Optional[List[str]] = None,
key_func: Callable[..., str] | None = None,
invalidate_on: list[str] | None = None,
cache_backend: str = "default",
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""
@@ -378,8 +381,8 @@ def smart_cache(
# Add cache invalidation if specified
if invalidate_on:
setattr(wrapper, "_cache_invalidate_on", invalidate_on)
setattr(wrapper, "_cache_backend", cache_backend)
wrapper._cache_invalidate_on = invalidate_on
wrapper._cache_backend = cache_backend
return wrapper
@@ -431,7 +434,7 @@ def generate_model_cache_key(model_instance: Any, suffix: str = "") -> str:
def generate_queryset_cache_key(
queryset: Any, params: Optional[Dict[str, Any]] = None
queryset: Any, params: dict[str, Any] | None = None
) -> str:
"""Generate cache key for queryset with parameters"""
model_name = queryset.model._meta.model_name

View File

@@ -3,7 +3,7 @@ Custom exception classes for ThrillWiki.
Provides domain-specific exceptions with proper error codes and messages.
"""
from typing import Optional, Dict, Any
from typing import Any
class ThrillWikiException(Exception):
@@ -15,16 +15,16 @@ class ThrillWikiException(Exception):
def __init__(
self,
message: Optional[str] = None,
error_code: Optional[str] = None,
details: Optional[Dict[str, Any]] = None,
message: str | None = None,
error_code: str | None = None,
details: dict[str, Any] | None = None,
):
self.message = message or self.default_message
self.error_code = error_code or self.error_code
self.details = details or {}
super().__init__(self.message)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Convert exception to dictionary for API responses."""
return {
"error_code": self.error_code,
@@ -96,7 +96,7 @@ class ParkNotFoundError(NotFoundError):
default_message = "Park not found"
error_code = "PARK_NOT_FOUND"
def __init__(self, park_slug: Optional[str] = None, **kwargs):
def __init__(self, park_slug: str | None = None, **kwargs):
if park_slug:
kwargs["details"] = {"park_slug": park_slug}
kwargs["message"] = f"Park with slug '{park_slug}' not found"
@@ -122,7 +122,7 @@ class RideNotFoundError(NotFoundError):
default_message = "Ride not found"
error_code = "RIDE_NOT_FOUND"
def __init__(self, ride_slug: Optional[str] = None, **kwargs):
def __init__(self, ride_slug: str | None = None, **kwargs):
if ride_slug:
kwargs["details"] = {"ride_slug": ride_slug}
kwargs["message"] = f"Ride with slug '{ride_slug}' not found"
@@ -150,8 +150,8 @@ class InvalidCoordinatesError(ValidationException):
def __init__(
self,
latitude: Optional[float] = None,
longitude: Optional[float] = None,
latitude: float | None = None,
longitude: float | None = None,
**kwargs,
):
if latitude is not None or longitude is not None:
@@ -198,7 +198,7 @@ class InsufficientPermissionsError(PermissionDeniedError):
default_message = "Insufficient permissions"
error_code = "INSUFFICIENT_PERMISSIONS"
def __init__(self, required_permission: Optional[str] = None, **kwargs):
def __init__(self, required_permission: str | None = None, **kwargs):
if required_permission:
kwargs["details"] = {"required_permission": required_permission}
kwargs["message"] = f"Permission '{required_permission}' required"
@@ -226,7 +226,7 @@ class RoadTripError(ExternalServiceError):
default_message = "Road trip planning error"
error_code = "ROADTRIP_ERROR"
def __init__(self, service_name: Optional[str] = None, **kwargs):
def __init__(self, service_name: str | None = None, **kwargs):
if service_name:
kwargs["details"] = {"service": service_name}
super().__init__(**kwargs)

View File

@@ -1,11 +1,10 @@
"""Core forms and form components."""
from autocomplete import Autocomplete
from django.conf import settings
from django.core.exceptions import PermissionDenied
from django.utils.translation import gettext_lazy as _
from autocomplete import Autocomplete
class BaseAutocomplete(Autocomplete):
"""Base autocomplete class for consistent autocomplete behavior across the project.

View File

@@ -1,8 +1,8 @@
"""
Base forms and views for HTMX integration.
"""
from django.views.generic.edit import FormView
from django.http import JsonResponse
from django.views.generic.edit import FormView
class HTMXFormView(FormView):

View File

@@ -2,9 +2,10 @@
Custom health checks for ThrillWiki application.
"""
import time
import logging
import time
from pathlib import Path
from django.core.cache import cache
from django.db import connection
from health_check.backends import BaseHealthCheckBackend
@@ -165,9 +166,10 @@ class ApplicationHealthCheck(BaseHealthCheckBackend):
# Check if we can access critical models
try:
from parks.models import Park
from apps.rides.models import Ride
from django.contrib.auth import get_user_model
from parks.models import Park
from apps.rides.models import Ride
User = get_user_model()
@@ -185,9 +187,10 @@ class ApplicationHealthCheck(BaseHealthCheckBackend):
self.add_error(f"Model access check failed: {e}")
# Check media and static file configuration
from django.conf import settings
import os
from django.conf import settings
if not os.path.exists(settings.MEDIA_ROOT):
self.add_error(f"Media directory does not exist: {settings.MEDIA_ROOT}")
@@ -208,8 +211,8 @@ class ExternalServiceHealthCheck(BaseHealthCheckBackend):
def check_status(self):
# Check email service if configured
try:
from django.core.mail import get_connection
from django.conf import settings
from django.core.mail import get_connection
if (
hasattr(settings, "EMAIL_BACKEND")
@@ -253,8 +256,8 @@ class ExternalServiceHealthCheck(BaseHealthCheckBackend):
# Check Redis connection if configured
try:
from django.core.cache import caches
from django.conf import settings
from django.core.cache import caches
cache_config = settings.CACHES.get("default", {})
if "redis" in cache_config.get("BACKEND", "").lower():
@@ -279,6 +282,7 @@ class DiskSpaceHealthCheck(BaseHealthCheckBackend):
def check_status(self):
try:
import shutil
from django.conf import settings
# Check disk space for media directory

View File

@@ -1,8 +1,9 @@
from django.db import models
from django.contrib.contenttypes.models import ContentType
from django.contrib.contenttypes.fields import GenericForeignKey
from typing import TYPE_CHECKING, Any
from django.conf import settings
from typing import Any, Dict, Optional, TYPE_CHECKING
from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.db import models
from django.db.models import QuerySet
if TYPE_CHECKING:
@@ -12,7 +13,7 @@ if TYPE_CHECKING:
class DiffMixin:
"""Mixin to add diffing capabilities to models with pghistory"""
def get_prev_record(self) -> Optional[Any]:
def get_prev_record(self) -> Any | None:
"""Get the previous record for this instance"""
try:
# Use getattr to safely access objects manager and pghistory fields
@@ -37,7 +38,7 @@ class DiffMixin:
except (AttributeError, TypeError):
return None
def diff_against_previous(self) -> Dict:
def diff_against_previous(self) -> dict:
"""Compare this record against the previous one"""
prev_record = self.get_prev_record()
if not prev_record:

View File

@@ -24,7 +24,7 @@ import json
from functools import wraps
from typing import TYPE_CHECKING, Any
from django.http import HttpResponse, JsonResponse
from django.http import HttpResponse
from django.template import TemplateDoesNotExist
from django.template.loader import render_to_string

View File

@@ -5,7 +5,8 @@ Provides structured logging with proper formatting and context.
import logging
import sys
from typing import Dict, Any, Optional
from typing import Any
from django.conf import settings
from django.utils import timezone
@@ -65,7 +66,7 @@ def log_exception(
logger: logging.Logger,
exception: Exception,
*,
context: Optional[Dict[str, Any]] = None,
context: dict[str, Any] | None = None,
request=None,
level: int = logging.ERROR,
) -> None:
@@ -111,7 +112,7 @@ def log_business_event(
event_type: str,
*,
message: str,
context: Optional[Dict[str, Any]] = None,
context: dict[str, Any] | None = None,
request=None,
level: int = logging.INFO,
) -> None:
@@ -149,7 +150,7 @@ def log_performance_metric(
operation: str,
*,
duration_ms: float,
context: Optional[Dict[str, Any]] = None,
context: dict[str, Any] | None = None,
level: int = logging.INFO,
) -> None:
"""
@@ -177,8 +178,8 @@ def log_api_request(
logger: logging.Logger,
request,
*,
response_status: Optional[int] = None,
duration_ms: Optional[float] = None,
response_status: int | None = None,
duration_ms: float | None = None,
level: int = logging.INFO,
) -> None:
"""
@@ -219,7 +220,7 @@ def log_security_event(
*,
message: str,
severity: str = "medium",
context: Optional[Dict[str, Any]] = None,
context: dict[str, Any] | None = None,
request=None,
) -> None:
"""

View File

@@ -7,11 +7,12 @@ Run with: uv run manage.py calculate_new_content
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Any
from django.core.management.base import BaseCommand, CommandError
from django.utils import timezone
from typing import Any
from django.core.cache import cache
from django.core.management.base import BaseCommand, CommandError
from django.db.models import Q
from django.utils import timezone
from apps.parks.models import Park
from apps.rides.models import Ride
@@ -102,7 +103,7 @@ class Command(BaseCommand):
logger.error(f"Error calculating new content: {e}", exc_info=True)
raise CommandError(f"Failed to calculate new content: {e}")
def _get_new_parks(self, cutoff_date: datetime, limit: int) -> List[Dict[str, Any]]:
def _get_new_parks(self, cutoff_date: datetime, limit: int) -> list[dict[str, Any]]:
"""Get recently added parks using real data."""
new_parks = (
Park.objects.filter(
@@ -117,9 +118,8 @@ class Command(BaseCommand):
results = []
for park in new_parks:
date_added = park.opening_date or park.created_at
if date_added:
if isinstance(date_added, datetime):
date_added = date_added.date()
if date_added and isinstance(date_added, datetime):
date_added = date_added.date()
opening_date = getattr(park, "opening_date", None)
if opening_date and isinstance(opening_date, datetime):
@@ -142,7 +142,7 @@ class Command(BaseCommand):
return results
def _get_new_rides(self, cutoff_date: datetime, limit: int) -> List[Dict[str, Any]]:
def _get_new_rides(self, cutoff_date: datetime, limit: int) -> list[dict[str, Any]]:
"""Get recently added rides using real data."""
new_rides = (
Ride.objects.filter(
@@ -159,9 +159,8 @@ class Command(BaseCommand):
date_added = getattr(ride, "opening_date", None) or getattr(
ride, "created_at", None
)
if date_added:
if isinstance(date_added, datetime):
date_added = date_added.date()
if date_added and isinstance(date_added, datetime):
date_added = date_added.date()
opening_date = getattr(ride, "opening_date", None)
if opening_date and isinstance(opening_date, datetime):
@@ -186,8 +185,8 @@ class Command(BaseCommand):
return results
def _format_new_content_results(
self, new_items: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
self, new_items: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Format new content results for frontend consumption."""
formatted_results = []

View File

@@ -6,11 +6,12 @@ Run with: uv run manage.py calculate_trending
"""
import logging
from typing import Dict, List, Any
from typing import Any
from django.contrib.contenttypes.models import ContentType
from django.core.cache import cache
from django.core.management.base import BaseCommand, CommandError
from django.utils import timezone
from django.core.cache import cache
from django.contrib.contenttypes.models import ContentType
from apps.core.analytics import PageView
from apps.parks.models import Park
@@ -107,7 +108,7 @@ class Command(BaseCommand):
def _calculate_trending_parks(
self, current_period_hours: int, previous_period_hours: int, limit: int
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""Calculate trending scores for parks using real data."""
parks = Park.objects.filter(status="OPERATING").select_related(
"location", "operator"
@@ -151,7 +152,7 @@ class Command(BaseCommand):
def _calculate_trending_rides(
self, current_period_hours: int, previous_period_hours: int, limit: int
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""Calculate trending scores for rides using real data."""
rides = Ride.objects.filter(status="OPERATING").select_related(
"park", "park__location"
@@ -339,10 +340,10 @@ class Command(BaseCommand):
def _format_trending_results(
self,
trending_items: List[Dict[str, Any]],
trending_items: list[dict[str, Any]],
current_period_hours: int,
previous_period_hours: int,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""Format trending results for frontend consumption."""
formatted_results = []

View File

@@ -15,9 +15,9 @@ import shutil
import subprocess
from pathlib import Path
from django.conf import settings
from django.core.cache import cache, caches
from django.core.management.base import BaseCommand
from django.conf import settings
class Command(BaseCommand):

View File

@@ -6,11 +6,10 @@ showing which callbacks are registered for each model and transition.
"""
from django.core.management.base import BaseCommand, CommandParser
from django.apps import apps
from apps.core.state_machine.callback_base import (
callback_registry,
CallbackStage,
callback_registry,
)
from apps.core.state_machine.config import callback_config

View File

@@ -10,10 +10,10 @@ Usage:
python manage.py optimize_static --force
"""
import os
from pathlib import Path
from django.core.management.base import BaseCommand, CommandError
from django.conf import settings
from django.core.management.base import BaseCommand, CommandError
class Command(BaseCommand):

View File

@@ -5,8 +5,8 @@ This command automatically sets up the development environment and starts
the server, replacing the need for the dev_server.sh script.
"""
from django.core.management.base import BaseCommand
from django.core.management import execute_from_command_line
from django.core.management.base import BaseCommand
class Command(BaseCommand):
@@ -92,7 +92,7 @@ class Command(BaseCommand):
def has_runserver_plus(self):
"""Check if runserver_plus is available (django-extensions)."""
try:
import django_extensions
import django_extensions # noqa: F401
return True
except ImportError:

View File

@@ -10,9 +10,9 @@ Usage:
python manage.py security_audit --verbose
"""
from django.core.management.base import BaseCommand
from django.core.checks import registry, Tags
from django.conf import settings
from django.core.checks import Tags, registry
from django.core.management.base import BaseCommand
class Command(BaseCommand):
@@ -88,10 +88,7 @@ class Command(BaseCommand):
)
else:
for error in errors:
if error.is_serious():
prefix = self.style.ERROR(" ✗ ERROR")
else:
prefix = self.style.WARNING(" ! WARNING")
prefix = self.style.ERROR(" ✗ ERROR") if error.is_serious() else self.style.WARNING(" ! WARNING")
self.log(f"{prefix}: {error.msg}", report_lines)
if error.hint and self.verbose:
@@ -169,10 +166,7 @@ class Command(BaseCommand):
])
for name, is_secure, value in checks:
if is_secure:
status = self.style.SUCCESS("")
else:
status = self.style.WARNING("!")
status = self.style.SUCCESS("") if is_secure else self.style.WARNING("!")
msg = f" {status} {name}"
if self.verbose:

View File

@@ -6,11 +6,10 @@ allowing the project to run without requiring the shell script.
"""
import subprocess
import sys
from pathlib import Path
from django.core.management.base import BaseCommand
from django.conf import settings
from django.core.management.base import BaseCommand
class Command(BaseCommand):

View File

@@ -5,14 +5,14 @@ This command allows testing callbacks for specific transitions
without actually changing model state.
"""
from django.core.management.base import BaseCommand, CommandParser, CommandError
from django.apps import apps
from django.contrib.auth import get_user_model
from django.core.management.base import BaseCommand, CommandError, CommandParser
from apps.core.state_machine.callback_base import (
callback_registry,
CallbackStage,
TransitionContext,
callback_registry,
)
from apps.core.state_machine.monitoring import callback_monitor

View File

@@ -1,13 +1,15 @@
from django.core.management.base import BaseCommand
import random
from datetime import datetime, timedelta
from django.contrib.contenttypes.models import ContentType
from django.core.management.base import BaseCommand
from django.utils import timezone
from apps.parks.models.parks import Park
from apps.rides.models.rides import Ride
from apps.parks.models.companies import Company
from apps.core.analytics import PageView
from apps.core.services.trending_service import trending_service
from datetime import datetime, timedelta
import random
from apps.parks.models.companies import Company
from apps.parks.models.parks import Park
from apps.rides.models.rides import Ride
class Command(BaseCommand):
@@ -205,7 +207,7 @@ class Command(BaseCommand):
content_type = ContentType.objects.get_for_model(type(content_object))
# Create recent views (last 2 hours)
for i in range(recent_views):
for _i in range(recent_views):
view_time = base_time - timedelta(
minutes=random.randint(0, 120) # Last 2 hours
)
@@ -218,7 +220,7 @@ class Command(BaseCommand):
)
# Create older views (2-24 hours ago)
for i in range(older_views):
for _i in range(older_views):
view_time = base_time - timedelta(hours=random.randint(2, 24))
PageView.objects.create(
content_type=content_type,

View File

@@ -1,8 +1,9 @@
from django.core.management.base import BaseCommand
from django.core.cache import cache
from django.core.management.base import BaseCommand
from apps.core.analytics import PageView
from apps.parks.models import Park
from apps.rides.models import Ride
from apps.core.analytics import PageView
class Command(BaseCommand):

View File

@@ -12,14 +12,15 @@ Usage:
import json
import sys
from django.core.management.base import BaseCommand, CommandError
from django.core.management.base import BaseCommand
from config.settings.secrets import (
check_secret_expiry,
validate_required_secrets,
)
from config.settings.validation import (
validate_all_settings,
get_validation_report,
)
from config.settings.secrets import (
validate_required_secrets,
check_secret_expiry,
)

View File

@@ -12,12 +12,13 @@ Usage:
python manage.py warm_cache --dry-run
"""
import time
import logging
from django.core.management.base import BaseCommand
from django.db.models import Count, Avg
import time
from apps.core.services.enhanced_cache_service import EnhancedCacheService, CacheWarmer
from django.core.management.base import BaseCommand
from django.db.models import Count
from apps.core.services.enhanced_cache_service import EnhancedCacheService
logger = logging.getLogger(__name__)
@@ -122,7 +123,7 @@ class Command(BaseCommand):
)
warmed_count += 1
if verbose:
self.stdout.write(f" Cached park status counts")
self.stdout.write(" Cached park status counts")
except Exception as e:
failed_count += 1
self.stdout.write(self.style.ERROR(f" Failed to cache park status counts: {e}"))
@@ -191,7 +192,7 @@ class Command(BaseCommand):
)
warmed_count += 1
if verbose:
self.stdout.write(f" Cached ride category counts")
self.stdout.write(" Cached ride category counts")
except Exception as e:
failed_count += 1
self.stdout.write(self.style.ERROR(f" Failed to cache ride category counts: {e}"))

View File

@@ -3,13 +3,13 @@ Custom managers and QuerySets for optimized database patterns.
Following Django styleguide best practices for database access.
"""
from typing import Optional, List, Union
from django.db import models
from django.db.models import Q, Count, Avg, Max
from datetime import timedelta
from django.contrib.gis.geos import Point
from django.contrib.gis.measure import Distance
from django.db import models
from django.db.models import Avg, Count, Max, Q
from django.utils import timezone
from datetime import timedelta
class BaseQuerySet(models.QuerySet):
@@ -32,7 +32,7 @@ class BaseQuerySet(models.QuerySet):
cutoff_date = timezone.now() - timedelta(days=days)
return self.filter(created_at__gte=cutoff_date)
def search(self, *, query: str, fields: Optional[List[str]] = None):
def search(self, *, query: str, fields: list[str] | None = None):
"""
Full-text search across specified fields.
@@ -81,7 +81,7 @@ class BaseManager(models.Manager):
def recent(self, *, days: int = 30):
return self.get_queryset().recent(days=days)
def search(self, *, query: str, fields: Optional[List[str]] = None):
def search(self, *, query: str, fields: list[str] | None = None):
return self.get_queryset().search(query=query, fields=fields)
@@ -245,7 +245,7 @@ class TimestampedManager(BaseManager):
class StatusQuerySet(BaseQuerySet):
"""QuerySet for models with status fields."""
def with_status(self, *, status: Union[str, List[str]]):
def with_status(self, *, status: str | list[str]):
"""Filter by status."""
if isinstance(status, list):
return self.filter(status__in=status)

View File

@@ -5,9 +5,9 @@ This package contains middleware components for the Django application,
including view tracking and other core functionality.
"""
from .view_tracking import ViewTrackingMiddleware, get_view_stats_for_content
from .analytics import PgHistoryContextMiddleware
from .nextjs import APIResponseMiddleware
from .view_tracking import ViewTrackingMiddleware, get_view_stats_for_content
__all__ = [
"ViewTrackingMiddleware",

View File

@@ -2,6 +2,7 @@
Middleware for handling errors in HTMX requests.
"""
import logging
from django.http import HttpResponseServerError
from django.template.loader import render_to_string

View File

@@ -2,11 +2,12 @@
Performance monitoring middleware for tracking request metrics.
"""
import time
import logging
import time
from django.conf import settings
from django.db import connection
from django.utils.deprecation import MiddlewareMixin
from django.conf import settings
performance_logger = logging.getLogger("performance")
logger = logging.getLogger(__name__)
@@ -162,10 +163,7 @@ class PerformanceMiddleware(MiddlewareMixin):
def _get_client_ip(self, request):
"""Extract client IP address from request"""
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
if x_forwarded_for:
ip = x_forwarded_for.split(",")[0].strip()
else:
ip = request.META.get("REMOTE_ADDR", "")
ip = x_forwarded_for.split(",")[0].strip() if x_forwarded_for else request.META.get("REMOTE_ADDR", "")
return ip
def _get_log_level(self, duration, query_count, status_code):

View File

@@ -14,11 +14,10 @@ Usage:
"""
import logging
from typing import Callable, Optional, Tuple
from collections.abc import Callable
from django.core.cache import cache
from django.http import HttpRequest, HttpResponse, JsonResponse
from django.conf import settings
logger = logging.getLogger(__name__)
@@ -88,7 +87,7 @@ class AuthRateLimitMiddleware:
return response
def _get_rate_limits(self, path: str) -> Optional[dict]:
def _get_rate_limits(self, path: str) -> dict | None:
"""Get rate limits for a path, if any."""
# Exact match
if path in self.RATE_LIMITED_PATHS:
@@ -125,7 +124,7 @@ class AuthRateLimitMiddleware:
client_ip: str,
path: str,
limits: dict
) -> Tuple[bool, str]:
) -> tuple[bool, str]:
"""
Check if the client has exceeded rate limits.

View File

@@ -3,9 +3,10 @@ Request logging middleware for comprehensive request/response logging.
Logs all HTTP requests with detailed data for debugging and monitoring.
"""
import json
import logging
import time
import json
from django.utils.deprecation import MiddlewareMixin
logger = logging.getLogger('request_logging')

View File

@@ -9,12 +9,13 @@ analytics for the trending algorithm.
import logging
import re
from datetime import timedelta
from typing import Optional, Union
from django.http import HttpRequest, HttpResponse
from django.utils import timezone
from typing import Union
from django.conf import settings
from django.contrib.contenttypes.models import ContentType
from django.core.cache import cache
from django.conf import settings
from django.http import HttpRequest, HttpResponse
from django.utils import timezone
from apps.core.analytics import PageView
from apps.parks.models import Park
@@ -105,10 +106,7 @@ class ViewTrackingMiddleware:
return True
# Skip AJAX requests (optional - depending on requirements)
if request.META.get("HTTP_X_REQUESTED_WITH") == "XMLHttpRequest":
return True
return False
return request.META.get("HTTP_X_REQUESTED_WITH") == "XMLHttpRequest"
def _track_view_if_applicable(self, request: HttpRequest) -> None:
"""Track view if the URL matches tracked patterns."""
@@ -159,7 +157,7 @@ class ViewTrackingMiddleware:
def _get_content_object(
self, content_type: str, slug: str
) -> Optional[ContentObject]:
) -> ContentObject | None:
"""Get the content object by type and slug."""
try:
if content_type == "park":
@@ -234,7 +232,7 @@ class ViewTrackingMiddleware:
content_type = ContentType.objects.get_for_model(content_obj)
return f"pageview_dedup:{content_type.id}:{content_obj.pk}:{client_ip}"
def _get_client_ip(self, request: HttpRequest) -> Optional[str]:
def _get_client_ip(self, request: HttpRequest) -> str | None:
"""Extract client IP address from request."""
# Check for forwarded IP (common in production with load balancers)
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
@@ -270,13 +268,12 @@ class ViewTrackingMiddleware:
# Skip localhost and private IPs in production
if getattr(settings, "SKIP_LOCAL_IPS", not settings.DEBUG):
if ip.startswith(("127.", "192.168.", "10.")) or ip.startswith("172."):
if any(
16 <= int(ip.split(".")[1]) <= 31
for _ in [ip]
if ip.startswith("172.")
):
return False
if (ip.startswith(("127.", "192.168.", "10.")) or ip.startswith("172.")) and any(
16 <= int(ip.split(".")[1]) <= 31
for _ in [ip]
if ip.startswith("172.")
):
return False
return True

View File

@@ -1,6 +1,6 @@
"""HTMX mixins for views. Canonical definitions for partial rendering and triggers."""
from typing import Any, Optional, Type
from typing import Any
from django.template import TemplateDoesNotExist
from django.template.loader import select_template
@@ -11,7 +11,7 @@ from django.views.generic.list import MultipleObjectMixin
class HTMXFilterableMixin(MultipleObjectMixin):
"""Enhance list views to return partial templates for HTMX requests."""
filter_class: Optional[Type[Any]] = None
filter_class: type[Any] | None = None
htmx_partial_suffix = "_partial.html"
def get_queryset(self):
@@ -47,7 +47,7 @@ class HTMXFilterableMixin(MultipleObjectMixin):
class HTMXFormMixin(FormMixin):
"""FormMixin that returns partials and field-level errors for HTMX requests."""
htmx_success_trigger: Optional[str] = None
htmx_success_trigger: str | None = None
def form_invalid(self, form):
"""Return partial with errors on invalid form submission via HTMX."""

View File

@@ -1,9 +1,10 @@
from django.db import models
import pghistory
from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.db import models
from django.utils.text import slugify
from apps.core.history import TrackedModel
import pghistory
@pghistory.track()

View File

@@ -1,5 +1,6 @@
from rest_framework import permissions
class IsOwnerOrReadOnly(permissions.BasePermission):
"""
Custom permission to only allow owners of an object to edit it.

View File

@@ -3,24 +3,26 @@ Selectors for core functionality including map services and analytics.
Following Django styleguide pattern for separating data access from business logic.
"""
from typing import Optional, Dict, Any, List
from django.db.models import QuerySet, Q, Count
from datetime import timedelta
from typing import Any
from django.contrib.gis.geos import Point, Polygon
from django.contrib.gis.measure import Distance
from django.db.models import Count, Q, QuerySet
from django.utils import timezone
from datetime import timedelta
from .analytics import PageView
from apps.parks.models import Park
from apps.rides.models import Ride
from .analytics import PageView
def unified_locations_for_map(
*,
bounds: Optional[Polygon] = None,
location_types: Optional[List[str]] = None,
filters: Optional[Dict[str, Any]] = None,
) -> Dict[str, QuerySet]:
bounds: Polygon | None = None,
location_types: list[str] | None = None,
filters: dict[str, Any] | None = None,
) -> dict[str, QuerySet]:
"""
Get unified location data for map display across all location types.
@@ -88,9 +90,9 @@ def locations_near_point(
*,
point: Point,
distance_km: float = 50,
location_types: Optional[List[str]] = None,
location_types: list[str] | None = None,
limit: int = 20,
) -> Dict[str, QuerySet]:
) -> dict[str, QuerySet]:
"""
Get locations near a specific geographic point across all types.
@@ -149,7 +151,7 @@ def locations_near_point(
return results
def search_all_locations(*, query: str, limit: int = 20) -> Dict[str, QuerySet]:
def search_all_locations(*, query: str, limit: int = 20) -> dict[str, QuerySet]:
"""
Search across all location types for a query string.
@@ -193,9 +195,9 @@ def search_all_locations(*, query: str, limit: int = 20) -> Dict[str, QuerySet]:
def page_views_for_analytics(
*,
start_date: Optional[timezone.datetime] = None,
end_date: Optional[timezone.datetime] = None,
path_pattern: Optional[str] = None,
start_date: timezone.datetime | None = None,
end_date: timezone.datetime | None = None,
path_pattern: str | None = None,
) -> QuerySet[PageView]:
"""
Get page views for analytics with optional filtering.
@@ -222,7 +224,7 @@ def page_views_for_analytics(
return queryset.order_by("-timestamp")
def popular_pages_summary(*, days: int = 30) -> Dict[str, Any]:
def popular_pages_summary(*, days: int = 30) -> dict[str, Any]:
"""
Get summary of most popular pages in the last N days.
@@ -261,7 +263,7 @@ def popular_pages_summary(*, days: int = 30) -> Dict[str, Any]:
}
def geographic_distribution_summary() -> Dict[str, Any]:
def geographic_distribution_summary() -> dict[str, Any]:
"""
Get geographic distribution statistics for all locations.
@@ -290,7 +292,7 @@ def geographic_distribution_summary() -> Dict[str, Any]:
}
def system_health_metrics() -> Dict[str, Any]:
def system_health_metrics() -> dict[str, Any]:
"""
Get system health and activity metrics.

View File

@@ -2,17 +2,17 @@
Core services for ThrillWiki unified map functionality.
"""
from .map_service import UnifiedMapService
from .clustering_service import ClusteringService
from .map_cache_service import MapCacheService
from .data_structures import (
UnifiedLocation,
LocationType,
ClusterData,
GeoBounds,
LocationType,
MapFilters,
MapResponse,
ClusterData,
UnifiedLocation,
)
from .map_cache_service import MapCacheService
from .map_service import UnifiedMapService
__all__ = [
"UnifiedMapService",

View File

@@ -3,15 +3,15 @@ Clustering service for map locations to improve performance and user experience.
"""
import math
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from collections import defaultdict
from dataclasses import dataclass
from typing import Any
from .data_structures import (
UnifiedLocation,
ClusterData,
GeoBounds,
LocationType,
UnifiedLocation,
)
@@ -70,10 +70,10 @@ class ClusteringService:
def cluster_locations(
self,
locations: List[UnifiedLocation],
locations: list[UnifiedLocation],
zoom_level: int,
bounds: Optional[GeoBounds] = None,
) -> tuple[List[UnifiedLocation], List[ClusterData]]:
bounds: GeoBounds | None = None,
) -> tuple[list[UnifiedLocation], list[ClusterData]]:
"""
Cluster locations based on zoom level and density.
Returns (unclustered_locations, clusters).
@@ -115,9 +115,9 @@ class ClusteringService:
def _project_locations(
self,
locations: List[UnifiedLocation],
bounds: Optional[GeoBounds] = None,
) -> List[ClusterPoint]:
locations: list[UnifiedLocation],
bounds: GeoBounds | None = None,
) -> list[ClusterPoint]:
"""Convert lat/lng coordinates to projected x/y for clustering calculations."""
cluster_points = []
@@ -149,8 +149,8 @@ class ClusteringService:
return cluster_points
def _cluster_points(
self, points: List[ClusterPoint], radius_pixels: int, min_points: int
) -> List[List[ClusterPoint]]:
self, points: list[ClusterPoint], radius_pixels: int, min_points: int
) -> list[list[ClusterPoint]]:
"""
Cluster points using a simple distance-based approach.
Radius is in pixels, converted to meters based on zoom level.
@@ -189,7 +189,7 @@ class ClusteringService:
dy = point1.y - point2.y
return math.sqrt(dx * dx + dy * dy)
def _create_cluster(self, cluster_points: List[ClusterPoint]) -> ClusterData:
def _create_cluster(self, cluster_points: list[ClusterPoint]) -> ClusterData:
"""Create a ClusterData object from a group of points."""
locations = [cp.location for cp in cluster_points]
@@ -205,7 +205,7 @@ class ClusteringService:
)
# Collect location types in cluster
types = set(loc.type for loc in locations)
types = {loc.type for loc in locations}
# Select representative location (highest weight)
representative = self._select_representative_location(locations)
@@ -224,8 +224,8 @@ class ClusteringService:
)
def _select_representative_location(
self, locations: List[UnifiedLocation]
) -> Optional[UnifiedLocation]:
self, locations: list[UnifiedLocation]
) -> UnifiedLocation | None:
"""Select the most representative location for a cluster."""
if not locations:
return None
@@ -259,7 +259,7 @@ class ClusteringService:
# Fall back to highest weight location
return max(locations, key=lambda x: x.cluster_weight)
def get_cluster_breakdown(self, clusters: List[ClusterData]) -> Dict[str, Any]:
def get_cluster_breakdown(self, clusters: list[ClusterData]) -> dict[str, Any]:
"""Get statistics about clustering results."""
if not clusters:
return {
@@ -293,7 +293,7 @@ class ClusteringService:
def expand_cluster(
self, cluster: ClusterData, zoom_level: int
) -> List[UnifiedLocation]:
) -> list[UnifiedLocation]:
"""
Expand a cluster to show individual locations (for drill-down functionality).
This would typically require re-querying the database with the cluster bounds.
@@ -335,7 +335,7 @@ class SmartClusteringRules:
@staticmethod
def calculate_cluster_priority(
locations: List[UnifiedLocation],
locations: list[UnifiedLocation],
) -> UnifiedLocation:
"""Select the representative location for a cluster based on priority rules."""
# Prioritize by: 1) Parks over rides, 2) Higher weight, 3) Better

View File

@@ -4,7 +4,8 @@ Data structures for the unified map service.
from dataclasses import dataclass, field
from enum import Enum
from typing import Dict, List, Optional, Set, Any
from typing import Any
from django.contrib.gis.geos import Polygon
@@ -60,7 +61,7 @@ class GeoBounds:
"""Check if a point is within these bounds."""
return self.south <= lat <= self.north and self.west <= lng <= self.east
def to_dict(self) -> Dict[str, float]:
def to_dict(self) -> dict[str, float]:
"""Convert to dictionary for JSON serialization."""
return {
"north": self.north,
@@ -74,18 +75,18 @@ class GeoBounds:
class MapFilters:
"""Filtering options for map queries."""
location_types: Optional[Set[LocationType]] = None
park_status: Optional[Set[str]] = None # OPERATING, CLOSED_TEMP, etc.
ride_types: Optional[Set[str]] = None
company_roles: Optional[Set[str]] = None # OPERATOR, MANUFACTURER, etc.
search_query: Optional[str] = None
min_rating: Optional[float] = None
location_types: set[LocationType] | None = None
park_status: set[str] | None = None # OPERATING, CLOSED_TEMP, etc.
ride_types: set[str] | None = None
company_roles: set[str] | None = None # OPERATOR, MANUFACTURER, etc.
search_query: str | None = None
min_rating: float | None = None
has_coordinates: bool = True
country: Optional[str] = None
state: Optional[str] = None
city: Optional[str] = None
country: str | None = None
state: str | None = None
city: str | None = None
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for caching and serialization."""
return {
"location_types": (
@@ -110,10 +111,10 @@ class UnifiedLocation:
id: str # Composite: f"{type}_{id}"
type: LocationType
name: str
coordinates: List[float] # [lat, lng]
address: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
type_data: Dict[str, Any] = field(default_factory=dict)
coordinates: list[float] # [lat, lng]
address: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
type_data: dict[str, Any] = field(default_factory=dict)
cluster_weight: int = 1
cluster_category: str = "default"
@@ -127,7 +128,7 @@ class UnifiedLocation:
"""Get longitude from coordinates."""
return self.coordinates[1]
def to_geojson_feature(self) -> Dict[str, Any]:
def to_geojson_feature(self) -> dict[str, Any]:
"""Convert to GeoJSON feature for mapping libraries."""
return {
"type": "Feature",
@@ -148,7 +149,7 @@ class UnifiedLocation:
},
}
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for JSON responses."""
return {
"id": self.id,
@@ -168,13 +169,13 @@ class ClusterData:
"""Represents a cluster of locations for map display."""
id: str
coordinates: List[float] # [lat, lng]
coordinates: list[float] # [lat, lng]
count: int
types: Set[LocationType]
types: set[LocationType]
bounds: GeoBounds
representative_location: Optional[UnifiedLocation] = None
representative_location: UnifiedLocation | None = None
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for JSON responses."""
return {
"id": self.id,
@@ -194,18 +195,18 @@ class ClusterData:
class MapResponse:
"""Response structure for map API calls."""
locations: List[UnifiedLocation] = field(default_factory=list)
clusters: List[ClusterData] = field(default_factory=list)
bounds: Optional[GeoBounds] = None
locations: list[UnifiedLocation] = field(default_factory=list)
clusters: list[ClusterData] = field(default_factory=list)
bounds: GeoBounds | None = None
total_count: int = 0
filtered_count: int = 0
zoom_level: Optional[int] = None
zoom_level: int | None = None
clustered: bool = False
cache_hit: bool = False
query_time_ms: Optional[int] = None
filters_applied: List[str] = field(default_factory=list)
query_time_ms: int | None = None
filters_applied: list[str] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for JSON responses."""
return {
"status": "success",
@@ -241,7 +242,7 @@ class QueryPerformanceMetrics:
bounds_used: bool
clustering_used: bool
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for logging."""
return {
"query_time_ms": self.query_time_ms,

View File

@@ -2,13 +2,15 @@
Enhanced caching service with multiple cache backends and strategies.
"""
from typing import Optional, Any, Dict, Callable
from django.core.cache import caches
import hashlib
import json
import logging
import time
from collections.abc import Callable
from functools import wraps
from typing import Any
from django.core.cache import caches
logger = logging.getLogger(__name__)
@@ -64,7 +66,7 @@ class EnhancedCacheService:
def cache_api_response(
self,
view_name: str,
params: Dict,
params: dict,
response_data: Any,
timeout: int = 1800,
):
@@ -73,7 +75,7 @@ class EnhancedCacheService:
self.api_cache.set(cache_key, response_data, timeout)
logger.debug(f"Cached API response for view '{view_name}'")
def get_cached_api_response(self, view_name: str, params: Dict) -> Optional[Any]:
def get_cached_api_response(self, view_name: str, params: dict) -> Any | None:
"""Retrieve cached API response"""
cache_key = self._generate_api_cache_key(view_name, params)
result = self.api_cache.get(cache_key)
@@ -103,7 +105,7 @@ class EnhancedCacheService:
def get_cached_geographic_data(
self, bounds: "GeoBounds", zoom_level: int
) -> Optional[Any]:
) -> Any | None:
"""Retrieve cached geographic data"""
cache_key = f"geo:{bounds.min_lat}:{bounds.min_lng}:{bounds.max_lat}:{
bounds.max_lng
@@ -129,13 +131,10 @@ class EnhancedCacheService:
logger.error(f"Error invalidating cache pattern '{pattern}': {e}")
def invalidate_model_cache(
self, model_name: str, instance_id: Optional[int] = None
self, model_name: str, instance_id: int | None = None
):
"""Invalidate cache keys related to a specific model"""
if instance_id:
pattern = f"*{model_name}:{instance_id}*"
else:
pattern = f"*{model_name}*"
pattern = f"*{model_name}:{instance_id}*" if instance_id else f"*{model_name}*"
self.invalidate_pattern(pattern)
@@ -155,7 +154,7 @@ class EnhancedCacheService:
except Exception as e:
logger.error(f"Error warming cache for key '{cache_key}': {e}")
def _generate_api_cache_key(self, view_name: str, params: Dict) -> str:
def _generate_api_cache_key(self, view_name: str, params: dict) -> str:
"""Generate consistent cache keys for API responses"""
# Sort params to ensure consistent key generation
params_str = json.dumps(params, sort_keys=True, default=str)
@@ -275,7 +274,7 @@ class CacheMonitor:
def __init__(self):
self.cache_service = EnhancedCacheService()
def get_cache_stats(self) -> Dict[str, Any]:
def get_cache_stats(self) -> dict[str, Any]:
"""Get cache statistics if available"""
stats = {}
@@ -319,7 +318,7 @@ class CacheMonitor:
if stats:
logger.info("Cache performance statistics", extra=stats)
def get_cache_statistics(self, key_prefix: str = "") -> Dict[str, Any]:
def get_cache_statistics(self, key_prefix: str = "") -> dict[str, Any]:
"""
Get cache statistics for a given key prefix.

View File

@@ -13,16 +13,15 @@ Features:
"""
import re
from difflib import SequenceMatcher
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from difflib import SequenceMatcher
from enum import Enum
from typing import Any
from django.db.models import Q
from apps.parks.models import Park
from apps.parks.models import Company, Park
from apps.rides.models import Ride
from apps.parks.models import Company
class EntityType(Enum):
@@ -44,9 +43,9 @@ class FuzzyMatchResult:
score: float # 0.0 to 1.0, higher is better match
match_reason: str # Description of why this was matched
confidence: str # 'high', 'medium', 'low'
url: Optional[str] = None
url: str | None = None
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for API responses."""
return {
"entity_type": self.entity_type.value,
@@ -180,8 +179,8 @@ class EntityFuzzyMatcher:
self.algorithms = FuzzyMatchingAlgorithms()
def find_entity(
self, query: str, entity_types: Optional[List[EntityType]] = None, user=None
) -> tuple[List[FuzzyMatchResult], Optional[EntitySuggestion]]:
self, query: str, entity_types: list[EntityType] | None = None, user=None
) -> tuple[list[FuzzyMatchResult], EntitySuggestion | None]:
"""
Find entities matching the query with fuzzy matching.
@@ -221,7 +220,7 @@ class EntityFuzzyMatcher:
def _get_candidates(
self, query: str, entity_type: EntityType
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""Get potential matching candidates for an entity type."""
candidates = []
@@ -286,8 +285,8 @@ class EntityFuzzyMatcher:
return candidates
def _score_and_rank_candidates(
self, query: str, candidates: List[Dict[str, Any]]
) -> List[FuzzyMatchResult]:
self, query: str, candidates: list[dict[str, Any]]
) -> list[FuzzyMatchResult]:
"""Score and rank all candidates using multiple algorithms."""
scored_matches = []
@@ -356,7 +355,7 @@ class EntityFuzzyMatcher:
return sorted(scored_matches, key=lambda x: x.score, reverse=True)
def _generate_entity_suggestion(
self, query: str, entity_types: List[EntityType], user
self, query: str, entity_types: list[EntityType], user
) -> EntitySuggestion:
"""Generate suggestion for creating new entity when no matches found."""

View File

@@ -2,36 +2,37 @@
Location adapters for converting between domain-specific models and UnifiedLocation.
"""
from typing import List, Optional
from django.db.models import QuerySet
from django.urls import reverse
from .data_structures import (
UnifiedLocation,
LocationType,
GeoBounds,
MapFilters,
)
from apps.parks.models import ParkLocation, CompanyHeadquarters
from apps.parks.models import CompanyHeadquarters, ParkLocation
from apps.rides.models import RideLocation
from .data_structures import (
GeoBounds,
LocationType,
MapFilters,
UnifiedLocation,
)
class BaseLocationAdapter:
"""Base adapter class for location conversions."""
def to_unified_location(self, location_obj) -> Optional[UnifiedLocation]:
def to_unified_location(self, location_obj) -> UnifiedLocation | None:
"""Convert model instance to UnifiedLocation."""
raise NotImplementedError
def get_queryset(
self,
bounds: Optional[GeoBounds] = None,
filters: Optional[MapFilters] = None,
bounds: GeoBounds | None = None,
filters: MapFilters | None = None,
) -> QuerySet:
"""Get optimized queryset for this location type."""
raise NotImplementedError
def bulk_convert(self, queryset: QuerySet) -> List[UnifiedLocation]:
def bulk_convert(self, queryset: QuerySet) -> list[UnifiedLocation]:
"""Convert multiple location objects efficiently."""
unified_locations = []
for obj in queryset:
@@ -46,7 +47,7 @@ class ParkLocationAdapter(BaseLocationAdapter):
def to_unified_location(
self, location_obj: ParkLocation
) -> Optional[UnifiedLocation]:
) -> UnifiedLocation | None:
"""Convert ParkLocation to UnifiedLocation."""
if (
not location_obj.point
@@ -106,8 +107,8 @@ class ParkLocationAdapter(BaseLocationAdapter):
def get_queryset(
self,
bounds: Optional[GeoBounds] = None,
filters: Optional[MapFilters] = None,
bounds: GeoBounds | None = None,
filters: MapFilters | None = None,
) -> QuerySet:
"""Get optimized queryset for park locations."""
queryset = ParkLocation.objects.select_related("park", "park__operator").filter(
@@ -177,7 +178,7 @@ class RideLocationAdapter(BaseLocationAdapter):
def to_unified_location(
self, location_obj: RideLocation
) -> Optional[UnifiedLocation]:
) -> UnifiedLocation | None:
"""Convert RideLocation to UnifiedLocation."""
if (
not location_obj.point
@@ -235,8 +236,8 @@ class RideLocationAdapter(BaseLocationAdapter):
def get_queryset(
self,
bounds: Optional[GeoBounds] = None,
filters: Optional[MapFilters] = None,
bounds: GeoBounds | None = None,
filters: MapFilters | None = None,
) -> QuerySet:
"""Get optimized queryset for ride locations."""
queryset = RideLocation.objects.select_related(
@@ -293,7 +294,7 @@ class CompanyLocationAdapter(BaseLocationAdapter):
def to_unified_location(
self, location_obj: CompanyHeadquarters
) -> Optional[UnifiedLocation]:
) -> UnifiedLocation | None:
"""Convert CompanyHeadquarters to UnifiedLocation."""
# Note: CompanyHeadquarters doesn't have coordinates, so we need to geocode
# For now, we'll skip companies without coordinates
@@ -302,8 +303,8 @@ class CompanyLocationAdapter(BaseLocationAdapter):
def get_queryset(
self,
bounds: Optional[GeoBounds] = None,
filters: Optional[MapFilters] = None,
bounds: GeoBounds | None = None,
filters: MapFilters | None = None,
) -> QuerySet:
"""Get optimized queryset for company locations."""
queryset = CompanyHeadquarters.objects.select_related("company")
@@ -346,9 +347,9 @@ class LocationAbstractionLayer:
def get_all_locations(
self,
bounds: Optional[GeoBounds] = None,
filters: Optional[MapFilters] = None,
) -> List[UnifiedLocation]:
bounds: GeoBounds | None = None,
filters: MapFilters | None = None,
) -> list[UnifiedLocation]:
"""Get locations from all sources within bounds."""
all_locations = []
@@ -370,9 +371,9 @@ class LocationAbstractionLayer:
def get_locations_by_type(
self,
location_type: LocationType,
bounds: Optional[GeoBounds] = None,
filters: Optional[MapFilters] = None,
) -> List[UnifiedLocation]:
bounds: GeoBounds | None = None,
filters: MapFilters | None = None,
) -> list[UnifiedLocation]:
"""Get locations of specific type."""
adapter = self.adapters[location_type]
queryset = adapter.get_queryset(bounds, filters)
@@ -380,7 +381,7 @@ class LocationAbstractionLayer:
def get_location_by_id(
self, location_type: LocationType, location_id: int
) -> Optional[UnifiedLocation]:
) -> UnifiedLocation | None:
"""Get single location with full details."""
adapter = self.adapters[location_type]

View File

@@ -6,13 +6,14 @@ to provide proximity-based search, location filtering, and geographic
search capabilities.
"""
from dataclasses import dataclass
from typing import Any
from django.contrib.gis.geos import Point
from django.contrib.gis.measure import Distance
from django.db.models import Q
from typing import Optional, List, Dict, Any, Set
from dataclasses import dataclass
from apps.parks.models import Park, Company, ParkLocation
from apps.parks.models import Company, Park, ParkLocation
from apps.rides.models import Ride
@@ -21,22 +22,22 @@ class LocationSearchFilters:
"""Filters for location-aware search queries."""
# Text search
search_query: Optional[str] = None
search_query: str | None = None
# Location-based filters
location_point: Optional[Point] = None
radius_km: Optional[float] = None
location_types: Optional[Set[str]] = None # 'park', 'ride', 'company'
location_point: Point | None = None
radius_km: float | None = None
location_types: set[str] | None = None # 'park', 'ride', 'company'
# Geographic filters
country: Optional[str] = None
state: Optional[str] = None
city: Optional[str] = None
country: str | None = None
state: str | None = None
city: str | None = None
# Content-specific filters
park_status: Optional[List[str]] = None
ride_types: Optional[List[str]] = None
company_roles: Optional[List[str]] = None
park_status: list[str] | None = None
ride_types: list[str] | None = None
company_roles: list[str] | None = None
# Result options
include_distance: bool = True
@@ -51,26 +52,26 @@ class LocationSearchResult:
content_type: str # 'park', 'ride', 'company'
object_id: int
name: str
description: Optional[str] = None
url: Optional[str] = None
description: str | None = None
url: str | None = None
# Location data
latitude: Optional[float] = None
longitude: Optional[float] = None
address: Optional[str] = None
city: Optional[str] = None
state: Optional[str] = None
country: Optional[str] = None
latitude: float | None = None
longitude: float | None = None
address: str | None = None
city: str | None = None
state: str | None = None
country: str | None = None
# Distance data (if proximity search)
distance_km: Optional[float] = None
distance_km: float | None = None
# Additional metadata
status: Optional[str] = None
tags: Optional[List[str]] = None
rating: Optional[float] = None
status: str | None = None
tags: list[str] | None = None
rating: float | None = None
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
return {
"content_type": self.content_type,
@@ -96,7 +97,7 @@ class LocationSearchResult:
class LocationSearchService:
"""Service for performing location-aware searches across ThrillWiki content."""
def search(self, filters: LocationSearchFilters) -> List[LocationSearchResult]:
def search(self, filters: LocationSearchFilters) -> list[LocationSearchResult]:
"""
Perform a comprehensive location-aware search.
@@ -129,7 +130,7 @@ class LocationSearchService:
def _search_parks(
self, filters: LocationSearchFilters
) -> List[LocationSearchResult]:
) -> list[LocationSearchResult]:
"""Search parks with location data."""
queryset = Park.objects.select_related("location", "operator").all()
@@ -199,7 +200,7 @@ class LocationSearchService:
def _search_rides(
self, filters: LocationSearchFilters
) -> List[LocationSearchResult]:
) -> list[LocationSearchResult]:
"""Search rides with location data."""
queryset = Ride.objects.select_related("park", "location").all()
@@ -282,7 +283,7 @@ class LocationSearchService:
def _search_companies(
self, filters: LocationSearchFilters
) -> List[LocationSearchResult]:
) -> list[LocationSearchResult]:
"""Search companies with headquarters location data."""
queryset = Company.objects.select_related("headquarters").all()
@@ -398,7 +399,7 @@ class LocationSearchService:
return queryset
def suggest_locations(self, query: str, limit: int = 10) -> List[Dict[str, Any]]:
def suggest_locations(self, query: str, limit: int = 10) -> list[dict[str, Any]]:
"""
Get location suggestions for autocomplete.

View File

@@ -5,18 +5,18 @@ Caching service for map data to improve performance and reduce database load.
import hashlib
import json
import time
from typing import Dict, List, Optional, Any
from typing import Any
from django.core.cache import cache
from django.utils import timezone
from .data_structures import (
UnifiedLocation,
ClusterData,
GeoBounds,
MapFilters,
MapResponse,
QueryPerformanceMetrics,
UnifiedLocation,
)
@@ -52,9 +52,9 @@ class MapCacheService:
def get_locations_cache_key(
self,
bounds: Optional[GeoBounds],
filters: Optional[MapFilters],
zoom_level: Optional[int] = None,
bounds: GeoBounds | None,
filters: MapFilters | None,
zoom_level: int | None = None,
) -> str:
"""Generate cache key for location queries."""
key_parts = [self.LOCATIONS_PREFIX]
@@ -76,8 +76,8 @@ class MapCacheService:
def get_clusters_cache_key(
self,
bounds: Optional[GeoBounds],
filters: Optional[MapFilters],
bounds: GeoBounds | None,
filters: MapFilters | None,
zoom_level: int,
) -> str:
"""Generate cache key for cluster queries."""
@@ -102,8 +102,8 @@ class MapCacheService:
def cache_locations(
self,
cache_key: str,
locations: List[UnifiedLocation],
ttl: Optional[int] = None,
locations: list[UnifiedLocation],
ttl: int | None = None,
) -> None:
"""Cache location data."""
try:
@@ -122,8 +122,8 @@ class MapCacheService:
def cache_clusters(
self,
cache_key: str,
clusters: List[ClusterData],
ttl: Optional[int] = None,
clusters: list[ClusterData],
ttl: int | None = None,
) -> None:
"""Cache cluster data."""
try:
@@ -138,7 +138,7 @@ class MapCacheService:
print(f"Cache write error for clusters {cache_key}: {e}")
def cache_map_response(
self, cache_key: str, response: MapResponse, ttl: Optional[int] = None
self, cache_key: str, response: MapResponse, ttl: int | None = None
) -> None:
"""Cache complete map response."""
try:
@@ -149,7 +149,7 @@ class MapCacheService:
except Exception as e:
print(f"Cache write error for response {cache_key}: {e}")
def get_cached_locations(self, cache_key: str) -> Optional[List[UnifiedLocation]]:
def get_cached_locations(self, cache_key: str) -> list[UnifiedLocation] | None:
"""Retrieve cached location data."""
try:
cache_data = cache.get(cache_key)
@@ -172,7 +172,7 @@ class MapCacheService:
self.cache_stats["misses"] += 1
return None
def get_cached_clusters(self, cache_key: str) -> Optional[List[ClusterData]]:
def get_cached_clusters(self, cache_key: str) -> list[ClusterData] | None:
"""Retrieve cached cluster data."""
try:
cache_data = cache.get(cache_key)
@@ -194,7 +194,7 @@ class MapCacheService:
self.cache_stats["misses"] += 1
return None
def get_cached_map_response(self, cache_key: str) -> Optional[MapResponse]:
def get_cached_map_response(self, cache_key: str) -> MapResponse | None:
"""Retrieve cached map response."""
try:
cache_data = cache.get(cache_key)
@@ -213,7 +213,7 @@ class MapCacheService:
return None
def invalidate_location_cache(
self, location_type: str, location_id: Optional[int] = None
self, location_type: str, location_id: int | None = None
) -> None:
"""Invalidate cache for specific location or all locations of a type."""
try:
@@ -268,7 +268,7 @@ class MapCacheService:
except Exception as e:
print(f"Cache clear error: {e}")
def get_cache_stats(self) -> Dict[str, Any]:
def get_cache_stats(self) -> dict[str, Any]:
"""Get cache performance statistics."""
total_requests = self.cache_stats["hits"] + self.cache_stats["misses"]
hit_rate = (
@@ -370,7 +370,7 @@ class MapCacheService:
filter_str = json.dumps(filter_dict, sort_keys=True)
return hashlib.md5(filter_str.encode()).hexdigest()[:8]
def _dict_to_unified_location(self, data: Dict[str, Any]) -> UnifiedLocation:
def _dict_to_unified_location(self, data: dict[str, Any]) -> UnifiedLocation:
"""Convert dictionary back to UnifiedLocation object."""
from .data_structures import LocationType
@@ -386,7 +386,7 @@ class MapCacheService:
cluster_category=data.get("cluster_category", "default"),
)
def _dict_to_cluster_data(self, data: Dict[str, Any]) -> ClusterData:
def _dict_to_cluster_data(self, data: dict[str, Any]) -> ClusterData:
"""Convert dictionary back to ClusterData object."""
from .data_structures import LocationType
@@ -406,7 +406,7 @@ class MapCacheService:
representative_location=representative,
)
def _dict_to_map_response(self, data: Dict[str, Any]) -> MapResponse:
def _dict_to_map_response(self, data: dict[str, Any]) -> MapResponse:
"""Convert dictionary back to MapResponse object."""
locations = [
self._dict_to_unified_location(loc) for loc in data.get("locations", [])

View File

@@ -3,20 +3,21 @@ Unified Map Service - Main orchestrating service for all map functionality.
"""
import time
from typing import List, Optional, Dict, Any, Set
from typing import Any
from django.db import connection
from .clustering_service import ClusteringService
from .data_structures import (
UnifiedLocation,
ClusterData,
GeoBounds,
LocationType,
MapFilters,
MapResponse,
LocationType,
QueryPerformanceMetrics,
UnifiedLocation,
)
from .location_adapters import LocationAbstractionLayer
from .clustering_service import ClusteringService
from .map_cache_service import MapCacheService
@@ -39,8 +40,8 @@ class UnifiedMapService:
def get_map_data(
self,
*,
bounds: Optional[GeoBounds] = None,
filters: Optional[MapFilters] = None,
bounds: GeoBounds | None = None,
filters: MapFilters | None = None,
zoom_level: int = DEFAULT_ZOOM_LEVEL,
cluster: bool = True,
use_cache: bool = True,
@@ -145,7 +146,7 @@ class UnifiedMapService:
def get_location_details(
self, location_type: str, location_id: int
) -> Optional[UnifiedLocation]:
) -> UnifiedLocation | None:
"""
Get detailed information for a specific location.
@@ -188,10 +189,10 @@ class UnifiedMapService:
def search_locations(
self,
query: str,
bounds: Optional[GeoBounds] = None,
location_types: Optional[Set[LocationType]] = None,
bounds: GeoBounds | None = None,
location_types: set[LocationType] | None = None,
limit: int = 50,
) -> List[UnifiedLocation]:
) -> list[UnifiedLocation]:
"""
Search locations with text query.
@@ -228,7 +229,7 @@ class UnifiedMapService:
south: float,
east: float,
west: float,
location_types: Optional[Set[LocationType]] = None,
location_types: set[LocationType] | None = None,
zoom_level: int = DEFAULT_ZOOM_LEVEL,
) -> MapResponse:
"""
@@ -261,8 +262,8 @@ class UnifiedMapService:
def get_clustered_locations(
self,
zoom_level: int,
bounds: Optional[GeoBounds] = None,
filters: Optional[MapFilters] = None,
bounds: GeoBounds | None = None,
filters: MapFilters | None = None,
) -> MapResponse:
"""
Get clustered location data for map display.
@@ -282,9 +283,9 @@ class UnifiedMapService:
def get_locations_by_type(
self,
location_type: LocationType,
bounds: Optional[GeoBounds] = None,
limit: Optional[int] = None,
) -> List[UnifiedLocation]:
bounds: GeoBounds | None = None,
limit: int | None = None,
) -> list[UnifiedLocation]:
"""
Get locations of a specific type.
@@ -313,9 +314,9 @@ class UnifiedMapService:
def invalidate_cache(
self,
location_type: Optional[str] = None,
location_id: Optional[int] = None,
bounds: Optional[GeoBounds] = None,
location_type: str | None = None,
location_id: int | None = None,
bounds: GeoBounds | None = None,
) -> None:
"""
Invalidate cached map data.
@@ -332,7 +333,7 @@ class UnifiedMapService:
else:
self.cache_service.clear_all_map_cache()
def get_service_stats(self) -> Dict[str, Any]:
def get_service_stats(self) -> dict[str, Any]:
"""Get service performance and usage statistics."""
cache_stats = self.cache_service.get_cache_stats()
@@ -346,17 +347,17 @@ class UnifiedMapService:
}
def _get_locations_from_db(
self, bounds: Optional[GeoBounds], filters: Optional[MapFilters]
) -> List[UnifiedLocation]:
self, bounds: GeoBounds | None, filters: MapFilters | None
) -> list[UnifiedLocation]:
"""Get locations from database using the abstraction layer."""
return self.location_layer.get_all_locations(bounds, filters)
def _apply_smart_limiting(
self,
locations: List[UnifiedLocation],
bounds: Optional[GeoBounds],
locations: list[UnifiedLocation],
bounds: GeoBounds | None,
zoom_level: int,
) -> List[UnifiedLocation]:
) -> list[UnifiedLocation]:
"""Apply intelligent limiting based on zoom level and density."""
if zoom_level < 6: # Very zoomed out - show only major parks
major_parks = [
@@ -375,10 +376,10 @@ class UnifiedMapService:
def _calculate_response_bounds(
self,
locations: List[UnifiedLocation],
clusters: List[ClusterData],
request_bounds: Optional[GeoBounds],
) -> Optional[GeoBounds]:
locations: list[UnifiedLocation],
clusters: list[ClusterData],
request_bounds: GeoBounds | None,
) -> GeoBounds | None:
"""Calculate the actual bounds of the response data."""
if request_bounds:
return request_bounds
@@ -396,12 +397,12 @@ class UnifiedMapService:
if not all_coords:
return None
lats, lngs = zip(*all_coords)
lats, lngs = zip(*all_coords, strict=False)
return GeoBounds(
north=max(lats), south=min(lats), east=max(lngs), west=min(lngs)
)
def _get_applied_filters_list(self, filters: Optional[MapFilters]) -> List[str]:
def _get_applied_filters_list(self, filters: MapFilters | None) -> list[str]:
"""Get list of applied filter types for metadata."""
if not filters:
return []
@@ -430,8 +431,8 @@ class UnifiedMapService:
def _generate_cache_key(
self,
bounds: Optional[GeoBounds],
filters: Optional[MapFilters],
bounds: GeoBounds | None,
filters: MapFilters | None,
zoom_level: int,
cluster: bool,
) -> str:

View File

@@ -6,12 +6,13 @@ that can be used across all domain-specific media implementations.
"""
import logging
from typing import Any, Optional, Dict
from datetime import datetime
from django.core.files.uploadedfile import UploadedFile
from django.conf import settings
from PIL import Image, ExifTags
import os
from datetime import datetime
from typing import Any
from django.conf import settings
from django.core.files.uploadedfile import UploadedFile
from PIL import ExifTags, Image
logger = logging.getLogger(__name__)
@@ -21,7 +22,7 @@ class MediaService:
@staticmethod
def generate_upload_path(
domain: str, identifier: str, filename: str, subdirectory: Optional[str] = None
domain: str, identifier: str, filename: str, subdirectory: str | None = None
) -> str:
"""
Generate standardized upload path for media files.
@@ -44,7 +45,7 @@ class MediaService:
return f"{domain}/{identifier}/{base_filename}"
@staticmethod
def extract_exif_date(image_file: UploadedFile) -> Optional[datetime]:
def extract_exif_date(image_file: UploadedFile) -> datetime | None:
"""
Extract the date taken from image EXIF data.
@@ -60,18 +61,17 @@ class MediaService:
if exif:
# Find the DateTime tag ID
for tag_id in ExifTags.TAGS:
if ExifTags.TAGS[tag_id] == "DateTimeOriginal":
if tag_id in exif:
# EXIF dates are typically in format: '2024:02:15 14:30:00'
date_str = exif[tag_id]
return datetime.strptime(date_str, "%Y:%m:%d %H:%M:%S")
if ExifTags.TAGS[tag_id] == "DateTimeOriginal" and tag_id in exif:
# EXIF dates are typically in format: '2024:02:15 14:30:00'
date_str = exif[tag_id]
return datetime.strptime(date_str, "%Y:%m:%d %H:%M:%S")
return None
except Exception as e:
logger.warning(f"Failed to extract EXIF date: {str(e)}")
return None
@staticmethod
def validate_image_file(image_file: UploadedFile) -> tuple[bool, Optional[str]]:
def validate_image_file(image_file: UploadedFile) -> tuple[bool, str | None]:
"""
Validate uploaded image file.
@@ -144,6 +144,7 @@ class MediaService:
# Save processed image
from io import BytesIO
from django.core.files.uploadedfile import InMemoryUploadedFile
output = BytesIO()
@@ -180,7 +181,7 @@ class MediaService:
return f"Uploaded by {username} on {current_time.strftime('%B %d, %Y at %I:%M %p')}"
@staticmethod
def get_storage_stats() -> Dict[str, Any]:
def get_storage_stats() -> dict[str, Any]:
"""
Get media storage statistics.

View File

@@ -6,7 +6,8 @@ while maintaining compatibility with Cloudflare Images.
"""
import re
from typing import Optional, Dict, Any
from typing import Any
from django.utils.text import slugify
@@ -83,7 +84,7 @@ class MediaURLService:
return f"/parks/{park_slug}/rides/{ride_slug}/photos/{filename}"
@staticmethod
def parse_photo_filename(filename: str) -> Optional[Dict[str, Any]]:
def parse_photo_filename(filename: str) -> dict[str, Any] | None:
"""
Parse a friendly filename to extract photo ID and variant.
@@ -118,7 +119,7 @@ class MediaURLService:
return None
@staticmethod
def get_cloudflare_url_with_fallback(cloudflare_image, variant: str = "public") -> Optional[str]:
def get_cloudflare_url_with_fallback(cloudflare_image, variant: str = "public") -> str | None:
"""
Get Cloudflare URL with fallback handling.

View File

@@ -2,13 +2,14 @@
Performance monitoring utilities and context managers.
"""
import time
import logging
import time
from contextlib import contextmanager
from functools import wraps
from typing import Optional, Dict, Any, List
from django.db import connection
from typing import Any
from django.conf import settings
from django.db import connection
from django.utils import timezone
logger = logging.getLogger("performance")
@@ -271,7 +272,7 @@ class DatabaseQueryAnalyzer:
"""Analyze database query patterns and performance"""
@staticmethod
def analyze_queries(queries: List[Dict]) -> Dict[str, Any]:
def analyze_queries(queries: list[dict]) -> dict[str, Any]:
"""Analyze a list of queries for patterns and issues"""
if not queries:
return {}
@@ -332,7 +333,7 @@ class DatabaseQueryAnalyzer:
return analysis
@classmethod
def analyze_current_queries(cls) -> Dict[str, Any]:
def analyze_current_queries(cls) -> dict[str, Any]:
"""Analyze the current request's queries"""
if hasattr(connection, "queries"):
return cls.analyze_queries(connection.queries)
@@ -340,7 +341,7 @@ class DatabaseQueryAnalyzer:
# Performance monitoring decorators
def monitor_function_performance(operation_name: Optional[str] = None):
def monitor_function_performance(operation_name: str | None = None):
"""Decorator to monitor function performance"""
def decorator(func):
@@ -379,7 +380,7 @@ class PerformanceMetrics:
def __init__(self):
self.metrics = []
def record_metric(self, name: str, value: float, tags: Optional[Dict] = None):
def record_metric(self, name: str, value: float, tags: dict | None = None):
"""Record a performance metric"""
metric = {
"name": name,
@@ -392,7 +393,7 @@ class PerformanceMetrics:
# Log the metric
logger.info(f"Performance metric: {name} = {value}", extra=metric)
def get_metrics(self, name: Optional[str] = None) -> List[Dict]:
def get_metrics(self, name: str | None = None) -> list[dict]:
"""Get recorded metrics, optionally filtered by name"""
if name:
return [m for m in self.metrics if m["name"] == name]

View File

@@ -12,11 +12,12 @@ Results are cached in Redis for performance optimization.
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Any
from django.utils import timezone
from typing import Any
from django.contrib.contenttypes.models import ContentType
from django.core.cache import cache
from django.db.models import Q
from django.utils import timezone
from apps.core.analytics import PageView
from apps.parks.models import Park
@@ -56,7 +57,7 @@ class TrendingService:
def get_trending_content(
self, content_type: str = "all", limit: int = 20, force_refresh: bool = False
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
Get trending content using direct calculation.
@@ -121,7 +122,7 @@ class TrendingService:
limit: int = 20,
days_back: int = 30,
force_refresh: bool = False,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""
Get recently added content using direct calculation.
@@ -182,7 +183,7 @@ class TrendingService:
self.logger.error(f"Error getting new content: {e}", exc_info=True)
return []
def _calculate_trending_parks(self, limit: int) -> List[Dict[str, Any]]:
def _calculate_trending_parks(self, limit: int) -> list[dict[str, Any]]:
"""Calculate trending scores for parks."""
parks = Park.objects.filter(status="OPERATING").select_related(
"location", "operator", "card_image"
@@ -253,7 +254,7 @@ class TrendingService:
return trending_parks
def _calculate_trending_rides(self, limit: int) -> List[Dict[str, Any]]:
def _calculate_trending_rides(self, limit: int) -> list[dict[str, Any]]:
"""Calculate trending scores for rides."""
rides = Ride.objects.filter(status="OPERATING").select_related(
"park", "park__location", "card_image"
@@ -456,7 +457,7 @@ class TrendingService:
self.logger.warning(f"Error calculating popularity score: {e}")
return 0.0
def _get_new_parks(self, cutoff_date: datetime, limit: int) -> List[Dict[str, Any]]:
def _get_new_parks(self, cutoff_date: datetime, limit: int) -> list[dict[str, Any]]:
"""Get recently added parks."""
new_parks = (
Park.objects.filter(
@@ -528,7 +529,7 @@ class TrendingService:
return results
def _get_new_rides(self, cutoff_date: datetime, limit: int) -> List[Dict[str, Any]]:
def _get_new_rides(self, cutoff_date: datetime, limit: int) -> list[dict[str, Any]]:
"""Get recently added rides."""
new_rides = (
Ride.objects.filter(
@@ -584,8 +585,8 @@ class TrendingService:
return results
def _format_trending_results(
self, trending_items: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
self, trending_items: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Format trending results for frontend consumption."""
formatted_results = []
@@ -649,8 +650,8 @@ class TrendingService:
return formatted_results
def _format_new_content_results(
self, new_items: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
self, new_items: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Format new content results for frontend consumption."""
formatted_results = []

View File

@@ -1,106 +1,106 @@
"""State machine utilities for core app."""
from .fields import RichFSMField
from .mixins import StateMachineMixin
from .builder import (
StateTransitionBuilder,
determine_method_name_for_transition,
)
from .decorators import (
generate_transition_decorator,
TransitionMethodFactory,
with_callbacks,
register_method_callbacks,
)
from .registry import (
TransitionRegistry,
TransitionInfo,
registry_instance,
register_callback,
register_notification_callback,
register_cache_invalidation,
register_related_update,
register_transition_callbacks,
discover_and_register_callbacks,
)
from .callback_base import (
BaseTransitionCallback,
PreTransitionCallback,
PostTransitionCallback,
ErrorTransitionCallback,
TransitionContext,
TransitionCallbackRegistry,
callback_registry,
CallbackStage,
)
from .signals import (
pre_state_transition,
post_state_transition,
state_transition_failed,
register_transition_handler,
on_transition,
on_pre_transition,
on_post_transition,
on_transition_error,
ErrorTransitionCallback,
PostTransitionCallback,
PreTransitionCallback,
TransitionCallbackRegistry,
TransitionContext,
callback_registry,
)
from .config import (
CallbackConfig,
callback_config,
get_callback_config,
)
from .monitoring import (
CallbackMonitor,
callback_monitor,
TimedCallbackExecution,
)
from .validators import MetadataValidator, ValidationResult
from .guards import (
# Role constants
VALID_ROLES,
MODERATOR_ROLES,
ADMIN_ROLES,
SUPERUSER_ROLES,
ESCALATION_LEVEL_ROLES,
# Guard classes
PermissionGuard,
OwnershipGuard,
AssignmentGuard,
StateGuard,
MetadataGuard,
CompositeGuard,
# Guard extraction and creation
extract_guards_from_metadata,
create_permission_guard,
create_ownership_guard,
create_assignment_guard,
create_composite_guard,
validate_guard_metadata,
# Registry
GuardRegistry,
guard_registry,
# Role checking functions
get_user_role,
has_role,
is_moderator_or_above,
is_admin_or_above,
is_superuser_role,
has_permission,
from .decorators import (
TransitionMethodFactory,
generate_transition_decorator,
register_method_callbacks,
with_callbacks,
)
from .exceptions import (
ERROR_MESSAGES,
TransitionNotAvailable,
TransitionPermissionDenied,
TransitionValidationError,
TransitionNotAvailable,
ERROR_MESSAGES,
format_transition_error,
get_permission_error_message,
get_state_error_message,
format_transition_error,
raise_permission_denied,
raise_validation_error,
)
from .fields import RichFSMField
from .guards import (
ADMIN_ROLES,
ESCALATION_LEVEL_ROLES,
MODERATOR_ROLES,
SUPERUSER_ROLES,
# Role constants
VALID_ROLES,
AssignmentGuard,
CompositeGuard,
# Registry
GuardRegistry,
MetadataGuard,
OwnershipGuard,
# Guard classes
PermissionGuard,
StateGuard,
create_assignment_guard,
create_composite_guard,
create_ownership_guard,
create_permission_guard,
# Guard extraction and creation
extract_guards_from_metadata,
# Role checking functions
get_user_role,
guard_registry,
has_permission,
has_role,
is_admin_or_above,
is_moderator_or_above,
is_superuser_role,
validate_guard_metadata,
)
from .integration import (
apply_state_machine,
StateMachineModelMixin,
apply_state_machine,
state_machine_model,
)
from .mixins import StateMachineMixin
from .monitoring import (
CallbackMonitor,
TimedCallbackExecution,
callback_monitor,
)
from .registry import (
TransitionInfo,
TransitionRegistry,
discover_and_register_callbacks,
register_cache_invalidation,
register_callback,
register_notification_callback,
register_related_update,
register_transition_callbacks,
registry_instance,
)
from .signals import (
on_post_transition,
on_pre_transition,
on_transition,
on_transition_error,
post_state_transition,
pre_state_transition,
register_transition_handler,
state_transition_failed,
)
from .validators import MetadataValidator, ValidationResult
__all__ = [
# Fields and mixins

View File

@@ -60,11 +60,12 @@ See Also:
- apps.core.choices.registry: Central choice registry
- apps.core.state_machine.guards: Guard extraction from metadata
"""
from typing import Dict, List, Optional, Any
from typing import Any
from django.core.exceptions import ImproperlyConfigured
from apps.core.choices.registry import registry
from apps.core.choices.base import RichChoice
from apps.core.choices.registry import registry
class StateTransitionBuilder:
@@ -123,7 +124,7 @@ class StateTransitionBuilder:
"""
self.choice_group = choice_group
self.domain = domain
self._cache: Dict[str, Any] = {}
self._cache: dict[str, Any] = {}
# Validate choice group exists
group = registry.get(choice_group, domain)
@@ -134,7 +135,7 @@ class StateTransitionBuilder:
self.choices = registry.get_choices(choice_group, domain)
def get_choice_metadata(self, state_value: str) -> Dict[str, Any]:
def get_choice_metadata(self, state_value: str) -> dict[str, Any]:
"""
Retrieve metadata for a specific state.
@@ -156,7 +157,7 @@ class StateTransitionBuilder:
self._cache[cache_key] = metadata
return metadata
def extract_valid_transitions(self, state_value: str) -> List[str]:
def extract_valid_transitions(self, state_value: str) -> list[str]:
"""
Get can_transition_to list from metadata.
@@ -184,7 +185,7 @@ class StateTransitionBuilder:
def extract_permission_requirements(
self, state_value: str
) -> Dict[str, bool]:
) -> dict[str, bool]:
"""
Extract permission requirements from metadata.
@@ -228,7 +229,7 @@ class StateTransitionBuilder:
metadata = self.get_choice_metadata(state_value)
return metadata.get("is_actionable", False)
def build_transition_graph(self) -> Dict[str, List[str]]:
def build_transition_graph(self) -> dict[str, list[str]]:
"""
Create a complete state transition graph.
@@ -247,7 +248,7 @@ class StateTransitionBuilder:
self._cache[cache_key] = graph
return graph
def get_all_states(self) -> List[str]:
def get_all_states(self) -> list[str]:
"""
Get all state values in the choice group.
@@ -256,7 +257,7 @@ class StateTransitionBuilder:
"""
return [choice.value for choice in self.choices]
def get_choice(self, state_value: str) -> Optional[RichChoice]:
def get_choice(self, state_value: str) -> RichChoice | None:
"""
Get the RichChoice object for a state.
@@ -276,7 +277,7 @@ class StateTransitionBuilder:
def determine_method_name_for_transition(source: str, target: str) -> str:
"""
Determine appropriate method name for a transition.
Always uses transition_to_<state> pattern to avoid conflicts with
business logic methods (approve, reject, escalate, etc.).

View File

@@ -66,16 +66,15 @@ See Also:
- apps.core.state_machine.callbacks.related_updates: Related model callbacks
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
import logging
from typing import Any, Optional
from django.db import models
logger = logging.getLogger(__name__)
@@ -167,12 +166,12 @@ class TransitionContext:
field_name: str
source_state: str
target_state: str
user: Optional[Any] = None
user: Any | None = None
timestamp: datetime = field(default_factory=datetime.now)
extra_data: Dict[str, Any] = field(default_factory=dict)
extra_data: dict[str, Any] = field(default_factory=dict)
@property
def model_class(self) -> Type[models.Model]:
def model_class(self) -> type[models.Model]:
"""Get the model class of the instance."""
return type(self.instance)
@@ -206,9 +205,9 @@ class BaseTransitionCallback(ABC):
def __init__(
self,
priority: Optional[int] = None,
continue_on_error: Optional[bool] = None,
name: Optional[str] = None,
priority: int | None = None,
continue_on_error: bool | None = None,
name: str | None = None,
):
if priority is not None:
self.priority = priority
@@ -288,7 +287,7 @@ class ErrorTransitionCallback(BaseTransitionCallback):
# Error callbacks should always continue
continue_on_error: bool = True
def execute(self, context: TransitionContext, exception: Optional[Exception] = None) -> bool:
def execute(self, context: TransitionContext, exception: Exception | None = None) -> bool:
"""
Execute the error callback.
@@ -307,7 +306,7 @@ class CallbackRegistration:
"""Represents a registered callback with its configuration."""
callback: BaseTransitionCallback
model_class: Type[models.Model]
model_class: type[models.Model]
field_name: str
source: str # Can be '*' for wildcard
target: str # Can be '*' for wildcard
@@ -315,7 +314,7 @@ class CallbackRegistration:
def matches(
self,
model_class: Type[models.Model],
model_class: type[models.Model],
field_name: str,
source: str,
target: str,
@@ -327,9 +326,7 @@ class CallbackRegistration:
return False
if self.source != '*' and self.source != source:
return False
if self.target != '*' and self.target != target:
return False
return True
return not (self.target != '*' and self.target != target)
class TransitionCallbackRegistry:
@@ -351,7 +348,7 @@ class TransitionCallbackRegistry:
def __init__(self):
if self._initialized:
return
self._callbacks: Dict[CallbackStage, List[CallbackRegistration]] = {
self._callbacks: dict[CallbackStage, list[CallbackRegistration]] = {
CallbackStage.PRE: [],
CallbackStage.POST: [],
CallbackStage.ERROR: [],
@@ -360,12 +357,12 @@ class TransitionCallbackRegistry:
def register(
self,
model_class: Type[models.Model],
model_class: type[models.Model],
field_name: str,
source: str,
target: str,
callback: BaseTransitionCallback,
stage: Union[CallbackStage, str] = CallbackStage.POST,
stage: CallbackStage | str = CallbackStage.POST,
) -> None:
"""
Register a callback for a specific transition.
@@ -402,10 +399,10 @@ class TransitionCallbackRegistry:
def register_bulk(
self,
model_class: Type[models.Model],
model_class: type[models.Model],
field_name: str,
callbacks_config: Dict[Tuple[str, str], List[BaseTransitionCallback]],
stage: Union[CallbackStage, str] = CallbackStage.POST,
callbacks_config: dict[tuple[str, str], list[BaseTransitionCallback]],
stage: CallbackStage | str = CallbackStage.POST,
) -> None:
"""
Register multiple callbacks for multiple transitions.
@@ -422,12 +419,12 @@ class TransitionCallbackRegistry:
def get_callbacks(
self,
model_class: Type[models.Model],
model_class: type[models.Model],
field_name: str,
source: str,
target: str,
stage: Union[CallbackStage, str] = CallbackStage.POST,
) -> List[BaseTransitionCallback]:
stage: CallbackStage | str = CallbackStage.POST,
) -> list[BaseTransitionCallback]:
"""
Get all callbacks matching the given transition.
@@ -454,9 +451,9 @@ class TransitionCallbackRegistry:
def execute_callbacks(
self,
context: TransitionContext,
stage: Union[CallbackStage, str] = CallbackStage.POST,
exception: Optional[Exception] = None,
) -> Tuple[bool, List[Tuple[BaseTransitionCallback, Optional[Exception]]]]:
stage: CallbackStage | str = CallbackStage.POST,
exception: Exception | None = None,
) -> tuple[bool, list[tuple[BaseTransitionCallback, Exception | None]]]:
"""
Execute all callbacks for a transition.
@@ -479,7 +476,7 @@ class TransitionCallbackRegistry:
stage,
)
failures: List[Tuple[BaseTransitionCallback, Optional[Exception]]] = []
failures: list[tuple[BaseTransitionCallback, Exception | None]] = []
overall_success = True
for callback in callbacks:
@@ -530,7 +527,7 @@ class TransitionCallbackRegistry:
return overall_success, failures
def clear(self, model_class: Optional[Type[models.Model]] = None) -> None:
def clear(self, model_class: type[models.Model] | None = None) -> None:
"""
Clear registered callbacks.
@@ -550,8 +547,8 @@ class TransitionCallbackRegistry:
def get_all_registrations(
self,
model_class: Optional[Type[models.Model]] = None,
) -> Dict[CallbackStage, List[CallbackRegistration]]:
model_class: type[models.Model] | None = None,
) -> dict[CallbackStage, list[CallbackRegistration]]:
"""
Get all registered callbacks, optionally filtered by model class.
@@ -585,19 +582,19 @@ callback_registry = TransitionCallbackRegistry()
# Convenience functions for common operations
def register_callback(
model_class: Type[models.Model],
model_class: type[models.Model],
field_name: str,
source: str,
target: str,
callback: BaseTransitionCallback,
stage: Union[CallbackStage, str] = CallbackStage.POST,
stage: CallbackStage | str = CallbackStage.POST,
) -> None:
"""Convenience function to register a callback."""
callback_registry.register(model_class, field_name, source, target, callback, stage)
def register_pre_callback(
model_class: Type[models.Model],
model_class: type[models.Model],
field_name: str,
source: str,
target: str,
@@ -610,7 +607,7 @@ def register_pre_callback(
def register_post_callback(
model_class: Type[models.Model],
model_class: type[models.Model],
field_name: str,
source: str,
target: str,
@@ -623,7 +620,7 @@ def register_post_callback(
def register_error_callback(
model_class: Type[models.Model],
model_class: type[models.Model],
field_name: str,
source: str,
target: str,

View File

@@ -5,29 +5,28 @@ This package provides specialized callback implementations for
FSM state transitions.
"""
from .notifications import (
NotificationCallback,
SubmissionApprovedNotification,
SubmissionRejectedNotification,
SubmissionEscalatedNotification,
StatusChangeNotification,
ModerationNotificationCallback,
)
from .cache import (
APICacheInvalidation,
CacheInvalidationCallback,
ModelCacheInvalidation,
RelatedModelCacheInvalidation,
PatternCacheInvalidation,
APICacheInvalidation,
RelatedModelCacheInvalidation,
)
from .notifications import (
ModerationNotificationCallback,
NotificationCallback,
StatusChangeNotification,
SubmissionApprovedNotification,
SubmissionEscalatedNotification,
SubmissionRejectedNotification,
)
from .related_updates import (
RelatedModelUpdateCallback,
ParkCountUpdateCallback,
SearchTextUpdateCallback,
ComputedFieldUpdateCallback,
ParkCountUpdateCallback,
RelatedModelUpdateCallback,
SearchTextUpdateCallback,
)
__all__ = [
# Notification callbacks
"NotificationCallback",

View File

@@ -5,15 +5,12 @@ This module provides callback implementations that invalidate cache entries
when state transitions occur.
"""
from typing import Any, Dict, List, Optional, Set, Type
import logging
from django.conf import settings
from django.db import models
from ..callback_base import PostTransitionCallback, TransitionContext
logger = logging.getLogger(__name__)
@@ -29,7 +26,7 @@ class CacheInvalidationCallback(PostTransitionCallback):
def __init__(
self,
patterns: Optional[List[str]] = None,
patterns: list[str] | None = None,
include_instance_patterns: bool = True,
**kwargs,
):
@@ -62,7 +59,7 @@ class CacheInvalidationCallback(PostTransitionCallback):
logger.warning("EnhancedCacheService not available")
return None
def _get_instance_patterns(self, context: TransitionContext) -> List[str]:
def _get_instance_patterns(self, context: TransitionContext) -> list[str]:
"""Generate cache key patterns specific to the instance."""
patterns = []
model_name = context.model_name.lower()
@@ -75,7 +72,7 @@ class CacheInvalidationCallback(PostTransitionCallback):
return patterns
def _get_all_patterns(self, context: TransitionContext) -> Set[str]:
def _get_all_patterns(self, context: TransitionContext) -> set[str]:
"""Get all patterns to invalidate, including generated ones."""
all_patterns = set(self.patterns)
@@ -130,7 +127,6 @@ class CacheInvalidationCallback(PostTransitionCallback):
def _fallback_invalidation(self, context: TransitionContext) -> bool:
"""Fallback cache invalidation using Django's cache framework."""
try:
from django.core.cache import cache
patterns = self._get_all_patterns(context)
@@ -171,7 +167,7 @@ class ModelCacheInvalidation(CacheInvalidationCallback):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def _get_instance_patterns(self, context: TransitionContext) -> List[str]:
def _get_instance_patterns(self, context: TransitionContext) -> list[str]:
"""Get model-specific patterns."""
base_patterns = super()._get_instance_patterns(context)
@@ -198,7 +194,7 @@ class RelatedModelCacheInvalidation(CacheInvalidationCallback):
def __init__(
self,
related_fields: Optional[List[str]] = None,
related_fields: list[str] | None = None,
**kwargs,
):
"""
@@ -211,7 +207,7 @@ class RelatedModelCacheInvalidation(CacheInvalidationCallback):
super().__init__(**kwargs)
self.related_fields = related_fields or []
def _get_related_patterns(self, context: TransitionContext) -> List[str]:
def _get_related_patterns(self, context: TransitionContext) -> list[str]:
"""Get cache patterns for related models."""
patterns = []
@@ -236,7 +232,7 @@ class RelatedModelCacheInvalidation(CacheInvalidationCallback):
return patterns
def _get_all_patterns(self, context: TransitionContext) -> Set[str]:
def _get_all_patterns(self, context: TransitionContext) -> set[str]:
"""Get all patterns including related model patterns."""
patterns = super()._get_all_patterns(context)
patterns.update(self._get_related_patterns(context))
@@ -254,7 +250,7 @@ class PatternCacheInvalidation(CacheInvalidationCallback):
def __init__(
self,
patterns: List[str],
patterns: list[str],
include_instance_patterns: bool = False,
**kwargs,
):
@@ -284,7 +280,7 @@ class APICacheInvalidation(CacheInvalidationCallback):
def __init__(
self,
api_prefixes: Optional[List[str]] = None,
api_prefixes: list[str] | None = None,
include_geo_cache: bool = False,
**kwargs,
):
@@ -300,7 +296,7 @@ class APICacheInvalidation(CacheInvalidationCallback):
self.api_prefixes = api_prefixes or ['api:*']
self.include_geo_cache = include_geo_cache
def _get_all_patterns(self, context: TransitionContext) -> Set[str]:
def _get_all_patterns(self, context: TransitionContext) -> set[str]:
"""Get API-specific cache patterns."""
patterns = set()
@@ -358,7 +354,7 @@ class RideCacheInvalidation(CacheInvalidationCallback):
**kwargs,
)
def _get_instance_patterns(self, context: TransitionContext) -> List[str]:
def _get_instance_patterns(self, context: TransitionContext) -> list[str]:
"""Include parent park cache patterns."""
patterns = super()._get_instance_patterns(context)

View File

@@ -5,15 +5,14 @@ This module provides callback implementations that send notifications
when state transitions occur.
"""
from typing import Any, Dict, List, Optional, Type
import logging
from typing import Any
from django.conf import settings
from django.db import models
from ..callback_base import PostTransitionCallback, TransitionContext
logger = logging.getLogger(__name__)
@@ -31,7 +30,7 @@ class NotificationCallback(PostTransitionCallback):
self,
notification_type: str,
recipient_field: str = "submitted_by",
template_name: Optional[str] = None,
template_name: str | None = None,
include_transition_data: bool = True,
**kwargs,
):
@@ -69,7 +68,7 @@ class NotificationCallback(PostTransitionCallback):
return True
def _get_recipient(self, instance: models.Model) -> Optional[Any]:
def _get_recipient(self, instance: models.Model) -> Any | None:
"""Get the notification recipient from the instance."""
return getattr(instance, self.recipient_field, None)
@@ -82,7 +81,7 @@ class NotificationCallback(PostTransitionCallback):
logger.warning("NotificationService not available")
return None
def _build_extra_data(self, context: TransitionContext) -> Dict[str, Any]:
def _build_extra_data(self, context: TransitionContext) -> dict[str, Any]:
"""Build extra data for the notification."""
extra_data = {}
@@ -401,7 +400,7 @@ class StatusChangeNotification(NotificationCallback):
def __init__(
self,
significant_states: Optional[List[str]] = None,
significant_states: list[str] | None = None,
notify_admins: bool = True,
**kwargs,
):
@@ -429,10 +428,7 @@ class StatusChangeNotification(NotificationCallback):
return False
# Only notify for significant status changes
if context.target_state not in self.significant_states:
return False
return True
return context.target_state in self.significant_states
def execute(self, context: TransitionContext) -> bool:
"""Execute the status change notification."""
@@ -518,12 +514,12 @@ class ModerationNotificationCallback(NotificationCallback):
**kwargs,
)
def _get_notification_type(self, context: TransitionContext) -> Optional[str]:
def _get_notification_type(self, context: TransitionContext) -> str | None:
"""Get the specific notification type based on model and state."""
key = (context.model_name, context.target_state)
return self.NOTIFICATION_MAPPING.get(key)
def _get_recipient(self, instance: models.Model) -> Optional[Any]:
def _get_recipient(self, instance: models.Model) -> Any | None:
"""Get the appropriate recipient based on model type."""
# Try common recipient fields
for field in ['reporter', 'assigned_to', 'created_by', 'submitted_by']:

View File

@@ -5,15 +5,14 @@ This module provides callback implementations that update related models
when state transitions occur.
"""
from typing import Any, Callable, Dict, List, Optional, Set, Type
import logging
from collections.abc import Callable
from django.conf import settings
from django.db import models, transaction
from ..callback_base import PostTransitionCallback, TransitionContext
logger = logging.getLogger(__name__)
@@ -28,7 +27,7 @@ class RelatedModelUpdateCallback(PostTransitionCallback):
def __init__(
self,
update_function: Optional[Callable[[TransitionContext], bool]] = None,
update_function: Callable[[TransitionContext], bool] | None = None,
use_transaction: bool = True,
**kwargs,
):
@@ -251,8 +250,8 @@ class ComputedFieldUpdateCallback(RelatedModelUpdateCallback):
def __init__(
self,
computed_fields: Optional[List[str]] = None,
update_method: Optional[str] = None,
computed_fields: list[str] | None = None,
update_method: str | None = None,
**kwargs,
):
"""
@@ -321,10 +320,7 @@ class RideStatusUpdateCallback(RelatedModelUpdateCallback):
return False
# Only execute for Ride model
if context.model_name != 'Ride':
return False
return True
return context.model_name == 'Ride'
def perform_update(self, context: TransitionContext) -> bool:
"""Perform ride-specific status updates."""
@@ -425,7 +421,7 @@ class ModerationQueueUpdateCallback(RelatedModelUpdateCallback):
except Exception as e:
logger.warning(f"Failed to update queue items: {e}")
def _get_content_type_id(self, instance: models.Model) -> Optional[int]:
def _get_content_type_id(self, instance: models.Model) -> int | None:
"""Get content type ID for the instance."""
try:
from django.contrib.contenttypes.models import ContentType

View File

@@ -5,14 +5,13 @@ This module provides centralized configuration for all FSM transition callbacks,
including enable/disable settings, priorities, and environment-specific overrides.
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Type
import logging
from dataclasses import dataclass, field
from typing import Any
from django.conf import settings
from django.db import models
logger = logging.getLogger(__name__)
@@ -23,10 +22,10 @@ class TransitionCallbackConfig:
notifications_enabled: bool = True
cache_invalidation_enabled: bool = True
related_updates_enabled: bool = True
notification_template: Optional[str] = None
cache_patterns: List[str] = field(default_factory=list)
notification_template: str | None = None
cache_patterns: list[str] = field(default_factory=list)
priority: int = 100
extra_data: Dict[str, Any] = field(default_factory=dict)
extra_data: dict[str, Any] = field(default_factory=dict)
@dataclass
@@ -35,7 +34,7 @@ class ModelCallbackConfig:
model_name: str
field_name: str = 'status'
transitions: Dict[tuple, TransitionCallbackConfig] = field(default_factory=dict)
transitions: dict[tuple, TransitionCallbackConfig] = field(default_factory=dict)
default_config: TransitionCallbackConfig = field(default_factory=TransitionCallbackConfig)
@@ -63,20 +62,20 @@ class CallbackConfig:
}
# Model-specific configurations
MODEL_CONFIGS: Dict[str, ModelCallbackConfig] = {}
MODEL_CONFIGS: dict[str, ModelCallbackConfig] = {}
def __init__(self):
self._settings = self._load_settings()
self._model_configs = self._build_model_configs()
def _load_settings(self) -> Dict[str, Any]:
def _load_settings(self) -> dict[str, Any]:
"""Load settings from Django configuration."""
django_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
merged = dict(self.DEFAULT_SETTINGS)
merged.update(django_settings)
return merged
def _build_model_configs(self) -> Dict[str, ModelCallbackConfig]:
def _build_model_configs(self) -> dict[str, ModelCallbackConfig]:
"""Build model-specific configurations."""
return {
'EditSubmission': ModelCallbackConfig(
@@ -315,7 +314,7 @@ class CallbackConfig:
model_name: str,
source: str,
target: str,
) -> List[str]:
) -> list[str]:
"""Get cache invalidation patterns for a transition."""
config = self.get_config(model_name, source, target)
return config.cache_patterns
@@ -325,14 +324,14 @@ class CallbackConfig:
model_name: str,
source: str,
target: str,
) -> Optional[str]:
) -> str | None:
"""Get notification template for a transition."""
config = self.get_config(model_name, source, target)
return config.notification_template
def register_model_config(
self,
model_class: Type[models.Model],
model_class: type[models.Model],
config: ModelCallbackConfig,
) -> None:
"""

View File

@@ -1,7 +1,8 @@
"""Transition decorator generation for django-fsm integration."""
from typing import Any, Callable, List, Optional, Type, Union
from functools import wraps
import logging
from collections.abc import Callable
from functools import wraps
from typing import Any
from django.db import models
from django_fsm import transition
@@ -14,12 +15,11 @@ from .callback_base import (
callback_registry,
)
from .signals import (
pre_state_transition,
post_state_transition,
pre_state_transition,
state_transition_failed,
)
logger = logging.getLogger(__name__)
@@ -193,10 +193,10 @@ def create_transition_method(
source: str,
target: str,
field_name: str,
permission_guard: Optional[Callable] = None,
on_success: Optional[Callable] = None,
on_error: Optional[Callable] = None,
callbacks: Optional[List[BaseTransitionCallback]] = None,
permission_guard: Callable | None = None,
on_success: Callable | None = None,
on_error: Callable | None = None,
callbacks: list[BaseTransitionCallback] | None = None,
enable_callbacks: bool = True,
emit_signals: bool = True,
) -> Callable:
@@ -259,7 +259,7 @@ def create_transition_method(
def register_method_callbacks(
model_class: Type[models.Model],
model_class: type[models.Model],
method: Callable,
) -> None:
"""
@@ -275,14 +275,11 @@ def register_method_callbacks(
if not metadata or not metadata.get('callbacks'):
return
from .callback_base import CallbackStage, PostTransitionCallback, PreTransitionCallback
from .callback_base import CallbackStage, PreTransitionCallback
for callback in metadata['callbacks']:
# Determine stage from callback type
if isinstance(callback, PreTransitionCallback):
stage = CallbackStage.PRE
else:
stage = CallbackStage.POST
stage = CallbackStage.PRE if isinstance(callback, PreTransitionCallback) else CallbackStage.POST
callback_registry.register(
model_class=model_class,
@@ -302,7 +299,7 @@ class TransitionMethodFactory:
source: str,
target: str,
field_name: str = "status",
permission_guard: Optional[Callable] = None,
permission_guard: Callable | None = None,
enable_callbacks: bool = True,
emit_signals: bool = True,
) -> Callable:
@@ -353,7 +350,7 @@ class TransitionMethodFactory:
source: str,
target: str,
field_name: str = "status",
permission_guard: Optional[Callable] = None,
permission_guard: Callable | None = None,
enable_callbacks: bool = True,
emit_signals: bool = True,
) -> Callable:
@@ -404,7 +401,7 @@ class TransitionMethodFactory:
source: str,
target: str,
field_name: str = "status",
permission_guard: Optional[Callable] = None,
permission_guard: Callable | None = None,
enable_callbacks: bool = True,
emit_signals: bool = True,
) -> Callable:
@@ -456,8 +453,8 @@ class TransitionMethodFactory:
source: str,
target: str,
field_name: str = "status",
permission_guard: Optional[Callable] = None,
docstring: Optional[str] = None,
permission_guard: Callable | None = None,
docstring: str | None = None,
enable_callbacks: bool = True,
emit_signals: bool = True,
) -> Callable:

View File

@@ -12,7 +12,8 @@ Example usage:
'code': e.error_code
}, status=403)
"""
from typing import Any, Optional, List, Dict
from typing import Any
from django_fsm import TransitionNotAllowed
@@ -42,10 +43,10 @@ class TransitionPermissionDenied(TransitionNotAllowed):
self,
message: str = "Permission denied for this transition",
error_code: str = "PERMISSION_DENIED",
user_message: Optional[str] = None,
required_roles: Optional[List[str]] = None,
user_role: Optional[str] = None,
guard: Optional[Any] = None,
user_message: str | None = None,
required_roles: list[str] | None = None,
user_role: str | None = None,
guard: Any | None = None,
):
"""
Initialize permission denied exception.
@@ -65,7 +66,7 @@ class TransitionPermissionDenied(TransitionNotAllowed):
self.user_role = user_role
self.guard = guard
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""
Convert exception to dictionary for API responses.
@@ -106,10 +107,10 @@ class TransitionValidationError(TransitionNotAllowed):
self,
message: str = "Transition validation failed",
error_code: str = "VALIDATION_FAILED",
user_message: Optional[str] = None,
field_name: Optional[str] = None,
current_state: Optional[str] = None,
guard: Optional[Any] = None,
user_message: str | None = None,
field_name: str | None = None,
current_state: str | None = None,
guard: Any | None = None,
):
"""
Initialize validation error exception.
@@ -129,7 +130,7 @@ class TransitionValidationError(TransitionNotAllowed):
self.current_state = current_state
self.guard = guard
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""
Convert exception to dictionary for API responses.
@@ -168,10 +169,10 @@ class TransitionNotAvailable(TransitionNotAllowed):
self,
message: str = "This transition is not available",
error_code: str = "TRANSITION_NOT_AVAILABLE",
user_message: Optional[str] = None,
current_state: Optional[str] = None,
requested_transition: Optional[str] = None,
available_transitions: Optional[List[str]] = None,
user_message: str | None = None,
current_state: str | None = None,
requested_transition: str | None = None,
available_transitions: list[str] | None = None,
):
"""
Initialize transition not available exception.
@@ -191,7 +192,7 @@ class TransitionNotAvailable(TransitionNotAllowed):
self.requested_transition = requested_transition
self.available_transitions = available_transitions or []
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""
Convert exception to dictionary for API responses.
@@ -267,12 +268,12 @@ def get_permission_error_message(
# "You need moderator permissions to approve submissions..."
"""
from .guards import (
PermissionGuard,
OwnershipGuard,
AssignmentGuard,
MODERATOR_ROLES,
ADMIN_ROLES,
MODERATOR_ROLES,
SUPERUSER_ROLES,
AssignmentGuard,
OwnershipGuard,
PermissionGuard,
)
if hasattr(guard, "get_error_message"):
@@ -348,7 +349,7 @@ def get_state_error_message(
def format_transition_error(
exception: Exception,
include_details: bool = False,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Format a transition exception for API response.
@@ -412,7 +413,7 @@ def raise_permission_denied(
user_role = get_user_role(user) if user else None
error_code = TransitionPermissionDenied.ERROR_CODE_PERMISSION_DENIED_ROLE
required_roles: List[str] = []
required_roles: list[str] = []
if isinstance(guard, PermissionGuard):
required_roles = guard.get_required_roles()
@@ -431,8 +432,8 @@ def raise_permission_denied(
def raise_validation_error(
guard: Any,
current_state: Optional[str] = None,
field_name: Optional[str] = None,
current_state: str | None = None,
field_name: str | None = None,
) -> None:
"""
Raise a TransitionValidationError exception with proper context.
@@ -445,7 +446,7 @@ def raise_validation_error(
Raises:
TransitionValidationError: Always raised with proper context
"""
from .guards import StateGuard, MetadataGuard
from .guards import MetadataGuard, StateGuard
error_code = TransitionValidationError.ERROR_CODE_VALIDATION_FAILED
user_message = "Validation failed for this transition"

View File

@@ -47,7 +47,7 @@ See Also:
- apps.core.choices.registry: The central choice registry
- apps.core.state_machine.mixins.StateMachineMixin: Convenience helpers
"""
from typing import Any, Optional
from typing import Any
from django.core.exceptions import ValidationError
from django_fsm import FSMField as DjangoFSMField
@@ -147,7 +147,7 @@ class RichFSMField(DjangoFSMField):
f"'{value}' is deprecated and cannot be used for new entries"
)
def get_rich_choice(self, value: str) -> Optional[RichChoice]:
def get_rich_choice(self, value: str) -> RichChoice | None:
"""Return the RichChoice object for a given state value."""
return registry.get_choice(self.choice_group, value, self.domain)

View File

@@ -18,7 +18,8 @@ Example usage:
OwnershipGuard()
], operator='OR')
"""
from typing import Callable, Dict, List, Optional, Any, Tuple, Union
from collections.abc import Callable
from typing import Any, Optional
# Valid user roles in order of increasing privilege
VALID_ROLES = ["USER", "MODERATOR", "ADMIN", "SUPERUSER"]
@@ -62,9 +63,9 @@ class PermissionGuard:
requires_moderator: bool = False,
requires_admin: bool = False,
requires_superuser: bool = False,
required_roles: Optional[List[str]] = None,
custom_check: Optional[Callable] = None,
error_message: Optional[str] = None,
required_roles: list[str] | None = None,
custom_check: Callable | None = None,
error_message: str | None = None,
):
"""
Initialize permission guard.
@@ -83,10 +84,10 @@ class PermissionGuard:
self.required_roles = required_roles
self.custom_check = custom_check
self._custom_error_message = error_message
self._last_error_code: Optional[str] = None
self._last_error_code: str | None = None
@property
def error_code(self) -> Optional[str]:
def error_code(self) -> str | None:
"""Return the error code from the last failed check."""
return self._last_error_code
@@ -126,16 +127,14 @@ class PermissionGuard:
return False
# Check moderator (includes admin and superuser)
elif self.requires_moderator:
if not is_moderator_or_above(user):
self._last_error_code = self.ERROR_CODE_PERMISSION_DENIED_ROLE
return False
elif self.requires_moderator and not is_moderator_or_above(user):
self._last_error_code = self.ERROR_CODE_PERMISSION_DENIED_ROLE
return False
# Apply custom check if provided
if self.custom_check:
if not self.custom_check(instance, user):
self._last_error_code = self.ERROR_CODE_PERMISSION_DENIED_CUSTOM
return False
if self.custom_check and not self.custom_check(instance, user):
self._last_error_code = self.ERROR_CODE_PERMISSION_DENIED_CUSTOM
return False
return True
@@ -162,7 +161,7 @@ class PermissionGuard:
return "This transition requires special permissions"
return "This transition is not allowed"
def get_required_roles(self) -> List[str]:
def get_required_roles(self) -> list[str]:
"""
Return list of roles that would satisfy this guard.
@@ -207,10 +206,10 @@ class OwnershipGuard:
def __init__(
self,
owner_fields: Optional[List[str]] = None,
owner_fields: list[str] | None = None,
allow_moderator_override: bool = False,
allow_admin_override: bool = False,
error_message: Optional[str] = None,
error_message: str | None = None,
):
"""
Initialize ownership guard.
@@ -225,10 +224,10 @@ class OwnershipGuard:
self.allow_moderator_override = allow_moderator_override
self.allow_admin_override = allow_admin_override
self._custom_error_message = error_message
self._last_error_code: Optional[str] = None
self._last_error_code: str | None = None
@property
def error_code(self) -> Optional[str]:
def error_code(self) -> str | None:
"""Return the error code from the last failed check."""
return self._last_error_code
@@ -305,10 +304,10 @@ class AssignmentGuard:
def __init__(
self,
assignment_fields: Optional[List[str]] = None,
assignment_fields: list[str] | None = None,
require_assignment: bool = False,
allow_admin_override: bool = False,
error_message: Optional[str] = None,
error_message: str | None = None,
):
"""
Initialize assignment guard.
@@ -323,10 +322,10 @@ class AssignmentGuard:
self.require_assignment = require_assignment
self.allow_admin_override = allow_admin_override
self._custom_error_message = error_message
self._last_error_code: Optional[str] = None
self._last_error_code: str | None = None
@property
def error_code(self) -> Optional[str]:
def error_code(self) -> str | None:
"""Return the error code from the last failed check."""
return self._last_error_code
@@ -399,10 +398,10 @@ class StateGuard:
def __init__(
self,
allowed_states: Optional[List[str]] = None,
blocked_states: Optional[List[str]] = None,
allowed_states: list[str] | None = None,
blocked_states: list[str] | None = None,
state_field: str = "status",
error_message: Optional[str] = None,
error_message: str | None = None,
):
"""
Initialize state guard.
@@ -417,11 +416,11 @@ class StateGuard:
self.blocked_states = blocked_states or []
self.state_field = state_field
self._custom_error_message = error_message
self._last_error_code: Optional[str] = None
self._current_state: Optional[str] = None
self._last_error_code: str | None = None
self._current_state: str | None = None
@property
def error_code(self) -> Optional[str]:
def error_code(self) -> str | None:
"""Return the error code from the last failed check."""
return self._last_error_code
@@ -445,10 +444,9 @@ class StateGuard:
return False
# Check allowed states if specified
if self.allowed_states is not None:
if self._current_state not in self.allowed_states:
self._last_error_code = self.ERROR_CODE_INVALID_STATE
return False
if self.allowed_states is not None and self._current_state not in self.allowed_states:
self._last_error_code = self.ERROR_CODE_INVALID_STATE
return False
return True
@@ -485,9 +483,9 @@ class MetadataGuard:
def __init__(
self,
required_fields: Optional[List[str]] = None,
required_fields: list[str] | None = None,
check_not_empty: bool = True,
error_message: Optional[str] = None,
error_message: str | None = None,
):
"""
Initialize metadata guard.
@@ -500,11 +498,11 @@ class MetadataGuard:
self.required_fields = required_fields or []
self.check_not_empty = check_not_empty
self._custom_error_message = error_message
self._last_error_code: Optional[str] = None
self._failed_field: Optional[str] = None
self._last_error_code: str | None = None
self._failed_field: str | None = None
@property
def error_code(self) -> Optional[str]:
def error_code(self) -> str | None:
"""Return the error code from the last failed check."""
return self._last_error_code
@@ -582,9 +580,9 @@ class CompositeGuard:
def __init__(
self,
guards: List[Callable],
guards: list[Callable],
operator: str = "AND",
error_message: Optional[str] = None,
error_message: str | None = None,
):
"""
Initialize composite guard.
@@ -597,11 +595,11 @@ class CompositeGuard:
self.guards = guards
self.operator = operator.upper()
self._custom_error_message = error_message
self._last_error_code: Optional[str] = None
self._failed_guards: List[Callable] = []
self._last_error_code: str | None = None
self._failed_guards: list[Callable] = []
@property
def error_code(self) -> Optional[str]:
def error_code(self) -> str | None:
"""Return the error code from the last failed check."""
return self._last_error_code
@@ -658,7 +656,7 @@ class CompositeGuard:
def create_ownership_guard(
owner_fields: Optional[List[str]] = None,
owner_fields: list[str] | None = None,
allow_moderator_override: bool = False,
allow_admin_override: bool = False,
) -> OwnershipGuard:
@@ -681,7 +679,7 @@ def create_ownership_guard(
def create_assignment_guard(
assignment_fields: Optional[List[str]] = None,
assignment_fields: list[str] | None = None,
require_assignment: bool = False,
allow_admin_override: bool = False,
) -> AssignmentGuard:
@@ -704,7 +702,7 @@ def create_assignment_guard(
def create_composite_guard(
guards: List[Callable],
guards: list[Callable],
operator: str = "AND",
) -> CompositeGuard:
"""
@@ -728,7 +726,7 @@ ESCALATION_LEVEL_ROLES = {
}
def extract_guards_from_metadata(metadata: Dict[str, Any]) -> List[Callable]:
def extract_guards_from_metadata(metadata: dict[str, Any]) -> list[Callable]:
"""
Convert RichChoice metadata to guard functions.
@@ -823,7 +821,7 @@ def extract_guards_from_metadata(metadata: Dict[str, Any]) -> List[Callable]:
return guards
def create_permission_guard(metadata: Dict[str, Any]) -> PermissionGuard:
def create_permission_guard(metadata: dict[str, Any]) -> PermissionGuard:
"""
Create a permission guard from RichChoice metadata.
@@ -874,7 +872,7 @@ def create_permission_guard(metadata: Dict[str, Any]) -> PermissionGuard:
)
def validate_guard_metadata(metadata: Dict[str, Any]) -> Tuple[bool, List[str]]:
def validate_guard_metadata(metadata: dict[str, Any]) -> tuple[bool, list[str]]:
"""
Validate that metadata contains valid guard configuration.
@@ -892,12 +890,11 @@ def validate_guard_metadata(metadata: Dict[str, Any]) -> Tuple[bool, List[str]]:
# Validate escalation_level
escalation_level = metadata.get("escalation_level")
if escalation_level:
if escalation_level.lower() not in ESCALATION_LEVEL_ROLES:
errors.append(
f"Invalid escalation_level: {escalation_level}. "
f"Must be one of: {', '.join(ESCALATION_LEVEL_ROLES.keys())}"
)
if escalation_level and escalation_level.lower() not in ESCALATION_LEVEL_ROLES:
errors.append(
f"Invalid escalation_level: {escalation_level}. "
f"Must be one of: {', '.join(ESCALATION_LEVEL_ROLES.keys())}"
)
# Validate required_permissions is a list
required_permissions = metadata.get("required_permissions")
@@ -927,7 +924,7 @@ class GuardRegistry:
"""Registry for storing and retrieving guard functions."""
_instance: Optional["GuardRegistry"] = None
_guards: Dict[str, Callable]
_guards: dict[str, Callable]
def __new__(cls):
"""Implement singleton pattern."""
@@ -946,7 +943,7 @@ class GuardRegistry:
"""
self._guards[name] = guard
def get_guard(self, name: str) -> Optional[Callable]:
def get_guard(self, name: str) -> Callable | None:
"""
Retrieve a guard by name.
@@ -961,9 +958,9 @@ class GuardRegistry:
def apply_guards(
self,
instance: Any,
guards: List[Callable],
user: Optional[Any] = None,
) -> Tuple[bool, Optional[str]]:
guards: list[Callable],
user: Any | None = None,
) -> tuple[bool, str | None]:
"""
Apply multiple guards.
@@ -993,8 +990,8 @@ class GuardRegistry:
def create_condition_from_metadata(
metadata: Dict[str, Any],
) -> Optional[Callable]:
metadata: dict[str, Any],
) -> Callable | None:
"""
Create FSM condition from metadata.
@@ -1011,10 +1008,7 @@ def create_condition_from_metadata(
def combined_condition(instance, user=None):
"""Combined condition from all guards."""
for guard in guards:
if not guard(instance, user):
return False
return True
return all(guard(instance, user) for guard in guards)
return combined_condition
@@ -1022,7 +1016,7 @@ def create_condition_from_metadata(
# Helper functions for permission checks
def get_user_role(user: Any) -> Optional[str]:
def get_user_role(user: Any) -> str | None:
"""
Get the user's role from the role field.
@@ -1043,7 +1037,7 @@ def get_user_role(user: Any) -> Optional[str]:
return None
def has_role(user: Any, required_roles: List[str]) -> bool:
def has_role(user: Any, required_roles: list[str]) -> bool:
"""
Check if user has one of the required roles.
@@ -1083,9 +1077,8 @@ def has_role(user: Any, required_roles: List[str]) -> bool:
return True
# Check for staff status (treat as moderator)
if hasattr(user, "is_staff") and user.is_staff:
if "MODERATOR" in required_roles:
return True
if hasattr(user, "is_staff") and user.is_staff and "MODERATOR" in required_roles:
return True
return False
@@ -1191,7 +1184,7 @@ def has_permission(user: Any, permission: str) -> bool:
def create_guard_from_drf_permission(
permission_class: type,
error_message: Optional[str] = None,
error_message: str | None = None,
) -> Callable:
"""
Create an FSM guard from a DRF permission class.
@@ -1225,13 +1218,13 @@ def create_guard_from_drf_permission(
class DRFPermissionGuard:
"""Guard that wraps a DRF permission class."""
def __init__(self, perm_class: type, err_msg: Optional[str] = None):
def __init__(self, perm_class: type, err_msg: str | None = None):
self.permission_class = perm_class
self._custom_error_message = err_msg
self._last_error_code: Optional[str] = None
self._last_error_code: str | None = None
@property
def error_code(self) -> Optional[str]:
def error_code(self) -> str | None:
return self._last_error_code
def __call__(self, instance: Any, user: Any = None) -> bool:

View File

@@ -1,5 +1,6 @@
"""Model integration utilities for applying state machines to Django models."""
from typing import Type, Optional, Dict, Any, List, Callable
from collections.abc import Callable
from typing import Any
from django.db import models
from django_fsm import can_proceed
@@ -8,23 +9,21 @@ from apps.core.state_machine.builder import (
StateTransitionBuilder,
determine_method_name_for_transition,
)
from apps.core.state_machine.decorators import TransitionMethodFactory
from apps.core.state_machine.guards import (
CompositeGuard,
create_guard_from_drf_permission,
extract_guards_from_metadata,
)
from apps.core.state_machine.registry import (
TransitionInfo,
registry_instance,
)
from apps.core.state_machine.validators import MetadataValidator
from apps.core.state_machine.decorators import TransitionMethodFactory
from apps.core.state_machine.guards import (
create_permission_guard,
extract_guards_from_metadata,
create_condition_from_metadata,
create_guard_from_drf_permission,
CompositeGuard,
)
def apply_state_machine(
model_class: Type[models.Model],
model_class: type[models.Model],
field_name: str,
choice_group: str,
domain: str = "core",
@@ -48,7 +47,7 @@ def apply_state_machine(
if not result.is_valid:
error_messages = [str(e) for e in result.errors]
raise ValueError(
f"Cannot apply state machine - validation failed:\n"
"Cannot apply state machine - validation failed:\n"
+ "\n".join(error_messages)
)
@@ -62,7 +61,7 @@ def apply_state_machine(
def generate_transition_methods_for_model(
model_class: Type[models.Model],
model_class: type[models.Model],
field_name: str,
choice_group: str,
domain: str = "core",
@@ -100,7 +99,7 @@ def generate_transition_methods_for_model(
all_guards = guards + target_guards
# Create combined guard if we have multiple guards
combined_guard: Optional[Callable] = None
combined_guard: Callable | None = None
if len(all_guards) == 1:
combined_guard = all_guards[0]
elif len(all_guards) > 1:
@@ -149,7 +148,7 @@ class StateMachineModelMixin:
def get_available_state_transitions(
self, field_name: str = "status"
) -> List[TransitionInfo]:
) -> list[TransitionInfo]:
"""
Get available transitions from current state.
@@ -176,7 +175,7 @@ class StateMachineModelMixin:
self,
target_state: str,
field_name: str = "status",
user: Optional[Any] = None,
user: Any | None = None,
) -> bool:
"""
Check if transition to target state is allowed.
@@ -219,7 +218,7 @@ class StateMachineModelMixin:
def get_transition_method(
self, target_state: str, field_name: str = "status"
) -> Optional[Callable]:
) -> Callable | None:
"""
Get the transition method for moving to target state.
@@ -252,7 +251,7 @@ class StateMachineModelMixin:
self,
target_state: str,
field_name: str = "status",
user: Optional[Any] = None,
user: Any | None = None,
**kwargs: Any,
) -> bool:
"""
@@ -299,7 +298,7 @@ def state_machine_model(
Decorator function
"""
def decorator(model_class: Type[models.Model]) -> Type[models.Model]:
def decorator(model_class: type[models.Model]) -> type[models.Model]:
"""Apply state machine to model class."""
apply_state_machine(model_class, field_name, choice_group, domain)
return model_class
@@ -308,7 +307,7 @@ def state_machine_model(
def validate_model_state_machine(
model_class: Type[models.Model], field_name: str
model_class: type[models.Model], field_name: str
) -> bool:
"""
Ensure model is properly configured with state machine.
@@ -345,7 +344,7 @@ def validate_model_state_machine(
if not result.is_valid:
error_messages = [str(e) for e in result.errors]
raise ValueError(
f"State machine validation failed:\n" + "\n".join(error_messages)
"State machine validation failed:\n" + "\n".join(error_messages)
)
return True

View File

@@ -38,12 +38,12 @@ See Also:
- apps.core.state_machine.fields.RichFSMField: The FSM field implementation
- django_fsm.can_proceed: FSM transition checking utility
"""
from typing import Any, Dict, Iterable, List, Optional
from collections.abc import Iterable
from typing import Any
from django.db import models
from django_fsm import can_proceed
# Default transition metadata for styling
TRANSITION_METADATA = {
# Approval transitions
@@ -71,7 +71,7 @@ TRANSITION_METADATA = {
}
def _get_transition_metadata(transition_name: str) -> Dict[str, Any]:
def _get_transition_metadata(transition_name: str) -> dict[str, Any]:
"""Get metadata for a transition by name."""
if transition_name in TRANSITION_METADATA:
return TRANSITION_METADATA[transition_name].copy()
@@ -161,12 +161,12 @@ class StateMachineMixin(models.Model):
class Meta:
abstract = True
def get_state_value(self, field_name: Optional[str] = None) -> Any:
def get_state_value(self, field_name: str | None = None) -> Any:
"""Return the raw state value for the given field (default is `state`)."""
name = field_name or self.state_field_name
return getattr(self, name, None)
def get_state_display_value(self, field_name: Optional[str] = None) -> str:
def get_state_display_value(self, field_name: str | None = None) -> str:
"""Return the display label for the current state, if available."""
name = field_name or self.state_field_name
getter = getattr(self, f"get_{name}_display", None)
@@ -175,7 +175,7 @@ class StateMachineMixin(models.Model):
value = getattr(self, name, "")
return value if value is not None else ""
def get_state_choice(self, field_name: Optional[str] = None):
def get_state_choice(self, field_name: str | None = None):
"""Return the RichChoice object when the field provides one."""
name = field_name or self.state_field_name
getter = getattr(self, f"get_{name}_rich_choice", None)
@@ -193,7 +193,7 @@ class StateMachineMixin(models.Model):
return can_proceed(method)
def get_available_transitions(
self, field_name: Optional[str] = None
self, field_name: str | None = None
) -> Iterable[Any]:
"""Return available transitions when helpers are present."""
name = field_name or self.state_field_name
@@ -203,12 +203,12 @@ class StateMachineMixin(models.Model):
return helper() # type: ignore[misc]
return []
def is_in_state(self, state: str, field_name: Optional[str] = None) -> bool:
def is_in_state(self, state: str, field_name: str | None = None) -> bool:
"""Convenience check for comparing the current state."""
current_state = self.get_state_value(field_name)
return current_state == state
def get_available_user_transitions(self, user) -> List[Dict[str, Any]]:
def get_available_user_transitions(self, user) -> list[dict[str, Any]]:
"""
Get transitions available to the given user.

View File

@@ -5,20 +5,18 @@ This module provides tools for monitoring callback execution,
tracking performance, and debugging transition issues.
"""
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple, Type
from collections import defaultdict
import logging
import time
import threading
import time
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Optional
from django.conf import settings
from django.db import models
from .callback_base import TransitionContext
logger = logging.getLogger(__name__)
@@ -35,9 +33,9 @@ class CallbackExecutionRecord:
timestamp: datetime
duration_ms: float
success: bool
error_message: Optional[str] = None
instance_id: Optional[int] = None
user_id: Optional[int] = None
error_message: str | None = None
instance_id: int | None = None
user_id: int | None = None
@dataclass
@@ -51,8 +49,8 @@ class CallbackStats:
total_duration_ms: float = 0.0
min_duration_ms: float = float('inf')
max_duration_ms: float = 0.0
last_execution: Optional[datetime] = None
last_error: Optional[str] = None
last_execution: datetime | None = None
last_error: str | None = None
@property
def avg_duration_ms(self) -> float:
@@ -72,7 +70,7 @@ class CallbackStats:
self,
duration_ms: float,
success: bool,
error_message: Optional[str] = None,
error_message: str | None = None,
) -> None:
"""Record a callback execution."""
self.total_executions += 1
@@ -114,10 +112,10 @@ class CallbackMonitor:
if self._initialized:
return
self._stats: Dict[str, CallbackStats] = defaultdict(
self._stats: dict[str, CallbackStats] = defaultdict(
lambda: CallbackStats(callback_name="")
)
self._recent_executions: List[CallbackExecutionRecord] = []
self._recent_executions: list[CallbackExecutionRecord] = []
self._max_recent_records = 1000
self._enabled = self._check_enabled()
self._debug_mode = self._check_debug_mode()
@@ -159,7 +157,7 @@ class CallbackMonitor:
stage: str,
duration_ms: float,
success: bool,
error_message: Optional[str] = None,
error_message: str | None = None,
) -> None:
"""
Record a callback execution.
@@ -220,7 +218,7 @@ class CallbackMonitor:
else:
logger.warning(f"{log_message} - Error: {record.error_message}")
def get_stats(self, callback_name: Optional[str] = None) -> Dict[str, CallbackStats]:
def get_stats(self, callback_name: str | None = None) -> dict[str, CallbackStats]:
"""
Get callback statistics.
@@ -239,10 +237,10 @@ class CallbackMonitor:
def get_recent_executions(
self,
limit: int = 100,
callback_name: Optional[str] = None,
model_name: Optional[str] = None,
success_only: Optional[bool] = None,
) -> List[CallbackExecutionRecord]:
callback_name: str | None = None,
model_name: str | None = None,
success_only: bool | None = None,
) -> list[CallbackExecutionRecord]:
"""
Get recent execution records.
@@ -268,12 +266,12 @@ class CallbackMonitor:
# Return most recent first
return list(reversed(records[-limit:]))
def get_failure_summary(self) -> Dict[str, Any]:
def get_failure_summary(self) -> dict[str, Any]:
"""Get a summary of callback failures."""
failures = [r for r in self._recent_executions if not r.success]
# Group by callback
by_callback: Dict[str, List[CallbackExecutionRecord]] = defaultdict(list)
by_callback: dict[str, list[CallbackExecutionRecord]] = defaultdict(list)
for record in failures:
by_callback[record.callback_name].append(record)
@@ -292,7 +290,7 @@ class CallbackMonitor:
return summary
def get_performance_report(self) -> Dict[str, Any]:
def get_performance_report(self) -> dict[str, Any]:
"""Get a performance report for all callbacks."""
report = {
'callbacks': {},
@@ -361,7 +359,7 @@ class TimedCallbackExecution:
self.stage = stage
self.start_time = 0.0
self.success = True
self.error_message: Optional[str] = None
self.error_message: str | None = None
def __enter__(self) -> 'TimedCallbackExecution':
self.start_time = time.perf_counter()
@@ -419,7 +417,7 @@ def get_callback_execution_order(
model_name: str,
source: str,
target: str,
) -> List[Tuple[str, str, int]]:
) -> list[tuple[str, str, int]]:
"""
Get the order of callback execution for a transition.
@@ -431,7 +429,7 @@ def get_callback_execution_order(
Returns:
List of (stage, callback_name, priority) tuples in execution order.
"""
from .callback_base import callback_registry, CallbackStage
from .callback_base import CallbackStage
order = []

View File

@@ -1,13 +1,13 @@
"""TransitionRegistry - Centralized registry for managing FSM transitions."""
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Any, Tuple, Type
import logging
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any, Optional
from django.db import models
from apps.core.state_machine.builder import StateTransitionBuilder
logger = logging.getLogger(__name__)
@@ -20,7 +20,7 @@ class TransitionInfo:
method_name: str
requires_moderator: bool = False
requires_admin_approval: bool = False
metadata: Dict[str, Any] = field(default_factory=dict)
metadata: dict[str, Any] = field(default_factory=dict)
def __hash__(self):
"""Make TransitionInfo hashable."""
@@ -31,7 +31,7 @@ class TransitionRegistry:
"""Centralized registry for managing and looking up FSM transitions."""
_instance: Optional["TransitionRegistry"] = None
_transitions: Dict[Tuple[str, str], Dict[Tuple[str, str], TransitionInfo]]
_transitions: dict[tuple[str, str], dict[tuple[str, str], TransitionInfo]]
def __new__(cls):
"""Implement singleton pattern."""
@@ -40,7 +40,7 @@ class TransitionRegistry:
cls._instance._transitions = {}
return cls._instance
def _get_key(self, choice_group: str, domain: str) -> Tuple[str, str]:
def _get_key(self, choice_group: str, domain: str) -> tuple[str, str]:
"""Generate registry key from choice group and domain."""
return (domain, choice_group)
@@ -51,7 +51,7 @@ class TransitionRegistry:
source: str,
target: str,
method_name: str,
metadata: Optional[Dict[str, Any]] = None,
metadata: dict[str, Any] | None = None,
) -> TransitionInfo:
"""
Register a transition.
@@ -88,7 +88,7 @@ class TransitionRegistry:
def get_transition(
self, choice_group: str, domain: str, source: str, target: str
) -> Optional[TransitionInfo]:
) -> TransitionInfo | None:
"""
Retrieve transition info.
@@ -111,7 +111,7 @@ class TransitionRegistry:
def get_available_transitions(
self, choice_group: str, domain: str, current_state: str
) -> List[TransitionInfo]:
) -> list[TransitionInfo]:
"""
Get all valid transitions from a state.
@@ -129,7 +129,7 @@ class TransitionRegistry:
return []
available = []
for (source, target), info in self._transitions[key].items():
for (source, _target), info in self._transitions[key].items():
if source == current_state:
available.append(info)
@@ -137,7 +137,7 @@ class TransitionRegistry:
def get_transition_method_name(
self, choice_group: str, domain: str, source: str, target: str
) -> Optional[str]:
) -> str | None:
"""
Get the method name for a transition.
@@ -209,8 +209,8 @@ class TransitionRegistry:
def clear_registry(
self,
choice_group: Optional[str] = None,
domain: Optional[str] = None,
choice_group: str | None = None,
domain: str | None = None,
) -> None:
"""
Clear registry entries for testing.
@@ -246,7 +246,7 @@ class TransitionRegistry:
return {} if format == "dict" else ""
if format == "dict":
graph: Dict[str, List[str]] = {}
graph: dict[str, list[str]] = {}
for (source, target), info in self._transitions[key].items():
if source not in graph:
graph[source] = []
@@ -272,7 +272,7 @@ class TransitionRegistry:
else:
raise ValueError(f"Unsupported format: {format}")
def get_all_registered_groups(self) -> List[Tuple[str, str]]:
def get_all_registered_groups(self) -> list[tuple[str, str]]:
"""
Get all registered choice groups.
@@ -289,7 +289,7 @@ registry_instance = TransitionRegistry()
# Callback registration helpers
def register_callback(
model_class: Type[models.Model],
model_class: type[models.Model],
field_name: str,
source: str,
target: str,
@@ -307,7 +307,7 @@ def register_callback(
callback: The callback instance.
stage: When to execute ('pre', 'post', 'error').
"""
from .callback_base import callback_registry, CallbackStage
from .callback_base import CallbackStage, callback_registry
callback_registry.register(
model_class=model_class,
@@ -320,7 +320,7 @@ def register_callback(
def register_notification_callback(
model_class: Type[models.Model],
model_class: type[models.Model],
field_name: str,
source: str,
target: str,
@@ -348,9 +348,9 @@ def register_notification_callback(
def register_cache_invalidation(
model_class: Type[models.Model],
model_class: type[models.Model],
field_name: str,
cache_patterns: Optional[List[str]] = None,
cache_patterns: list[str] | None = None,
source: str = '*',
target: str = '*',
) -> None:
@@ -371,7 +371,7 @@ def register_cache_invalidation(
def register_related_update(
model_class: Type[models.Model],
model_class: type[models.Model],
field_name: str,
update_func: Callable,
source: str = '*',
@@ -393,7 +393,7 @@ def register_related_update(
register_callback(model_class, field_name, source, target, callback, 'post')
def register_transition_callbacks(cls: Type[models.Model]) -> Type[models.Model]:
def register_transition_callbacks(cls: type[models.Model]) -> type[models.Model]:
"""
Class decorator to auto-register callbacks from model's Meta.

View File

@@ -5,15 +5,12 @@ This module defines custom Django signals emitted during state machine
transitions and provides utilities for connecting signal handlers.
"""
from typing import Any, Callable, Dict, List, Optional, Type, Union
import logging
from collections.abc import Callable
from django.db import models
from django.dispatch import Signal, receiver
from .callback_base import TransitionContext
logger = logging.getLogger(__name__)
@@ -69,11 +66,11 @@ class TransitionSignalHandler:
"""
def __init__(self):
self._handlers: Dict[str, List[Callable]] = {}
self._handlers: dict[str, list[Callable]] = {}
def register(
self,
model_class: Type[models.Model],
model_class: type[models.Model],
source: str,
target: str,
handler: Callable,
@@ -105,7 +102,7 @@ class TransitionSignalHandler:
def unregister(
self,
model_class: Type[models.Model],
model_class: type[models.Model],
source: str,
target: str,
handler: Callable,
@@ -121,7 +118,7 @@ class TransitionSignalHandler:
def _make_key(
self,
model_class: Type[models.Model],
model_class: type[models.Model],
source: str,
target: str,
stage: str,
@@ -140,7 +137,7 @@ class TransitionSignalHandler:
def _connect_signal(
self,
signal: Signal,
model_class: Type[models.Model],
model_class: type[models.Model],
source: str,
target: str,
handler: Callable,
@@ -173,7 +170,7 @@ transition_signal_handler = TransitionSignalHandler()
def register_transition_handler(
model_class: Type[models.Model],
model_class: type[models.Model],
source: str,
target: str,
handler: Callable,
@@ -233,7 +230,7 @@ class TransitionHandlerDecorator:
def __init__(
self,
model_class: Type[models.Model],
model_class: type[models.Model],
source: str = '*',
target: str = '*',
stage: str = 'post',
@@ -265,7 +262,7 @@ class TransitionHandlerDecorator:
def on_transition(
model_class: Type[models.Model],
model_class: type[models.Model],
source: str = '*',
target: str = '*',
stage: str = 'post',
@@ -291,7 +288,7 @@ def on_transition(
def on_pre_transition(
model_class: Type[models.Model],
model_class: type[models.Model],
source: str = '*',
target: str = '*',
) -> TransitionHandlerDecorator:
@@ -300,7 +297,7 @@ def on_pre_transition(
def on_post_transition(
model_class: Type[models.Model],
model_class: type[models.Model],
source: str = '*',
target: str = '*',
) -> TransitionHandlerDecorator:
@@ -309,7 +306,7 @@ def on_post_transition(
def on_transition_error(
model_class: Type[models.Model],
model_class: type[models.Model],
source: str = '*',
target: str = '*',
) -> TransitionHandlerDecorator:

View File

@@ -7,43 +7,44 @@ This module provides reusable fixtures for creating test data:
- Mock objects for testing guards and callbacks
"""
from typing import Any
from django.contrib.auth import get_user_model
from django.contrib.contenttypes.models import ContentType
from typing import Optional, Any, Dict
User = get_user_model()
class UserFactory:
"""Factory for creating users with different roles."""
_counter = 0
@classmethod
def _get_unique_id(cls) -> int:
"""Get a unique counter for creating unique usernames."""
cls._counter += 1
return cls._counter
@classmethod
def create_user(
cls,
role: str = 'USER',
username: Optional[str] = None,
email: Optional[str] = None,
username: str | None = None,
email: str | None = None,
password: str = 'testpass123',
**kwargs
) -> User:
"""
Create a user with specified role.
Args:
role: User role (USER, MODERATOR, ADMIN, SUPERUSER)
username: Username (auto-generated if not provided)
email: Email (auto-generated if not provided)
password: Password for the user
**kwargs: Additional user fields
Returns:
Created User instance
"""
@@ -52,7 +53,7 @@ class UserFactory:
username = f"user_{role.lower()}_{uid}"
if email is None:
email = f"{role.lower()}_{uid}@example.com"
return User.objects.create_user(
username=username,
email=email,
@@ -60,22 +61,22 @@ class UserFactory:
role=role,
**kwargs
)
@classmethod
def create_regular_user(cls, **kwargs) -> User:
"""Create a regular user."""
return cls.create_user(role='USER', **kwargs)
@classmethod
def create_moderator(cls, **kwargs) -> User:
"""Create a moderator user."""
return cls.create_user(role='MODERATOR', **kwargs)
@classmethod
def create_admin(cls, **kwargs) -> User:
"""Create an admin user."""
return cls.create_user(role='ADMIN', **kwargs)
@classmethod
def create_superuser(cls, **kwargs) -> User:
"""Create a superuser."""
@@ -84,23 +85,23 @@ class UserFactory:
class CompanyFactory:
"""Factory for creating company instances."""
_counter = 0
@classmethod
def _get_unique_id(cls) -> int:
cls._counter += 1
return cls._counter
@classmethod
def create_operator(cls, name: Optional[str] = None, **kwargs) -> Any:
def create_operator(cls, name: str | None = None, **kwargs) -> Any:
"""Create an operator company."""
from apps.parks.models import Company
uid = cls._get_unique_id()
if name is None:
name = f"Test Operator {uid}"
defaults = {
'name': name,
'description': f'Test operator company {uid}',
@@ -108,16 +109,16 @@ class CompanyFactory:
}
defaults.update(kwargs)
return Company.objects.create(**defaults)
@classmethod
def create_manufacturer(cls, name: Optional[str] = None, **kwargs) -> Any:
def create_manufacturer(cls, name: str | None = None, **kwargs) -> Any:
"""Create a manufacturer company."""
from apps.rides.models import Company
uid = cls._get_unique_id()
if name is None:
name = f"Test Manufacturer {uid}"
defaults = {
'name': name,
'description': f'Test manufacturer company {uid}',
@@ -129,42 +130,42 @@ class CompanyFactory:
class ParkFactory:
"""Factory for creating park instances."""
_counter = 0
@classmethod
def _get_unique_id(cls) -> int:
cls._counter += 1
return cls._counter
@classmethod
def create_park(
cls,
name: Optional[str] = None,
operator: Optional[Any] = None,
name: str | None = None,
operator: Any | None = None,
status: str = 'OPERATING',
**kwargs
) -> Any:
"""
Create a park with specified status.
Args:
name: Park name (auto-generated if not provided)
operator: Operator company (auto-created if not provided)
status: Park status
**kwargs: Additional park fields
Returns:
Created Park instance
"""
from apps.parks.models import Park
uid = cls._get_unique_id()
if name is None:
name = f"Test Park {uid}"
if operator is None:
operator = CompanyFactory.create_operator()
defaults = {
'name': name,
'slug': f'test-park-{uid}',
@@ -179,38 +180,38 @@ class ParkFactory:
class RideFactory:
"""Factory for creating ride instances."""
_counter = 0
@classmethod
def _get_unique_id(cls) -> int:
cls._counter += 1
return cls._counter
@classmethod
def create_ride(
cls,
name: Optional[str] = None,
park: Optional[Any] = None,
manufacturer: Optional[Any] = None,
name: str | None = None,
park: Any | None = None,
manufacturer: Any | None = None,
status: str = 'OPERATING',
**kwargs
) -> Any:
"""
Create a ride with specified status.
Args:
name: Ride name (auto-generated if not provided)
park: Park for the ride (auto-created if not provided)
manufacturer: Manufacturer company (auto-created if not provided)
status: Ride status
**kwargs: Additional ride fields
Returns:
Created Ride instance
"""
from apps.rides.models import Ride
uid = cls._get_unique_id()
if name is None:
name = f"Test Ride {uid}"
@@ -218,7 +219,7 @@ class RideFactory:
park = ParkFactory.create_park()
if manufacturer is None:
manufacturer = CompanyFactory.create_manufacturer()
defaults = {
'name': name,
'slug': f'test-ride-{uid}',
@@ -233,39 +234,39 @@ class RideFactory:
class EditSubmissionFactory:
"""Factory for creating edit submission instances."""
_counter = 0
@classmethod
def _get_unique_id(cls) -> int:
cls._counter += 1
return cls._counter
@classmethod
def create_submission(
cls,
user: Optional[Any] = None,
target_object: Optional[Any] = None,
user: Any | None = None,
target_object: Any | None = None,
status: str = 'PENDING',
changes: Optional[Dict[str, Any]] = None,
changes: dict[str, Any] | None = None,
**kwargs
) -> Any:
"""
Create an edit submission.
Args:
user: User who submitted (auto-created if not provided)
target_object: Object being edited (auto-created if not provided)
status: Submission status
changes: Changes dictionary
**kwargs: Additional fields
Returns:
Created EditSubmission instance
"""
from apps.moderation.models import EditSubmission
from apps.parks.models import Company
uid = cls._get_unique_id()
if user is None:
user = UserFactory.create_regular_user()
@@ -276,9 +277,9 @@ class EditSubmissionFactory:
)
if changes is None:
changes = {'name': f'Updated Name {uid}'}
content_type = ContentType.objects.get_for_model(target_object)
defaults = {
'user': user,
'content_type': content_type,
@@ -294,37 +295,37 @@ class EditSubmissionFactory:
class ModerationReportFactory:
"""Factory for creating moderation report instances."""
_counter = 0
@classmethod
def _get_unique_id(cls) -> int:
cls._counter += 1
return cls._counter
@classmethod
def create_report(
cls,
reporter: Optional[Any] = None,
target_object: Optional[Any] = None,
reporter: Any | None = None,
target_object: Any | None = None,
status: str = 'PENDING',
**kwargs
) -> Any:
"""
Create a moderation report.
Args:
reporter: User who reported (auto-created if not provided)
target_object: Object being reported (auto-created if not provided)
status: Report status
**kwargs: Additional fields
Returns:
Created ModerationReport instance
"""
from apps.moderation.models import ModerationReport
from apps.parks.models import Company
uid = cls._get_unique_id()
if reporter is None:
reporter = UserFactory.create_regular_user()
@@ -333,9 +334,9 @@ class ModerationReportFactory:
name=f'Reported Company {uid}',
description='Test company'
)
content_type = ContentType.objects.get_for_model(target_object)
defaults = {
'report_type': 'CONTENT',
'status': status,
@@ -354,7 +355,7 @@ class ModerationReportFactory:
class MockInstance:
"""
Mock instance for testing guards without database.
Example:
instance = MockInstance(
status='PENDING',
@@ -362,11 +363,11 @@ class MockInstance:
assigned_to=moderator
)
"""
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def __repr__(self):
attrs = ', '.join(f'{k}={v!r}' for k, v in self.__dict__.items())
return f'MockInstance({attrs})'

View File

@@ -7,34 +7,36 @@ This module provides utility functions for testing state machine functionality:
- Guard testing utilities
"""
from typing import Any, Optional, List, Callable
from collections.abc import Callable
from typing import Any
from django.contrib.contenttypes.models import ContentType
def assert_transition_allowed(
instance: Any,
method_name: str,
user: Optional[Any] = None
user: Any | None = None
) -> bool:
"""
Assert that a transition is allowed.
Args:
instance: Model instance with FSM field
method_name: Name of the transition method
user: User attempting the transition
Returns:
True if transition is allowed
Raises:
AssertionError: If transition is not allowed
Example:
assert_transition_allowed(submission, 'transition_to_approved', moderator)
"""
from django_fsm import can_proceed
method = getattr(instance, method_name)
result = can_proceed(method)
assert result, f"Transition {method_name} should be allowed but was denied"
@@ -44,27 +46,27 @@ def assert_transition_allowed(
def assert_transition_denied(
instance: Any,
method_name: str,
user: Optional[Any] = None
user: Any | None = None
) -> bool:
"""
Assert that a transition is denied.
Args:
instance: Model instance with FSM field
method_name: Name of the transition method
user: User attempting the transition
Returns:
True if transition is denied
Raises:
AssertionError: If transition is allowed
Example:
assert_transition_denied(submission, 'transition_to_approved', regular_user)
"""
from django_fsm import can_proceed
method = getattr(instance, method_name)
result = can_proceed(method)
assert not result, f"Transition {method_name} should be denied but was allowed"
@@ -74,130 +76,130 @@ def assert_transition_denied(
def assert_state_log_created(
instance: Any,
expected_state: str,
user: Optional[Any] = None
user: Any | None = None
) -> Any:
"""
Assert that a StateLog entry was created for a transition.
Args:
instance: Model instance that was transitioned
expected_state: The expected final state in the log
user: Expected user who made the transition (optional)
Returns:
The StateLog entry
Raises:
AssertionError: If StateLog entry not found or doesn't match
Example:
log = assert_state_log_created(submission, 'APPROVED', moderator)
"""
from django_fsm_log.models import StateLog
ct = ContentType.objects.get_for_model(instance)
log = StateLog.objects.filter(
content_type=ct,
object_id=instance.id,
state=expected_state
).first()
assert log is not None, f"StateLog for state '{expected_state}' not found"
if user is not None:
assert log.by == user, f"Expected log.by={user}, got {log.by}"
return log
def assert_state_log_count(instance: Any, expected_count: int) -> List[Any]:
def assert_state_log_count(instance: Any, expected_count: int) -> list[Any]:
"""
Assert the number of StateLog entries for an instance.
Args:
instance: Model instance to check logs for
expected_count: Expected number of log entries
Returns:
List of StateLog entries
Raises:
AssertionError: If count doesn't match
Example:
logs = assert_state_log_count(submission, 2)
"""
from django_fsm_log.models import StateLog
ct = ContentType.objects.get_for_model(instance)
logs = list(StateLog.objects.filter(
content_type=ct,
object_id=instance.id
).order_by('timestamp'))
actual_count = len(logs)
assert actual_count == expected_count, \
f"Expected {expected_count} StateLog entries, got {actual_count}"
return logs
def assert_state_transition_sequence(
instance: Any,
expected_states: List[str]
) -> List[Any]:
expected_states: list[str]
) -> list[Any]:
"""
Assert that state transitions occurred in a specific sequence.
Args:
instance: Model instance to check
expected_states: List of expected states in order
Returns:
List of StateLog entries
Raises:
AssertionError: If sequence doesn't match
Example:
assert_state_transition_sequence(submission, ['ESCALATED', 'APPROVED'])
"""
from django_fsm_log.models import StateLog
ct = ContentType.objects.get_for_model(instance)
logs = list(StateLog.objects.filter(
content_type=ct,
object_id=instance.id
).order_by('timestamp'))
actual_states = [log.state for log in logs]
assert actual_states == expected_states, \
f"Expected state sequence {expected_states}, got {actual_states}"
return logs
def assert_guard_passes(
guard: Callable,
instance: Any,
user: Optional[Any] = None,
user: Any | None = None,
message: str = ""
) -> bool:
"""
Assert that a guard function passes.
Args:
guard: Guard function or callable
instance: Model instance to check
user: User attempting the action
message: Optional message on failure
Returns:
True if guard passes
Raises:
AssertionError: If guard fails
Example:
assert_guard_passes(permission_guard, instance, moderator)
"""
@@ -210,58 +212,58 @@ def assert_guard_passes(
def assert_guard_fails(
guard: Callable,
instance: Any,
user: Optional[Any] = None,
expected_error_code: Optional[str] = None,
user: Any | None = None,
expected_error_code: str | None = None,
message: str = ""
) -> bool:
"""
Assert that a guard function fails.
Args:
guard: Guard function or callable
instance: Model instance to check
user: User attempting the action
expected_error_code: Expected error code from guard
message: Optional message on failure
Returns:
True if guard fails as expected
Raises:
AssertionError: If guard passes or wrong error code
Example:
assert_guard_fails(permission_guard, instance, regular_user, 'PERMISSION_DENIED')
"""
result = guard(instance, user)
fail_message = message or f"Guard should fail but returned {result}"
assert result is False, fail_message
if expected_error_code and hasattr(guard, 'error_code'):
assert guard.error_code == expected_error_code, \
f"Expected error code {expected_error_code}, got {guard.error_code}"
return True
def transition_and_save(
instance: Any,
transition_method: str,
user: Optional[Any] = None,
user: Any | None = None,
**kwargs
) -> Any:
"""
Execute a transition and save the instance.
Args:
instance: Model instance with FSM field
transition_method: Name of the transition method
user: User performing the transition
**kwargs: Additional arguments for the transition
Returns:
The saved instance
Example:
submission = transition_and_save(submission, 'transition_to_approved', moderator)
"""
@@ -272,37 +274,36 @@ def transition_and_save(
return instance
def get_available_transitions(instance: Any) -> List[str]:
def get_available_transitions(instance: Any) -> list[str]:
"""
Get list of available transitions for an instance.
Args:
instance: Model instance with FSM field
Returns:
List of available transition method names
Example:
transitions = get_available_transitions(submission)
# ['transition_to_approved', 'transition_to_rejected', 'transition_to_escalated']
"""
from django_fsm import get_available_FIELD_transitions
# Get the state field name from the instance
state_field = getattr(instance, 'state_field_name', 'status')
# Build the function name dynamically
func_name = f'get_available_{state_field}_transitions'
if hasattr(instance, func_name):
get_transitions = getattr(instance, func_name)
return [t.name for t in get_transitions()]
# Fallback: look for transition methods
transitions = []
for attr_name in dir(instance):
if attr_name.startswith('transition_to_'):
transitions.append(attr_name)
return transitions
@@ -310,22 +311,22 @@ def create_transition_context(
instance: Any,
from_state: str,
to_state: str,
user: Optional[Any] = None,
user: Any | None = None,
**extra
) -> dict:
"""
Create a mock transition context dictionary.
Args:
instance: Model instance being transitioned
from_state: Source state
to_state: Target state
user: User performing the transition
**extra: Additional context data
Returns:
Dictionary matching TransitionContext structure
Example:
context = create_transition_context(submission, 'PENDING', 'APPROVED', moderator)
"""

View File

@@ -2,7 +2,7 @@
import pytest
from django.core.exceptions import ImproperlyConfigured
from apps.core.choices.base import RichChoice, ChoiceCategory
from apps.core.choices.base import ChoiceCategory, RichChoice
from apps.core.choices.registry import registry
from apps.core.state_machine.builder import StateTransitionBuilder

View File

@@ -9,14 +9,15 @@ This module tests:
- Callback context handling
"""
from typing import Any
from unittest.mock import Mock, patch
from django.test import TestCase
from unittest.mock import Mock, patch, call
from typing import Any, Dict, List
class CallbackContext:
"""Mock context for testing callbacks."""
def __init__(
self,
instance: Any = None,
@@ -30,8 +31,8 @@ class CallbackContext:
self.to_state = to_state
self.user = user
self.extra = extra
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
return {
'instance': self.instance,
'from_state': self.from_state,
@@ -43,179 +44,179 @@ class CallbackContext:
class MockCallback:
"""Mock callback for testing."""
def __init__(self, name: str = 'callback', should_raise: bool = False):
self.name = name
self.calls: List[Dict] = []
self.calls: list[dict] = []
self.should_raise = should_raise
def __call__(self, context: Dict[str, Any]) -> None:
def __call__(self, context: dict[str, Any]) -> None:
self.calls.append(context)
if self.should_raise:
raise ValueError(f"Callback {self.name} failed")
@property
def call_count(self) -> int:
return len(self.calls)
def was_called(self) -> bool:
return len(self.calls) > 0
def reset(self):
self.calls = []
class PreTransitionCallbackTests(TestCase):
"""Tests for pre-transition callbacks."""
def test_pre_callback_executes_before_state_change(self):
"""Test that pre-transition callback executes before state changes."""
callback = MockCallback('pre_callback')
context = CallbackContext(from_state='PENDING', to_state='APPROVED')
# Simulate pre-transition execution
callback(context.to_dict())
self.assertTrue(callback.was_called())
self.assertEqual(callback.calls[0]['from_state'], 'PENDING')
self.assertEqual(callback.calls[0]['to_state'], 'APPROVED')
def test_pre_callback_receives_instance(self):
"""Test that pre-callback receives the model instance."""
mock_instance = Mock()
mock_instance.id = 123
mock_instance.status = 'PENDING'
callback = MockCallback()
context = CallbackContext(instance=mock_instance)
callback(context.to_dict())
self.assertEqual(callback.calls[0]['instance'], mock_instance)
def test_pre_callback_receives_user(self):
"""Test that pre-callback receives the user performing transition."""
mock_user = Mock()
mock_user.username = 'moderator'
callback = MockCallback()
context = CallbackContext(user=mock_user)
callback(context.to_dict())
self.assertEqual(callback.calls[0]['user'], mock_user)
def test_pre_callback_can_prevent_transition(self):
"""Test that pre-callback can prevent transition by raising exception."""
callback = MockCallback(should_raise=True)
context = CallbackContext()
with self.assertRaises(ValueError):
callback(context.to_dict())
def test_multiple_pre_callbacks_execute_in_order(self):
"""Test that multiple pre-callbacks execute in registration order."""
execution_order = []
def callback_1(ctx):
execution_order.append('first')
def callback_2(ctx):
execution_order.append('second')
def callback_3(ctx):
execution_order.append('third')
context = CallbackContext().to_dict()
# Execute in order
callback_1(context)
callback_2(context)
callback_3(context)
self.assertEqual(execution_order, ['first', 'second', 'third'])
class PostTransitionCallbackTests(TestCase):
"""Tests for post-transition callbacks."""
def test_post_callback_executes_after_state_change(self):
"""Test that post-transition callback executes after state changes."""
callback = MockCallback('post_callback')
# Simulate instance after transition
mock_instance = Mock()
mock_instance.status = 'APPROVED' # Already changed
context = CallbackContext(
instance=mock_instance,
from_state='PENDING',
to_state='APPROVED'
)
callback(context.to_dict())
self.assertTrue(callback.was_called())
self.assertEqual(callback.calls[0]['instance'].status, 'APPROVED')
def test_post_callback_receives_updated_instance(self):
"""Test that post-callback receives instance with new state."""
mock_instance = Mock()
mock_instance.status = 'APPROVED'
mock_instance.approved_at = '2025-01-15'
mock_instance.handled_by_id = 456
callback = MockCallback()
context = CallbackContext(instance=mock_instance)
callback(context.to_dict())
instance = callback.calls[0]['instance']
self.assertEqual(instance.status, 'APPROVED')
self.assertEqual(instance.approved_at, '2025-01-15')
def test_post_callback_failure_does_not_rollback(self):
"""Test that post-callback failures don't rollback the transition."""
# In a real scenario, the transition would already be committed
callback = MockCallback(should_raise=True)
context = CallbackContext()
# Post-callback failure should not affect already-committed transition
with self.assertRaises(ValueError):
callback(context.to_dict())
# The transition would still be committed in real usage
self.assertTrue(callback.was_called())
def test_multiple_post_callbacks_execute_in_order(self):
"""Test that multiple post-callbacks execute in order."""
execution_order = []
def notification_callback(ctx):
execution_order.append('notification')
def cache_callback(ctx):
execution_order.append('cache')
def analytics_callback(ctx):
execution_order.append('analytics')
context = CallbackContext().to_dict()
notification_callback(context)
cache_callback(context)
analytics_callback(context)
self.assertEqual(execution_order, ['notification', 'cache', 'analytics'])
class ErrorCallbackTests(TestCase):
"""Tests for error callbacks."""
def test_error_callback_receives_exception(self):
"""Test that error callback receives exception information."""
error_callback = MockCallback()
try:
raise ValueError("Transition failed")
except ValueError as e:
@@ -227,33 +228,33 @@ class ErrorCallbackTests(TestCase):
'exception_type': type(e).__name__
}
error_callback(error_context)
self.assertTrue(error_callback.was_called())
self.assertIn('exception', error_callback.calls[0])
self.assertEqual(error_callback.calls[0]['exception_type'], 'ValueError')
def test_error_callback_for_cleanup(self):
"""Test that error callbacks can perform cleanup."""
cleanup_performed = []
def cleanup_callback(ctx):
cleanup_performed.append(True)
# In real usage, might release locks, revert partial changes, etc.
try:
raise ValueError("Transition failed")
except ValueError:
cleanup_callback({'exception': 'test'})
self.assertTrue(cleanup_performed)
def test_error_callback_receives_context(self):
"""Test that error callback receives full transition context."""
mock_instance = Mock()
mock_user = Mock()
error_callback = MockCallback()
error_context = {
'instance': mock_instance,
'from_state': 'PENDING',
@@ -261,152 +262,152 @@ class ErrorCallbackTests(TestCase):
'user': mock_user,
'exception': ValueError("Test error")
}
error_callback(error_context)
self.assertEqual(error_callback.calls[0]['instance'], mock_instance)
self.assertEqual(error_callback.calls[0]['user'], mock_user)
class ConditionalCallbackTests(TestCase):
"""Tests for conditional callback execution."""
def test_callback_with_state_filter(self):
"""Test callback that only executes for specific states."""
execution_log = []
def approval_only_callback(ctx):
if ctx.get('to_state') == 'APPROVED':
execution_log.append('approved')
# Transition to APPROVED - should execute
approval_only_callback({'to_state': 'APPROVED'})
self.assertEqual(len(execution_log), 1)
# Transition to REJECTED - should not execute
approval_only_callback({'to_state': 'REJECTED'})
self.assertEqual(len(execution_log), 1) # Still 1
def test_callback_with_transition_filter(self):
"""Test callback that only executes for specific transitions."""
execution_log = []
def escalation_callback(ctx):
if ctx.get('to_state') == 'ESCALATED':
execution_log.append('escalated')
# Escalation - should execute
escalation_callback({'to_state': 'ESCALATED'})
self.assertEqual(len(execution_log), 1)
# Other transitions - should not execute
escalation_callback({'to_state': 'APPROVED'})
self.assertEqual(len(execution_log), 1)
def test_callback_with_user_role_filter(self):
"""Test callback that checks user role."""
admin_notifications = []
def admin_only_notification(ctx):
user = ctx.get('user')
if user and getattr(user, 'role', None) == 'ADMIN':
admin_notifications.append(ctx)
admin_user = Mock(role='ADMIN')
moderator_user = Mock(role='MODERATOR')
admin_only_notification({'user': admin_user})
self.assertEqual(len(admin_notifications), 1)
admin_only_notification({'user': moderator_user})
self.assertEqual(len(admin_notifications), 1) # Still 1
class CallbackChainTests(TestCase):
"""Tests for callback chains and pipelines."""
def test_callback_chain_continues_on_success(self):
"""Test that callback chain continues when callbacks succeed."""
results = []
callbacks = [
lambda ctx: results.append('a'),
lambda ctx: results.append('b'),
lambda ctx: results.append('c'),
]
context = {}
for cb in callbacks:
cb(context)
self.assertEqual(results, ['a', 'b', 'c'])
def test_callback_chain_stops_on_failure(self):
"""Test that callback chain stops when a callback fails."""
results = []
def callback_a(ctx):
results.append('a')
def callback_b(ctx):
raise ValueError("B failed")
def callback_c(ctx):
results.append('c')
callbacks = [callback_a, callback_b, callback_c]
context = {}
for cb in callbacks:
try:
cb(context)
except ValueError:
break
self.assertEqual(results, ['a']) # c never executed
def test_callback_chain_with_continue_on_error(self):
"""Test callback chain that continues despite errors."""
results = []
errors = []
def callback_a(ctx):
results.append('a')
def callback_b(ctx):
raise ValueError("B failed")
def callback_c(ctx):
results.append('c')
callbacks = [callback_a, callback_b, callback_c]
context = {}
for cb in callbacks:
try:
cb(context)
except Exception as e:
errors.append(str(e))
self.assertEqual(results, ['a', 'c'])
self.assertEqual(len(errors), 1)
class CallbackContextEnrichmentTests(TestCase):
"""Tests for callback context enrichment."""
def test_context_includes_model_class(self):
"""Test that context includes the model class."""
mock_instance = Mock()
mock_instance.__class__.__name__ = 'EditSubmission'
context = {
'instance': mock_instance,
'model_class': type(mock_instance)
}
self.assertIn('model_class', context)
def test_context_includes_transition_name(self):
"""Test that context includes the transition method name."""
context = {
@@ -415,9 +416,9 @@ class CallbackContextEnrichmentTests(TestCase):
'to_state': 'APPROVED',
'transition_name': 'transition_to_approved'
}
self.assertEqual(context['transition_name'], 'transition_to_approved')
def test_context_includes_timestamp(self):
"""Test that context includes transition timestamp."""
from django.utils import timezone
@@ -452,9 +453,10 @@ class NotificationCallbackTests(TestCase):
instance=None,
):
"""Helper to create a TransitionContext."""
from ..callback_base import TransitionContext
from django.utils import timezone
from ..callback_base import TransitionContext
if instance is None:
instance = Mock()
instance.pk = 123
@@ -603,9 +605,10 @@ class CacheCallbackTests(TestCase):
target_state: str = 'CLOSED_TEMP',
):
"""Helper to create a TransitionContext."""
from ..callback_base import TransitionContext
from django.utils import timezone
from ..callback_base import TransitionContext
instance = Mock()
instance.pk = instance_id
instance.__class__.__name__ = model_name
@@ -703,7 +706,7 @@ class CacheCallbackTests(TestCase):
context = self._create_transition_context()
# Should not raise overall
result = callback.execute(context)
callback.execute(context)
# All patterns should have been attempted
self.assertGreater(call_count, 1)
@@ -716,9 +719,10 @@ class ModelCacheInvalidationTests(TestCase):
model_name: str = 'Ride',
instance_id: int = 789,
):
from ..callback_base import TransitionContext
from django.utils import timezone
from ..callback_base import TransitionContext
instance = Mock()
instance.pk = instance_id
instance.__class__.__name__ = model_name
@@ -771,7 +775,7 @@ class RelatedUpdateCallbackTests(TestCase):
def setUp(self):
"""Set up test fixtures."""
from django.contrib.auth import get_user_model
User = get_user_model()
get_user_model()
self.user = Mock()
self.user.pk = 1
@@ -783,9 +787,10 @@ class RelatedUpdateCallbackTests(TestCase):
instance=None,
target_state: str = 'OPERATING',
):
from ..callback_base import TransitionContext
from django.utils import timezone
from ..callback_base import TransitionContext
if instance is None:
instance = Mock()
instance.pk = 123
@@ -920,9 +925,10 @@ class CallbackErrorHandlingTests(TestCase):
"""Tests for callback error handling paths."""
def _create_transition_context(self):
from ..callback_base import TransitionContext
from django.utils import timezone
from ..callback_base import TransitionContext
instance = Mock()
instance.pk = 1
instance.__class__.__name__ = 'EditSubmission'
@@ -939,9 +945,10 @@ class CallbackErrorHandlingTests(TestCase):
@patch('apps.core.state_machine.callbacks.notifications.NotificationService')
def test_notification_callback_logs_error_on_failure(self, mock_service_class):
"""Test NotificationCallback logs errors when service fails."""
from ..callbacks.notifications import NotificationCallback
import logging
from ..callbacks.notifications import NotificationCallback
mock_service = Mock()
mock_service.send_notification = Mock(side_effect=Exception("Network error"))
mock_service_class.return_value = mock_service
@@ -949,7 +956,7 @@ class CallbackErrorHandlingTests(TestCase):
callback = NotificationCallback()
context = self._create_transition_context()
with self.assertLogs(level=logging.WARNING) as log_output:
with self.assertLogs(level=logging.WARNING):
try:
callback.execute(context)
except Exception:
@@ -979,9 +986,10 @@ class CallbackErrorHandlingTests(TestCase):
def test_callback_with_none_user(self):
"""Test callbacks handle None user gracefully."""
from django.utils import timezone
from ..callback_base import TransitionContext
from ..callbacks.notifications import NotificationCallback
from django.utils import timezone
instance = Mock()
instance.pk = 1

View File

@@ -1,11 +1,10 @@
"""Tests for transition decorator generation."""
import pytest
from unittest.mock import Mock
from apps.core.state_machine.decorators import (
generate_transition_decorator,
create_transition_method,
TransitionMethodFactory,
create_transition_method,
generate_transition_decorator,
with_transition_logging,
)

View File

@@ -10,33 +10,30 @@ This module contains tests for:
- CompositeGuard (combining guards with AND/OR logic)
"""
from django.test import TestCase
from django.contrib.auth import get_user_model
from django.contrib.auth.models import AnonymousUser
from django.test import TestCase
from apps.core.state_machine.guards import (
PermissionGuard,
OwnershipGuard,
ADMIN_ROLES,
MODERATOR_ROLES,
AssignmentGuard,
StateGuard,
MetadataGuard,
CompositeGuard,
extract_guards_from_metadata,
create_permission_guard,
create_ownership_guard,
MetadataGuard,
OwnershipGuard,
PermissionGuard,
StateGuard,
create_assignment_guard,
create_composite_guard,
validate_guard_metadata,
create_ownership_guard,
create_permission_guard,
extract_guards_from_metadata,
get_user_role,
has_role,
is_moderator_or_above,
is_admin_or_above,
is_moderator_or_above,
is_superuser_role,
has_permission,
VALID_ROLES,
MODERATOR_ROLES,
ADMIN_ROLES,
SUPERUSER_ROLES,
validate_guard_metadata,
)
User = get_user_model()
@@ -44,7 +41,7 @@ User = get_user_model()
class MockInstance:
"""Mock instance for testing guards."""
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
@@ -89,90 +86,90 @@ class PermissionGuardTests(TestCase):
def test_no_user_fails(self):
"""Test that guard fails when no user is provided."""
guard = PermissionGuard(requires_moderator=True)
result = guard(self.instance, user=None)
self.assertFalse(result)
self.assertEqual(guard.error_code, PermissionGuard.ERROR_CODE_NO_USER)
def test_requires_moderator_allows_moderator(self):
"""Test that requires_moderator allows moderator role."""
guard = PermissionGuard(requires_moderator=True)
result = guard(self.instance, user=self.moderator)
self.assertTrue(result)
def test_requires_moderator_allows_admin(self):
"""Test that requires_moderator allows admin role."""
guard = PermissionGuard(requires_moderator=True)
result = guard(self.instance, user=self.admin)
self.assertTrue(result)
def test_requires_moderator_allows_superuser(self):
"""Test that requires_moderator allows superuser role."""
guard = PermissionGuard(requires_moderator=True)
result = guard(self.instance, user=self.superuser)
self.assertTrue(result)
def test_requires_moderator_denies_regular_user(self):
"""Test that requires_moderator denies regular user."""
guard = PermissionGuard(requires_moderator=True)
result = guard(self.instance, user=self.regular_user)
self.assertFalse(result)
self.assertEqual(guard.error_code, PermissionGuard.ERROR_CODE_PERMISSION_DENIED_ROLE)
def test_requires_admin_allows_admin(self):
"""Test that requires_admin allows admin role."""
guard = PermissionGuard(requires_admin=True)
result = guard(self.instance, user=self.admin)
self.assertTrue(result)
def test_requires_admin_allows_superuser(self):
"""Test that requires_admin allows superuser role."""
guard = PermissionGuard(requires_admin=True)
result = guard(self.instance, user=self.superuser)
self.assertTrue(result)
def test_requires_admin_denies_moderator(self):
"""Test that requires_admin denies moderator role."""
guard = PermissionGuard(requires_admin=True)
result = guard(self.instance, user=self.moderator)
self.assertFalse(result)
self.assertEqual(guard.error_code, PermissionGuard.ERROR_CODE_PERMISSION_DENIED_ROLE)
def test_requires_superuser_allows_superuser(self):
"""Test that requires_superuser allows superuser role."""
guard = PermissionGuard(requires_superuser=True)
result = guard(self.instance, user=self.superuser)
self.assertTrue(result)
def test_requires_superuser_denies_admin(self):
"""Test that requires_superuser denies admin role."""
guard = PermissionGuard(requires_superuser=True)
result = guard(self.instance, user=self.admin)
self.assertFalse(result)
def test_required_roles_explicit_list(self):
"""Test using explicit required_roles list."""
guard = PermissionGuard(required_roles=['ADMIN', 'SUPERUSER'])
self.assertTrue(guard(self.instance, user=self.admin))
self.assertTrue(guard(self.instance, user=self.superuser))
self.assertFalse(guard(self.instance, user=self.moderator))
@@ -182,24 +179,24 @@ class PermissionGuardTests(TestCase):
"""Test custom check function that passes."""
def custom_check(instance, user):
return hasattr(instance, 'allow_access') and instance.allow_access
guard = PermissionGuard(custom_check=custom_check)
instance = MockInstance(allow_access=True)
result = guard(instance, user=self.regular_user)
self.assertTrue(result)
def test_custom_check_fails(self):
"""Test custom check function that fails."""
def custom_check(instance, user):
return hasattr(instance, 'allow_access') and instance.allow_access
guard = PermissionGuard(custom_check=custom_check)
instance = MockInstance(allow_access=False)
result = guard(instance, user=self.regular_user)
self.assertFalse(result)
self.assertEqual(guard.error_code, PermissionGuard.ERROR_CODE_PERMISSION_DENIED_CUSTOM)
@@ -207,25 +204,25 @@ class PermissionGuardTests(TestCase):
"""Test custom error message."""
custom_message = "You need special access for this"
guard = PermissionGuard(requires_moderator=True, error_message=custom_message)
guard(self.instance, user=self.regular_user)
self.assertEqual(guard.get_error_message(), custom_message)
def test_get_required_roles_moderator(self):
"""Test get_required_roles for moderator requirement."""
guard = PermissionGuard(requires_moderator=True)
roles = guard.get_required_roles()
self.assertEqual(set(roles), set(MODERATOR_ROLES))
def test_get_required_roles_admin(self):
"""Test get_required_roles for admin requirement."""
guard = PermissionGuard(requires_admin=True)
roles = guard.get_required_roles()
self.assertEqual(set(roles), set(ADMIN_ROLES))
@@ -268,9 +265,9 @@ class OwnershipGuardTests(TestCase):
"""Test that guard fails when no user is provided."""
instance = MockInstance(created_by=self.owner)
guard = OwnershipGuard()
result = guard(instance, user=None)
self.assertFalse(result)
self.assertEqual(guard.error_code, OwnershipGuard.ERROR_CODE_NO_USER)
@@ -278,36 +275,36 @@ class OwnershipGuardTests(TestCase):
"""Test that owner passes via created_by field."""
instance = MockInstance(created_by=self.owner)
guard = OwnershipGuard()
result = guard(instance, user=self.owner)
self.assertTrue(result)
def test_owner_passes_user_field(self):
"""Test that owner passes via user field."""
instance = MockInstance(user=self.owner)
guard = OwnershipGuard()
result = guard(instance, user=self.owner)
self.assertTrue(result)
def test_owner_passes_submitted_by(self):
"""Test that owner passes via submitted_by field."""
instance = MockInstance(submitted_by=self.owner)
guard = OwnershipGuard()
result = guard(instance, user=self.owner)
self.assertTrue(result)
def test_non_owner_fails(self):
"""Test that non-owner fails."""
instance = MockInstance(created_by=self.owner)
guard = OwnershipGuard()
result = guard(instance, user=self.other_user)
self.assertFalse(result)
self.assertEqual(guard.error_code, OwnershipGuard.ERROR_CODE_NOT_OWNER)
@@ -315,27 +312,27 @@ class OwnershipGuardTests(TestCase):
"""Test that moderator can bypass ownership check."""
instance = MockInstance(created_by=self.owner)
guard = OwnershipGuard(allow_moderator_override=True)
result = guard(instance, user=self.moderator)
self.assertTrue(result)
def test_admin_override(self):
"""Test that admin can bypass ownership check."""
instance = MockInstance(created_by=self.owner)
guard = OwnershipGuard(allow_admin_override=True)
result = guard(instance, user=self.admin)
self.assertTrue(result)
def test_custom_owner_fields(self):
"""Test custom owner field names."""
instance = MockInstance(author=self.owner)
guard = OwnershipGuard(owner_fields=['author'])
result = guard(instance, user=self.owner)
self.assertTrue(result)
def test_anonymous_user_fails(self):
@@ -343,9 +340,9 @@ class OwnershipGuardTests(TestCase):
instance = MockInstance(created_by=self.owner)
guard = OwnershipGuard()
anonymous = AnonymousUser()
result = guard(instance, user=anonymous)
self.assertFalse(result)
@@ -382,9 +379,9 @@ class AssignmentGuardTests(TestCase):
"""Test that guard fails when no user is provided."""
instance = MockInstance(assigned_to=self.assigned_user)
guard = AssignmentGuard()
result = guard(instance, user=None)
self.assertFalse(result)
self.assertEqual(guard.error_code, AssignmentGuard.ERROR_CODE_NO_USER)
@@ -392,18 +389,18 @@ class AssignmentGuardTests(TestCase):
"""Test that assigned user passes."""
instance = MockInstance(assigned_to=self.assigned_user)
guard = AssignmentGuard()
result = guard(instance, user=self.assigned_user)
self.assertTrue(result)
def test_unassigned_user_fails(self):
"""Test that unassigned user fails."""
instance = MockInstance(assigned_to=self.assigned_user)
guard = AssignmentGuard()
result = guard(instance, user=self.other_user)
self.assertFalse(result)
self.assertEqual(guard.error_code, AssignmentGuard.ERROR_CODE_NOT_ASSIGNED)
@@ -411,18 +408,18 @@ class AssignmentGuardTests(TestCase):
"""Test that admin can bypass assignment check."""
instance = MockInstance(assigned_to=self.assigned_user)
guard = AssignmentGuard(allow_admin_override=True)
result = guard(instance, user=self.admin)
self.assertTrue(result)
def test_require_assignment_with_no_assignment(self):
"""Test require_assignment fails when no one is assigned."""
instance = MockInstance(assigned_to=None)
guard = AssignmentGuard(require_assignment=True)
result = guard(instance, user=self.assigned_user)
self.assertFalse(result)
self.assertEqual(guard.error_code, AssignmentGuard.ERROR_CODE_NO_ASSIGNMENT)
@@ -430,18 +427,18 @@ class AssignmentGuardTests(TestCase):
"""Test custom assignment field names."""
instance = MockInstance(reviewer=self.assigned_user)
guard = AssignmentGuard(assignment_fields=['reviewer'])
result = guard(instance, user=self.assigned_user)
self.assertTrue(result)
def test_error_message_for_no_assignment(self):
"""Test error message when no assignment exists."""
instance = MockInstance(assigned_to=None)
guard = AssignmentGuard(require_assignment=True)
guard(instance, user=self.assigned_user)
self.assertIn('assigned', guard.get_error_message().lower())
@@ -466,18 +463,18 @@ class StateGuardTests(TestCase):
"""Test that guard passes when in allowed state."""
instance = MockInstance(status='PENDING')
guard = StateGuard(allowed_states=['PENDING', 'UNDER_REVIEW'])
result = guard(instance, user=self.user)
self.assertTrue(result)
def test_allowed_states_fails(self):
"""Test that guard fails when not in allowed state."""
instance = MockInstance(status='COMPLETED')
guard = StateGuard(allowed_states=['PENDING', 'UNDER_REVIEW'])
result = guard(instance, user=self.user)
self.assertFalse(result)
self.assertEqual(guard.error_code, StateGuard.ERROR_CODE_INVALID_STATE)
@@ -485,18 +482,18 @@ class StateGuardTests(TestCase):
"""Test that guard passes when not in blocked state."""
instance = MockInstance(status='PENDING')
guard = StateGuard(blocked_states=['COMPLETED', 'CANCELLED'])
result = guard(instance, user=self.user)
self.assertTrue(result)
def test_blocked_states_fails(self):
"""Test that guard fails when in blocked state."""
instance = MockInstance(status='COMPLETED')
guard = StateGuard(blocked_states=['COMPLETED', 'CANCELLED'])
result = guard(instance, user=self.user)
self.assertFalse(result)
self.assertEqual(guard.error_code, StateGuard.ERROR_CODE_BLOCKED_STATE)
@@ -504,18 +501,18 @@ class StateGuardTests(TestCase):
"""Test using custom state field name."""
instance = MockInstance(workflow_status='ACTIVE')
guard = StateGuard(allowed_states=['ACTIVE'], state_field='workflow_status')
result = guard(instance, user=self.user)
self.assertTrue(result)
def test_error_message_includes_states(self):
"""Test that error message includes allowed states."""
instance = MockInstance(status='COMPLETED')
guard = StateGuard(allowed_states=['PENDING', 'UNDER_REVIEW'])
guard(instance, user=self.user)
message = guard.get_error_message()
self.assertIn('PENDING', message)
self.assertIn('UNDER_REVIEW', message)
@@ -542,18 +539,18 @@ class MetadataGuardTests(TestCase):
"""Test that guard passes when required fields are present."""
instance = MockInstance(resolution_notes='Fixed', assigned_to='user')
guard = MetadataGuard(required_fields=['resolution_notes', 'assigned_to'])
result = guard(instance, user=self.user)
self.assertTrue(result)
def test_required_field_missing(self):
"""Test that guard fails when required field is missing."""
instance = MockInstance(resolution_notes='Fixed')
guard = MetadataGuard(required_fields=['resolution_notes', 'assigned_to'])
result = guard(instance, user=self.user)
self.assertFalse(result)
self.assertEqual(guard.error_code, MetadataGuard.ERROR_CODE_MISSING_FIELD)
@@ -561,9 +558,9 @@ class MetadataGuardTests(TestCase):
"""Test that guard fails when required field is None."""
instance = MockInstance(resolution_notes=None)
guard = MetadataGuard(required_fields=['resolution_notes'])
result = guard(instance, user=self.user)
self.assertFalse(result)
self.assertEqual(guard.error_code, MetadataGuard.ERROR_CODE_MISSING_FIELD)
@@ -571,9 +568,9 @@ class MetadataGuardTests(TestCase):
"""Test that empty string fails when check_not_empty is True."""
instance = MockInstance(resolution_notes=' ')
guard = MetadataGuard(required_fields=['resolution_notes'], check_not_empty=True)
result = guard(instance, user=self.user)
self.assertFalse(result)
self.assertEqual(guard.error_code, MetadataGuard.ERROR_CODE_EMPTY_FIELD)
@@ -581,9 +578,9 @@ class MetadataGuardTests(TestCase):
"""Test that empty list fails when check_not_empty is True."""
instance = MockInstance(tags=[])
guard = MetadataGuard(required_fields=['tags'], check_not_empty=True)
result = guard(instance, user=self.user)
self.assertFalse(result)
self.assertEqual(guard.error_code, MetadataGuard.ERROR_CODE_EMPTY_FIELD)
@@ -591,9 +588,9 @@ class MetadataGuardTests(TestCase):
"""Test that empty dict fails when check_not_empty is True."""
instance = MockInstance(metadata={})
guard = MetadataGuard(required_fields=['metadata'], check_not_empty=True)
result = guard(instance, user=self.user)
self.assertFalse(result)
self.assertEqual(guard.error_code, MetadataGuard.ERROR_CODE_EMPTY_FIELD)
@@ -601,9 +598,9 @@ class MetadataGuardTests(TestCase):
"""Test that error message includes the field name."""
instance = MockInstance(resolution_notes=None)
guard = MetadataGuard(required_fields=['resolution_notes'])
guard(instance, user=self.user)
message = guard.get_error_message()
self.assertIn('Resolution Notes', message)
@@ -645,9 +642,9 @@ class CompositeGuardTests(TestCase):
OwnershipGuard()
]
composite = CompositeGuard(guards, operator='AND')
result = composite(instance, user=self.moderator)
self.assertTrue(result)
def test_and_operator_one_fails(self):
@@ -658,9 +655,9 @@ class CompositeGuardTests(TestCase):
OwnershipGuard() # Will fail - moderator is not owner
]
composite = CompositeGuard(guards, operator='AND')
result = composite(instance, user=self.non_owner_moderator)
self.assertFalse(result)
self.assertEqual(composite.error_code, CompositeGuard.ERROR_CODE_SOME_FAILED)
@@ -672,9 +669,9 @@ class CompositeGuardTests(TestCase):
OwnershipGuard() # Will pass - user is owner
]
composite = CompositeGuard(guards, operator='OR')
result = composite(instance, user=self.owner)
self.assertTrue(result)
def test_or_operator_all_fail(self):
@@ -685,30 +682,30 @@ class CompositeGuardTests(TestCase):
OwnershipGuard() # Not the owner fails
]
composite = CompositeGuard(guards, operator='OR')
result = composite(instance, user=self.owner)
self.assertFalse(result)
self.assertEqual(composite.error_code, CompositeGuard.ERROR_CODE_ALL_FAILED)
def test_nested_composite_guards(self):
"""Test nested composite guards."""
instance = MockInstance(created_by=self.moderator, status='PENDING')
# Inner composite: moderator OR owner
inner = CompositeGuard([
PermissionGuard(requires_moderator=True),
OwnershipGuard()
], operator='OR')
# Outer composite: (moderator OR owner) AND valid state
outer = CompositeGuard([
inner,
StateGuard(allowed_states=['PENDING'])
], operator='AND')
result = outer(instance, user=self.moderator)
self.assertTrue(result)
def test_error_message_from_failed_guard(self):
@@ -717,9 +714,9 @@ class CompositeGuardTests(TestCase):
perm_guard = PermissionGuard(requires_admin=True)
guards = [perm_guard]
composite = CompositeGuard(guards, operator='AND')
composite(instance, user=self.owner)
message = composite.get_error_message()
self.assertIn('admin', message.lower())
@@ -746,42 +743,42 @@ class GuardFactoryTests(TestCase):
metadata = {'requires_moderator': True}
guard = create_permission_guard(metadata)
instance = MockInstance()
result = guard(instance, user=self.moderator)
self.assertTrue(result)
def test_create_permission_guard_admin(self):
"""Test create_permission_guard with admin requirement."""
metadata = {'requires_admin_approval': True}
guard = create_permission_guard(metadata)
self.assertTrue(guard.requires_admin)
def test_create_permission_guard_escalation_level(self):
"""Test create_permission_guard with escalation level."""
metadata = {'escalation_level': 'admin'}
guard = create_permission_guard(metadata)
self.assertTrue(guard.requires_admin)
def test_create_ownership_guard(self):
"""Test create_ownership_guard factory."""
guard = create_ownership_guard(allow_moderator_override=True)
self.assertTrue(guard.allow_moderator_override)
def test_create_assignment_guard(self):
"""Test create_assignment_guard factory."""
guard = create_assignment_guard(require_assignment=True)
self.assertTrue(guard.require_assignment)
def test_create_composite_guard(self):
"""Test create_composite_guard factory."""
guards = [PermissionGuard(), OwnershipGuard()]
composite = create_composite_guard(guards, operator='OR')
self.assertEqual(composite.operator, 'OR')
self.assertEqual(len(composite.guards), 2)
@@ -798,7 +795,7 @@ class MetadataExtractionTests(TestCase):
"""Test extracting guard for moderator requirement."""
metadata = {'requires_moderator': True}
guards = extract_guards_from_metadata(metadata)
self.assertEqual(len(guards), 1)
self.assertIsInstance(guards[0], PermissionGuard)
@@ -806,7 +803,7 @@ class MetadataExtractionTests(TestCase):
"""Test extracting guard for admin requirement."""
metadata = {'requires_admin_approval': True}
guards = extract_guards_from_metadata(metadata)
self.assertEqual(len(guards), 1)
self.assertTrue(guards[0].requires_admin)
@@ -814,7 +811,7 @@ class MetadataExtractionTests(TestCase):
"""Test extracting assignment guard."""
metadata = {'requires_assignment': True}
guards = extract_guards_from_metadata(metadata)
self.assertEqual(len(guards), 1)
self.assertIsInstance(guards[0], AssignmentGuard)
@@ -825,21 +822,21 @@ class MetadataExtractionTests(TestCase):
'requires_assignment': True
}
guards = extract_guards_from_metadata(metadata)
self.assertEqual(len(guards), 2)
def test_extract_zero_tolerance_guard(self):
"""Test extracting guard for zero tolerance (superuser required)."""
metadata = {'zero_tolerance': True}
guards = extract_guards_from_metadata(metadata)
self.assertEqual(len(guards), 1)
self.assertTrue(guards[0].requires_superuser)
def test_invalid_escalation_level_raises(self):
"""Test that invalid escalation level raises ValueError."""
metadata = {'escalation_level': 'invalid'}
with self.assertRaises(ValueError):
extract_guards_from_metadata(metadata)
@@ -859,36 +856,36 @@ class MetadataValidationTests(TestCase):
'escalation_level': 'admin',
'requires_assignment': False
}
is_valid, errors = validate_guard_metadata(metadata)
self.assertTrue(is_valid)
self.assertEqual(len(errors), 0)
def test_invalid_escalation_level(self):
"""Test that invalid escalation level fails validation."""
metadata = {'escalation_level': 'invalid_level'}
is_valid, errors = validate_guard_metadata(metadata)
self.assertFalse(is_valid)
self.assertTrue(any('escalation_level' in e for e in errors))
def test_invalid_boolean_field(self):
"""Test that non-boolean value for boolean field fails validation."""
metadata = {'requires_moderator': 'yes'}
is_valid, errors = validate_guard_metadata(metadata)
self.assertFalse(is_valid)
self.assertTrue(any('requires_moderator' in e for e in errors))
def test_required_permissions_not_list(self):
"""Test that non-list required_permissions fails validation."""
metadata = {'required_permissions': 'app.permission'}
is_valid, errors = validate_guard_metadata(metadata)
self.assertFalse(is_valid)
self.assertTrue(any('required_permissions' in e for e in errors))
@@ -965,7 +962,7 @@ class RoleHelperTests(TestCase):
def test_anonymous_user_has_no_role(self):
"""Test that anonymous user has no role."""
anonymous = AnonymousUser()
self.assertFalse(has_role(anonymous, ['USER']))
self.assertFalse(is_moderator_or_above(anonymous))
self.assertFalse(is_admin_or_above(anonymous))

View File

@@ -1,14 +1,14 @@
"""Integration tests for state machine model integration."""
import pytest
from unittest.mock import Mock, patch
from django.core.exceptions import ImproperlyConfigured
import pytest
from apps.core.choices.base import RichChoice
from apps.core.choices.registry import registry
from apps.core.state_machine.integration import (
StateMachineModelMixin,
apply_state_machine,
generate_transition_methods_for_model,
StateMachineModelMixin,
state_machine_model,
validate_model_state_machine,
)

View File

@@ -4,8 +4,8 @@ import pytest
from apps.core.choices.base import RichChoice
from apps.core.choices.registry import registry
from apps.core.state_machine.registry import (
TransitionRegistry,
TransitionInfo,
TransitionRegistry,
registry_instance,
)

View File

@@ -5,8 +5,8 @@ from apps.core.choices.base import RichChoice
from apps.core.choices.registry import registry
from apps.core.state_machine.validators import (
MetadataValidator,
ValidationResult,
ValidationError,
ValidationResult,
ValidationWarning,
validate_on_registration,
)

View File

@@ -1,9 +1,8 @@
"""Metadata validators for ensuring RichChoice metadata meets FSM requirements."""
from dataclasses import dataclass, field
from typing import List, Dict, Set, Optional, Any
from typing import Any
from apps.core.state_machine.builder import StateTransitionBuilder
from apps.core.choices.registry import registry
@dataclass
@@ -12,8 +11,8 @@ class ValidationError:
code: str
message: str
state: Optional[str] = None
metadata: Dict[str, Any] = field(default_factory=dict)
state: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
def __str__(self):
"""String representation of the error."""
@@ -28,7 +27,7 @@ class ValidationWarning:
code: str
message: str
state: Optional[str] = None
state: str | None = None
def __str__(self):
"""String representation of the warning."""
@@ -42,15 +41,15 @@ class ValidationResult:
"""Result of metadata validation."""
is_valid: bool
errors: List[ValidationError] = field(default_factory=list)
warnings: List[ValidationWarning] = field(default_factory=list)
errors: list[ValidationError] = field(default_factory=list)
warnings: list[ValidationWarning] = field(default_factory=list)
def add_error(self, code: str, message: str, state: Optional[str] = None):
def add_error(self, code: str, message: str, state: str | None = None):
"""Add a validation error."""
self.errors.append(ValidationError(code, message, state))
self.is_valid = False
def add_warning(self, code: str, message: str, state: Optional[str] = None):
def add_warning(self, code: str, message: str, state: str | None = None):
"""Add a validation warning."""
self.warnings.append(ValidationWarning(code, message, state))
@@ -91,7 +90,7 @@ class MetadataValidator:
return result
def validate_transitions(self) -> List[ValidationError]:
def validate_transitions(self) -> list[ValidationError]:
"""
Check all can_transition_to references exist.
@@ -148,7 +147,7 @@ class MetadataValidator:
return errors
def validate_terminal_states(self) -> List[ValidationError]:
def validate_terminal_states(self) -> list[ValidationError]:
"""
Ensure terminal states have no outgoing transitions.
@@ -175,7 +174,7 @@ class MetadataValidator:
return errors
def validate_permission_consistency(self) -> List[ValidationError]:
def validate_permission_consistency(self) -> list[ValidationError]:
"""
Check permission requirements are consistent.
@@ -206,7 +205,7 @@ class MetadataValidator:
return errors
def validate_no_cycles(self) -> List[ValidationError]:
def validate_no_cycles(self) -> list[ValidationError]:
"""
Detect invalid state cycles (excluding self-loops).
@@ -224,10 +223,10 @@ class MetadataValidator:
pass
# Detect cycles using DFS
visited: Set[str] = set()
rec_stack: Set[str] = set()
visited: set[str] = set()
rec_stack: set[str] = set()
def has_cycle(node: str, path: List[str]) -> Optional[List[str]]:
def has_cycle(node: str, path: list[str]) -> list[str] | None:
visited.add(node)
rec_stack.add(node)
path.append(node)
@@ -262,7 +261,7 @@ class MetadataValidator:
return errors
def validate_reachability(self) -> List[ValidationError]:
def validate_reachability(self) -> list[ValidationError]:
"""
Ensure all states are reachable from initial states.
@@ -274,7 +273,7 @@ class MetadataValidator:
all_states = set(self.builder.get_all_states())
# Find states with no incoming transitions (potential initial states)
incoming: Dict[str, List[str]] = {state: [] for state in all_states}
incoming: dict[str, list[str]] = {state: [] for state in all_states}
for source, targets in graph.items():
for target in targets:
incoming[target].append(source)
@@ -293,7 +292,7 @@ class MetadataValidator:
return errors
# BFS from initial states to find reachable states
reachable: Set[str] = set(initial_states)
reachable: set[str] = set(initial_states)
queue = list(initial_states)
while queue:

View File

@@ -7,12 +7,13 @@ All tasks run asynchronously to avoid blocking the main application.
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Any
from typing import Any
from celery import shared_task
from django.utils import timezone
from django.core.cache import cache
from django.contrib.contenttypes.models import ContentType
from django.core.cache import cache
from django.db.models import Q
from django.utils import timezone
from apps.core.analytics import PageView
from apps.parks.models import Park
@@ -24,7 +25,7 @@ logger = logging.getLogger(__name__)
@shared_task(bind=True, max_retries=3, default_retry_delay=60)
def calculate_trending_content(
self, content_type: str = "all", limit: int = 50
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Calculate trending content using real analytics data.
@@ -100,7 +101,7 @@ def calculate_trending_content(
@shared_task(bind=True, max_retries=3, default_retry_delay=30)
def calculate_new_content(
self, content_type: str = "all", days_back: int = 30, limit: int = 50
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Calculate new content based on opening dates and creation dates.
@@ -157,7 +158,7 @@ def calculate_new_content(
@shared_task(bind=True)
def warm_trending_cache(self) -> Dict[str, Any]:
def warm_trending_cache(self) -> dict[str, Any]:
"""
Warm the trending cache by pre-calculating common queries.
@@ -208,7 +209,7 @@ def warm_trending_cache(self) -> Dict[str, Any]:
def _calculate_trending_parks(
current_period_hours: int, previous_period_hours: int, limit: int
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""Calculate trending scores for parks using real data."""
parks = Park.objects.filter(status="OPERATING").select_related(
"location", "operator"
@@ -247,7 +248,7 @@ def _calculate_trending_parks(
def _calculate_trending_rides(
current_period_hours: int, previous_period_hours: int, limit: int
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""Calculate trending scores for rides using real data."""
rides = Ride.objects.filter(status="OPERATING").select_related(
"park", "park__location"
@@ -453,7 +454,7 @@ def _calculate_popularity_score(
return 0.0
def _get_new_parks(cutoff_date: datetime, limit: int) -> List[Dict[str, Any]]:
def _get_new_parks(cutoff_date: datetime, limit: int) -> list[dict[str, Any]]:
"""Get recently added parks using real data."""
new_parks = (
Park.objects.filter(
@@ -467,9 +468,8 @@ def _get_new_parks(cutoff_date: datetime, limit: int) -> List[Dict[str, Any]]:
results = []
for park in new_parks:
date_added = park.opening_date or park.created_at
if date_added:
if isinstance(date_added, datetime):
date_added = date_added.date()
if date_added and isinstance(date_added, datetime):
date_added = date_added.date()
opening_date = getattr(park, "opening_date", None)
if opening_date and isinstance(opening_date, datetime):
@@ -492,7 +492,7 @@ def _get_new_parks(cutoff_date: datetime, limit: int) -> List[Dict[str, Any]]:
return results
def _get_new_rides(cutoff_date: datetime, limit: int) -> List[Dict[str, Any]]:
def _get_new_rides(cutoff_date: datetime, limit: int) -> list[dict[str, Any]]:
"""Get recently added rides using real data."""
new_rides = (
Ride.objects.filter(
@@ -508,9 +508,8 @@ def _get_new_rides(cutoff_date: datetime, limit: int) -> List[Dict[str, Any]]:
date_added = getattr(ride, "opening_date", None) or getattr(
ride, "created_at", None
)
if date_added:
if isinstance(date_added, datetime):
date_added = date_added.date()
if date_added and isinstance(date_added, datetime):
date_added = date_added.date()
opening_date = getattr(ride, "opening_date", None)
if opening_date and isinstance(opening_date, datetime):
@@ -534,10 +533,10 @@ def _get_new_rides(cutoff_date: datetime, limit: int) -> List[Dict[str, Any]]:
def _format_trending_results(
trending_items: List[Dict[str, Any]],
trending_items: list[dict[str, Any]],
current_period_hours: int,
previous_period_hours: int,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""Format trending results for frontend consumption."""
formatted_results = []
@@ -581,8 +580,8 @@ def _format_trending_results(
def _format_new_content_results(
new_items: List[Dict[str, Any]],
) -> List[Dict[str, Any]]:
new_items: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Format new content results for frontend consumption."""
formatted_results = []

View File

@@ -14,10 +14,10 @@ Usage:
"""
from datetime import timedelta
from django import template
from django.template.defaultfilters import stringfilter
from django.utils import timezone
from django.utils.html import format_html
register = template.Library()

View File

@@ -23,13 +23,13 @@ Usage:
{# Render a transition button #}
{% transition_button submission 'approve' request.user %}
"""
from typing import Any, Dict, List, Optional
from typing import Any
from django import template
from django.urls import reverse, NoReverseMatch
from django.urls import NoReverseMatch, reverse
from django_fsm import can_proceed
from apps.core.views.views import get_transition_metadata, TRANSITION_METADATA
from apps.core.views.views import get_transition_metadata
register = template.Library()
@@ -40,7 +40,7 @@ register = template.Library()
@register.filter
def get_state_value(obj) -> Optional[str]:
def get_state_value(obj) -> str | None:
"""
Get the current state value of an FSM-enabled object.
@@ -171,7 +171,7 @@ def default_target_id(obj) -> str:
@register.simple_tag
def get_available_transitions(obj, user) -> List[Dict[str, Any]]:
def get_available_transitions(obj, user) -> list[dict[str, Any]]:
"""
Get all available transitions for an object that the user can execute.

View File

@@ -28,19 +28,26 @@ Usage:
{% icon "check" class="w-4 h-4" %}
"""
import json
from django import template
from django.utils.safestring import mark_safe
from apps.core.utils.html_sanitizer import (
sanitize_html,
sanitize_minimal as _sanitize_minimal,
sanitize_svg,
strip_html as _strip_html,
sanitize_for_json,
escape_js_string as _escape_js_string,
sanitize_url as _sanitize_url,
)
from apps.core.utils.html_sanitizer import (
sanitize_attribute_value,
sanitize_for_json,
sanitize_html,
sanitize_svg,
)
from apps.core.utils.html_sanitizer import (
sanitize_minimal as _sanitize_minimal,
)
from apps.core.utils.html_sanitizer import (
sanitize_url as _sanitize_url,
)
from apps.core.utils.html_sanitizer import (
strip_html as _strip_html,
)
register = template.Library()

View File

@@ -5,7 +5,6 @@ These tests verify the functionality of the base admin classes and mixins
that provide standardized behavior across all admin interfaces.
"""
import pytest
from django.contrib.admin.sites import AdminSite
from django.contrib.auth import get_user_model
from django.test import RequestFactory, TestCase

View File

@@ -1,8 +1,9 @@
import pghistory
import pytest
from django.contrib.auth import get_user_model
from apps.parks.models import Park, Company
import pghistory
from apps.parks.models import Company, Park
User = get_user_model()
@@ -16,7 +17,7 @@ class TestTrackedModel:
"""Test that creating a model instance creates a history event."""
user = User.objects.create_user(username="testuser", password="password")
company = Company.objects.create(name="Test Operator", roles=["OPERATOR"])
with pghistory.context(user=user.id):
park = Park.objects.create(
name="History Test Park",
@@ -24,13 +25,13 @@ class TestTrackedModel:
operating_season="Summer",
operator=company
)
# Verify history using the helper method from TrackedModel
events = park.get_history()
assert events.count() == 1
event = events.first()
assert event.pgh_obj_id == park.pk
# Verify context was captured
# The middleware isn't running here, so we used pghistory.context explicitly
# But pghistory.context stores data in pgh_context field if configured?
@@ -40,15 +41,15 @@ class TestTrackedModel:
def test_update_tracking(self):
company = Company.objects.create(name="Test Operator 2", roles=["OPERATOR"])
park = Park.objects.create(name="Original", operator=company)
# Initial create event
assert park.get_history().count() == 1
# Update
park.name = "Updated"
park.save()
assert park.get_history().count() == 2
latest = park.get_history().first() # Ordered by -pgh_created_at
assert latest.name == "Updated"

View File

@@ -2,7 +2,8 @@
Core app URL configuration.
"""
from django.urls import path, include
from django.urls import include, path
from ..views.entity_search import (
EntityFuzzySearchView,
EntityNotFoundView,

View File

@@ -3,13 +3,14 @@ URL patterns for the unified map service API.
"""
from django.urls import path
from ..views.map_views import (
MapLocationsView,
MapLocationDetailView,
MapSearchView,
MapBoundsView,
MapStatsView,
MapCacheView,
MapLocationDetailView,
MapLocationsView,
MapSearchView,
MapStatsView,
)
app_name = "map_api"

View File

@@ -4,15 +4,16 @@ Includes both HTML views and HTMX endpoints.
"""
from django.urls import path
from ..views.maps import (
UniversalMapView,
ParkMapView,
NearbyLocationsView,
LocationDetailModalView,
LocationFilterView,
LocationListView,
LocationSearchView,
MapBoundsUpdateView,
LocationDetailModalView,
LocationListView,
NearbyLocationsView,
ParkMapView,
UniversalMapView,
)
app_name = "maps"

View File

@@ -1,4 +1,5 @@
from django.urls import path
from apps.core.views.search import (
AdaptiveSearchView,
FilterFormView,

View File

@@ -29,7 +29,8 @@ Usage Examples:
from __future__ import annotations
from dataclasses import dataclass, field
import contextlib
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from urllib.parse import urljoin
@@ -351,19 +352,15 @@ def get_model_breadcrumb(
parent_list_url = f"{parent_model_name}s:list"
parent_list_label = f"{parent.__class__.__name__}s"
try:
with contextlib.suppress(Exception):
builder.add_from_url(parent_list_url, parent_list_label)
except Exception:
pass
builder.add_model(parent)
# Add list page breadcrumb
if list_url_name and list_label:
try:
with contextlib.suppress(Exception):
builder.add_from_url(list_url_name, list_label)
except Exception:
pass
# Add current model instance
builder.add_model_current(instance)

View File

@@ -1,53 +1,54 @@
import logging
import requests
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
import logging
logger = logging.getLogger(__name__)
def get_direct_upload_url(user_id=None):
"""
Generates a direct upload URL for Cloudflare Images.
Args:
user_id (str, optional): The user ID to associate with the upload.
Returns:
dict: A dictionary containing 'id' and 'uploadURL'.
Raises:
ImproperlyConfigured: If Cloudflare settings are missing.
requests.RequestException: If the Cloudflare API request fails.
"""
account_id = getattr(settings, 'CLOUDFLARE_IMAGES_ACCOUNT_ID', None)
api_token = getattr(settings, 'CLOUDFLARE_IMAGES_API_TOKEN', None)
if not account_id or not api_token:
raise ImproperlyConfigured(
"CLOUDFLARE_IMAGES_ACCOUNT_ID and CLOUDFLARE_IMAGES_API_TOKEN must be set."
)
url = f"https://api.cloudflare.com/client/v4/accounts/{account_id}/images/v2/direct_upload"
headers = {
"Authorization": f"Bearer {api_token}",
}
data = {
"requireSignedURLs": "false",
}
if user_id:
data["metadata"] = f'{{"user_id": "{user_id}"}}'
response = requests.post(url, headers=headers, data=data)
response.raise_for_status()
result = response.json()
if not result.get("success"):
error_msg = result.get("errors", [{"message": "Unknown error"}])[0].get("message")
logger.error(f"Cloudflare Direct Upload Error: {error_msg}")
raise requests.RequestException(f"Cloudflare Error: {error_msg}")
return result.get("result", {})

View File

@@ -6,7 +6,7 @@ ensuring consistent logging, user messages, and API responses.
"""
import logging
from typing import Any, Dict, Optional
from typing import Any
from django.contrib import messages
from django.http import HttpRequest
@@ -26,7 +26,7 @@ class ErrorHandler:
request: HttpRequest,
error: Exception,
user_message: str = "An error occurred",
log_message: Optional[str] = None,
log_message: str | None = None,
level: str = "error",
) -> None:
"""
@@ -68,7 +68,7 @@ class ErrorHandler:
def handle_api_error(
error: Exception,
user_message: str = "An error occurred",
log_message: Optional[str] = None,
log_message: str | None = None,
status_code: int = status.HTTP_400_BAD_REQUEST,
) -> Response:
"""
@@ -99,7 +99,7 @@ class ErrorHandler:
logger.error(log_msg, exc_info=True)
# Build error response
error_data: Dict[str, Any] = {
error_data: dict[str, Any] = {
"error": user_message,
"detail": str(error),
}
@@ -150,7 +150,7 @@ class ErrorHandler:
Returns:
DRF Response with success data in standard format
"""
response_data: Dict[str, Any] = {
response_data: dict[str, Any] = {
"status": "success",
"message": message,
}

View File

@@ -30,7 +30,6 @@ import os
import re
import uuid
from io import BytesIO
from typing import Optional, Set, Tuple
from django.core.exceptions import ValidationError
from django.core.files.uploadedfile import UploadedFile
@@ -72,7 +71,7 @@ IMAGE_SIGNATURES = {
}
# All allowed MIME types
ALLOWED_IMAGE_MIME_TYPES: Set[str] = frozenset({
ALLOWED_IMAGE_MIME_TYPES: set[str] = frozenset({
'image/jpeg',
'image/png',
'image/gif',
@@ -80,7 +79,7 @@ ALLOWED_IMAGE_MIME_TYPES: Set[str] = frozenset({
})
# Allowed file extensions
ALLOWED_IMAGE_EXTENSIONS: Set[str] = frozenset({
ALLOWED_IMAGE_EXTENSIONS: set[str] = frozenset({
'.jpg', '.jpeg', '.png', '.gif', '.webp',
})
@@ -98,8 +97,8 @@ MIN_FILE_SIZE = 100 # 100 bytes
def validate_image_upload(
file: UploadedFile,
max_size: int = MAX_FILE_SIZE,
allowed_types: Optional[Set[str]] = None,
allowed_extensions: Optional[Set[str]] = None,
allowed_types: set[str] | None = None,
allowed_extensions: set[str] | None = None,
) -> bool:
"""
Validate an uploaded image file for security.
@@ -191,15 +190,14 @@ def _validate_magic_number(file: UploadedFile) -> bool:
# Check against known signatures
for format_name, signatures in IMAGE_SIGNATURES.items():
for magic, offset, description in signatures:
if len(header) >= offset + len(magic):
if header[offset:offset + len(magic)] == magic:
# Special handling for WebP (must also have WEBP marker)
if format_name == 'webp':
if len(header) >= 12 and header[8:12] == b'WEBP':
return True
else:
for magic, offset, _description in signatures:
if len(header) >= offset + len(magic) and header[offset:offset + len(magic)] == magic:
# Special handling for WebP (must also have WEBP marker)
if format_name == 'webp':
if len(header) >= 12 and header[8:12] == b'WEBP':
return True
else:
return True
return False
@@ -340,7 +338,7 @@ UPLOAD_RATE_LIMITS = {
}
def check_upload_rate_limit(user_id: int, cache_backend=None) -> Tuple[bool, str]:
def check_upload_rate_limit(user_id: int, cache_backend=None) -> tuple[bool, str]:
"""
Check if user has exceeded upload rate limits.
@@ -414,7 +412,7 @@ def increment_upload_count(user_id: int, cache_backend=None) -> None:
# Antivirus Integration Point
# =============================================================================
def scan_file_for_malware(file: UploadedFile) -> Tuple[bool, str]:
def scan_file_for_malware(file: UploadedFile) -> tuple[bool, str]:
"""
Placeholder for antivirus/malware scanning integration.

View File

@@ -16,8 +16,6 @@ Usage Examples:
from __future__ import annotations
from typing import Any
def success_created(
model_name: str,

View File

@@ -2,14 +2,15 @@
Database query optimization utilities and helpers.
"""
import time
import logging
import time
from contextlib import contextmanager
from typing import Optional, Dict, Any, List, Type
from django.db import connection, models
from django.db.models import QuerySet, Prefetch, Count, Avg, Max
from typing import Any
from django.conf import settings
from django.core.cache import cache
from django.db import connection, models
from django.db.models import Avg, Count, Max, Prefetch, QuerySet
logger = logging.getLogger("query_optimization")
@@ -136,7 +137,7 @@ class QueryOptimizer:
)
@staticmethod
def create_bulk_queryset(model: Type[models.Model], ids: List[int]) -> QuerySet:
def create_bulk_queryset(model: type[models.Model], ids: list[int]) -> QuerySet:
"""
Create an optimized queryset for bulk operations
"""
@@ -186,7 +187,7 @@ class QueryCache:
return result
@staticmethod
def invalidate_model_cache(model_name: str, instance_id: Optional[int] = None):
def invalidate_model_cache(model_name: str, instance_id: int | None = None):
"""
Invalidate cache keys related to a specific model
@@ -195,10 +196,7 @@ class QueryCache:
instance_id: Specific instance ID, if applicable
"""
# Pattern-based cache invalidation (works with Redis)
if instance_id:
pattern = f"*{model_name}_{instance_id}*"
else:
pattern = f"*{model_name}*"
pattern = f"*{model_name}_{instance_id}*" if instance_id else f"*{model_name}*"
try:
# For Redis cache backends that support pattern deletion
@@ -219,7 +217,7 @@ class IndexAnalyzer:
"""Analyze and suggest database indexes"""
@staticmethod
def analyze_slow_queries(min_time: float = 0.1) -> List[Dict[str, Any]]:
def analyze_slow_queries(min_time: float = 0.1) -> list[dict[str, Any]]:
"""
Analyze slow queries from the current request
@@ -244,7 +242,7 @@ class IndexAnalyzer:
return slow_queries
@staticmethod
def _analyze_query_sql(sql: str) -> Dict[str, Any]:
def _analyze_query_sql(sql: str) -> dict[str, Any]:
"""
Analyze SQL to suggest potential optimizations
"""
@@ -285,7 +283,7 @@ class IndexAnalyzer:
return analysis
@staticmethod
def suggest_model_indexes(model: Type[models.Model]) -> List[str]:
def suggest_model_indexes(model: type[models.Model]) -> list[str]:
"""
Suggest database indexes for a Django model based on its fields
"""
@@ -343,7 +341,7 @@ def log_query_performance():
def optimize_queryset_for_serialization(
queryset: QuerySet, fields: List[str]
queryset: QuerySet, fields: list[str]
) -> QuerySet:
"""
Optimize a queryset for API serialization by only selecting needed fields

Some files were not shown because too many files have changed in this diff Show More