feat: Implement initial schema and add various API, service, and management command enhancements across the application.

This commit is contained in:
pacnpal
2026-01-01 15:13:01 -05:00
parent c95f99ca10
commit b243b17af7
413 changed files with 11164 additions and 17433 deletions

View File

@@ -9,4 +9,4 @@ system status, and other foundational features.
from .choices import core_choices
# Ensure choices are registered on app startup
__all__ = ['core_choices']
__all__ = ["core_choices"]

View File

@@ -23,9 +23,7 @@ from .models import SlugHistory
@admin.register(SlugHistory)
class SlugHistoryAdmin(
ReadOnlyAdminMixin, QueryOptimizationMixin, ExportActionMixin, BaseModelAdmin
):
class SlugHistoryAdmin(ReadOnlyAdminMixin, QueryOptimizationMixin, ExportActionMixin, BaseModelAdmin):
"""
Admin interface for SlugHistory management.

View File

@@ -221,13 +221,9 @@ class ExportActionMixin:
writer.writerow(row)
response = HttpResponse(output.getvalue(), content_type="text/csv")
response["Content-Disposition"] = (
f'attachment; filename="{self.get_export_filename("csv")}"'
)
response["Content-Disposition"] = f'attachment; filename="{self.get_export_filename("csv")}"'
self.message_user(
request, f"Successfully exported {queryset.count()} records to CSV."
)
self.message_user(request, f"Successfully exported {queryset.count()} records to CSV.")
return response
@admin.action(description="Export selected to JSON")
@@ -250,13 +246,9 @@ class ExportActionMixin:
json.dumps(data, indent=2, cls=DjangoJSONEncoder),
content_type="application/json",
)
response["Content-Disposition"] = (
f'attachment; filename="{self.get_export_filename("json")}"'
)
response["Content-Disposition"] = f'attachment; filename="{self.get_export_filename("json")}"'
self.message_user(
request, f"Successfully exported {queryset.count()} records to JSON."
)
self.message_user(request, f"Successfully exported {queryset.count()} records to JSON.")
return response
def get_actions(self, request):

View File

@@ -10,9 +10,7 @@ from django.utils import timezone
@pghistory.track()
class PageView(models.Model):
content_type = models.ForeignKey(
ContentType, on_delete=models.CASCADE, related_name="page_views"
)
content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE, related_name="page_views")
object_id = models.PositiveIntegerField()
content_object = GenericForeignKey("content_type", "object_id")
@@ -64,9 +62,7 @@ class PageView(models.Model):
return model_class.objects.none()
@classmethod
def get_views_growth(
cls, content_type, object_id, current_period_hours, previous_period_hours
):
def get_views_growth(cls, content_type, object_id, current_period_hours, previous_period_hours):
"""Get view growth statistics between two time periods.
Args:
@@ -102,9 +98,7 @@ class PageView(models.Model):
if previous_views == 0:
growth_percentage = current_views * 100 if current_views > 0 else 0
else:
growth_percentage = (
(current_views - previous_views) / previous_views
) * 100
growth_percentage = ((current_views - previous_views) / previous_views) * 100
return current_views, previous_views, growth_percentage
@@ -121,6 +115,4 @@ class PageView(models.Model):
int: Total view count
"""
cutoff = timezone.now() - timedelta(hours=hours)
return cls.objects.filter(
content_type=content_type, object_id=object_id, timestamp__gte=cutoff
).count()
return cls.objects.filter(content_type=content_type, object_id=object_id, timestamp__gte=cutoff).count()

View File

@@ -31,9 +31,7 @@ from ..logging import get_logger, log_exception
logger = get_logger(__name__)
def custom_exception_handler(
exc: Exception, context: dict[str, Any]
) -> Response | None:
def custom_exception_handler(exc: Exception, context: dict[str, Any]) -> Response | None:
"""
Custom exception handler for DRF that provides standardized error responses.
@@ -172,9 +170,7 @@ def custom_exception_handler(
request=request,
)
response = Response(
custom_response_data, status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
response = Response(custom_response_data, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
return response
@@ -234,10 +230,7 @@ def _format_django_validation_errors(
"""Format Django ValidationError for API response."""
if hasattr(exc, "error_dict"):
# Field-specific errors
return {
field: [str(error) for error in errors]
for field, errors in exc.error_dict.items()
}
return {field: [str(error) for error in errors] for field, errors in exc.error_dict.items()}
elif hasattr(exc, "error_list"):
# Non-field errors
return {"non_field_errors": [str(error) for error in exc.error_list]}

View File

@@ -103,15 +103,11 @@ class ApiMixin:
# These will raise if not implemented; they also inform static analyzers about their existence.
def paginate_queryset(self, queryset):
"""Override / implement in subclass or provided base if pagination is needed."""
raise NotImplementedError(
"Subclasses must implement paginate_queryset to enable pagination"
)
raise NotImplementedError("Subclasses must implement paginate_queryset to enable pagination")
def get_paginated_response(self, data):
"""Override / implement in subclass or provided base to return paginated responses."""
raise NotImplementedError(
"Subclasses must implement get_paginated_response to enable pagination"
)
raise NotImplementedError("Subclasses must implement get_paginated_response to enable pagination")
def get_object(self):
"""Default placeholder; subclasses should implement this."""
@@ -168,9 +164,7 @@ class UpdateApiMixin(ApiMixin):
def update(self, _request: Request, *_args, **_kwargs) -> Response:
"""Handle PUT/PATCH requests for updating resources."""
instance = self.get_object()
serializer = self.get_input_serializer(
data=_request.data, partial=_kwargs.get("partial", False)
)
serializer = self.get_input_serializer(data=_request.data, partial=_kwargs.get("partial", False))
serializer.is_valid(raise_exception=True)
# Update the object using the service layer
@@ -229,9 +223,7 @@ class ListApiMixin(ApiMixin):
Override this method to use selector patterns.
Should call selector functions, not access model managers directly.
"""
raise NotImplementedError(
"Subclasses must implement get_queryset using selectors"
)
raise NotImplementedError("Subclasses must implement get_queryset using selectors")
def get_output_serializer(self, *args, **kwargs):
"""Get the output serializer for response."""

View File

@@ -28,6 +28,7 @@ from django.core.checks import Error, Tags, Warning, register
# Secret Key Validation
# =============================================================================
@register(Tags.security)
def check_secret_key(app_configs, **kwargs):
"""
@@ -38,30 +39,30 @@ def check_secret_key(app_configs, **kwargs):
- Key has sufficient entropy (length and character variety)
"""
errors = []
secret_key = getattr(settings, 'SECRET_KEY', '')
secret_key = getattr(settings, "SECRET_KEY", "")
# Check for empty or missing key
if not secret_key:
errors.append(
Error(
'SECRET_KEY is not set.',
hint='Set a strong, random SECRET_KEY in your environment.',
id='security.E001',
"SECRET_KEY is not set.",
hint="Set a strong, random SECRET_KEY in your environment.",
id="security.E001",
)
)
return errors
# Check for known insecure default values
insecure_defaults = [
'django-insecure',
'your-secret-key',
'change-me',
'changeme',
'secret',
'xxx',
'test',
'development',
'dev-key',
"django-insecure",
"your-secret-key",
"change-me",
"changeme",
"secret",
"xxx",
"test",
"development",
"dev-key",
]
key_lower = secret_key.lower()
@@ -71,7 +72,7 @@ def check_secret_key(app_configs, **kwargs):
Error(
f'SECRET_KEY appears to contain an insecure default value: "{default}"',
hint='Generate a new secret key using: python -c "from django.core.management.utils import get_random_secret_key; print(get_random_secret_key())"',
id='security.E002',
id="security.E002",
)
)
break
@@ -80,25 +81,25 @@ def check_secret_key(app_configs, **kwargs):
if len(secret_key) < 50:
errors.append(
Warning(
f'SECRET_KEY is only {len(secret_key)} characters long.',
hint='A secret key should be at least 50 characters for proper security.',
id='security.W001',
f"SECRET_KEY is only {len(secret_key)} characters long.",
hint="A secret key should be at least 50 characters for proper security.",
id="security.W001",
)
)
# Check for sufficient character variety
has_upper = bool(re.search(r'[A-Z]', secret_key))
has_lower = bool(re.search(r'[a-z]', secret_key))
has_digit = bool(re.search(r'[0-9]', secret_key))
has_upper = bool(re.search(r"[A-Z]", secret_key))
has_lower = bool(re.search(r"[a-z]", secret_key))
has_digit = bool(re.search(r"[0-9]", secret_key))
has_special = bool(re.search(r'[!@#$%^&*()_+\-=\[\]{};\':"\\|,.<>\/?]', secret_key))
char_types = sum([has_upper, has_lower, has_digit, has_special])
if char_types < 3:
errors.append(
Warning(
'SECRET_KEY lacks character variety.',
hint='A good secret key should contain uppercase, lowercase, digits, and special characters.',
id='security.W002',
"SECRET_KEY lacks character variety.",
hint="A good secret key should contain uppercase, lowercase, digits, and special characters.",
id="security.W002",
)
)
@@ -109,6 +110,7 @@ def check_secret_key(app_configs, **kwargs):
# Debug Mode Check
# =============================================================================
@register(Tags.security)
def check_debug_mode(app_configs, **kwargs):
"""
@@ -117,27 +119,27 @@ def check_debug_mode(app_configs, **kwargs):
errors = []
# Check if we're in a production-like environment
env = os.environ.get('DJANGO_SETTINGS_MODULE', '')
is_production = 'production' in env.lower() or 'prod' in env.lower()
env = os.environ.get("DJANGO_SETTINGS_MODULE", "")
is_production = "production" in env.lower() or "prod" in env.lower()
if is_production and settings.DEBUG:
errors.append(
Error(
'DEBUG is True in what appears to be a production environment.',
hint='Set DEBUG=False in production settings.',
id='security.E003',
"DEBUG is True in what appears to be a production environment.",
hint="Set DEBUG=False in production settings.",
id="security.E003",
)
)
# Also check if DEBUG is True with ALLOWED_HOSTS configured
# (indicates possible production deployment with debug on)
if settings.DEBUG and settings.ALLOWED_HOSTS and '*' not in settings.ALLOWED_HOSTS:
if len(settings.ALLOWED_HOSTS) > 0 and 'localhost' not in settings.ALLOWED_HOSTS[0]:
if settings.DEBUG and settings.ALLOWED_HOSTS and "*" not in settings.ALLOWED_HOSTS: # noqa: SIM102
if len(settings.ALLOWED_HOSTS) > 0 and "localhost" not in settings.ALLOWED_HOSTS[0]:
errors.append(
Warning(
'DEBUG is True but ALLOWED_HOSTS contains non-localhost values.',
hint='This may indicate DEBUG is accidentally enabled in a deployed environment.',
id='security.W003',
"DEBUG is True but ALLOWED_HOSTS contains non-localhost values.",
hint="This may indicate DEBUG is accidentally enabled in a deployed environment.",
id="security.W003",
)
)
@@ -148,30 +150,31 @@ def check_debug_mode(app_configs, **kwargs):
# ALLOWED_HOSTS Check
# =============================================================================
@register(Tags.security)
def check_allowed_hosts(app_configs, **kwargs):
"""
Check ALLOWED_HOSTS configuration.
"""
errors = []
allowed_hosts = getattr(settings, 'ALLOWED_HOSTS', [])
allowed_hosts = getattr(settings, "ALLOWED_HOSTS", [])
if not settings.DEBUG:
# In non-debug mode, ALLOWED_HOSTS must be set
if not allowed_hosts:
errors.append(
Error(
'ALLOWED_HOSTS is empty but DEBUG is False.',
hint='Set ALLOWED_HOSTS to a list of allowed hostnames.',
id='security.E004',
"ALLOWED_HOSTS is empty but DEBUG is False.",
hint="Set ALLOWED_HOSTS to a list of allowed hostnames.",
id="security.E004",
)
)
elif '*' in allowed_hosts:
elif "*" in allowed_hosts:
errors.append(
Error(
'ALLOWED_HOSTS contains "*" which allows all hosts.',
hint='Specify explicit hostnames instead of wildcards.',
id='security.E005',
hint="Specify explicit hostnames instead of wildcards.",
id="security.E005",
)
)
@@ -182,6 +185,7 @@ def check_allowed_hosts(app_configs, **kwargs):
# Security Headers Check
# =============================================================================
@register(Tags.security)
def check_security_headers(app_configs, **kwargs):
"""
@@ -190,34 +194,34 @@ def check_security_headers(app_configs, **kwargs):
errors = []
# Check X-Frame-Options
x_frame_options = getattr(settings, 'X_FRAME_OPTIONS', None)
if x_frame_options not in ('DENY', 'SAMEORIGIN'):
x_frame_options = getattr(settings, "X_FRAME_OPTIONS", None)
if x_frame_options not in ("DENY", "SAMEORIGIN"):
errors.append(
Warning(
f'X_FRAME_OPTIONS is set to "{x_frame_options}" or not set.',
hint='Set X_FRAME_OPTIONS to "DENY" or "SAMEORIGIN" to prevent clickjacking.',
id='security.W004',
id="security.W004",
)
)
# Check content type sniffing protection
if not getattr(settings, 'SECURE_CONTENT_TYPE_NOSNIFF', False):
if not getattr(settings, "SECURE_CONTENT_TYPE_NOSNIFF", False):
errors.append(
Warning(
'SECURE_CONTENT_TYPE_NOSNIFF is not enabled.',
hint='Set SECURE_CONTENT_TYPE_NOSNIFF = True to prevent MIME type sniffing.',
id='security.W005',
"SECURE_CONTENT_TYPE_NOSNIFF is not enabled.",
hint="Set SECURE_CONTENT_TYPE_NOSNIFF = True to prevent MIME type sniffing.",
id="security.W005",
)
)
# Check referrer policy
referrer_policy = getattr(settings, 'SECURE_REFERRER_POLICY', None)
referrer_policy = getattr(settings, "SECURE_REFERRER_POLICY", None)
if not referrer_policy:
errors.append(
Warning(
'SECURE_REFERRER_POLICY is not set.',
hint='Set SECURE_REFERRER_POLICY to control referrer header behavior.',
id='security.W006',
"SECURE_REFERRER_POLICY is not set.",
hint="Set SECURE_REFERRER_POLICY to control referrer header behavior.",
id="security.W006",
)
)
@@ -228,6 +232,7 @@ def check_security_headers(app_configs, **kwargs):
# HTTPS Settings Check
# =============================================================================
@register(Tags.security)
def check_https_settings(app_configs, **kwargs):
"""
@@ -240,32 +245,32 @@ def check_https_settings(app_configs, **kwargs):
return errors
# Check SSL redirect
if not getattr(settings, 'SECURE_SSL_REDIRECT', False):
if not getattr(settings, "SECURE_SSL_REDIRECT", False):
errors.append(
Warning(
'SECURE_SSL_REDIRECT is not enabled.',
hint='Set SECURE_SSL_REDIRECT = True to redirect HTTP to HTTPS.',
id='security.W007',
"SECURE_SSL_REDIRECT is not enabled.",
hint="Set SECURE_SSL_REDIRECT = True to redirect HTTP to HTTPS.",
id="security.W007",
)
)
# Check HSTS settings
hsts_seconds = getattr(settings, 'SECURE_HSTS_SECONDS', 0)
hsts_seconds = getattr(settings, "SECURE_HSTS_SECONDS", 0)
if hsts_seconds < 31536000: # Less than 1 year
errors.append(
Warning(
f'SECURE_HSTS_SECONDS is {hsts_seconds} (less than 1 year).',
hint='Set SECURE_HSTS_SECONDS to at least 31536000 (1 year) for HSTS preload eligibility.',
id='security.W008',
f"SECURE_HSTS_SECONDS is {hsts_seconds} (less than 1 year).",
hint="Set SECURE_HSTS_SECONDS to at least 31536000 (1 year) for HSTS preload eligibility.",
id="security.W008",
)
)
if not getattr(settings, 'SECURE_HSTS_INCLUDE_SUBDOMAINS', False):
if not getattr(settings, "SECURE_HSTS_INCLUDE_SUBDOMAINS", False):
errors.append(
Warning(
'SECURE_HSTS_INCLUDE_SUBDOMAINS is not enabled.',
hint='Set SECURE_HSTS_INCLUDE_SUBDOMAINS = True to include all subdomains in HSTS.',
id='security.W009',
"SECURE_HSTS_INCLUDE_SUBDOMAINS is not enabled.",
hint="Set SECURE_HSTS_INCLUDE_SUBDOMAINS = True to include all subdomains in HSTS.",
id="security.W009",
)
)
@@ -276,6 +281,7 @@ def check_https_settings(app_configs, **kwargs):
# Cookie Security Check
# =============================================================================
@register(Tags.security)
def check_cookie_security(app_configs, **kwargs):
"""
@@ -288,42 +294,42 @@ def check_cookie_security(app_configs, **kwargs):
return errors
# Check session cookie security
if not getattr(settings, 'SESSION_COOKIE_SECURE', False):
if not getattr(settings, "SESSION_COOKIE_SECURE", False):
errors.append(
Warning(
'SESSION_COOKIE_SECURE is not enabled.',
hint='Set SESSION_COOKIE_SECURE = True to only send session cookies over HTTPS.',
id='security.W010',
"SESSION_COOKIE_SECURE is not enabled.",
hint="Set SESSION_COOKIE_SECURE = True to only send session cookies over HTTPS.",
id="security.W010",
)
)
if not getattr(settings, 'SESSION_COOKIE_HTTPONLY', True):
if not getattr(settings, "SESSION_COOKIE_HTTPONLY", True):
errors.append(
Warning(
'SESSION_COOKIE_HTTPONLY is disabled.',
hint='Set SESSION_COOKIE_HTTPONLY = True to prevent JavaScript access to session cookies.',
id='security.W011',
"SESSION_COOKIE_HTTPONLY is disabled.",
hint="Set SESSION_COOKIE_HTTPONLY = True to prevent JavaScript access to session cookies.",
id="security.W011",
)
)
# Check CSRF cookie security
if not getattr(settings, 'CSRF_COOKIE_SECURE', False):
if not getattr(settings, "CSRF_COOKIE_SECURE", False):
errors.append(
Warning(
'CSRF_COOKIE_SECURE is not enabled.',
hint='Set CSRF_COOKIE_SECURE = True to only send CSRF cookies over HTTPS.',
id='security.W012',
"CSRF_COOKIE_SECURE is not enabled.",
hint="Set CSRF_COOKIE_SECURE = True to only send CSRF cookies over HTTPS.",
id="security.W012",
)
)
# Check SameSite attributes
session_samesite = getattr(settings, 'SESSION_COOKIE_SAMESITE', 'Lax')
if session_samesite not in ('Strict', 'Lax'):
session_samesite = getattr(settings, "SESSION_COOKIE_SAMESITE", "Lax")
if session_samesite not in ("Strict", "Lax"):
errors.append(
Warning(
f'SESSION_COOKIE_SAMESITE is set to "{session_samesite}".',
hint='Set SESSION_COOKIE_SAMESITE to "Strict" or "Lax" for CSRF protection.',
id='security.W013',
id="security.W013",
)
)
@@ -334,6 +340,7 @@ def check_cookie_security(app_configs, **kwargs):
# Database Security Check
# =============================================================================
@register(Tags.security)
def check_database_security(app_configs, **kwargs):
"""
@@ -345,27 +352,27 @@ def check_database_security(app_configs, **kwargs):
if settings.DEBUG:
return errors
databases = getattr(settings, 'DATABASES', {})
default_db = databases.get('default', {})
databases = getattr(settings, "DATABASES", {})
default_db = databases.get("default", {})
# Check for empty password
if not default_db.get('PASSWORD') and default_db.get('ENGINE', '').endswith('postgresql'):
if not default_db.get("PASSWORD") and default_db.get("ENGINE", "").endswith("postgresql"):
errors.append(
Warning(
'Database password is empty.',
hint='Set a strong password for database authentication.',
id='security.W014',
"Database password is empty.",
hint="Set a strong password for database authentication.",
id="security.W014",
)
)
# Check for SSL mode in PostgreSQL
options = default_db.get('OPTIONS', {})
if 'sslmode' not in str(options) and default_db.get('ENGINE', '').endswith('postgresql'):
options = default_db.get("OPTIONS", {})
if "sslmode" not in str(options) and default_db.get("ENGINE", "").endswith("postgresql"):
errors.append(
Warning(
'Database SSL mode is not explicitly configured.',
hint='Consider setting sslmode in database OPTIONS for encrypted connections.',
id='security.W015',
"Database SSL mode is not explicitly configured.",
hint="Consider setting sslmode in database OPTIONS for encrypted connections.",
id="security.W015",
)
)

View File

@@ -19,14 +19,14 @@ from .serializers import RichChoiceOptionSerializer, RichChoiceSerializer
from .utils import get_choice_display, validate_choice_value
__all__ = [
'RichChoice',
'ChoiceCategory',
'ChoiceGroup',
'ChoiceRegistry',
'register_choices',
'RichChoiceField',
'RichChoiceSerializer',
'RichChoiceOptionSerializer',
'validate_choice_value',
'get_choice_display',
"RichChoice",
"ChoiceCategory",
"ChoiceGroup",
"ChoiceRegistry",
"register_choices",
"RichChoiceField",
"RichChoiceSerializer",
"RichChoiceOptionSerializer",
"validate_choice_value",
"get_choice_display",
]

View File

@@ -11,6 +11,7 @@ from typing import Any
class ChoiceCategory(Enum):
"""Categories for organizing choice types"""
STATUS = "status"
TYPE = "type"
CLASSIFICATION = "classification"
@@ -42,6 +43,7 @@ class RichChoice:
deprecated: Whether this choice is deprecated and should not be used for new entries
category: Category for organizing related choices
"""
value: str
label: str
description: str = ""
@@ -59,40 +61,38 @@ class RichChoice:
@property
def color(self) -> str | None:
"""Get the color from metadata if available"""
return self.metadata.get('color')
return self.metadata.get("color")
@property
def icon(self) -> str | None:
"""Get the icon from metadata if available"""
return self.metadata.get('icon')
return self.metadata.get("icon")
@property
def css_class(self) -> str | None:
"""Get the CSS class from metadata if available"""
return self.metadata.get('css_class')
return self.metadata.get("css_class")
@property
def sort_order(self) -> int:
"""Get the sort order from metadata, defaulting to 0"""
return self.metadata.get('sort_order', 0)
return self.metadata.get("sort_order", 0)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary representation for API serialization"""
return {
'value': self.value,
'label': self.label,
'description': self.description,
'metadata': self.metadata,
'deprecated': self.deprecated,
'category': self.category.value,
'color': self.color,
'icon': self.icon,
'css_class': self.css_class,
'sort_order': self.sort_order,
"value": self.value,
"label": self.label,
"description": self.description,
"metadata": self.metadata,
"deprecated": self.deprecated,
"category": self.category.value,
"color": self.color,
"icon": self.icon,
"css_class": self.css_class,
"sort_order": self.sort_order,
}
def __str__(self) -> str:
return self.label
@@ -108,6 +108,7 @@ class ChoiceGroup:
This allows for organizing choices into logical groups with
common properties and behaviors.
"""
name: str
choices: list[RichChoice]
description: str = ""
@@ -147,8 +148,8 @@ class ChoiceGroup:
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary representation for API serialization"""
return {
'name': self.name,
'description': self.description,
'metadata': self.metadata,
'choices': [choice.to_dict() for choice in self.choices]
"name": self.name,
"description": self.description,
"metadata": self.metadata,
"choices": [choice.to_dict() for choice in self.choices],
}

View File

@@ -15,26 +15,26 @@ HEALTH_STATUSES = [
label="Healthy",
description="System is operating normally with no issues detected",
metadata={
'color': 'green',
'icon': 'check-circle',
'css_class': 'bg-green-100 text-green-800',
'sort_order': 1,
'http_status': 200
"color": "green",
"icon": "check-circle",
"css_class": "bg-green-100 text-green-800",
"sort_order": 1,
"http_status": 200,
},
category=ChoiceCategory.STATUS
category=ChoiceCategory.STATUS,
),
RichChoice(
value="unhealthy",
label="Unhealthy",
description="System has detected issues that may affect functionality",
metadata={
'color': 'red',
'icon': 'x-circle',
'css_class': 'bg-red-100 text-red-800',
'sort_order': 2,
'http_status': 503
"color": "red",
"icon": "x-circle",
"css_class": "bg-red-100 text-red-800",
"sort_order": 2,
"http_status": 503,
},
category=ChoiceCategory.STATUS
category=ChoiceCategory.STATUS,
),
]
@@ -45,26 +45,26 @@ SIMPLE_HEALTH_STATUSES = [
label="OK",
description="Basic health check passed",
metadata={
'color': 'green',
'icon': 'check',
'css_class': 'bg-green-100 text-green-800',
'sort_order': 1,
'http_status': 200
"color": "green",
"icon": "check",
"css_class": "bg-green-100 text-green-800",
"sort_order": 1,
"http_status": 200,
},
category=ChoiceCategory.STATUS
category=ChoiceCategory.STATUS,
),
RichChoice(
value="error",
label="Error",
description="Basic health check failed",
metadata={
'color': 'red',
'icon': 'x',
'css_class': 'bg-red-100 text-red-800',
'sort_order': 2,
'http_status': 500
"color": "red",
"icon": "x",
"css_class": "bg-red-100 text-red-800",
"sort_order": 2,
"http_status": 500,
},
category=ChoiceCategory.STATUS
category=ChoiceCategory.STATUS,
),
]
@@ -75,52 +75,52 @@ ENTITY_TYPES = [
label="Park",
description="Theme parks and amusement parks",
metadata={
'color': 'green',
'icon': 'map-pin',
'css_class': 'bg-green-100 text-green-800',
'sort_order': 1,
'search_weight': 1.0
"color": "green",
"icon": "map-pin",
"css_class": "bg-green-100 text-green-800",
"sort_order": 1,
"search_weight": 1.0,
},
category=ChoiceCategory.CLASSIFICATION
category=ChoiceCategory.CLASSIFICATION,
),
RichChoice(
value="ride",
label="Ride",
description="Individual rides and attractions",
metadata={
'color': 'blue',
'icon': 'activity',
'css_class': 'bg-blue-100 text-blue-800',
'sort_order': 2,
'search_weight': 1.0
"color": "blue",
"icon": "activity",
"css_class": "bg-blue-100 text-blue-800",
"sort_order": 2,
"search_weight": 1.0,
},
category=ChoiceCategory.CLASSIFICATION
category=ChoiceCategory.CLASSIFICATION,
),
RichChoice(
value="company",
label="Company",
description="Manufacturers, operators, and designers",
metadata={
'color': 'purple',
'icon': 'building',
'css_class': 'bg-purple-100 text-purple-800',
'sort_order': 3,
'search_weight': 0.8
"color": "purple",
"icon": "building",
"css_class": "bg-purple-100 text-purple-800",
"sort_order": 3,
"search_weight": 0.8,
},
category=ChoiceCategory.CLASSIFICATION
category=ChoiceCategory.CLASSIFICATION,
),
RichChoice(
value="user",
label="User",
description="User profiles and accounts",
metadata={
'color': 'orange',
'icon': 'user',
'css_class': 'bg-orange-100 text-orange-800',
'sort_order': 4,
'search_weight': 0.5
"color": "orange",
"icon": "user",
"css_class": "bg-orange-100 text-orange-800",
"sort_order": 4,
"search_weight": 0.5,
},
category=ChoiceCategory.CLASSIFICATION
category=ChoiceCategory.CLASSIFICATION,
),
]
@@ -133,7 +133,7 @@ def register_core_choices():
choices=HEALTH_STATUSES,
domain="core",
description="Health check status options",
metadata={'domain': 'core', 'type': 'health_status'}
metadata={"domain": "core", "type": "health_status"},
)
register_choices(
@@ -141,7 +141,7 @@ def register_core_choices():
choices=SIMPLE_HEALTH_STATUSES,
domain="core",
description="Simple health check status options",
metadata={'domain': 'core', 'type': 'simple_health_status'}
metadata={"domain": "core", "type": "simple_health_status"},
)
register_choices(
@@ -149,7 +149,7 @@ def register_core_choices():
choices=ENTITY_TYPES,
domain="core",
description="Entity type classifications for search functionality",
metadata={'domain': 'core', 'type': 'entity_type'}
metadata={"domain": "core", "type": "entity_type"},
)

View File

@@ -23,12 +23,7 @@ class RichChoiceField(models.CharField):
"""
def __init__(
self,
choice_group: str,
domain: str = "core",
max_length: int = 50,
allow_deprecated: bool = False,
**kwargs
self, choice_group: str, domain: str = "core", max_length: int = 50, allow_deprecated: bool = False, **kwargs
):
"""
Initialize the RichChoiceField.
@@ -52,8 +47,8 @@ class RichChoiceField(models.CharField):
choices = [(choice.value, choice.label) for choice in choices_list]
kwargs['choices'] = choices
kwargs['max_length'] = max_length
kwargs["choices"] = choices
kwargs["max_length"] = max_length
super().__init__(**kwargs)
@@ -61,21 +56,17 @@ class RichChoiceField(models.CharField):
"""Validate the choice value"""
super().validate(value, model_instance)
if value is None or value == '':
if value is None or value == "":
return
# Check if choice exists in registry
choice = registry.get_choice(self.choice_group, value, self.domain)
if choice is None:
raise ValidationError(
f"'{value}' is not a valid choice for {self.choice_group}"
)
raise ValidationError(f"'{value}' is not a valid choice for {self.choice_group}")
# Check if deprecated choices are allowed
if choice.deprecated and not self.allow_deprecated:
raise ValidationError(
f"'{value}' is deprecated and cannot be used for new entries"
)
raise ValidationError(f"'{value}' is deprecated and cannot be used for new entries")
def get_rich_choice(self, value: str) -> RichChoice | None:
"""Get the RichChoice object for a value"""
@@ -94,21 +85,21 @@ class RichChoiceField(models.CharField):
value = getattr(instance, name)
return self.get_rich_choice(value) if value else None
setattr(cls, f'get_{name}_rich_choice', get_rich_choice_method)
setattr(cls, f"get_{name}_rich_choice", get_rich_choice_method)
# Add get_FOO_display method (Django provides this, but we enhance it)
def get_display_method(instance):
value = getattr(instance, name)
return self.get_choice_display(value) if value else ''
return self.get_choice_display(value) if value else ""
setattr(cls, f'get_{name}_display', get_display_method)
setattr(cls, f"get_{name}_display", get_display_method)
def deconstruct(self):
"""Support for Django migrations"""
name, path, args, kwargs = super().deconstruct()
kwargs['choice_group'] = self.choice_group
kwargs['domain'] = self.domain
kwargs['allow_deprecated'] = self.allow_deprecated
kwargs["choice_group"] = self.choice_group
kwargs["domain"] = self.domain
kwargs["allow_deprecated"] = self.allow_deprecated
return name, path, args, kwargs
@@ -123,7 +114,7 @@ class RichChoiceFormField(ChoiceField):
domain: str = "core",
allow_deprecated: bool = False,
show_descriptions: bool = False,
**kwargs
**kwargs,
):
"""
Initialize the form field.
@@ -154,36 +145,28 @@ class RichChoiceFormField(ChoiceField):
label = f"{choice.label} - {choice.description}"
choices.append((choice.value, label))
kwargs['choices'] = choices
kwargs["choices"] = choices
super().__init__(**kwargs)
def validate(self, value: Any) -> None:
"""Validate the choice value"""
super().validate(value)
if value is None or value == '':
if value is None or value == "":
return
# Check if choice exists in registry
choice = registry.get_choice(self.choice_group, value, self.domain)
if choice is None:
raise ValidationError(
f"'{value}' is not a valid choice for {self.choice_group}"
)
raise ValidationError(f"'{value}' is not a valid choice for {self.choice_group}")
# Check if deprecated choices are allowed
if choice.deprecated and not self.allow_deprecated:
raise ValidationError(
f"'{value}' is deprecated and cannot be used"
)
raise ValidationError(f"'{value}' is deprecated and cannot be used")
def create_rich_choice_field(
choice_group: str,
domain: str = "core",
max_length: int = 50,
allow_deprecated: bool = False,
**kwargs
choice_group: str, domain: str = "core", max_length: int = 50, allow_deprecated: bool = False, **kwargs
) -> RichChoiceField:
"""
Factory function to create a RichChoiceField.
@@ -192,9 +175,5 @@ def create_rich_choice_field(
across multiple models.
"""
return RichChoiceField(
choice_group=choice_group,
domain=domain,
max_length=max_length,
allow_deprecated=allow_deprecated,
**kwargs
choice_group=choice_group, domain=domain, max_length=max_length, allow_deprecated=allow_deprecated, **kwargs
)

View File

@@ -29,7 +29,7 @@ class ChoiceRegistry:
choices: list[RichChoice],
domain: str = "core",
description: str = "",
metadata: dict[str, Any] | None = None
metadata: dict[str, Any] | None = None,
) -> ChoiceGroup:
"""
Register a group of choices.
@@ -65,12 +65,7 @@ class ChoiceRegistry:
f"Existing: {existing_values}, New: {new_values}"
)
choice_group = ChoiceGroup(
name=full_name,
choices=choices,
description=description,
metadata=metadata or {}
)
choice_group = ChoiceGroup(name=full_name, choices=choices, description=description, metadata=metadata or {})
self._choices[full_name] = choice_group
@@ -103,7 +98,6 @@ class ChoiceRegistry:
choice_group = self.get(name, domain)
return choice_group.get_active_choices() if choice_group else []
def get_domains(self) -> list[str]:
"""Get all registered domains"""
return list(self._domains.keys())
@@ -113,10 +107,7 @@ class ChoiceRegistry:
if domain not in self._domains:
return {}
return {
name: self._choices[f"{domain}.{name}"]
for name in self._domains[domain]
}
return {name: self._choices[f"{domain}.{name}"] for name in self._domains[domain]}
def list_all(self) -> dict[str, ChoiceGroup]:
"""Get all registered choice groups"""
@@ -159,7 +150,7 @@ def register_choices(
choices: list[RichChoice],
domain: str = "core",
description: str = "",
metadata: dict[str, Any] | None = None
metadata: dict[str, Any] | None = None,
) -> ChoiceGroup:
"""
Convenience function to register choices with the global registry.
@@ -187,8 +178,6 @@ def get_choice(group_name: str, value: str, domain: str = "core") -> RichChoice
return registry.get_choice(group_name, value, domain)
def validate_choice(group_name: str, value: str, domain: str = "core") -> bool:
"""Validate a choice value using the global registry"""
return registry.validate_choice(group_name, value, domain)

View File

@@ -20,6 +20,7 @@ class RichChoiceSerializer(serializers.Serializer):
This provides a consistent API representation for choice objects
with all their metadata.
"""
value = serializers.CharField()
label = serializers.CharField()
description = serializers.CharField()
@@ -42,6 +43,7 @@ class RichChoiceOptionSerializer(serializers.Serializer):
This replaces the legacy FilterOptionSerializer with rich choice support.
"""
value = serializers.CharField()
label = serializers.CharField()
description = serializers.CharField(allow_blank=True)
@@ -58,30 +60,30 @@ class RichChoiceOptionSerializer(serializers.Serializer):
if isinstance(instance, RichChoice):
# Convert RichChoice to option format
return {
'value': instance.value,
'label': instance.label,
'description': instance.description,
'count': None,
'selected': False,
'deprecated': instance.deprecated,
'color': instance.color,
'icon': instance.icon,
'css_class': instance.css_class,
'metadata': instance.metadata,
"value": instance.value,
"label": instance.label,
"description": instance.description,
"count": None,
"selected": False,
"deprecated": instance.deprecated,
"color": instance.color,
"icon": instance.icon,
"css_class": instance.css_class,
"metadata": instance.metadata,
}
elif isinstance(instance, dict):
# Handle dictionary input (for backwards compatibility)
return {
'value': instance.get('value', ''),
'label': instance.get('label', ''),
'description': instance.get('description', ''),
'count': instance.get('count'),
'selected': instance.get('selected', False),
'deprecated': instance.get('deprecated', False),
'color': instance.get('color'),
'icon': instance.get('icon'),
'css_class': instance.get('css_class'),
'metadata': instance.get('metadata', {}),
"value": instance.get("value", ""),
"label": instance.get("label", ""),
"description": instance.get("description", ""),
"count": instance.get("count"),
"selected": instance.get("selected", False),
"deprecated": instance.get("deprecated", False),
"color": instance.get("color"),
"icon": instance.get("icon"),
"css_class": instance.get("css_class"),
"metadata": instance.get("metadata", {}),
}
else:
return super().to_representation(instance)
@@ -94,6 +96,7 @@ class ChoiceGroupSerializer(serializers.Serializer):
This provides API representation for entire choice groups
with all their choices and metadata.
"""
name = serializers.CharField()
description = serializers.CharField()
metadata = serializers.DictField()
@@ -112,13 +115,7 @@ class RichChoiceFieldSerializer(serializers.CharField):
include rich choice metadata in the response.
"""
def __init__(
self,
choice_group: str,
domain: str = "core",
include_metadata: bool = False,
**kwargs
):
def __init__(self, choice_group: str, domain: str = "core", include_metadata: bool = False, **kwargs):
"""
Initialize the serializer field.
@@ -146,16 +143,16 @@ class RichChoiceFieldSerializer(serializers.CharField):
else:
# Fallback for unknown values
return {
'value': value,
'label': value,
'description': '',
'metadata': {},
'deprecated': False,
'category': 'other',
'color': None,
'icon': None,
'css_class': None,
'sort_order': 0,
"value": value,
"label": value,
"description": "",
"metadata": {},
"deprecated": False,
"category": "other",
"color": None,
"icon": None,
"css_class": None,
"sort_order": 0,
}
else:
# Return just the value
@@ -163,20 +160,16 @@ class RichChoiceFieldSerializer(serializers.CharField):
def to_internal_value(self, data: Any) -> str:
"""Convert input data to choice value"""
if isinstance(data, dict) and 'value' in data:
if isinstance(data, dict) and "value" in data:
# Handle rich choice object input
return data['value']
return data["value"]
else:
# Handle string input
return super().to_internal_value(data)
def create_choice_options_serializer(
choice_group: str,
domain: str = "core",
include_counts: bool = False,
queryset=None,
count_field: str = 'id'
choice_group: str, domain: str = "core", include_counts: bool = False, queryset=None, count_field: str = "id"
) -> list[dict[str, Any]]:
"""
Create choice options for filter endpoints.
@@ -199,47 +192,44 @@ def create_choice_options_serializer(
for choice in choices:
option_data = {
'value': choice.value,
'label': choice.label,
'description': choice.description,
'selected': False,
'deprecated': choice.deprecated,
'color': choice.color,
'icon': choice.icon,
'css_class': choice.css_class,
'metadata': choice.metadata,
"value": choice.value,
"label": choice.label,
"description": choice.description,
"selected": False,
"deprecated": choice.deprecated,
"color": choice.color,
"icon": choice.icon,
"css_class": choice.css_class,
"metadata": choice.metadata,
}
if include_counts and queryset is not None:
# Count items for this choice
try:
count = queryset.filter(**{count_field: choice.value}).count()
option_data['count'] = count
option_data["count"] = count
except Exception:
# If counting fails, set count to None
option_data['count'] = None
option_data["count"] = None
else:
option_data['count'] = None
option_data["count"] = None
options.append(option_data)
# Sort by sort_order, then by label
options.sort(key=lambda x: (
(lambda c: c.sort_order if (c is not None and hasattr(c, 'sort_order')) else 0)(
registry.get_choice(choice_group, x['value'], domain)
),
x['label']
))
options.sort(
key=lambda x: (
(lambda c: c.sort_order if (c is not None and hasattr(c, "sort_order")) else 0)(
registry.get_choice(choice_group, x["value"], domain)
),
x["label"],
)
)
return options
def serialize_choice_value(
value: str,
choice_group: str,
domain: str = "core",
include_metadata: bool = False
) -> Any:
def serialize_choice_value(value: str, choice_group: str, domain: str = "core", include_metadata: bool = False) -> Any:
"""
Serialize a single choice value.
@@ -262,16 +252,16 @@ def serialize_choice_value(
else:
# Fallback for unknown values
return {
'value': value,
'label': value,
'description': '',
'metadata': {},
'deprecated': False,
'category': 'other',
'color': None,
'icon': None,
'css_class': None,
'sort_order': 0,
"value": value,
"label": value,
"description": "",
"metadata": {},
"deprecated": False,
"category": "other",
"color": None,
"icon": None,
"css_class": None,
"sort_order": 0,
}
else:
return value

View File

@@ -10,12 +10,7 @@ from .base import ChoiceCategory, RichChoice
from .registry import registry
def validate_choice_value(
value: str,
choice_group: str,
domain: str = "core",
allow_deprecated: bool = False
) -> bool:
def validate_choice_value(value: str, choice_group: str, domain: str = "core", allow_deprecated: bool = False) -> bool:
"""
Validate that a choice value is valid for a given choice group.
@@ -38,11 +33,7 @@ def validate_choice_value(
return not (choice.deprecated and not allow_deprecated)
def get_choice_display(
value: str,
choice_group: str,
domain: str = "core"
) -> str:
def get_choice_display(value: str, choice_group: str, domain: str = "core") -> str:
"""
Get the display label for a choice value.
@@ -67,11 +58,8 @@ def get_choice_display(
raise ValueError(f"Choice value '{value}' not found in group '{choice_group}' for domain '{domain}'")
def create_status_choices(
statuses: dict[str, dict[str, Any]],
category: ChoiceCategory = ChoiceCategory.STATUS
statuses: dict[str, dict[str, Any]], category: ChoiceCategory = ChoiceCategory.STATUS
) -> list[RichChoice]:
"""
Create status choices with consistent color coding.
@@ -86,28 +74,28 @@ def create_status_choices(
choices = []
for value, config in statuses.items():
metadata = config.get('metadata', {})
metadata = config.get("metadata", {})
# Add default status colors if not specified
if 'color' not in metadata:
if 'operating' in value.lower() or 'active' in value.lower():
metadata['color'] = 'green'
elif 'closed' in value.lower() or 'inactive' in value.lower():
metadata['color'] = 'red'
elif 'temp' in value.lower() or 'pending' in value.lower():
metadata['color'] = 'yellow'
elif 'construction' in value.lower():
metadata['color'] = 'blue'
if "color" not in metadata:
if "operating" in value.lower() or "active" in value.lower():
metadata["color"] = "green"
elif "closed" in value.lower() or "inactive" in value.lower():
metadata["color"] = "red"
elif "temp" in value.lower() or "pending" in value.lower():
metadata["color"] = "yellow"
elif "construction" in value.lower():
metadata["color"] = "blue"
else:
metadata['color'] = 'gray'
metadata["color"] = "gray"
choice = RichChoice(
value=value,
label=config['label'],
description=config.get('description', ''),
label=config["label"],
description=config.get("description", ""),
metadata=metadata,
deprecated=config.get('deprecated', False),
category=category
deprecated=config.get("deprecated", False),
category=category,
)
choices.append(choice)
@@ -115,8 +103,7 @@ def create_status_choices(
def create_type_choices(
types: dict[str, dict[str, Any]],
category: ChoiceCategory = ChoiceCategory.TYPE
types: dict[str, dict[str, Any]], category: ChoiceCategory = ChoiceCategory.TYPE
) -> list[RichChoice]:
"""
Create type/classification choices.
@@ -133,21 +120,18 @@ def create_type_choices(
for value, config in types.items():
choice = RichChoice(
value=value,
label=config['label'],
description=config.get('description', ''),
metadata=config.get('metadata', {}),
deprecated=config.get('deprecated', False),
category=category
label=config["label"],
description=config.get("description", ""),
metadata=config.get("metadata", {}),
deprecated=config.get("deprecated", False),
category=category,
)
choices.append(choice)
return choices
def merge_choice_metadata(
base_metadata: dict[str, Any],
override_metadata: dict[str, Any]
) -> dict[str, Any]:
def merge_choice_metadata(base_metadata: dict[str, Any], override_metadata: dict[str, Any]) -> dict[str, Any]:
"""
Merge choice metadata dictionaries.
@@ -163,10 +147,7 @@ def merge_choice_metadata(
return merged
def filter_choices_by_category(
choices: list[RichChoice],
category: ChoiceCategory
) -> list[RichChoice]:
def filter_choices_by_category(choices: list[RichChoice], category: ChoiceCategory) -> list[RichChoice]:
"""
Filter choices by category.
@@ -180,10 +161,7 @@ def filter_choices_by_category(
return [choice for choice in choices if choice.category == category]
def sort_choices(
choices: list[RichChoice],
sort_by: str = "sort_order"
) -> list[RichChoice]:
def sort_choices(choices: list[RichChoice], sort_by: str = "sort_order") -> list[RichChoice]:
"""
Sort choices by specified criteria.
@@ -204,10 +182,7 @@ def sort_choices(
return choices
def get_choice_colors(
choice_group: str,
domain: str = "core"
) -> dict[str, str]:
def get_choice_colors(choice_group: str, domain: str = "core") -> dict[str, str]:
"""
Get a mapping of choice values to their colors.
@@ -219,18 +194,10 @@ def get_choice_colors(
Dictionary mapping choice values to colors
"""
choices = registry.get_choices(choice_group, domain)
return {
choice.value: choice.color
for choice in choices
if choice.color
}
return {choice.value: choice.color for choice in choices if choice.color}
def validate_choice_group_data(
name: str,
choices: list[RichChoice],
domain: str = "core"
) -> list[str]:
def validate_choice_group_data(name: str, choices: list[RichChoice], domain: str = "core") -> list[str]:
"""
Validate choice group data and return list of errors.
@@ -267,7 +234,7 @@ def validate_choice_group_data(
description=choice.description,
metadata=choice.metadata,
deprecated=choice.deprecated,
category=choice.category
category=choice.category,
)
except ValueError as e:
errors.append(f"Choice {i}: {str(e)}")
@@ -286,19 +253,16 @@ def create_choice_from_config(config: dict[str, Any]) -> RichChoice:
RichChoice object
"""
return RichChoice(
value=config['value'],
label=config['label'],
description=config.get('description', ''),
metadata=config.get('metadata', {}),
deprecated=config.get('deprecated', False),
category=ChoiceCategory(config.get('category', 'other'))
value=config["value"],
label=config["label"],
description=config.get("description", ""),
metadata=config.get("metadata", {}),
deprecated=config.get("deprecated", False),
category=ChoiceCategory(config.get("category", "other")),
)
def export_choices_to_dict(
choice_group: str,
domain: str = "core"
) -> dict[str, Any]:
def export_choices_to_dict(choice_group: str, domain: str = "core") -> dict[str, Any]:
"""
Export a choice group to a dictionary format.

View File

@@ -48,11 +48,7 @@ def cache_api_response(
cache_key_parts = [
key_prefix,
view_func.__name__,
(
str(getattr(request.user, "id", "anonymous"))
if request.user.is_authenticated
else "anonymous"
),
(str(getattr(request.user, "id", "anonymous")) if request.user.is_authenticated else "anonymous"),
str(hash(frozenset(request.GET.items()))),
]
@@ -72,9 +68,7 @@ def cache_api_response(
# Try to get from cache
cache_service = EnhancedCacheService()
cached_response = getattr(cache_service, cache_backend + "_cache").get(
cache_key
)
cached_response = getattr(cache_service, cache_backend + "_cache").get(cache_key)
if cached_response:
logger.debug(
@@ -87,11 +81,8 @@ def cache_api_response(
)
# If cached data is our dict format for DRF responses, reconstruct it
if isinstance(cached_response, dict) and '__drf_data__' in cached_response:
return DRFResponse(
data=cached_response['__drf_data__'],
status=cached_response.get('status', 200)
)
if isinstance(cached_response, dict) and "__drf_data__" in cached_response:
return DRFResponse(data=cached_response["__drf_data__"], status=cached_response.get("status", 200))
return cached_response
@@ -104,17 +95,12 @@ def cache_api_response(
if hasattr(response, "status_code") and response.status_code == 200:
# For DRF responses, we must cache the data, not the response object
# because the response object is not rendered yet and cannot be pickled
if hasattr(response, 'data'):
cache_payload = {
'__drf_data__': response.data,
'status': response.status_code
}
if hasattr(response, "data"):
cache_payload = {"__drf_data__": response.data, "status": response.status_code}
else:
cache_payload = response
getattr(cache_service, cache_backend + "_cache").set(
cache_key, cache_payload, timeout
)
getattr(cache_service, cache_backend + "_cache").set(cache_key, cache_payload, timeout)
logger.debug(
f"Cached API response for view {view_func.__name__}",
extra={
@@ -162,9 +148,7 @@ def cache_queryset_result(
cache_key = f"{cache_key_template}:{hash(str(args) + str(kwargs))}"
cache_service = EnhancedCacheService()
cached_result = getattr(cache_service, cache_backend + "_cache").get(
cache_key
)
cached_result = getattr(cache_service, cache_backend + "_cache").get(cache_key)
if cached_result is not None:
logger.debug(f"Cache hit for queryset operation: {func.__name__}")
@@ -175,9 +159,7 @@ def cache_queryset_result(
result = func(*args, **kwargs)
execution_time = time.time() - start_time
getattr(cache_service, cache_backend + "_cache").set(
cache_key, result, timeout
)
getattr(cache_service, cache_backend + "_cache").set(cache_key, result, timeout)
logger.debug(
f"Cached queryset result for {func.__name__}",
extra={
@@ -250,24 +232,18 @@ class CachedAPIViewMixin(View):
cache_backend = "api"
@method_decorator(vary_on_headers("User-Agent", "Accept-Language"))
def dispatch(
self, request: HttpRequest, *args: Any, **kwargs: Any
) -> HttpResponseBase:
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponseBase:
"""Add caching to the dispatch method"""
if request.method == "GET" and getattr(self, "enable_caching", True):
return self._cached_dispatch(request, *args, **kwargs)
return super().dispatch(request, *args, **kwargs)
def _cached_dispatch(
self, request: HttpRequest, *args: Any, **kwargs: Any
) -> HttpResponseBase:
def _cached_dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponseBase:
"""Handle cached dispatch for GET requests"""
cache_key = self._generate_cache_key(request, *args, **kwargs)
cache_service = EnhancedCacheService()
cached_response = getattr(cache_service, self.cache_backend + "_cache").get(
cache_key
)
cached_response = getattr(cache_service, self.cache_backend + "_cache").get(cache_key)
if cached_response:
logger.debug(f"Cache hit for view {self.__class__.__name__}")
@@ -278,26 +254,18 @@ class CachedAPIViewMixin(View):
# Cache successful responses
if hasattr(response, "status_code") and response.status_code == 200:
getattr(cache_service, self.cache_backend + "_cache").set(
cache_key, response, self.cache_timeout
)
getattr(cache_service, self.cache_backend + "_cache").set(cache_key, response, self.cache_timeout)
logger.debug(f"Cached response for view {self.__class__.__name__}")
return response
def _generate_cache_key(
self, request: HttpRequest, *args: Any, **kwargs: Any
) -> str:
def _generate_cache_key(self, request: HttpRequest, *args: Any, **kwargs: Any) -> str:
"""Generate cache key for the request"""
key_parts = [
self.cache_key_prefix,
self.__class__.__name__,
request.method,
(
str(getattr(request.user, "id", "anonymous"))
if request.user.is_authenticated
else "anonymous"
),
(str(getattr(request.user, "id", "anonymous")) if request.user.is_authenticated else "anonymous"),
str(hash(frozenset(request.GET.items()))),
]
@@ -344,15 +312,11 @@ def smart_cache(
"kwargs": json.dumps(kwargs, sort_keys=True, default=str),
}
key_string = json.dumps(key_data, sort_keys=True)
cache_key = (
f"smart_cache:{hashlib.md5(key_string.encode()).hexdigest()}"
)
cache_key = f"smart_cache:{hashlib.md5(key_string.encode()).hexdigest()}"
# Try to get from cache
cache_service = EnhancedCacheService()
cached_result = getattr(cache_service, cache_backend + "_cache").get(
cache_key
)
cached_result = getattr(cache_service, cache_backend + "_cache").get(cache_key)
if cached_result is not None:
logger.debug(f"Smart cache hit for {func.__name__}")
@@ -364,9 +328,7 @@ def smart_cache(
execution_time = time.time() - start_time
# Cache result
getattr(cache_service, cache_backend + "_cache").set(
cache_key, result, timeout
)
getattr(cache_service, cache_backend + "_cache").set(cache_key, result, timeout)
logger.debug(
f"Smart cached result for {func.__name__}",
@@ -426,16 +388,10 @@ def generate_model_cache_key(model_instance: Any, suffix: str = "") -> str:
"""Generate cache key based on model instance"""
model_name = model_instance._meta.model_name
instance_id = model_instance.id
return (
f"{model_name}:{instance_id}:{suffix}"
if suffix
else f"{model_name}:{instance_id}"
)
return f"{model_name}:{instance_id}:{suffix}" if suffix else f"{model_name}:{instance_id}"
def generate_queryset_cache_key(
queryset: Any, params: dict[str, Any] | None = None
) -> str:
def generate_queryset_cache_key(queryset: Any, params: dict[str, Any] | None = None) -> str:
"""Generate cache key for queryset with parameters"""
model_name = queryset.model._meta.model_name
params_str = json.dumps(params or {}, sort_keys=True, default=str)

View File

@@ -22,9 +22,7 @@ class BaseAutocomplete(Autocomplete):
# UI text configuration using gettext for i18n
no_result_text = _("No matches found")
narrow_search_text = _(
"Showing %(page_size)s of %(total)s matches. Please refine your search."
)
narrow_search_text = _("Showing %(page_size)s of %(total)s matches. Please refine your search.")
type_at_least_n_characters = _("Type at least %(n)s characters...")
# Project-wide component settings

View File

@@ -1,6 +1,7 @@
"""
Base forms and views for HTMX integration.
"""
from django.http import JsonResponse
from django.views.generic.edit import FormView
@@ -20,9 +21,6 @@ class HTMXFormView(FormView):
def post(self, request, *args, **kwargs):
# If HTMX field validation pattern: ?field=name
if (
request.headers.get("HX-Request") == "true"
and request.GET.get("validate_field")
):
if request.headers.get("HX-Request") == "true" and request.GET.get("validate_field"):
return self.validate_field(request.GET.get("validate_field"))
return super().post(request, *args, **kwargs)

View File

@@ -42,12 +42,8 @@ class LocationSearchForm(forms.Form):
)
# Hidden fields for coordinates
lat = forms.FloatField(
required=False, widget=forms.HiddenInput(attrs={"id": "lat-input"})
)
lng = forms.FloatField(
required=False, widget=forms.HiddenInput(attrs={"id": "lng-input"})
)
lat = forms.FloatField(required=False, widget=forms.HiddenInput(attrs={"id": "lat-input"}))
lng = forms.FloatField(required=False, widget=forms.HiddenInput(attrs={"id": "lng-input"}))
# Search radius
radius_km = forms.ChoiceField(
@@ -81,8 +77,7 @@ class LocationSearchForm(forms.Form):
widget=forms.CheckboxInput(
attrs={
"class": (
"rounded border-gray-300 text-blue-600 focus:ring-blue-500 "
"dark:border-gray-600 dark:bg-gray-700"
"rounded border-gray-300 text-blue-600 focus:ring-blue-500 " "dark:border-gray-600 dark:bg-gray-700"
)
}
),
@@ -93,8 +88,7 @@ class LocationSearchForm(forms.Form):
widget=forms.CheckboxInput(
attrs={
"class": (
"rounded border-gray-300 text-blue-600 focus:ring-blue-500 "
"dark:border-gray-600 dark:bg-gray-700"
"rounded border-gray-300 text-blue-600 focus:ring-blue-500 " "dark:border-gray-600 dark:bg-gray-700"
)
}
),
@@ -105,8 +99,7 @@ class LocationSearchForm(forms.Form):
widget=forms.CheckboxInput(
attrs={
"class": (
"rounded border-gray-300 text-blue-600 focus:ring-blue-500 "
"dark:border-gray-600 dark:bg-gray-700"
"rounded border-gray-300 text-blue-600 focus:ring-blue-500 " "dark:border-gray-600 dark:bg-gray-700"
)
}
),

View File

@@ -58,13 +58,9 @@ class CacheHealthCheck(BaseHealthCheckBackend):
if max_memory > 0:
memory_usage_percent = (used_memory / max_memory) * 100
if memory_usage_percent > 90:
self.add_error(
f"Redis memory usage critical: {memory_usage_percent:.1f}%"
)
self.add_error(f"Redis memory usage critical: {memory_usage_percent:.1f}%")
elif memory_usage_percent > 80:
logger.warning(
f"Redis memory usage high: {memory_usage_percent:.1f}%"
)
logger.warning(f"Redis memory usage high: {memory_usage_percent:.1f}%")
except ImportError:
# django-redis not available, skip additional checks
@@ -160,9 +156,7 @@ class ApplicationHealthCheck(BaseHealthCheckBackend):
try:
__import__(module_name)
except ImportError as e:
self.add_error(
f"Critical module import failed: {module_name} - {e}"
)
self.add_error(f"Critical module import failed: {module_name} - {e}")
# Check if we can access critical models
try:
@@ -179,9 +173,7 @@ class ApplicationHealthCheck(BaseHealthCheckBackend):
ride_count = Ride.objects.count()
user_count = User.objects.count()
logger.debug(
f"Model counts - Parks: {park_count}, Rides: {ride_count}, Users: {user_count}"
)
logger.debug(f"Model counts - Parks: {park_count}, Rides: {ride_count}, Users: {user_count}")
except Exception as e:
self.add_error(f"Model access check failed: {e}")
@@ -195,9 +187,7 @@ class ApplicationHealthCheck(BaseHealthCheckBackend):
self.add_error(f"Media directory does not exist: {settings.MEDIA_ROOT}")
if not os.path.exists(settings.STATIC_ROOT) and not settings.DEBUG:
self.add_error(
f"Static directory does not exist: {settings.STATIC_ROOT}"
)
self.add_error(f"Static directory does not exist: {settings.STATIC_ROOT}")
except Exception as e:
self.add_error(f"Application health check failed: {e}")
@@ -214,10 +204,7 @@ class ExternalServiceHealthCheck(BaseHealthCheckBackend):
from django.conf import settings
from django.core.mail import get_connection
if (
hasattr(settings, "EMAIL_BACKEND")
and "console" not in settings.EMAIL_BACKEND
):
if hasattr(settings, "EMAIL_BACKEND") and "console" not in settings.EMAIL_BACKEND:
# Only check if not using console backend
connection = get_connection()
if hasattr(connection, "open"):
@@ -304,9 +291,7 @@ class DiskSpaceHealthCheck(BaseHealthCheckBackend):
media_free_percent:.1f}% free in media directory"
)
elif media_free_percent < 20:
logger.warning(
f"Low disk space: {media_free_percent:.1f}% free in media directory"
)
logger.warning(f"Low disk space: {media_free_percent:.1f}% free in media directory")
if logs_free_percent < 10:
self.add_error(
@@ -314,9 +299,7 @@ class DiskSpaceHealthCheck(BaseHealthCheckBackend):
logs_free_percent:.1f}% free in logs directory"
)
elif logs_free_percent < 20:
logger.warning(
f"Low disk space: {logs_free_percent:.1f}% free in logs directory"
)
logger.warning(f"Low disk space: {logs_free_percent:.1f}% free in logs directory")
except Exception as e:
logger.warning(f"Disk space check failed: {e}")

View File

@@ -94,9 +94,7 @@ class TrackedModel(models.Model):
event_model = getattr(events, "model", None)
if event_model:
return event_model.objects.filter(pgh_obj_id=self.pk).order_by(
"-pgh_created_at"
)
return event_model.objects.filter(pgh_obj_id=self.pk).order_by("-pgh_created_at")
except (AttributeError, TypeError):
pass
return self.__class__.objects.none()

View File

@@ -23,9 +23,7 @@ class ThrillWikiFormatter(logging.Formatter):
if hasattr(record, "request"):
record.request_id = getattr(record.request, "id", "unknown")
record.user_id = (
getattr(record.request.user, "id", "anonymous")
if hasattr(record.request, "user")
else "unknown"
getattr(record.request.user, "id", "anonymous") if hasattr(record.request, "user") else "unknown"
)
record.path = getattr(record.request, "path", "unknown")
record.method = getattr(record.request, "method", "unknown")
@@ -52,9 +50,7 @@ def get_logger(name: str) -> logging.Logger:
# Only configure if not already configured
if not logger.handlers:
handler = logging.StreamHandler(sys.stdout)
formatter = ThrillWikiFormatter(
fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
formatter = ThrillWikiFormatter(fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO if settings.DEBUG else logging.WARNING)
@@ -91,11 +87,7 @@ def log_exception(
{
"request_path": getattr(request, "path", "unknown"),
"request_method": getattr(request, "method", "unknown"),
"user_id": (
getattr(request.user, "id", "anonymous")
if hasattr(request, "user")
else "unknown"
),
"user_id": (getattr(request.user, "id", "anonymous") if hasattr(request, "user") else "unknown"),
}
)
@@ -134,11 +126,7 @@ def log_business_event(
{
"request_path": getattr(request, "path", "unknown"),
"request_method": getattr(request, "method", "unknown"),
"user_id": (
getattr(request.user, "id", "anonymous")
if hasattr(request, "user")
else "unknown"
),
"user_id": (getattr(request.user, "id", "anonymous") if hasattr(request, "user") else "unknown"),
}
)
@@ -196,11 +184,7 @@ def log_api_request(
"request_type": "api",
"path": getattr(request, "path", "unknown"),
"method": getattr(request, "method", "unknown"),
"user_id": (
getattr(request.user, "id", "anonymous")
if hasattr(request, "user")
else "unknown"
),
"user_id": (getattr(request.user, "id", "anonymous") if hasattr(request, "user") else "unknown"),
"response_status": response_status,
"duration_ms": duration_ms,
}
@@ -246,11 +230,7 @@ def log_security_event(
{
"request_path": getattr(request, "path", "unknown"),
"request_method": getattr(request, "method", "unknown"),
"user_id": (
getattr(request.user, "id", "anonymous")
if hasattr(request, "user")
else "unknown"
),
"user_id": (getattr(request.user, "id", "anonymous") if hasattr(request, "user") else "unknown"),
"remote_addr": request.META.get("REMOTE_ADDR", "unknown"),
"user_agent": request.META.get("HTTP_USER_AGENT", "unknown"),
}

View File

@@ -43,9 +43,7 @@ class Command(BaseCommand):
default=50,
help="Maximum number of results to calculate (default: 50)",
)
parser.add_argument(
"--verbose", action="store_true", help="Enable verbose output"
)
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
def handle(self, *args, **options):
content_type = options["content_type"]
@@ -61,17 +59,13 @@ class Command(BaseCommand):
new_items = []
if content_type in ["all", "parks"]:
parks = self._get_new_parks(
cutoff_date, limit if content_type == "parks" else limit * 2
)
parks = self._get_new_parks(cutoff_date, limit if content_type == "parks" else limit * 2)
new_items.extend(parks)
if verbose:
self.stdout.write(f"Found {len(parks)} new parks")
if content_type in ["all", "rides"]:
rides = self._get_new_rides(
cutoff_date, limit if content_type == "rides" else limit * 2
)
rides = self._get_new_rides(cutoff_date, limit if content_type == "rides" else limit * 2)
new_items.extend(rides)
if verbose:
self.stdout.write(f"Found {len(rides)} new rides")
@@ -88,27 +82,22 @@ class Command(BaseCommand):
cache.set(cache_key, formatted_results, 1800) # Cache for 30 minutes
self.stdout.write(
self.style.SUCCESS(
f"Successfully calculated {len(formatted_results)} new items for {content_type}"
)
self.style.SUCCESS(f"Successfully calculated {len(formatted_results)} new items for {content_type}")
)
if verbose:
for item in formatted_results[:5]: # Show first 5 items
self.stdout.write(
f" {item['name']} ({item['park']}) - opened: {item['date_opened']}"
)
self.stdout.write(f" {item['name']} ({item['park']}) - opened: {item['date_opened']}")
except Exception as e:
logger.error(f"Error calculating new content: {e}", exc_info=True)
raise CommandError(f"Failed to calculate new content: {e}")
raise CommandError(f"Failed to calculate new content: {e}") from None
def _get_new_parks(self, cutoff_date: datetime, limit: int) -> list[dict[str, Any]]:
"""Get recently added parks using real data."""
new_parks = (
Park.objects.filter(
Q(created_at__gte=cutoff_date)
| Q(opening_date__gte=cutoff_date.date()),
Q(created_at__gte=cutoff_date) | Q(opening_date__gte=cutoff_date.date()),
status="OPERATING",
)
.select_related("location", "operator")
@@ -146,8 +135,7 @@ class Command(BaseCommand):
"""Get recently added rides using real data."""
new_rides = (
Ride.objects.filter(
Q(created_at__gte=cutoff_date)
| Q(opening_date__gte=cutoff_date.date()),
Q(created_at__gte=cutoff_date) | Q(opening_date__gte=cutoff_date.date()),
status="OPERATING",
)
.select_related("park", "park__location")
@@ -156,9 +144,7 @@ class Command(BaseCommand):
results = []
for ride in new_rides:
date_added = getattr(ride, "opening_date", None) or getattr(
ride, "created_at", None
)
date_added = getattr(ride, "opening_date", None) or getattr(ride, "created_at", None)
if date_added and isinstance(date_added, datetime):
date_added = date_added.date()
@@ -184,9 +170,7 @@ class Command(BaseCommand):
return results
def _format_new_content_results(
self, new_items: list[dict[str, Any]]
) -> list[dict[str, Any]]:
def _format_new_content_results(self, new_items: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Format new content results for frontend consumption."""
formatted_results = []

View File

@@ -37,9 +37,7 @@ class Command(BaseCommand):
default=50,
help="Maximum number of results to calculate (default: 50)",
)
parser.add_argument(
"--verbose", action="store_true", help="Enable verbose output"
)
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
def handle(self, *args, **options):
content_type = options["content_type"]
@@ -98,29 +96,23 @@ class Command(BaseCommand):
if verbose:
for item in formatted_results[:5]: # Show first 5 items
self.stdout.write(
f" {item['name']} (score: {item.get('views_change', 'N/A')})"
)
self.stdout.write(f" {item['name']} (score: {item.get('views_change', 'N/A')})")
except Exception as e:
logger.error(f"Error calculating trending content: {e}", exc_info=True)
raise CommandError(f"Failed to calculate trending content: {e}")
raise CommandError(f"Failed to calculate trending content: {e}") from None
def _calculate_trending_parks(
self, current_period_hours: int, previous_period_hours: int, limit: int
) -> list[dict[str, Any]]:
"""Calculate trending scores for parks using real data."""
parks = Park.objects.filter(status="OPERATING").select_related(
"location", "operator"
)
parks = Park.objects.filter(status="OPERATING").select_related("location", "operator")
trending_parks = []
for park in parks:
try:
score = self._calculate_content_score(
park, "park", current_period_hours, previous_period_hours
)
score = self._calculate_content_score(park, "park", current_period_hours, previous_period_hours)
if score > 0: # Only include items with positive trending scores
trending_parks.append(
{
@@ -132,16 +124,8 @@ class Command(BaseCommand):
"slug": park.slug,
"park": park.name, # For parks, park field is the park name itself
"category": "park",
"rating": (
float(park.average_rating)
if park.average_rating
else 0.0
),
"date_opened": (
park.opening_date.isoformat()
if park.opening_date
else ""
),
"rating": (float(park.average_rating) if park.average_rating else 0.0),
"date_opened": (park.opening_date.isoformat() if park.opening_date else ""),
"url": park.url,
}
)
@@ -154,17 +138,13 @@ class Command(BaseCommand):
self, current_period_hours: int, previous_period_hours: int, limit: int
) -> list[dict[str, Any]]:
"""Calculate trending scores for rides using real data."""
rides = Ride.objects.filter(status="OPERATING").select_related(
"park", "park__location"
)
rides = Ride.objects.filter(status="OPERATING").select_related("park", "park__location")
trending_rides = []
for ride in rides:
try:
score = self._calculate_content_score(
ride, "ride", current_period_hours, previous_period_hours
)
score = self._calculate_content_score(ride, "ride", current_period_hours, previous_period_hours)
if score > 0: # Only include items with positive trending scores
trending_rides.append(
{
@@ -176,16 +156,8 @@ class Command(BaseCommand):
"slug": ride.slug,
"park": ride.park.name if ride.park else "",
"category": "ride",
"rating": (
float(ride.average_rating)
if ride.average_rating
else 0.0
),
"date_opened": (
ride.opening_date.isoformat()
if ride.opening_date
else ""
),
"rating": (float(ride.average_rating) if ride.average_rating else 0.0),
"date_opened": (ride.opening_date.isoformat() if ride.opening_date else ""),
"url": ride.url,
"park_url": ride.park.url if ride.park else "",
}
@@ -219,24 +191,15 @@ class Command(BaseCommand):
recency_score = self._calculate_recency_score(content_obj)
# 4. Popularity Score (10% weight)
popularity_score = self._calculate_popularity_score(
ct, content_obj.id, current_period_hours
)
popularity_score = self._calculate_popularity_score(ct, content_obj.id, current_period_hours)
# Calculate weighted final score
final_score = (
view_growth_score * 0.4
+ rating_score * 0.3
+ recency_score * 0.2
+ popularity_score * 0.1
)
final_score = view_growth_score * 0.4 + rating_score * 0.3 + recency_score * 0.2 + popularity_score * 0.1
return final_score
except Exception as e:
logger.error(
f"Error calculating score for {content_type} {content_obj.id}: {e}"
)
logger.error(f"Error calculating score for {content_type} {content_obj.id}: {e}")
return 0.0
def _calculate_view_growth_score(
@@ -248,13 +211,11 @@ class Command(BaseCommand):
) -> float:
"""Calculate normalized view growth score using real PageView data."""
try:
current_views, previous_views, growth_percentage = (
PageView.get_views_growth(
content_type,
object_id,
current_period_hours,
previous_period_hours,
)
current_views, previous_views, growth_percentage = PageView.get_views_growth(
content_type,
object_id,
current_period_hours,
previous_period_hours,
)
if previous_views == 0:
@@ -262,9 +223,7 @@ class Command(BaseCommand):
return min(current_views / 100.0, 1.0) if current_views > 0 else 0.0
# Normalize growth percentage to 0-1 scale
normalized_growth = (
min(growth_percentage / 500.0, 1.0) if growth_percentage > 0 else 0.0
)
normalized_growth = min(growth_percentage / 500.0, 1.0) if growth_percentage > 0 else 0.0
return max(normalized_growth, 0.0)
except Exception as e:
@@ -317,14 +276,10 @@ class Command(BaseCommand):
logger.warning(f"Error calculating recency score: {e}")
return 0.5
def _calculate_popularity_score(
self, content_type: ContentType, object_id: int, hours: int
) -> float:
def _calculate_popularity_score(self, content_type: ContentType, object_id: int, hours: int) -> float:
"""Calculate popularity score based on total view count."""
try:
total_views = PageView.get_total_views_count(
content_type, object_id, hours=hours
)
total_views = PageView.get_total_views_count(content_type, object_id, hours=hours)
# Normalize views to 0-1 scale
if total_views == 0:
@@ -352,13 +307,11 @@ class Command(BaseCommand):
# Get view change for display
content_obj = item["content_object"]
ct = ContentType.objects.get_for_model(content_obj)
current_views, previous_views, growth_percentage = (
PageView.get_views_growth(
ct,
content_obj.id,
current_period_hours,
previous_period_hours,
)
current_views, previous_views, growth_percentage = PageView.get_views_growth(
ct,
content_obj.id,
current_period_hours,
previous_period_hours,
)
# Format exactly as frontend expects
@@ -371,9 +324,7 @@ class Command(BaseCommand):
"rank": rank,
"views": current_views,
"views_change": (
f"+{growth_percentage:.1f}%"
if growth_percentage > 0
else f"{growth_percentage:.1f}%"
f"+{growth_percentage:.1f}%" if growth_percentage > 0 else f"{growth_percentage:.1f}%"
),
"slug": item["slug"],
"date_opened": item["date_opened"],

View File

@@ -21,10 +21,7 @@ from django.core.management.base import BaseCommand
class Command(BaseCommand):
help = (
"Clear all types of cache data including Django cache, "
"__pycache__, and build caches"
)
help = "Clear all types of cache data including Django cache, " "__pycache__, and build caches"
def add_arguments(self, parser):
parser.add_argument(
@@ -92,9 +89,7 @@ class Command(BaseCommand):
)
if self.dry_run:
self.stdout.write(
self.style.WARNING("🔍 DRY RUN MODE - No files will be deleted")
)
self.stdout.write(self.style.WARNING("🔍 DRY RUN MODE - No files will be deleted"))
self.stdout.write("")
self.stdout.write(self.style.SUCCESS("🧹 ThrillWiki Cache Clearing Utility"))
@@ -129,9 +124,7 @@ class Command(BaseCommand):
self.clear_opcache()
self.stdout.write("")
self.stdout.write(
self.style.SUCCESS("✅ Cache clearing completed successfully!")
)
self.stdout.write(self.style.SUCCESS("✅ Cache clearing completed successfully!"))
def clear_django_cache(self):
"""Clear Django cache framework cache."""
@@ -154,23 +147,13 @@ class Command(BaseCommand):
if not self.dry_run:
cache_backend.clear()
cache_info = (
f"{alias} cache ({cache_backend.__class__.__name__})"
)
self.stdout.write(
self.style.SUCCESS(f" ✅ Cleared {cache_info}")
)
cache_info = f"{alias} cache ({cache_backend.__class__.__name__})"
self.stdout.write(self.style.SUCCESS(f" ✅ Cleared {cache_info}"))
except Exception as e:
self.stdout.write(
self.style.WARNING(
f" ⚠️ Could not clear {alias} cache: {e}"
)
)
self.stdout.write(self.style.WARNING(f" ⚠️ Could not clear {alias} cache: {e}"))
except Exception as e:
self.stdout.write(
self.style.ERROR(f" ❌ Error clearing Django cache: {e}")
)
self.stdout.write(self.style.ERROR(f" ❌ Error clearing Django cache: {e}"))
def clear_pycache(self):
"""Clear Python __pycache__ directories and .pyc files."""
@@ -188,11 +171,7 @@ class Command(BaseCommand):
if pycache_dir.is_dir():
try:
# Calculate size before removal
dir_size = sum(
f.stat().st_size
for f in pycache_dir.rglob("*")
if f.is_file()
)
dir_size = sum(f.stat().st_size for f in pycache_dir.rglob("*") if f.is_file())
removed_size += dir_size
if self.verbose:
@@ -203,11 +182,7 @@ class Command(BaseCommand):
removed_count += 1
except Exception as e:
self.stdout.write(
self.style.WARNING(
f" ⚠️ Could not remove {pycache_dir}: {e}"
)
)
self.stdout.write(self.style.WARNING(f" ⚠️ Could not remove {pycache_dir}: {e}"))
# Find and remove .pyc files
for pyc_file in project_root.rglob("*.pyc"):
@@ -223,22 +198,14 @@ class Command(BaseCommand):
removed_count += 1
except Exception as e:
self.stdout.write(
self.style.WARNING(f" ⚠️ Could not remove {pyc_file}: {e}")
)
self.stdout.write(self.style.WARNING(f" ⚠️ Could not remove {pyc_file}: {e}"))
# Format file size
size_mb = removed_size / (1024 * 1024)
self.stdout.write(
self.style.SUCCESS(
f" ✅ Removed {removed_count} Python cache items ({size_mb:.2f} MB)"
)
)
self.stdout.write(self.style.SUCCESS(f" ✅ Removed {removed_count} Python cache items ({size_mb:.2f} MB)"))
except Exception as e:
self.stdout.write(
self.style.ERROR(f" ❌ Error clearing Python cache: {e}")
)
self.stdout.write(self.style.ERROR(f" ❌ Error clearing Python cache: {e}"))
def clear_static_cache(self):
"""Clear static files cache."""
@@ -251,9 +218,7 @@ class Command(BaseCommand):
static_path = Path(static_root)
# Calculate size
total_size = sum(
f.stat().st_size for f in static_path.rglob("*") if f.is_file()
)
total_size = sum(f.stat().st_size for f in static_path.rglob("*") if f.is_file())
size_mb = total_size / (1024 * 1024)
if self.verbose:
@@ -263,22 +228,12 @@ class Command(BaseCommand):
shutil.rmtree(static_path)
static_path.mkdir(parents=True, exist_ok=True)
self.stdout.write(
self.style.SUCCESS(
f" ✅ Cleared static files cache ({size_mb:.2f} MB)"
)
)
self.stdout.write(self.style.SUCCESS(f" ✅ Cleared static files cache ({size_mb:.2f} MB)"))
else:
self.stdout.write(
self.style.WARNING(
" ⚠️ No STATIC_ROOT configured or directory doesn't exist"
)
)
self.stdout.write(self.style.WARNING(" ⚠️ No STATIC_ROOT configured or directory doesn't exist"))
except Exception as e:
self.stdout.write(
self.style.ERROR(f" ❌ Error clearing static cache: {e}")
)
self.stdout.write(self.style.ERROR(f" ❌ Error clearing static cache: {e}"))
def clear_sessions_cache(self):
"""Clear session cache if using cache-based sessions."""
@@ -289,9 +244,7 @@ class Command(BaseCommand):
if "cache" in session_engine:
# Using cache-based sessions
session_cache_alias = getattr(
settings, "SESSION_CACHE_ALIAS", "default"
)
session_cache_alias = getattr(settings, "SESSION_CACHE_ALIAS", "default")
session_cache = caches[session_cache_alias]
if not self.dry_run:
@@ -299,20 +252,12 @@ class Command(BaseCommand):
# In production, you might want more sophisticated session clearing
session_cache.clear()
self.stdout.write(
self.style.SUCCESS(
f" ✅ Cleared cache-based sessions ({session_cache_alias})"
)
)
self.stdout.write(self.style.SUCCESS(f" ✅ Cleared cache-based sessions ({session_cache_alias})"))
else:
self.stdout.write(
self.style.WARNING(" ⚠️ Not using cache-based sessions")
)
self.stdout.write(self.style.WARNING(" ⚠️ Not using cache-based sessions"))
except Exception as e:
self.stdout.write(
self.style.ERROR(f" ❌ Error clearing session cache: {e}")
)
self.stdout.write(self.style.ERROR(f" ❌ Error clearing session cache: {e}"))
def clear_template_cache(self):
"""Clear template cache."""
@@ -332,18 +277,14 @@ class Command(BaseCommand):
# Get engine instance safely
engine_instance = getattr(engine, "engine", None)
if engine_instance:
template_loaders = getattr(
engine_instance, "template_loaders", []
)
template_loaders = getattr(engine_instance, "template_loaders", [])
for loader in template_loaders:
if isinstance(loader, CachedLoader):
if not self.dry_run:
loader.reset()
cleared_engines += 1
if self.verbose:
self.stdout.write(
f" 🗑️ Cleared cached loader: {loader}"
)
self.stdout.write(f" 🗑️ Cleared cached loader: {loader}")
# Check for Jinja2 engines (if present)
elif "Jinja2" in engine_backend and hasattr(engine, "env"):
@@ -353,34 +294,21 @@ class Command(BaseCommand):
env.cache.clear()
cleared_engines += 1
if self.verbose:
self.stdout.write(
f" 🗑️ Cleared Jinja2 cache: {engine}"
)
self.stdout.write(f" 🗑️ Cleared Jinja2 cache: {engine}")
except Exception as e:
if self.verbose:
self.stdout.write(
self.style.WARNING(
f" ⚠️ Could not clear cache for engine {engine}: {e}"
)
)
self.stdout.write(self.style.WARNING(f" ⚠️ Could not clear cache for engine {engine}: {e}"))
if cleared_engines > 0:
self.stdout.write(
self.style.SUCCESS(
f" ✅ Cleared template cache for "
f"{cleared_engines} loaders/engines"
)
self.style.SUCCESS(f" ✅ Cleared template cache for " f"{cleared_engines} loaders/engines")
)
else:
self.stdout.write(
self.style.WARNING(" ⚠️ No cached template loaders found")
)
self.stdout.write(self.style.WARNING(" ⚠️ No cached template loaders found"))
except Exception as e:
self.stdout.write(
self.style.ERROR(f" ❌ Error clearing template cache: {e}")
)
self.stdout.write(self.style.ERROR(f" ❌ Error clearing template cache: {e}"))
def clear_tailwind_cache(self):
"""Clear Tailwind CSS build cache."""
@@ -410,27 +338,15 @@ class Command(BaseCommand):
cleared_count += 1
except Exception as e:
self.stdout.write(
self.style.WARNING(
f" ⚠️ Could not remove {cache_path}: {e}"
)
)
self.stdout.write(self.style.WARNING(f" ⚠️ Could not remove {cache_path}: {e}"))
if cleared_count > 0:
self.stdout.write(
self.style.SUCCESS(
f" ✅ Cleared {cleared_count} Tailwind cache directories"
)
)
self.stdout.write(self.style.SUCCESS(f" ✅ Cleared {cleared_count} Tailwind cache directories"))
else:
self.stdout.write(
self.style.WARNING(" ⚠️ No Tailwind cache directories found")
)
self.stdout.write(self.style.WARNING(" ⚠️ No Tailwind cache directories found"))
except Exception as e:
self.stdout.write(
self.style.ERROR(f" ❌ Error clearing Tailwind cache: {e}")
)
self.stdout.write(self.style.ERROR(f" ❌ Error clearing Tailwind cache: {e}"))
def clear_opcache(self):
"""Clear PHP OPcache if available."""
@@ -452,21 +368,13 @@ class Command(BaseCommand):
if result.returncode == 0:
if "cleared" in result.stdout:
self.stdout.write(
self.style.SUCCESS(" ✅ OPcache cleared successfully")
)
self.stdout.write(self.style.SUCCESS(" ✅ OPcache cleared successfully"))
else:
self.stdout.write(self.style.WARNING(" ⚠️ OPcache not available"))
else:
self.stdout.write(
self.style.WARNING(
" ⚠️ PHP not available or OPcache not accessible"
)
)
self.stdout.write(self.style.WARNING(" ⚠️ PHP not available or OPcache not accessible"))
except (subprocess.TimeoutExpired, FileNotFoundError):
self.stdout.write(
self.style.WARNING(" ⚠️ PHP not found or not accessible")
)
self.stdout.write(self.style.WARNING(" ⚠️ PHP not found or not accessible"))
except Exception as e:
self.stdout.write(self.style.ERROR(f" ❌ Error clearing OPcache: {e}"))

View File

@@ -15,69 +15,69 @@ from apps.core.state_machine.config import callback_config
class Command(BaseCommand):
help = 'List all registered FSM transition callbacks'
help = "List all registered FSM transition callbacks"
def add_arguments(self, parser: CommandParser) -> None:
parser.add_argument(
'--model',
"--model",
type=str,
help='Filter by model name (e.g., EditSubmission, Ride)',
help="Filter by model name (e.g., EditSubmission, Ride)",
)
parser.add_argument(
'--stage',
"--stage",
type=str,
choices=['pre', 'post', 'error', 'all'],
default='all',
help='Filter by callback stage',
choices=["pre", "post", "error", "all"],
default="all",
help="Filter by callback stage",
)
parser.add_argument(
'--verbose',
'-v',
action='store_true',
help='Show detailed callback information',
"--verbose",
"-v",
action="store_true",
help="Show detailed callback information",
)
parser.add_argument(
'--format',
"--format",
type=str,
choices=['text', 'table', 'json'],
default='text',
help='Output format',
choices=["text", "table", "json"],
default="text",
help="Output format",
)
def handle(self, *args, **options):
model_filter = options.get('model')
stage_filter = options.get('stage')
verbose = options.get('verbose', False)
output_format = options.get('format', 'text')
model_filter = options.get("model")
stage_filter = options.get("stage")
verbose = options.get("verbose", False)
output_format = options.get("format", "text")
# Get all registrations
all_registrations = callback_registry.get_all_registrations()
if output_format == 'json':
if output_format == "json":
self._output_json(all_registrations, model_filter, stage_filter)
elif output_format == 'table':
elif output_format == "table":
self._output_table(all_registrations, model_filter, stage_filter, verbose)
else:
self._output_text(all_registrations, model_filter, stage_filter, verbose)
def _output_text(self, registrations, model_filter, stage_filter, verbose):
"""Output in text format."""
self.stdout.write(self.style.SUCCESS('\n=== FSM Transition Callbacks ===\n'))
self.stdout.write(self.style.SUCCESS("\n=== FSM Transition Callbacks ===\n"))
# Group by model
models_seen = set()
total_callbacks = 0
for stage in CallbackStage:
if stage_filter != 'all' and stage.value != stage_filter:
if stage_filter != "all" and stage.value != stage_filter:
continue
stage_regs = registrations.get(stage, [])
if not stage_regs:
continue
self.stdout.write(self.style.WARNING(f'\n{stage.value.upper()} Callbacks:'))
self.stdout.write('-' * 50)
self.stdout.write(self.style.WARNING(f"\n{stage.value.upper()} Callbacks:"))
self.stdout.write("-" * 50)
# Group by model
by_model = {}
@@ -92,42 +92,34 @@ class Command(BaseCommand):
total_callbacks += 1
for model_name, regs in sorted(by_model.items()):
self.stdout.write(f'\n {model_name}:')
self.stdout.write(f"\n {model_name}:")
for reg in regs:
transition = f'{reg.source}{reg.target}'
transition = f"{reg.source}{reg.target}"
callback_name = reg.callback.name
priority = reg.callback.priority
self.stdout.write(
f' [{transition}] {callback_name} (priority: {priority})'
)
self.stdout.write(f" [{transition}] {callback_name} (priority: {priority})")
if verbose:
self.stdout.write(
f' continue_on_error: {reg.callback.continue_on_error}'
)
if hasattr(reg.callback, 'patterns'):
self.stdout.write(
f' patterns: {reg.callback.patterns}'
)
self.stdout.write(f" continue_on_error: {reg.callback.continue_on_error}")
if hasattr(reg.callback, "patterns"):
self.stdout.write(f" patterns: {reg.callback.patterns}")
# Summary
self.stdout.write('\n' + '=' * 50)
self.stdout.write(self.style.SUCCESS(
f'Total: {total_callbacks} callbacks across {len(models_seen)} models'
))
self.stdout.write("\n" + "=" * 50)
self.stdout.write(self.style.SUCCESS(f"Total: {total_callbacks} callbacks across {len(models_seen)} models"))
# Configuration status
self.stdout.write(self.style.WARNING('\nConfiguration Status:'))
self.stdout.write(f' Callbacks enabled: {callback_config.enabled}')
self.stdout.write(f' Notifications enabled: {callback_config.notifications_enabled}')
self.stdout.write(f' Cache invalidation enabled: {callback_config.cache_invalidation_enabled}')
self.stdout.write(f' Related updates enabled: {callback_config.related_updates_enabled}')
self.stdout.write(f' Debug mode: {callback_config.debug_mode}')
self.stdout.write(self.style.WARNING("\nConfiguration Status:"))
self.stdout.write(f" Callbacks enabled: {callback_config.enabled}")
self.stdout.write(f" Notifications enabled: {callback_config.notifications_enabled}")
self.stdout.write(f" Cache invalidation enabled: {callback_config.cache_invalidation_enabled}")
self.stdout.write(f" Related updates enabled: {callback_config.related_updates_enabled}")
self.stdout.write(f" Debug mode: {callback_config.debug_mode}")
def _output_table(self, registrations, model_filter, stage_filter, verbose):
"""Output in table format."""
self.stdout.write(self.style.SUCCESS('\n=== FSM Transition Callbacks ===\n'))
self.stdout.write(self.style.SUCCESS("\n=== FSM Transition Callbacks ===\n"))
# Header
if verbose:
@@ -136,10 +128,10 @@ class Command(BaseCommand):
header = f"{'Model':<20} {'Source':<15} {'Target':<15} {'Stage':<8} {'Callback':<30}"
self.stdout.write(self.style.WARNING(header))
self.stdout.write('-' * len(header))
self.stdout.write("-" * len(header))
for stage in CallbackStage:
if stage_filter != 'all' and stage.value != stage_filter:
if stage_filter != "all" and stage.value != stage_filter:
continue
stage_regs = registrations.get(stage, [])
@@ -160,18 +152,18 @@ class Command(BaseCommand):
import json
output = {
'callbacks': [],
'configuration': {
'enabled': callback_config.enabled,
'notifications_enabled': callback_config.notifications_enabled,
'cache_invalidation_enabled': callback_config.cache_invalidation_enabled,
'related_updates_enabled': callback_config.related_updates_enabled,
'debug_mode': callback_config.debug_mode,
}
"callbacks": [],
"configuration": {
"enabled": callback_config.enabled,
"notifications_enabled": callback_config.notifications_enabled,
"cache_invalidation_enabled": callback_config.cache_invalidation_enabled,
"related_updates_enabled": callback_config.related_updates_enabled,
"debug_mode": callback_config.debug_mode,
},
}
for stage in CallbackStage:
if stage_filter != 'all' and stage.value != stage_filter:
if stage_filter != "all" and stage.value != stage_filter:
continue
stage_regs = registrations.get(stage, [])
@@ -180,15 +172,17 @@ class Command(BaseCommand):
if model_filter and model_name != model_filter:
continue
output['callbacks'].append({
'model': model_name,
'field': reg.field_name,
'source': reg.source,
'target': reg.target,
'stage': stage.value,
'callback': reg.callback.name,
'priority': reg.callback.priority,
'continue_on_error': reg.callback.continue_on_error,
})
output["callbacks"].append(
{
"model": model_name,
"field": reg.field_name,
"source": reg.source,
"target": reg.target,
"stage": stage.value,
"callback": reg.callback.name,
"priority": reg.callback.priority,
"continue_on_error": reg.callback.continue_on_error,
}
)
self.stdout.write(json.dumps(output, indent=2))

View File

@@ -52,26 +52,17 @@ class Command(BaseCommand):
import rjsmin
except ImportError:
rjsmin = None
self.stdout.write(
self.style.WARNING(
"rjsmin not installed. Install with: pip install rjsmin"
)
)
self.stdout.write(self.style.WARNING("rjsmin not installed. Install with: pip install rjsmin"))
try:
import rcssmin
except ImportError:
rcssmin = None
self.stdout.write(
self.style.WARNING(
"rcssmin not installed. Install with: pip install rcssmin"
)
)
self.stdout.write(self.style.WARNING("rcssmin not installed. Install with: pip install rcssmin"))
if not rjsmin and not rcssmin:
raise CommandError(
"Neither rjsmin nor rcssmin is installed. "
"Install at least one: pip install rjsmin rcssmin"
"Neither rjsmin nor rcssmin is installed. " "Install at least one: pip install rjsmin rcssmin"
)
# Get static file directories
@@ -93,9 +84,7 @@ class Command(BaseCommand):
if not css_only and rjsmin:
js_dir = static_dir / "js"
if js_dir.exists():
saved, count = self._process_js_files(
js_dir, rjsmin, dry_run, force
)
saved, count = self._process_js_files(js_dir, rjsmin, dry_run, force)
total_js_saved += saved
js_files_processed += count
@@ -103,9 +92,7 @@ class Command(BaseCommand):
if not js_only and rcssmin:
css_dir = static_dir / "css"
if css_dir.exists():
saved, count = self._process_css_files(
css_dir, rcssmin, dry_run, force
)
saved, count = self._process_css_files(css_dir, rcssmin, dry_run, force)
total_css_saved += saved
css_files_processed += count
@@ -114,17 +101,11 @@ class Command(BaseCommand):
self.stdout.write(self.style.SUCCESS("Static file optimization complete!"))
self.stdout.write(f"JavaScript files processed: {js_files_processed}")
self.stdout.write(f"CSS files processed: {css_files_processed}")
self.stdout.write(
f"Total JS savings: {self._format_size(total_js_saved)}"
)
self.stdout.write(
f"Total CSS savings: {self._format_size(total_css_saved)}"
)
self.stdout.write(f"Total JS savings: {self._format_size(total_js_saved)}")
self.stdout.write(f"Total CSS savings: {self._format_size(total_css_saved)}")
if dry_run:
self.stdout.write(
self.style.WARNING("\nDry run - no files were modified")
)
self.stdout.write(self.style.WARNING("\nDry run - no files were modified"))
def _process_js_files(self, js_dir, rjsmin, dry_run, force):
"""Process JavaScript files for minification."""
@@ -140,9 +121,7 @@ class Command(BaseCommand):
# Skip if minified version exists and not forcing
if min_file.exists() and not force:
self.stdout.write(
f" Skipping {js_file.name} (min version exists)"
)
self.stdout.write(f" Skipping {js_file.name} (min version exists)")
continue
try:
@@ -169,9 +148,7 @@ class Command(BaseCommand):
files_processed += 1
except Exception as e:
self.stdout.write(
self.style.ERROR(f" Error processing {js_file.name}: {e}")
)
self.stdout.write(self.style.ERROR(f" Error processing {js_file.name}: {e}"))
return total_saved, files_processed
@@ -189,9 +166,7 @@ class Command(BaseCommand):
# Skip if minified version exists and not forcing
if min_file.exists() and not force:
self.stdout.write(
f" Skipping {css_file.name} (min version exists)"
)
self.stdout.write(f" Skipping {css_file.name} (min version exists)")
continue
try:
@@ -218,9 +193,7 @@ class Command(BaseCommand):
files_processed += 1
except Exception as e:
self.stdout.write(
self.style.ERROR(f" Error processing {css_file.name}: {e}")
)
self.stdout.write(self.style.ERROR(f" Error processing {css_file.name}: {e}"))
return total_saved, files_processed

View File

@@ -39,19 +39,13 @@ class Command(BaseCommand):
def handle(self, *args, **options):
"""Run the development setup and start the server."""
if not options["skip_setup"]:
self.stdout.write(
self.style.SUCCESS(
"🚀 Setting up and starting ThrillWiki Development Server..."
)
)
self.stdout.write(self.style.SUCCESS("🚀 Setting up and starting ThrillWiki Development Server..."))
# Run the setup_dev command first
execute_from_command_line(["manage.py", "setup_dev"])
else:
self.stdout.write(
self.style.SUCCESS("🚀 Starting ThrillWiki Development Server...")
)
self.stdout.write(self.style.SUCCESS("🚀 Starting ThrillWiki Development Server..."))
# Determine which server command to use
self.get_server_command(options)
@@ -59,9 +53,7 @@ class Command(BaseCommand):
# Start the server
self.stdout.write("")
self.stdout.write(
self.style.SUCCESS(
f"🌟 Starting Django development server on http://{options['host']}:{options['port']}"
)
self.style.SUCCESS(f"🌟 Starting Django development server on http://{options['host']}:{options['port']}")
)
self.stdout.write("Press Ctrl+C to stop the server")
self.stdout.write("")
@@ -76,9 +68,7 @@ class Command(BaseCommand):
]
)
else:
execute_from_command_line(
["manage.py", "runserver", f"{options['host']}:{options['port']}"]
)
execute_from_command_line(["manage.py", "runserver", f"{options['host']}:{options['port']}"])
except KeyboardInterrupt:
self.stdout.write("")
self.stdout.write(self.style.SUCCESS("👋 Development server stopped"))

View File

@@ -16,23 +16,23 @@ from django.core.management.base import BaseCommand
class Command(BaseCommand):
help = 'Run security audit and generate a report'
help = "Run security audit and generate a report"
def add_arguments(self, parser):
parser.add_argument(
'--output',
"--output",
type=str,
help='Output file for the security report',
help="Output file for the security report",
)
parser.add_argument(
'--verbose',
action='store_true',
help='Show detailed information for each check',
"--verbose",
action="store_true",
help="Show detailed information for each check",
)
def handle(self, *args, **options):
self.verbose = options.get('verbose', False)
output_file = options.get('output')
self.verbose = options.get("verbose", False)
output_file = options.get("output")
report_lines = []
@@ -66,11 +66,9 @@ class Command(BaseCommand):
# Write to file if specified
if output_file:
with open(output_file, 'w') as f:
f.write('\n'.join(report_lines))
self.stdout.write(
self.style.SUCCESS(f'\nReport saved to: {output_file}')
)
with open(output_file, "w") as f:
f.write("\n".join(report_lines))
self.stdout.write(self.style.SUCCESS(f"\nReport saved to: {output_file}"))
def log(self, message, report_lines):
"""Log message to both stdout and report."""
@@ -82,10 +80,7 @@ class Command(BaseCommand):
errors = registry.run_checks(tags=[Tags.security])
if not errors:
self.log(
self.style.SUCCESS(" ✓ All Django security checks passed"),
report_lines
)
self.log(self.style.SUCCESS(" ✓ All Django security checks passed"), report_lines)
else:
for error in errors:
prefix = self.style.ERROR(" ✗ ERROR") if error.is_serious() else self.style.WARNING(" ! WARNING")
@@ -97,73 +92,71 @@ class Command(BaseCommand):
def check_configuration(self, report_lines):
"""Check various configuration settings."""
checks = [
('DEBUG mode', not settings.DEBUG, 'DEBUG should be False'),
("DEBUG mode", not settings.DEBUG, "DEBUG should be False"),
("SECRET_KEY length", len(settings.SECRET_KEY) >= 50, f"Length: {len(settings.SECRET_KEY)}"),
(
'SECRET_KEY length',
len(settings.SECRET_KEY) >= 50,
f'Length: {len(settings.SECRET_KEY)}'
"ALLOWED_HOSTS",
bool(settings.ALLOWED_HOSTS) and "*" not in settings.ALLOWED_HOSTS,
str(settings.ALLOWED_HOSTS),
),
(
'ALLOWED_HOSTS',
bool(settings.ALLOWED_HOSTS) and '*' not in settings.ALLOWED_HOSTS,
str(settings.ALLOWED_HOSTS)
"CSRF_TRUSTED_ORIGINS",
bool(getattr(settings, "CSRF_TRUSTED_ORIGINS", [])),
str(getattr(settings, "CSRF_TRUSTED_ORIGINS", [])),
),
(
'CSRF_TRUSTED_ORIGINS',
bool(getattr(settings, 'CSRF_TRUSTED_ORIGINS', [])),
str(getattr(settings, 'CSRF_TRUSTED_ORIGINS', []))
"X_FRAME_OPTIONS",
getattr(settings, "X_FRAME_OPTIONS", "") in ("DENY", "SAMEORIGIN"),
str(getattr(settings, "X_FRAME_OPTIONS", "Not set")),
),
(
'X_FRAME_OPTIONS',
getattr(settings, 'X_FRAME_OPTIONS', '') in ('DENY', 'SAMEORIGIN'),
str(getattr(settings, 'X_FRAME_OPTIONS', 'Not set'))
"SECURE_CONTENT_TYPE_NOSNIFF",
getattr(settings, "SECURE_CONTENT_TYPE_NOSNIFF", False),
str(getattr(settings, "SECURE_CONTENT_TYPE_NOSNIFF", False)),
),
(
'SECURE_CONTENT_TYPE_NOSNIFF',
getattr(settings, 'SECURE_CONTENT_TYPE_NOSNIFF', False),
str(getattr(settings, 'SECURE_CONTENT_TYPE_NOSNIFF', False))
"SECURE_BROWSER_XSS_FILTER",
getattr(settings, "SECURE_BROWSER_XSS_FILTER", False),
str(getattr(settings, "SECURE_BROWSER_XSS_FILTER", False)),
),
(
'SECURE_BROWSER_XSS_FILTER',
getattr(settings, 'SECURE_BROWSER_XSS_FILTER', False),
str(getattr(settings, 'SECURE_BROWSER_XSS_FILTER', False))
"SESSION_COOKIE_HTTPONLY",
getattr(settings, "SESSION_COOKIE_HTTPONLY", True),
str(getattr(settings, "SESSION_COOKIE_HTTPONLY", "Not set")),
),
(
'SESSION_COOKIE_HTTPONLY',
getattr(settings, 'SESSION_COOKIE_HTTPONLY', True),
str(getattr(settings, 'SESSION_COOKIE_HTTPONLY', 'Not set'))
),
(
'CSRF_COOKIE_HTTPONLY',
getattr(settings, 'CSRF_COOKIE_HTTPONLY', True),
str(getattr(settings, 'CSRF_COOKIE_HTTPONLY', 'Not set'))
"CSRF_COOKIE_HTTPONLY",
getattr(settings, "CSRF_COOKIE_HTTPONLY", True),
str(getattr(settings, "CSRF_COOKIE_HTTPONLY", "Not set")),
),
]
# Production-only checks
if not settings.DEBUG:
checks.extend([
(
'SECURE_SSL_REDIRECT',
getattr(settings, 'SECURE_SSL_REDIRECT', False),
str(getattr(settings, 'SECURE_SSL_REDIRECT', False))
),
(
'SESSION_COOKIE_SECURE',
getattr(settings, 'SESSION_COOKIE_SECURE', False),
str(getattr(settings, 'SESSION_COOKIE_SECURE', False))
),
(
'CSRF_COOKIE_SECURE',
getattr(settings, 'CSRF_COOKIE_SECURE', False),
str(getattr(settings, 'CSRF_COOKIE_SECURE', False))
),
(
'SECURE_HSTS_SECONDS',
getattr(settings, 'SECURE_HSTS_SECONDS', 0) >= 31536000,
str(getattr(settings, 'SECURE_HSTS_SECONDS', 0))
),
])
checks.extend(
[
(
"SECURE_SSL_REDIRECT",
getattr(settings, "SECURE_SSL_REDIRECT", False),
str(getattr(settings, "SECURE_SSL_REDIRECT", False)),
),
(
"SESSION_COOKIE_SECURE",
getattr(settings, "SESSION_COOKIE_SECURE", False),
str(getattr(settings, "SESSION_COOKIE_SECURE", False)),
),
(
"CSRF_COOKIE_SECURE",
getattr(settings, "CSRF_COOKIE_SECURE", False),
str(getattr(settings, "CSRF_COOKIE_SECURE", False)),
),
(
"SECURE_HSTS_SECONDS",
getattr(settings, "SECURE_HSTS_SECONDS", 0) >= 31536000,
str(getattr(settings, "SECURE_HSTS_SECONDS", 0)),
),
]
)
for name, is_secure, value in checks:
status = self.style.SUCCESS("") if is_secure else self.style.WARNING("!")
@@ -176,59 +169,43 @@ class Command(BaseCommand):
def check_middleware(self, report_lines):
"""Check security-related middleware is properly configured."""
middleware = getattr(settings, 'MIDDLEWARE', [])
middleware = getattr(settings, "MIDDLEWARE", [])
required_middleware = [
('django.middleware.security.SecurityMiddleware', 'SecurityMiddleware'),
('django.middleware.csrf.CsrfViewMiddleware', 'CSRF Middleware'),
('django.middleware.clickjacking.XFrameOptionsMiddleware', 'X-Frame-Options'),
("django.middleware.security.SecurityMiddleware", "SecurityMiddleware"),
("django.middleware.csrf.CsrfViewMiddleware", "CSRF Middleware"),
("django.middleware.clickjacking.XFrameOptionsMiddleware", "X-Frame-Options"),
]
custom_security_middleware = [
('apps.core.middleware.security_headers.SecurityHeadersMiddleware', 'Security Headers'),
('apps.core.middleware.rate_limiting.AuthRateLimitMiddleware', 'Rate Limiting'),
("apps.core.middleware.security_headers.SecurityHeadersMiddleware", "Security Headers"),
("apps.core.middleware.rate_limiting.AuthRateLimitMiddleware", "Rate Limiting"),
]
# Check required middleware
for mw_path, mw_name in required_middleware:
if mw_path in middleware:
self.log(
f" {self.style.SUCCESS('')} {mw_name} is enabled",
report_lines
)
self.log(f" {self.style.SUCCESS('')} {mw_name} is enabled", report_lines)
else:
self.log(
f" {self.style.ERROR('')} {mw_name} is NOT enabled",
report_lines
)
self.log(f" {self.style.ERROR('')} {mw_name} is NOT enabled", report_lines)
# Check custom security middleware
for mw_path, mw_name in custom_security_middleware:
if mw_path in middleware:
self.log(
f" {self.style.SUCCESS('')} {mw_name} is enabled",
report_lines
)
self.log(f" {self.style.SUCCESS('')} {mw_name} is enabled", report_lines)
else:
self.log(
f" {self.style.WARNING('!')} {mw_name} is not enabled (optional)",
report_lines
)
self.log(f" {self.style.WARNING('!')} {mw_name} is not enabled (optional)", report_lines)
# Check middleware order
try:
security_idx = middleware.index('django.middleware.security.SecurityMiddleware')
session_idx = middleware.index('django.contrib.sessions.middleware.SessionMiddleware')
security_idx = middleware.index("django.middleware.security.SecurityMiddleware")
session_idx = middleware.index("django.contrib.sessions.middleware.SessionMiddleware")
if security_idx < session_idx:
self.log(
f" {self.style.SUCCESS('')} Middleware ordering is correct",
report_lines
)
self.log(f" {self.style.SUCCESS('')} Middleware ordering is correct", report_lines)
else:
self.log(
f" {self.style.WARNING('!')} SecurityMiddleware should come before SessionMiddleware",
report_lines
f" {self.style.WARNING('!')} SecurityMiddleware should come before SessionMiddleware", report_lines
)
except ValueError:
pass # Middleware not found, already reported above

View File

@@ -39,9 +39,7 @@ class Command(BaseCommand):
def handle(self, *args, **options):
"""Run the development setup process."""
self.stdout.write(
self.style.SUCCESS("🚀 Setting up ThrillWiki Development Environment...")
)
self.stdout.write(self.style.SUCCESS("🚀 Setting up ThrillWiki Development Environment..."))
# Create necessary directories
self.create_directories()
@@ -71,9 +69,7 @@ class Command(BaseCommand):
# Display environment info
self.display_environment_info()
self.stdout.write(
self.style.SUCCESS("✅ Development environment setup complete!")
)
self.stdout.write(self.style.SUCCESS("✅ Development environment setup complete!"))
def create_directories(self):
"""Create necessary directories."""
@@ -99,36 +95,24 @@ class Command(BaseCommand):
)
if result.returncode == 0:
self.stdout.write(
self.style.SUCCESS("✅ Database migrations are up to date")
)
self.stdout.write(self.style.SUCCESS("✅ Database migrations are up to date"))
else:
self.stdout.write("🔄 Running database migrations...")
subprocess.run(
["uv", "run", "manage.py", "migrate", "--noinput"], check=True
)
self.stdout.write(
self.style.SUCCESS("✅ Database migrations completed")
)
subprocess.run(["uv", "run", "manage.py", "migrate", "--noinput"], check=True)
self.stdout.write(self.style.SUCCESS("✅ Database migrations completed"))
except subprocess.CalledProcessError as e:
self.stdout.write(
self.style.WARNING(f"⚠️ Migration error (continuing): {e}")
)
self.stdout.write(self.style.WARNING(f"⚠️ Migration error (continuing): {e}"))
def seed_sample_data(self):
"""Seed sample data to the database."""
self.stdout.write("🌱 Seeding sample data...")
try:
subprocess.run(
["uv", "run", "manage.py", "seed_sample_data"], check=True
)
subprocess.run(["uv", "run", "manage.py", "seed_sample_data"], check=True)
self.stdout.write(self.style.SUCCESS("✅ Sample data seeded"))
except subprocess.CalledProcessError:
self.stdout.write(
self.style.WARNING("⚠️ Could not seed sample data (continuing)")
)
self.stdout.write(self.style.WARNING("⚠️ Could not seed sample data (continuing)"))
def create_superuser(self):
"""Create development superuser if it doesn't exist."""
@@ -145,13 +129,9 @@ class Command(BaseCommand):
self.stdout.write("👤 Creating development superuser (admin/admin)...")
if not User.objects.filter(username="admin").exists():
User.objects.create_superuser("admin", "admin@example.com", "admin")
self.stdout.write(
self.style.SUCCESS("✅ Created superuser: admin/admin")
)
self.stdout.write(self.style.SUCCESS("✅ Created superuser: admin/admin"))
else:
self.stdout.write(
self.style.SUCCESS("✅ Admin user already exists")
)
self.stdout.write(self.style.SUCCESS("✅ Admin user already exists"))
except Exception as e:
self.stdout.write(self.style.WARNING(f"⚠️ Could not create superuser: {e}"))
@@ -167,9 +147,7 @@ class Command(BaseCommand):
)
self.stdout.write(self.style.SUCCESS("✅ Static files collected"))
except subprocess.CalledProcessError as e:
self.stdout.write(
self.style.WARNING(f"⚠️ Could not collect static files: {e}")
)
self.stdout.write(self.style.WARNING(f"⚠️ Could not collect static files: {e}"))
def build_tailwind(self):
"""Build Tailwind CSS if npm is available."""
@@ -180,17 +158,11 @@ class Command(BaseCommand):
subprocess.run(["npm", "--version"], capture_output=True, check=True)
# Build Tailwind CSS
subprocess.run(
["uv", "run", "manage.py", "tailwind", "build"], check=True
)
subprocess.run(["uv", "run", "manage.py", "tailwind", "build"], check=True)
self.stdout.write(self.style.SUCCESS("✅ Tailwind CSS built"))
except (subprocess.CalledProcessError, FileNotFoundError):
self.stdout.write(
self.style.WARNING(
"⚠️ npm not found or Tailwind build failed, skipping"
)
)
self.stdout.write(self.style.WARNING("⚠️ npm not found or Tailwind build failed, skipping"))
def run_system_checks(self):
"""Run Django system checks."""
@@ -200,9 +172,7 @@ class Command(BaseCommand):
subprocess.run(["uv", "run", "manage.py", "check"], check=True)
self.stdout.write(self.style.SUCCESS("✅ System checks passed"))
except subprocess.CalledProcessError:
self.stdout.write(
self.style.WARNING("❌ System checks failed, but continuing...")
)
self.stdout.write(self.style.WARNING("❌ System checks failed, but continuing..."))
def display_environment_info(self):
"""Display development environment information."""

View File

@@ -18,62 +18,62 @@ from apps.core.state_machine.monitoring import callback_monitor
class Command(BaseCommand):
help = 'Test FSM transition callbacks for specific transitions'
help = "Test FSM transition callbacks for specific transitions"
def add_arguments(self, parser: CommandParser) -> None:
parser.add_argument(
'model',
"model",
type=str,
help='Model name (e.g., EditSubmission, Ride, Park)',
help="Model name (e.g., EditSubmission, Ride, Park)",
)
parser.add_argument(
'source',
"source",
type=str,
help='Source state value',
help="Source state value",
)
parser.add_argument(
'target',
"target",
type=str,
help='Target state value',
help="Target state value",
)
parser.add_argument(
'--instance-id',
"--instance-id",
type=int,
help='ID of an existing instance to use for testing',
help="ID of an existing instance to use for testing",
)
parser.add_argument(
'--user-id',
"--user-id",
type=int,
help='ID of user to use for testing',
help="ID of user to use for testing",
)
parser.add_argument(
'--dry-run',
action='store_true',
help='Show what would be executed without running callbacks',
"--dry-run",
action="store_true",
help="Show what would be executed without running callbacks",
)
parser.add_argument(
'--stage',
"--stage",
type=str,
choices=['pre', 'post', 'error', 'all'],
default='all',
help='Which callback stage to test',
choices=["pre", "post", "error", "all"],
default="all",
help="Which callback stage to test",
)
parser.add_argument(
'--field',
"--field",
type=str,
default='status',
help='FSM field name (default: status)',
default="status",
help="FSM field name (default: status)",
)
def handle(self, *args, **options):
model_name = options['model']
source = options['source']
target = options['target']
instance_id = options.get('instance_id')
user_id = options.get('user_id')
dry_run = options.get('dry_run', False)
stage_filter = options.get('stage', 'all')
field_name = options.get('field', 'status')
model_name = options["model"]
source = options["source"]
target = options["target"]
instance_id = options.get("instance_id")
user_id = options.get("user_id")
dry_run = options.get("dry_run", False)
stage_filter = options.get("stage", "all")
field_name = options.get("field", "status")
# Find the model class
model_class = self._find_model(model_name)
@@ -90,7 +90,7 @@ class Command(BaseCommand):
try:
user = User.objects.get(pk=user_id)
except User.DoesNotExist:
raise CommandError(f"User with ID {user_id} not found")
raise CommandError(f"User with ID {user_id} not found") from None
# Create transition context
context = TransitionContext(
@@ -101,19 +101,21 @@ class Command(BaseCommand):
user=user,
)
self.stdout.write(self.style.SUCCESS(
f'\n=== Testing Transition Callbacks ===\n'
f'Model: {model_name}\n'
f'Transition: {source}{target}\n'
f'Field: {field_name}\n'
f'Instance: {instance}\n'
f'User: {user}\n'
f'Dry Run: {dry_run}\n'
))
self.stdout.write(
self.style.SUCCESS(
f"\n=== Testing Transition Callbacks ===\n"
f"Model: {model_name}\n"
f"Transition: {source}{target}\n"
f"Field: {field_name}\n"
f"Instance: {instance}\n"
f"User: {user}\n"
f"Dry Run: {dry_run}\n"
)
)
# Get callbacks for each stage
stages_to_test = []
if stage_filter == 'all':
if stage_filter == "all":
stages_to_test = [CallbackStage.PRE, CallbackStage.POST, CallbackStage.ERROR]
else:
stages_to_test = [CallbackStage(stage_filter)]
@@ -123,83 +125,69 @@ class Command(BaseCommand):
total_failures = 0
for stage in stages_to_test:
callbacks = callback_registry.get_callbacks(
model_class, field_name, source, target, stage
)
callbacks = callback_registry.get_callbacks(model_class, field_name, source, target, stage)
if not callbacks:
self.stdout.write(
self.style.WARNING(f'\nNo {stage.value.upper()} callbacks registered')
)
self.stdout.write(self.style.WARNING(f"\nNo {stage.value.upper()} callbacks registered"))
continue
self.stdout.write(
self.style.WARNING(f'\n{stage.value.upper()} Callbacks ({len(callbacks)}):')
)
self.stdout.write('-' * 50)
self.stdout.write(self.style.WARNING(f"\n{stage.value.upper()} Callbacks ({len(callbacks)}):"))
self.stdout.write("-" * 50)
for callback in callbacks:
total_callbacks += 1
callback_info = (
f' {callback.name} (priority: {callback.priority}, '
f'continue_on_error: {callback.continue_on_error})'
f" {callback.name} (priority: {callback.priority}, "
f"continue_on_error: {callback.continue_on_error})"
)
if dry_run:
self.stdout.write(callback_info)
self.stdout.write(self.style.NOTICE(' → Would execute (dry run)'))
self.stdout.write(self.style.NOTICE(" → Would execute (dry run)"))
else:
self.stdout.write(callback_info)
# Check should_execute
if not callback.should_execute(context):
self.stdout.write(
self.style.WARNING(' → Skipped (should_execute returned False)')
)
self.stdout.write(self.style.WARNING(" → Skipped (should_execute returned False)"))
continue
# Execute callback
try:
if stage == CallbackStage.ERROR:
result = callback.execute(
context,
exception=Exception("Test exception")
)
result = callback.execute(context, exception=Exception("Test exception"))
else:
result = callback.execute(context)
if result:
self.stdout.write(self.style.SUCCESS(' → Success'))
self.stdout.write(self.style.SUCCESS(" → Success"))
total_success += 1
else:
self.stdout.write(self.style.ERROR(' → Failed (returned False)'))
self.stdout.write(self.style.ERROR(" → Failed (returned False)"))
total_failures += 1
except Exception as e:
self.stdout.write(
self.style.ERROR(f' → Exception: {type(e).__name__}: {e}')
)
self.stdout.write(self.style.ERROR(f" → Exception: {type(e).__name__}: {e}"))
total_failures += 1
# Summary
self.stdout.write('\n' + '=' * 50)
self.stdout.write(self.style.SUCCESS(f'Total callbacks: {total_callbacks}'))
self.stdout.write("\n" + "=" * 50)
self.stdout.write(self.style.SUCCESS(f"Total callbacks: {total_callbacks}"))
if not dry_run:
self.stdout.write(self.style.SUCCESS(f'Successful: {total_success}'))
self.stdout.write(self.style.SUCCESS(f"Successful: {total_success}"))
self.stdout.write(
self.style.ERROR(f'Failed: {total_failures}') if total_failures
else self.style.SUCCESS(f'Failed: {total_failures}')
self.style.ERROR(f"Failed: {total_failures}")
if total_failures
else self.style.SUCCESS(f"Failed: {total_failures}")
)
# Show monitoring stats if available
if not dry_run:
self.stdout.write(self.style.WARNING('\nRecent Executions:'))
self.stdout.write(self.style.WARNING("\nRecent Executions:"))
recent = callback_monitor.get_recent_executions(limit=10)
for record in recent:
status = '' if record.success else ''
self.stdout.write(
f' {status} {record.callback_name} [{record.duration_ms:.2f}ms]'
)
status = "" if record.success else ""
self.stdout.write(f" {status} {record.callback_name} [{record.duration_ms:.2f}ms]")
def _find_model(self, model_name):
"""Find a model class by name."""
@@ -217,9 +205,7 @@ class Command(BaseCommand):
try:
return model_class.objects.get(pk=instance_id)
except model_class.DoesNotExist:
raise CommandError(
f"{model_class.__name__} with ID {instance_id} not found"
)
raise CommandError(f"{model_class.__name__} with ID {instance_id} not found") from None
# Create a mock instance for testing
# This won't be saved to the database
@@ -227,8 +213,6 @@ class Command(BaseCommand):
instance.pk = 0 # Fake ID
setattr(instance, field_name, source)
self.stdout.write(self.style.NOTICE(
'Using mock instance (no --instance-id provided)'
))
self.stdout.write(self.style.NOTICE("Using mock instance (no --instance-id provided)"))
return instance

View File

@@ -37,9 +37,7 @@ class Command(BaseCommand):
self.test_trending_algorithm()
self.test_api_format()
self.stdout.write(
self.style.SUCCESS("✓ Trending system test completed successfully!")
)
self.stdout.write(self.style.SUCCESS("✓ Trending system test completed successfully!"))
def clean_test_data(self):
"""Clean existing test data."""
@@ -101,9 +99,7 @@ class Command(BaseCommand):
# Create parks
parks = []
for park_data in parks_data:
park, created = Park.objects.get_or_create(
name=park_data["name"], defaults=park_data
)
park, created = Park.objects.get_or_create(name=park_data["name"], defaults=park_data)
parks.append(park)
if created and self.verbose:
self.stdout.write(f" Created park: {park.name}")
@@ -151,9 +147,7 @@ class Command(BaseCommand):
# Create rides
rides = []
for ride_data in rides_data:
ride, created = Ride.objects.get_or_create(
name=ride_data["name"], defaults=ride_data
)
ride, created = Ride.objects.get_or_create(name=ride_data["name"], defaults=ride_data)
rides.append(ride)
if created and self.verbose:
self.stdout.write(f" Created ride: {ride.name}")
@@ -169,48 +163,34 @@ class Command(BaseCommand):
# Pattern 1: Recently trending item (Steel Vengeance)
steel_vengeance = next(r for r in rides if r.name == "Steel Vengeance")
self.create_views_for_content(
steel_vengeance, recent_views=50, older_views=10, base_time=now
)
self.create_views_for_content(steel_vengeance, recent_views=50, older_views=10, base_time=now)
# Pattern 2: Consistently popular item (Space Mountain)
space_mountain = next(r for r in rides if r.name == "Space Mountain")
self.create_views_for_content(
space_mountain, recent_views=30, older_views=25, base_time=now
)
self.create_views_for_content(space_mountain, recent_views=30, older_views=25, base_time=now)
# Pattern 3: Declining popularity (Kingda Ka)
kingda_ka = next(r for r in rides if r.name == "Kingda Ka")
self.create_views_for_content(
kingda_ka, recent_views=5, older_views=40, base_time=now
)
self.create_views_for_content(kingda_ka, recent_views=5, older_views=40, base_time=now)
# Pattern 4: New but growing (Millennium Force)
millennium_force = next(r for r in rides if r.name == "Millennium Force")
self.create_views_for_content(
millennium_force, recent_views=25, older_views=5, base_time=now
)
self.create_views_for_content(millennium_force, recent_views=25, older_views=5, base_time=now)
# Create some park views too
cedar_point = next(p for p in parks if p.name == "Cedar Point")
self.create_views_for_content(
cedar_point, recent_views=35, older_views=20, base_time=now
)
self.create_views_for_content(cedar_point, recent_views=35, older_views=20, base_time=now)
if self.verbose:
self.stdout.write(" Created PageView data for trending analysis")
def create_views_for_content(
self, content_object, recent_views, older_views, base_time
):
def create_views_for_content(self, content_object, recent_views, older_views, base_time):
"""Create PageViews for a content object with specified patterns."""
content_type = ContentType.objects.get_for_model(type(content_object))
# Create recent views (last 2 hours)
for _i in range(recent_views):
view_time = base_time - timedelta(
minutes=random.randint(0, 120) # Last 2 hours
)
view_time = base_time - timedelta(minutes=random.randint(0, 120)) # Last 2 hours
PageView.objects.create(
content_type=content_type,
object_id=content_object.id,
@@ -235,15 +215,9 @@ class Command(BaseCommand):
self.stdout.write("Testing trending algorithm...")
# Test trending content for different content types
trending_parks = trending_service.get_trending_content(
content_type="parks", limit=3
)
trending_rides = trending_service.get_trending_content(
content_type="rides", limit=3
)
trending_all = trending_service.get_trending_content(
content_type="all", limit=5
)
trending_parks = trending_service.get_trending_content(content_type="parks", limit=3)
trending_rides = trending_service.get_trending_content(content_type="rides", limit=3)
trending_all = trending_service.get_trending_content(content_type="all", limit=5)
# Test new content
new_parks = trending_service.get_new_content(content_type="parks", limit=3)
@@ -265,12 +239,8 @@ class Command(BaseCommand):
self.stdout.write("Testing API response format...")
# Test trending content format
trending_parks = trending_service.get_trending_content(
content_type="parks", limit=3
)
trending_service.get_trending_content(
content_type="rides", limit=3
)
trending_parks = trending_service.get_trending_content(content_type="parks", limit=3)
trending_service.get_trending_content(content_type="rides", limit=3)
# Test new content format
new_parks = trending_service.get_new_content(content_type="parks", limit=3)

View File

@@ -94,13 +94,21 @@ class Command(BaseCommand):
try:
parks_list = list(
Park.objects.select_related("location", "operator")
.only("id", "name", "slug", "status", "location__city", "location__state_province", "location__country")
.only(
"id",
"name",
"slug",
"status",
"location__city",
"location__state_province",
"location__country",
)
.order_by("name")[:500]
)
cache_service.default_cache.set(
"warm:park_list",
[{"id": p.id, "name": p.name, "slug": p.slug} for p in parks_list],
timeout=3600
timeout=3600,
)
warmed_count += 1
if verbose:
@@ -116,11 +124,7 @@ class Command(BaseCommand):
if not dry_run:
try:
status_counts = Park.objects.values("status").annotate(count=Count("id"))
cache_service.default_cache.set(
"warm:park_status_counts",
list(status_counts),
timeout=3600
)
cache_service.default_cache.set("warm:park_status_counts", list(status_counts), timeout=3600)
warmed_count += 1
if verbose:
self.stdout.write(" Cached park status counts")
@@ -141,8 +145,11 @@ class Command(BaseCommand):
)
cache_service.default_cache.set(
"warm:popular_parks",
[{"id": p.id, "name": p.name, "slug": p.slug, "ride_count": p.ride_count} for p in popular_parks],
timeout=3600
[
{"id": p.id, "name": p.name, "slug": p.slug, "ride_count": p.ride_count}
for p in popular_parks
],
timeout=3600,
)
warmed_count += 1
if verbose:
@@ -168,8 +175,11 @@ class Command(BaseCommand):
)
cache_service.default_cache.set(
"warm:ride_list",
[{"id": r.id, "name": r.name, "slug": r.slug, "park": r.park.name if r.park else None} for r in rides_list],
timeout=3600
[
{"id": r.id, "name": r.name, "slug": r.slug, "park": r.park.name if r.park else None}
for r in rides_list
],
timeout=3600,
)
warmed_count += 1
if verbose:
@@ -185,11 +195,7 @@ class Command(BaseCommand):
if not dry_run:
try:
category_counts = Ride.objects.values("category").annotate(count=Count("id"))
cache_service.default_cache.set(
"warm:ride_category_counts",
list(category_counts),
timeout=3600
)
cache_service.default_cache.set("warm:ride_category_counts", list(category_counts), timeout=3600)
warmed_count += 1
if verbose:
self.stdout.write(" Cached ride category counts")
@@ -210,8 +216,16 @@ class Command(BaseCommand):
)
cache_service.default_cache.set(
"warm:top_rated_rides",
[{"id": r.id, "name": r.name, "slug": r.slug, "rating": float(r.average_rating) if r.average_rating else None} for r in top_rides],
timeout=3600
[
{
"id": r.id,
"name": r.name,
"slug": r.slug,
"rating": float(r.average_rating) if r.average_rating else None,
}
for r in top_rides
],
timeout=3600,
)
warmed_count += 1
if verbose:
@@ -231,12 +245,9 @@ class Command(BaseCommand):
try:
# Park filter metadata
from apps.parks.services.hybrid_loader import smart_park_loader
metadata = smart_park_loader.get_filter_metadata()
cache_service.default_cache.set(
"warm:park_filter_metadata",
metadata,
timeout=1800
)
cache_service.default_cache.set("warm:park_filter_metadata", metadata, timeout=1800)
warmed_count += 1
if verbose:
self.stdout.write(" Cached park filter metadata")
@@ -251,13 +262,10 @@ class Command(BaseCommand):
try:
# Ride filter metadata
from apps.rides.services.hybrid_loader import SmartRideLoader
ride_loader = SmartRideLoader()
metadata = ride_loader.get_filter_metadata()
cache_service.default_cache.set(
"warm:ride_filter_metadata",
metadata,
timeout=1800
)
cache_service.default_cache.set("warm:ride_filter_metadata", metadata, timeout=1800)
warmed_count += 1
if verbose:
self.stdout.write(" Cached ride filter metadata")

View File

@@ -92,9 +92,7 @@ class LocationQuerySet(BaseQuerySet):
"""Filter locations near a geographic point."""
if hasattr(self.model, "point"):
return (
self.filter(point__distance_lte=(point, Distance(km=distance_km)))
.distance(point)
.order_by("distance")
self.filter(point__distance_lte=(point, Distance(km=distance_km))).distance(point).order_by("distance")
)
return self
@@ -138,9 +136,7 @@ class LocationManager(BaseManager):
return self.get_queryset().near_point(point=point, distance_km=distance_km)
def within_bounds(self, *, north: float, south: float, east: float, west: float):
return self.get_queryset().within_bounds(
north=north, south=south, east=east, west=west
)
return self.get_queryset().within_bounds(north=north, south=south, east=east, west=west)
class ReviewableQuerySet(BaseQuerySet):
@@ -151,9 +147,7 @@ class ReviewableQuerySet(BaseQuerySet):
return self.annotate(
review_count=Count("reviews", filter=Q(reviews__is_published=True)),
average_rating=Avg("reviews__rating", filter=Q(reviews__is_published=True)),
latest_review_date=Max(
"reviews__created_at", filter=Q(reviews__is_published=True)
),
latest_review_date=Max("reviews__created_at", filter=Q(reviews__is_published=True)),
)
def highly_rated(self, *, min_rating: float = 8.0):
@@ -163,9 +157,7 @@ class ReviewableQuerySet(BaseQuerySet):
def recently_reviewed(self, *, days: int = 30):
"""Filter for items with recent reviews."""
cutoff_date = timezone.now() - timedelta(days=days)
return self.filter(
reviews__created_at__gte=cutoff_date, reviews__is_published=True
).distinct()
return self.filter(reviews__created_at__gte=cutoff_date, reviews__is_published=True).distinct()
class ReviewableManager(BaseManager):
@@ -237,9 +229,7 @@ class TimestampedManager(BaseManager):
return TimestampedQuerySet(self.model, using=self._db)
def created_between(self, *, start_date, end_date):
return self.get_queryset().created_between(
start_date=start_date, end_date=end_date
)
return self.get_queryset().created_between(start_date=start_date, end_date=end_date)
class StatusQuerySet(BaseQuerySet):

View File

@@ -16,16 +16,10 @@ class RequestContextProvider(pghistory.context):
def __call__(self, request: WSGIRequest) -> dict:
return {
"user": (
str(request.user)
if request.user and not isinstance(request.user, AnonymousUser)
else None
),
"user": (str(request.user) if request.user and not isinstance(request.user, AnonymousUser) else None),
"ip": request.META.get("REMOTE_ADDR"),
"user_agent": request.META.get("HTTP_USER_AGENT"),
"session_key": (
request.session.session_key if hasattr(request, "session") else None
),
"session_key": (request.session.session_key if hasattr(request, "session") else None),
}

View File

@@ -1,6 +1,7 @@
"""
Middleware for handling errors in HTMX requests.
"""
import logging
from django.http import HttpResponseServerError

View File

@@ -38,12 +38,8 @@ class APIResponseMiddleware(MiddlewareMixin):
response["Vary"] = "Origin"
# Helpful dev CORS headers (adjust for your frontend requests)
response["Access-Control-Allow-Methods"] = (
"GET, POST, PUT, PATCH, DELETE, OPTIONS"
)
response["Access-Control-Allow-Headers"] = (
"Authorization, Content-Type, X-Requested-With"
)
response["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
response["Access-Control-Allow-Headers"] = "Authorization, Content-Type, X-Requested-With"
# Uncomment if your dev frontend needs to send cookies/auth credentials
# response['Access-Control-Allow-Credentials'] = 'true'
logger.debug(f"Added CORS headers for origin: {origin}")

View File

@@ -19,9 +19,7 @@ class PerformanceMiddleware(MiddlewareMixin):
def process_request(self, request):
"""Initialize performance tracking for the request"""
request._performance_start_time = time.time()
request._performance_initial_queries = (
len(connection.queries) if hasattr(connection, "queries") else 0
)
request._performance_initial_queries = len(connection.queries) if hasattr(connection, "queries") else 0
def process_response(self, request, response):
"""Log performance metrics after response is ready"""
@@ -42,11 +40,7 @@ class PerformanceMiddleware(MiddlewareMixin):
duration = end_time - start_time
initial_queries = getattr(request, "_performance_initial_queries", 0)
total_queries = (
len(connection.queries) - initial_queries
if hasattr(connection, "queries")
else 0
)
total_queries = len(connection.queries) - initial_queries if hasattr(connection, "queries") else 0
# Get content length
content_length = 0
@@ -70,9 +64,7 @@ class PerformanceMiddleware(MiddlewareMixin):
if hasattr(request, "user") and request.user.is_authenticated
else None
),
"user_agent": request.META.get("HTTP_USER_AGENT", "")[
:100
], # Truncate user agent
"user_agent": request.META.get("HTTP_USER_AGENT", "")[:100], # Truncate user agent
"remote_addr": self._get_client_ip(request),
}
@@ -81,11 +73,7 @@ class PerformanceMiddleware(MiddlewareMixin):
recent_queries = connection.queries[-total_queries:]
performance_data["queries"] = [
{
"sql": (
query["sql"][:200] + "..."
if len(query["sql"]) > 200
else query["sql"]
),
"sql": (query["sql"][:200] + "..." if len(query["sql"]) > 200 else query["sql"]),
"time": float(query["time"]),
}
for query in recent_queries[-10:] # Last 10 queries only
@@ -95,9 +83,7 @@ class PerformanceMiddleware(MiddlewareMixin):
slow_queries = [q for q in recent_queries if float(q["time"]) > 0.1]
if slow_queries:
performance_data["slow_query_count"] = len(slow_queries)
performance_data["slowest_query_time"] = max(
float(q["time"]) for q in slow_queries
)
performance_data["slowest_query_time"] = max(float(q["time"]) for q in slow_queries)
# Determine log level based on performance
log_level = self._get_log_level(duration, total_queries, response.status_code)
@@ -115,9 +101,7 @@ class PerformanceMiddleware(MiddlewareMixin):
response["X-Response-Time"] = f"{duration * 1000:.2f}ms"
response["X-Query-Count"] = str(total_queries)
if total_queries > 0 and hasattr(connection, "queries"):
total_query_time = sum(
float(q["time"]) for q in connection.queries[-total_queries:]
)
total_query_time = sum(float(q["time"]) for q in connection.queries[-total_queries:])
response["X-Query-Time"] = f"{total_query_time * 1000:.2f}ms"
return response
@@ -129,11 +113,7 @@ class PerformanceMiddleware(MiddlewareMixin):
duration = end_time - start_time
initial_queries = getattr(request, "_performance_initial_queries", 0)
total_queries = (
len(connection.queries) - initial_queries
if hasattr(connection, "queries")
else 0
)
total_queries = len(connection.queries) - initial_queries if hasattr(connection, "queries") else 0
performance_data = {
"path": request.path,
@@ -195,9 +175,7 @@ class QueryCountMiddleware(MiddlewareMixin):
def process_request(self, request):
"""Initialize query tracking"""
request._query_count_start = (
len(connection.queries) if hasattr(connection, "queries") else 0
)
request._query_count_start = len(connection.queries) if hasattr(connection, "queries") else 0
def process_response(self, request, response):
"""Check query count and warn if excessive"""
@@ -267,9 +245,7 @@ class CachePerformanceMiddleware(MiddlewareMixin):
def process_response(self, request, response):
"""Log cache performance metrics"""
cache_duration = time.time() - getattr(
request, "_cache_start_time", time.time()
)
cache_duration = time.time() - getattr(request, "_cache_start_time", time.time())
cache_hits = getattr(request, "_cache_hits", 0)
cache_misses = getattr(request, "_cache_misses", 0)

View File

@@ -35,20 +35,17 @@ class AuthRateLimitMiddleware:
# Endpoints to rate limit
RATE_LIMITED_PATHS = {
# Login endpoints
'/api/v1/auth/login/': {'per_minute': 5, 'per_hour': 30, 'per_day': 100},
'/accounts/login/': {'per_minute': 5, 'per_hour': 30, 'per_day': 100},
"/api/v1/auth/login/": {"per_minute": 5, "per_hour": 30, "per_day": 100},
"/accounts/login/": {"per_minute": 5, "per_hour": 30, "per_day": 100},
# Signup endpoints
'/api/v1/auth/signup/': {'per_minute': 3, 'per_hour': 10, 'per_day': 20},
'/accounts/signup/': {'per_minute': 3, 'per_hour': 10, 'per_day': 20},
"/api/v1/auth/signup/": {"per_minute": 3, "per_hour": 10, "per_day": 20},
"/accounts/signup/": {"per_minute": 3, "per_hour": 10, "per_day": 20},
# Password reset endpoints
'/api/v1/auth/password-reset/': {'per_minute': 2, 'per_hour': 5, 'per_day': 10},
'/accounts/password/reset/': {'per_minute': 2, 'per_hour': 5, 'per_day': 10},
"/api/v1/auth/password-reset/": {"per_minute": 2, "per_hour": 5, "per_day": 10},
"/accounts/password/reset/": {"per_minute": 2, "per_hour": 5, "per_day": 10},
# Token endpoints
'/api/v1/auth/token/': {'per_minute': 10, 'per_hour': 60, 'per_day': 200},
'/api/v1/auth/token/refresh/': {'per_minute': 20, 'per_hour': 120, 'per_day': 500},
"/api/v1/auth/token/": {"per_minute": 10, "per_hour": 60, "per_day": 200},
"/api/v1/auth/token/refresh/": {"per_minute": 20, "per_hour": 120, "per_day": 500},
}
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
@@ -56,7 +53,7 @@ class AuthRateLimitMiddleware:
def __call__(self, request: HttpRequest) -> HttpResponse:
# Only rate limit POST requests to auth endpoints
if request.method != 'POST':
if request.method != "POST":
return self.get_response(request)
# Check if this path should be rate limited
@@ -68,14 +65,10 @@ class AuthRateLimitMiddleware:
client_ip = self._get_client_ip(request)
# Check rate limits
is_allowed, message = self._check_rate_limits(
client_ip, request.path, limits
)
is_allowed, message = self._check_rate_limits(client_ip, request.path, limits)
if not is_allowed:
logger.warning(
f"Rate limit exceeded for {client_ip} on {request.path}"
)
logger.warning(f"Rate limit exceeded for {client_ip} on {request.path}")
return self._rate_limit_response(message)
# Process request
@@ -94,9 +87,9 @@ class AuthRateLimitMiddleware:
return self.RATE_LIMITED_PATHS[path]
# Prefix match (for paths with trailing slashes)
path_without_slash = path.rstrip('/')
path_without_slash = path.rstrip("/")
for limited_path, limits in self.RATE_LIMITED_PATHS.items():
if path_without_slash == limited_path.rstrip('/'):
if path_without_slash == limited_path.rstrip("/"):
return limits
return None
@@ -108,23 +101,18 @@ class AuthRateLimitMiddleware:
Handles common proxy headers (X-Forwarded-For, X-Real-IP).
"""
# Check for forwarded headers (set by reverse proxies)
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
if x_forwarded_for:
# Take the first IP in the chain (client IP)
return x_forwarded_for.split(',')[0].strip()
return x_forwarded_for.split(",")[0].strip()
x_real_ip = request.META.get('HTTP_X_REAL_IP')
x_real_ip = request.META.get("HTTP_X_REAL_IP")
if x_real_ip:
return x_real_ip
return request.META.get('REMOTE_ADDR', 'unknown')
return request.META.get("REMOTE_ADDR", "unknown")
def _check_rate_limits(
self,
client_ip: str,
path: str,
limits: dict
) -> tuple[bool, str]:
def _check_rate_limits(self, client_ip: str, path: str, limits: dict) -> tuple[bool, str]:
"""
Check if the client has exceeded rate limits.
@@ -132,31 +120,31 @@ class AuthRateLimitMiddleware:
Tuple of (is_allowed, reason_if_blocked)
"""
# Create a safe cache key from path
path_key = path.replace('/', '_').strip('_')
path_key = path.replace("/", "_").strip("_")
# Check per-minute limit
minute_key = f"auth_rate:{client_ip}:{path_key}:minute"
minute_count = cache.get(minute_key, 0)
if minute_count >= limits.get('per_minute', 10):
if minute_count >= limits.get("per_minute", 10):
return False, "Too many requests. Please wait a minute before trying again."
# Check per-hour limit
hour_key = f"auth_rate:{client_ip}:{path_key}:hour"
hour_count = cache.get(hour_key, 0)
if hour_count >= limits.get('per_hour', 60):
if hour_count >= limits.get("per_hour", 60):
return False, "Too many requests. Please try again later."
# Check per-day limit
day_key = f"auth_rate:{client_ip}:{path_key}:day"
day_count = cache.get(day_key, 0)
if day_count >= limits.get('per_day', 200):
if day_count >= limits.get("per_day", 200):
return False, "Daily limit exceeded. Please try again tomorrow."
return True, ""
def _increment_counters(self, client_ip: str, path: str) -> None:
"""Increment rate limit counters."""
path_key = path.replace('/', '_').strip('_')
path_key = path.replace("/", "_").strip("_")
# Increment per-minute counter
minute_key = f"auth_rate:{client_ip}:{path_key}:minute"
@@ -183,8 +171,8 @@ class AuthRateLimitMiddleware:
"""Generate a rate limit exceeded response."""
return JsonResponse(
{
'error': message,
'code': 'RATE_LIMIT_EXCEEDED',
"error": message,
"code": "RATE_LIMIT_EXCEEDED",
},
status=429, # Too Many Requests
)
@@ -201,50 +189,31 @@ class SecurityEventLogger:
"""
@staticmethod
def log_failed_login(
request: HttpRequest,
username: str,
reason: str = "Invalid credentials"
) -> None:
def log_failed_login(request: HttpRequest, username: str, reason: str = "Invalid credentials") -> None:
"""Log a failed login attempt."""
client_ip = AuthRateLimitMiddleware._get_client_ip(
AuthRateLimitMiddleware, request
)
client_ip = AuthRateLimitMiddleware._get_client_ip(AuthRateLimitMiddleware, request)
logger.warning(
f"Failed login attempt - IP: {client_ip}, Username: {username}, "
f"Reason: {reason}, User-Agent: {request.META.get('HTTP_USER_AGENT', 'unknown')}"
)
@staticmethod
def log_permission_denied(
request: HttpRequest,
resource: str,
action: str = "access"
) -> None:
def log_permission_denied(request: HttpRequest, resource: str, action: str = "access") -> None:
"""Log a permission denied event."""
client_ip = AuthRateLimitMiddleware._get_client_ip(
AuthRateLimitMiddleware, request
)
user = getattr(request, 'user', None)
username = user.username if user and user.is_authenticated else 'anonymous'
client_ip = AuthRateLimitMiddleware._get_client_ip(AuthRateLimitMiddleware, request)
user = getattr(request, "user", None)
username = user.username if user and user.is_authenticated else "anonymous"
logger.warning(
f"Permission denied - IP: {client_ip}, User: {username}, "
f"Resource: {resource}, Action: {action}"
f"Permission denied - IP: {client_ip}, User: {username}, " f"Resource: {resource}, Action: {action}"
)
@staticmethod
def log_suspicious_activity(
request: HttpRequest,
activity_type: str,
details: str = ""
) -> None:
def log_suspicious_activity(request: HttpRequest, activity_type: str, details: str = "") -> None:
"""Log suspicious activity."""
client_ip = AuthRateLimitMiddleware._get_client_ip(
AuthRateLimitMiddleware, request
)
user = getattr(request, 'user', None)
username = user.username if user and user.is_authenticated else 'anonymous'
client_ip = AuthRateLimitMiddleware._get_client_ip(AuthRateLimitMiddleware, request)
user = getattr(request, "user", None)
username = user.username if user and user.is_authenticated else "anonymous"
logger.error(
f"Suspicious activity detected - Type: {activity_type}, "

View File

@@ -9,7 +9,7 @@ import time
from django.utils.deprecation import MiddlewareMixin
logger = logging.getLogger('request_logging')
logger = logging.getLogger("request_logging")
class RequestLoggingMiddleware(MiddlewareMixin):
@@ -20,17 +20,16 @@ class RequestLoggingMiddleware(MiddlewareMixin):
# Paths to exclude from detailed logging (e.g., static files, health checks)
EXCLUDE_DETAILED_LOGGING_PATHS = [
'/static/',
'/media/',
'/favicon.ico',
'/health/',
'/admin/jsi18n/',
"/static/",
"/media/",
"/favicon.ico",
"/health/",
"/admin/jsi18n/",
]
def _should_log_detailed(self, request):
"""Determine if detailed logging should be enabled for this request."""
return not any(
path in request.path for path in self.EXCLUDE_DETAILED_LOGGING_PATHS)
return not any(path in request.path for path in self.EXCLUDE_DETAILED_LOGGING_PATHS)
def process_request(self, request):
"""Store request start time and capture request data for detailed logging."""
@@ -44,14 +43,17 @@ class RequestLoggingMiddleware(MiddlewareMixin):
try:
# Log request data
request_data = {}
if hasattr(request, 'data') and request.data:
if hasattr(request, "data") and request.data:
request_data = dict(request.data)
elif request.body:
try:
request_data = json.loads(request.body.decode('utf-8'))
request_data = json.loads(request.body.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError):
request_data = {'body': str(request.body)[
:200] + '...' if len(str(request.body)) > 200 else str(request.body)}
request_data = {
"body": (
str(request.body)[:200] + "..." if len(str(request.body)) > 200 else str(request.body)
)
}
# Log query parameters
query_params = dict(request.GET) if request.GET else {}
@@ -61,9 +63,8 @@ class RequestLoggingMiddleware(MiddlewareMixin):
logger.info(f" Body: {self._safe_log_data(request_data)}")
if query_params:
logger.info(f" Query: {query_params}")
if hasattr(request, 'user') and request.user.is_authenticated:
logger.info(
f" User: {request.user.username} (ID: {request.user.id})")
if hasattr(request, "user") and request.user.is_authenticated:
logger.info(f" User: {request.user.username} (ID: {request.user.id})")
except Exception as e:
logger.warning(f"Failed to log request data: {e}")
@@ -75,34 +76,28 @@ class RequestLoggingMiddleware(MiddlewareMixin):
try:
# Calculate request duration
duration = 0
if hasattr(request, '_start_time'):
if hasattr(request, "_start_time"):
duration = time.time() - request._start_time
# Basic request logging
logger.info(
f"{request.method} {request.get_full_path()} -> {response.status_code} "
f"({duration:.3f}s)"
)
logger.info(f"{request.method} {request.get_full_path()} -> {response.status_code} " f"({duration:.3f}s)")
# Detailed response logging for specific endpoints
if getattr(request, '_log_request_data', False):
if getattr(request, "_log_request_data", False):
try:
# Log response data
if hasattr(response, 'data'):
logger.info(
f"RESPONSE DATA for {request.method} {request.path}:")
if hasattr(response, "data"):
logger.info(f"RESPONSE DATA for {request.method} {request.path}:")
logger.info(f" Status: {response.status_code}")
logger.info(f" Data: {self._safe_log_data(response.data)}")
elif hasattr(response, 'content'):
elif hasattr(response, "content"):
try:
content = json.loads(response.content.decode('utf-8'))
logger.info(
f"RESPONSE DATA for {request.method} {request.path}:")
content = json.loads(response.content.decode("utf-8"))
logger.info(f"RESPONSE DATA for {request.method} {request.path}:")
logger.info(f" Status: {response.status_code}")
logger.info(f" Content: {self._safe_log_data(content)}")
except (json.JSONDecodeError, UnicodeDecodeError):
logger.info(
f"RESPONSE DATA for {request.method} {request.path}:")
logger.info(f"RESPONSE DATA for {request.method} {request.path}:")
logger.info(f" Status: {response.status_code}")
logger.info(f" Content: {str(response.content)[:200]}...")
@@ -118,31 +113,31 @@ class RequestLoggingMiddleware(MiddlewareMixin):
# Sensitive field patterns that should be masked in logs
# Security: Comprehensive list of sensitive data patterns
SENSITIVE_PATTERNS = [
'password',
'passwd',
'pwd',
'token',
'secret',
'key',
'api_key',
'apikey',
'auth',
'authorization',
'credential',
'ssn',
'social_security',
'credit_card',
'creditcard',
'card_number',
'cvv',
'cvc',
'pin',
'access_token',
'refresh_token',
'jwt',
'session',
'cookie',
'private',
"password",
"passwd",
"pwd",
"token",
"secret",
"key",
"api_key",
"apikey",
"auth",
"authorization",
"credential",
"ssn",
"social_security",
"credit_card",
"creditcard",
"card_number",
"cvv",
"cvc",
"pin",
"access_token",
"refresh_token",
"jwt",
"session",
"cookie",
"private",
]
def _safe_log_data(self, data):
@@ -167,15 +162,15 @@ class RequestLoggingMiddleware(MiddlewareMixin):
# Truncate if too long
if len(data_str) > 1000:
return data_str[:1000] + '...[TRUNCATED]'
return data_str[:1000] + "...[TRUNCATED]"
return data_str
except Exception:
return str(data)[:500] + '...[ERROR_LOGGING]'
return str(data)[:500] + "...[ERROR_LOGGING]"
def _mask_sensitive_dict(self, data, depth=0):
"""Recursively mask sensitive fields in a dictionary."""
if depth > 5: # Prevent infinite recursion
return '***DEPTH_LIMIT***'
return "***DEPTH_LIMIT***"
safe_data = {}
for key, value in data.items():
@@ -183,7 +178,7 @@ class RequestLoggingMiddleware(MiddlewareMixin):
# Check if key contains any sensitive pattern
if any(pattern in key_lower for pattern in self.SENSITIVE_PATTERNS):
safe_data[key] = '***MASKED***'
safe_data[key] = "***MASKED***"
else:
safe_data[key] = self._mask_sensitive_value(value, depth)
@@ -197,11 +192,11 @@ class RequestLoggingMiddleware(MiddlewareMixin):
return [self._mask_sensitive_value(item, depth + 1) for item in value[:10]] # Limit list items
elif isinstance(value, str):
# Mask email addresses (show only domain)
if '@' in value and '.' in value.split('@')[-1]:
parts = value.split('@')
if "@" in value and "." in value.split("@")[-1]:
parts = value.split("@")
if len(parts) == 2:
return f"***@{parts[1]}"
# Truncate long strings
if len(value) > 200:
return value[:200] + '...[TRUNCATED]'
return value[:200] + "...[TRUNCATED]"
return value

View File

@@ -49,9 +49,7 @@ class SecurityHeadersMiddleware:
if not response.get("Content-Security-Policy"):
response["Content-Security-Policy"] = self._csp_header
else:
logger.warning(
f"CSP header already present for {request.path}, skipping"
)
logger.warning(f"CSP header already present for {request.path}, skipping")
# Permissions-Policy (successor to Feature-Policy)
if not response.get("Permissions-Policy"):
@@ -144,11 +142,13 @@ class SecurityHeadersMiddleware:
# Add debug-specific relaxations
if debug:
# Allow webpack dev server connections in development
directives["connect-src"].extend([
"ws://localhost:*",
"http://localhost:*",
"http://127.0.0.1:*",
])
directives["connect-src"].extend(
[
"ws://localhost:*",
"http://localhost:*",
"http://127.0.0.1:*",
]
)
# Build header string
parts = []
@@ -168,30 +168,34 @@ class SecurityHeadersMiddleware:
This header controls which browser features the page can use.
"""
# Get permissions policy from settings or use defaults
policy = getattr(settings, "PERMISSIONS_POLICY", {
"accelerometer": [],
"ambient-light-sensor": [],
"autoplay": [],
"camera": [],
"display-capture": [],
"document-domain": [],
"encrypted-media": [],
"fullscreen": ["self"],
"geolocation": ["self"],
"gyroscope": [],
"interest-cohort": [],
"magnetometer": [],
"microphone": [],
"midi": [],
"payment": [],
"picture-in-picture": [],
"publickey-credentials-get": [],
"screen-wake-lock": [],
"sync-xhr": [],
"usb": [],
"web-share": ["self"],
"xr-spatial-tracking": [],
})
policy = getattr(
settings,
"PERMISSIONS_POLICY",
{
"accelerometer": [],
"ambient-light-sensor": [],
"autoplay": [],
"camera": [],
"display-capture": [],
"document-domain": [],
"encrypted-media": [],
"fullscreen": ["self"],
"geolocation": ["self"],
"gyroscope": [],
"interest-cohort": [],
"magnetometer": [],
"microphone": [],
"midi": [],
"payment": [],
"picture-in-picture": [],
"publickey-credentials-get": [],
"screen-wake-lock": [],
"sync-xhr": [],
"usb": [],
"web-share": ["self"],
"xr-spatial-tracking": [],
},
)
parts = []
for feature, allowlist in policy.items():

View File

@@ -9,7 +9,6 @@ analytics for the trending algorithm.
import logging
import re
from datetime import timedelta
from typing import Union
from django.conf import settings
from django.contrib.contenttypes.models import ContentType
@@ -22,7 +21,7 @@ from apps.parks.models import Park
from apps.rides.models import Ride
# Type alias for content objects
ContentObject = Union[Park, Ride]
ContentObject = Park | Ride
logger = logging.getLogger(__name__)
@@ -50,8 +49,7 @@ class ViewTrackingMiddleware:
# Compile patterns for performance
self.compiled_patterns = [
(re.compile(pattern), content_type)
for pattern, content_type in self.tracked_patterns
(re.compile(pattern), content_type) for pattern, content_type in self.tracked_patterns
]
# Cache configuration
@@ -63,11 +61,7 @@ class ViewTrackingMiddleware:
response = self.get_response(request)
# Only track successful GET requests
if (
request.method == "GET"
and 200 <= response.status_code < 300
and not self._should_skip_tracking(request)
):
if request.method == "GET" and 200 <= response.status_code < 300 and not self._should_skip_tracking(request):
try:
self._track_view_if_applicable(request)
except Exception as e:
@@ -119,9 +113,7 @@ class ViewTrackingMiddleware:
self._record_page_view(request, content_type, slug)
break
def _record_page_view(
self, request: HttpRequest, content_type: str, slug: str
) -> None:
def _record_page_view(self, request: HttpRequest, content_type: str, slug: str) -> None:
"""Record a page view for the specified content."""
client_ip = self._get_client_ip(request)
if not client_ip:
@@ -131,33 +123,23 @@ class ViewTrackingMiddleware:
# Get the content object
content_obj = self._get_content_object(content_type, slug)
if not content_obj:
self.logger.warning(
f"Content not found: {content_type} with slug '{slug}'"
)
self.logger.warning(f"Content not found: {content_type} with slug '{slug}'")
return
# Check deduplication
if self._is_duplicate_view(content_obj, client_ip):
self.logger.debug(
f"Duplicate view skipped for {content_type} {slug} from {client_ip}"
)
self.logger.debug(f"Duplicate view skipped for {content_type} {slug} from {client_ip}")
return
# Create PageView record
self._create_page_view(content_obj, client_ip, request)
self.logger.debug(
f"Recorded view for {content_type} {slug} from {client_ip}"
)
self.logger.debug(f"Recorded view for {content_type} {slug} from {client_ip}")
except Exception as e:
self.logger.error(
f"Failed to record page view for {content_type} {slug}: {e}"
)
self.logger.error(f"Failed to record page view for {content_type} {slug}: {e}")
def _get_content_object(
self, content_type: str, slug: str
) -> ContentObject | None:
def _get_content_object(self, content_type: str, slug: str) -> ContentObject | None:
"""Get the content object by type and slug."""
try:
if content_type == "park":
@@ -202,16 +184,12 @@ class ViewTrackingMiddleware:
return existing_view
def _create_page_view(
self, content_obj: ContentObject, client_ip: str, request: HttpRequest
) -> None:
def _create_page_view(self, content_obj: ContentObject, client_ip: str, request: HttpRequest) -> None:
"""Create a new PageView record."""
content_type = ContentType.objects.get_for_model(content_obj)
# Extract additional metadata
user_agent = request.META.get("HTTP_USER_AGENT", "")[
:500
] # Truncate long user agents
user_agent = request.META.get("HTTP_USER_AGENT", "")[:500] # Truncate long user agents
referer = request.META.get("HTTP_REFERER", "")[:500]
PageView.objects.create(
@@ -267,11 +245,9 @@ class ViewTrackingMiddleware:
return False
# Skip localhost and private IPs in production
if getattr(settings, "SKIP_LOCAL_IPS", not settings.DEBUG):
if getattr(settings, "SKIP_LOCAL_IPS", not settings.DEBUG): # noqa: SIM102
if (ip.startswith(("127.", "192.168.", "10.")) or ip.startswith("172.")) and any(
16 <= int(ip.split(".")[1]) <= 31
for _ in [ip]
if ip.startswith("172.")
16 <= int(ip.split(".")[1]) <= 31 for _ in [ip] if ip.startswith("172.")
):
return False

View File

@@ -32,9 +32,7 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="slughistory",
name="object_id",
field=models.CharField(
help_text="ID of the object this slug belongs to", max_length=50
),
field=models.CharField(help_text="ID of the object this slug belongs to", max_length=50),
),
migrations.AlterField(
model_name="slughistory",
@@ -56,15 +54,11 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="slughistoryevent",
name="object_id",
field=models.CharField(
help_text="ID of the object this slug belongs to", max_length=50
),
field=models.CharField(help_text="ID of the object this slug belongs to", max_length=50),
),
migrations.AlterField(
model_name="slughistoryevent",
name="old_slug",
field=models.SlugField(
db_index=False, help_text="Previous slug value", max_length=200
),
field=models.SlugField(db_index=False, help_text="Previous slug value", max_length=200),
),
]

View File

@@ -58,10 +58,7 @@ class HTMXFormMixin(FormMixin):
def form_valid(self, form):
"""Add HX-Trigger header on successful form submission via HTMX."""
res = super().form_valid(form)
if (
self.request.headers.get("HX-Request") == "true"
and self.htmx_success_trigger
):
if self.request.headers.get("HX-Request") == "true" and self.htmx_success_trigger:
res["HX-Trigger"] = self.htmx_success_trigger
return res

View File

@@ -78,9 +78,7 @@ class SluggedModel(TrackedModel):
Returns the name of the read-only ID field for this model.
Should be overridden by subclasses.
"""
raise NotImplementedError(
"Subclasses of SluggedModel must implement get_id_field_name()"
)
raise NotImplementedError("Subclasses of SluggedModel must implement get_id_field_name()")
@classmethod
def get_by_slug(cls, slug):
@@ -123,4 +121,4 @@ class SluggedModel(TrackedModel):
True,
)
raise cls.DoesNotExist(f"{cls.__name__} with slug '{slug}' does not exist")
raise cls.DoesNotExist(f"{cls.__name__} with slug '{slug}' does not exist") from None

View File

@@ -14,10 +14,11 @@ class IsOwnerOrReadOnly(permissions.BasePermission):
# Write permissions are only allowed to the owner of the object.
# Assumes the model instance has an `user` attribute.
if hasattr(obj, 'user'):
if hasattr(obj, "user"):
return obj.user == request.user
return False
class IsStaffOrReadOnly(permissions.BasePermission):
"""
Custom permission to only allow staff to edit it.

View File

@@ -61,23 +61,20 @@ def unified_locations_for_map(
# Rides
if "ride" in location_types:
ride_queryset = Ride.objects.select_related(
"park", "manufacturer"
).prefetch_related("park__location", "location")
ride_queryset = Ride.objects.select_related("park", "manufacturer").prefetch_related(
"park__location", "location"
)
if bounds:
ride_queryset = ride_queryset.filter(
Q(location__coordinates__within=bounds)
| Q(park__location__coordinates__within=bounds)
Q(location__coordinates__within=bounds) | Q(park__location__coordinates__within=bounds)
)
if filters:
if "category" in filters:
ride_queryset = ride_queryset.filter(category=filters["category"])
if "manufacturer" in filters:
ride_queryset = ride_queryset.filter(
manufacturer=filters["manufacturer"]
)
ride_queryset = ride_queryset.filter(manufacturer=filters["manufacturer"])
if "park" in filters:
ride_queryset = ride_queryset.filter(park=filters["park"])
@@ -248,12 +245,7 @@ def popular_pages_summary(*, days: int = 30) -> dict[str, Any]:
total_views = PageView.objects.filter(timestamp__gte=cutoff_date).count()
# Unique visitors (based on IP)
unique_visitors = (
PageView.objects.filter(timestamp__gte=cutoff_date)
.values("ip_address")
.distinct()
.count()
)
unique_visitors = PageView.objects.filter(timestamp__gte=cutoff_date).values("ip_address").distinct().count()
return {
"popular_pages": list(popular_pages),
@@ -311,14 +303,10 @@ def system_health_metrics() -> dict[str, Any]:
"page_views_7d": PageView.objects.filter(timestamp__gte=last_7d).count(),
"data_freshness": {
"latest_park_update": (
Park.objects.order_by("-updated_at").first().updated_at
if Park.objects.exists()
else None
Park.objects.order_by("-updated_at").first().updated_at if Park.objects.exists() else None
),
"latest_ride_update": (
Ride.objects.order_by("-updated_at").first().updated_at
if Ride.objects.exists()
else None
Ride.objects.order_by("-updated_at").first().updated_at if Ride.objects.exists() else None
),
},
}

View File

@@ -63,9 +63,7 @@ class ClusteringService:
if zoom_level < self.MIN_ZOOM_FOR_CLUSTERING:
return True
config = self.ZOOM_CONFIGS.get(
zoom_level, {"min_points": self.MIN_POINTS_TO_CLUSTER}
)
config = self.ZOOM_CONFIGS.get(zoom_level, {"min_points": self.MIN_POINTS_TO_CLUSTER})
return point_count >= config["min_points"]
def cluster_locations(
@@ -94,9 +92,7 @@ class ClusteringService:
)
# Perform clustering
clustered_groups = self._cluster_points(
cluster_points, config["radius"], config["min_points"]
)
clustered_groups = self._cluster_points(cluster_points, config["radius"], config["min_points"])
# Separate individual locations from clusters
unclustered_locations = []
@@ -135,9 +131,7 @@ class ClusteringService:
# Simple equirectangular projection (good enough for clustering)
center_lat = (bounds.north + bounds.south) / 2
lat_scale = 111320 # meters per degree latitude
lng_scale = 111320 * math.cos(
math.radians(center_lat)
) # meters per degree longitude
lng_scale = 111320 * math.cos(math.radians(center_lat)) # meters per degree longitude
for location in locations:
# Convert to meters relative to bounds center
@@ -200,9 +194,7 @@ class ClusteringService:
# Calculate cluster bounds
lats = [loc.latitude for loc in locations]
lngs = [loc.longitude for loc in locations]
cluster_bounds = GeoBounds(
north=max(lats), south=min(lats), east=max(lngs), west=min(lngs)
)
cluster_bounds = GeoBounds(north=max(lats), south=min(lats), east=max(lngs), west=min(lngs))
# Collect location types in cluster
types = {loc.type for loc in locations}
@@ -223,9 +215,7 @@ class ClusteringService:
representative_location=representative,
)
def _select_representative_location(
self, locations: list[UnifiedLocation]
) -> UnifiedLocation | None:
def _select_representative_location(self, locations: list[UnifiedLocation]) -> UnifiedLocation | None:
"""Select the most representative location for a cluster."""
if not locations:
return None
@@ -291,9 +281,7 @@ class ClusteringService:
"category_distribution": dict(category_counts),
}
def expand_cluster(
self, cluster: ClusterData, zoom_level: int
) -> list[UnifiedLocation]:
def expand_cluster(self, cluster: ClusterData, zoom_level: int) -> list[UnifiedLocation]:
"""
Expand a cluster to show individual locations (for drill-down functionality).
This would typically require re-querying the database with the cluster bounds.
@@ -320,14 +308,11 @@ class SmartClusteringRules:
return True
# Major parks should resist clustering unless very close
if (
loc1.cluster_category == "major_park"
or loc2.cluster_category == "major_park"
):
if loc1.cluster_category == "major_park" or loc2.cluster_category == "major_park":
return False
# Similar types cluster more readily
if loc1.type == loc2.type:
if loc1.type == loc2.type: # noqa: SIM103
return True
# Different types can cluster but with higher threshold

View File

@@ -89,9 +89,7 @@ class MapFilters:
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for caching and serialization."""
return {
"location_types": (
[t.value for t in self.location_types] if self.location_types else None
),
"location_types": ([t.value for t in self.location_types] if self.location_types else None),
"park_status": (list(self.park_status) if self.park_status else None),
"ride_types": list(self.ride_types) if self.ride_types else None,
"company_roles": (list(self.company_roles) if self.company_roles else None),
@@ -183,11 +181,7 @@ class ClusterData:
"count": self.count,
"types": [t.value for t in self.types],
"bounds": self.bounds.to_dict(),
"representative": (
self.representative_location.to_dict()
if self.representative_location
else None
),
"representative": (self.representative_location.to_dict() if self.representative_location else None),
}

View File

@@ -103,9 +103,7 @@ class EnhancedCacheService:
self.default_cache.set(cache_key, data, timeout)
logger.debug(f"Cached geographic data for bounds {bounds}")
def get_cached_geographic_data(
self, bounds: "GeoBounds", zoom_level: int
) -> Any | None:
def get_cached_geographic_data(self, bounds: "GeoBounds", zoom_level: int) -> Any | None:
"""Retrieve cached geographic data"""
cache_key = f"geo:{bounds.min_lat}:{bounds.min_lng}:{bounds.max_lat}:{
bounds.max_lng
@@ -119,20 +117,14 @@ class EnhancedCacheService:
# For Redis cache backends
if hasattr(self.default_cache, "delete_pattern"):
deleted_count = self.default_cache.delete_pattern(pattern)
logger.info(
f"Invalidated {deleted_count} cache keys matching pattern '{pattern}'"
)
logger.info(f"Invalidated {deleted_count} cache keys matching pattern '{pattern}'")
return deleted_count
else:
logger.warning(
f"Cache backend does not support pattern deletion for pattern '{pattern}'"
)
logger.warning(f"Cache backend does not support pattern deletion for pattern '{pattern}'")
except Exception as e:
logger.error(f"Error invalidating cache pattern '{pattern}': {e}")
def invalidate_model_cache(
self, model_name: str, instance_id: int | None = None
):
def invalidate_model_cache(self, model_name: str, instance_id: int | None = None):
"""Invalidate cache keys related to a specific model"""
pattern = f"*{model_name}:{instance_id}*" if instance_id else f"*{model_name}*"
@@ -175,11 +167,7 @@ def cache_api_response(timeout=1800, vary_on=None, key_prefix=""):
# Generate cache key based on view, user, and parameters
cache_key_parts = [
key_prefix or view_func.__name__,
(
str(request.user.id)
if request.user.is_authenticated
else "anonymous"
),
(str(request.user.id) if request.user.is_authenticated else "anonymous"),
str(hash(frozenset(request.GET.items()))),
]
@@ -219,9 +207,7 @@ def cache_queryset_result(cache_key_template: str, timeout: int = 3600):
cache_key = cache_key_template.format(*args, **kwargs)
cache_service = EnhancedCacheService()
return cache_service.cache_queryset(
cache_key, func, timeout, *args, **kwargs
)
return cache_service.cache_queryset(cache_key, func, timeout, *args, **kwargs)
return wrapper

View File

@@ -218,17 +218,14 @@ class EntityFuzzyMatcher:
return matches[: self.MAX_RESULTS], suggestion
def _get_candidates(
self, query: str, entity_type: EntityType
) -> list[dict[str, Any]]:
def _get_candidates(self, query: str, entity_type: EntityType) -> list[dict[str, Any]]:
"""Get potential matching candidates for an entity type."""
candidates = []
if entity_type == EntityType.PARK:
parks = Park.objects.filter(
Q(name__icontains=query)
| Q(slug__icontains=query.lower().replace(" ", "-"))
)[: self.MAX_CANDIDATES]
parks = Park.objects.filter(Q(name__icontains=query) | Q(slug__icontains=query.lower().replace(" ", "-")))[
: self.MAX_CANDIDATES
]
for park in parks:
candidates.append(
@@ -265,8 +262,7 @@ class EntityFuzzyMatcher:
elif entity_type == EntityType.COMPANY:
companies = Company.objects.filter(
Q(name__icontains=query)
| Q(slug__icontains=query.lower().replace(" ", "-"))
Q(name__icontains=query) | Q(slug__icontains=query.lower().replace(" ", "-"))
)[: self.MAX_CANDIDATES]
for company in companies:
@@ -284,9 +280,7 @@ class EntityFuzzyMatcher:
return candidates
def _score_and_rank_candidates(
self, query: str, candidates: list[dict[str, Any]]
) -> list[FuzzyMatchResult]:
def _score_and_rank_candidates(self, query: str, candidates: list[dict[str, Any]]) -> list[FuzzyMatchResult]:
"""Score and rank all candidates using multiple algorithms."""
scored_matches = []
@@ -354,9 +348,7 @@ class EntityFuzzyMatcher:
# Sort by score (highest first) and return
return sorted(scored_matches, key=lambda x: x.score, reverse=True)
def _generate_entity_suggestion(
self, query: str, entity_types: list[EntityType], user
) -> EntitySuggestion:
def _generate_entity_suggestion(self, query: str, entity_types: list[EntityType], user) -> EntitySuggestion:
"""Generate suggestion for creating new entity when no matches found."""
# Determine most likely entity type based on query characteristics
@@ -364,14 +356,9 @@ class EntityFuzzyMatcher:
# Simple heuristics for entity type detection
query_lower = query.lower()
if any(
word in query_lower
for word in ["roller coaster", "ride", "coaster", "attraction"]
):
if any(word in query_lower for word in ["roller coaster", "ride", "coaster", "attraction"]):
suggested_type = EntityType.RIDE
elif any(
word in query_lower for word in ["inc", "corp", "company", "manufacturer"]
):
elif any(word in query_lower for word in ["inc", "corp", "company", "manufacturer"]):
suggested_type = EntityType.COMPANY
elif EntityType.PARK in entity_types:
suggested_type = EntityType.PARK
@@ -382,21 +369,13 @@ class EntityFuzzyMatcher:
suggested_name = " ".join(word.capitalize() for word in query.split())
# Check if user is authenticated
is_authenticated = (
user and hasattr(user, "is_authenticated") and user.is_authenticated
)
is_authenticated = user and hasattr(user, "is_authenticated") and user.is_authenticated
# Generate appropriate prompts
entity_name = suggested_type.value
login_prompt = (
f"Log in to suggest adding '{suggested_name}' as a new {entity_name}"
)
signup_prompt = (
f"Sign up to contribute and add '{suggested_name}' to ThrillWiki"
)
creation_hint = (
f"Help expand ThrillWiki by adding information about '{suggested_name}'"
)
login_prompt = f"Log in to suggest adding '{suggested_name}' as a new {entity_name}"
signup_prompt = f"Sign up to contribute and add '{suggested_name}' to ThrillWiki"
creation_hint = f"Help expand ThrillWiki by adding information about '{suggested_name}'"
return EntitySuggestion(
suggested_name=suggested_name,

View File

@@ -2,7 +2,6 @@
Location adapters for converting between domain-specific models and UnifiedLocation.
"""
from django.db.models import QuerySet
from django.urls import reverse
@@ -45,15 +44,9 @@ class BaseLocationAdapter:
class ParkLocationAdapter(BaseLocationAdapter):
"""Converts Park/ParkLocation to UnifiedLocation."""
def to_unified_location(
self, location_obj: ParkLocation
) -> UnifiedLocation | None:
def to_unified_location(self, location_obj: ParkLocation) -> UnifiedLocation | None:
"""Convert ParkLocation to UnifiedLocation."""
if (
not location_obj.point
or location_obj.latitude is None
or location_obj.longitude is None
):
if not location_obj.point or location_obj.latitude is None or location_obj.longitude is None:
return None
park = location_obj.park
@@ -67,17 +60,11 @@ class ParkLocationAdapter(BaseLocationAdapter):
metadata={
"status": getattr(park, "status", "UNKNOWN"),
"rating": (
float(park.average_rating)
if hasattr(park, "average_rating") and park.average_rating
else None
float(park.average_rating) if hasattr(park, "average_rating") and park.average_rating else None
),
"ride_count": getattr(park, "ride_count", 0),
"coaster_count": getattr(park, "coaster_count", 0),
"operator": (
park.operator.name
if hasattr(park, "operator") and park.operator
else None
),
"operator": (park.operator.name if hasattr(park, "operator") and park.operator else None),
"city": location_obj.city,
"state": location_obj.state,
"country": location_obj.country,
@@ -85,18 +72,14 @@ class ParkLocationAdapter(BaseLocationAdapter):
type_data={
"slug": park.slug,
"opening_date": (
park.opening_date.isoformat()
if hasattr(park, "opening_date") and park.opening_date
else None
park.opening_date.isoformat() if hasattr(park, "opening_date") and park.opening_date else None
),
"website": getattr(park, "website", ""),
"operating_season": getattr(park, "operating_season", ""),
"highway_exit": location_obj.highway_exit,
"parking_notes": location_obj.parking_notes,
"best_arrival_time": (
location_obj.best_arrival_time.strftime("%H:%M")
if location_obj.best_arrival_time
else None
location_obj.best_arrival_time.strftime("%H:%M") if location_obj.best_arrival_time else None
),
"seasonal_notes": location_obj.seasonal_notes,
"url": self._get_park_url(park),
@@ -111,9 +94,7 @@ class ParkLocationAdapter(BaseLocationAdapter):
filters: MapFilters | None = None,
) -> QuerySet:
"""Get optimized queryset for park locations."""
queryset = ParkLocation.objects.select_related("park", "park__operator").filter(
point__isnull=False
)
queryset = ParkLocation.objects.select_related("park", "park__operator").filter(point__isnull=False)
# Spatial filtering
if bounds:
@@ -139,17 +120,9 @@ class ParkLocationAdapter(BaseLocationAdapter):
weight = 1
if hasattr(park, "ride_count") and park.ride_count and park.ride_count > 20:
weight += 2
if (
hasattr(park, "coaster_count")
and park.coaster_count
and park.coaster_count > 5
):
if hasattr(park, "coaster_count") and park.coaster_count and park.coaster_count > 5:
weight += 1
if (
hasattr(park, "average_rating")
and park.average_rating
and park.average_rating > 4.0
):
if hasattr(park, "average_rating") and park.average_rating and park.average_rating > 4.0:
weight += 1
return min(weight, 5) # Cap at 5
@@ -176,15 +149,9 @@ class ParkLocationAdapter(BaseLocationAdapter):
class RideLocationAdapter(BaseLocationAdapter):
"""Converts Ride/RideLocation to UnifiedLocation."""
def to_unified_location(
self, location_obj: RideLocation
) -> UnifiedLocation | None:
def to_unified_location(self, location_obj: RideLocation) -> UnifiedLocation | None:
"""Convert RideLocation to UnifiedLocation."""
if (
not location_obj.point
or location_obj.latitude is None
or location_obj.longitude is None
):
if not location_obj.point or location_obj.latitude is None or location_obj.longitude is None:
return None
ride = location_obj.ride
@@ -194,11 +161,7 @@ class RideLocationAdapter(BaseLocationAdapter):
type=LocationType.RIDE,
name=ride.name,
coordinates=[float(location_obj.latitude), float(location_obj.longitude)],
address=(
f"{location_obj.park_area}, {ride.park.name}"
if location_obj.park_area
else ride.park.name
),
address=(f"{location_obj.park_area}, {ride.park.name}" if location_obj.park_area else ride.park.name),
metadata={
"park_id": ride.park.id,
"park_name": ride.park.name,
@@ -206,22 +169,16 @@ class RideLocationAdapter(BaseLocationAdapter):
"ride_type": getattr(ride, "ride_type", "Unknown"),
"status": getattr(ride, "status", "UNKNOWN"),
"rating": (
float(ride.average_rating)
if hasattr(ride, "average_rating") and ride.average_rating
else None
float(ride.average_rating) if hasattr(ride, "average_rating") and ride.average_rating else None
),
"manufacturer": (
getattr(ride, "manufacturer", {}).get("name")
if hasattr(ride, "manufacturer")
else None
getattr(ride, "manufacturer", {}).get("name") if hasattr(ride, "manufacturer") else None
),
},
type_data={
"slug": ride.slug,
"opening_date": (
ride.opening_date.isoformat()
if hasattr(ride, "opening_date") and ride.opening_date
else None
ride.opening_date.isoformat() if hasattr(ride, "opening_date") and ride.opening_date else None
),
"height_requirement": getattr(ride, "height_requirement", ""),
"duration_minutes": getattr(ride, "duration_minutes", None),
@@ -240,9 +197,9 @@ class RideLocationAdapter(BaseLocationAdapter):
filters: MapFilters | None = None,
) -> QuerySet:
"""Get optimized queryset for ride locations."""
queryset = RideLocation.objects.select_related(
"ride", "ride__park", "ride__park__operator"
).filter(point__isnull=False)
queryset = RideLocation.objects.select_related("ride", "ride__park", "ride__park__operator").filter(
point__isnull=False
)
# Spatial filtering
if bounds:
@@ -263,11 +220,7 @@ class RideLocationAdapter(BaseLocationAdapter):
ride_type = getattr(ride, "ride_type", "").lower()
if "coaster" in ride_type or "roller" in ride_type:
weight += 1
if (
hasattr(ride, "average_rating")
and ride.average_rating
and ride.average_rating > 4.0
):
if hasattr(ride, "average_rating") and ride.average_rating and ride.average_rating > 4.0:
weight += 1
return min(weight, 3) # Cap at 3 for rides
@@ -292,9 +245,7 @@ class RideLocationAdapter(BaseLocationAdapter):
class CompanyLocationAdapter(BaseLocationAdapter):
"""Converts Company/CompanyHeadquarters to UnifiedLocation."""
def to_unified_location(
self, location_obj: CompanyHeadquarters
) -> UnifiedLocation | None:
def to_unified_location(self, location_obj: CompanyHeadquarters) -> UnifiedLocation | None:
"""Convert CompanyHeadquarters to UnifiedLocation."""
# Note: CompanyHeadquarters doesn't have coordinates, so we need to geocode
# For now, we'll skip companies without coordinates
@@ -312,13 +263,9 @@ class CompanyLocationAdapter(BaseLocationAdapter):
# Company-specific filters
if filters:
if filters.company_roles:
queryset = queryset.filter(
company__roles__overlap=filters.company_roles
)
queryset = queryset.filter(company__roles__overlap=filters.company_roles)
if filters.search_query:
queryset = queryset.filter(
company__name__icontains=filters.search_query
)
queryset = queryset.filter(company__name__icontains=filters.search_query)
if filters.country:
queryset = queryset.filter(country=filters.country)
if filters.city:
@@ -354,11 +301,7 @@ class LocationAbstractionLayer:
all_locations = []
# Determine which location types to include
location_types = (
filters.location_types
if filters and filters.location_types
else set(LocationType)
)
location_types = filters.location_types if filters and filters.location_types else set(LocationType)
for location_type in location_types:
adapter = self.adapters[location_type]
@@ -379,25 +322,17 @@ class LocationAbstractionLayer:
queryset = adapter.get_queryset(bounds, filters)
return adapter.bulk_convert(queryset)
def get_location_by_id(
self, location_type: LocationType, location_id: int
) -> UnifiedLocation | None:
def get_location_by_id(self, location_type: LocationType, location_id: int) -> UnifiedLocation | None:
"""Get single location with full details."""
adapter = self.adapters[location_type]
try:
if location_type == LocationType.PARK:
obj = ParkLocation.objects.select_related("park", "park__operator").get(
park_id=location_id
)
obj = ParkLocation.objects.select_related("park", "park__operator").get(park_id=location_id)
elif location_type == LocationType.RIDE:
obj = RideLocation.objects.select_related("ride", "ride__park").get(
ride_id=location_id
)
obj = RideLocation.objects.select_related("ride", "ride__park").get(ride_id=location_id)
elif location_type == LocationType.COMPANY:
obj = CompanyHeadquarters.objects.select_related("company").get(
company_id=location_id
)
obj = CompanyHeadquarters.objects.select_related("company").get(company_id=location_id)
# LocationType.GENERIC removed - generic location app deprecated
else:
return None

View File

@@ -128,9 +128,7 @@ class LocationSearchService:
# Apply max results limit
return results[: filters.max_results]
def _search_parks(
self, filters: LocationSearchFilters
) -> list[LocationSearchResult]:
def _search_parks(self, filters: LocationSearchFilters) -> list[LocationSearchResult]:
"""Search parks with location data."""
queryset = Park.objects.select_related("location", "operator").all()
@@ -154,9 +152,9 @@ class LocationSearchService:
# Add distance annotation if proximity search
if filters.location_point and filters.include_distance:
queryset = queryset.annotate(
distance=Distance("location__point", filters.location_point)
).order_by("distance")
queryset = queryset.annotate(distance=Distance("location__point", filters.location_point)).order_by(
"distance"
)
# Convert to search results
results = []
@@ -166,11 +164,7 @@ class LocationSearchService:
object_id=park.id,
name=park.name,
description=park.description,
url=(
park.get_absolute_url()
if hasattr(park, "get_absolute_url")
else None
),
url=(park.get_absolute_url() if hasattr(park, "get_absolute_url") else None),
status=park.get_status_display(),
rating=(float(park.average_rating) if park.average_rating else None),
tags=["park", park.status.lower()],
@@ -187,20 +181,14 @@ class LocationSearchService:
result.country = location.country
# Add distance if proximity search
if (
filters.location_point
and filters.include_distance
and hasattr(park, "distance")
):
if filters.location_point and filters.include_distance and hasattr(park, "distance"):
result.distance_km = float(park.distance.km)
results.append(result)
return results
def _search_rides(
self, filters: LocationSearchFilters
) -> list[LocationSearchResult]:
def _search_rides(self, filters: LocationSearchFilters) -> list[LocationSearchResult]:
"""Search rides with location data."""
queryset = Ride.objects.select_related("park", "location").all()
@@ -223,9 +211,9 @@ class LocationSearchService:
# Add distance annotation if proximity search
if filters.location_point and filters.include_distance:
queryset = queryset.annotate(
distance=Distance("location__point", filters.location_point)
).order_by("distance")
queryset = queryset.annotate(distance=Distance("location__point", filters.location_point)).order_by(
"distance"
)
# Convert to search results
results = []
@@ -235,11 +223,7 @@ class LocationSearchService:
object_id=ride.id,
name=ride.name,
description=ride.description,
url=(
ride.get_absolute_url()
if hasattr(ride, "get_absolute_url")
else None
),
url=(ride.get_absolute_url() if hasattr(ride, "get_absolute_url") else None),
status=ride.status,
tags=[
"ride",
@@ -253,18 +237,10 @@ class LocationSearchService:
location = ride.location
result.latitude = location.latitude
result.longitude = location.longitude
result.address = (
f"{ride.park.name} - {location.park_area}"
if location.park_area
else ride.park.name
)
result.address = f"{ride.park.name} - {location.park_area}" if location.park_area else ride.park.name
# Add distance if proximity search
if (
filters.location_point
and filters.include_distance
and hasattr(ride, "distance")
):
if filters.location_point and filters.include_distance and hasattr(ride, "distance"):
result.distance_km = float(ride.distance.km)
# Fall back to park location if no specific ride location
@@ -281,16 +257,12 @@ class LocationSearchService:
return results
def _search_companies(
self, filters: LocationSearchFilters
) -> list[LocationSearchResult]:
def _search_companies(self, filters: LocationSearchFilters) -> list[LocationSearchResult]:
"""Search companies with headquarters location data."""
queryset = Company.objects.select_related("headquarters").all()
# Apply location filters
queryset = self._apply_location_filters(
queryset, filters, "headquarters__point"
)
queryset = self._apply_location_filters(queryset, filters, "headquarters__point")
# Apply text search
if filters.search_query:
@@ -309,9 +281,9 @@ class LocationSearchService:
# Add distance annotation if proximity search
if filters.location_point and filters.include_distance:
queryset = queryset.annotate(
distance=Distance("headquarters__point", filters.location_point)
).order_by("distance")
queryset = queryset.annotate(distance=Distance("headquarters__point", filters.location_point)).order_by(
"distance"
)
# Convert to search results
results = []
@@ -321,11 +293,7 @@ class LocationSearchService:
object_id=company.id,
name=company.name,
description=company.description,
url=(
company.get_absolute_url()
if hasattr(company, "get_absolute_url")
else None
),
url=(company.get_absolute_url() if hasattr(company, "get_absolute_url") else None),
tags=["company"] + (company.roles or []),
)
@@ -340,20 +308,14 @@ class LocationSearchService:
result.country = hq.country
# Add distance if proximity search
if (
filters.location_point
and filters.include_distance
and hasattr(company, "distance")
):
if filters.location_point and filters.include_distance and hasattr(company, "distance"):
result.distance_km = float(company.distance.km)
results.append(result)
return results
def _apply_location_filters(
self, queryset, filters: LocationSearchFilters, point_field: str
):
def _apply_location_filters(self, queryset, filters: LocationSearchFilters, point_field: str):
"""Apply common location filters to a queryset."""
# Proximity filter
@@ -371,31 +333,21 @@ class LocationSearchService:
# Geographic filters - adjust field names based on model
if filters.country:
if "headquarters" in point_field:
queryset = queryset.filter(
headquarters__country__icontains=filters.country
)
queryset = queryset.filter(headquarters__country__icontains=filters.country)
else:
location_field = point_field.split("__")[0]
queryset = queryset.filter(
**{f"{location_field}__country__icontains": filters.country}
)
queryset = queryset.filter(**{f"{location_field}__country__icontains": filters.country})
if filters.state:
if "headquarters" in point_field:
queryset = queryset.filter(
headquarters__state_province__icontains=filters.state
)
queryset = queryset.filter(headquarters__state_province__icontains=filters.state)
else:
location_field = point_field.split("__")[0]
queryset = queryset.filter(
**{f"{location_field}__state__icontains": filters.state}
)
queryset = queryset.filter(**{f"{location_field}__state__icontains": filters.state})
if filters.city:
location_field = point_field.split("__")[0]
queryset = queryset.filter(
**{f"{location_field}__city__icontains": filters.city}
)
queryset = queryset.filter(**{f"{location_field}__city__icontains": filters.city})
return queryset
@@ -417,9 +369,7 @@ class LocationSearchService:
# Get park location suggestions
park_locations = ParkLocation.objects.filter(
Q(park__name__icontains=query)
| Q(city__icontains=query)
| Q(state__icontains=query)
Q(park__name__icontains=query) | Q(city__icontains=query) | Q(state__icontains=query)
).select_related("park")[: limit // 3]
for location in park_locations:
@@ -429,11 +379,7 @@ class LocationSearchService:
"name": location.park.name,
"address": location.formatted_address,
"coordinates": location.coordinates,
"url": (
location.park.get_absolute_url()
if hasattr(location.park, "get_absolute_url")
else None
),
"url": (location.park.get_absolute_url() if hasattr(location.park, "get_absolute_url") else None),
}
)

View File

@@ -93,9 +93,7 @@ class MapCacheService:
return ":".join(key_parts)
def get_location_detail_cache_key(
self, location_type: str, location_id: int
) -> str:
def get_location_detail_cache_key(self, location_type: str, location_id: int) -> str:
"""Generate cache key for individual location details."""
return f"{self.DETAIL_PREFIX}:{location_type}:{location_id}"
@@ -137,9 +135,7 @@ class MapCacheService:
except Exception as e:
print(f"Cache write error for clusters {cache_key}: {e}")
def cache_map_response(
self, cache_key: str, response: MapResponse, ttl: int | None = None
) -> None:
def cache_map_response(self, cache_key: str, response: MapResponse, ttl: int | None = None) -> None:
"""Cache complete map response."""
try:
cache_data = response.to_dict()
@@ -212,24 +208,18 @@ class MapCacheService:
self.cache_stats["misses"] += 1
return None
def invalidate_location_cache(
self, location_type: str, location_id: int | None = None
) -> None:
def invalidate_location_cache(self, location_type: str, location_id: int | None = None) -> None:
"""Invalidate cache for specific location or all locations of a type."""
try:
if location_id:
# Invalidate specific location detail
detail_key = self.get_location_detail_cache_key(
location_type, location_id
)
detail_key = self.get_location_detail_cache_key(location_type, location_id)
cache.delete(detail_key)
# Invalidate related location and cluster caches
# In a production system, you'd want more sophisticated cache
# tagging
cache.delete_many(
[f"{self.LOCATIONS_PREFIX}:*", f"{self.CLUSTERS_PREFIX}:*"]
)
cache.delete_many([f"{self.LOCATIONS_PREFIX}:*", f"{self.CLUSTERS_PREFIX}:*"])
self.cache_stats["invalidations"] += 1
@@ -271,11 +261,7 @@ class MapCacheService:
def get_cache_stats(self) -> dict[str, Any]:
"""Get cache performance statistics."""
total_requests = self.cache_stats["hits"] + self.cache_stats["misses"]
hit_rate = (
(self.cache_stats["hits"] / total_requests * 100)
if total_requests > 0
else 0
)
hit_rate = (self.cache_stats["hits"] / total_requests * 100) if total_requests > 0 else 0
return {
"hits": self.cache_stats["hits"],
@@ -408,12 +394,8 @@ class MapCacheService:
def _dict_to_map_response(self, data: dict[str, Any]) -> MapResponse:
"""Convert dictionary back to MapResponse object."""
locations = [
self._dict_to_unified_location(loc) for loc in data.get("locations", [])
]
clusters = [
self._dict_to_cluster_data(cluster) for cluster in data.get("clusters", [])
]
locations = [self._dict_to_unified_location(loc) for loc in data.get("locations", [])]
clusters = [self._dict_to_cluster_data(cluster) for cluster in data.get("clusters", [])]
bounds = None
if data.get("bounds"):

View File

@@ -67,17 +67,13 @@ class UnifiedMapService:
# Generate cache key
cache_key = None
if use_cache:
cache_key = self._generate_cache_key(
bounds, filters, zoom_level, cluster
)
cache_key = self._generate_cache_key(bounds, filters, zoom_level, cluster)
# Try to get from cache first
cached_response = self.cache_service.get_cached_map_response(cache_key)
if cached_response:
cached_response.cache_hit = True
cached_response.query_time_ms = int(
(time.time() - start_time) * 1000
)
cached_response.query_time_ms = int((time.time() - start_time) * 1000)
return cached_response
# Get locations from database
@@ -87,21 +83,15 @@ class UnifiedMapService:
locations = self._apply_smart_limiting(locations, bounds, zoom_level)
# Determine if clustering should be applied
should_cluster = cluster and self.clustering_service.should_cluster(
zoom_level, len(locations)
)
should_cluster = cluster and self.clustering_service.should_cluster(zoom_level, len(locations))
# Apply clustering if needed
clusters = []
if should_cluster:
locations, clusters = self.clustering_service.cluster_locations(
locations, zoom_level, bounds
)
locations, clusters = self.clustering_service.cluster_locations(locations, zoom_level, bounds)
# Calculate response bounds
response_bounds = self._calculate_response_bounds(
locations, clusters, bounds
)
response_bounds = self._calculate_response_bounds(locations, clusters, bounds)
# Create response
response = MapResponse(
@@ -144,9 +134,7 @@ class UnifiedMapService:
cache_hit=False,
)
def get_location_details(
self, location_type: str, location_id: int
) -> UnifiedLocation | None:
def get_location_details(self, location_type: str, location_id: int) -> UnifiedLocation | None:
"""
Get detailed information for a specific location.
@@ -159,18 +147,14 @@ class UnifiedMapService:
"""
try:
# Check cache first
cache_key = self.cache_service.get_location_detail_cache_key(
location_type, location_id
)
cache_key = self.cache_service.get_location_detail_cache_key(location_type, location_id)
cached_locations = self.cache_service.get_cached_locations(cache_key)
if cached_locations:
return cached_locations[0] if cached_locations else None
# Get from database
location_type_enum = LocationType(location_type.lower())
location = self.location_layer.get_location_by_id(
location_type_enum, location_id
)
location = self.location_layer.get_location_by_id(location_type_enum, location_id)
# Cache the result
if location:
@@ -245,19 +229,13 @@ class UnifiedMapService:
"""
try:
bounds = GeoBounds(north=north, south=south, east=east, west=west)
filters = (
MapFilters(location_types=location_types) if location_types else None
)
filters = MapFilters(location_types=location_types) if location_types else None
return self.get_map_data(
bounds=bounds, filters=filters, zoom_level=zoom_level
)
return self.get_map_data(bounds=bounds, filters=filters, zoom_level=zoom_level)
except ValueError:
# Invalid bounds
return MapResponse(
locations=[], clusters=[], total_count=0, filtered_count=0
)
return MapResponse(locations=[], clusters=[], total_count=0, filtered_count=0)
def get_clustered_locations(
self,
@@ -276,9 +254,7 @@ class UnifiedMapService:
Returns:
MapResponse with clustered data
"""
return self.get_map_data(
bounds=bounds, filters=filters, zoom_level=zoom_level, cluster=True
)
return self.get_map_data(bounds=bounds, filters=filters, zoom_level=zoom_level, cluster=True)
def get_locations_by_type(
self,
@@ -299,9 +275,7 @@ class UnifiedMapService:
"""
try:
filters = MapFilters(location_types={location_type})
locations = self.location_layer.get_locations_by_type(
location_type, bounds, filters
)
locations = self.location_layer.get_locations_by_type(location_type, bounds, filters)
if limit:
locations = locations[:limit]
@@ -346,9 +320,7 @@ class UnifiedMapService:
"service_version": "1.0.0",
}
def _get_locations_from_db(
self, bounds: GeoBounds | None, filters: MapFilters | None
) -> list[UnifiedLocation]:
def _get_locations_from_db(self, bounds: GeoBounds | None, filters: MapFilters | None) -> list[UnifiedLocation]:
"""Get locations from database using the abstraction layer."""
return self.location_layer.get_all_locations(bounds, filters)
@@ -363,10 +335,7 @@ class UnifiedMapService:
major_parks = [
loc
for loc in locations
if (
loc.type == LocationType.PARK
and loc.cluster_category in ["major_park", "theme_park"]
)
if (loc.type == LocationType.PARK and loc.cluster_category in ["major_park", "theme_park"])
]
return major_parks[:200]
elif zoom_level < 10: # Regional level
@@ -398,9 +367,7 @@ class UnifiedMapService:
return None
lats, lngs = zip(*all_coords, strict=False)
return GeoBounds(
north=max(lats), south=min(lats), east=max(lngs), west=min(lngs)
)
return GeoBounds(north=max(lats), south=min(lats), east=max(lngs), west=min(lngs))
def _get_applied_filters_list(self, filters: MapFilters | None) -> list[str]:
"""Get list of applied filter types for metadata."""
@@ -438,13 +405,9 @@ class UnifiedMapService:
) -> str:
"""Generate cache key for the request."""
if cluster:
return self.cache_service.get_clusters_cache_key(
bounds, filters, zoom_level
)
return self.cache_service.get_clusters_cache_key(bounds, filters, zoom_level)
else:
return self.cache_service.get_locations_cache_key(
bounds, filters, zoom_level
)
return self.cache_service.get_locations_cache_key(bounds, filters, zoom_level)
def _record_performance_metrics(
self,

View File

@@ -21,9 +21,7 @@ class MediaService:
"""Shared service for media upload and processing operations."""
@staticmethod
def generate_upload_path(
domain: str, identifier: str, filename: str, subdirectory: str | None = None
) -> str:
def generate_upload_path(domain: str, identifier: str, filename: str, subdirectory: str | None = None) -> str:
"""
Generate standardized upload path for media files.
@@ -83,9 +81,7 @@ class MediaService:
"""
try:
# Check file size
max_size = getattr(
settings, "MAX_PHOTO_SIZE", 10 * 1024 * 1024
) # 10MB default
max_size = getattr(settings, "MAX_PHOTO_SIZE", 10 * 1024 * 1024) # 10MB default
if image_file.size > max_size:
return (
False,

View File

@@ -32,7 +32,7 @@ class MediaURLService:
slug = slugify(caption)
# Limit length to avoid overly long URLs
if len(slug) > 50:
slug = slug[:50].rsplit('-', 1)[0] # Cut at word boundary
slug = slug[:50].rsplit("-", 1)[0] # Cut at word boundary
return f"{slug}-{photo_id}.{extension}"
else:
return f"photo-{photo_id}.{extension}"
@@ -55,13 +55,15 @@ class MediaURLService:
# Add variant to filename if not public
if variant != "public":
name, ext = filename.rsplit('.', 1)
name, ext = filename.rsplit(".", 1)
filename = f"{name}-{variant}.{ext}"
return f"/parks/{park_slug}/photos/{filename}"
@staticmethod
def generate_ride_photo_url(park_slug: str, ride_slug: str, caption: str, photo_id: int, variant: str = "public") -> str:
def generate_ride_photo_url(
park_slug: str, ride_slug: str, caption: str, photo_id: int, variant: str = "public"
) -> str:
"""
Generate a friendly URL for a ride photo.
@@ -78,7 +80,7 @@ class MediaURLService:
filename = MediaURLService.generate_friendly_filename(caption, photo_id)
if variant != "public":
name, ext = filename.rsplit('.', 1)
name, ext = filename.rsplit(".", 1)
filename = f"{name}-{variant}.{ext}"
return f"/parks/{park_slug}/rides/{ride_slug}/photos/{filename}"
@@ -95,7 +97,7 @@ class MediaURLService:
Dict with photo_id and variant, or None if parsing fails
"""
# Remove extension
name = filename.rsplit('.', 1)[0]
name = filename.rsplit(".", 1)[0]
# Check for variant suffix
variant = "public"
@@ -104,17 +106,14 @@ class MediaURLService:
for v in variant_patterns:
if name.endswith(f"-{v}"):
variant = v
name = name[:-len(f"-{v}")]
name = name[: -len(f"-{v}")]
break
# Extract photo ID (should be the last number)
match = re.search(r'-(\d+)$', name)
match = re.search(r"-(\d+)$", name)
if match:
photo_id = int(match.group(1))
return {
"photo_id": photo_id,
"variant": variant
}
return {"photo_id": photo_id, "variant": variant}
return None

View File

@@ -53,9 +53,7 @@ def monitor_performance(operation_name: str, **tags):
)
# Log performance data
log_level = (
logging.WARNING if duration > 2.0 or total_queries > 10 else logging.INFO
)
log_level = logging.WARNING if duration > 2.0 or total_queries > 10 else logging.INFO
logger.log(
log_level,
f"Performance: {operation_name} completed in {duration:.3f}s with {
@@ -108,11 +106,7 @@ def track_queries(operation_name: str, warn_threshold: int = 10):
recent_queries = connection.queries[-total_queries:]
query_details = [
{
"sql": (
query["sql"][:200] + "..."
if len(query["sql"]) > 200
else query["sql"]
),
"sql": (query["sql"][:200] + "..." if len(query["sql"]) > 200 else query["sql"]),
"time": float(query["time"]),
}
for query in recent_queries
@@ -127,14 +121,12 @@ def track_queries(operation_name: str, warn_threshold: int = 10):
if total_queries > warn_threshold or execution_time > 1.0:
logger.warning(
f"Performance concern in {operation_name}: "
f"{total_queries} queries, {execution_time:.2f}s",
f"Performance concern in {operation_name}: " f"{total_queries} queries, {execution_time:.2f}s",
extra=performance_data,
)
else:
logger.debug(
f"Query tracking for {operation_name}: "
f"{total_queries} queries, {execution_time:.2f}s",
f"Query tracking for {operation_name}: " f"{total_queries} queries, {execution_time:.2f}s",
extra=performance_data,
)
@@ -221,9 +213,7 @@ class PerformanceProfiler:
"total_queries": total_queries,
"checkpoints": self.checkpoints,
"memory_usage": self.memory_usage,
"queries_per_second": (
total_queries / total_duration if total_duration > 0 else 0
),
"queries_per_second": (total_queries / total_duration if total_duration > 0 else 0),
}
# Calculate checkpoint intervals
@@ -237,8 +227,7 @@ class PerformanceProfiler:
"from": prev["name"],
"to": curr["name"],
"duration": curr["elapsed_seconds"] - prev["elapsed_seconds"],
"queries": curr["queries_since_start"]
- prev["queries_since_start"],
"queries": curr["queries_since_start"] - prev["queries_since_start"],
}
)
report["checkpoint_intervals"] = intervals
@@ -288,9 +277,7 @@ class DatabaseQueryAnalyzer:
query_types[query_type] = query_types.get(query_type, 0) + 1
# Find slow queries (top 10% by time)
sorted_queries = sorted(
queries, key=lambda q: float(q.get("time", 0)), reverse=True
)
sorted_queries = sorted(queries, key=lambda q: float(q.get("time", 0)), reverse=True)
slow_query_count = max(1, query_count // 10)
slow_queries = sorted_queries[:slow_query_count]
@@ -302,9 +289,7 @@ class DatabaseQueryAnalyzer:
signature = " ".join(sql.split()) # Normalize whitespace
query_signatures[signature] = query_signatures.get(signature, 0) + 1
duplicates = {
sig: count for sig, count in query_signatures.items() if count > 1
}
duplicates = {sig: count for sig, count in query_signatures.items() if count > 1}
analysis = {
"total_queries": query_count,
@@ -313,21 +298,13 @@ class DatabaseQueryAnalyzer:
"query_types": query_types,
"slow_queries": [
{
"sql": (
q.get("sql", "")[:200] + "..."
if len(q.get("sql", "")) > 200
else q.get("sql", "")
),
"sql": (q.get("sql", "")[:200] + "..." if len(q.get("sql", "")) > 200 else q.get("sql", "")),
"time": float(q.get("time", 0)),
}
for q in slow_queries
],
"duplicate_query_count": len(duplicates),
"duplicate_queries": (
duplicates
if len(duplicates) <= 10
else dict(list(duplicates.items())[:10])
),
"duplicate_queries": (duplicates if len(duplicates) <= 10 else dict(list(duplicates.items())[:10])),
}
return analysis
@@ -348,9 +325,7 @@ def monitor_function_performance(operation_name: str | None = None):
@wraps(func)
def wrapper(*args, **kwargs):
name = operation_name or f"{func.__module__}.{func.__name__}"
with monitor_performance(
name, function=func.__name__, module=func.__module__
):
with monitor_performance(name, function=func.__name__, module=func.__module__):
return func(*args, **kwargs)
return wrapper

View File

@@ -74,9 +74,7 @@ class TrendingService:
if not force_refresh:
cached_result = cache.get(cache_key)
if cached_result is not None:
self.logger.debug(
f"Returning cached trending results for {content_type}"
)
self.logger.debug(f"Returning cached trending results for {content_type}")
return cached_result
self.logger.info(f"Getting trending content for {content_type}")
@@ -86,15 +84,11 @@ class TrendingService:
trending_items = []
if content_type in ["all", "parks"]:
park_items = self._calculate_trending_parks(
limit * 2 if content_type == "all" else limit
)
park_items = self._calculate_trending_parks(limit * 2 if content_type == "all" else limit)
trending_items.extend(park_items)
if content_type in ["all", "rides"]:
ride_items = self._calculate_trending_rides(
limit * 2 if content_type == "all" else limit
)
ride_items = self._calculate_trending_rides(limit * 2 if content_type == "all" else limit)
trending_items.extend(ride_items)
# Sort by trending score and apply limit
@@ -107,9 +101,7 @@ class TrendingService:
# Cache results
cache.set(cache_key, formatted_results, self.CACHE_TTL)
self.logger.info(
f"Calculated {len(formatted_results)} trending items for {content_type}"
)
self.logger.info(f"Calculated {len(formatted_results)} trending items for {content_type}")
return formatted_results
except Exception as e:
@@ -140,9 +132,7 @@ class TrendingService:
if not force_refresh:
cached_result = cache.get(cache_key)
if cached_result is not None:
self.logger.debug(
f"Returning cached new content results for {content_type}"
)
self.logger.debug(f"Returning cached new content results for {content_type}")
return cached_result
self.logger.info(f"Getting new content for {content_type}")
@@ -153,15 +143,11 @@ class TrendingService:
new_items = []
if content_type in ["all", "parks"]:
parks = self._get_new_parks(
cutoff_date, limit * 2 if content_type == "all" else limit
)
parks = self._get_new_parks(cutoff_date, limit * 2 if content_type == "all" else limit)
new_items.extend(parks)
if content_type in ["all", "rides"]:
rides = self._get_new_rides(
cutoff_date, limit * 2 if content_type == "all" else limit
)
rides = self._get_new_rides(cutoff_date, limit * 2 if content_type == "all" else limit)
new_items.extend(rides)
# Sort by date added (most recent first) and apply limit
@@ -174,9 +160,7 @@ class TrendingService:
# Cache results
cache.set(cache_key, formatted_results, 1800) # Cache for 30 minutes
self.logger.info(
f"Calculated {len(formatted_results)} new items for {content_type}"
)
self.logger.info(f"Calculated {len(formatted_results)} new items for {content_type}")
return formatted_results
except Exception as e:
@@ -185,9 +169,7 @@ class TrendingService:
def _calculate_trending_parks(self, limit: int) -> list[dict[str, Any]]:
"""Calculate trending scores for parks."""
parks = Park.objects.filter(status="OPERATING").select_related(
"location", "operator", "card_image"
)
parks = Park.objects.filter(status="OPERATING").select_related("location", "operator", "card_image")
trending_parks = []
@@ -216,9 +198,7 @@ class TrendingService:
# Get card image URL
card_image_url = ""
if park.card_image and hasattr(park.card_image, "image"):
card_image_url = (
park.card_image.image.url if park.card_image.image else ""
)
card_image_url = park.card_image.image.url if park.card_image.image else ""
# Get primary company (operator)
primary_company = park.operator.name if park.operator else ""
@@ -233,14 +213,8 @@ class TrendingService:
"slug": park.slug,
"park": park.name, # For parks, park field is the park name itself
"category": "park",
"rating": (
float(park.average_rating)
if park.average_rating
else 0.0
),
"date_opened": (
opening_date.isoformat() if opening_date else ""
),
"rating": (float(park.average_rating) if park.average_rating else 0.0),
"date_opened": (opening_date.isoformat() if opening_date else ""),
"url": park.url,
"card_image": card_image_url,
"city": city,
@@ -256,9 +230,7 @@ class TrendingService:
def _calculate_trending_rides(self, limit: int) -> list[dict[str, Any]]:
"""Calculate trending scores for rides."""
rides = Ride.objects.filter(status="OPERATING").select_related(
"park", "park__location", "card_image"
)
rides = Ride.objects.filter(status="OPERATING").select_related("park", "park__location", "card_image")
trending_rides = []
@@ -274,9 +246,7 @@ class TrendingService:
# Get card image URL
card_image_url = ""
if ride.card_image and hasattr(ride.card_image, "image"):
card_image_url = (
ride.card_image.image.url if ride.card_image.image else ""
)
card_image_url = ride.card_image.image.url if ride.card_image.image else ""
trending_rides.append(
{
@@ -288,14 +258,8 @@ class TrendingService:
"slug": ride.slug,
"park": ride.park.name if ride.park else "",
"category": "ride",
"rating": (
float(ride.average_rating)
if ride.average_rating
else 0.0
),
"date_opened": (
opening_date.isoformat() if opening_date else ""
),
"rating": (float(ride.average_rating) if ride.average_rating else 0.0),
"date_opened": (opening_date.isoformat() if opening_date else ""),
"url": ride.url,
"park_url": ride.park.url if ride.park else "",
"card_image": card_image_url,
@@ -347,23 +311,17 @@ class TrendingService:
return final_score
except Exception as e:
self.logger.error(
f"Error calculating score for {content_type} {content_obj.id}: {e}"
)
self.logger.error(f"Error calculating score for {content_type} {content_obj.id}: {e}")
return 0.0
def _calculate_view_growth_score(
self, content_type: ContentType, object_id: int
) -> float:
def _calculate_view_growth_score(self, content_type: ContentType, object_id: int) -> float:
"""Calculate normalized view growth score."""
try:
current_views, previous_views, growth_percentage = (
PageView.get_views_growth(
content_type,
object_id,
self.CURRENT_PERIOD_HOURS,
self.PREVIOUS_PERIOD_HOURS,
)
current_views, previous_views, growth_percentage = PageView.get_views_growth(
content_type,
object_id,
self.CURRENT_PERIOD_HOURS,
self.PREVIOUS_PERIOD_HOURS,
)
if previous_views == 0:
@@ -372,9 +330,7 @@ class TrendingService:
# Normalize growth percentage to 0-1 scale
# 100% growth = 0.5, 500% growth = 1.0
normalized_growth = (
min(growth_percentage / 500.0, 1.0) if growth_percentage > 0 else 0.0
)
normalized_growth = min(growth_percentage / 500.0, 1.0) if growth_percentage > 0 else 0.0
return max(normalized_growth, 0.0)
except Exception as e:
@@ -421,11 +377,7 @@ class TrendingService:
elif days_since_added <= 30:
return 1.0 - (days_since_added / 30.0) * 0.2 # 1.0 to 0.8
elif days_since_added <= self.RECENCY_BASELINE_DAYS:
return (
0.8
- ((days_since_added - 30) / (self.RECENCY_BASELINE_DAYS - 30))
* 0.7
) # 0.8 to 0.1
return 0.8 - ((days_since_added - 30) / (self.RECENCY_BASELINE_DAYS - 30)) * 0.7 # 0.8 to 0.1
else:
return 0.0
@@ -433,9 +385,7 @@ class TrendingService:
self.logger.warning(f"Error calculating recency score: {e}")
return 0.5
def _calculate_popularity_score(
self, content_type: ContentType, object_id: int
) -> float:
def _calculate_popularity_score(self, content_type: ContentType, object_id: int) -> float:
"""Calculate popularity score based on total view count."""
try:
total_views = PageView.get_total_views_count(
@@ -461,8 +411,7 @@ class TrendingService:
"""Get recently added parks."""
new_parks = (
Park.objects.filter(
Q(created_at__gte=cutoff_date)
| Q(opening_date__gte=cutoff_date.date()),
Q(created_at__gte=cutoff_date) | Q(opening_date__gte=cutoff_date.date()),
status="OPERATING",
)
.select_related("location", "operator", "card_image")
@@ -473,7 +422,7 @@ class TrendingService:
for park in new_parks:
date_added = park.opening_date or park.created_at
# Handle datetime to date conversion
if date_added:
if date_added: # noqa: SIM102
# If it's a datetime, convert to date
if isinstance(date_added, datetime):
date_added = date_added.date()
@@ -500,9 +449,7 @@ class TrendingService:
# Get card image URL
card_image_url = ""
if park.card_image and hasattr(park.card_image, "image"):
card_image_url = (
park.card_image.image.url if park.card_image.image else ""
)
card_image_url = park.card_image.image.url if park.card_image.image else ""
# Get primary company (operator)
primary_company = park.operator.name if park.operator else ""
@@ -533,8 +480,7 @@ class TrendingService:
"""Get recently added rides."""
new_rides = (
Ride.objects.filter(
Q(created_at__gte=cutoff_date)
| Q(opening_date__gte=cutoff_date.date()),
Q(created_at__gte=cutoff_date) | Q(opening_date__gte=cutoff_date.date()),
status="OPERATING",
)
.select_related("park", "park__location", "card_image")
@@ -543,11 +489,9 @@ class TrendingService:
results = []
for ride in new_rides:
date_added = getattr(ride, "opening_date", None) or getattr(
ride, "created_at", None
)
date_added = getattr(ride, "opening_date", None) or getattr(ride, "created_at", None)
# Handle datetime to date conversion
if date_added:
if date_added: # noqa: SIM102
# If it's a datetime, convert to date
if isinstance(date_added, datetime):
date_added = date_added.date()
@@ -561,9 +505,7 @@ class TrendingService:
# Get card image URL
card_image_url = ""
if ride.card_image and hasattr(ride.card_image, "image"):
card_image_url = (
ride.card_image.image.url if ride.card_image.image else ""
)
card_image_url = ride.card_image.image.url if ride.card_image.image else ""
results.append(
{
@@ -584,9 +526,7 @@ class TrendingService:
return results
def _format_trending_results(
self, trending_items: list[dict[str, Any]]
) -> list[dict[str, Any]]:
def _format_trending_results(self, trending_items: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Format trending results for frontend consumption."""
formatted_results = []
@@ -595,13 +535,11 @@ class TrendingService:
# Get view change for display
content_obj = item["content_object"]
ct = ContentType.objects.get_for_model(content_obj)
current_views, previous_views, growth_percentage = (
PageView.get_views_growth(
ct,
content_obj.id,
self.CURRENT_PERIOD_HOURS,
self.PREVIOUS_PERIOD_HOURS,
)
current_views, previous_views, growth_percentage = PageView.get_views_growth(
ct,
content_obj.id,
self.CURRENT_PERIOD_HOURS,
self.PREVIOUS_PERIOD_HOURS,
)
# Format exactly as frontend expects
@@ -614,9 +552,7 @@ class TrendingService:
"rank": rank,
"views": current_views,
"views_change": (
f"+{growth_percentage:.1f}%"
if growth_percentage > 0
else f"{growth_percentage:.1f}%"
f"+{growth_percentage:.1f}%" if growth_percentage > 0 else f"{growth_percentage:.1f}%"
),
"slug": item["slug"],
"date_opened": item["date_opened"],
@@ -649,9 +585,7 @@ class TrendingService:
return formatted_results
def _format_new_content_results(
self, new_items: list[dict[str, Any]]
) -> list[dict[str, Any]]:
def _format_new_content_results(self, new_items: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Format new content results for frontend consumption."""
formatted_results = []

View File

@@ -1,4 +1,5 @@
"""State machine utilities for core app."""
from .builder import (
StateTransitionBuilder,
determine_method_name_for_transition,

View File

@@ -60,6 +60,7 @@ See Also:
- apps.core.choices.registry: Central choice registry
- apps.core.state_machine.guards: Guard extraction from metadata
"""
from typing import Any
from django.core.exceptions import ImproperlyConfigured
@@ -129,9 +130,7 @@ class StateTransitionBuilder:
# Validate choice group exists
group = registry.get(choice_group, domain)
if group is None:
raise ImproperlyConfigured(
f"Choice group '{choice_group}' not found in domain '{domain}'"
)
raise ImproperlyConfigured(f"Choice group '{choice_group}' not found in domain '{domain}'")
self.choices = registry.get_choices(choice_group, domain)
@@ -172,20 +171,15 @@ class StateTransitionBuilder:
# Validate all target states exist
for target in transitions:
target_choice = registry.get_choice(
self.choice_group, target, self.domain
)
target_choice = registry.get_choice(self.choice_group, target, self.domain)
if target_choice is None:
raise ImproperlyConfigured(
f"State '{state_value}' references non-existent "
f"transition target '{target}'"
f"State '{state_value}' references non-existent " f"transition target '{target}'"
)
return transitions
def extract_permission_requirements(
self, state_value: str
) -> dict[str, bool]:
def extract_permission_requirements(self, state_value: str) -> dict[str, bool]:
"""
Extract permission requirements from metadata.
@@ -198,9 +192,7 @@ class StateTransitionBuilder:
metadata = self.get_choice_metadata(state_value)
return {
"requires_moderator": metadata.get("requires_moderator", False),
"requires_admin_approval": metadata.get(
"requires_admin_approval", False
),
"requires_admin_approval": metadata.get("requires_admin_approval", False),
}
def is_terminal_state(self, state_value: str) -> bool:

View File

@@ -181,10 +181,7 @@ class TransitionContext:
return self.model_class.__name__
def __str__(self) -> str:
return (
f"TransitionContext({self.model_name}.{self.field_name}: "
f"{self.source_state}{self.target_state})"
)
return f"TransitionContext({self.model_name}.{self.field_name}: " f"{self.source_state}{self.target_state})"
class BaseTransitionCallback(ABC):
@@ -324,9 +321,9 @@ class CallbackRegistration:
return False
if self.field_name != field_name:
return False
if self.source != '*' and self.source != source:
if self.source != "*" and self.source != source:
return False
return not (self.target != '*' and self.target != target)
return not (self.target != "*" and self.target != target)
class TransitionCallbackRegistry:
@@ -337,10 +334,10 @@ class TransitionCallbackRegistry:
for specific transitions.
"""
_instance: Optional['TransitionCallbackRegistry'] = None
_instance: Optional["TransitionCallbackRegistry"] = None
_initialized: bool = False
def __new__(cls) -> 'TransitionCallbackRegistry':
def __new__(cls) -> "TransitionCallbackRegistry":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
@@ -483,10 +480,7 @@ class TransitionCallbackRegistry:
try:
# Check if callback should execute
if not callback.should_execute(context):
logger.debug(
f"Skipping callback {callback.name} - "
f"should_execute returned False"
)
logger.debug(f"Skipping callback {callback.name} - " f"should_execute returned False")
continue
# Execute callback
@@ -498,30 +492,24 @@ class TransitionCallbackRegistry:
result = callback.execute(context)
if not result:
logger.warning(
f"Callback {callback.name} returned False for {context}"
)
logger.warning(f"Callback {callback.name} returned False for {context}")
failures.append((callback, None))
overall_success = False
if not callback.continue_on_error:
logger.error(
f"Aborting callback chain - {callback.name} failed "
f"and continue_on_error=False"
f"Aborting callback chain - {callback.name} failed " f"and continue_on_error=False"
)
break
except Exception as e:
logger.exception(
f"Callback {callback.name} raised exception for {context}: {e}"
)
logger.exception(f"Callback {callback.name} raised exception for {context}: {e}")
failures.append((callback, e))
overall_success = False
if not callback.continue_on_error:
logger.error(
f"Aborting callback chain - {callback.name} raised exception "
f"and continue_on_error=False"
f"Aborting callback chain - {callback.name} raised exception " f"and continue_on_error=False"
)
break
@@ -540,10 +528,7 @@ class TransitionCallbackRegistry:
self._callbacks[stage] = []
else:
for stage in CallbackStage:
self._callbacks[stage] = [
r for r in self._callbacks[stage]
if r.model_class != model_class
]
self._callbacks[stage] = [r for r in self._callbacks[stage] if r.model_class != model_class]
def get_all_registrations(
self,
@@ -563,10 +548,7 @@ class TransitionCallbackRegistry:
filtered = {}
for stage, registrations in self._callbacks.items():
filtered[stage] = [
r for r in registrations
if r.model_class == model_class
]
filtered[stage] = [r for r in registrations if r.model_class == model_class]
return filtered
@classmethod
@@ -601,9 +583,7 @@ def register_pre_callback(
callback: PreTransitionCallback,
) -> None:
"""Convenience function to register a pre-transition callback."""
callback_registry.register(
model_class, field_name, source, target, callback, CallbackStage.PRE
)
callback_registry.register(model_class, field_name, source, target, callback, CallbackStage.PRE)
def register_post_callback(
@@ -614,9 +594,7 @@ def register_post_callback(
callback: PostTransitionCallback,
) -> None:
"""Convenience function to register a post-transition callback."""
callback_registry.register(
model_class, field_name, source, target, callback, CallbackStage.POST
)
callback_registry.register(model_class, field_name, source, target, callback, CallbackStage.POST)
def register_error_callback(
@@ -627,6 +605,4 @@ def register_error_callback(
callback: ErrorTransitionCallback,
) -> None:
"""Convenience function to register an error callback."""
callback_registry.register(
model_class, field_name, source, target, callback, CallbackStage.ERROR
)
callback_registry.register(model_class, field_name, source, target, callback, CallbackStage.ERROR)

View File

@@ -44,8 +44,8 @@ class CacheInvalidationCallback(PostTransitionCallback):
def should_execute(self, context: TransitionContext) -> bool:
"""Check if cache invalidation is enabled."""
callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
if not callback_settings.get('cache_invalidation_enabled', True):
callback_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {})
if not callback_settings.get("cache_invalidation_enabled", True):
logger.debug("Cache invalidation disabled via settings")
return False
return True
@@ -54,6 +54,7 @@ class CacheInvalidationCallback(PostTransitionCallback):
"""Get the EnhancedCacheService instance."""
try:
from apps.core.services.enhanced_cache_service import EnhancedCacheService
return EnhancedCacheService()
except ImportError:
logger.warning("EnhancedCacheService not available")
@@ -85,11 +86,7 @@ class CacheInvalidationCallback(PostTransitionCallback):
substituted = set()
for pattern in all_patterns:
substituted.add(
pattern
.replace('{id}', instance_id)
.replace('{model}', model_name)
)
substituted.add(pattern.replace("{id}", instance_id).replace("{model}", model_name))
return substituted
@@ -108,20 +105,13 @@ class CacheInvalidationCallback(PostTransitionCallback):
cache_service.invalidate_pattern(pattern)
logger.debug(f"Invalidated cache pattern: {pattern}")
except Exception as e:
logger.warning(
f"Failed to invalidate cache pattern {pattern}: {e}"
)
logger.warning(f"Failed to invalidate cache pattern {pattern}: {e}")
logger.info(
f"Cache invalidation completed for {context}: "
f"{len(patterns)} patterns"
)
logger.info(f"Cache invalidation completed for {context}: " f"{len(patterns)} patterns")
return True
except Exception as e:
logger.exception(
f"Failed to invalidate cache for {context}: {e}"
)
logger.exception(f"Failed to invalidate cache for {context}: {e}")
return False
def _fallback_invalidation(self, context: TransitionContext) -> bool:
@@ -133,8 +123,7 @@ class CacheInvalidationCallback(PostTransitionCallback):
# Django's default cache doesn't support pattern deletion
# Log a warning and return True (don't fail the transition)
logger.warning(
f"EnhancedCacheService not available, skipping pattern "
f"invalidation for {len(patterns)} patterns"
f"EnhancedCacheService not available, skipping pattern " f"invalidation for {len(patterns)} patterns"
)
return True
@@ -155,13 +144,13 @@ class ModelCacheInvalidation(CacheInvalidationCallback):
# Default patterns by model type
MODEL_PATTERNS = {
'Park': ['*park:{id}*', '*parks*', 'geo:*'],
'Ride': ['*ride:{id}*', '*rides*', '*park:*', 'geo:*'],
'EditSubmission': ['*submission:{id}*', '*moderation*'],
'PhotoSubmission': ['*photo:{id}*', '*moderation*'],
'ModerationReport': ['*report:{id}*', '*moderation*'],
'ModerationQueue': ['*queue*', '*moderation*'],
'BulkOperation': ['*operation:{id}*', '*moderation*'],
"Park": ["*park:{id}*", "*parks*", "geo:*"],
"Ride": ["*ride:{id}*", "*rides*", "*park:*", "geo:*"],
"EditSubmission": ["*submission:{id}*", "*moderation*"],
"PhotoSubmission": ["*photo:{id}*", "*moderation*"],
"ModerationReport": ["*report:{id}*", "*moderation*"],
"ModerationQueue": ["*queue*", "*moderation*"],
"BulkOperation": ["*operation:{id}*", "*moderation*"],
}
def __init__(self, **kwargs):
@@ -178,7 +167,7 @@ class ModelCacheInvalidation(CacheInvalidationCallback):
# Substitute {id} placeholder
instance_id = str(context.instance.pk)
for pattern in model_patterns:
base_patterns.append(pattern.replace('{id}', instance_id))
base_patterns.append(pattern.replace("{id}", instance_id))
return base_patterns
@@ -217,14 +206,14 @@ class RelatedModelCacheInvalidation(CacheInvalidationCallback):
continue
# Handle foreign key relationships
if hasattr(related_obj, 'pk'):
if hasattr(related_obj, "pk"):
related_model = type(related_obj).__name__.lower()
related_id = related_obj.pk
patterns.append(f"*{related_model}:{related_id}*")
patterns.append(f"*{related_model}_{related_id}*")
# Handle many-to-many relationships
elif hasattr(related_obj, 'all'):
elif hasattr(related_obj, "all"):
for obj in related_obj.all():
related_model = type(obj).__name__.lower()
related_id = obj.pk
@@ -293,7 +282,7 @@ class APICacheInvalidation(CacheInvalidationCallback):
**kwargs: Additional arguments.
"""
super().__init__(**kwargs)
self.api_prefixes = api_prefixes or ['api:*']
self.api_prefixes = api_prefixes or ["api:*"]
self.include_geo_cache = include_geo_cache
def _get_all_patterns(self, context: TransitionContext) -> set[str]:
@@ -306,8 +295,8 @@ class APICacheInvalidation(CacheInvalidationCallback):
# Add geo cache if requested
if self.include_geo_cache:
patterns.add('geo:*')
patterns.add('map:*')
patterns.add("geo:*")
patterns.add("map:*")
# Add model-specific API patterns
model_name = context.model_name.lower()
@@ -329,10 +318,10 @@ class ParkCacheInvalidation(CacheInvalidationCallback):
def __init__(self, **kwargs):
super().__init__(
patterns=[
'*park:{id}*',
'*parks*',
'api:*',
'geo:*',
"*park:{id}*",
"*parks*",
"api:*",
"geo:*",
],
**kwargs,
)
@@ -346,10 +335,10 @@ class RideCacheInvalidation(CacheInvalidationCallback):
def __init__(self, **kwargs):
super().__init__(
patterns=[
'*ride:{id}*',
'*rides*',
'api:*',
'geo:*',
"*ride:{id}*",
"*rides*",
"api:*",
"geo:*",
],
**kwargs,
)
@@ -359,9 +348,9 @@ class RideCacheInvalidation(CacheInvalidationCallback):
patterns = super()._get_instance_patterns(context)
# Invalidate parent park's cache
park = getattr(context.instance, 'park', None)
park = getattr(context.instance, "park", None)
if park:
park_id = park.pk if hasattr(park, 'pk') else park
park_id = park.pk if hasattr(park, "pk") else park
patterns.append(f"*park:{park_id}*")
patterns.append(f"*park_{park_id}*")
@@ -376,9 +365,9 @@ class ModerationCacheInvalidation(CacheInvalidationCallback):
def __init__(self, **kwargs):
super().__init__(
patterns=[
'*submission*',
'*moderation*',
'api:moderation*',
"*submission*",
"*moderation*",
"api:moderation*",
],
**kwargs,
)

View File

@@ -53,17 +53,15 @@ class NotificationCallback(PostTransitionCallback):
def should_execute(self, context: TransitionContext) -> bool:
"""Check if notifications are enabled and recipient exists."""
# Check if notifications are disabled in settings
callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
if not callback_settings.get('notifications_enabled', True):
callback_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {})
if not callback_settings.get("notifications_enabled", True):
logger.debug("Notifications disabled via settings")
return False
# Check if recipient exists
recipient = self._get_recipient(context.instance)
if not recipient:
logger.debug(
f"No recipient found at {self.recipient_field} for {context}"
)
logger.debug(f"No recipient found at {self.recipient_field} for {context}")
return False
return True
@@ -76,6 +74,7 @@ class NotificationCallback(PostTransitionCallback):
"""Get the NotificationService instance."""
try:
from apps.accounts.services.notification_service import NotificationService
return NotificationService()
except ImportError:
logger.warning("NotificationService not available")
@@ -86,18 +85,16 @@ class NotificationCallback(PostTransitionCallback):
extra_data = {}
if self.include_transition_data:
extra_data['transition'] = {
'source_state': context.source_state,
'target_state': context.target_state,
'field_name': context.field_name,
'timestamp': context.timestamp.isoformat(),
extra_data["transition"] = {
"source_state": context.source_state,
"target_state": context.target_state,
"field_name": context.field_name,
"timestamp": context.timestamp.isoformat(),
}
if context.user:
extra_data['transition']['by_user_id'] = context.user.id
extra_data['transition']['by_username'] = getattr(
context.user, 'username', str(context.user)
)
extra_data["transition"]["by_user_id"] = context.user.id
extra_data["transition"]["by_username"] = getattr(context.user, "username", str(context.user))
# Include any extra data from the context
extra_data.update(context.extra_data)
@@ -112,10 +109,7 @@ class NotificationCallback(PostTransitionCallback):
def _get_notification_message(self, context: TransitionContext) -> str:
"""Get the notification message based on context."""
model_name = context.model_name
return (
f"The {model_name} has transitioned from {context.source_state} "
f"to {context.target_state}."
)
return f"The {model_name} has transitioned from {context.source_state} " f"to {context.target_state}."
def execute(self, context: TransitionContext) -> bool:
"""Execute the notification callback."""
@@ -140,16 +134,11 @@ class NotificationCallback(PostTransitionCallback):
extra_data=extra_data,
)
logger.info(
f"Created {self.notification_type} notification for "
f"{recipient} on {context}"
)
logger.info(f"Created {self.notification_type} notification for " f"{recipient} on {context}")
return True
except Exception as e:
logger.exception(
f"Failed to create notification for {context}: {e}"
)
logger.exception(f"Failed to create notification for {context}: {e}")
return False
@@ -176,8 +165,8 @@ class SubmissionApprovedNotification(NotificationCallback):
def _get_submission_type(self, context: TransitionContext) -> str:
"""Get the submission type from context or instance."""
# Try to get from extra_data first
if 'submission_type' in context.extra_data:
return context.extra_data['submission_type']
if "submission_type" in context.extra_data:
return context.extra_data["submission_type"]
# Fall back to model name
return self.submission_type or context.model_name.lower()
@@ -193,10 +182,10 @@ class SubmissionApprovedNotification(NotificationCallback):
try:
submission_type = self._get_submission_type(context)
additional_message = context.extra_data.get('comment', '')
additional_message = context.extra_data.get("comment", "")
# Use the specific method if available
if hasattr(notification_service, 'create_submission_approved_notification'):
if hasattr(notification_service, "create_submission_approved_notification"):
notification_service.create_submission_approved_notification(
user=recipient,
submission_object=context.instance,
@@ -215,15 +204,11 @@ class SubmissionApprovedNotification(NotificationCallback):
extra_data=extra_data,
)
logger.info(
f"Created approval notification for {recipient} on {context}"
)
logger.info(f"Created approval notification for {recipient} on {context}")
return True
except Exception as e:
logger.exception(
f"Failed to create approval notification for {context}: {e}"
)
logger.exception(f"Failed to create approval notification for {context}: {e}")
return False
@@ -250,8 +235,8 @@ class SubmissionRejectedNotification(NotificationCallback):
def _get_submission_type(self, context: TransitionContext) -> str:
"""Get the submission type from context or instance."""
# Try to get from extra_data first
if 'submission_type' in context.extra_data:
return context.extra_data['submission_type']
if "submission_type" in context.extra_data:
return context.extra_data["submission_type"]
# Fall back to model name
return self.submission_type or context.model_name.lower()
@@ -268,11 +253,11 @@ class SubmissionRejectedNotification(NotificationCallback):
try:
submission_type = self._get_submission_type(context)
# Extract rejection reason from extra_data
rejection_reason = context.extra_data.get('reason', 'No reason provided')
additional_message = context.extra_data.get('comment', '')
rejection_reason = context.extra_data.get("reason", "No reason provided")
additional_message = context.extra_data.get("comment", "")
# Use the specific method if available
if hasattr(notification_service, 'create_submission_rejected_notification'):
if hasattr(notification_service, "create_submission_rejected_notification"):
notification_service.create_submission_rejected_notification(
user=recipient,
submission_object=context.instance,
@@ -291,15 +276,11 @@ class SubmissionRejectedNotification(NotificationCallback):
extra_data=extra_data,
)
logger.info(
f"Created rejection notification for {recipient} on {context}"
)
logger.info(f"Created rejection notification for {recipient} on {context}")
return True
except Exception as e:
logger.exception(
f"Failed to create rejection notification for {context}: {e}"
)
logger.exception(f"Failed to create rejection notification for {context}: {e}")
return False
@@ -326,6 +307,7 @@ class SubmissionEscalatedNotification(NotificationCallback):
"""Get admin users to notify."""
try:
from django.contrib.auth import get_user_model
user_model = get_user_model()
return user_model.objects.filter(is_staff=True, is_active=True)
except Exception as e:
@@ -340,9 +322,9 @@ class SubmissionEscalatedNotification(NotificationCallback):
try:
extra_data = self._build_extra_data(context)
escalation_reason = context.extra_data.get('reason', '')
escalation_reason = context.extra_data.get("reason", "")
if escalation_reason:
extra_data['escalation_reason'] = escalation_reason
extra_data["escalation_reason"] = escalation_reason
title = f"{context.model_name} escalated for review"
message = f"A {context.model_name} has been escalated and requires attention."
@@ -361,9 +343,7 @@ class SubmissionEscalatedNotification(NotificationCallback):
related_object=context.instance,
extra_data=extra_data,
)
logger.info(
f"Created escalation notifications for {admins.count()} admins"
)
logger.info(f"Created escalation notifications for {admins.count()} admins")
else:
# Notify the submitter
recipient = self._get_recipient(context.instance)
@@ -376,16 +356,12 @@ class SubmissionEscalatedNotification(NotificationCallback):
related_object=context.instance,
extra_data=extra_data,
)
logger.info(
f"Created escalation notification for {recipient}"
)
logger.info(f"Created escalation notification for {recipient}")
return True
except Exception as e:
logger.exception(
f"Failed to create escalation notification for {context}: {e}"
)
logger.exception(f"Failed to create escalation notification for {context}: {e}")
return False
@@ -415,16 +391,14 @@ class StatusChangeNotification(NotificationCallback):
notification_type="status_change",
**kwargs,
)
self.significant_states = significant_states or [
'CLOSED_PERM', 'DEMOLISHED', 'RELOCATED'
]
self.significant_states = significant_states or ["CLOSED_PERM", "DEMOLISHED", "RELOCATED"]
self.notify_admins = notify_admins
def should_execute(self, context: TransitionContext) -> bool:
"""Only execute for significant state changes."""
# Check if notifications are disabled
callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
if not callback_settings.get('notifications_enabled', True):
callback_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {})
if not callback_settings.get("notifications_enabled", True):
return False
# Only notify for significant status changes
@@ -441,16 +415,13 @@ class StatusChangeNotification(NotificationCallback):
try:
extra_data = self._build_extra_data(context)
extra_data['entity_type'] = context.model_name
extra_data['entity_id'] = context.instance.pk
extra_data["entity_type"] = context.model_name
extra_data["entity_id"] = context.instance.pk
# Build title and message
entity_name = getattr(context.instance, 'name', str(context.instance))
entity_name = getattr(context.instance, "name", str(context.instance))
title = f"{context.model_name} status changed to {context.target_state}"
message = (
f"{entity_name} has changed status from {context.source_state} "
f"to {context.target_state}."
)
message = f"{entity_name} has changed status from {context.source_state} " f"to {context.target_state}."
# Notify admin users
admins = self._get_admin_users()
@@ -471,15 +442,14 @@ class StatusChangeNotification(NotificationCallback):
return True
except Exception as e:
logger.exception(
f"Failed to create status change notification for {context}: {e}"
)
logger.exception(f"Failed to create status change notification for {context}: {e}")
return False
def _get_admin_users(self):
"""Get admin users to notify."""
try:
from django.contrib.auth import get_user_model
user_model = get_user_model()
return user_model.objects.filter(is_staff=True, is_active=True)
except Exception as e:
@@ -499,13 +469,13 @@ class ModerationNotificationCallback(NotificationCallback):
# Mapping of (model_name, target_state) to notification type
NOTIFICATION_MAPPING = {
('ModerationReport', 'UNDER_REVIEW'): 'report_under_review',
('ModerationReport', 'RESOLVED'): 'report_resolved',
('ModerationQueue', 'IN_PROGRESS'): 'queue_in_progress',
('ModerationQueue', 'COMPLETED'): 'queue_completed',
('BulkOperation', 'RUNNING'): 'bulk_operation_started',
('BulkOperation', 'COMPLETED'): 'bulk_operation_completed',
('BulkOperation', 'FAILED'): 'bulk_operation_failed',
("ModerationReport", "UNDER_REVIEW"): "report_under_review",
("ModerationReport", "RESOLVED"): "report_resolved",
("ModerationQueue", "IN_PROGRESS"): "queue_in_progress",
("ModerationQueue", "COMPLETED"): "queue_completed",
("BulkOperation", "RUNNING"): "bulk_operation_started",
("BulkOperation", "COMPLETED"): "bulk_operation_completed",
("BulkOperation", "FAILED"): "bulk_operation_failed",
}
def __init__(self, **kwargs):
@@ -522,7 +492,7 @@ class ModerationNotificationCallback(NotificationCallback):
def _get_recipient(self, instance: models.Model) -> Any | None:
"""Get the appropriate recipient based on model type."""
# Try common recipient fields
for field in ['reporter', 'assigned_to', 'created_by', 'submitted_by']:
for field in ["reporter", "assigned_to", "created_by", "submitted_by"]:
recipient = getattr(instance, field, None)
if recipient:
return recipient
@@ -531,31 +501,28 @@ class ModerationNotificationCallback(NotificationCallback):
def _get_notification_title(self, context: TransitionContext, notification_type: str) -> str:
"""Get the notification title based on notification type."""
titles = {
'report_under_review': 'Your report is under review',
'report_resolved': 'Your report has been resolved',
'queue_in_progress': 'Moderation queue item in progress',
'queue_completed': 'Moderation queue item completed',
'bulk_operation_started': 'Bulk operation started',
'bulk_operation_completed': 'Bulk operation completed',
'bulk_operation_failed': 'Bulk operation failed',
"report_under_review": "Your report is under review",
"report_resolved": "Your report has been resolved",
"queue_in_progress": "Moderation queue item in progress",
"queue_completed": "Moderation queue item completed",
"bulk_operation_started": "Bulk operation started",
"bulk_operation_completed": "Bulk operation completed",
"bulk_operation_failed": "Bulk operation failed",
}
return titles.get(notification_type, f"{context.model_name} status updated")
def _get_notification_message(self, context: TransitionContext, notification_type: str) -> str:
"""Get the notification message based on notification type."""
messages = {
'report_under_review': 'Your moderation report is now being reviewed by our team.',
'report_resolved': 'Your moderation report has been reviewed and resolved.',
'queue_in_progress': 'A moderation queue item is now being processed.',
'queue_completed': 'A moderation queue item has been completed.',
'bulk_operation_started': 'Your bulk operation has started processing.',
'bulk_operation_completed': 'Your bulk operation has completed successfully.',
'bulk_operation_failed': 'Your bulk operation encountered an error and could not complete.',
"report_under_review": "Your moderation report is now being reviewed by our team.",
"report_resolved": "Your moderation report has been reviewed and resolved.",
"queue_in_progress": "A moderation queue item is now being processed.",
"queue_completed": "A moderation queue item has been completed.",
"bulk_operation_started": "Your bulk operation has started processing.",
"bulk_operation_completed": "Your bulk operation has completed successfully.",
"bulk_operation_failed": "Your bulk operation encountered an error and could not complete.",
}
return messages.get(
notification_type,
f"The {context.model_name} has been updated to {context.target_state}."
)
return messages.get(notification_type, f"The {context.model_name} has been updated to {context.target_state}.")
def execute(self, context: TransitionContext) -> bool:
"""Execute the moderation notification."""
@@ -565,10 +532,7 @@ class ModerationNotificationCallback(NotificationCallback):
notification_type = self._get_notification_type(context)
if not notification_type:
logger.debug(
f"No notification type defined for {context.model_name} "
f"{context.target_state}"
)
logger.debug(f"No notification type defined for {context.model_name} " f"{context.target_state}")
return True # Not an error, just no notification needed
recipient = self._get_recipient(context.instance)
@@ -587,13 +551,9 @@ class ModerationNotificationCallback(NotificationCallback):
extra_data=extra_data,
)
logger.info(
f"Created {notification_type} notification for {recipient}"
)
logger.info(f"Created {notification_type} notification for {recipient}")
return True
except Exception as e:
logger.exception(
f"Failed to create moderation notification for {context}: {e}"
)
logger.exception(f"Failed to create moderation notification for {context}: {e}")
return False

View File

@@ -45,8 +45,8 @@ class RelatedModelUpdateCallback(PostTransitionCallback):
def should_execute(self, context: TransitionContext) -> bool:
"""Check if related updates are enabled."""
callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
if not callback_settings.get('related_updates_enabled', True):
callback_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {})
if not callback_settings.get("related_updates_enabled", True):
logger.debug("Related model updates disabled via settings")
return False
return True
@@ -77,9 +77,7 @@ class RelatedModelUpdateCallback(PostTransitionCallback):
return self.perform_update(context)
except Exception as e:
logger.exception(
f"Failed to update related models for {context}: {e}"
)
logger.exception(f"Failed to update related models for {context}: {e}")
return False
@@ -94,10 +92,10 @@ class ParkCountUpdateCallback(RelatedModelUpdateCallback):
name: str = "ParkCountUpdateCallback"
# Status values that count as "active" rides
ACTIVE_STATUSES = {'OPERATING', 'SEASONAL', 'UNDER_CONSTRUCTION'}
ACTIVE_STATUSES = {"OPERATING", "SEASONAL", "UNDER_CONSTRUCTION"}
# Status values that indicate a ride is no longer countable
INACTIVE_STATUSES = {'CLOSED_PERM', 'DEMOLISHED', 'RELOCATED', 'REMOVED'}
INACTIVE_STATUSES = {"CLOSED_PERM", "DEMOLISHED", "RELOCATED", "REMOVED"}
def should_execute(self, context: TransitionContext) -> bool:
"""Only execute when status affects ride counts."""
@@ -115,14 +113,14 @@ class ParkCountUpdateCallback(RelatedModelUpdateCallback):
return source_affects or target_affects
# Category value for roller coasters (from rides domain choices)
COASTER_CATEGORY = 'RC'
COASTER_CATEGORY = "RC"
def perform_update(self, context: TransitionContext) -> bool:
"""Update park ride counts."""
instance = context.instance
# Get the parent park
park = getattr(instance, 'park', None)
park = getattr(instance, "park", None)
if not park:
logger.debug(f"No park found for ride {instance.pk}")
return True
@@ -133,22 +131,17 @@ class ParkCountUpdateCallback(RelatedModelUpdateCallback):
from apps.rides.models.rides import Ride
# Get the park ID (handle both object and ID)
park_id = park.pk if hasattr(park, 'pk') else park
park_id = park.pk if hasattr(park, "pk") else park
# Calculate new counts efficiently
ride_queryset = Ride.objects.filter(park_id=park_id)
# Count active rides
active_statuses = list(self.ACTIVE_STATUSES)
ride_count = ride_queryset.filter(
status__in=active_statuses
).count()
ride_count = ride_queryset.filter(status__in=active_statuses).count()
# Count active coasters (category='RC' for Roller Coaster)
coaster_count = ride_queryset.filter(
status__in=active_statuses,
category=self.COASTER_CATEGORY
).count()
coaster_count = ride_queryset.filter(status__in=active_statuses, category=self.COASTER_CATEGORY).count()
# Update park counts
Park.objects.filter(id=park_id).update(
@@ -156,16 +149,11 @@ class ParkCountUpdateCallback(RelatedModelUpdateCallback):
coaster_count=coaster_count,
)
logger.info(
f"Updated park {park_id} counts: "
f"ride_count={ride_count}, coaster_count={coaster_count}"
)
logger.info(f"Updated park {park_id} counts: " f"ride_count={ride_count}, coaster_count={coaster_count}")
return True
except Exception as e:
logger.exception(
f"Failed to update park counts for {instance.pk}: {e}"
)
logger.exception(f"Failed to update park counts for {instance.pk}: {e}")
return False
@@ -184,20 +172,16 @@ class SearchTextUpdateCallback(RelatedModelUpdateCallback):
instance = context.instance
# Check if instance has search_text field
if not hasattr(instance, 'search_text'):
logger.debug(
f"{context.model_name} has no search_text field"
)
if not hasattr(instance, "search_text"):
logger.debug(f"{context.model_name} has no search_text field")
return True
try:
# Call the model's update_search_text method if available
if hasattr(instance, 'update_search_text'):
if hasattr(instance, "update_search_text"):
instance.update_search_text()
instance.save(update_fields=['search_text'])
logger.info(
f"Updated search_text for {context.model_name} {instance.pk}"
)
instance.save(update_fields=["search_text"])
logger.info(f"Updated search_text for {context.model_name} {instance.pk}")
else:
# Build search text manually
self._build_search_text(instance, context)
@@ -205,9 +189,7 @@ class SearchTextUpdateCallback(RelatedModelUpdateCallback):
return True
except Exception as e:
logger.exception(
f"Failed to update search_text for {instance.pk}: {e}"
)
logger.exception(f"Failed to update search_text for {instance.pk}: {e}")
return False
def _build_search_text(
@@ -219,7 +201,7 @@ class SearchTextUpdateCallback(RelatedModelUpdateCallback):
parts = []
# Common searchable fields
for field in ['name', 'title', 'description', 'location']:
for field in ["name", "title", "description", "location"]:
value = getattr(instance, field, None)
if value:
parts.append(str(value))
@@ -228,15 +210,15 @@ class SearchTextUpdateCallback(RelatedModelUpdateCallback):
status_field = getattr(instance, context.field_name, None)
if status_field:
# Try to get the display label
display_method = f'get_{context.field_name}_display'
display_method = f"get_{context.field_name}_display"
if hasattr(instance, display_method):
parts.append(getattr(instance, display_method)())
else:
parts.append(str(status_field))
# Update search_text
instance.search_text = ' '.join(parts)
instance.save(update_fields=['search_text'])
instance.search_text = " ".join(parts)
instance.save(update_fields=["search_text"])
class ComputedFieldUpdateCallback(RelatedModelUpdateCallback):
@@ -280,7 +262,7 @@ class ComputedFieldUpdateCallback(RelatedModelUpdateCallback):
# Update specific fields
updated_fields = []
for field_name in self.computed_fields:
update_method_name = f'compute_{field_name}'
update_method_name = f"compute_{field_name}"
if hasattr(instance, update_method_name):
method = getattr(instance, update_method_name)
if callable(method):
@@ -291,17 +273,12 @@ class ComputedFieldUpdateCallback(RelatedModelUpdateCallback):
# Save updated fields
if updated_fields:
instance.save(update_fields=updated_fields)
logger.info(
f"Updated computed fields {updated_fields} for "
f"{context.model_name} {instance.pk}"
)
logger.info(f"Updated computed fields {updated_fields} for " f"{context.model_name} {instance.pk}")
return True
except Exception as e:
logger.exception(
f"Failed to update computed fields for {instance.pk}: {e}"
)
logger.exception(f"Failed to update computed fields for {instance.pk}: {e}")
return False
@@ -320,7 +297,7 @@ class RideStatusUpdateCallback(RelatedModelUpdateCallback):
return False
# Only execute for Ride model
return context.model_name == 'Ride'
return context.model_name == "Ride"
def perform_update(self, context: TransitionContext) -> bool:
"""Perform ride-specific status updates."""
@@ -329,22 +306,18 @@ class RideStatusUpdateCallback(RelatedModelUpdateCallback):
try:
# Handle CLOSING → post_closing_status transition
if context.source_state == 'CLOSING' and target != 'CLOSING':
post_closing_status = getattr(instance, 'post_closing_status', None)
if context.source_state == "CLOSING" and target != "CLOSING":
post_closing_status = getattr(instance, "post_closing_status", None)
if post_closing_status and target == post_closing_status:
# Clear post_closing_status after applying it
instance.post_closing_status = None
instance.save(update_fields=['post_closing_status'])
logger.info(
f"Cleared post_closing_status for ride {instance.pk}"
)
instance.save(update_fields=["post_closing_status"])
logger.info(f"Cleared post_closing_status for ride {instance.pk}")
return True
except Exception as e:
logger.exception(
f"Failed to update ride status fields for {instance.pk}: {e}"
)
logger.exception(f"Failed to update ride status fields for {instance.pk}: {e}")
return False
@@ -362,9 +335,7 @@ class ModerationQueueUpdateCallback(RelatedModelUpdateCallback):
# Only for submission and report models
model_name = context.model_name
return model_name in (
'EditSubmission', 'PhotoSubmission', 'ModerationReport'
)
return model_name in ("EditSubmission", "PhotoSubmission", "ModerationReport")
def perform_update(self, context: TransitionContext) -> bool:
"""Update moderation queue entries."""
@@ -373,15 +344,13 @@ class ModerationQueueUpdateCallback(RelatedModelUpdateCallback):
try:
# Mark related queue items as completed when submission is resolved
if target in ('APPROVED', 'REJECTED', 'RESOLVED'):
if target in ("APPROVED", "REJECTED", "RESOLVED"):
self._update_queue_items(instance, context)
return True
except Exception as e:
logger.exception(
f"Failed to update moderation queue for {instance.pk}: {e}"
)
logger.exception(f"Failed to update moderation queue for {instance.pk}: {e}")
return False
def _update_queue_items(
@@ -401,20 +370,18 @@ class ModerationQueueUpdateCallback(RelatedModelUpdateCallback):
queue_items = ModerationQueue.objects.filter(
content_type_id=content_type_id,
object_id=instance.pk,
status='IN_PROGRESS',
status="IN_PROGRESS",
)
for item in queue_items:
if hasattr(item, 'complete'):
if hasattr(item, "complete"):
item.complete(user=context.user)
else:
item.status = 'COMPLETED'
item.save(update_fields=['status'])
item.status = "COMPLETED"
item.save(update_fields=["status"])
if queue_items.exists():
logger.info(
f"Marked {queue_items.count()} queue items as completed"
)
logger.info(f"Marked {queue_items.count()} queue items as completed")
except ImportError:
logger.debug("ModerationQueue model not available")
@@ -425,6 +392,7 @@ class ModerationQueueUpdateCallback(RelatedModelUpdateCallback):
"""Get content type ID for the instance."""
try:
from django.contrib.contenttypes.models import ContentType
content_type = ContentType.objects.get_for_model(type(instance))
return content_type.pk
except Exception:

View File

@@ -33,7 +33,7 @@ class ModelCallbackConfig:
"""Configuration for all callbacks on a model."""
model_name: str
field_name: str = 'status'
field_name: str = "status"
transitions: dict[tuple, TransitionCallbackConfig] = field(default_factory=dict)
default_config: TransitionCallbackConfig = field(default_factory=TransitionCallbackConfig)
@@ -53,12 +53,12 @@ class CallbackConfig:
# Default settings
DEFAULT_SETTINGS = {
'enabled': True,
'notifications_enabled': True,
'cache_invalidation_enabled': True,
'related_updates_enabled': True,
'debug_mode': False,
'log_callbacks': False,
"enabled": True,
"notifications_enabled": True,
"cache_invalidation_enabled": True,
"related_updates_enabled": True,
"debug_mode": False,
"log_callbacks": False,
}
# Model-specific configurations
@@ -70,7 +70,7 @@ class CallbackConfig:
def _load_settings(self) -> dict[str, Any]:
"""Load settings from Django configuration."""
django_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
django_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {})
merged = dict(self.DEFAULT_SETTINGS)
merged.update(django_settings)
return merged
@@ -78,123 +78,123 @@ class CallbackConfig:
def _build_model_configs(self) -> dict[str, ModelCallbackConfig]:
"""Build model-specific configurations."""
return {
'EditSubmission': ModelCallbackConfig(
model_name='EditSubmission',
field_name='status',
"EditSubmission": ModelCallbackConfig(
model_name="EditSubmission",
field_name="status",
transitions={
('PENDING', 'APPROVED'): TransitionCallbackConfig(
notification_template='submission_approved',
cache_patterns=['*submission*', '*moderation*'],
("PENDING", "APPROVED"): TransitionCallbackConfig(
notification_template="submission_approved",
cache_patterns=["*submission*", "*moderation*"],
),
('PENDING', 'REJECTED'): TransitionCallbackConfig(
notification_template='submission_rejected',
cache_patterns=['*submission*', '*moderation*'],
("PENDING", "REJECTED"): TransitionCallbackConfig(
notification_template="submission_rejected",
cache_patterns=["*submission*", "*moderation*"],
),
('PENDING', 'ESCALATED'): TransitionCallbackConfig(
notification_template='submission_escalated',
cache_patterns=['*submission*', '*moderation*'],
("PENDING", "ESCALATED"): TransitionCallbackConfig(
notification_template="submission_escalated",
cache_patterns=["*submission*", "*moderation*"],
),
},
),
'PhotoSubmission': ModelCallbackConfig(
model_name='PhotoSubmission',
field_name='status',
"PhotoSubmission": ModelCallbackConfig(
model_name="PhotoSubmission",
field_name="status",
transitions={
('PENDING', 'APPROVED'): TransitionCallbackConfig(
notification_template='photo_approved',
cache_patterns=['*photo*', '*moderation*'],
("PENDING", "APPROVED"): TransitionCallbackConfig(
notification_template="photo_approved",
cache_patterns=["*photo*", "*moderation*"],
),
('PENDING', 'REJECTED'): TransitionCallbackConfig(
notification_template='photo_rejected',
cache_patterns=['*photo*', '*moderation*'],
("PENDING", "REJECTED"): TransitionCallbackConfig(
notification_template="photo_rejected",
cache_patterns=["*photo*", "*moderation*"],
),
},
),
'ModerationReport': ModelCallbackConfig(
model_name='ModerationReport',
field_name='status',
"ModerationReport": ModelCallbackConfig(
model_name="ModerationReport",
field_name="status",
transitions={
('PENDING', 'UNDER_REVIEW'): TransitionCallbackConfig(
notification_template='report_under_review',
cache_patterns=['*report*', '*moderation*'],
("PENDING", "UNDER_REVIEW"): TransitionCallbackConfig(
notification_template="report_under_review",
cache_patterns=["*report*", "*moderation*"],
),
('UNDER_REVIEW', 'RESOLVED'): TransitionCallbackConfig(
notification_template='report_resolved',
cache_patterns=['*report*', '*moderation*'],
("UNDER_REVIEW", "RESOLVED"): TransitionCallbackConfig(
notification_template="report_resolved",
cache_patterns=["*report*", "*moderation*"],
),
},
),
'ModerationQueue': ModelCallbackConfig(
model_name='ModerationQueue',
field_name='status',
"ModerationQueue": ModelCallbackConfig(
model_name="ModerationQueue",
field_name="status",
transitions={
('PENDING', 'IN_PROGRESS'): TransitionCallbackConfig(
notification_template='queue_in_progress',
cache_patterns=['*queue*', '*moderation*'],
("PENDING", "IN_PROGRESS"): TransitionCallbackConfig(
notification_template="queue_in_progress",
cache_patterns=["*queue*", "*moderation*"],
),
('IN_PROGRESS', 'COMPLETED'): TransitionCallbackConfig(
notification_template='queue_completed',
cache_patterns=['*queue*', '*moderation*'],
("IN_PROGRESS", "COMPLETED"): TransitionCallbackConfig(
notification_template="queue_completed",
cache_patterns=["*queue*", "*moderation*"],
),
},
),
'BulkOperation': ModelCallbackConfig(
model_name='BulkOperation',
field_name='status',
"BulkOperation": ModelCallbackConfig(
model_name="BulkOperation",
field_name="status",
transitions={
('PENDING', 'RUNNING'): TransitionCallbackConfig(
notification_template='bulk_operation_started',
cache_patterns=['*operation*', '*moderation*'],
("PENDING", "RUNNING"): TransitionCallbackConfig(
notification_template="bulk_operation_started",
cache_patterns=["*operation*", "*moderation*"],
),
('RUNNING', 'COMPLETED'): TransitionCallbackConfig(
notification_template='bulk_operation_completed',
cache_patterns=['*operation*', '*moderation*'],
("RUNNING", "COMPLETED"): TransitionCallbackConfig(
notification_template="bulk_operation_completed",
cache_patterns=["*operation*", "*moderation*"],
),
('RUNNING', 'FAILED'): TransitionCallbackConfig(
notification_template='bulk_operation_failed',
cache_patterns=['*operation*', '*moderation*'],
("RUNNING", "FAILED"): TransitionCallbackConfig(
notification_template="bulk_operation_failed",
cache_patterns=["*operation*", "*moderation*"],
),
},
),
'Park': ModelCallbackConfig(
model_name='Park',
field_name='status',
"Park": ModelCallbackConfig(
model_name="Park",
field_name="status",
default_config=TransitionCallbackConfig(
cache_patterns=['*park*', 'api:*', 'geo:*'],
cache_patterns=["*park*", "api:*", "geo:*"],
),
transitions={
('*', 'CLOSED_PERM'): TransitionCallbackConfig(
("*", "CLOSED_PERM"): TransitionCallbackConfig(
notifications_enabled=True,
notification_template='park_closed_permanently',
cache_patterns=['*park*', 'api:*', 'geo:*'],
notification_template="park_closed_permanently",
cache_patterns=["*park*", "api:*", "geo:*"],
),
('*', 'OPERATING'): TransitionCallbackConfig(
("*", "OPERATING"): TransitionCallbackConfig(
notifications_enabled=False,
cache_patterns=['*park*', 'api:*', 'geo:*'],
cache_patterns=["*park*", "api:*", "geo:*"],
),
},
),
'Ride': ModelCallbackConfig(
model_name='Ride',
field_name='status',
"Ride": ModelCallbackConfig(
model_name="Ride",
field_name="status",
default_config=TransitionCallbackConfig(
cache_patterns=['*ride*', '*park*', 'api:*', 'geo:*'],
cache_patterns=["*ride*", "*park*", "api:*", "geo:*"],
),
transitions={
('*', 'OPERATING'): TransitionCallbackConfig(
cache_patterns=['*ride*', '*park*', 'api:*', 'geo:*'],
("*", "OPERATING"): TransitionCallbackConfig(
cache_patterns=["*ride*", "*park*", "api:*", "geo:*"],
related_updates_enabled=True,
),
('*', 'CLOSED_PERM'): TransitionCallbackConfig(
cache_patterns=['*ride*', '*park*', 'api:*', 'geo:*'],
("*", "CLOSED_PERM"): TransitionCallbackConfig(
cache_patterns=["*ride*", "*park*", "api:*", "geo:*"],
related_updates_enabled=True,
),
('*', 'DEMOLISHED'): TransitionCallbackConfig(
cache_patterns=['*ride*', '*park*', 'api:*', 'geo:*'],
("*", "DEMOLISHED"): TransitionCallbackConfig(
cache_patterns=["*ride*", "*park*", "api:*", "geo:*"],
related_updates_enabled=True,
),
('*', 'RELOCATED'): TransitionCallbackConfig(
cache_patterns=['*ride*', '*park*', 'api:*', 'geo:*'],
("*", "RELOCATED"): TransitionCallbackConfig(
cache_patterns=["*ride*", "*park*", "api:*", "geo:*"],
related_updates_enabled=True,
),
},
@@ -204,32 +204,32 @@ class CallbackConfig:
@property
def enabled(self) -> bool:
"""Check if callbacks are globally enabled."""
return self._settings.get('enabled', True)
return self._settings.get("enabled", True)
@property
def notifications_enabled(self) -> bool:
"""Check if notification callbacks are enabled."""
return self._settings.get('notifications_enabled', True)
return self._settings.get("notifications_enabled", True)
@property
def cache_invalidation_enabled(self) -> bool:
"""Check if cache invalidation is enabled."""
return self._settings.get('cache_invalidation_enabled', True)
return self._settings.get("cache_invalidation_enabled", True)
@property
def related_updates_enabled(self) -> bool:
"""Check if related model updates are enabled."""
return self._settings.get('related_updates_enabled', True)
return self._settings.get("related_updates_enabled", True)
@property
def debug_mode(self) -> bool:
"""Check if debug mode is enabled."""
return self._settings.get('debug_mode', False)
return self._settings.get("debug_mode", False)
@property
def log_callbacks(self) -> bool:
"""Check if callback logging is enabled."""
return self._settings.get('log_callbacks', False)
return self._settings.get("log_callbacks", False)
def get_config(
self,
@@ -258,12 +258,12 @@ class CallbackConfig:
return config
# Try wildcard source
config = model_config.transitions.get(('*', target))
config = model_config.transitions.get(("*", target))
if config:
return config
# Try wildcard target
config = model_config.transitions.get((source, '*'))
config = model_config.transitions.get((source, "*"))
if config:
return config
@@ -362,9 +362,7 @@ class CallbackConfig:
**kwargs: Configuration values to update.
"""
if model_name not in self._model_configs:
self._model_configs[model_name] = ModelCallbackConfig(
model_name=model_name
)
self._model_configs[model_name] = ModelCallbackConfig(model_name=model_name)
model_config = self._model_configs[model_name]
transition_key = (source, target)
@@ -394,9 +392,9 @@ def get_callback_config() -> CallbackConfig:
__all__ = [
'TransitionCallbackConfig',
'ModelCallbackConfig',
'CallbackConfig',
'callback_config',
'get_callback_config',
"TransitionCallbackConfig",
"ModelCallbackConfig",
"CallbackConfig",
"callback_config",
"get_callback_config",
]

View File

@@ -1,4 +1,5 @@
"""Transition decorator generation for django-fsm integration."""
import logging
from collections.abc import Callable
from functools import wraps
@@ -51,51 +52,42 @@ def with_callbacks(
@wraps(func)
def wrapper(instance, *args, **kwargs):
# Extract user from kwargs
user = kwargs.get('user')
user = kwargs.get("user")
# Get source state before transition
source_state = getattr(instance, field_name, None)
# Get target state from the transition decorator
# The @transition decorator sets _django_fsm_target
target_state = getattr(func, '_django_fsm', {}).get('target', None)
target_state = getattr(func, "_django_fsm", {}).get("target", None)
# If we can't determine the target from decorator metadata,
# we'll capture it after the transition
if target_state is None:
# This happens when decorators are applied in wrong order
logger.debug(
f"Could not determine target state from decorator for {func.__name__}"
)
logger.debug(f"Could not determine target state from decorator for {func.__name__}")
# Create transition context
context = TransitionContext(
instance=instance,
field_name=field_name,
source_state=str(source_state) if source_state else '',
target_state=str(target_state) if target_state else '',
source_state=str(source_state) if source_state else "",
target_state=str(target_state) if target_state else "",
user=user,
extra_data=dict(kwargs),
)
# Execute pre-transition callbacks
pre_success, pre_failures = callback_registry.execute_callbacks(
context, CallbackStage.PRE
)
pre_success, pre_failures = callback_registry.execute_callbacks(context, CallbackStage.PRE)
# If pre-callbacks fail with continue_on_error=False, abort
if not pre_success and pre_failures:
for callback, exc in pre_failures:
if not callback.continue_on_error:
logger.error(
f"Pre-transition callback {callback.name} failed, "
f"aborting transition"
)
logger.error(f"Pre-transition callback {callback.name} failed, " f"aborting transition")
if exc:
raise exc
raise RuntimeError(
f"Pre-transition callback {callback.name} failed"
)
raise RuntimeError(f"Pre-transition callback {callback.name} failed")
# Emit pre-transition signal
if emit_signals:
@@ -114,19 +106,14 @@ def with_callbacks(
# Update context with actual target state after transition
actual_target = getattr(instance, field_name, None)
context.target_state = str(actual_target) if actual_target else ''
context.target_state = str(actual_target) if actual_target else ""
# Execute post-transition callbacks
post_success, post_failures = callback_registry.execute_callbacks(
context, CallbackStage.POST
)
post_success, post_failures = callback_registry.execute_callbacks(context, CallbackStage.POST)
if not post_success:
for callback, exc in post_failures:
logger.warning(
f"Post-transition callback {callback.name} failed "
f"for {context}"
)
for callback, _exc in post_failures:
logger.warning(f"Post-transition callback {callback.name} failed " f"for {context}")
# Emit post-transition signal
if emit_signals:
@@ -236,9 +223,7 @@ def create_transition_method(
on_success(instance, user=user, **kwargs)
transition_method.__name__ = method_name
transition_method.__doc__ = (
f"Transition from {source} to {target} on field {field_name}"
)
transition_method.__doc__ = f"Transition from {source} to {target} on field {field_name}"
# Apply callback wrapper if enabled
if enable_callbacks:
@@ -249,10 +234,10 @@ def create_transition_method(
# Store metadata for callback registration
transition_method._fsm_metadata = {
'source': source,
'target': target,
'field_name': field_name,
'callbacks': callbacks or [],
"source": source,
"target": target,
"field_name": field_name,
"callbacks": callbacks or [],
}
return transition_method
@@ -271,21 +256,21 @@ def register_method_callbacks(
model_class: The model class containing the method.
method: The transition method with _fsm_metadata.
"""
metadata = getattr(method, '_fsm_metadata', None)
if not metadata or not metadata.get('callbacks'):
metadata = getattr(method, "_fsm_metadata", None)
if not metadata or not metadata.get("callbacks"):
return
from .callback_base import CallbackStage, PreTransitionCallback
for callback in metadata['callbacks']:
for callback in metadata["callbacks"]:
# Determine stage from callback type
stage = CallbackStage.PRE if isinstance(callback, PreTransitionCallback) else CallbackStage.POST
callback_registry.register(
model_class=model_class,
field_name=metadata['field_name'],
source=metadata['source'],
target=metadata['target'],
field_name=metadata["field_name"],
source=metadata["source"],
target=metadata["target"],
callback=callback,
stage=stage,
)
@@ -490,9 +475,7 @@ class TransitionMethodFactory:
if docstring:
generic_transition.__doc__ = docstring
else:
generic_transition.__doc__ = (
f"Transition from {source} to {target}"
)
generic_transition.__doc__ = f"Transition from {source} to {target}"
# Apply callback wrapper if enabled
if enable_callbacks:

View File

@@ -12,6 +12,7 @@ Example usage:
'code': e.error_code
}, status=403)
"""
from typing import Any
from django_fsm import TransitionNotAllowed
@@ -214,29 +215,18 @@ ERROR_MESSAGES = {
"You need {required_role} permissions to {action}. "
"Please contact an administrator if you believe this is an error."
),
"PERMISSION_DENIED_OWNERSHIP": (
"You must be the owner of this item to perform this action."
),
"PERMISSION_DENIED_OWNERSHIP": ("You must be the owner of this item to perform this action."),
"PERMISSION_DENIED_ASSIGNMENT": (
"This item must be assigned to you before you can {action}. "
"Please assign it to yourself first."
),
"NO_ASSIGNMENT": (
"This item must be assigned before this action can be performed."
"This item must be assigned to you before you can {action}. " "Please assign it to yourself first."
),
"NO_ASSIGNMENT": ("This item must be assigned before this action can be performed."),
"INVALID_STATE_TRANSITION": (
"This action cannot be performed from the current state. "
"The item is currently '{current_state}' and cannot be modified."
),
"TRANSITION_NOT_AVAILABLE": (
"This {item_type} has already been {state} and cannot be modified."
),
"MISSING_REQUIRED_FIELD": (
"{field_name} is required to complete this action."
),
"EMPTY_REQUIRED_FIELD": (
"{field_name} must not be empty."
),
"TRANSITION_NOT_AVAILABLE": ("This {item_type} has already been {state} and cannot be modified."),
"MISSING_REQUIRED_FIELD": ("{field_name} is required to complete this action."),
"EMPTY_REQUIRED_FIELD": ("{field_name} must not be empty."),
"ESCALATED_REQUIRES_ADMIN": (
"This submission has been escalated and requires admin review. "
"Only administrators can approve or reject escalated items."

View File

@@ -47,6 +47,7 @@ See Also:
- apps.core.choices.registry: The central choice registry
- apps.core.state_machine.mixins.StateMachineMixin: Convenience helpers
"""
from typing import Any
from django.core.exceptions import ValidationError
@@ -138,14 +139,10 @@ class RichFSMField(DjangoFSMField):
choice = registry.get_choice(self.choice_group, value, self.domain)
if choice is None:
raise ValidationError(
f"'{value}' is not a valid state for {self.choice_group}"
)
raise ValidationError(f"'{value}' is not a valid state for {self.choice_group}")
if choice.deprecated and not self.allow_deprecated:
raise ValidationError(
f"'{value}' is deprecated and cannot be used for new entries"
)
raise ValidationError(f"'{value}' is deprecated and cannot be used for new entries")
def get_rich_choice(self, value: str) -> RichChoice | None:
"""Return the RichChoice object for a given state value."""
@@ -155,9 +152,7 @@ class RichFSMField(DjangoFSMField):
"""Return the label for the given state value."""
return registry.get_choice_display(self.choice_group, value, self.domain)
def contribute_to_class(
self, cls: Any, name: str, private_only: bool = False, **kwargs: Any
) -> None:
def contribute_to_class(self, cls: Any, name: str, private_only: bool = False, **kwargs: Any) -> None:
"""Attach helpers to the model for convenience."""
super().contribute_to_class(cls, name, private_only=private_only, **kwargs)

View File

@@ -18,6 +18,7 @@ Example usage:
OwnershipGuard()
], operator='OR')
"""
from collections.abc import Callable
from typing import Any, Optional
@@ -534,7 +535,7 @@ class MetadataGuard:
self._last_error_code = self.ERROR_CODE_EMPTY_FIELD
self._failed_field = field_name
return False
if isinstance(value, (list, dict)) and not value:
if isinstance(value, list | dict) and not value:
self._last_error_code = self.ERROR_CODE_EMPTY_FIELD
self._failed_field = field_name
return False
@@ -787,8 +788,7 @@ def extract_guards_from_metadata(metadata: dict[str, Any]) -> list[Callable]:
# Zero tolerance requires superuser
if zero_tolerance:
guard = PermissionGuard(
requires_superuser=True,
error_message="Zero tolerance violations require superuser permissions"
requires_superuser=True, error_message="Zero tolerance violations require superuser permissions"
)
guards.append(guard)
elif requires_moderator or requires_admin or escalation_level:
@@ -801,7 +801,7 @@ def extract_guards_from_metadata(metadata: dict[str, Any]) -> list[Callable]:
assignment_guard = AssignmentGuard(
require_assignment=True,
allow_admin_override=True,
error_message="This item must be assigned to you before this action can be performed"
error_message="This item must be assigned to you before this action can be performed",
)
guards.append(assignment_guard)
@@ -814,7 +814,7 @@ def extract_guards_from_metadata(metadata: dict[str, Any]) -> list[Callable]:
perm_guard = PermissionGuard(
custom_check=check_permissions,
error_message=f"Missing required permissions: {', '.join(required_permissions)}"
error_message=f"Missing required permissions: {', '.join(required_permissions)}",
)
guards.append(perm_guard)
@@ -1072,7 +1072,7 @@ def has_role(user: Any, required_roles: list[str]) -> bool:
# Only apply if role field is not set
if user_role is None:
# Check for superuser (Django's is_superuser flag)
if hasattr(user, "is_superuser") and user.is_superuser:
if hasattr(user, "is_superuser") and user.is_superuser: # noqa: SIM102
if "SUPERUSER" in required_roles or "ADMIN" in required_roles:
return True
@@ -1248,7 +1248,7 @@ def create_guard_from_drf_permission(
self._last_error_code = "PERMISSION_DENIED"
return False
if hasattr(permission, "has_object_permission"):
if hasattr(permission, "has_object_permission"): # noqa: SIM102
if not permission.has_object_permission(mock_request, None, instance):
self._last_error_code = "OBJECT_PERMISSION_DENIED"
return False

View File

@@ -1,4 +1,5 @@
"""Model integration utilities for applying state machines to Django models."""
from collections.abc import Callable
from typing import Any
@@ -46,18 +47,13 @@ def apply_state_machine(
if not result.is_valid:
error_messages = [str(e) for e in result.errors]
raise ValueError(
"Cannot apply state machine - validation failed:\n"
+ "\n".join(error_messages)
)
raise ValueError("Cannot apply state machine - validation failed:\n" + "\n".join(error_messages))
# Build transition registry
registry_instance.build_registry_from_choices(choice_group, domain)
# Generate and attach transition methods
generate_transition_methods_for_model(
model_class, field_name, choice_group, domain
)
generate_transition_methods_for_model(model_class, field_name, choice_group, domain)
def generate_transition_methods_for_model(
@@ -140,15 +136,10 @@ def generate_transition_methods_for_model(
setattr(model_class, method_name, method)
class StateMachineModelMixin:
"""Mixin providing state machine helper methods for models."""
def get_available_state_transitions(
self, field_name: str = "status"
) -> list[TransitionInfo]:
def get_available_state_transitions(self, field_name: str = "status") -> list[TransitionInfo]:
"""
Get available transitions from current state.
@@ -167,9 +158,7 @@ class StateMachineModelMixin:
domain = field.domain
current_state = getattr(self, field_name)
return registry_instance.get_available_transitions(
choice_group, domain, current_state
)
return registry_instance.get_available_transitions(choice_group, domain, current_state)
def can_transition_to(
self,
@@ -199,9 +188,7 @@ class StateMachineModelMixin:
domain = field.domain
# Check if transition exists in registry
transition = registry_instance.get_transition(
choice_group, domain, current_state, target_state
)
transition = registry_instance.get_transition(choice_group, domain, current_state, target_state)
if not transition:
return False
@@ -216,9 +203,7 @@ class StateMachineModelMixin:
# Use django-fsm's can_proceed
return can_proceed(method)
def get_transition_method(
self, target_state: str, field_name: str = "status"
) -> Callable | None:
def get_transition_method(self, target_state: str, field_name: str = "status") -> Callable | None:
"""
Get the transition method for moving to target state.
@@ -238,9 +223,7 @@ class StateMachineModelMixin:
choice_group = field.choice_group
domain = field.domain
transition = registry_instance.get_transition(
choice_group, domain, current_state, target_state
)
transition = registry_instance.get_transition(choice_group, domain, current_state, target_state)
if not transition:
return None
@@ -270,9 +253,7 @@ class StateMachineModelMixin:
ValueError: If transition is not allowed
"""
if not self.can_transition_to(target_state, field_name, user):
raise ValueError(
f"Cannot transition to {target_state} from current state"
)
raise ValueError(f"Cannot transition to {target_state} from current state")
method = self.get_transition_method(target_state, field_name)
if method is None:
@@ -283,9 +264,7 @@ class StateMachineModelMixin:
return True
def state_machine_model(
field_name: str, choice_group: str, domain: str = "core"
):
def state_machine_model(field_name: str, choice_group: str, domain: str = "core"):
"""
Class decorator to automatically apply state machine to models.
@@ -306,9 +285,7 @@ def state_machine_model(
return decorator
def validate_model_state_machine(
model_class: type[models.Model], field_name: str
) -> bool:
def validate_model_state_machine(model_class: type[models.Model], field_name: str) -> bool:
"""
Ensure model is properly configured with state machine.
@@ -326,13 +303,11 @@ def validate_model_state_machine(
try:
field = model_class._meta.get_field(field_name)
except Exception:
raise ValueError(f"Field {field_name} not found on {model_class}")
raise ValueError(f"Field {field_name} not found on {model_class}") from None
# Check if field has choice_group attribute
if not hasattr(field, "choice_group"):
raise ValueError(
f"Field {field_name} is not a RichFSMField or RichChoiceField"
)
raise ValueError(f"Field {field_name} is not a RichFSMField or RichChoiceField")
# Validate metadata
choice_group = field.choice_group
@@ -343,9 +318,7 @@ def validate_model_state_machine(
if not result.is_valid:
error_messages = [str(e) for e in result.errors]
raise ValueError(
"State machine validation failed:\n" + "\n".join(error_messages)
)
raise ValueError("State machine validation failed:\n" + "\n".join(error_messages))
return True

View File

@@ -38,6 +38,7 @@ See Also:
- apps.core.state_machine.fields.RichFSMField: The FSM field implementation
- django_fsm.can_proceed: FSM transition checking utility
"""
from collections.abc import Iterable
from typing import Any
@@ -47,25 +48,75 @@ from django_fsm import can_proceed
# Default transition metadata for styling
TRANSITION_METADATA = {
# Approval transitions
"approve": {"style": "green", "icon": "check", "requires_confirm": True, "confirm_message": "Are you sure you want to approve this?"},
"transition_to_approved": {"style": "green", "icon": "check", "requires_confirm": True, "confirm_message": "Are you sure you want to approve this?"},
"approve": {
"style": "green",
"icon": "check",
"requires_confirm": True,
"confirm_message": "Are you sure you want to approve this?",
},
"transition_to_approved": {
"style": "green",
"icon": "check",
"requires_confirm": True,
"confirm_message": "Are you sure you want to approve this?",
},
# Rejection transitions
"reject": {"style": "red", "icon": "times", "requires_confirm": True, "confirm_message": "Are you sure you want to reject this?"},
"transition_to_rejected": {"style": "red", "icon": "times", "requires_confirm": True, "confirm_message": "Are you sure you want to reject this?"},
"reject": {
"style": "red",
"icon": "times",
"requires_confirm": True,
"confirm_message": "Are you sure you want to reject this?",
},
"transition_to_rejected": {
"style": "red",
"icon": "times",
"requires_confirm": True,
"confirm_message": "Are you sure you want to reject this?",
},
# Escalation transitions
"escalate": {"style": "yellow", "icon": "arrow-up", "requires_confirm": True, "confirm_message": "Are you sure you want to escalate this?"},
"transition_to_escalated": {"style": "yellow", "icon": "arrow-up", "requires_confirm": True, "confirm_message": "Are you sure you want to escalate this?"},
"escalate": {
"style": "yellow",
"icon": "arrow-up",
"requires_confirm": True,
"confirm_message": "Are you sure you want to escalate this?",
},
"transition_to_escalated": {
"style": "yellow",
"icon": "arrow-up",
"requires_confirm": True,
"confirm_message": "Are you sure you want to escalate this?",
},
# Assignment transitions
"assign": {"style": "blue", "icon": "user-plus", "requires_confirm": False},
"unassign": {"style": "gray", "icon": "user-minus", "requires_confirm": False},
# Status transitions
"start": {"style": "blue", "icon": "play", "requires_confirm": False},
"complete": {"style": "green", "icon": "check-circle", "requires_confirm": True, "confirm_message": "Are you sure you want to complete this?"},
"cancel": {"style": "red", "icon": "ban", "requires_confirm": True, "confirm_message": "Are you sure you want to cancel this?"},
"complete": {
"style": "green",
"icon": "check-circle",
"requires_confirm": True,
"confirm_message": "Are you sure you want to complete this?",
},
"cancel": {
"style": "red",
"icon": "ban",
"requires_confirm": True,
"confirm_message": "Are you sure you want to cancel this?",
},
"reopen": {"style": "blue", "icon": "redo", "requires_confirm": False},
# Resolution transitions
"resolve": {"style": "green", "icon": "check-double", "requires_confirm": True, "confirm_message": "Are you sure you want to resolve this?"},
"dismiss": {"style": "gray", "icon": "times-circle", "requires_confirm": True, "confirm_message": "Are you sure you want to dismiss this?"},
"resolve": {
"style": "green",
"icon": "check-double",
"requires_confirm": True,
"confirm_message": "Are you sure you want to resolve this?",
},
"dismiss": {
"style": "gray",
"icon": "times-circle",
"requires_confirm": True,
"confirm_message": "Are you sure you want to dismiss this?",
},
# Default
"default": {"style": "gray", "icon": "arrow-right", "requires_confirm": False},
}
@@ -86,22 +137,22 @@ def _get_transition_metadata(transition_name: str) -> dict[str, Any]:
def _format_transition_label(transition_name: str) -> str:
"""Format a transition method name into a human-readable label."""
label = transition_name
for prefix in ['transition_to_', 'transition_', 'do_']:
for prefix in ["transition_to_", "transition_", "do_"]:
if label.startswith(prefix):
label = label[len(prefix):]
label = label[len(prefix) :]
break
if label.endswith('ed') and len(label) > 3:
if label.endswith('ied'):
label = label[:-3] + 'y'
if label.endswith("ed") and len(label) > 3:
if label.endswith("ied"):
label = label[:-3] + "y"
elif label[-3] == label[-4]:
label = label[:-3]
else:
label = label[:-1]
if not label.endswith('e'):
if not label.endswith("e"):
label = label[:-1]
return label.replace('_', ' ').title()
return label.replace("_", " ").title()
class StateMachineMixin(models.Model):
@@ -187,14 +238,10 @@ class StateMachineMixin(models.Model):
"""Check if a transition method can proceed for the current instance."""
method = getattr(self, transition_method_name, None)
if method is None or not callable(method):
raise AttributeError(
f"Transition method '{transition_method_name}' not found"
)
raise AttributeError(f"Transition method '{transition_method_name}' not found")
return can_proceed(method)
def get_available_transitions(
self, field_name: str | None = None
) -> Iterable[Any]:
def get_available_transitions(self, field_name: str | None = None) -> Iterable[Any]:
"""Return available transitions when helpers are present."""
name = field_name or self.state_field_name
helper_name = f"get_available_{name}_transitions"
@@ -246,14 +293,16 @@ class StateMachineMixin(models.Model):
try:
if can_proceed(method, user):
metadata = _get_transition_metadata(transition_name)
transitions.append({
'name': transition_name,
'label': _format_transition_label(transition_name),
'icon': metadata.get('icon', 'arrow-right'),
'style': metadata.get('style', 'gray'),
'requires_confirm': metadata.get('requires_confirm', False),
'confirm_message': metadata.get('confirm_message', 'Are you sure?'),
})
transitions.append(
{
"name": transition_name,
"label": _format_transition_label(transition_name),
"icon": metadata.get("icon", "arrow-right"),
"style": metadata.get("style", "gray"),
"requires_confirm": metadata.get("requires_confirm", False),
"confirm_message": metadata.get("confirm_message", "Are you sure?"),
}
)
except Exception:
# Skip transitions that raise errors
pass

View File

@@ -47,7 +47,7 @@ class CallbackStats:
successful_executions: int = 0
failed_executions: int = 0
total_duration_ms: float = 0.0
min_duration_ms: float = float('inf')
min_duration_ms: float = float("inf")
max_duration_ms: float = 0.0
last_execution: datetime | None = None
last_error: str | None = None
@@ -97,10 +97,10 @@ class CallbackMonitor:
- Performance statistics
"""
_instance: Optional['CallbackMonitor'] = None
_instance: Optional["CallbackMonitor"] = None
_lock = threading.Lock()
def __new__(cls) -> 'CallbackMonitor':
def __new__(cls) -> "CallbackMonitor":
if cls._instance is None:
with cls._lock:
if cls._instance is None:
@@ -112,9 +112,7 @@ class CallbackMonitor:
if self._initialized:
return
self._stats: dict[str, CallbackStats] = defaultdict(
lambda: CallbackStats(callback_name="")
)
self._stats: dict[str, CallbackStats] = defaultdict(lambda: CallbackStats(callback_name=""))
self._recent_executions: list[CallbackExecutionRecord] = []
self._max_recent_records = 1000
self._enabled = self._check_enabled()
@@ -123,13 +121,13 @@ class CallbackMonitor:
def _check_enabled(self) -> bool:
"""Check if monitoring is enabled."""
callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
return callback_settings.get('monitoring_enabled', True)
callback_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {})
return callback_settings.get("monitoring_enabled", True)
def _check_debug_mode(self) -> bool:
"""Check if debug mode is enabled."""
callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
return callback_settings.get('debug_mode', settings.DEBUG)
callback_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {})
return callback_settings.get("debug_mode", settings.DEBUG)
def is_enabled(self) -> bool:
"""Check if monitoring is currently enabled."""
@@ -197,7 +195,7 @@ class CallbackMonitor:
# Store recent executions (with size limit)
self._recent_executions.append(record)
if len(self._recent_executions) > self._max_recent_records:
self._recent_executions = self._recent_executions[-self._max_recent_records:]
self._recent_executions = self._recent_executions[-self._max_recent_records :]
# Log in debug mode
if self._debug_mode:
@@ -277,12 +275,12 @@ class CallbackMonitor:
# Build summary
summary = {
'total_failures': len(failures),
'by_callback': {
"total_failures": len(failures),
"by_callback": {
name: {
'count': len(records),
'last_error': records[-1].error_message if records else None,
'last_occurrence': records[-1].timestamp if records else None,
"count": len(records),
"last_error": records[-1].error_message if records else None,
"last_occurrence": records[-1].timestamp if records else None,
}
for name, records in by_callback.items()
},
@@ -293,12 +291,12 @@ class CallbackMonitor:
def get_performance_report(self) -> dict[str, Any]:
"""Get a performance report for all callbacks."""
report = {
'callbacks': {},
'summary': {
'total_callbacks': len(self._stats),
'total_executions': sum(s.total_executions for s in self._stats.values()),
'total_failures': sum(s.failed_executions for s in self._stats.values()),
'avg_duration_ms': 0.0,
"callbacks": {},
"summary": {
"total_callbacks": len(self._stats),
"total_executions": sum(s.total_executions for s in self._stats.values()),
"total_failures": sum(s.failed_executions for s in self._stats.values()),
"avg_duration_ms": 0.0,
},
}
@@ -306,19 +304,19 @@ class CallbackMonitor:
total_count = 0
for name, stats in self._stats.items():
report['callbacks'][name] = {
'executions': stats.total_executions,
'success_rate': f"{stats.success_rate:.1f}%",
'avg_duration_ms': f"{stats.avg_duration_ms:.2f}",
'min_duration_ms': f"{stats.min_duration_ms:.2f}" if stats.min_duration_ms != float('inf') else "N/A",
'max_duration_ms': f"{stats.max_duration_ms:.2f}",
'last_execution': stats.last_execution.isoformat() if stats.last_execution else None,
report["callbacks"][name] = {
"executions": stats.total_executions,
"success_rate": f"{stats.success_rate:.1f}%",
"avg_duration_ms": f"{stats.avg_duration_ms:.2f}",
"min_duration_ms": f"{stats.min_duration_ms:.2f}" if stats.min_duration_ms != float("inf") else "N/A",
"max_duration_ms": f"{stats.max_duration_ms:.2f}",
"last_execution": stats.last_execution.isoformat() if stats.last_execution else None,
}
total_duration += stats.total_duration_ms
total_count += stats.total_executions
if total_count > 0:
report['summary']['avg_duration_ms'] = total_duration / total_count
report["summary"]["avg_duration_ms"] = total_duration / total_count
return report
@@ -361,7 +359,7 @@ class TimedCallbackExecution:
self.success = True
self.error_message: str | None = None
def __enter__(self) -> 'TimedCallbackExecution':
def __enter__(self) -> "TimedCallbackExecution":
self.start_time = time.perf_counter()
return self
@@ -442,12 +440,12 @@ def get_callback_execution_order(
__all__ = [
'CallbackExecutionRecord',
'CallbackStats',
'CallbackMonitor',
'callback_monitor',
'TimedCallbackExecution',
'log_transition_start',
'log_transition_end',
'get_callback_execution_order',
"CallbackExecutionRecord",
"CallbackStats",
"CallbackMonitor",
"callback_monitor",
"TimedCallbackExecution",
"log_transition_start",
"log_transition_end",
"get_callback_execution_order",
]

View File

@@ -1,4 +1,5 @@
"""TransitionRegistry - Centralized registry for managing FSM transitions."""
import logging
from collections.abc import Callable
from dataclasses import dataclass, field
@@ -86,9 +87,7 @@ class TransitionRegistry:
self._transitions[key][transition_key] = transition_info
return transition_info
def get_transition(
self, choice_group: str, domain: str, source: str, target: str
) -> TransitionInfo | None:
def get_transition(self, choice_group: str, domain: str, source: str, target: str) -> TransitionInfo | None:
"""
Retrieve transition info.
@@ -109,9 +108,7 @@ class TransitionRegistry:
return self._transitions[key].get(transition_key)
def get_available_transitions(
self, choice_group: str, domain: str, current_state: str
) -> list[TransitionInfo]:
def get_available_transitions(self, choice_group: str, domain: str, current_state: str) -> list[TransitionInfo]:
"""
Get all valid transitions from a state.
@@ -135,9 +132,7 @@ class TransitionRegistry:
return available
def get_transition_method_name(
self, choice_group: str, domain: str, source: str, target: str
) -> str | None:
def get_transition_method_name(self, choice_group: str, domain: str, source: str, target: str) -> str | None:
"""
Get the method name for a transition.
@@ -153,9 +148,7 @@ class TransitionRegistry:
transition = self.get_transition(choice_group, domain, source, target)
return transition.method_name if transition else None
def validate_transition(
self, choice_group: str, domain: str, source: str, target: str
) -> bool:
def validate_transition(self, choice_group: str, domain: str, source: str, target: str) -> bool:
"""
Check if a transition is valid.
@@ -168,13 +161,9 @@ class TransitionRegistry:
Returns:
True if transition is valid
"""
return (
self.get_transition(choice_group, domain, source, target) is not None
)
return self.get_transition(choice_group, domain, source, target) is not None
def build_registry_from_choices(
self, choice_group: str, domain: str = "core"
) -> None:
def build_registry_from_choices(self, choice_group: str, domain: str = "core") -> None:
"""
Automatically populate registry from RichChoice metadata.
@@ -194,9 +183,7 @@ class TransitionRegistry:
for target in targets:
# Use shared method name determination
method_name = determine_method_name_for_transition(
source, target
)
method_name = determine_method_name_for_transition(source, target)
self.register_transition(
choice_group=choice_group,
@@ -226,9 +213,7 @@ class TransitionRegistry:
else:
self._transitions.clear()
def export_transition_graph(
self, choice_group: str, domain: str, format: str = "dict"
) -> Any:
def export_transition_graph(self, choice_group: str, domain: str, format: str = "dict") -> Any:
"""
Export state machine graph for visualization.
@@ -247,7 +232,7 @@ class TransitionRegistry:
if format == "dict":
graph: dict[str, list[str]] = {}
for (source, target), info in self._transitions[key].items():
for (source, target), _info in self._transitions[key].items():
if source not in graph:
graph[source] = []
graph[source].append(target)
@@ -262,10 +247,7 @@ class TransitionRegistry:
elif format == "dot":
lines = ["digraph {"]
for (source, target), info in self._transitions[key].items():
lines.append(
f' "{source}" -> "{target}" '
f'[label="{info.method_name}"];'
)
lines.append(f' "{source}" -> "{target}" ' f'[label="{info.method_name}"];')
lines.append("}")
return "\n".join(lines)
@@ -288,13 +270,14 @@ registry_instance = TransitionRegistry()
# Callback registration helpers
def register_callback(
model_class: type[models.Model],
field_name: str,
source: str,
target: str,
callback: Any,
stage: str = 'post',
stage: str = "post",
) -> None:
"""
Register a callback for a specific state transition.
@@ -325,7 +308,7 @@ def register_notification_callback(
source: str,
target: str,
notification_type: str,
recipient_field: str = 'submitted_by',
recipient_field: str = "submitted_by",
) -> None:
"""
Register a notification callback for a state transition.
@@ -344,15 +327,15 @@ def register_notification_callback(
notification_type=notification_type,
recipient_field=recipient_field,
)
register_callback(model_class, field_name, source, target, callback, 'post')
register_callback(model_class, field_name, source, target, callback, "post")
def register_cache_invalidation(
model_class: type[models.Model],
field_name: str,
cache_patterns: list[str] | None = None,
source: str = '*',
target: str = '*',
source: str = "*",
target: str = "*",
) -> None:
"""
Register cache invalidation for state transitions.
@@ -367,15 +350,15 @@ def register_cache_invalidation(
from .callbacks.cache import CacheInvalidationCallback
callback = CacheInvalidationCallback(patterns=cache_patterns or [])
register_callback(model_class, field_name, source, target, callback, 'post')
register_callback(model_class, field_name, source, target, callback, "post")
def register_related_update(
model_class: type[models.Model],
field_name: str,
update_func: Callable,
source: str = '*',
target: str = '*',
source: str = "*",
target: str = "*",
) -> None:
"""
Register a related model update callback.
@@ -390,7 +373,7 @@ def register_related_update(
from .callbacks.related_updates import RelatedModelUpdateCallback
callback = RelatedModelUpdateCallback(update_function=update_func)
register_callback(model_class, field_name, source, target, callback, 'post')
register_callback(model_class, field_name, source, target, callback, "post")
def register_transition_callbacks(cls: type[models.Model]) -> type[models.Model]:
@@ -414,20 +397,20 @@ def register_transition_callbacks(cls: type[models.Model]) -> type[models.Model]
Returns:
The decorated model class.
"""
meta = getattr(cls, 'Meta', None)
meta = getattr(cls, "Meta", None)
if not meta:
return cls
transition_callbacks = getattr(meta, 'transition_callbacks', None)
transition_callbacks = getattr(meta, "transition_callbacks", None)
if not transition_callbacks:
return cls
# Get the FSM field name
field_name = getattr(meta, 'fsm_field', 'status')
field_name = getattr(meta, "fsm_field", "status")
# Register each callback
for (source, target), callbacks in transition_callbacks.items():
if not isinstance(callbacks, (list, tuple)):
if not isinstance(callbacks, list | tuple):
callbacks = [callbacks]
for callback in callbacks:
@@ -455,23 +438,23 @@ def discover_and_register_callbacks() -> None:
for model in apps.get_models():
# Check if model has StateMachineMixin
if not hasattr(model, '_fsm_metadata') and not hasattr(model, 'Meta'):
if not hasattr(model, "_fsm_metadata") and not hasattr(model, "Meta"):
continue
meta = getattr(model, 'Meta', None)
meta = getattr(model, "Meta", None)
if not meta:
continue
transition_callbacks = getattr(meta, 'transition_callbacks', None)
transition_callbacks = getattr(meta, "transition_callbacks", None)
if not transition_callbacks:
continue
# Get the FSM field name
field_name = getattr(meta, 'fsm_field', 'status')
field_name = getattr(meta, "fsm_field", "status")
# Register callbacks
for (source, target), callbacks in transition_callbacks.items():
if not isinstance(callbacks, (list, tuple)):
if not isinstance(callbacks, list | tuple):
callbacks = [callbacks]
for callback in callbacks:

View File

@@ -74,7 +74,7 @@ class TransitionSignalHandler:
source: str,
target: str,
handler: Callable,
stage: str = 'post',
stage: str = "post",
) -> None:
"""
Register a handler for a specific transition.
@@ -95,10 +95,7 @@ class TransitionSignalHandler:
signal = self._get_signal(stage)
self._connect_signal(signal, model_class, source, target, handler)
logger.debug(
f"Registered {stage} transition handler for "
f"{model_class.__name__}: {source}{target}"
)
logger.debug(f"Registered {stage} transition handler for " f"{model_class.__name__}: {source}{target}")
def unregister(
self,
@@ -106,7 +103,7 @@ class TransitionSignalHandler:
source: str,
target: str,
handler: Callable,
stage: str = 'post',
stage: str = "post",
) -> None:
"""Unregister a previously registered handler."""
key = self._make_key(model_class, source, target, stage)
@@ -128,9 +125,9 @@ class TransitionSignalHandler:
def _get_signal(self, stage: str) -> Signal:
"""Get the signal for a given stage."""
if stage == 'pre':
if stage == "pre":
return pre_state_transition
elif stage == 'error':
elif stage == "error":
return state_transition_failed
return post_state_transition
@@ -150,13 +147,13 @@ class TransitionSignalHandler:
return
# Check source state
signal_source = kwargs.get('source', '')
if source != '*' and str(signal_source) != source:
signal_source = kwargs.get("source", "")
if source != "*" and str(signal_source) != source:
return
# Check target state
signal_target = kwargs.get('target', '')
if target != '*' and str(signal_target) != target:
signal_target = kwargs.get("target", "")
if target != "*" and str(signal_target) != target:
return
# Call the handler
@@ -174,7 +171,7 @@ def register_transition_handler(
source: str,
target: str,
handler: Callable,
stage: str = 'post',
stage: str = "post",
) -> None:
"""
Convenience function to register a transition signal handler.
@@ -186,9 +183,7 @@ def register_transition_handler(
handler: The handler function to call.
stage: 'pre', 'post', or 'error'.
"""
transition_signal_handler.register(
model_class, source, target, handler, stage
)
transition_signal_handler.register(model_class, source, target, handler, stage)
def connect_fsm_log_signals() -> None:
@@ -231,9 +226,9 @@ class TransitionHandlerDecorator:
def __init__(
self,
model_class: type[models.Model],
source: str = '*',
target: str = '*',
stage: str = 'post',
source: str = "*",
target: str = "*",
stage: str = "post",
):
"""
Initialize the decorator.
@@ -263,9 +258,9 @@ class TransitionHandlerDecorator:
def on_transition(
model_class: type[models.Model],
source: str = '*',
target: str = '*',
stage: str = 'post',
source: str = "*",
target: str = "*",
stage: str = "post",
) -> TransitionHandlerDecorator:
"""
Decorator factory for registering transition handlers.
@@ -289,44 +284,44 @@ def on_transition(
def on_pre_transition(
model_class: type[models.Model],
source: str = '*',
target: str = '*',
source: str = "*",
target: str = "*",
) -> TransitionHandlerDecorator:
"""Decorator for pre-transition handlers."""
return on_transition(model_class, source, target, stage='pre')
return on_transition(model_class, source, target, stage="pre")
def on_post_transition(
model_class: type[models.Model],
source: str = '*',
target: str = '*',
source: str = "*",
target: str = "*",
) -> TransitionHandlerDecorator:
"""Decorator for post-transition handlers."""
return on_transition(model_class, source, target, stage='post')
return on_transition(model_class, source, target, stage="post")
def on_transition_error(
model_class: type[models.Model],
source: str = '*',
target: str = '*',
source: str = "*",
target: str = "*",
) -> TransitionHandlerDecorator:
"""Decorator for transition error handlers."""
return on_transition(model_class, source, target, stage='error')
return on_transition(model_class, source, target, stage="error")
__all__ = [
# Signals
'pre_state_transition',
'post_state_transition',
'state_transition_failed',
"pre_state_transition",
"post_state_transition",
"state_transition_failed",
# Handler registration
'TransitionSignalHandler',
'transition_signal_handler',
'register_transition_handler',
'connect_fsm_log_signals',
"TransitionSignalHandler",
"transition_signal_handler",
"register_transition_handler",
"connect_fsm_log_signals",
# Decorators
'on_transition',
'on_pre_transition',
'on_post_transition',
'on_transition_error',
"on_transition",
"on_pre_transition",
"on_post_transition",
"on_transition_error",
]

View File

@@ -29,11 +29,11 @@ class UserFactory:
@classmethod
def create_user(
cls,
role: str = 'USER',
role: str = "USER",
username: str | None = None,
email: str | None = None,
password: str = 'testpass123',
**kwargs
password: str = "testpass123",
**kwargs,
) -> User:
"""
Create a user with specified role.
@@ -54,33 +54,27 @@ class UserFactory:
if email is None:
email = f"{role.lower()}_{uid}@example.com"
return User.objects.create_user(
username=username,
email=email,
password=password,
role=role,
**kwargs
)
return User.objects.create_user(username=username, email=email, password=password, role=role, **kwargs)
@classmethod
def create_regular_user(cls, **kwargs) -> User:
"""Create a regular user."""
return cls.create_user(role='USER', **kwargs)
return cls.create_user(role="USER", **kwargs)
@classmethod
def create_moderator(cls, **kwargs) -> User:
"""Create a moderator user."""
return cls.create_user(role='MODERATOR', **kwargs)
return cls.create_user(role="MODERATOR", **kwargs)
@classmethod
def create_admin(cls, **kwargs) -> User:
"""Create an admin user."""
return cls.create_user(role='ADMIN', **kwargs)
return cls.create_user(role="ADMIN", **kwargs)
@classmethod
def create_superuser(cls, **kwargs) -> User:
"""Create a superuser."""
return cls.create_user(role='SUPERUSER', **kwargs)
return cls.create_user(role="SUPERUSER", **kwargs)
class CompanyFactory:
@@ -102,11 +96,7 @@ class CompanyFactory:
if name is None:
name = f"Test Operator {uid}"
defaults = {
'name': name,
'description': f'Test operator company {uid}',
'roles': ['OPERATOR']
}
defaults = {"name": name, "description": f"Test operator company {uid}", "roles": ["OPERATOR"]}
defaults.update(kwargs)
return Company.objects.create(**defaults)
@@ -119,11 +109,7 @@ class CompanyFactory:
if name is None:
name = f"Test Manufacturer {uid}"
defaults = {
'name': name,
'description': f'Test manufacturer company {uid}',
'roles': ['MANUFACTURER']
}
defaults = {"name": name, "description": f"Test manufacturer company {uid}", "roles": ["MANUFACTURER"]}
defaults.update(kwargs)
return Company.objects.create(**defaults)
@@ -140,11 +126,7 @@ class ParkFactory:
@classmethod
def create_park(
cls,
name: str | None = None,
operator: Any | None = None,
status: str = 'OPERATING',
**kwargs
cls, name: str | None = None, operator: Any | None = None, status: str = "OPERATING", **kwargs
) -> Any:
"""
Create a park with specified status.
@@ -167,12 +149,12 @@ class ParkFactory:
operator = CompanyFactory.create_operator()
defaults = {
'name': name,
'slug': f'test-park-{uid}',
'description': f'A test park {uid}',
'operator': operator,
'status': status,
'timezone': 'America/New_York'
"name": name,
"slug": f"test-park-{uid}",
"description": f"A test park {uid}",
"operator": operator,
"status": status,
"timezone": "America/New_York",
}
defaults.update(kwargs)
return Park.objects.create(**defaults)
@@ -194,8 +176,8 @@ class RideFactory:
name: str | None = None,
park: Any | None = None,
manufacturer: Any | None = None,
status: str = 'OPERATING',
**kwargs
status: str = "OPERATING",
**kwargs,
) -> Any:
"""
Create a ride with specified status.
@@ -221,12 +203,12 @@ class RideFactory:
manufacturer = CompanyFactory.create_manufacturer()
defaults = {
'name': name,
'slug': f'test-ride-{uid}',
'description': f'A test ride {uid}',
'park': park,
'manufacturer': manufacturer,
'status': status
"name": name,
"slug": f"test-ride-{uid}",
"description": f"A test ride {uid}",
"park": park,
"manufacturer": manufacturer,
"status": status,
}
defaults.update(kwargs)
return Ride.objects.create(**defaults)
@@ -247,9 +229,9 @@ class EditSubmissionFactory:
cls,
user: Any | None = None,
target_object: Any | None = None,
status: str = 'PENDING',
status: str = "PENDING",
changes: dict[str, Any] | None = None,
**kwargs
**kwargs,
) -> Any:
"""
Create an edit submission.
@@ -271,23 +253,20 @@ class EditSubmissionFactory:
if user is None:
user = UserFactory.create_regular_user()
if target_object is None:
target_object = Company.objects.create(
name=f'Target Company {uid}',
description='Test company'
)
target_object = Company.objects.create(name=f"Target Company {uid}", description="Test company")
if changes is None:
changes = {'name': f'Updated Name {uid}'}
changes = {"name": f"Updated Name {uid}"}
content_type = ContentType.objects.get_for_model(target_object)
defaults = {
'user': user,
'content_type': content_type,
'object_id': target_object.id,
'submission_type': 'EDIT',
'changes': changes,
'status': status,
'reason': f'Test reason {uid}'
"user": user,
"content_type": content_type,
"object_id": target_object.id,
"submission_type": "EDIT",
"changes": changes,
"status": status,
"reason": f"Test reason {uid}",
}
defaults.update(kwargs)
return EditSubmission.objects.create(**defaults)
@@ -305,11 +284,7 @@ class ModerationReportFactory:
@classmethod
def create_report(
cls,
reporter: Any | None = None,
target_object: Any | None = None,
status: str = 'PENDING',
**kwargs
cls, reporter: Any | None = None, target_object: Any | None = None, status: str = "PENDING", **kwargs
) -> Any:
"""
Create a moderation report.
@@ -330,23 +305,20 @@ class ModerationReportFactory:
if reporter is None:
reporter = UserFactory.create_regular_user()
if target_object is None:
target_object = Company.objects.create(
name=f'Reported Company {uid}',
description='Test company'
)
target_object = Company.objects.create(name=f"Reported Company {uid}", description="Test company")
content_type = ContentType.objects.get_for_model(target_object)
defaults = {
'report_type': 'CONTENT',
'status': status,
'priority': 'MEDIUM',
'reported_entity_type': target_object._meta.model_name,
'reported_entity_id': target_object.id,
'content_type': content_type,
'reason': f'Test reason {uid}',
'description': f'Test report description {uid}',
'reported_by': reporter
"report_type": "CONTENT",
"status": status,
"priority": "MEDIUM",
"reported_entity_type": target_object._meta.model_name,
"reported_entity_id": target_object.id,
"content_type": content_type,
"reason": f"Test reason {uid}",
"description": f"Test report description {uid}",
"reported_by": reporter,
}
defaults.update(kwargs)
return ModerationReport.objects.create(**defaults)
@@ -369,5 +341,5 @@ class MockInstance:
setattr(self, key, value)
def __repr__(self):
attrs = ', '.join(f'{k}={v!r}' for k, v in self.__dict__.items())
return f'MockInstance({attrs})'
attrs = ", ".join(f"{k}={v!r}" for k, v in self.__dict__.items())
return f"MockInstance({attrs})"

View File

@@ -13,11 +13,7 @@ from typing import Any
from django.contrib.contenttypes.models import ContentType
def assert_transition_allowed(
instance: Any,
method_name: str,
user: Any | None = None
) -> bool:
def assert_transition_allowed(instance: Any, method_name: str, user: Any | None = None) -> bool:
"""
Assert that a transition is allowed.
@@ -43,11 +39,7 @@ def assert_transition_allowed(
return True
def assert_transition_denied(
instance: Any,
method_name: str,
user: Any | None = None
) -> bool:
def assert_transition_denied(instance: Any, method_name: str, user: Any | None = None) -> bool:
"""
Assert that a transition is denied.
@@ -73,11 +65,7 @@ def assert_transition_denied(
return True
def assert_state_log_created(
instance: Any,
expected_state: str,
user: Any | None = None
) -> Any:
def assert_state_log_created(instance: Any, expected_state: str, user: Any | None = None) -> Any:
"""
Assert that a StateLog entry was created for a transition.
@@ -98,11 +86,7 @@ def assert_state_log_created(
from django_fsm_log.models import StateLog
ct = ContentType.objects.get_for_model(instance)
log = StateLog.objects.filter(
content_type=ct,
object_id=instance.id,
state=expected_state
).first()
log = StateLog.objects.filter(content_type=ct, object_id=instance.id, state=expected_state).first()
assert log is not None, f"StateLog for state '{expected_state}' not found"
@@ -132,22 +116,15 @@ def assert_state_log_count(instance: Any, expected_count: int) -> list[Any]:
from django_fsm_log.models import StateLog
ct = ContentType.objects.get_for_model(instance)
logs = list(StateLog.objects.filter(
content_type=ct,
object_id=instance.id
).order_by('timestamp'))
logs = list(StateLog.objects.filter(content_type=ct, object_id=instance.id).order_by("timestamp"))
actual_count = len(logs)
assert actual_count == expected_count, \
f"Expected {expected_count} StateLog entries, got {actual_count}"
assert actual_count == expected_count, f"Expected {expected_count} StateLog entries, got {actual_count}"
return logs
def assert_state_transition_sequence(
instance: Any,
expected_states: list[str]
) -> list[Any]:
def assert_state_transition_sequence(instance: Any, expected_states: list[str]) -> list[Any]:
"""
Assert that state transitions occurred in a specific sequence.
@@ -167,24 +144,15 @@ def assert_state_transition_sequence(
from django_fsm_log.models import StateLog
ct = ContentType.objects.get_for_model(instance)
logs = list(StateLog.objects.filter(
content_type=ct,
object_id=instance.id
).order_by('timestamp'))
logs = list(StateLog.objects.filter(content_type=ct, object_id=instance.id).order_by("timestamp"))
actual_states = [log.state for log in logs]
assert actual_states == expected_states, \
f"Expected state sequence {expected_states}, got {actual_states}"
assert actual_states == expected_states, f"Expected state sequence {expected_states}, got {actual_states}"
return logs
def assert_guard_passes(
guard: Callable,
instance: Any,
user: Any | None = None,
message: str = ""
) -> bool:
def assert_guard_passes(guard: Callable, instance: Any, user: Any | None = None, message: str = "") -> bool:
"""
Assert that a guard function passes.
@@ -210,11 +178,7 @@ def assert_guard_passes(
def assert_guard_fails(
guard: Callable,
instance: Any,
user: Any | None = None,
expected_error_code: str | None = None,
message: str = ""
guard: Callable, instance: Any, user: Any | None = None, expected_error_code: str | None = None, message: str = ""
) -> bool:
"""
Assert that a guard function fails.
@@ -239,19 +203,15 @@ def assert_guard_fails(
fail_message = message or f"Guard should fail but returned {result}"
assert result is False, fail_message
if expected_error_code and hasattr(guard, 'error_code'):
assert guard.error_code == expected_error_code, \
f"Expected error code {expected_error_code}, got {guard.error_code}"
if expected_error_code and hasattr(guard, "error_code"):
assert (
guard.error_code == expected_error_code
), f"Expected error code {expected_error_code}, got {guard.error_code}"
return True
def transition_and_save(
instance: Any,
transition_method: str,
user: Any | None = None,
**kwargs
) -> Any:
def transition_and_save(instance: Any, transition_method: str, user: Any | None = None, **kwargs) -> Any:
"""
Execute a transition and save the instance.
@@ -290,10 +250,10 @@ def get_available_transitions(instance: Any) -> list[str]:
"""
# Get the state field name from the instance
state_field = getattr(instance, 'state_field_name', 'status')
state_field = getattr(instance, "state_field_name", "status")
# Build the function name dynamically
func_name = f'get_available_{state_field}_transitions'
func_name = f"get_available_{state_field}_transitions"
if hasattr(instance, func_name):
get_transitions = getattr(instance, func_name)
return [t.name for t in get_transitions()]
@@ -301,19 +261,13 @@ def get_available_transitions(instance: Any) -> list[str]:
# Fallback: look for transition methods
transitions = []
for attr_name in dir(instance):
if attr_name.startswith('transition_to_'):
if attr_name.startswith("transition_to_"):
transitions.append(attr_name)
return transitions
def create_transition_context(
instance: Any,
from_state: str,
to_state: str,
user: Any | None = None,
**extra
) -> dict:
def create_transition_context(instance: Any, from_state: str, to_state: str, user: Any | None = None, **extra) -> dict:
"""
Create a mock transition context dictionary.
@@ -331,11 +285,11 @@ def create_transition_context(
context = create_transition_context(submission, 'PENDING', 'APPROVED', moderator)
"""
return {
'instance': instance,
'from_state': from_state,
'to_state': to_state,
'user': user,
'model_class': type(instance),
'transition_name': f'transition_to_{to_state.lower()}',
**extra
"instance": instance,
"from_state": from_state,
"to_state": to_state,
"user": user,
"model_class": type(instance),
"transition_name": f"transition_to_{to_state.lower()}",
**extra,
}

View File

@@ -1,4 +1,5 @@
"""Tests for StateTransitionBuilder."""
import pytest
from django.core.exceptions import ImproperlyConfigured

View File

@@ -19,12 +19,7 @@ class CallbackContext:
"""Mock context for testing callbacks."""
def __init__(
self,
instance: Any = None,
from_state: str = 'PENDING',
to_state: str = 'APPROVED',
user: Any = None,
**extra
self, instance: Any = None, from_state: str = "PENDING", to_state: str = "APPROVED", user: Any = None, **extra
):
self.instance = instance or Mock()
self.from_state = from_state
@@ -34,18 +29,18 @@ class CallbackContext:
def to_dict(self) -> dict[str, Any]:
return {
'instance': self.instance,
'from_state': self.from_state,
'to_state': self.to_state,
'user': self.user,
**self.extra
"instance": self.instance,
"from_state": self.from_state,
"to_state": self.to_state,
"user": self.user,
**self.extra,
}
class MockCallback:
"""Mock callback for testing."""
def __init__(self, name: str = 'callback', should_raise: bool = False):
def __init__(self, name: str = "callback", should_raise: bool = False):
self.name = name
self.calls: list[dict] = []
self.should_raise = should_raise
@@ -71,40 +66,40 @@ class PreTransitionCallbackTests(TestCase):
def test_pre_callback_executes_before_state_change(self):
"""Test that pre-transition callback executes before state changes."""
callback = MockCallback('pre_callback')
context = CallbackContext(from_state='PENDING', to_state='APPROVED')
callback = MockCallback("pre_callback")
context = CallbackContext(from_state="PENDING", to_state="APPROVED")
# Simulate pre-transition execution
callback(context.to_dict())
self.assertTrue(callback.was_called())
self.assertEqual(callback.calls[0]['from_state'], 'PENDING')
self.assertEqual(callback.calls[0]['to_state'], 'APPROVED')
self.assertEqual(callback.calls[0]["from_state"], "PENDING")
self.assertEqual(callback.calls[0]["to_state"], "APPROVED")
def test_pre_callback_receives_instance(self):
"""Test that pre-callback receives the model instance."""
mock_instance = Mock()
mock_instance.id = 123
mock_instance.status = 'PENDING'
mock_instance.status = "PENDING"
callback = MockCallback()
context = CallbackContext(instance=mock_instance)
callback(context.to_dict())
self.assertEqual(callback.calls[0]['instance'], mock_instance)
self.assertEqual(callback.calls[0]["instance"], mock_instance)
def test_pre_callback_receives_user(self):
"""Test that pre-callback receives the user performing transition."""
mock_user = Mock()
mock_user.username = 'moderator'
mock_user.username = "moderator"
callback = MockCallback()
context = CallbackContext(user=mock_user)
callback(context.to_dict())
self.assertEqual(callback.calls[0]['user'], mock_user)
self.assertEqual(callback.calls[0]["user"], mock_user)
def test_pre_callback_can_prevent_transition(self):
"""Test that pre-callback can prevent transition by raising exception."""
@@ -119,13 +114,13 @@ class PreTransitionCallbackTests(TestCase):
execution_order = []
def callback_1(ctx):
execution_order.append('first')
execution_order.append("first")
def callback_2(ctx):
execution_order.append('second')
execution_order.append("second")
def callback_3(ctx):
execution_order.append('third')
execution_order.append("third")
context = CallbackContext().to_dict()
@@ -134,7 +129,7 @@ class PreTransitionCallbackTests(TestCase):
callback_2(context)
callback_3(context)
self.assertEqual(execution_order, ['first', 'second', 'third'])
self.assertEqual(execution_order, ["first", "second", "third"])
class PostTransitionCallbackTests(TestCase):
@@ -142,28 +137,24 @@ class PostTransitionCallbackTests(TestCase):
def test_post_callback_executes_after_state_change(self):
"""Test that post-transition callback executes after state changes."""
callback = MockCallback('post_callback')
callback = MockCallback("post_callback")
# Simulate instance after transition
mock_instance = Mock()
mock_instance.status = 'APPROVED' # Already changed
mock_instance.status = "APPROVED" # Already changed
context = CallbackContext(
instance=mock_instance,
from_state='PENDING',
to_state='APPROVED'
)
context = CallbackContext(instance=mock_instance, from_state="PENDING", to_state="APPROVED")
callback(context.to_dict())
self.assertTrue(callback.was_called())
self.assertEqual(callback.calls[0]['instance'].status, 'APPROVED')
self.assertEqual(callback.calls[0]["instance"].status, "APPROVED")
def test_post_callback_receives_updated_instance(self):
"""Test that post-callback receives instance with new state."""
mock_instance = Mock()
mock_instance.status = 'APPROVED'
mock_instance.approved_at = '2025-01-15'
mock_instance.status = "APPROVED"
mock_instance.approved_at = "2025-01-15"
mock_instance.handled_by_id = 456
callback = MockCallback()
@@ -171,9 +162,9 @@ class PostTransitionCallbackTests(TestCase):
callback(context.to_dict())
instance = callback.calls[0]['instance']
self.assertEqual(instance.status, 'APPROVED')
self.assertEqual(instance.approved_at, '2025-01-15')
instance = callback.calls[0]["instance"]
self.assertEqual(instance.status, "APPROVED")
self.assertEqual(instance.approved_at, "2025-01-15")
def test_post_callback_failure_does_not_rollback(self):
"""Test that post-callback failures don't rollback the transition."""
@@ -193,13 +184,13 @@ class PostTransitionCallbackTests(TestCase):
execution_order = []
def notification_callback(ctx):
execution_order.append('notification')
execution_order.append("notification")
def cache_callback(ctx):
execution_order.append('cache')
execution_order.append("cache")
def analytics_callback(ctx):
execution_order.append('analytics')
execution_order.append("analytics")
context = CallbackContext().to_dict()
@@ -207,7 +198,7 @@ class PostTransitionCallbackTests(TestCase):
cache_callback(context)
analytics_callback(context)
self.assertEqual(execution_order, ['notification', 'cache', 'analytics'])
self.assertEqual(execution_order, ["notification", "cache", "analytics"])
class ErrorCallbackTests(TestCase):
@@ -221,17 +212,17 @@ class ErrorCallbackTests(TestCase):
raise ValueError("Transition failed")
except ValueError as e:
error_context = {
'instance': Mock(),
'from_state': 'PENDING',
'to_state': 'APPROVED',
'exception': e,
'exception_type': type(e).__name__
"instance": Mock(),
"from_state": "PENDING",
"to_state": "APPROVED",
"exception": e,
"exception_type": type(e).__name__,
}
error_callback(error_context)
self.assertTrue(error_callback.was_called())
self.assertIn('exception', error_callback.calls[0])
self.assertEqual(error_callback.calls[0]['exception_type'], 'ValueError')
self.assertIn("exception", error_callback.calls[0])
self.assertEqual(error_callback.calls[0]["exception_type"], "ValueError")
def test_error_callback_for_cleanup(self):
"""Test that error callbacks can perform cleanup."""
@@ -244,7 +235,7 @@ class ErrorCallbackTests(TestCase):
try:
raise ValueError("Transition failed")
except ValueError:
cleanup_callback({'exception': 'test'})
cleanup_callback({"exception": "test"})
self.assertTrue(cleanup_performed)
@@ -256,17 +247,17 @@ class ErrorCallbackTests(TestCase):
error_callback = MockCallback()
error_context = {
'instance': mock_instance,
'from_state': 'PENDING',
'to_state': 'APPROVED',
'user': mock_user,
'exception': ValueError("Test error")
"instance": mock_instance,
"from_state": "PENDING",
"to_state": "APPROVED",
"user": mock_user,
"exception": ValueError("Test error"),
}
error_callback(error_context)
self.assertEqual(error_callback.calls[0]['instance'], mock_instance)
self.assertEqual(error_callback.calls[0]['user'], mock_user)
self.assertEqual(error_callback.calls[0]["instance"], mock_instance)
self.assertEqual(error_callback.calls[0]["user"], mock_user)
class ConditionalCallbackTests(TestCase):
@@ -277,15 +268,15 @@ class ConditionalCallbackTests(TestCase):
execution_log = []
def approval_only_callback(ctx):
if ctx.get('to_state') == 'APPROVED':
execution_log.append('approved')
if ctx.get("to_state") == "APPROVED":
execution_log.append("approved")
# Transition to APPROVED - should execute
approval_only_callback({'to_state': 'APPROVED'})
approval_only_callback({"to_state": "APPROVED"})
self.assertEqual(len(execution_log), 1)
# Transition to REJECTED - should not execute
approval_only_callback({'to_state': 'REJECTED'})
approval_only_callback({"to_state": "REJECTED"})
self.assertEqual(len(execution_log), 1) # Still 1
def test_callback_with_transition_filter(self):
@@ -293,15 +284,15 @@ class ConditionalCallbackTests(TestCase):
execution_log = []
def escalation_callback(ctx):
if ctx.get('to_state') == 'ESCALATED':
execution_log.append('escalated')
if ctx.get("to_state") == "ESCALATED":
execution_log.append("escalated")
# Escalation - should execute
escalation_callback({'to_state': 'ESCALATED'})
escalation_callback({"to_state": "ESCALATED"})
self.assertEqual(len(execution_log), 1)
# Other transitions - should not execute
escalation_callback({'to_state': 'APPROVED'})
escalation_callback({"to_state": "APPROVED"})
self.assertEqual(len(execution_log), 1)
def test_callback_with_user_role_filter(self):
@@ -309,17 +300,17 @@ class ConditionalCallbackTests(TestCase):
admin_notifications = []
def admin_only_notification(ctx):
user = ctx.get('user')
if user and getattr(user, 'role', None) == 'ADMIN':
user = ctx.get("user")
if user and getattr(user, "role", None) == "ADMIN":
admin_notifications.append(ctx)
admin_user = Mock(role='ADMIN')
moderator_user = Mock(role='MODERATOR')
admin_user = Mock(role="ADMIN")
moderator_user = Mock(role="MODERATOR")
admin_only_notification({'user': admin_user})
admin_only_notification({"user": admin_user})
self.assertEqual(len(admin_notifications), 1)
admin_only_notification({'user': moderator_user})
admin_only_notification({"user": moderator_user})
self.assertEqual(len(admin_notifications), 1) # Still 1
@@ -331,29 +322,29 @@ class CallbackChainTests(TestCase):
results = []
callbacks = [
lambda ctx: results.append('a'),
lambda ctx: results.append('b'),
lambda ctx: results.append('c'),
lambda ctx: results.append("a"),
lambda ctx: results.append("b"),
lambda ctx: results.append("c"),
]
context = {}
for cb in callbacks:
cb(context)
self.assertEqual(results, ['a', 'b', 'c'])
self.assertEqual(results, ["a", "b", "c"])
def test_callback_chain_stops_on_failure(self):
"""Test that callback chain stops when a callback fails."""
results = []
def callback_a(ctx):
results.append('a')
results.append("a")
def callback_b(ctx):
raise ValueError("B failed")
def callback_c(ctx):
results.append('c')
results.append("c")
callbacks = [callback_a, callback_b, callback_c]
@@ -364,7 +355,7 @@ class CallbackChainTests(TestCase):
except ValueError:
break
self.assertEqual(results, ['a']) # c never executed
self.assertEqual(results, ["a"]) # c never executed
def test_callback_chain_with_continue_on_error(self):
"""Test callback chain that continues despite errors."""
@@ -372,13 +363,13 @@ class CallbackChainTests(TestCase):
errors = []
def callback_a(ctx):
results.append('a')
results.append("a")
def callback_b(ctx):
raise ValueError("B failed")
def callback_c(ctx):
results.append('c')
results.append("c")
callbacks = [callback_a, callback_b, callback_c]
@@ -389,7 +380,7 @@ class CallbackChainTests(TestCase):
except Exception as e:
errors.append(str(e))
self.assertEqual(results, ['a', 'c'])
self.assertEqual(results, ["a", "c"])
self.assertEqual(len(errors), 1)
@@ -399,36 +390,30 @@ class CallbackContextEnrichmentTests(TestCase):
def test_context_includes_model_class(self):
"""Test that context includes the model class."""
mock_instance = Mock()
mock_instance.__class__.__name__ = 'EditSubmission'
mock_instance.__class__.__name__ = "EditSubmission"
context = {
'instance': mock_instance,
'model_class': type(mock_instance)
}
context = {"instance": mock_instance, "model_class": type(mock_instance)}
self.assertIn('model_class', context)
self.assertIn("model_class", context)
def test_context_includes_transition_name(self):
"""Test that context includes the transition method name."""
context = {
'instance': Mock(),
'from_state': 'PENDING',
'to_state': 'APPROVED',
'transition_name': 'transition_to_approved'
"instance": Mock(),
"from_state": "PENDING",
"to_state": "APPROVED",
"transition_name": "transition_to_approved",
}
self.assertEqual(context['transition_name'], 'transition_to_approved')
self.assertEqual(context["transition_name"], "transition_to_approved")
def test_context_includes_timestamp(self):
"""Test that context includes transition timestamp."""
from django.utils import timezone
context = {
'instance': Mock(),
'timestamp': timezone.now()
}
context = {"instance": Mock(), "timestamp": timezone.now()}
self.assertIn('timestamp', context)
self.assertIn("timestamp", context)
# ============================================================================
@@ -446,9 +431,9 @@ class NotificationCallbackTests(TestCase):
def _create_transition_context(
self,
model_name: str = 'EditSubmission',
source_state: str = 'PENDING',
target_state: str = 'APPROVED',
model_name: str = "EditSubmission",
source_state: str = "PENDING",
target_state: str = "APPROVED",
user=None,
instance=None,
):
@@ -465,18 +450,18 @@ class NotificationCallbackTests(TestCase):
if user is None:
user = Mock()
user.pk = 1
user.username = 'moderator'
user.username = "moderator"
return TransitionContext(
instance=instance,
field_name='status',
field_name="status",
source_state=source_state,
target_state=target_state,
user=user,
timestamp=timezone.now(),
)
@patch('apps.core.state_machine.callbacks.notifications.NotificationService')
@patch("apps.core.state_machine.callbacks.notifications.NotificationService")
def test_notification_callback_approval_title(self, mock_service_class):
"""Test NotificationCallback generates correct title for approvals."""
from ..callbacks.notifications import NotificationCallback
@@ -487,8 +472,8 @@ class NotificationCallbackTests(TestCase):
callback = NotificationCallback()
context = self._create_transition_context(
source_state='PENDING',
target_state='APPROVED',
source_state="PENDING",
target_state="APPROVED",
)
callback.execute(context)
@@ -496,9 +481,9 @@ class NotificationCallbackTests(TestCase):
# Check that notification was sent with correct title
if mock_service.send_notification.called:
call_args = mock_service.send_notification.call_args
self.assertIn('approved', call_args[1].get('title', '').lower())
self.assertIn("approved", call_args[1].get("title", "").lower())
@patch('apps.core.state_machine.callbacks.notifications.NotificationService')
@patch("apps.core.state_machine.callbacks.notifications.NotificationService")
def test_notification_callback_rejection_title(self, mock_service_class):
"""Test NotificationCallback generates correct title for rejections."""
from ..callbacks.notifications import NotificationCallback
@@ -509,17 +494,17 @@ class NotificationCallbackTests(TestCase):
callback = NotificationCallback()
context = self._create_transition_context(
source_state='PENDING',
target_state='REJECTED',
source_state="PENDING",
target_state="REJECTED",
)
callback.execute(context)
if mock_service.send_notification.called:
call_args = mock_service.send_notification.call_args
self.assertIn('rejected', call_args[1].get('title', '').lower())
self.assertIn("rejected", call_args[1].get("title", "").lower())
@patch('apps.core.state_machine.callbacks.notifications.NotificationService')
@patch("apps.core.state_machine.callbacks.notifications.NotificationService")
def test_moderation_notification_recipient_selection(self, mock_service_class):
"""Test ModerationNotificationCallback sends to correct recipient."""
from ..callbacks.notifications import ModerationNotificationCallback
@@ -530,16 +515,16 @@ class NotificationCallbackTests(TestCase):
submitter = Mock()
submitter.pk = 999
submitter.username = 'submitter'
submitter.username = "submitter"
instance = Mock()
instance.pk = 123
instance.__class__.__name__ = 'EditSubmission'
instance.__class__.__name__ = "EditSubmission"
instance.user = submitter # The submitter who should receive notification
callback = ModerationNotificationCallback()
context = self._create_transition_context(
target_state='APPROVED',
target_state="APPROVED",
instance=instance,
)
@@ -548,10 +533,10 @@ class NotificationCallbackTests(TestCase):
if mock_service.send_notification.called:
call_args = mock_service.send_notification.call_args
# Should notify the submitter about their submission
recipient = call_args[1].get('user') or call_args[0][0] if call_args[0] else None
recipient = call_args[1].get("user") or call_args[0][0] if call_args[0] else None
self.assertIsNotNone(recipient)
@patch('apps.core.state_machine.callbacks.notifications.NotificationService')
@patch("apps.core.state_machine.callbacks.notifications.NotificationService")
def test_notification_callback_handles_service_error(self, mock_service_class):
"""Test NotificationCallback handles service errors gracefully."""
from ..callbacks.notifications import NotificationCallback
@@ -568,7 +553,7 @@ class NotificationCallbackTests(TestCase):
# Callback may return False on error but should not raise
self.assertIsNotNone(result)
@patch('apps.core.state_machine.callbacks.notifications.NotificationService')
@patch("apps.core.state_machine.callbacks.notifications.NotificationService")
def test_notification_callback_message_includes_model_info(self, mock_service_class):
"""Test notification message includes model information."""
from ..callbacks.notifications import NotificationCallback
@@ -578,13 +563,13 @@ class NotificationCallbackTests(TestCase):
mock_service_class.return_value = mock_service
callback = NotificationCallback()
context = self._create_transition_context(model_name='PhotoSubmission')
context = self._create_transition_context(model_name="PhotoSubmission")
callback.execute(context)
if mock_service.send_notification.called:
call_args = mock_service.send_notification.call_args
message = call_args[1].get('message', '')
message = call_args[1].get("message", "")
# Should reference the submission type or model
self.assertIsInstance(message, str)
@@ -599,10 +584,10 @@ class CacheCallbackTests(TestCase):
def _create_transition_context(
self,
model_name: str = 'Park',
model_name: str = "Park",
instance_id: int = 123,
source_state: str = 'OPERATING',
target_state: str = 'CLOSED_TEMP',
source_state: str = "OPERATING",
target_state: str = "CLOSED_TEMP",
):
"""Helper to create a TransitionContext."""
from django.utils import timezone
@@ -615,14 +600,14 @@ class CacheCallbackTests(TestCase):
return TransitionContext(
instance=instance,
field_name='status',
field_name="status",
source_state=source_state,
target_state=target_state,
user=Mock(),
timestamp=timezone.now(),
)
@patch('apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service')
@patch("apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service")
def test_cache_callback_invalidates_model_patterns(self, mock_get_service):
"""Test CacheInvalidationCallback invalidates correct patterns."""
from ..callbacks.cache import CacheInvalidationCallback
@@ -631,9 +616,7 @@ class CacheCallbackTests(TestCase):
mock_cache.invalidate_pattern = Mock()
mock_get_service.return_value = mock_cache
callback = CacheInvalidationCallback(
patterns=['*park:123*', '*parks*']
)
callback = CacheInvalidationCallback(patterns=["*park:123*", "*parks*"])
context = self._create_transition_context()
callback.execute(context)
@@ -641,7 +624,7 @@ class CacheCallbackTests(TestCase):
# Should have called invalidate_pattern for each pattern
self.assertTrue(mock_cache.invalidate_pattern.called)
@patch('apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service')
@patch("apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service")
def test_cache_callback_generates_instance_patterns(self, mock_get_service):
"""Test CacheInvalidationCallback generates instance-specific patterns."""
from ..callbacks.cache import CacheInvalidationCallback
@@ -651,30 +634,25 @@ class CacheCallbackTests(TestCase):
mock_get_service.return_value = mock_cache
callback = CacheInvalidationCallback(include_instance_patterns=True)
context = self._create_transition_context(
model_name='Park',
instance_id=456
)
context = self._create_transition_context(model_name="Park", instance_id=456)
callback.execute(context)
# Should have called invalidate_pattern with instance-specific patterns
self.assertTrue(mock_cache.invalidate_pattern.called)
patterns_called = [
call[0][0] for call in mock_cache.invalidate_pattern.call_args_list
]
patterns_called = [call[0][0] for call in mock_cache.invalidate_pattern.call_args_list]
# Should include patterns containing the instance ID
has_instance_pattern = any('456' in p for p in patterns_called)
has_instance_pattern = any("456" in p for p in patterns_called)
self.assertTrue(has_instance_pattern, f"No pattern with instance ID in {patterns_called}")
@patch('apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service')
@patch("apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service")
def test_cache_callback_handles_service_unavailable(self, mock_get_service):
"""Test CacheInvalidationCallback handles unavailable cache service."""
from ..callbacks.cache import CacheInvalidationCallback
mock_get_service.return_value = None
callback = CacheInvalidationCallback(patterns=['*test*'])
callback = CacheInvalidationCallback(patterns=["*test*"])
context = self._create_transition_context()
# Should not raise, uses fallback
@@ -682,7 +660,7 @@ class CacheCallbackTests(TestCase):
# Should return True (fallback succeeds)
self.assertTrue(result)
@patch('apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service')
@patch("apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service")
def test_cache_callback_continues_on_pattern_error(self, mock_get_service):
"""Test CacheInvalidationCallback continues if individual pattern fails."""
from ..callbacks.cache import CacheInvalidationCallback
@@ -693,16 +671,13 @@ class CacheCallbackTests(TestCase):
def invalidate_side_effect(pattern):
nonlocal call_count
call_count += 1
if 'bad' in pattern:
if "bad" in pattern:
raise Exception("Pattern invalid")
mock_cache.invalidate_pattern = Mock(side_effect=invalidate_side_effect)
mock_get_service.return_value = mock_cache
callback = CacheInvalidationCallback(
patterns=['good:*', 'bad:*', 'another:*'],
include_instance_patterns=False
)
callback = CacheInvalidationCallback(patterns=["good:*", "bad:*", "another:*"], include_instance_patterns=False)
context = self._create_transition_context()
# Should not raise overall
@@ -716,7 +691,7 @@ class ModelCacheInvalidationTests(TestCase):
def _create_transition_context(
self,
model_name: str = 'Ride',
model_name: str = "Ride",
instance_id: int = 789,
):
from django.utils import timezone
@@ -728,20 +703,20 @@ class ModelCacheInvalidationTests(TestCase):
instance.__class__.__name__ = model_name
# Add park reference for rides
if model_name == 'Ride':
if model_name == "Ride":
instance.park = Mock()
instance.park.pk = 111
return TransitionContext(
instance=instance,
field_name='status',
source_state='OPERATING',
target_state='CLOSED_TEMP',
field_name="status",
source_state="OPERATING",
target_state="CLOSED_TEMP",
user=Mock(),
timestamp=timezone.now(),
)
@patch('apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service')
@patch("apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service")
def test_ride_cache_includes_park_patterns(self, mock_get_service):
"""Test RideCacheInvalidation includes parent park patterns."""
from ..callbacks.cache import RideCacheInvalidation
@@ -755,12 +730,10 @@ class ModelCacheInvalidationTests(TestCase):
callback.execute(context)
patterns_called = [
call[0][0] for call in mock_cache.invalidate_pattern.call_args_list
]
patterns_called = [call[0][0] for call in mock_cache.invalidate_pattern.call_args_list]
# Should include park patterns (parent park ID is 111)
has_park_pattern = any('park' in p.lower() for p in patterns_called)
has_park_pattern = any("park" in p.lower() for p in patterns_called)
self.assertTrue(has_park_pattern, f"No park pattern in {patterns_called}")
@@ -775,17 +748,18 @@ class RelatedUpdateCallbackTests(TestCase):
def setUp(self):
"""Set up test fixtures."""
from django.contrib.auth import get_user_model
get_user_model()
self.user = Mock()
self.user.pk = 1
self.user.username = 'testuser'
self.user.username = "testuser"
def _create_transition_context(
self,
model_name: str = 'Ride',
model_name: str = "Ride",
instance=None,
target_state: str = 'OPERATING',
target_state: str = "OPERATING",
):
from django.utils import timezone
@@ -798,8 +772,8 @@ class RelatedUpdateCallbackTests(TestCase):
return TransitionContext(
instance=instance,
field_name='status',
source_state='UNDER_CONSTRUCTION',
field_name="status",
source_state="UNDER_CONSTRUCTION",
target_state=target_state,
user=self.user,
timestamp=timezone.now(),
@@ -819,7 +793,7 @@ class RelatedUpdateCallbackTests(TestCase):
# Create mock ride that belongs to park
mock_ride = Mock()
mock_ride.pk = 200
mock_ride.__class__.__name__ = 'Ride'
mock_ride.__class__.__name__ = "Ride"
mock_ride.park = mock_park
mock_ride.is_coaster = True
@@ -830,9 +804,9 @@ class RelatedUpdateCallbackTests(TestCase):
callback = ParkCountUpdateCallback()
context = self._create_transition_context(
model_name='Ride',
model_name="Ride",
instance=mock_ride,
target_state='OPERATING',
target_state="OPERATING",
)
# Execute callback
@@ -847,12 +821,12 @@ class RelatedUpdateCallbackTests(TestCase):
mock_ride = Mock()
mock_ride.pk = 200
mock_ride.__class__.__name__ = 'Ride'
mock_ride.__class__.__name__ = "Ride"
mock_ride.park = None # No park
callback = ParkCountUpdateCallback()
context = self._create_transition_context(
model_name='Ride',
model_name="Ride",
instance=mock_ride,
)
@@ -872,15 +846,15 @@ class RelatedUpdateCallbackTests(TestCase):
mock_ride = Mock()
mock_ride.pk = 200
mock_ride.__class__.__name__ = 'Ride'
mock_ride.__class__.__name__ = "Ride"
mock_ride.park = mock_park
mock_ride.is_coaster = False
callback = ParkCountUpdateCallback()
context = self._create_transition_context(
model_name='Ride',
model_name="Ride",
instance=mock_ride,
target_state='OPERATING',
target_state="OPERATING",
)
callback.execute(context)
@@ -901,15 +875,15 @@ class RelatedUpdateCallbackTests(TestCase):
mock_ride = Mock()
mock_ride.pk = 200
mock_ride.__class__.__name__ = 'Ride'
mock_ride.__class__.__name__ = "Ride"
mock_ride.park = mock_park
mock_ride.is_coaster = True
callback = ParkCountUpdateCallback()
context = self._create_transition_context(
model_name='Ride',
model_name="Ride",
instance=mock_ride,
target_state='CLOSED_PERM',
target_state="CLOSED_PERM",
)
result = callback.execute(context)
@@ -931,18 +905,18 @@ class CallbackErrorHandlingTests(TestCase):
instance = Mock()
instance.pk = 1
instance.__class__.__name__ = 'EditSubmission'
instance.__class__.__name__ = "EditSubmission"
return TransitionContext(
instance=instance,
field_name='status',
source_state='PENDING',
target_state='APPROVED',
field_name="status",
source_state="PENDING",
target_state="APPROVED",
user=Mock(),
timestamp=timezone.now(),
)
@patch('apps.core.state_machine.callbacks.notifications.NotificationService')
@patch("apps.core.state_machine.callbacks.notifications.NotificationService")
def test_notification_callback_logs_error_on_failure(self, mock_service_class):
"""Test NotificationCallback logs errors when service fails."""
import logging
@@ -957,7 +931,7 @@ class CallbackErrorHandlingTests(TestCase):
context = self._create_transition_context()
with self.assertLogs(level=logging.WARNING):
try:
try: # noqa: SIM105
callback.execute(context)
except Exception:
pass # May or may not raise depending on implementation
@@ -965,7 +939,7 @@ class CallbackErrorHandlingTests(TestCase):
# Should have logged something about the error
# (Logging behavior depends on implementation)
@patch('apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service')
@patch("apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service")
def test_cache_callback_returns_false_on_total_failure(self, mock_get_service):
"""Test CacheInvalidationCallback returns False on complete failure."""
from ..callbacks.cache import CacheInvalidationCallback
@@ -974,10 +948,7 @@ class CallbackErrorHandlingTests(TestCase):
mock_cache.invalidate_pattern = Mock(side_effect=Exception("Cache error"))
mock_get_service.return_value = mock_cache
callback = CacheInvalidationCallback(
patterns=['*test*'],
include_instance_patterns=False
)
callback = CacheInvalidationCallback(patterns=["*test*"], include_instance_patterns=False)
context = self._create_transition_context()
result = callback.execute(context)
@@ -993,18 +964,18 @@ class CallbackErrorHandlingTests(TestCase):
instance = Mock()
instance.pk = 1
instance.__class__.__name__ = 'EditSubmission'
instance.__class__.__name__ = "EditSubmission"
context = TransitionContext(
instance=instance,
field_name='status',
source_state='PENDING',
target_state='APPROVED',
field_name="status",
source_state="PENDING",
target_state="APPROVED",
user=None, # No user
timestamp=timezone.now(),
)
with patch('apps.core.state_machine.callbacks.notifications.NotificationService'):
with patch("apps.core.state_machine.callbacks.notifications.NotificationService"):
callback = NotificationCallback()
# Should not raise with None user
try:

View File

@@ -1,4 +1,5 @@
"""Tests for transition decorator generation."""
from unittest.mock import Mock
from apps.core.state_machine.decorators import (
@@ -11,9 +12,7 @@ from apps.core.state_machine.decorators import (
def test_generate_transition_decorator():
"""Test basic transition decorator generation."""
decorator = generate_transition_decorator(
source="pending", target="approved", field_name="status"
)
decorator = generate_transition_decorator(source="pending", target="approved", field_name="status")
assert callable(decorator)
@@ -72,9 +71,7 @@ def test_create_transition_method_with_callbacks():
def test_factory_create_approve_method():
"""Test approval method creation."""
factory = TransitionMethodFactory()
method = factory.create_approve_method(
source="pending", target="approved", field_name="status"
)
method = factory.create_approve_method(source="pending", target="approved", field_name="status")
assert callable(method)
assert method.__name__ == "approve"
@@ -82,9 +79,7 @@ def test_factory_create_approve_method():
def test_factory_create_reject_method():
"""Test rejection method creation."""
factory = TransitionMethodFactory()
method = factory.create_reject_method(
source="pending", target="rejected", field_name="status"
)
method = factory.create_reject_method(source="pending", target="rejected", field_name="status")
assert callable(method)
assert method.__name__ == "reject"
@@ -92,9 +87,7 @@ def test_factory_create_reject_method():
def test_factory_create_escalate_method():
"""Test escalation method creation."""
factory = TransitionMethodFactory()
method = factory.create_escalate_method(
source="pending", target="escalated", field_name="status"
)
method = factory.create_escalate_method(source="pending", target="escalated", field_name="status")
assert callable(method)
assert method.__name__ == "escalate"
@@ -145,16 +138,14 @@ def test_with_transition_logging():
def test_method_signature_generation():
"""Test that generated methods have proper signatures."""
factory = TransitionMethodFactory()
method = factory.create_approve_method(
source="pending", target="approved"
)
method = factory.create_approve_method(source="pending", target="approved")
# Check method accepts expected parameters
mock_instance = Mock()
mock_user = Mock()
# Should not raise
try:
try: # noqa: SIM105
method(mock_instance, user=mock_user, comment="test")
except Exception:
# May fail due to django-fsm not being fully configured

View File

@@ -58,28 +58,16 @@ class PermissionGuardTests(TestCase):
def setUp(self):
"""Set up test fixtures."""
self.regular_user = User.objects.create_user(
username='user',
email='user@example.com',
password='testpass123',
role='USER'
username="user", email="user@example.com", password="testpass123", role="USER"
)
self.moderator = User.objects.create_user(
username='moderator',
email='moderator@example.com',
password='testpass123',
role='MODERATOR'
username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR"
)
self.admin = User.objects.create_user(
username='admin',
email='admin@example.com',
password='testpass123',
role='ADMIN'
username="admin", email="admin@example.com", password="testpass123", role="ADMIN"
)
self.superuser = User.objects.create_user(
username='superuser',
email='superuser@example.com',
password='testpass123',
role='SUPERUSER'
username="superuser", email="superuser@example.com", password="testpass123", role="SUPERUSER"
)
self.instance = MockInstance()
@@ -168,7 +156,7 @@ class PermissionGuardTests(TestCase):
def test_required_roles_explicit_list(self):
"""Test using explicit required_roles list."""
guard = PermissionGuard(required_roles=['ADMIN', 'SUPERUSER'])
guard = PermissionGuard(required_roles=["ADMIN", "SUPERUSER"])
self.assertTrue(guard(self.instance, user=self.admin))
self.assertTrue(guard(self.instance, user=self.superuser))
@@ -177,8 +165,9 @@ class PermissionGuardTests(TestCase):
def test_custom_check_passes(self):
"""Test custom check function that passes."""
def custom_check(instance, user):
return hasattr(instance, 'allow_access') and instance.allow_access
return hasattr(instance, "allow_access") and instance.allow_access
guard = PermissionGuard(custom_check=custom_check)
instance = MockInstance(allow_access=True)
@@ -189,8 +178,9 @@ class PermissionGuardTests(TestCase):
def test_custom_check_fails(self):
"""Test custom check function that fails."""
def custom_check(instance, user):
return hasattr(instance, 'allow_access') and instance.allow_access
return hasattr(instance, "allow_access") and instance.allow_access
guard = PermissionGuard(custom_check=custom_check)
instance = MockInstance(allow_access=False)
@@ -237,28 +227,16 @@ class OwnershipGuardTests(TestCase):
def setUp(self):
"""Set up test fixtures."""
self.owner = User.objects.create_user(
username='owner',
email='owner@example.com',
password='testpass123',
role='USER'
username="owner", email="owner@example.com", password="testpass123", role="USER"
)
self.other_user = User.objects.create_user(
username='other',
email='other@example.com',
password='testpass123',
role='USER'
username="other", email="other@example.com", password="testpass123", role="USER"
)
self.moderator = User.objects.create_user(
username='moderator',
email='moderator@example.com',
password='testpass123',
role='MODERATOR'
username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR"
)
self.admin = User.objects.create_user(
username='admin',
email='admin@example.com',
password='testpass123',
role='ADMIN'
username="admin", email="admin@example.com", password="testpass123", role="ADMIN"
)
def test_no_user_fails(self):
@@ -329,7 +307,7 @@ class OwnershipGuardTests(TestCase):
def test_custom_owner_fields(self):
"""Test custom owner field names."""
instance = MockInstance(author=self.owner)
guard = OwnershipGuard(owner_fields=['author'])
guard = OwnershipGuard(owner_fields=["author"])
result = guard(instance, user=self.owner)
@@ -357,22 +335,13 @@ class AssignmentGuardTests(TestCase):
def setUp(self):
"""Set up test fixtures."""
self.assigned_user = User.objects.create_user(
username='assigned',
email='assigned@example.com',
password='testpass123',
role='MODERATOR'
username="assigned", email="assigned@example.com", password="testpass123", role="MODERATOR"
)
self.other_user = User.objects.create_user(
username='other',
email='other@example.com',
password='testpass123',
role='MODERATOR'
username="other", email="other@example.com", password="testpass123", role="MODERATOR"
)
self.admin = User.objects.create_user(
username='admin',
email='admin@example.com',
password='testpass123',
role='ADMIN'
username="admin", email="admin@example.com", password="testpass123", role="ADMIN"
)
def test_no_user_fails(self):
@@ -426,7 +395,7 @@ class AssignmentGuardTests(TestCase):
def test_custom_assignment_fields(self):
"""Test custom assignment field names."""
instance = MockInstance(reviewer=self.assigned_user)
guard = AssignmentGuard(assignment_fields=['reviewer'])
guard = AssignmentGuard(assignment_fields=["reviewer"])
result = guard(instance, user=self.assigned_user)
@@ -439,7 +408,7 @@ class AssignmentGuardTests(TestCase):
guard(instance, user=self.assigned_user)
self.assertIn('assigned', guard.get_error_message().lower())
self.assertIn("assigned", guard.get_error_message().lower())
# ============================================================================
@@ -453,16 +422,13 @@ class StateGuardTests(TestCase):
def setUp(self):
"""Set up test fixtures."""
self.user = User.objects.create_user(
username='user',
email='user@example.com',
password='testpass123',
role='USER'
username="user", email="user@example.com", password="testpass123", role="USER"
)
def test_allowed_states_passes(self):
"""Test that guard passes when in allowed state."""
instance = MockInstance(status='PENDING')
guard = StateGuard(allowed_states=['PENDING', 'UNDER_REVIEW'])
instance = MockInstance(status="PENDING")
guard = StateGuard(allowed_states=["PENDING", "UNDER_REVIEW"])
result = guard(instance, user=self.user)
@@ -470,8 +436,8 @@ class StateGuardTests(TestCase):
def test_allowed_states_fails(self):
"""Test that guard fails when not in allowed state."""
instance = MockInstance(status='COMPLETED')
guard = StateGuard(allowed_states=['PENDING', 'UNDER_REVIEW'])
instance = MockInstance(status="COMPLETED")
guard = StateGuard(allowed_states=["PENDING", "UNDER_REVIEW"])
result = guard(instance, user=self.user)
@@ -480,8 +446,8 @@ class StateGuardTests(TestCase):
def test_blocked_states_passes(self):
"""Test that guard passes when not in blocked state."""
instance = MockInstance(status='PENDING')
guard = StateGuard(blocked_states=['COMPLETED', 'CANCELLED'])
instance = MockInstance(status="PENDING")
guard = StateGuard(blocked_states=["COMPLETED", "CANCELLED"])
result = guard(instance, user=self.user)
@@ -489,8 +455,8 @@ class StateGuardTests(TestCase):
def test_blocked_states_fails(self):
"""Test that guard fails when in blocked state."""
instance = MockInstance(status='COMPLETED')
guard = StateGuard(blocked_states=['COMPLETED', 'CANCELLED'])
instance = MockInstance(status="COMPLETED")
guard = StateGuard(blocked_states=["COMPLETED", "CANCELLED"])
result = guard(instance, user=self.user)
@@ -499,8 +465,8 @@ class StateGuardTests(TestCase):
def test_custom_state_field(self):
"""Test using custom state field name."""
instance = MockInstance(workflow_status='ACTIVE')
guard = StateGuard(allowed_states=['ACTIVE'], state_field='workflow_status')
instance = MockInstance(workflow_status="ACTIVE")
guard = StateGuard(allowed_states=["ACTIVE"], state_field="workflow_status")
result = guard(instance, user=self.user)
@@ -508,14 +474,14 @@ class StateGuardTests(TestCase):
def test_error_message_includes_states(self):
"""Test that error message includes allowed states."""
instance = MockInstance(status='COMPLETED')
guard = StateGuard(allowed_states=['PENDING', 'UNDER_REVIEW'])
instance = MockInstance(status="COMPLETED")
guard = StateGuard(allowed_states=["PENDING", "UNDER_REVIEW"])
guard(instance, user=self.user)
message = guard.get_error_message()
self.assertIn('PENDING', message)
self.assertIn('UNDER_REVIEW', message)
self.assertIn("PENDING", message)
self.assertIn("UNDER_REVIEW", message)
# ============================================================================
@@ -529,16 +495,13 @@ class MetadataGuardTests(TestCase):
def setUp(self):
"""Set up test fixtures."""
self.user = User.objects.create_user(
username='user',
email='user@example.com',
password='testpass123',
role='USER'
username="user", email="user@example.com", password="testpass123", role="USER"
)
def test_required_fields_present(self):
"""Test that guard passes when required fields are present."""
instance = MockInstance(resolution_notes='Fixed', assigned_to='user')
guard = MetadataGuard(required_fields=['resolution_notes', 'assigned_to'])
instance = MockInstance(resolution_notes="Fixed", assigned_to="user")
guard = MetadataGuard(required_fields=["resolution_notes", "assigned_to"])
result = guard(instance, user=self.user)
@@ -546,8 +509,8 @@ class MetadataGuardTests(TestCase):
def test_required_field_missing(self):
"""Test that guard fails when required field is missing."""
instance = MockInstance(resolution_notes='Fixed')
guard = MetadataGuard(required_fields=['resolution_notes', 'assigned_to'])
instance = MockInstance(resolution_notes="Fixed")
guard = MetadataGuard(required_fields=["resolution_notes", "assigned_to"])
result = guard(instance, user=self.user)
@@ -557,7 +520,7 @@ class MetadataGuardTests(TestCase):
def test_required_field_none(self):
"""Test that guard fails when required field is None."""
instance = MockInstance(resolution_notes=None)
guard = MetadataGuard(required_fields=['resolution_notes'])
guard = MetadataGuard(required_fields=["resolution_notes"])
result = guard(instance, user=self.user)
@@ -566,8 +529,8 @@ class MetadataGuardTests(TestCase):
def test_empty_string_fails_check_not_empty(self):
"""Test that empty string fails when check_not_empty is True."""
instance = MockInstance(resolution_notes=' ')
guard = MetadataGuard(required_fields=['resolution_notes'], check_not_empty=True)
instance = MockInstance(resolution_notes=" ")
guard = MetadataGuard(required_fields=["resolution_notes"], check_not_empty=True)
result = guard(instance, user=self.user)
@@ -577,7 +540,7 @@ class MetadataGuardTests(TestCase):
def test_empty_list_fails_check_not_empty(self):
"""Test that empty list fails when check_not_empty is True."""
instance = MockInstance(tags=[])
guard = MetadataGuard(required_fields=['tags'], check_not_empty=True)
guard = MetadataGuard(required_fields=["tags"], check_not_empty=True)
result = guard(instance, user=self.user)
@@ -587,7 +550,7 @@ class MetadataGuardTests(TestCase):
def test_empty_dict_fails_check_not_empty(self):
"""Test that empty dict fails when check_not_empty is True."""
instance = MockInstance(metadata={})
guard = MetadataGuard(required_fields=['metadata'], check_not_empty=True)
guard = MetadataGuard(required_fields=["metadata"], check_not_empty=True)
result = guard(instance, user=self.user)
@@ -597,12 +560,12 @@ class MetadataGuardTests(TestCase):
def test_error_message_includes_field_name(self):
"""Test that error message includes the field name."""
instance = MockInstance(resolution_notes=None)
guard = MetadataGuard(required_fields=['resolution_notes'])
guard = MetadataGuard(required_fields=["resolution_notes"])
guard(instance, user=self.user)
message = guard.get_error_message()
self.assertIn('Resolution Notes', message)
self.assertIn("Resolution Notes", message)
# ============================================================================
@@ -616,32 +579,23 @@ class CompositeGuardTests(TestCase):
def setUp(self):
"""Set up test fixtures."""
self.owner = User.objects.create_user(
username='owner',
email='owner@example.com',
password='testpass123',
role='USER'
username="owner", email="owner@example.com", password="testpass123", role="USER"
)
self.moderator = User.objects.create_user(
username='moderator',
email='moderator@example.com',
password='testpass123',
role='MODERATOR'
username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR"
)
self.non_owner_moderator = User.objects.create_user(
username='non_owner_moderator',
email='non_owner_moderator@example.com',
password='testpass123',
role='MODERATOR'
username="non_owner_moderator",
email="non_owner_moderator@example.com",
password="testpass123",
role="MODERATOR",
)
def test_and_operator_all_pass(self):
"""Test AND operator when all guards pass."""
instance = MockInstance(created_by=self.moderator)
guards = [
PermissionGuard(requires_moderator=True),
OwnershipGuard()
]
composite = CompositeGuard(guards, operator='AND')
guards = [PermissionGuard(requires_moderator=True), OwnershipGuard()]
composite = CompositeGuard(guards, operator="AND")
result = composite(instance, user=self.moderator)
@@ -652,9 +606,9 @@ class CompositeGuardTests(TestCase):
instance = MockInstance(created_by=self.owner)
guards = [
PermissionGuard(requires_moderator=True), # Will pass for moderator
OwnershipGuard() # Will fail - moderator is not owner
OwnershipGuard(), # Will fail - moderator is not owner
]
composite = CompositeGuard(guards, operator='AND')
composite = CompositeGuard(guards, operator="AND")
result = composite(instance, user=self.non_owner_moderator)
@@ -666,9 +620,9 @@ class CompositeGuardTests(TestCase):
instance = MockInstance(created_by=self.owner)
guards = [
PermissionGuard(requires_moderator=True), # Will fail for owner
OwnershipGuard() # Will pass - user is owner
OwnershipGuard(), # Will pass - user is owner
]
composite = CompositeGuard(guards, operator='OR')
composite = CompositeGuard(guards, operator="OR")
result = composite(instance, user=self.owner)
@@ -677,11 +631,8 @@ class CompositeGuardTests(TestCase):
def test_or_operator_all_fail(self):
"""Test OR operator when all guards fail."""
instance = MockInstance(created_by=self.moderator)
guards = [
PermissionGuard(requires_admin=True), # Regular user fails
OwnershipGuard() # Not the owner fails
]
composite = CompositeGuard(guards, operator='OR')
guards = [PermissionGuard(requires_admin=True), OwnershipGuard()] # Regular user fails # Not the owner fails
composite = CompositeGuard(guards, operator="OR")
result = composite(instance, user=self.owner)
@@ -690,19 +641,13 @@ class CompositeGuardTests(TestCase):
def test_nested_composite_guards(self):
"""Test nested composite guards."""
instance = MockInstance(created_by=self.moderator, status='PENDING')
instance = MockInstance(created_by=self.moderator, status="PENDING")
# Inner composite: moderator OR owner
inner = CompositeGuard([
PermissionGuard(requires_moderator=True),
OwnershipGuard()
], operator='OR')
inner = CompositeGuard([PermissionGuard(requires_moderator=True), OwnershipGuard()], operator="OR")
# Outer composite: (moderator OR owner) AND valid state
outer = CompositeGuard([
inner,
StateGuard(allowed_states=['PENDING'])
], operator='AND')
outer = CompositeGuard([inner, StateGuard(allowed_states=["PENDING"])], operator="AND")
result = outer(instance, user=self.moderator)
@@ -713,12 +658,12 @@ class CompositeGuardTests(TestCase):
instance = MockInstance(created_by=self.owner)
perm_guard = PermissionGuard(requires_admin=True)
guards = [perm_guard]
composite = CompositeGuard(guards, operator='AND')
composite = CompositeGuard(guards, operator="AND")
composite(instance, user=self.owner)
message = composite.get_error_message()
self.assertIn('admin', message.lower())
self.assertIn("admin", message.lower())
# ============================================================================
@@ -732,15 +677,12 @@ class GuardFactoryTests(TestCase):
def setUp(self):
"""Set up test fixtures."""
self.moderator = User.objects.create_user(
username='moderator',
email='moderator@example.com',
password='testpass123',
role='MODERATOR'
username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR"
)
def test_create_permission_guard_moderator(self):
"""Test create_permission_guard with moderator requirement."""
metadata = {'requires_moderator': True}
metadata = {"requires_moderator": True}
guard = create_permission_guard(metadata)
instance = MockInstance()
@@ -750,14 +692,14 @@ class GuardFactoryTests(TestCase):
def test_create_permission_guard_admin(self):
"""Test create_permission_guard with admin requirement."""
metadata = {'requires_admin_approval': True}
metadata = {"requires_admin_approval": True}
guard = create_permission_guard(metadata)
self.assertTrue(guard.requires_admin)
def test_create_permission_guard_escalation_level(self):
"""Test create_permission_guard with escalation level."""
metadata = {'escalation_level': 'admin'}
metadata = {"escalation_level": "admin"}
guard = create_permission_guard(metadata)
self.assertTrue(guard.requires_admin)
@@ -777,9 +719,9 @@ class GuardFactoryTests(TestCase):
def test_create_composite_guard(self):
"""Test create_composite_guard factory."""
guards = [PermissionGuard(), OwnershipGuard()]
composite = create_composite_guard(guards, operator='OR')
composite = create_composite_guard(guards, operator="OR")
self.assertEqual(composite.operator, 'OR')
self.assertEqual(composite.operator, "OR")
self.assertEqual(len(composite.guards), 2)
@@ -793,7 +735,7 @@ class MetadataExtractionTests(TestCase):
def test_extract_moderator_guard(self):
"""Test extracting guard for moderator requirement."""
metadata = {'requires_moderator': True}
metadata = {"requires_moderator": True}
guards = extract_guards_from_metadata(metadata)
self.assertEqual(len(guards), 1)
@@ -801,7 +743,7 @@ class MetadataExtractionTests(TestCase):
def test_extract_admin_guard(self):
"""Test extracting guard for admin requirement."""
metadata = {'requires_admin_approval': True}
metadata = {"requires_admin_approval": True}
guards = extract_guards_from_metadata(metadata)
self.assertEqual(len(guards), 1)
@@ -809,7 +751,7 @@ class MetadataExtractionTests(TestCase):
def test_extract_assignment_guard(self):
"""Test extracting assignment guard."""
metadata = {'requires_assignment': True}
metadata = {"requires_assignment": True}
guards = extract_guards_from_metadata(metadata)
self.assertEqual(len(guards), 1)
@@ -817,17 +759,14 @@ class MetadataExtractionTests(TestCase):
def test_extract_multiple_guards(self):
"""Test extracting multiple guards."""
metadata = {
'requires_moderator': True,
'requires_assignment': True
}
metadata = {"requires_moderator": True, "requires_assignment": True}
guards = extract_guards_from_metadata(metadata)
self.assertEqual(len(guards), 2)
def test_extract_zero_tolerance_guard(self):
"""Test extracting guard for zero tolerance (superuser required)."""
metadata = {'zero_tolerance': True}
metadata = {"zero_tolerance": True}
guards = extract_guards_from_metadata(metadata)
self.assertEqual(len(guards), 1)
@@ -835,7 +774,7 @@ class MetadataExtractionTests(TestCase):
def test_invalid_escalation_level_raises(self):
"""Test that invalid escalation level raises ValueError."""
metadata = {'escalation_level': 'invalid'}
metadata = {"escalation_level": "invalid"}
with self.assertRaises(ValueError):
extract_guards_from_metadata(metadata)
@@ -851,11 +790,7 @@ class MetadataValidationTests(TestCase):
def test_valid_metadata(self):
"""Test that valid metadata passes validation."""
metadata = {
'requires_moderator': True,
'escalation_level': 'admin',
'requires_assignment': False
}
metadata = {"requires_moderator": True, "escalation_level": "admin", "requires_assignment": False}
is_valid, errors = validate_guard_metadata(metadata)
@@ -864,30 +799,30 @@ class MetadataValidationTests(TestCase):
def test_invalid_escalation_level(self):
"""Test that invalid escalation level fails validation."""
metadata = {'escalation_level': 'invalid_level'}
metadata = {"escalation_level": "invalid_level"}
is_valid, errors = validate_guard_metadata(metadata)
self.assertFalse(is_valid)
self.assertTrue(any('escalation_level' in e for e in errors))
self.assertTrue(any("escalation_level" in e for e in errors))
def test_invalid_boolean_field(self):
"""Test that non-boolean value for boolean field fails validation."""
metadata = {'requires_moderator': 'yes'}
metadata = {"requires_moderator": "yes"}
is_valid, errors = validate_guard_metadata(metadata)
self.assertFalse(is_valid)
self.assertTrue(any('requires_moderator' in e for e in errors))
self.assertTrue(any("requires_moderator" in e for e in errors))
def test_required_permissions_not_list(self):
"""Test that non-list required_permissions fails validation."""
metadata = {'required_permissions': 'app.permission'}
metadata = {"required_permissions": "app.permission"}
is_valid, errors = validate_guard_metadata(metadata)
self.assertFalse(is_valid)
self.assertTrue(any('required_permissions' in e for e in errors))
self.assertTrue(any("required_permissions" in e for e in errors))
# ============================================================================
@@ -901,42 +836,30 @@ class RoleHelperTests(TestCase):
def setUp(self):
"""Set up test fixtures."""
self.regular_user = User.objects.create_user(
username='user',
email='user@example.com',
password='testpass123',
role='USER'
username="user", email="user@example.com", password="testpass123", role="USER"
)
self.moderator = User.objects.create_user(
username='moderator',
email='moderator@example.com',
password='testpass123',
role='MODERATOR'
username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR"
)
self.admin = User.objects.create_user(
username='admin',
email='admin@example.com',
password='testpass123',
role='ADMIN'
username="admin", email="admin@example.com", password="testpass123", role="ADMIN"
)
self.superuser = User.objects.create_user(
username='superuser',
email='superuser@example.com',
password='testpass123',
role='SUPERUSER'
username="superuser", email="superuser@example.com", password="testpass123", role="SUPERUSER"
)
def test_get_user_role(self):
"""Test get_user_role returns correct role."""
self.assertEqual(get_user_role(self.regular_user), 'USER')
self.assertEqual(get_user_role(self.moderator), 'MODERATOR')
self.assertEqual(get_user_role(self.admin), 'ADMIN')
self.assertEqual(get_user_role(self.superuser), 'SUPERUSER')
self.assertEqual(get_user_role(self.regular_user), "USER")
self.assertEqual(get_user_role(self.moderator), "MODERATOR")
self.assertEqual(get_user_role(self.admin), "ADMIN")
self.assertEqual(get_user_role(self.superuser), "SUPERUSER")
self.assertIsNone(get_user_role(None))
def test_has_role(self):
"""Test has_role function."""
self.assertTrue(has_role(self.moderator, ['MODERATOR', 'ADMIN']))
self.assertFalse(has_role(self.regular_user, ['MODERATOR', 'ADMIN']))
self.assertTrue(has_role(self.moderator, ["MODERATOR", "ADMIN"]))
self.assertFalse(has_role(self.regular_user, ["MODERATOR", "ADMIN"]))
def test_is_moderator_or_above(self):
"""Test is_moderator_or_above function."""
@@ -963,7 +886,7 @@ class RoleHelperTests(TestCase):
"""Test that anonymous user has no role."""
anonymous = AnonymousUser()
self.assertFalse(has_role(anonymous, ['USER']))
self.assertFalse(has_role(anonymous, ["USER"]))
self.assertFalse(is_moderator_or_above(anonymous))
self.assertFalse(is_admin_or_above(anonymous))
self.assertFalse(is_superuser_role(anonymous))

View File

@@ -1,4 +1,5 @@
"""Integration tests for state machine model integration."""
from unittest.mock import Mock, patch
import pytest
@@ -74,31 +75,23 @@ def test_generate_transition_methods(sample_choices):
"""Test generating transition methods on model."""
mock_model = type("MockModel", (), {})
generate_transition_methods_for_model(
mock_model, "status", "test_states", "test"
)
generate_transition_methods_for_model(mock_model, "status", "test_states", "test")
# Check that transition methods were added
# Method names may vary based on implementation
assert hasattr(mock_model, "approve") or hasattr(
mock_model, "transition_to_approved"
)
assert hasattr(mock_model, "approve") or hasattr(mock_model, "transition_to_approved")
def test_state_machine_model_decorator(sample_choices):
"""Test state_machine_model decorator."""
@state_machine_model(
field_name="status", choice_group="test_states", domain="test"
)
@state_machine_model(field_name="status", choice_group="test_states", domain="test")
class TestModel:
pass
# Decorator should apply state machine
# Check for transition methods
assert hasattr(TestModel, "approve") or hasattr(
TestModel, "transition_to_approved"
)
assert hasattr(TestModel, "approve") or hasattr(TestModel, "transition_to_approved")
def test_state_machine_mixin_get_available_transitions():

View File

@@ -1,4 +1,5 @@
"""Tests for TransitionRegistry."""
import pytest
from apps.core.choices.base import RichChoice
@@ -55,12 +56,8 @@ def test_transition_info_creation():
def test_transition_info_hashable():
"""Test TransitionInfo is hashable."""
info1 = TransitionInfo(
source="pending", target="approved", method_name="approve"
)
info2 = TransitionInfo(
source="pending", target="approved", method_name="approve"
)
info1 = TransitionInfo(source="pending", target="approved", method_name="approve")
info2 = TransitionInfo(source="pending", target="approved", method_name="approve")
assert hash(info1) == hash(info2)
@@ -82,9 +79,7 @@ def test_register_transition():
metadata={"requires_moderator": True},
)
transition = registry_instance.get_transition(
"test_states", "test", "pending", "approved"
)
transition = registry_instance.get_transition("test_states", "test", "pending", "approved")
assert transition is not None
assert transition.method_name == "approve"
assert transition.requires_moderator is True
@@ -92,9 +87,7 @@ def test_register_transition():
def test_get_transition_not_found():
"""Test getting non-existent transition."""
transition = registry_instance.get_transition(
"nonexistent", "test", "pending", "approved"
)
transition = registry_instance.get_transition("nonexistent", "test", "pending", "approved")
assert transition is None
@@ -102,9 +95,7 @@ def test_get_available_transitions(sample_choices):
"""Test getting available transitions from a state."""
registry_instance.build_registry_from_choices("test_states", "test")
available = registry_instance.get_available_transitions(
"test_states", "test", "pending"
)
available = registry_instance.get_available_transitions("test_states", "test", "pending")
assert len(available) == 2
targets = [t.target for t in available]
assert "approved" in targets
@@ -121,9 +112,7 @@ def test_get_transition_method_name():
method_name="approve",
)
method_name = registry_instance.get_transition_method_name(
"test_states", "test", "pending", "approved"
)
method_name = registry_instance.get_transition_method_name("test_states", "test", "pending", "approved")
assert method_name == "approve"
@@ -137,12 +126,8 @@ def test_validate_transition():
method_name="approve",
)
assert registry_instance.validate_transition(
"test_states", "test", "pending", "approved"
)
assert not registry_instance.validate_transition(
"test_states", "test", "pending", "nonexistent"
)
assert registry_instance.validate_transition("test_states", "test", "pending", "approved")
assert not registry_instance.validate_transition("test_states", "test", "pending", "nonexistent")
def test_build_registry_from_choices(sample_choices):
@@ -150,9 +135,7 @@ def test_build_registry_from_choices(sample_choices):
registry_instance.build_registry_from_choices("test_states", "test")
# Check transitions were registered
transition = registry_instance.get_transition(
"test_states", "test", "pending", "approved"
)
transition = registry_instance.get_transition("test_states", "test", "pending", "approved")
assert transition is not None
@@ -168,9 +151,7 @@ def test_clear_registry_specific():
registry_instance.clear_registry(choice_group="test_states", domain="test")
transition = registry_instance.get_transition(
"test_states", "test", "pending", "approved"
)
transition = registry_instance.get_transition("test_states", "test", "pending", "approved")
assert transition is None
@@ -186,9 +167,7 @@ def test_clear_registry_all():
registry_instance.clear_registry()
transition = registry_instance.get_transition(
"test_states", "test", "pending", "approved"
)
transition = registry_instance.get_transition("test_states", "test", "pending", "approved")
assert transition is None
@@ -196,9 +175,7 @@ def test_export_transition_graph_dict(sample_choices):
"""Test exporting transition graph as dict."""
registry_instance.build_registry_from_choices("test_states", "test")
graph = registry_instance.export_transition_graph(
"test_states", "test", format="dict"
)
graph = registry_instance.export_transition_graph("test_states", "test", format="dict")
assert isinstance(graph, dict)
assert "pending" in graph
assert set(graph["pending"]) == {"approved", "rejected"}
@@ -208,9 +185,7 @@ def test_export_transition_graph_mermaid(sample_choices):
"""Test exporting transition graph as mermaid."""
registry_instance.build_registry_from_choices("test_states", "test")
graph = registry_instance.export_transition_graph(
"test_states", "test", format="mermaid"
)
graph = registry_instance.export_transition_graph("test_states", "test", format="mermaid")
assert isinstance(graph, str)
assert "stateDiagram-v2" in graph
assert "pending" in graph
@@ -220,9 +195,7 @@ def test_export_transition_graph_dot(sample_choices):
"""Test exporting transition graph as DOT."""
registry_instance.build_registry_from_choices("test_states", "test")
graph = registry_instance.export_transition_graph(
"test_states", "test", format="dot"
)
graph = registry_instance.export_transition_graph("test_states", "test", format="dot")
assert isinstance(graph, str)
assert "digraph" in graph
assert "pending" in graph
@@ -233,9 +206,7 @@ def test_export_invalid_format(sample_choices):
registry_instance.build_registry_from_choices("test_states", "test")
with pytest.raises(ValueError):
registry_instance.export_transition_graph(
"test_states", "test", format="invalid"
)
registry_instance.export_transition_graph("test_states", "test", format="invalid")
def test_get_all_registered_groups():

View File

@@ -1,4 +1,5 @@
"""Tests for metadata validators."""
import pytest
from apps.core.choices.base import RichChoice
@@ -70,9 +71,7 @@ def terminal_with_transitions():
def test_validation_error_creation():
"""Test ValidationError creation."""
error = ValidationError(
code="TEST_ERROR", message="Test message", state="pending"
)
error = ValidationError(code="TEST_ERROR", message="Test message", state="pending")
assert error.code == "TEST_ERROR"
assert error.message == "Test message"
assert error.state == "pending"
@@ -81,9 +80,7 @@ def test_validation_error_creation():
def test_validation_warning_creation():
"""Test ValidationWarning creation."""
warning = ValidationWarning(
code="TEST_WARNING", message="Test warning", state="pending"
)
warning = ValidationWarning(code="TEST_WARNING", message="Test warning", state="pending")
assert warning.code == "TEST_WARNING"
assert warning.message == "Test warning"
@@ -166,15 +163,9 @@ def test_validate_no_cycles(valid_choices):
def test_validate_no_cycles_with_cycle():
"""Test cycle detection finds cycles."""
choices = [
RichChoice(
value="a", label="A", metadata={"can_transition_to": ["b"]}
),
RichChoice(
value="b", label="B", metadata={"can_transition_to": ["c"]}
),
RichChoice(
value="c", label="C", metadata={"can_transition_to": ["a"]}
),
RichChoice(value="a", label="A", metadata={"can_transition_to": ["b"]}),
RichChoice(value="b", label="B", metadata={"can_transition_to": ["c"]}),
RichChoice(value="c", label="C", metadata={"can_transition_to": ["a"]}),
]
registry.register("cycle_states", choices, domain="test")
@@ -202,9 +193,7 @@ def test_validate_reachability_unreachable():
label="Pending",
metadata={"can_transition_to": ["approved"]},
),
RichChoice(
value="approved", label="Approved", metadata={"is_final": True}
),
RichChoice(value="approved", label="Approved", metadata={"is_final": True}),
RichChoice(
value="orphan",
label="Orphan",

View File

@@ -1,4 +1,5 @@
"""Metadata validators for ensuring RichChoice metadata meets FSM requirements."""
from dataclasses import dataclass, field
from typing import Any
@@ -110,8 +111,7 @@ class MetadataValidator:
ValidationError(
code="MISSING_CAN_TRANSITION_TO",
message=(
"State metadata must explicitly define "
"'can_transition_to' (use [] for terminal states)"
"State metadata must explicitly define " "'can_transition_to' (use [] for terminal states)"
),
state=state,
)
@@ -138,9 +138,7 @@ class MetadataValidator:
errors.append(
ValidationError(
code="INVALID_TRANSITION_TARGET",
message=(
f"Transition target '{target}' does not exist"
),
message=(f"Transition target '{target}' does not exist"),
state=state,
)
)
@@ -188,17 +186,11 @@ class MetadataValidator:
perms = self.builder.extract_permission_requirements(state)
# Check for contradictory permissions
if (
perms.get("requires_admin_approval")
and not perms.get("requires_moderator")
):
if perms.get("requires_admin_approval") and not perms.get("requires_moderator"):
errors.append(
ValidationError(
code="PERMISSION_INCONSISTENCY",
message=(
"State requires admin approval but not moderator "
"(admin should imply moderator)"
),
message=("State requires admin approval but not moderator " "(admin should imply moderator)"),
state=state,
)
)
@@ -251,9 +243,7 @@ class MetadataValidator:
errors.append(
ValidationError(
code="STATE_CYCLE_DETECTED",
message=(
f"Cycle detected: {' -> '.join(cycle)}"
),
message=(f"Cycle detected: {' -> '.join(cycle)}"),
state=cycle[0],
)
)
@@ -278,9 +268,7 @@ class MetadataValidator:
for target in targets:
incoming[target].append(source)
initial_states = [
state for state in all_states if not incoming[state]
]
initial_states = [state for state in all_states if not incoming[state]]
if not initial_states:
errors.append(
@@ -327,9 +315,7 @@ class MetadataValidator:
result = self.validate_choice_group()
lines = []
lines.append(
f"Validation Report for {self.domain}.{self.choice_group}"
)
lines.append(f"Validation Report for {self.domain}.{self.choice_group}")
lines.append("=" * 60)
lines.append(f"Status: {'VALID' if result.is_valid else 'INVALID'}")
lines.append(f"Errors: {len(result.errors)}")
@@ -372,10 +358,7 @@ def validate_on_registration(choice_group: str, domain: str = "core") -> bool:
if not result.is_valid:
error_messages = [str(e) for e in result.errors]
raise ValueError(
f"Validation failed for {domain}.{choice_group}:\n"
+ "\n".join(error_messages)
)
raise ValueError(f"Validation failed for {domain}.{choice_group}:\n" + "\n".join(error_messages))
return True

View File

@@ -23,9 +23,7 @@ logger = logging.getLogger(__name__)
@shared_task(bind=True, max_retries=3, default_retry_delay=60)
def calculate_trending_content(
self, content_type: str = "all", limit: int = 50
) -> dict[str, Any]:
def calculate_trending_content(self, content_type: str = "all", limit: int = 50) -> dict[str, Any]:
"""
Calculate trending content using real analytics data.
@@ -72,17 +70,13 @@ def calculate_trending_content(
trending_items = trending_items[:limit]
# Format results for API consumption
formatted_results = _format_trending_results(
trending_items, current_period_hours, previous_period_hours
)
formatted_results = _format_trending_results(trending_items, current_period_hours, previous_period_hours)
# Cache results
cache_key = f"trending:calculated:{content_type}:{limit}"
cache.set(cache_key, formatted_results, 3600) # Cache for 1 hour
logger.info(
f"Calculated {len(formatted_results)} trending items for {content_type}"
)
logger.info(f"Calculated {len(formatted_results)} trending items for {content_type}")
return {
"success": True,
@@ -95,13 +89,11 @@ def calculate_trending_content(
except Exception as e:
logger.error(f"Error calculating trending content: {e}", exc_info=True)
# Retry the task
raise self.retry(exc=e)
raise self.retry(exc=e) from None
@shared_task(bind=True, max_retries=3, default_retry_delay=30)
def calculate_new_content(
self, content_type: str = "all", days_back: int = 30, limit: int = 50
) -> dict[str, Any]:
def calculate_new_content(self, content_type: str = "all", days_back: int = 30, limit: int = 50) -> dict[str, Any]:
"""
Calculate new content based on opening dates and creation dates.
@@ -120,15 +112,11 @@ def calculate_new_content(
new_items = []
if content_type in ["all", "parks"]:
parks = _get_new_parks(
cutoff_date, limit if content_type == "parks" else limit * 2
)
parks = _get_new_parks(cutoff_date, limit if content_type == "parks" else limit * 2)
new_items.extend(parks)
if content_type in ["all", "rides"]:
rides = _get_new_rides(
cutoff_date, limit if content_type == "rides" else limit * 2
)
rides = _get_new_rides(cutoff_date, limit if content_type == "rides" else limit * 2)
new_items.extend(rides)
# Sort by date added (most recent first) and apply limit
@@ -154,7 +142,7 @@ def calculate_new_content(
except Exception as e:
logger.error(f"Error calculating new content: {e}", exc_info=True)
raise self.retry(exc=e)
raise self.retry(exc=e) from None
@shared_task(bind=True)
@@ -185,9 +173,7 @@ def warm_trending_cache(self) -> dict[str, Any]:
calculate_new_content.delay(**query)
results[f"trending_{query['content_type']}_{query['limit']}"] = "scheduled"
results[f"new_content_{query['content_type']}_{query['limit']}"] = (
"scheduled"
)
results[f"new_content_{query['content_type']}_{query['limit']}"] = "scheduled"
logger.info("Trending cache warming completed")
@@ -211,17 +197,13 @@ def _calculate_trending_parks(
current_period_hours: int, previous_period_hours: int, limit: int
) -> list[dict[str, Any]]:
"""Calculate trending scores for parks using real data."""
parks = Park.objects.filter(status="OPERATING").select_related(
"location", "operator"
)
parks = Park.objects.filter(status="OPERATING").select_related("location", "operator")
trending_parks = []
for park in parks:
try:
score = _calculate_content_score(
park, "park", current_period_hours, previous_period_hours
)
score = _calculate_content_score(park, "park", current_period_hours, previous_period_hours)
if score > 0: # Only include items with positive trending scores
trending_parks.append(
{
@@ -231,13 +213,9 @@ def _calculate_trending_parks(
"id": park.id,
"name": park.name,
"slug": park.slug,
"location": (
park.formatted_location if hasattr(park, "location") else ""
),
"location": (park.formatted_location if hasattr(park, "location") else ""),
"category": "park",
"rating": (
float(park.average_rating) if park.average_rating else 0.0
),
"rating": (float(park.average_rating) if park.average_rating else 0.0),
}
)
except Exception as e:
@@ -250,17 +228,13 @@ def _calculate_trending_rides(
current_period_hours: int, previous_period_hours: int, limit: int
) -> list[dict[str, Any]]:
"""Calculate trending scores for rides using real data."""
rides = Ride.objects.filter(status="OPERATING").select_related(
"park", "park__location"
)
rides = Ride.objects.filter(status="OPERATING").select_related("park", "park__location")
trending_rides = []
for ride in rides:
try:
score = _calculate_content_score(
ride, "ride", current_period_hours, previous_period_hours
)
score = _calculate_content_score(ride, "ride", current_period_hours, previous_period_hours)
if score > 0: # Only include items with positive trending scores
# Get location from park
location = ""
@@ -277,9 +251,7 @@ def _calculate_trending_rides(
"slug": ride.slug,
"location": location,
"category": "ride",
"rating": (
float(ride.average_rating) if ride.average_rating else 0.0
),
"rating": (float(ride.average_rating) if ride.average_rating else 0.0),
}
)
except Exception as e:
@@ -322,17 +294,10 @@ def _calculate_content_score(
recency_score = _calculate_recency_score(content_obj)
# 4. Popularity Score (10% weight)
popularity_score = _calculate_popularity_score(
ct, content_obj.id, current_period_hours
)
popularity_score = _calculate_popularity_score(ct, content_obj.id, current_period_hours)
# Calculate weighted final score
final_score = (
view_growth_score * 0.4
+ rating_score * 0.3
+ recency_score * 0.2
+ popularity_score * 0.1
)
final_score = view_growth_score * 0.4 + rating_score * 0.3 + recency_score * 0.2 + popularity_score * 0.1
logger.debug(
f"{content_type} {content_obj.id}: "
@@ -344,9 +309,7 @@ def _calculate_content_score(
return final_score
except Exception as e:
logger.error(
f"Error calculating score for {content_type} {content_obj.id}: {e}"
)
logger.error(f"Error calculating score for {content_type} {content_obj.id}: {e}")
return 0.0
@@ -371,9 +334,7 @@ def _calculate_view_growth_score(
# Normalize growth percentage to 0-1 scale
# 100% growth = 0.5, 500% growth = 1.0
normalized_growth = (
min(growth_percentage / 500.0, 1.0) if growth_percentage > 0 else 0.0
)
normalized_growth = min(growth_percentage / 500.0, 1.0) if growth_percentage > 0 else 0.0
return max(normalized_growth, 0.0)
except Exception as e:
@@ -431,14 +392,10 @@ def _calculate_recency_score(content_obj: Any) -> float:
return 0.5
def _calculate_popularity_score(
content_type: ContentType, object_id: int, hours: int
) -> float:
def _calculate_popularity_score(content_type: ContentType, object_id: int, hours: int) -> float:
"""Calculate popularity score based on total view count."""
try:
total_views = PageView.get_total_views_count(
content_type, object_id, hours=hours
)
total_views = PageView.get_total_views_count(content_type, object_id, hours=hours)
# Normalize views to 0-1 scale
# 0 views = 0.0, 100 views = 0.5, 1000+ views = 1.0
@@ -505,9 +462,7 @@ def _get_new_rides(cutoff_date: datetime, limit: int) -> list[dict[str, Any]]:
results = []
for ride in new_rides:
date_added = getattr(ride, "opening_date", None) or getattr(
ride, "created_at", None
)
date_added = getattr(ride, "opening_date", None) or getattr(ride, "created_at", None)
if date_added and isinstance(date_added, datetime):
date_added = date_added.date()
@@ -545,13 +500,11 @@ def _format_trending_results(
# Get view change for display
content_obj = item["content_object"]
ct = ContentType.objects.get_for_model(content_obj)
current_views, previous_views, growth_percentage = (
PageView.get_views_growth(
ct,
content_obj.id,
current_period_hours,
previous_period_hours,
)
current_views, previous_views, growth_percentage = PageView.get_views_growth(
ct,
content_obj.id,
current_period_hours,
previous_period_hours,
)
# Format exactly as frontend expects
@@ -564,9 +517,7 @@ def _format_trending_results(
"rank": rank,
"views": current_views,
"views_change": (
f"+{growth_percentage:.1f}%"
if growth_percentage > 0
else f"{growth_percentage:.1f}%"
f"+{growth_percentage:.1f}%" if growth_percentage > 0 else f"{growth_percentage:.1f}%"
),
"slug": item["slug"],
}

View File

@@ -26,6 +26,7 @@ register = template.Library()
# Time and Date Filters
# =============================================================================
@register.filter
def humanize_timedelta(value):
"""
@@ -42,26 +43,26 @@ def humanize_timedelta(value):
Human-readable string like "2 hours ago"
"""
if value is None:
return ''
return ""
# Convert datetime to timedelta from now
if hasattr(value, 'tzinfo'): # It's a datetime
if hasattr(value, "tzinfo"): # It's a datetime
now = timezone.now()
if value > now:
return 'in the future'
return "in the future"
value = now - value
# Convert seconds to timedelta
if isinstance(value, (int, float)):
if isinstance(value, int | float):
value = timedelta(seconds=value)
if not isinstance(value, timedelta):
return ''
return ""
seconds = int(value.total_seconds())
if seconds < 60:
return 'just now'
return "just now"
elif seconds < 3600:
minutes = seconds // 60
return f'{minutes} minute{"s" if minutes != 1 else ""} ago'
@@ -92,22 +93,23 @@ def time_until(value):
Output: "in 2 days", "in 3 hours"
"""
if value is None:
return ''
return ""
if hasattr(value, 'tzinfo'):
if hasattr(value, "tzinfo"):
now = timezone.now()
if value <= now:
return 'now'
return "now"
diff = value - now
return humanize_timedelta(diff).replace(' ago', '')
return humanize_timedelta(diff).replace(" ago", "")
return ''
return ""
# =============================================================================
# Text Manipulation Filters
# =============================================================================
@register.filter
@stringfilter
def truncate_smart(value, max_length=50):
@@ -130,12 +132,12 @@ def truncate_smart(value, max_length=50):
# Find the last space before max_length
truncated = value[:max_length]
last_space = truncated.rfind(' ')
last_space = truncated.rfind(" ")
if last_space > max_length * 0.5: # Only use word boundary if reasonable
truncated = truncated[:last_space]
return truncated.rstrip('.,!?;:') + '...'
return truncated.rstrip(".,!?;:") + "..."
@register.filter
@@ -153,7 +155,7 @@ def truncate_middle(value, max_length=50):
return value
keep_chars = (max_length - 3) // 2
return f'{value[:keep_chars]}...{value[-keep_chars:]}'
return f"{value[:keep_chars]}...{value[-keep_chars:]}"
@register.filter
@@ -167,13 +169,14 @@ def initials(value, max_initials=2):
Output: "JD" for "John Doe"
"""
words = value.split()
return ''.join(word[0].upper() for word in words[:max_initials] if word)
return "".join(word[0].upper() for word in words[:max_initials] if word)
# =============================================================================
# Number Formatting Filters
# =============================================================================
@register.filter
def format_number(value, decimals=0):
"""
@@ -187,14 +190,14 @@ def format_number(value, decimals=0):
Output: "1,234.56"
"""
if value is None:
return ''
return ""
try:
value = float(value)
decimals = int(decimals)
if decimals > 0:
return f'{value:,.{decimals}f}'
return f'{int(value):,}'
return f"{value:,.{decimals}f}"
return f"{int(value):,}"
except (ValueError, TypeError):
return value
@@ -209,16 +212,16 @@ def format_compact(value):
Output: "1.2K", "3.4M", "2.1B"
"""
if value is None:
return ''
return ""
try:
value = float(value)
if value >= 1_000_000_000:
return f'{value / 1_000_000_000:.1f}B'
return f"{value / 1_000_000_000:.1f}B"
elif value >= 1_000_000:
return f'{value / 1_000_000:.1f}M'
return f"{value / 1_000_000:.1f}M"
elif value >= 1_000:
return f'{value / 1_000:.1f}K'
return f"{value / 1_000:.1f}K"
return str(int(value))
except (ValueError, TypeError):
return value
@@ -237,16 +240,17 @@ def percentage(value, total):
value = float(value)
total = float(total)
if total == 0:
return '0%'
return f'{(value / total * 100):.0f}%'
return "0%"
return f"{(value / total * 100):.0f}%"
except (ValueError, TypeError, ZeroDivisionError):
return '0%'
return "0%"
# =============================================================================
# Dictionary/List Filters
# =============================================================================
@register.filter
def get_item(dictionary, key):
"""
@@ -278,7 +282,7 @@ def getlist(querydict, key):
"""
if querydict is None:
return []
if hasattr(querydict, 'getlist'):
if hasattr(querydict, "getlist"):
return querydict.getlist(key)
return []
@@ -314,6 +318,7 @@ def index(sequence, i):
# Pluralization Filters
# =============================================================================
@register.filter
def pluralize_custom(count, forms):
"""
@@ -330,7 +335,7 @@ def pluralize_custom(count, forms):
"""
try:
count = int(count)
singular, plural = forms.split(',')
singular, plural = forms.split(",")
return singular if count == 1 else plural
except (ValueError, AttributeError):
return forms
@@ -347,9 +352,9 @@ def count_with_label(count, forms):
"""
try:
count = int(count)
singular, plural = forms.split(',')
singular, plural = forms.split(",")
label = singular if count == 1 else plural
return f'{count} {label}'
return f"{count} {label}"
except (ValueError, AttributeError):
return str(count)
@@ -358,6 +363,7 @@ def count_with_label(count, forms):
# CSS Class Manipulation
# =============================================================================
@register.filter
def add_class(field, css_class):
"""
@@ -366,10 +372,10 @@ def add_class(field, css_class):
Usage:
{{ form.email|add_class:"form-control" }}
"""
if hasattr(field, 'as_widget'):
existing = field.field.widget.attrs.get('class', '')
new_classes = f'{existing} {css_class}'.strip()
return field.as_widget(attrs={'class': new_classes})
if hasattr(field, "as_widget"):
existing = field.field.widget.attrs.get("class", "")
new_classes = f"{existing} {css_class}".strip()
return field.as_widget(attrs={"class": new_classes})
return field
@@ -381,8 +387,8 @@ def set_attr(field, attr_value):
Usage:
{{ form.email|set_attr:"placeholder:Enter email" }}
"""
if hasattr(field, 'as_widget'):
attr, value = attr_value.split(':')
if hasattr(field, "as_widget"):
attr, value = attr_value.split(":")
return field.as_widget(attrs={attr: value})
return field
@@ -391,6 +397,7 @@ def set_attr(field, attr_value):
# Conditional Filters
# =============================================================================
@register.filter
def default_if_none(value, default):
"""
@@ -413,5 +420,5 @@ def yesno_icon(value, icons="fa-check,fa-times"):
{{ has_feature|yesno_icon:"fa-star,fa-star-o" }}
"""
true_icon, false_icon = icons.split(',')
true_icon, false_icon = icons.split(",")
return true_icon if value else false_icon

View File

@@ -23,6 +23,7 @@ Usage:
{# Render a transition button #}
{% transition_button submission 'approve' request.user %}
"""
from typing import Any
from django import template
@@ -53,12 +54,12 @@ def get_state_value(obj) -> str | None:
Returns:
The current state value or None
"""
if hasattr(obj, 'get_state_value'):
if hasattr(obj, "get_state_value"):
return obj.get_state_value()
if hasattr(obj, 'state_field_name'):
if hasattr(obj, "state_field_name"):
return getattr(obj, obj.state_field_name, None)
# Try common field names
for field in ['status', 'state']:
for field in ["status", "state"]:
if hasattr(obj, field):
return getattr(obj, field, None)
return None
@@ -78,19 +79,19 @@ def get_state_display(obj) -> str:
Returns:
The human-readable state display value
"""
if hasattr(obj, 'get_state_display_value'):
if hasattr(obj, "get_state_display_value"):
return obj.get_state_display_value()
if hasattr(obj, 'state_field_name'):
if hasattr(obj, "state_field_name"):
field_name = obj.state_field_name
getter = getattr(obj, f'get_{field_name}_display', None)
getter = getattr(obj, f"get_{field_name}_display", None)
if callable(getter):
return getter()
# Try common field names
for field in ['status', 'state']:
getter = getattr(obj, f'get_{field}_display', None)
for field in ["status", "state"]:
getter = getattr(obj, f"get_{field}_display", None)
if callable(getter):
return getter()
return str(get_state_value(obj) or '')
return str(get_state_value(obj) or "")
@register.filter
@@ -109,7 +110,7 @@ def get_state_choice(obj):
Returns:
The RichChoice object or None
"""
if hasattr(obj, 'get_state_choice'):
if hasattr(obj, "get_state_choice"):
return obj.get_state_choice()
return None
@@ -205,18 +206,24 @@ def get_available_transitions(obj, user) -> list[dict[str, Any]]:
# Get list of available transitions
available_transition_names = []
if hasattr(obj, 'get_available_user_transitions'):
if hasattr(obj, "get_available_user_transitions"):
# Use the helper method if available
return obj.get_available_user_transitions(user)
if hasattr(obj, 'get_available_transitions'):
if hasattr(obj, "get_available_transitions"):
available_transition_names = list(obj.get_available_transitions())
else:
# Fallback: look for transition methods by convention
for attr_name in dir(obj):
if attr_name.startswith('transition_to_') or attr_name in ['approve', 'reject', 'escalate', 'complete', 'cancel']:
if attr_name.startswith("transition_to_") or attr_name in [
"approve",
"reject",
"escalate",
"complete",
"cancel",
]:
method = getattr(obj, attr_name, None)
if callable(method) and hasattr(method, '_django_fsm'):
if callable(method) and hasattr(method, "_django_fsm"):
available_transition_names.append(attr_name)
# Filter transitions by user permission
@@ -226,14 +233,16 @@ def get_available_transitions(obj, user) -> list[dict[str, Any]]:
try:
if can_proceed(method, user):
metadata = get_transition_metadata(transition_name)
transitions.append({
'name': transition_name,
'label': _format_transition_label(transition_name),
'icon': metadata.get('icon', 'arrow-right'),
'style': metadata.get('style', 'gray'),
'requires_confirm': metadata.get('requires_confirm', False),
'confirm_message': metadata.get('confirm_message', 'Are you sure?'),
})
transitions.append(
{
"name": transition_name,
"label": _format_transition_label(transition_name),
"icon": metadata.get("icon", "arrow-right"),
"style": metadata.get("style", "gray"),
"requires_confirm": metadata.get("requires_confirm", False),
"confirm_message": metadata.get("confirm_message", "Are you sure?"),
}
)
except Exception:
# Skip transitions that raise errors during can_proceed check
pass
@@ -289,14 +298,17 @@ def get_transition_url(obj, transition_name: str) -> str:
The URL string for the transition endpoint
"""
try:
return reverse('core:fsm_transition', kwargs={
'app_label': obj._meta.app_label,
'model_name': obj._meta.model_name,
'pk': obj.pk,
'transition_name': transition_name,
})
return reverse(
"core:fsm_transition",
kwargs={
"app_label": obj._meta.app_label,
"model_name": obj._meta.model_name,
"pk": obj.pk,
"transition_name": transition_name,
},
)
except NoReverseMatch:
return ''
return ""
# =============================================================================
@@ -304,7 +316,7 @@ def get_transition_url(obj, transition_name: str) -> str:
# =============================================================================
@register.inclusion_tag('htmx/state_actions.html', takes_context=True)
@register.inclusion_tag("htmx/state_actions.html", takes_context=True)
def render_state_actions(context, obj, user=None, **kwargs):
"""
Render the state action buttons for an FSM-enabled object.
@@ -323,17 +335,17 @@ def render_state_actions(context, obj, user=None, **kwargs):
Context for the state_actions.html template
"""
if user is None:
user = context.get('request', {}).user if 'request' in context else None
user = context.get("request", {}).user if "request" in context else None
return {
'object': obj,
'user': user,
'request': context.get('request'),
"object": obj,
"user": user,
"request": context.get("request"),
**kwargs,
}
@register.inclusion_tag('htmx/status_with_actions.html', takes_context=True)
@register.inclusion_tag("htmx/status_with_actions.html", takes_context=True)
def render_status_with_actions(context, obj, user=None, **kwargs):
"""
Render the status badge with action buttons for an FSM-enabled object.
@@ -352,12 +364,12 @@ def render_status_with_actions(context, obj, user=None, **kwargs):
Context for the status_with_actions.html template
"""
if user is None:
user = context.get('request', {}).user if 'request' in context else None
user = context.get("request", {}).user if "request" in context else None
return {
'object': obj,
'user': user,
'request': context.get('request'),
"object": obj,
"user": user,
"request": context.get("request"),
**kwargs,
}
@@ -384,28 +396,28 @@ def _format_transition_label(transition_name: str) -> str:
"""
# Remove common prefixes
label = transition_name
for prefix in ['transition_to_', 'transition_', 'do_']:
for prefix in ["transition_to_", "transition_", "do_"]:
if label.startswith(prefix):
label = label[len(prefix):]
label = label[len(prefix) :]
break
# Remove past tense suffix and capitalize
# e.g., 'approved' -> 'Approve'
if label.endswith('ed') and len(label) > 3:
if label.endswith("ed") and len(label) > 3:
# Handle special cases
if label.endswith('ied'):
label = label[:-3] + 'y'
if label.endswith("ied"):
label = label[:-3] + "y"
elif label[-3] == label[-4]: # doubled consonant (e.g., 'submitted')
label = label[:-3]
else:
label = label[:-1] # Remove 'd'
if label.endswith('e'):
if label.endswith("e"):
pass # Keep the 'e' for words like 'approve'
else:
label = label[:-1] # Remove 'e' for words like 'rejected' -> 'reject'
# Replace underscores with spaces and title case
label = label.replace('_', ' ').title()
label = label.replace("_", " ").title()
return label
@@ -418,17 +430,17 @@ def _format_transition_label(transition_name: str) -> str:
# Ensure all tags and filters are registered
__all__ = [
# Filters
'get_state_value',
'get_state_display',
'get_state_choice',
'app_label',
'model_name',
'default_target_id',
"get_state_value",
"get_state_display",
"get_state_choice",
"app_label",
"model_name",
"default_target_id",
# Tags
'get_available_transitions',
'can_transition',
'get_transition_url',
"get_available_transitions",
"can_transition",
"get_transition_url",
# Inclusion tags
'render_state_actions',
'render_status_with_actions',
"render_state_actions",
"render_status_with_actions",
]

View File

@@ -57,7 +57,8 @@ register = template.Library()
# HTML Sanitization Filters
# =============================================================================
@register.filter(name='sanitize', is_safe=True)
@register.filter(name="sanitize", is_safe=True)
def sanitize_filter(value):
"""
Sanitize HTML content to prevent XSS attacks.
@@ -68,11 +69,11 @@ def sanitize_filter(value):
{{ user_content|sanitize }}
"""
if not value:
return ''
return ""
return mark_safe(sanitize_html(str(value)))
@register.filter(name='sanitize_minimal', is_safe=True)
@register.filter(name="sanitize_minimal", is_safe=True)
def sanitize_minimal_filter(value):
"""
Sanitize HTML with minimal allowed tags.
@@ -83,11 +84,11 @@ def sanitize_minimal_filter(value):
{{ comment|sanitize_minimal }}
"""
if not value:
return ''
return ""
return mark_safe(_sanitize_minimal(str(value)))
@register.filter(name='sanitize_svg', is_safe=True)
@register.filter(name="sanitize_svg", is_safe=True)
def sanitize_svg_filter(value):
"""
Sanitize SVG content for safe inline rendering.
@@ -96,11 +97,11 @@ def sanitize_svg_filter(value):
{{ icon_svg|sanitize_svg }}
"""
if not value:
return ''
return ""
return mark_safe(sanitize_svg(str(value)))
@register.filter(name='strip_html')
@register.filter(name="strip_html")
def strip_html_filter(value):
"""
Remove all HTML tags from content.
@@ -109,7 +110,7 @@ def strip_html_filter(value):
{{ html_content|strip_html }}
"""
if not value:
return ''
return ""
return _strip_html(str(value))
@@ -117,7 +118,8 @@ def strip_html_filter(value):
# JavaScript/JSON Context Filters
# =============================================================================
@register.filter(name='json_safe', is_safe=True)
@register.filter(name="json_safe", is_safe=True)
def json_safe_filter(value):
"""
Safely serialize data for embedding in JavaScript.
@@ -131,11 +133,11 @@ def json_safe_filter(value):
</script>
"""
if value is None:
return 'null'
return "null"
return mark_safe(sanitize_for_json(value))
@register.filter(name='escapejs_safe')
@register.filter(name="escapejs_safe")
def escapejs_safe_filter(value):
"""
Escape a string for safe use in JavaScript string literals.
@@ -146,7 +148,7 @@ def escapejs_safe_filter(value):
</script>
"""
if not value:
return ''
return ""
return _escape_js_string(str(value))
@@ -154,7 +156,8 @@ def escapejs_safe_filter(value):
# URL and Attribute Filters
# =============================================================================
@register.filter(name='sanitize_url')
@register.filter(name="sanitize_url")
def sanitize_url_filter(value):
"""
Sanitize a URL to prevent javascript: and other dangerous protocols.
@@ -163,11 +166,11 @@ def sanitize_url_filter(value):
<a href="{{ user_url|sanitize_url }}">Link</a>
"""
if not value:
return ''
return ""
return _sanitize_url(str(value))
@register.filter(name='attr_safe')
@register.filter(name="attr_safe")
def attr_safe_filter(value):
"""
Escape a value for safe use in HTML attributes.
@@ -176,7 +179,7 @@ def attr_safe_filter(value):
<div data-value="{{ user_value|attr_safe }}">
"""
if not value:
return ''
return ""
return sanitize_attribute_value(str(value))
@@ -187,36 +190,36 @@ def attr_safe_filter(value):
# Predefined safe SVG icons
# These are trusted and can be rendered without sanitization
BUILTIN_ICONS = {
'check': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M4.5 12.75l6 6 9-13.5" /></svg>''',
'x': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M6 18L18 6M6 6l12 12" /></svg>''',
'plus': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M12 4.5v15m7.5-7.5h-15" /></svg>''',
'minus': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M19.5 12h-15" /></svg>''',
'chevron-down': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M19.5 8.25l-7.5 7.5-7.5-7.5" /></svg>''',
'chevron-up': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M4.5 15.75l7.5-7.5 7.5 7.5" /></svg>''',
'chevron-left': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M15.75 19.5L8.25 12l7.5-7.5" /></svg>''',
'chevron-right': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M8.25 4.5l7.5 7.5-7.5 7.5" /></svg>''',
'search': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M21 21l-5.197-5.197m0 0A7.5 7.5 0 105.196 5.196a7.5 7.5 0 0010.607 10.607z" /></svg>''',
'menu': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3.75 6.75h16.5M3.75 12h16.5m-16.5 5.25h16.5" /></svg>''',
'user': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M15.75 6a3.75 3.75 0 11-7.5 0 3.75 3.75 0 017.5 0zM4.501 20.118a7.5 7.5 0 0114.998 0A17.933 17.933 0 0112 21.75c-2.676 0-5.216-.584-7.499-1.632z" /></svg>''',
'cog': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M9.594 3.94c.09-.542.56-.94 1.11-.94h2.593c.55 0 1.02.398 1.11.94l.213 1.281c.063.374.313.686.645.87.074.04.147.083.22.127.324.196.72.257 1.075.124l1.217-.456a1.125 1.125 0 011.37.49l1.296 2.247a1.125 1.125 0 01-.26 1.431l-1.003.827c-.293.24-.438.613-.431.992a6.759 6.759 0 010 .255c-.007.378.138.75.43.99l1.005.828c.424.35.534.954.26 1.43l-1.298 2.247a1.125 1.125 0 01-1.369.491l-1.217-.456c-.355-.133-.75-.072-1.076.124a6.57 6.57 0 01-.22.128c-.331.183-.581.495-.644.869l-.213 1.28c-.09.543-.56.941-1.11.941h-2.594c-.55 0-1.02-.398-1.11-.94l-.213-1.281c-.062-.374-.312-.686-.644-.87a6.52 6.52 0 01-.22-.127c-.325-.196-.72-.257-1.076-.124l-1.217.456a1.125 1.125 0 01-1.369-.49l-1.297-2.247a1.125 1.125 0 01.26-1.431l1.004-.827c.292-.24.437-.613.43-.992a6.932 6.932 0 010-.255c.007-.378-.138-.75-.43-.99l-1.004-.828a1.125 1.125 0 01-.26-1.43l1.297-2.247a1.125 1.125 0 011.37-.491l1.216.456c.356.133.751.072 1.076-.124.072-.044.146-.087.22-.128.332-.183.582-.495.644-.869l.214-1.281z" /><path stroke-linecap="round" stroke-linejoin="round" d="M15 12a3 3 0 11-6 0 3 3 0 016 0z" /></svg>''',
'trash': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M14.74 9l-.346 9m-4.788 0L9.26 9m9.968-3.21c.342.052.682.107 1.022.166m-1.022-.165L18.16 19.673a2.25 2.25 0 01-2.244 2.077H8.084a2.25 2.25 0 01-2.244-2.077L4.772 5.79m14.456 0a48.108 48.108 0 00-3.478-.397m-12 .562c.34-.059.68-.114 1.022-.165m0 0a48.11 48.11 0 013.478-.397m7.5 0v-.916c0-1.18-.91-2.164-2.09-2.201a51.964 51.964 0 00-3.32 0c-1.18.037-2.09 1.022-2.09 2.201v.916m7.5 0a48.667 48.667 0 00-7.5 0" /></svg>''',
'pencil': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M16.862 4.487l1.687-1.688a1.875 1.875 0 112.652 2.652L10.582 16.07a4.5 4.5 0 01-1.897 1.13L6 18l.8-2.685a4.5 4.5 0 011.13-1.897l8.932-8.931zm0 0L19.5 7.125M18 14v4.75A2.25 2.25 0 0115.75 21H5.25A2.25 2.25 0 013 18.75V8.25A2.25 2.25 0 015.25 6H10" /></svg>''',
'eye': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M2.036 12.322a1.012 1.012 0 010-.639C3.423 7.51 7.36 4.5 12 4.5c4.638 0 8.573 3.007 9.963 7.178.07.207.07.431 0 .639C20.577 16.49 16.64 19.5 12 19.5c-4.638 0-8.573-3.007-9.963-7.178z" /><path stroke-linecap="round" stroke-linejoin="round" d="M15 12a3 3 0 11-6 0 3 3 0 016 0z" /></svg>''',
'eye-slash': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3.98 8.223A10.477 10.477 0 001.934 12C3.226 16.338 7.244 19.5 12 19.5c.993 0 1.953-.138 2.863-.395M6.228 6.228A10.45 10.45 0 0112 4.5c4.756 0 8.773 3.162 10.065 7.498a10.523 10.523 0 01-4.293 5.774M6.228 6.228L3 3m3.228 3.228l3.65 3.65m7.894 7.894L21 21m-3.228-3.228l-3.65-3.65m0 0a3 3 0 10-4.243-4.243m4.242 4.242L9.88 9.88" /></svg>''',
'arrow-left': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M10.5 19.5L3 12m0 0l7.5-7.5M3 12h18" /></svg>''',
'arrow-right': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M13.5 4.5L21 12m0 0l-7.5 7.5M21 12H3" /></svg>''',
'info': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M11.25 11.25l.041-.02a.75.75 0 011.063.852l-.708 2.836a.75.75 0 001.063.853l.041-.021M21 12a9 9 0 11-18 0 9 9 0 0118 0zm-9-3.75h.008v.008H12V8.25z" /></svg>''',
'warning': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M12 9v3.75m-9.303 3.376c-.866 1.5.217 3.374 1.948 3.374h14.71c1.73 0 2.813-1.874 1.948-3.374L13.949 3.378c-.866-1.5-3.032-1.5-3.898 0L2.697 16.126zM12 15.75h.007v.008H12v-.008z" /></svg>''',
'error': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M12 9v3.75m9-.75a9 9 0 11-18 0 9 9 0 0118 0zm-9 3.75h.008v.008H12v-.008z" /></svg>''',
'success': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M9 12.75L11.25 15 15 9.75M21 12a9 9 0 11-18 0 9 9 0 0118 0z" /></svg>''',
'loading': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" {attrs}><circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle><path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>''',
'external-link': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M13.5 6H5.25A2.25 2.25 0 003 8.25v10.5A2.25 2.25 0 005.25 21h10.5A2.25 2.25 0 0018 18.75V10.5m-10.5 6L21 3m0 0h-5.25M21 3v5.25" /></svg>''',
'download': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3 16.5v2.25A2.25 2.25 0 005.25 21h13.5A2.25 2.25 0 0021 18.75V16.5M16.5 12L12 16.5m0 0L7.5 12m4.5 4.5V3" /></svg>''',
'upload': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3 16.5v2.25A2.25 2.25 0 005.25 21h13.5A2.25 2.25 0 0021 18.75V16.5m-13.5-9L12 3m0 0l4.5 4.5M12 3v13.5" /></svg>''',
'star': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M11.48 3.499a.562.562 0 011.04 0l2.125 5.111a.563.563 0 00.475.345l5.518.442c.499.04.701.663.321.988l-4.204 3.602a.563.563 0 00-.182.557l1.285 5.385a.562.562 0 01-.84.61l-4.725-2.885a.563.563 0 00-.586 0L6.982 20.54a.562.562 0 01-.84-.61l1.285-5.386a.562.562 0 00-.182-.557l-4.204-3.602a.563.563 0 01.321-.988l5.518-.442a.563.563 0 00.475-.345L11.48 3.5z" /></svg>''',
'star-filled': '''<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" {attrs}><path fill-rule="evenodd" d="M10.788 3.21c.448-1.077 1.976-1.077 2.424 0l2.082 5.007 5.404.433c1.164.093 1.636 1.545.749 2.305l-4.117 3.527 1.257 5.273c.271 1.136-.964 2.033-1.96 1.425L12 18.354 7.373 21.18c-.996.608-2.231-.29-1.96-1.425l1.257-5.273-4.117-3.527c-.887-.76-.415-2.212.749-2.305l5.404-.433 2.082-5.006z" clip-rule="evenodd" /></svg>''',
'heart': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M21 8.25c0-2.485-2.099-4.5-4.688-4.5-1.935 0-3.597 1.126-4.312 2.733-.715-1.607-2.377-2.733-4.313-2.733C5.1 3.75 3 5.765 3 8.25c0 7.22 9 12 9 12s9-4.78 9-12z" /></svg>''',
'heart-filled': '''<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" {attrs}><path d="M11.645 20.91l-.007-.003-.022-.012a15.247 15.247 0 01-.383-.218 25.18 25.18 0 01-4.244-3.17C4.688 15.36 2.25 12.174 2.25 8.25 2.25 5.322 4.714 3 7.688 3A5.5 5.5 0 0112 5.052 5.5 5.5 0 0116.313 3c2.973 0 5.437 2.322 5.437 5.25 0 3.925-2.438 7.111-4.739 9.256a25.175 25.175 0 01-4.244 3.17 15.247 15.247 0 01-.383.219l-.022.012-.007.004-.003.001a.752.752 0 01-.704 0l-.003-.001z" /></svg>''',
"check": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M4.5 12.75l6 6 9-13.5" /></svg>""",
"x": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M6 18L18 6M6 6l12 12" /></svg>""",
"plus": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M12 4.5v15m7.5-7.5h-15" /></svg>""",
"minus": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M19.5 12h-15" /></svg>""",
"chevron-down": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M19.5 8.25l-7.5 7.5-7.5-7.5" /></svg>""",
"chevron-up": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M4.5 15.75l7.5-7.5 7.5 7.5" /></svg>""",
"chevron-left": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M15.75 19.5L8.25 12l7.5-7.5" /></svg>""",
"chevron-right": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M8.25 4.5l7.5 7.5-7.5 7.5" /></svg>""",
"search": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M21 21l-5.197-5.197m0 0A7.5 7.5 0 105.196 5.196a7.5 7.5 0 0010.607 10.607z" /></svg>""",
"menu": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3.75 6.75h16.5M3.75 12h16.5m-16.5 5.25h16.5" /></svg>""",
"user": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M15.75 6a3.75 3.75 0 11-7.5 0 3.75 3.75 0 017.5 0zM4.501 20.118a7.5 7.5 0 0114.998 0A17.933 17.933 0 0112 21.75c-2.676 0-5.216-.584-7.499-1.632z" /></svg>""",
"cog": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M9.594 3.94c.09-.542.56-.94 1.11-.94h2.593c.55 0 1.02.398 1.11.94l.213 1.281c.063.374.313.686.645.87.074.04.147.083.22.127.324.196.72.257 1.075.124l1.217-.456a1.125 1.125 0 011.37.49l1.296 2.247a1.125 1.125 0 01-.26 1.431l-1.003.827c-.293.24-.438.613-.431.992a6.759 6.759 0 010 .255c-.007.378.138.75.43.99l1.005.828c.424.35.534.954.26 1.43l-1.298 2.247a1.125 1.125 0 01-1.369.491l-1.217-.456c-.355-.133-.75-.072-1.076.124a6.57 6.57 0 01-.22.128c-.331.183-.581.495-.644.869l-.213 1.28c-.09.543-.56.941-1.11.941h-2.594c-.55 0-1.02-.398-1.11-.94l-.213-1.281c-.062-.374-.312-.686-.644-.87a6.52 6.52 0 01-.22-.127c-.325-.196-.72-.257-1.076-.124l-1.217.456a1.125 1.125 0 01-1.369-.49l-1.297-2.247a1.125 1.125 0 01.26-1.431l1.004-.827c.292-.24.437-.613.43-.992a6.932 6.932 0 010-.255c.007-.378-.138-.75-.43-.99l-1.004-.828a1.125 1.125 0 01-.26-1.43l1.297-2.247a1.125 1.125 0 011.37-.491l1.216.456c.356.133.751.072 1.076-.124.072-.044.146-.087.22-.128.332-.183.582-.495.644-.869l.214-1.281z" /><path stroke-linecap="round" stroke-linejoin="round" d="M15 12a3 3 0 11-6 0 3 3 0 016 0z" /></svg>""",
"trash": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M14.74 9l-.346 9m-4.788 0L9.26 9m9.968-3.21c.342.052.682.107 1.022.166m-1.022-.165L18.16 19.673a2.25 2.25 0 01-2.244 2.077H8.084a2.25 2.25 0 01-2.244-2.077L4.772 5.79m14.456 0a48.108 48.108 0 00-3.478-.397m-12 .562c.34-.059.68-.114 1.022-.165m0 0a48.11 48.11 0 013.478-.397m7.5 0v-.916c0-1.18-.91-2.164-2.09-2.201a51.964 51.964 0 00-3.32 0c-1.18.037-2.09 1.022-2.09 2.201v.916m7.5 0a48.667 48.667 0 00-7.5 0" /></svg>""",
"pencil": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M16.862 4.487l1.687-1.688a1.875 1.875 0 112.652 2.652L10.582 16.07a4.5 4.5 0 01-1.897 1.13L6 18l.8-2.685a4.5 4.5 0 011.13-1.897l8.932-8.931zm0 0L19.5 7.125M18 14v4.75A2.25 2.25 0 0115.75 21H5.25A2.25 2.25 0 013 18.75V8.25A2.25 2.25 0 015.25 6H10" /></svg>""",
"eye": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M2.036 12.322a1.012 1.012 0 010-.639C3.423 7.51 7.36 4.5 12 4.5c4.638 0 8.573 3.007 9.963 7.178.07.207.07.431 0 .639C20.577 16.49 16.64 19.5 12 19.5c-4.638 0-8.573-3.007-9.963-7.178z" /><path stroke-linecap="round" stroke-linejoin="round" d="M15 12a3 3 0 11-6 0 3 3 0 016 0z" /></svg>""",
"eye-slash": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3.98 8.223A10.477 10.477 0 001.934 12C3.226 16.338 7.244 19.5 12 19.5c.993 0 1.953-.138 2.863-.395M6.228 6.228A10.45 10.45 0 0112 4.5c4.756 0 8.773 3.162 10.065 7.498a10.523 10.523 0 01-4.293 5.774M6.228 6.228L3 3m3.228 3.228l3.65 3.65m7.894 7.894L21 21m-3.228-3.228l-3.65-3.65m0 0a3 3 0 10-4.243-4.243m4.242 4.242L9.88 9.88" /></svg>""",
"arrow-left": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M10.5 19.5L3 12m0 0l7.5-7.5M3 12h18" /></svg>""",
"arrow-right": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M13.5 4.5L21 12m0 0l-7.5 7.5M21 12H3" /></svg>""",
"info": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M11.25 11.25l.041-.02a.75.75 0 011.063.852l-.708 2.836a.75.75 0 001.063.853l.041-.021M21 12a9 9 0 11-18 0 9 9 0 0118 0zm-9-3.75h.008v.008H12V8.25z" /></svg>""",
"warning": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M12 9v3.75m-9.303 3.376c-.866 1.5.217 3.374 1.948 3.374h14.71c1.73 0 2.813-1.874 1.948-3.374L13.949 3.378c-.866-1.5-3.032-1.5-3.898 0L2.697 16.126zM12 15.75h.007v.008H12v-.008z" /></svg>""",
"error": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M12 9v3.75m9-.75a9 9 0 11-18 0 9 9 0 0118 0zm-9 3.75h.008v.008H12v-.008z" /></svg>""",
"success": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M9 12.75L11.25 15 15 9.75M21 12a9 9 0 11-18 0 9 9 0 0118 0z" /></svg>""",
"loading": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" {attrs}><circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle><path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>""",
"external-link": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M13.5 6H5.25A2.25 2.25 0 003 8.25v10.5A2.25 2.25 0 005.25 21h10.5A2.25 2.25 0 0018 18.75V10.5m-10.5 6L21 3m0 0h-5.25M21 3v5.25" /></svg>""",
"download": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3 16.5v2.25A2.25 2.25 0 005.25 21h13.5A2.25 2.25 0 0021 18.75V16.5M16.5 12L12 16.5m0 0L7.5 12m4.5 4.5V3" /></svg>""",
"upload": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3 16.5v2.25A2.25 2.25 0 005.25 21h13.5A2.25 2.25 0 0021 18.75V16.5m-13.5-9L12 3m0 0l4.5 4.5M12 3v13.5" /></svg>""",
"star": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M11.48 3.499a.562.562 0 011.04 0l2.125 5.111a.563.563 0 00.475.345l5.518.442c.499.04.701.663.321.988l-4.204 3.602a.563.563 0 00-.182.557l1.285 5.385a.562.562 0 01-.84.61l-4.725-2.885a.563.563 0 00-.586 0L6.982 20.54a.562.562 0 01-.84-.61l1.285-5.386a.562.562 0 00-.182-.557l-4.204-3.602a.563.563 0 01.321-.988l5.518-.442a.563.563 0 00.475-.345L11.48 3.5z" /></svg>""",
"star-filled": """<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" {attrs}><path fill-rule="evenodd" d="M10.788 3.21c.448-1.077 1.976-1.077 2.424 0l2.082 5.007 5.404.433c1.164.093 1.636 1.545.749 2.305l-4.117 3.527 1.257 5.273c.271 1.136-.964 2.033-1.96 1.425L12 18.354 7.373 21.18c-.996.608-2.231-.29-1.96-1.425l1.257-5.273-4.117-3.527c-.887-.76-.415-2.212.749-2.305l5.404-.433 2.082-5.006z" clip-rule="evenodd" /></svg>""",
"heart": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M21 8.25c0-2.485-2.099-4.5-4.688-4.5-1.935 0-3.597 1.126-4.312 2.733-.715-1.607-2.377-2.733-4.313-2.733C5.1 3.75 3 5.765 3 8.25c0 7.22 9 12 9 12s9-4.78 9-12z" /></svg>""",
"heart-filled": """<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" {attrs}><path d="M11.645 20.91l-.007-.003-.022-.012a15.247 15.247 0 01-.383-.218 25.18 25.18 0 01-4.244-3.17C4.688 15.36 2.25 12.174 2.25 8.25 2.25 5.322 4.714 3 7.688 3A5.5 5.5 0 0112 5.052 5.5 5.5 0 0116.313 3c2.973 0 5.437 2.322 5.437 5.25 0 3.925-2.438 7.111-4.739 9.256a25.175 25.175 0 01-4.244 3.17 15.247 15.247 0 01-.383.219l-.022.012-.007.004-.003.001a.752.752 0 01-.704 0l-.003-.001z" /></svg>""",
}
@@ -243,18 +246,18 @@ def icon(name, **kwargs):
if not svg_template:
# Return empty string for unknown icons (fail silently)
return ''
return ""
# Build attributes string
attrs_list = []
for key, value in kwargs.items():
# Convert underscore to hyphen for HTML attributes (e.g., aria_hidden -> aria-hidden)
attr_name = key.replace('_', '-')
attr_name = key.replace("_", "-")
# Escape attribute values to prevent XSS
safe_value = sanitize_attribute_value(str(value))
attrs_list.append(f'{attr_name}="{safe_value}"')
attrs_str = ' '.join(attrs_list)
attrs_str = " ".join(attrs_list)
# Substitute attributes into template
svg = svg_template.format(attrs=attrs_str)
@@ -263,7 +266,7 @@ def icon(name, **kwargs):
@register.simple_tag
def icon_class(name, size='w-5 h-5', extra_class=''):
def icon_class(name, size="w-5 h-5", extra_class=""):
"""
Render a trusted SVG icon with common class presets.
@@ -278,5 +281,5 @@ def icon_class(name, size='w-5 h-5', extra_class=''):
Returns:
Safe HTML for the icon SVG
"""
classes = f'{size} {extra_class}'.strip()
return icon(name, **{'class': classes})
classes = f"{size} {extra_class}".strip()
return icon(name, **{"class": classes})

View File

@@ -1,4 +1,3 @@
import pghistory
import pytest
from django.contrib.auth import get_user_model
@@ -7,6 +6,7 @@ from apps.parks.models import Company, Park
User = get_user_model()
@pytest.mark.django_db
class TestTrackedModel:
"""
@@ -20,10 +20,7 @@ class TestTrackedModel:
with pghistory.context(user=user.id):
park = Park.objects.create(
name="History Test Park",
description="Testing history",
operating_season="Summer",
operator=company
name="History Test Park", description="Testing history", operating_season="Summer", operator=company
)
# Verify history using the helper method from TrackedModel
@@ -50,6 +47,5 @@ class TestTrackedModel:
park.save()
assert park.get_history().count() == 2
latest = park.get_history().first() # Ordered by -pgh_created_at
latest = park.get_history().first() # Ordered by -pgh_created_at
assert latest.name == "Updated"

View File

@@ -17,9 +17,7 @@ app_name = "core"
entity_patterns = [
path("search/", EntityFuzzySearchView.as_view(), name="entity_fuzzy_search"),
path("not-found/", EntityNotFoundView.as_view(), name="entity_not_found"),
path(
"suggestions/", QuickEntitySuggestionView.as_view(), name="entity_suggestions"
),
path("suggestions/", QuickEntitySuggestionView.as_view(), name="entity_suggestions"),
]
# FSM transition endpoints

View File

@@ -6,6 +6,7 @@ from django.core.exceptions import ImproperlyConfigured
logger = logging.getLogger(__name__)
def get_direct_upload_url(user_id=None):
"""
Generates a direct upload URL for Cloudflare Images.
@@ -20,13 +21,11 @@ def get_direct_upload_url(user_id=None):
ImproperlyConfigured: If Cloudflare settings are missing.
requests.RequestException: If the Cloudflare API request fails.
"""
account_id = getattr(settings, 'CLOUDFLARE_IMAGES_ACCOUNT_ID', None)
api_token = getattr(settings, 'CLOUDFLARE_IMAGES_API_TOKEN', None)
account_id = getattr(settings, "CLOUDFLARE_IMAGES_ACCOUNT_ID", None)
api_token = getattr(settings, "CLOUDFLARE_IMAGES_API_TOKEN", None)
if not account_id or not api_token:
raise ImproperlyConfigured(
"CLOUDFLARE_IMAGES_ACCOUNT_ID and CLOUDFLARE_IMAGES_API_TOKEN must be set."
)
raise ImproperlyConfigured("CLOUDFLARE_IMAGES_ACCOUNT_ID and CLOUDFLARE_IMAGES_API_TOKEN must be set.")
url = f"https://api.cloudflare.com/client/v4/accounts/{account_id}/images/v2/direct_upload"

View File

@@ -37,6 +37,7 @@ from django.core.files.uploadedfile import UploadedFile
class FileValidationError(ValidationError):
"""Custom exception for file validation errors."""
pass
@@ -47,41 +48,49 @@ class FileValidationError(ValidationError):
# Magic number signatures for common image formats
# Format: (magic_bytes, offset, description)
IMAGE_SIGNATURES = {
'jpeg': [
(b'\xFF\xD8\xFF\xE0', 0, 'JPEG (JFIF)'),
(b'\xFF\xD8\xFF\xE1', 0, 'JPEG (EXIF)'),
(b'\xFF\xD8\xFF\xE2', 0, 'JPEG (ICC)'),
(b'\xFF\xD8\xFF\xE3', 0, 'JPEG (Samsung)'),
(b'\xFF\xD8\xFF\xE8', 0, 'JPEG (SPIFF)'),
(b'\xFF\xD8\xFF\xDB', 0, 'JPEG (Raw)'),
"jpeg": [
(b"\xff\xd8\xff\xe0", 0, "JPEG (JFIF)"),
(b"\xff\xd8\xff\xe1", 0, "JPEG (EXIF)"),
(b"\xff\xd8\xff\xe2", 0, "JPEG (ICC)"),
(b"\xff\xd8\xff\xe3", 0, "JPEG (Samsung)"),
(b"\xff\xd8\xff\xe8", 0, "JPEG (SPIFF)"),
(b"\xff\xd8\xff\xdb", 0, "JPEG (Raw)"),
],
'png': [
(b'\x89PNG\r\n\x1a\n', 0, 'PNG'),
"png": [
(b"\x89PNG\r\n\x1a\n", 0, "PNG"),
],
'gif': [
(b'GIF87a', 0, 'GIF87a'),
(b'GIF89a', 0, 'GIF89a'),
"gif": [
(b"GIF87a", 0, "GIF87a"),
(b"GIF89a", 0, "GIF89a"),
],
'webp': [
(b'RIFF', 0, 'RIFF'), # WebP starts with RIFF header
"webp": [
(b"RIFF", 0, "RIFF"), # WebP starts with RIFF header
],
'bmp': [
(b'BM', 0, 'BMP'),
"bmp": [
(b"BM", 0, "BMP"),
],
}
# All allowed MIME types
ALLOWED_IMAGE_MIME_TYPES: set[str] = frozenset({
'image/jpeg',
'image/png',
'image/gif',
'image/webp',
})
ALLOWED_IMAGE_MIME_TYPES: set[str] = frozenset(
{
"image/jpeg",
"image/png",
"image/gif",
"image/webp",
}
)
# Allowed file extensions
ALLOWED_IMAGE_EXTENSIONS: set[str] = frozenset({
'.jpg', '.jpeg', '.png', '.gif', '.webp',
})
ALLOWED_IMAGE_EXTENSIONS: set[str] = frozenset(
{
".jpg",
".jpeg",
".png",
".gif",
".webp",
}
)
# Maximum file size (10MB)
MAX_FILE_SIZE = 10 * 1024 * 1024
@@ -94,6 +103,7 @@ MIN_FILE_SIZE = 100 # 100 bytes
# File Validation Functions
# =============================================================================
def validate_image_upload(
file: UploadedFile,
max_size: int = MAX_FILE_SIZE,
@@ -133,39 +143,29 @@ def validate_image_upload(
# 2. Check file size
if file.size > max_size:
raise FileValidationError(
f"File too large. Maximum size is {max_size // (1024 * 1024)}MB"
)
raise FileValidationError(f"File too large. Maximum size is {max_size // (1024 * 1024)}MB")
if file.size < MIN_FILE_SIZE:
raise FileValidationError("File too small or empty")
# 3. Check file extension
filename = file.name or ''
filename = file.name or ""
ext = os.path.splitext(filename)[1].lower()
if ext not in allowed_extensions:
raise FileValidationError(
f"Invalid file extension '{ext}'. Allowed: {', '.join(allowed_extensions)}"
)
raise FileValidationError(f"Invalid file extension '{ext}'. Allowed: {', '.join(allowed_extensions)}")
# 4. Check Content-Type header
content_type = getattr(file, 'content_type', '')
content_type = getattr(file, "content_type", "")
if content_type and content_type not in allowed_types:
raise FileValidationError(
f"Invalid file type '{content_type}'. Allowed: {', '.join(allowed_types)}"
)
raise FileValidationError(f"Invalid file type '{content_type}'. Allowed: {', '.join(allowed_types)}")
# 5. Validate magic numbers (actual file content)
if not _validate_magic_number(file):
raise FileValidationError(
"File content doesn't match file extension. File may be corrupted or malicious."
)
raise FileValidationError("File content doesn't match file extension. File may be corrupted or malicious.")
# 6. Validate image integrity using PIL
if not _validate_image_integrity(file):
raise FileValidationError(
"Invalid or corrupted image file"
)
raise FileValidationError("Invalid or corrupted image file")
return True
@@ -191,10 +191,10 @@ def _validate_magic_number(file: UploadedFile) -> bool:
# Check against known signatures
for format_name, signatures in IMAGE_SIGNATURES.items():
for magic, offset, _description in signatures:
if len(header) >= offset + len(magic) and header[offset:offset + len(magic)] == magic:
if len(header) >= offset + len(magic) and header[offset : offset + len(magic)] == magic:
# Special handling for WebP (must also have WEBP marker)
if format_name == 'webp':
if len(header) >= 12 and header[8:12] == b'WEBP':
if format_name == "webp":
if len(header) >= 12 and header[8:12] == b"WEBP":
return True
else:
return True
@@ -233,9 +233,7 @@ def _validate_image_integrity(file: UploadedFile) -> bool:
# Prevent decompression bombs
max_dimension = 10000
if img2.width > max_dimension or img2.height > max_dimension:
raise FileValidationError(
f"Image dimensions too large. Maximum is {max_dimension}x{max_dimension}"
)
raise FileValidationError(f"Image dimensions too large. Maximum is {max_dimension}x{max_dimension}")
# Check for very small dimensions (might be suspicious)
if img2.width < 1 or img2.height < 1:
@@ -253,6 +251,7 @@ def _validate_image_integrity(file: UploadedFile) -> bool:
# Filename Sanitization
# =============================================================================
def sanitize_filename(filename: str, max_length: int = 100) -> str:
"""
Sanitize a filename to prevent directory traversal and other attacks.
@@ -281,13 +280,13 @@ def sanitize_filename(filename: str, max_length: int = 100) -> str:
# Remove or replace dangerous characters from name
# Allow alphanumeric, hyphens, underscores, dots
name = re.sub(r'[^\w\-.]', '_', name)
name = re.sub(r"[^\w\-.]", "_", name)
# Remove leading dots and underscores (hidden file prevention)
name = name.lstrip('._')
name = name.lstrip("._")
# Collapse multiple underscores
name = re.sub(r'_+', '_', name)
name = re.sub(r"_+", "_", name)
# Ensure name is not empty
if not name:
@@ -295,7 +294,7 @@ def sanitize_filename(filename: str, max_length: int = 100) -> str:
# Sanitize extension
ext = ext.lower()
ext = re.sub(r'[^\w.]', '', ext)
ext = re.sub(r"[^\w.]", "", ext)
# Combine and truncate
result = f"{name[:max_length - len(ext)]}{ext}"
@@ -303,7 +302,7 @@ def sanitize_filename(filename: str, max_length: int = 100) -> str:
return result
def generate_unique_filename(original_filename: str, prefix: str = '') -> str:
def generate_unique_filename(original_filename: str, prefix: str = "") -> str:
"""
Generate a unique filename using UUID while preserving extension.
@@ -317,7 +316,7 @@ def generate_unique_filename(original_filename: str, prefix: str = '') -> str:
ext = os.path.splitext(original_filename)[1].lower()
# Sanitize extension
ext = re.sub(r'[^\w.]', '', ext)
ext = re.sub(r"[^\w.]", "", ext)
# Generate unique filename
unique_id = uuid.uuid4().hex[:12]
@@ -332,9 +331,9 @@ def generate_unique_filename(original_filename: str, prefix: str = '') -> str:
# Rate limiting configuration
UPLOAD_RATE_LIMITS = {
'per_minute': 10,
'per_hour': 100,
'per_day': 500,
"per_minute": 10,
"per_hour": 100,
"per_day": 500,
}
@@ -351,24 +350,25 @@ def check_upload_rate_limit(user_id: int, cache_backend=None) -> tuple[bool, str
"""
if cache_backend is None:
from django.core.cache import cache
cache_backend = cache
# Check per-minute limit
minute_key = f"upload_rate:{user_id}:minute"
minute_count = cache_backend.get(minute_key, 0)
if minute_count >= UPLOAD_RATE_LIMITS['per_minute']:
if minute_count >= UPLOAD_RATE_LIMITS["per_minute"]:
return False, "Upload rate limit exceeded. Please wait a minute."
# Check per-hour limit
hour_key = f"upload_rate:{user_id}:hour"
hour_count = cache_backend.get(hour_key, 0)
if hour_count >= UPLOAD_RATE_LIMITS['per_hour']:
if hour_count >= UPLOAD_RATE_LIMITS["per_hour"]:
return False, "Hourly upload limit exceeded. Please try again later."
# Check per-day limit
day_key = f"upload_rate:{user_id}:day"
day_count = cache_backend.get(day_key, 0)
if day_count >= UPLOAD_RATE_LIMITS['per_day']:
if day_count >= UPLOAD_RATE_LIMITS["per_day"]:
return False, "Daily upload limit exceeded. Please try again tomorrow."
return True, ""
@@ -384,6 +384,7 @@ def increment_upload_count(user_id: int, cache_backend=None) -> None:
"""
if cache_backend is None:
from django.core.cache import cache
cache_backend = cache
# Increment per-minute counter (expires in 60 seconds)
@@ -412,6 +413,7 @@ def increment_upload_count(user_id: int, cache_backend=None) -> None:
# Antivirus Integration Point
# =============================================================================
def scan_file_for_malware(file: UploadedFile) -> tuple[bool, str]:
"""
Placeholder for antivirus/malware scanning integration.

View File

@@ -26,6 +26,7 @@ from typing import Any
try:
import bleach
BLEACH_AVAILABLE = True
except ImportError:
BLEACH_AVAILABLE = False
@@ -36,71 +37,135 @@ except ImportError:
# =============================================================================
# Default allowed HTML tags for user-generated content
ALLOWED_TAGS = frozenset([
# Text formatting
'p', 'br', 'hr',
'strong', 'b', 'em', 'i', 'u', 's', 'strike',
'sub', 'sup', 'small', 'mark',
# Headers
'h1', 'h2', 'h3', 'h4', 'h5', 'h6',
# Lists
'ul', 'ol', 'li',
# Links (with restrictions on attributes)
'a',
# Block elements
'blockquote', 'pre', 'code',
'div', 'span',
# Tables
'table', 'thead', 'tbody', 'tfoot', 'tr', 'th', 'td',
])
ALLOWED_TAGS = frozenset(
[
# Text formatting
"p",
"br",
"hr",
"strong",
"b",
"em",
"i",
"u",
"s",
"strike",
"sub",
"sup",
"small",
"mark",
# Headers
"h1",
"h2",
"h3",
"h4",
"h5",
"h6",
# Lists
"ul",
"ol",
"li",
# Links (with restrictions on attributes)
"a",
# Block elements
"blockquote",
"pre",
"code",
"div",
"span",
# Tables
"table",
"thead",
"tbody",
"tfoot",
"tr",
"th",
"td",
]
)
# Allowed attributes for each tag
ALLOWED_ATTRIBUTES = {
'a': ['href', 'title', 'rel', 'target'],
'img': ['src', 'alt', 'title', 'width', 'height'],
'div': ['class'],
'span': ['class'],
'p': ['class'],
'table': ['class'],
'th': ['class', 'colspan', 'rowspan'],
'td': ['class', 'colspan', 'rowspan'],
'*': ['class'], # Allow class on all elements
"a": ["href", "title", "rel", "target"],
"img": ["src", "alt", "title", "width", "height"],
"div": ["class"],
"span": ["class"],
"p": ["class"],
"table": ["class"],
"th": ["class", "colspan", "rowspan"],
"td": ["class", "colspan", "rowspan"],
"*": ["class"], # Allow class on all elements
}
# Allowed URL protocols
ALLOWED_PROTOCOLS = frozenset([
'http', 'https', 'mailto', 'tel',
])
ALLOWED_PROTOCOLS = frozenset(
[
"http",
"https",
"mailto",
"tel",
]
)
# Minimal tags for comments and short text
MINIMAL_TAGS = frozenset([
'p', 'br', 'strong', 'b', 'em', 'i', 'a',
])
MINIMAL_TAGS = frozenset(
[
"p",
"br",
"strong",
"b",
"em",
"i",
"a",
]
)
# Tags allowed in icon SVGs (for icon template rendering)
SVG_TAGS = frozenset([
'svg', 'path', 'g', 'circle', 'rect', 'line', 'polyline', 'polygon',
'ellipse', 'text', 'tspan', 'defs', 'use', 'symbol', 'clipPath',
'mask', 'linearGradient', 'radialGradient', 'stop', 'title',
])
SVG_TAGS = frozenset(
[
"svg",
"path",
"g",
"circle",
"rect",
"line",
"polyline",
"polygon",
"ellipse",
"text",
"tspan",
"defs",
"use",
"symbol",
"clipPath",
"mask",
"linearGradient",
"radialGradient",
"stop",
"title",
]
)
SVG_ATTRIBUTES = {
'svg': ['viewBox', 'width', 'height', 'fill', 'stroke', 'class',
'xmlns', 'aria-hidden', 'role'],
'path': ['d', 'fill', 'stroke', 'stroke-width', 'stroke-linecap',
'stroke-linejoin', 'class', 'fill-rule', 'clip-rule'],
'g': ['fill', 'stroke', 'transform', 'class'],
'circle': ['cx', 'cy', 'r', 'fill', 'stroke', 'class'],
'rect': ['x', 'y', 'width', 'height', 'rx', 'ry', 'fill', 'stroke', 'class'],
'line': ['x1', 'y1', 'x2', 'y2', 'stroke', 'stroke-width', 'class'],
'polyline': ['points', 'fill', 'stroke', 'class'],
'polygon': ['points', 'fill', 'stroke', 'class'],
'*': ['class', 'fill', 'stroke'],
"svg": ["viewBox", "width", "height", "fill", "stroke", "class", "xmlns", "aria-hidden", "role"],
"path": [
"d",
"fill",
"stroke",
"stroke-width",
"stroke-linecap",
"stroke-linejoin",
"class",
"fill-rule",
"clip-rule",
],
"g": ["fill", "stroke", "transform", "class"],
"circle": ["cx", "cy", "r", "fill", "stroke", "class"],
"rect": ["x", "y", "width", "height", "rx", "ry", "fill", "stroke", "class"],
"line": ["x1", "y1", "x2", "y2", "stroke", "stroke-width", "class"],
"polyline": ["points", "fill", "stroke", "class"],
"polygon": ["points", "fill", "stroke", "class"],
"*": ["class", "fill", "stroke"],
}
@@ -108,6 +173,7 @@ SVG_ATTRIBUTES = {
# Sanitization Functions
# =============================================================================
def sanitize_html(
html: str | None,
allowed_tags: frozenset | None = None,
@@ -133,7 +199,7 @@ def sanitize_html(
'<p>Hello</p>'
"""
if not html:
return ''
return ""
if not isinstance(html, str):
html = str(html)
@@ -170,7 +236,7 @@ def sanitize_minimal(html: str | None) -> str:
return sanitize_html(
html,
allowed_tags=MINIMAL_TAGS,
allowed_attributes={'a': ['href', 'title']},
allowed_attributes={"a": ["href", "title"]},
)
@@ -188,7 +254,7 @@ def sanitize_svg(svg: str | None) -> str:
Sanitized SVG string safe for inline rendering
"""
if not svg:
return ''
return ""
if not isinstance(svg, str):
svg = str(svg)
@@ -218,7 +284,7 @@ def strip_html(html: str | None) -> str:
Plain text with all HTML tags removed
"""
if not html:
return ''
return ""
if not isinstance(html, str):
html = str(html)
@@ -227,13 +293,14 @@ def strip_html(html: str | None) -> str:
return bleach.clean(html, tags=[], strip=True)
else:
# Fallback: use regex to strip tags
return re.sub(r'<[^>]+>', '', html)
return re.sub(r"<[^>]+>", "", html)
# =============================================================================
# JSON/JavaScript Context Sanitization
# =============================================================================
def sanitize_for_json(data: Any) -> str:
"""
Safely serialize data for embedding in JavaScript/JSON contexts.
@@ -251,14 +318,12 @@ def sanitize_for_json(data: Any) -> str:
'{"name": "\\u003c/script\\u003e\\u003cscript\\u003ealert(\\"xss\\")"}'
"""
# JSON encode with safe characters escaped
return json.dumps(data, ensure_ascii=False).replace(
'<', '\\u003c'
).replace(
'>', '\\u003e'
).replace(
'&', '\\u0026'
).replace(
"'", '\\u0027'
return (
json.dumps(data, ensure_ascii=False)
.replace("<", "\\u003c")
.replace(">", "\\u003e")
.replace("&", "\\u0026")
.replace("'", "\\u0027")
)
@@ -273,26 +338,21 @@ def escape_js_string(s: str | None) -> str:
Escaped string safe for JavaScript contexts
"""
if not s:
return ''
return ""
if not isinstance(s, str):
s = str(s)
# Escape backslashes first, then other special characters
return s.replace('\\', '\\\\').replace(
"'", "\\'"
).replace(
'"', '\\"'
).replace(
'\n', '\\n'
).replace(
'\r', '\\r'
).replace(
'<', '\\u003c'
).replace(
'>', '\\u003e'
).replace(
'&', '\\u0026'
return (
s.replace("\\", "\\\\")
.replace("'", "\\'")
.replace('"', '\\"')
.replace("\n", "\\n")
.replace("\r", "\\r")
.replace("<", "\\u003c")
.replace(">", "\\u003e")
.replace("&", "\\u0026")
)
@@ -300,6 +360,7 @@ def escape_js_string(s: str | None) -> str:
# URL Sanitization
# =============================================================================
def sanitize_url(url: str | None, allowed_protocols: frozenset | None = None) -> str:
"""
Sanitize a URL to prevent javascript: and other dangerous protocols.
@@ -312,7 +373,7 @@ def sanitize_url(url: str | None, allowed_protocols: frozenset | None = None) ->
Sanitized URL or empty string if unsafe
"""
if not url:
return ''
return ""
if not isinstance(url, str):
url = str(url)
@@ -320,7 +381,7 @@ def sanitize_url(url: str | None, allowed_protocols: frozenset | None = None) ->
url = url.strip()
if not url:
return ''
return ""
protocols = allowed_protocols if allowed_protocols is not None else ALLOWED_PROTOCOLS
@@ -328,12 +389,12 @@ def sanitize_url(url: str | None, allowed_protocols: frozenset | None = None) ->
url_lower = url.lower()
# Check for javascript:, data:, vbscript:, etc.
if ':' in url_lower:
protocol = url_lower.split(':')[0]
if protocol not in protocols:
if ":" in url_lower:
protocol = url_lower.split(":")[0]
if protocol not in protocols: # noqa: SIM102
# Allow relative URLs and anchor links
if not (url.startswith('/') or url.startswith('#') or url.startswith('?')):
return ''
if not (url.startswith("/") or url.startswith("#") or url.startswith("?")):
return ""
return url
@@ -342,6 +403,7 @@ def sanitize_url(url: str | None, allowed_protocols: frozenset | None = None) ->
# Attribute Sanitization
# =============================================================================
def sanitize_attribute_value(value: str | None) -> str:
"""
Sanitize a value for use in HTML attributes.
@@ -353,7 +415,7 @@ def sanitize_attribute_value(value: str | None) -> str:
Sanitized value safe for HTML attribute contexts
"""
if not value:
return ''
return ""
if not isinstance(value, str):
value = str(value)
@@ -373,10 +435,10 @@ def sanitize_class_name(name: str | None) -> str:
Sanitized class name containing only safe characters
"""
if not name:
return ''
return ""
if not isinstance(name, str):
name = str(name)
# Only allow alphanumeric, hyphens, and underscores
return re.sub(r'[^a-zA-Z0-9_-]', '', name)
return re.sub(r"[^a-zA-Z0-9_-]", "", name)

View File

@@ -16,9 +16,7 @@ logger = logging.getLogger("query_optimization")
@contextmanager
def track_queries(
operation_name: str, warn_threshold: int = 10, time_threshold: float = 1.0
):
def track_queries(operation_name: str, warn_threshold: int = 10, time_threshold: float = 1.0):
"""
Context manager to track database queries for specific operations
@@ -47,15 +45,9 @@ def track_queries(
recent_queries = connection.queries[-total_queries:]
query_details = [
{
"sql": (
query["sql"][:500] + "..."
if len(query["sql"]) > 500
else query["sql"]
),
"sql": (query["sql"][:500] + "..." if len(query["sql"]) > 500 else query["sql"]),
"time": float(query["time"]),
"duplicate_count": sum(
1 for q in recent_queries if q["sql"] == query["sql"]
),
"duplicate_count": sum(1 for q in recent_queries if q["sql"] == query["sql"]),
}
for query in recent_queries
]
@@ -65,22 +57,18 @@ def track_queries(
"query_count": total_queries,
"execution_time": execution_time,
"queries": query_details if settings.DEBUG else [],
"slow_queries": [
q for q in query_details if q["time"] > 0.1
], # Queries slower than 100ms
"slow_queries": [q for q in query_details if q["time"] > 0.1], # Queries slower than 100ms
}
# Log warnings for performance issues
if total_queries > warn_threshold or execution_time > time_threshold:
logger.warning(
f"Performance concern in {operation_name}: "
f"{total_queries} queries, {execution_time:.2f}s",
f"Performance concern in {operation_name}: " f"{total_queries} queries, {execution_time:.2f}s",
extra=performance_data,
)
else:
logger.debug(
f"Query tracking for {operation_name}: "
f"{total_queries} queries, {execution_time:.2f}s",
f"Query tracking for {operation_name}: " f"{total_queries} queries, {execution_time:.2f}s",
extra=performance_data,
)
@@ -109,9 +97,7 @@ class QueryOptimizer:
Optimize Ride queryset with proper relationships
"""
return (
queryset.select_related(
"park", "park__location", "manufacturer", "created_by"
)
queryset.select_related("park", "park__location", "manufacturer", "created_by")
.prefetch_related("reviews__user", "media_items")
.annotate(
review_count=Count("reviews"),
@@ -158,9 +144,7 @@ class QueryCache:
"""Caching utilities for expensive queries"""
@staticmethod
def cache_queryset_result(
cache_key: str, queryset_func, timeout: int = 3600, **kwargs
):
def cache_queryset_result(cache_key: str, queryset_func, timeout: int = 3600, **kwargs):
"""
Cache the result of an expensive queryset operation
@@ -202,13 +186,9 @@ class QueryCache:
# For Redis cache backends that support pattern deletion
if hasattr(cache, "delete_pattern"):
deleted_count = cache.delete_pattern(pattern)
logger.info(
f"Invalidated {deleted_count} cache keys for pattern: {pattern}"
)
logger.info(f"Invalidated {deleted_count} cache keys for pattern: {pattern}")
else:
logger.warning(
f"Cache backend does not support pattern deletion: {pattern}"
)
logger.warning(f"Cache backend does not support pattern deletion: {pattern}")
except Exception as e:
logger.error(f"Error invalidating cache pattern {pattern}: {e}")
@@ -249,10 +229,7 @@ class IndexAnalyzer:
sql_upper = sql.upper()
analysis = {
"has_where_clause": "WHERE" in sql_upper,
"has_join": any(
join in sql_upper
for join in ["JOIN", "INNER JOIN", "LEFT JOIN", "RIGHT JOIN"]
),
"has_join": any(join in sql_upper for join in ["JOIN", "INNER JOIN", "LEFT JOIN", "RIGHT JOIN"]),
"has_order_by": "ORDER BY" in sql_upper,
"has_group_by": "GROUP BY" in sql_upper,
"has_like": "LIKE" in sql_upper,
@@ -266,19 +243,13 @@ class IndexAnalyzer:
# Suggest indexes based on patterns
if analysis["has_where_clause"] and not analysis["has_join"]:
analysis["suggestions"].append(
"Consider adding indexes on WHERE clause columns"
)
analysis["suggestions"].append("Consider adding indexes on WHERE clause columns")
if analysis["has_order_by"]:
analysis["suggestions"].append(
"Consider adding indexes on ORDER BY columns"
)
analysis["suggestions"].append("Consider adding indexes on ORDER BY columns")
if analysis["has_like"] and "%" not in sql[: sql.find("LIKE") + 10]:
analysis["suggestions"].append(
"LIKE queries with leading wildcards cannot use indexes efficiently"
)
analysis["suggestions"].append("LIKE queries with leading wildcards cannot use indexes efficiently")
return analysis
@@ -294,28 +265,16 @@ class IndexAnalyzer:
# automatically)
for field in opts.fields:
if isinstance(field, models.ForeignKey):
suggestions.append(
f"Index on {field.name} (automatically created by Django)"
)
suggestions.append(f"Index on {field.name} (automatically created by Django)")
# Suggest composite indexes for common query patterns
date_fields = [
f.name
for f in opts.fields
if isinstance(f, (models.DateField, models.DateTimeField))
]
status_fields = [
f.name
for f in opts.fields
if f.name in ["status", "is_active", "is_published"]
]
date_fields = [f.name for f in opts.fields if isinstance(f, models.DateField | models.DateTimeField)]
status_fields = [f.name for f in opts.fields if f.name in ["status", "is_active", "is_published"]]
if date_fields and status_fields:
for date_field in date_fields:
for status_field in status_fields:
suggestions.append(
f"Composite index on ({status_field}, {date_field}) for filtered date queries"
)
suggestions.append(f"Composite index on ({status_field}, {date_field}) for filtered date queries")
# Suggest indexes for fields commonly used in WHERE clauses
common_filter_fields = ["slug", "name", "created_at", "updated_at"]
@@ -340,9 +299,7 @@ def log_query_performance():
return decorator
def optimize_queryset_for_serialization(
queryset: QuerySet, fields: list[str]
) -> QuerySet:
def optimize_queryset_for_serialization(queryset: QuerySet, fields: list[str]) -> QuerySet:
"""
Optimize a queryset for API serialization by only selecting needed fields
@@ -362,9 +319,7 @@ def optimize_queryset_for_serialization(
field = opts.get_field(field_name)
if isinstance(field, models.ForeignKey):
select_related_fields.append(field_name)
elif isinstance(
field, (models.ManyToManyField, models.reverse.ManyToManyRel)
):
elif isinstance(field, models.ManyToManyField | models.reverse.ManyToManyRel):
prefetch_related_fields.append(field_name)
except models.FieldDoesNotExist:
# Field might be a property or method, skip optimization
@@ -421,7 +376,6 @@ def monitor_db_performance(operation_name: str):
)
else:
logger.debug(
f"DB performance for {operation_name}: "
f"{duration:.3f}s, {total_queries} queries",
f"DB performance for {operation_name}: " f"{duration:.3f}s, {total_queries} queries",
extra=performance_data,
)

View File

@@ -4,6 +4,7 @@ Cloudflare Turnstile validation utilities.
This module provides a function to validate Turnstile tokens
on the server side before processing form submissions.
"""
import requests
from django.conf import settings
@@ -20,45 +21,41 @@ def validate_turnstile_token(token: str, ip: str = None) -> dict:
dict with 'success' boolean and optional 'error' message
"""
# Skip validation if configured (dev mode)
if getattr(settings, 'TURNSTILE_SKIP_VALIDATION', False):
return {'success': True}
if getattr(settings, "TURNSTILE_SKIP_VALIDATION", False):
return {"success": True}
secret = getattr(settings, 'TURNSTILE_SECRET', '')
secret = getattr(settings, "TURNSTILE_SECRET", "")
if not secret:
return {'success': True} # Skip if no secret configured
return {"success": True} # Skip if no secret configured
if not token:
return {'success': False, 'error': 'Captcha verification required'}
return {"success": False, "error": "Captcha verification required"}
try:
response = requests.post(
'https://challenges.cloudflare.com/turnstile/v0/siteverify',
"https://challenges.cloudflare.com/turnstile/v0/siteverify",
data={
'secret': secret,
'response': token,
'remoteip': ip,
"secret": secret,
"response": token,
"remoteip": ip,
},
timeout=10
timeout=10,
)
result = response.json()
if result.get('success'):
return {'success': True}
if result.get("success"):
return {"success": True}
else:
error_codes = result.get('error-codes', [])
return {
'success': False,
'error': 'Captcha verification failed',
'error_codes': error_codes
}
error_codes = result.get("error-codes", [])
return {"success": False, "error": "Captcha verification failed", "error_codes": error_codes}
except requests.RequestException:
# Log error but don't block user on network issues
return {'success': True} # Fail open to avoid blocking legitimate users
return {"success": True} # Fail open to avoid blocking legitimate users
def get_client_ip(request):
"""Extract client IP from request, handling proxies."""
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
if x_forwarded_for:
return x_forwarded_for.split(',')[0].strip()
return request.META.get('REMOTE_ADDR')
return x_forwarded_for.split(",")[0].strip()
return request.META.get("REMOTE_ADDR")

View File

@@ -5,7 +5,6 @@ This module provides base view classes that implement common patterns
such as automatic query optimization with select_related and prefetch_related.
"""
from django.db.models import QuerySet
from django.views.generic import DetailView, ListView

View File

@@ -2,7 +2,6 @@
Entity search views with fuzzy matching and authentication prompts.
"""
import contextlib
from rest_framework import status
@@ -67,9 +66,7 @@ class EntityFuzzySearchView(APIView):
try:
# Parse request data
query = request.data.get("query", "").strip()
entity_types_raw = request.data.get(
"entity_types", ["park", "ride", "company"]
)
entity_types_raw = request.data.get("entity_types", ["park", "ride", "company"])
include_suggestions = request.data.get("include_suggestions", True)
# Validate query
@@ -105,9 +102,7 @@ class EntityFuzzySearchView(APIView):
"query": query,
"matches": [match.to_dict() for match in matches],
"user_authenticated": (
request.user.is_authenticated
if hasattr(request.user, "is_authenticated")
else False
request.user.is_authenticated if hasattr(request.user, "is_authenticated") else False
),
}
@@ -211,9 +206,7 @@ class EntityNotFoundView(APIView):
"context": context,
"matches": [match.to_dict() for match in matches],
"user_authenticated": (
request.user.is_authenticated
if hasattr(request.user, "is_authenticated")
else False
request.user.is_authenticated if hasattr(request.user, "is_authenticated") else False
),
"has_matches": len(matches) > 0,
}
@@ -267,9 +260,7 @@ class QuickEntitySuggestionView(APIView):
limit = min(int(request.GET.get("limit", 5)), 10) # Cap at 10
if not query or len(query) < 2:
return Response(
{"suggestions": [], "query": query}, status=status.HTTP_200_OK
)
return Response({"suggestions": [], "query": query}, status=status.HTTP_200_OK)
# Parse entity types
entity_types = []
@@ -282,9 +273,7 @@ class QuickEntitySuggestionView(APIView):
entity_types = [EntityType.PARK, EntityType.RIDE, EntityType.COMPANY]
# Get fuzzy matches
matches, _ = entity_fuzzy_matcher.find_entity(
query=query, entity_types=entity_types, user=request.user
)
matches, _ = entity_fuzzy_matcher.find_entity(query=query, entity_types=entity_types, user=request.user)
# Format as simple suggestions
suggestions = []
@@ -313,9 +302,7 @@ class QuickEntitySuggestionView(APIView):
# Utility function for other views to use
def get_entity_suggestions(
query: str, entity_types: list[str] | None = None, user=None
):
def get_entity_suggestions(query: str, entity_types: list[str] | None = None, user=None):
"""
Utility function for other Django views to get entity suggestions.
@@ -340,8 +327,6 @@ def get_entity_suggestions(
if not parsed_types:
parsed_types = [EntityType.PARK, EntityType.RIDE, EntityType.COMPANY]
return entity_fuzzy_matcher.find_entity(
query=query, entity_types=parsed_types, user=user
)
return entity_fuzzy_matcher.find_entity(query=query, entity_types=parsed_types, user=user)
except Exception:
return [], None

View File

@@ -86,7 +86,7 @@ class MapAPIView(View):
return bounds
return None
except (ValueError, TypeError) as e:
raise ValidationError(f"Invalid bounds parameters: {e}")
raise ValidationError(f"Invalid bounds parameters: {e}") from None
def _parse_pagination(self, request: HttpRequest) -> dict[str, int]:
"""Parse pagination parameters from request."""
@@ -125,11 +125,7 @@ class MapAPIView(View):
if location_types_param:
type_strings = location_types_param.split(",")
valid_types = {lt.value for lt in LocationType}
filters.location_types = {
LocationType(t.strip())
for t in type_strings
if t.strip() in valid_types
}
filters.location_types = {LocationType(t.strip()) for t in type_strings if t.strip() in valid_types}
# Park status
park_status_param = request.GET.get("park_status")
@@ -199,7 +195,7 @@ class MapAPIView(View):
)
except (ValueError, TypeError) as e:
raise ValidationError(f"Invalid filter parameters: {e}")
raise ValidationError(f"Invalid filter parameters: {e}") from None
def _parse_zoom_level(self, request: HttpRequest) -> int:
"""Parse zoom level from request with default."""
@@ -218,9 +214,7 @@ class MapAPIView(View):
request: HttpRequest,
) -> dict[str, Any]:
"""Create paginated response with metadata."""
total_pages = (total_count + pagination["page_size"] - 1) // pagination[
"page_size"
]
total_pages = (total_count + pagination["page_size"] - 1) // pagination["page_size"]
# Build pagination URLs
base_url = request.build_absolute_uri(request.path)
@@ -278,9 +272,7 @@ class MapAPIView(View):
return JsonResponse(response_data, status=status)
def _success_response(
self, data: Any, message: str = None, metadata: dict[str, Any] = None
) -> JsonResponse:
def _success_response(self, data: Any, message: str = None, metadata: dict[str, Any] = None) -> JsonResponse:
"""Return standardized success response."""
response_data = {
"status": "success",
@@ -397,9 +389,7 @@ class MapLocationDetailView(MapAPIView):
"""
@method_decorator(cache_page(600)) # Cache for 10 minutes
def get(
self, request: HttpRequest, location_type: str, location_id: int
) -> JsonResponse:
def get(self, request: HttpRequest, location_type: str, location_id: int) -> JsonResponse:
"""Get detailed information for a specific location."""
try:
# Validate location type
@@ -422,9 +412,7 @@ class MapLocationDetailView(MapAPIView):
)
# Get location details
location = unified_map_service.get_location_details(
location_type, location_id
)
location = unified_map_service.get_location_details(location_type, location_id)
if not location:
return self._error_response(
@@ -499,9 +487,7 @@ class MapSearchView(MapAPIView):
try:
valid_types = {lt.value for lt in LocationType}
location_types = {
LocationType(t.strip())
for t in types_param.split(",")
if t.strip() in valid_types
LocationType(t.strip()) for t in types_param.split(",") if t.strip() in valid_types
}
except ValueError:
return self._error_response(
@@ -569,9 +555,7 @@ class MapBoundsView(MapAPIView):
# Parse required bounds
bounds = self._parse_bounds(request)
if not bounds:
return self._error_response(
"Bounds parameters required: north, south, east, west", 400
)
return self._error_response("Bounds parameters required: north, south, east, west", 400)
# Parse optional filters
location_types = None

View File

@@ -74,9 +74,7 @@ class UniversalMapView(MapViewMixin, TemplateView):
)
# Handle initial bounds from query parameters
if all(
param in self.request.GET for param in ["north", "south", "east", "west"]
):
if all(param in self.request.GET for param in ["north", "south", "east", "west"]):
with contextlib.suppress(ValueError, TypeError):
context["initial_bounds"] = {
"north": float(self.request.GET["north"]),
@@ -243,9 +241,7 @@ class LocationSearchView(MapViewMixin, View):
limit = min(20, max(5, int(request.GET.get("limit", "10"))))
# Perform search
results = unified_map_service.search_locations(
query=query, location_types=location_types, limit=limit
)
results = unified_map_service.search_locations(query=query, location_types=location_types, limit=limit)
return render(
request,
@@ -285,11 +281,7 @@ class MapBoundsUpdateView(MapViewMixin, View):
zoom_level = int(data.get("zoom", 10))
location_types = None
if "types" in data:
location_types = {
LocationType(t)
for t in data["types"]
if t in [lt.value for lt in LocationType]
}
location_types = {LocationType(t) for t in data["types"] if t in [lt.value for lt in LocationType]}
# Location types are used directly in the service call
@@ -324,9 +316,7 @@ class LocationDetailModalView(MapViewMixin, View):
URL: /maps/htmx/location/<type>/<id>/
"""
def get(
self, request: HttpRequest, location_type: str, location_id: int
) -> HttpResponse:
def get(self, request: HttpRequest, location_type: str, location_id: int) -> HttpResponse:
"""Return location detail modal content."""
try:
# Validate location type
@@ -338,9 +328,7 @@ class LocationDetailModalView(MapViewMixin, View):
)
# Get location details
location = unified_map_service.get_location_details(
location_type, location_id
)
location = unified_map_service.get_location_details(location_type, location_id)
if not location:
return render(
@@ -356,9 +344,7 @@ class LocationDetailModalView(MapViewMixin, View):
)
except Exception as e:
return render(
request, "maps/partials/location_modal.html", {"error": str(e)}
)
return render(request, "maps/partials/location_modal.html", {"error": str(e)})
class LocationListView(MapViewMixin, TemplateView):
@@ -392,9 +378,7 @@ class LocationListView(MapViewMixin, TemplateView):
)
# Get locations without clustering
map_response = unified_map_service.get_map_data(
filters=filters, cluster=False, use_cache=True
)
map_response = unified_map_service.get_map_data(filters=filters, cluster=False, use_cache=True)
# Paginate results
paginator = Paginator(map_response.locations, self.paginate_by)

View File

@@ -83,13 +83,15 @@ class PerformanceDashboardView(TemplateView):
try:
client = cache._cache.get_client()
info = client.info()
cache_stats.update({
"connected_clients": info.get("connected_clients"),
"used_memory_human": info.get("used_memory_human"),
"keyspace_hits": info.get("keyspace_hits", 0),
"keyspace_misses": info.get("keyspace_misses", 0),
"total_commands": info.get("total_commands_processed"),
})
cache_stats.update(
{
"connected_clients": info.get("connected_clients"),
"used_memory_human": info.get("used_memory_human"),
"keyspace_hits": info.get("keyspace_hits", 0),
"keyspace_misses": info.get("keyspace_misses", 0),
"total_commands": info.get("total_commands_processed"),
}
)
# Calculate hit rate
hits = info.get("keyspace_hits", 0)
@@ -127,8 +129,7 @@ class PerformanceDashboardView(TemplateView):
# Get connection count (PostgreSQL specific)
try:
cursor.execute(
"SELECT count(*) FROM pg_stat_activity WHERE datname = %s;",
[db_settings.get("NAME")]
"SELECT count(*) FROM pg_stat_activity WHERE datname = %s;", [db_settings.get("NAME")]
)
stats["active_connections"] = cursor.fetchone()[0]
except Exception:
@@ -244,16 +245,18 @@ class CacheStatsAPIView(View):
client = cache._cache.get_client()
info = client.info()
cache_info.update({
"used_memory": info.get("used_memory_human"),
"connected_clients": info.get("connected_clients"),
"keyspace_hits": info.get("keyspace_hits", 0),
"keyspace_misses": info.get("keyspace_misses", 0),
"expired_keys": info.get("expired_keys", 0),
"evicted_keys": info.get("evicted_keys", 0),
"total_connections_received": info.get("total_connections_received"),
"total_commands_processed": info.get("total_commands_processed"),
})
cache_info.update(
{
"used_memory": info.get("used_memory_human"),
"connected_clients": info.get("connected_clients"),
"keyspace_hits": info.get("keyspace_hits", 0),
"keyspace_misses": info.get("keyspace_misses", 0),
"expired_keys": info.get("expired_keys", 0),
"evicted_keys": info.get("evicted_keys", 0),
"total_connections_received": info.get("total_connections_received"),
"total_commands_processed": info.get("total_commands_processed"),
}
)
# Calculate metrics
hits = info.get("keyspace_hits", 0)

Some files were not shown because too many files have changed in this diff Show More