mirror of
https://github.com/pacnpal/thrillwiki_django_no_react.git
synced 2026-01-01 23:47:04 -05:00
feat: Implement initial schema and add various API, service, and management command enhancements across the application.
This commit is contained in:
@@ -9,4 +9,4 @@ system status, and other foundational features.
|
||||
from .choices import core_choices
|
||||
|
||||
# Ensure choices are registered on app startup
|
||||
__all__ = ['core_choices']
|
||||
__all__ = ["core_choices"]
|
||||
|
||||
@@ -23,9 +23,7 @@ from .models import SlugHistory
|
||||
|
||||
|
||||
@admin.register(SlugHistory)
|
||||
class SlugHistoryAdmin(
|
||||
ReadOnlyAdminMixin, QueryOptimizationMixin, ExportActionMixin, BaseModelAdmin
|
||||
):
|
||||
class SlugHistoryAdmin(ReadOnlyAdminMixin, QueryOptimizationMixin, ExportActionMixin, BaseModelAdmin):
|
||||
"""
|
||||
Admin interface for SlugHistory management.
|
||||
|
||||
|
||||
@@ -221,13 +221,9 @@ class ExportActionMixin:
|
||||
writer.writerow(row)
|
||||
|
||||
response = HttpResponse(output.getvalue(), content_type="text/csv")
|
||||
response["Content-Disposition"] = (
|
||||
f'attachment; filename="{self.get_export_filename("csv")}"'
|
||||
)
|
||||
response["Content-Disposition"] = f'attachment; filename="{self.get_export_filename("csv")}"'
|
||||
|
||||
self.message_user(
|
||||
request, f"Successfully exported {queryset.count()} records to CSV."
|
||||
)
|
||||
self.message_user(request, f"Successfully exported {queryset.count()} records to CSV.")
|
||||
return response
|
||||
|
||||
@admin.action(description="Export selected to JSON")
|
||||
@@ -250,13 +246,9 @@ class ExportActionMixin:
|
||||
json.dumps(data, indent=2, cls=DjangoJSONEncoder),
|
||||
content_type="application/json",
|
||||
)
|
||||
response["Content-Disposition"] = (
|
||||
f'attachment; filename="{self.get_export_filename("json")}"'
|
||||
)
|
||||
response["Content-Disposition"] = f'attachment; filename="{self.get_export_filename("json")}"'
|
||||
|
||||
self.message_user(
|
||||
request, f"Successfully exported {queryset.count()} records to JSON."
|
||||
)
|
||||
self.message_user(request, f"Successfully exported {queryset.count()} records to JSON.")
|
||||
return response
|
||||
|
||||
def get_actions(self, request):
|
||||
|
||||
@@ -10,9 +10,7 @@ from django.utils import timezone
|
||||
|
||||
@pghistory.track()
|
||||
class PageView(models.Model):
|
||||
content_type = models.ForeignKey(
|
||||
ContentType, on_delete=models.CASCADE, related_name="page_views"
|
||||
)
|
||||
content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE, related_name="page_views")
|
||||
object_id = models.PositiveIntegerField()
|
||||
content_object = GenericForeignKey("content_type", "object_id")
|
||||
|
||||
@@ -64,9 +62,7 @@ class PageView(models.Model):
|
||||
return model_class.objects.none()
|
||||
|
||||
@classmethod
|
||||
def get_views_growth(
|
||||
cls, content_type, object_id, current_period_hours, previous_period_hours
|
||||
):
|
||||
def get_views_growth(cls, content_type, object_id, current_period_hours, previous_period_hours):
|
||||
"""Get view growth statistics between two time periods.
|
||||
|
||||
Args:
|
||||
@@ -102,9 +98,7 @@ class PageView(models.Model):
|
||||
if previous_views == 0:
|
||||
growth_percentage = current_views * 100 if current_views > 0 else 0
|
||||
else:
|
||||
growth_percentage = (
|
||||
(current_views - previous_views) / previous_views
|
||||
) * 100
|
||||
growth_percentage = ((current_views - previous_views) / previous_views) * 100
|
||||
|
||||
return current_views, previous_views, growth_percentage
|
||||
|
||||
@@ -121,6 +115,4 @@ class PageView(models.Model):
|
||||
int: Total view count
|
||||
"""
|
||||
cutoff = timezone.now() - timedelta(hours=hours)
|
||||
return cls.objects.filter(
|
||||
content_type=content_type, object_id=object_id, timestamp__gte=cutoff
|
||||
).count()
|
||||
return cls.objects.filter(content_type=content_type, object_id=object_id, timestamp__gte=cutoff).count()
|
||||
|
||||
@@ -31,9 +31,7 @@ from ..logging import get_logger, log_exception
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def custom_exception_handler(
|
||||
exc: Exception, context: dict[str, Any]
|
||||
) -> Response | None:
|
||||
def custom_exception_handler(exc: Exception, context: dict[str, Any]) -> Response | None:
|
||||
"""
|
||||
Custom exception handler for DRF that provides standardized error responses.
|
||||
|
||||
@@ -172,9 +170,7 @@ def custom_exception_handler(
|
||||
request=request,
|
||||
)
|
||||
|
||||
response = Response(
|
||||
custom_response_data, status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
response = Response(custom_response_data, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
|
||||
return response
|
||||
|
||||
@@ -234,10 +230,7 @@ def _format_django_validation_errors(
|
||||
"""Format Django ValidationError for API response."""
|
||||
if hasattr(exc, "error_dict"):
|
||||
# Field-specific errors
|
||||
return {
|
||||
field: [str(error) for error in errors]
|
||||
for field, errors in exc.error_dict.items()
|
||||
}
|
||||
return {field: [str(error) for error in errors] for field, errors in exc.error_dict.items()}
|
||||
elif hasattr(exc, "error_list"):
|
||||
# Non-field errors
|
||||
return {"non_field_errors": [str(error) for error in exc.error_list]}
|
||||
|
||||
@@ -103,15 +103,11 @@ class ApiMixin:
|
||||
# These will raise if not implemented; they also inform static analyzers about their existence.
|
||||
def paginate_queryset(self, queryset):
|
||||
"""Override / implement in subclass or provided base if pagination is needed."""
|
||||
raise NotImplementedError(
|
||||
"Subclasses must implement paginate_queryset to enable pagination"
|
||||
)
|
||||
raise NotImplementedError("Subclasses must implement paginate_queryset to enable pagination")
|
||||
|
||||
def get_paginated_response(self, data):
|
||||
"""Override / implement in subclass or provided base to return paginated responses."""
|
||||
raise NotImplementedError(
|
||||
"Subclasses must implement get_paginated_response to enable pagination"
|
||||
)
|
||||
raise NotImplementedError("Subclasses must implement get_paginated_response to enable pagination")
|
||||
|
||||
def get_object(self):
|
||||
"""Default placeholder; subclasses should implement this."""
|
||||
@@ -168,9 +164,7 @@ class UpdateApiMixin(ApiMixin):
|
||||
def update(self, _request: Request, *_args, **_kwargs) -> Response:
|
||||
"""Handle PUT/PATCH requests for updating resources."""
|
||||
instance = self.get_object()
|
||||
serializer = self.get_input_serializer(
|
||||
data=_request.data, partial=_kwargs.get("partial", False)
|
||||
)
|
||||
serializer = self.get_input_serializer(data=_request.data, partial=_kwargs.get("partial", False))
|
||||
serializer.is_valid(raise_exception=True)
|
||||
|
||||
# Update the object using the service layer
|
||||
@@ -229,9 +223,7 @@ class ListApiMixin(ApiMixin):
|
||||
Override this method to use selector patterns.
|
||||
Should call selector functions, not access model managers directly.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Subclasses must implement get_queryset using selectors"
|
||||
)
|
||||
raise NotImplementedError("Subclasses must implement get_queryset using selectors")
|
||||
|
||||
def get_output_serializer(self, *args, **kwargs):
|
||||
"""Get the output serializer for response."""
|
||||
|
||||
@@ -28,6 +28,7 @@ from django.core.checks import Error, Tags, Warning, register
|
||||
# Secret Key Validation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@register(Tags.security)
|
||||
def check_secret_key(app_configs, **kwargs):
|
||||
"""
|
||||
@@ -38,30 +39,30 @@ def check_secret_key(app_configs, **kwargs):
|
||||
- Key has sufficient entropy (length and character variety)
|
||||
"""
|
||||
errors = []
|
||||
secret_key = getattr(settings, 'SECRET_KEY', '')
|
||||
secret_key = getattr(settings, "SECRET_KEY", "")
|
||||
|
||||
# Check for empty or missing key
|
||||
if not secret_key:
|
||||
errors.append(
|
||||
Error(
|
||||
'SECRET_KEY is not set.',
|
||||
hint='Set a strong, random SECRET_KEY in your environment.',
|
||||
id='security.E001',
|
||||
"SECRET_KEY is not set.",
|
||||
hint="Set a strong, random SECRET_KEY in your environment.",
|
||||
id="security.E001",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
# Check for known insecure default values
|
||||
insecure_defaults = [
|
||||
'django-insecure',
|
||||
'your-secret-key',
|
||||
'change-me',
|
||||
'changeme',
|
||||
'secret',
|
||||
'xxx',
|
||||
'test',
|
||||
'development',
|
||||
'dev-key',
|
||||
"django-insecure",
|
||||
"your-secret-key",
|
||||
"change-me",
|
||||
"changeme",
|
||||
"secret",
|
||||
"xxx",
|
||||
"test",
|
||||
"development",
|
||||
"dev-key",
|
||||
]
|
||||
|
||||
key_lower = secret_key.lower()
|
||||
@@ -71,7 +72,7 @@ def check_secret_key(app_configs, **kwargs):
|
||||
Error(
|
||||
f'SECRET_KEY appears to contain an insecure default value: "{default}"',
|
||||
hint='Generate a new secret key using: python -c "from django.core.management.utils import get_random_secret_key; print(get_random_secret_key())"',
|
||||
id='security.E002',
|
||||
id="security.E002",
|
||||
)
|
||||
)
|
||||
break
|
||||
@@ -80,25 +81,25 @@ def check_secret_key(app_configs, **kwargs):
|
||||
if len(secret_key) < 50:
|
||||
errors.append(
|
||||
Warning(
|
||||
f'SECRET_KEY is only {len(secret_key)} characters long.',
|
||||
hint='A secret key should be at least 50 characters for proper security.',
|
||||
id='security.W001',
|
||||
f"SECRET_KEY is only {len(secret_key)} characters long.",
|
||||
hint="A secret key should be at least 50 characters for proper security.",
|
||||
id="security.W001",
|
||||
)
|
||||
)
|
||||
|
||||
# Check for sufficient character variety
|
||||
has_upper = bool(re.search(r'[A-Z]', secret_key))
|
||||
has_lower = bool(re.search(r'[a-z]', secret_key))
|
||||
has_digit = bool(re.search(r'[0-9]', secret_key))
|
||||
has_upper = bool(re.search(r"[A-Z]", secret_key))
|
||||
has_lower = bool(re.search(r"[a-z]", secret_key))
|
||||
has_digit = bool(re.search(r"[0-9]", secret_key))
|
||||
has_special = bool(re.search(r'[!@#$%^&*()_+\-=\[\]{};\':"\\|,.<>\/?]', secret_key))
|
||||
|
||||
char_types = sum([has_upper, has_lower, has_digit, has_special])
|
||||
if char_types < 3:
|
||||
errors.append(
|
||||
Warning(
|
||||
'SECRET_KEY lacks character variety.',
|
||||
hint='A good secret key should contain uppercase, lowercase, digits, and special characters.',
|
||||
id='security.W002',
|
||||
"SECRET_KEY lacks character variety.",
|
||||
hint="A good secret key should contain uppercase, lowercase, digits, and special characters.",
|
||||
id="security.W002",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -109,6 +110,7 @@ def check_secret_key(app_configs, **kwargs):
|
||||
# Debug Mode Check
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@register(Tags.security)
|
||||
def check_debug_mode(app_configs, **kwargs):
|
||||
"""
|
||||
@@ -117,27 +119,27 @@ def check_debug_mode(app_configs, **kwargs):
|
||||
errors = []
|
||||
|
||||
# Check if we're in a production-like environment
|
||||
env = os.environ.get('DJANGO_SETTINGS_MODULE', '')
|
||||
is_production = 'production' in env.lower() or 'prod' in env.lower()
|
||||
env = os.environ.get("DJANGO_SETTINGS_MODULE", "")
|
||||
is_production = "production" in env.lower() or "prod" in env.lower()
|
||||
|
||||
if is_production and settings.DEBUG:
|
||||
errors.append(
|
||||
Error(
|
||||
'DEBUG is True in what appears to be a production environment.',
|
||||
hint='Set DEBUG=False in production settings.',
|
||||
id='security.E003',
|
||||
"DEBUG is True in what appears to be a production environment.",
|
||||
hint="Set DEBUG=False in production settings.",
|
||||
id="security.E003",
|
||||
)
|
||||
)
|
||||
|
||||
# Also check if DEBUG is True with ALLOWED_HOSTS configured
|
||||
# (indicates possible production deployment with debug on)
|
||||
if settings.DEBUG and settings.ALLOWED_HOSTS and '*' not in settings.ALLOWED_HOSTS:
|
||||
if len(settings.ALLOWED_HOSTS) > 0 and 'localhost' not in settings.ALLOWED_HOSTS[0]:
|
||||
if settings.DEBUG and settings.ALLOWED_HOSTS and "*" not in settings.ALLOWED_HOSTS: # noqa: SIM102
|
||||
if len(settings.ALLOWED_HOSTS) > 0 and "localhost" not in settings.ALLOWED_HOSTS[0]:
|
||||
errors.append(
|
||||
Warning(
|
||||
'DEBUG is True but ALLOWED_HOSTS contains non-localhost values.',
|
||||
hint='This may indicate DEBUG is accidentally enabled in a deployed environment.',
|
||||
id='security.W003',
|
||||
"DEBUG is True but ALLOWED_HOSTS contains non-localhost values.",
|
||||
hint="This may indicate DEBUG is accidentally enabled in a deployed environment.",
|
||||
id="security.W003",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -148,30 +150,31 @@ def check_debug_mode(app_configs, **kwargs):
|
||||
# ALLOWED_HOSTS Check
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@register(Tags.security)
|
||||
def check_allowed_hosts(app_configs, **kwargs):
|
||||
"""
|
||||
Check ALLOWED_HOSTS configuration.
|
||||
"""
|
||||
errors = []
|
||||
allowed_hosts = getattr(settings, 'ALLOWED_HOSTS', [])
|
||||
allowed_hosts = getattr(settings, "ALLOWED_HOSTS", [])
|
||||
|
||||
if not settings.DEBUG:
|
||||
# In non-debug mode, ALLOWED_HOSTS must be set
|
||||
if not allowed_hosts:
|
||||
errors.append(
|
||||
Error(
|
||||
'ALLOWED_HOSTS is empty but DEBUG is False.',
|
||||
hint='Set ALLOWED_HOSTS to a list of allowed hostnames.',
|
||||
id='security.E004',
|
||||
"ALLOWED_HOSTS is empty but DEBUG is False.",
|
||||
hint="Set ALLOWED_HOSTS to a list of allowed hostnames.",
|
||||
id="security.E004",
|
||||
)
|
||||
)
|
||||
elif '*' in allowed_hosts:
|
||||
elif "*" in allowed_hosts:
|
||||
errors.append(
|
||||
Error(
|
||||
'ALLOWED_HOSTS contains "*" which allows all hosts.',
|
||||
hint='Specify explicit hostnames instead of wildcards.',
|
||||
id='security.E005',
|
||||
hint="Specify explicit hostnames instead of wildcards.",
|
||||
id="security.E005",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -182,6 +185,7 @@ def check_allowed_hosts(app_configs, **kwargs):
|
||||
# Security Headers Check
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@register(Tags.security)
|
||||
def check_security_headers(app_configs, **kwargs):
|
||||
"""
|
||||
@@ -190,34 +194,34 @@ def check_security_headers(app_configs, **kwargs):
|
||||
errors = []
|
||||
|
||||
# Check X-Frame-Options
|
||||
x_frame_options = getattr(settings, 'X_FRAME_OPTIONS', None)
|
||||
if x_frame_options not in ('DENY', 'SAMEORIGIN'):
|
||||
x_frame_options = getattr(settings, "X_FRAME_OPTIONS", None)
|
||||
if x_frame_options not in ("DENY", "SAMEORIGIN"):
|
||||
errors.append(
|
||||
Warning(
|
||||
f'X_FRAME_OPTIONS is set to "{x_frame_options}" or not set.',
|
||||
hint='Set X_FRAME_OPTIONS to "DENY" or "SAMEORIGIN" to prevent clickjacking.',
|
||||
id='security.W004',
|
||||
id="security.W004",
|
||||
)
|
||||
)
|
||||
|
||||
# Check content type sniffing protection
|
||||
if not getattr(settings, 'SECURE_CONTENT_TYPE_NOSNIFF', False):
|
||||
if not getattr(settings, "SECURE_CONTENT_TYPE_NOSNIFF", False):
|
||||
errors.append(
|
||||
Warning(
|
||||
'SECURE_CONTENT_TYPE_NOSNIFF is not enabled.',
|
||||
hint='Set SECURE_CONTENT_TYPE_NOSNIFF = True to prevent MIME type sniffing.',
|
||||
id='security.W005',
|
||||
"SECURE_CONTENT_TYPE_NOSNIFF is not enabled.",
|
||||
hint="Set SECURE_CONTENT_TYPE_NOSNIFF = True to prevent MIME type sniffing.",
|
||||
id="security.W005",
|
||||
)
|
||||
)
|
||||
|
||||
# Check referrer policy
|
||||
referrer_policy = getattr(settings, 'SECURE_REFERRER_POLICY', None)
|
||||
referrer_policy = getattr(settings, "SECURE_REFERRER_POLICY", None)
|
||||
if not referrer_policy:
|
||||
errors.append(
|
||||
Warning(
|
||||
'SECURE_REFERRER_POLICY is not set.',
|
||||
hint='Set SECURE_REFERRER_POLICY to control referrer header behavior.',
|
||||
id='security.W006',
|
||||
"SECURE_REFERRER_POLICY is not set.",
|
||||
hint="Set SECURE_REFERRER_POLICY to control referrer header behavior.",
|
||||
id="security.W006",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -228,6 +232,7 @@ def check_security_headers(app_configs, **kwargs):
|
||||
# HTTPS Settings Check
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@register(Tags.security)
|
||||
def check_https_settings(app_configs, **kwargs):
|
||||
"""
|
||||
@@ -240,32 +245,32 @@ def check_https_settings(app_configs, **kwargs):
|
||||
return errors
|
||||
|
||||
# Check SSL redirect
|
||||
if not getattr(settings, 'SECURE_SSL_REDIRECT', False):
|
||||
if not getattr(settings, "SECURE_SSL_REDIRECT", False):
|
||||
errors.append(
|
||||
Warning(
|
||||
'SECURE_SSL_REDIRECT is not enabled.',
|
||||
hint='Set SECURE_SSL_REDIRECT = True to redirect HTTP to HTTPS.',
|
||||
id='security.W007',
|
||||
"SECURE_SSL_REDIRECT is not enabled.",
|
||||
hint="Set SECURE_SSL_REDIRECT = True to redirect HTTP to HTTPS.",
|
||||
id="security.W007",
|
||||
)
|
||||
)
|
||||
|
||||
# Check HSTS settings
|
||||
hsts_seconds = getattr(settings, 'SECURE_HSTS_SECONDS', 0)
|
||||
hsts_seconds = getattr(settings, "SECURE_HSTS_SECONDS", 0)
|
||||
if hsts_seconds < 31536000: # Less than 1 year
|
||||
errors.append(
|
||||
Warning(
|
||||
f'SECURE_HSTS_SECONDS is {hsts_seconds} (less than 1 year).',
|
||||
hint='Set SECURE_HSTS_SECONDS to at least 31536000 (1 year) for HSTS preload eligibility.',
|
||||
id='security.W008',
|
||||
f"SECURE_HSTS_SECONDS is {hsts_seconds} (less than 1 year).",
|
||||
hint="Set SECURE_HSTS_SECONDS to at least 31536000 (1 year) for HSTS preload eligibility.",
|
||||
id="security.W008",
|
||||
)
|
||||
)
|
||||
|
||||
if not getattr(settings, 'SECURE_HSTS_INCLUDE_SUBDOMAINS', False):
|
||||
if not getattr(settings, "SECURE_HSTS_INCLUDE_SUBDOMAINS", False):
|
||||
errors.append(
|
||||
Warning(
|
||||
'SECURE_HSTS_INCLUDE_SUBDOMAINS is not enabled.',
|
||||
hint='Set SECURE_HSTS_INCLUDE_SUBDOMAINS = True to include all subdomains in HSTS.',
|
||||
id='security.W009',
|
||||
"SECURE_HSTS_INCLUDE_SUBDOMAINS is not enabled.",
|
||||
hint="Set SECURE_HSTS_INCLUDE_SUBDOMAINS = True to include all subdomains in HSTS.",
|
||||
id="security.W009",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -276,6 +281,7 @@ def check_https_settings(app_configs, **kwargs):
|
||||
# Cookie Security Check
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@register(Tags.security)
|
||||
def check_cookie_security(app_configs, **kwargs):
|
||||
"""
|
||||
@@ -288,42 +294,42 @@ def check_cookie_security(app_configs, **kwargs):
|
||||
return errors
|
||||
|
||||
# Check session cookie security
|
||||
if not getattr(settings, 'SESSION_COOKIE_SECURE', False):
|
||||
if not getattr(settings, "SESSION_COOKIE_SECURE", False):
|
||||
errors.append(
|
||||
Warning(
|
||||
'SESSION_COOKIE_SECURE is not enabled.',
|
||||
hint='Set SESSION_COOKIE_SECURE = True to only send session cookies over HTTPS.',
|
||||
id='security.W010',
|
||||
"SESSION_COOKIE_SECURE is not enabled.",
|
||||
hint="Set SESSION_COOKIE_SECURE = True to only send session cookies over HTTPS.",
|
||||
id="security.W010",
|
||||
)
|
||||
)
|
||||
|
||||
if not getattr(settings, 'SESSION_COOKIE_HTTPONLY', True):
|
||||
if not getattr(settings, "SESSION_COOKIE_HTTPONLY", True):
|
||||
errors.append(
|
||||
Warning(
|
||||
'SESSION_COOKIE_HTTPONLY is disabled.',
|
||||
hint='Set SESSION_COOKIE_HTTPONLY = True to prevent JavaScript access to session cookies.',
|
||||
id='security.W011',
|
||||
"SESSION_COOKIE_HTTPONLY is disabled.",
|
||||
hint="Set SESSION_COOKIE_HTTPONLY = True to prevent JavaScript access to session cookies.",
|
||||
id="security.W011",
|
||||
)
|
||||
)
|
||||
|
||||
# Check CSRF cookie security
|
||||
if not getattr(settings, 'CSRF_COOKIE_SECURE', False):
|
||||
if not getattr(settings, "CSRF_COOKIE_SECURE", False):
|
||||
errors.append(
|
||||
Warning(
|
||||
'CSRF_COOKIE_SECURE is not enabled.',
|
||||
hint='Set CSRF_COOKIE_SECURE = True to only send CSRF cookies over HTTPS.',
|
||||
id='security.W012',
|
||||
"CSRF_COOKIE_SECURE is not enabled.",
|
||||
hint="Set CSRF_COOKIE_SECURE = True to only send CSRF cookies over HTTPS.",
|
||||
id="security.W012",
|
||||
)
|
||||
)
|
||||
|
||||
# Check SameSite attributes
|
||||
session_samesite = getattr(settings, 'SESSION_COOKIE_SAMESITE', 'Lax')
|
||||
if session_samesite not in ('Strict', 'Lax'):
|
||||
session_samesite = getattr(settings, "SESSION_COOKIE_SAMESITE", "Lax")
|
||||
if session_samesite not in ("Strict", "Lax"):
|
||||
errors.append(
|
||||
Warning(
|
||||
f'SESSION_COOKIE_SAMESITE is set to "{session_samesite}".',
|
||||
hint='Set SESSION_COOKIE_SAMESITE to "Strict" or "Lax" for CSRF protection.',
|
||||
id='security.W013',
|
||||
id="security.W013",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -334,6 +340,7 @@ def check_cookie_security(app_configs, **kwargs):
|
||||
# Database Security Check
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@register(Tags.security)
|
||||
def check_database_security(app_configs, **kwargs):
|
||||
"""
|
||||
@@ -345,27 +352,27 @@ def check_database_security(app_configs, **kwargs):
|
||||
if settings.DEBUG:
|
||||
return errors
|
||||
|
||||
databases = getattr(settings, 'DATABASES', {})
|
||||
default_db = databases.get('default', {})
|
||||
databases = getattr(settings, "DATABASES", {})
|
||||
default_db = databases.get("default", {})
|
||||
|
||||
# Check for empty password
|
||||
if not default_db.get('PASSWORD') and default_db.get('ENGINE', '').endswith('postgresql'):
|
||||
if not default_db.get("PASSWORD") and default_db.get("ENGINE", "").endswith("postgresql"):
|
||||
errors.append(
|
||||
Warning(
|
||||
'Database password is empty.',
|
||||
hint='Set a strong password for database authentication.',
|
||||
id='security.W014',
|
||||
"Database password is empty.",
|
||||
hint="Set a strong password for database authentication.",
|
||||
id="security.W014",
|
||||
)
|
||||
)
|
||||
|
||||
# Check for SSL mode in PostgreSQL
|
||||
options = default_db.get('OPTIONS', {})
|
||||
if 'sslmode' not in str(options) and default_db.get('ENGINE', '').endswith('postgresql'):
|
||||
options = default_db.get("OPTIONS", {})
|
||||
if "sslmode" not in str(options) and default_db.get("ENGINE", "").endswith("postgresql"):
|
||||
errors.append(
|
||||
Warning(
|
||||
'Database SSL mode is not explicitly configured.',
|
||||
hint='Consider setting sslmode in database OPTIONS for encrypted connections.',
|
||||
id='security.W015',
|
||||
"Database SSL mode is not explicitly configured.",
|
||||
hint="Consider setting sslmode in database OPTIONS for encrypted connections.",
|
||||
id="security.W015",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -19,14 +19,14 @@ from .serializers import RichChoiceOptionSerializer, RichChoiceSerializer
|
||||
from .utils import get_choice_display, validate_choice_value
|
||||
|
||||
__all__ = [
|
||||
'RichChoice',
|
||||
'ChoiceCategory',
|
||||
'ChoiceGroup',
|
||||
'ChoiceRegistry',
|
||||
'register_choices',
|
||||
'RichChoiceField',
|
||||
'RichChoiceSerializer',
|
||||
'RichChoiceOptionSerializer',
|
||||
'validate_choice_value',
|
||||
'get_choice_display',
|
||||
"RichChoice",
|
||||
"ChoiceCategory",
|
||||
"ChoiceGroup",
|
||||
"ChoiceRegistry",
|
||||
"register_choices",
|
||||
"RichChoiceField",
|
||||
"RichChoiceSerializer",
|
||||
"RichChoiceOptionSerializer",
|
||||
"validate_choice_value",
|
||||
"get_choice_display",
|
||||
]
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Any
|
||||
|
||||
class ChoiceCategory(Enum):
|
||||
"""Categories for organizing choice types"""
|
||||
|
||||
STATUS = "status"
|
||||
TYPE = "type"
|
||||
CLASSIFICATION = "classification"
|
||||
@@ -42,6 +43,7 @@ class RichChoice:
|
||||
deprecated: Whether this choice is deprecated and should not be used for new entries
|
||||
category: Category for organizing related choices
|
||||
"""
|
||||
|
||||
value: str
|
||||
label: str
|
||||
description: str = ""
|
||||
@@ -59,40 +61,38 @@ class RichChoice:
|
||||
@property
|
||||
def color(self) -> str | None:
|
||||
"""Get the color from metadata if available"""
|
||||
return self.metadata.get('color')
|
||||
return self.metadata.get("color")
|
||||
|
||||
@property
|
||||
def icon(self) -> str | None:
|
||||
"""Get the icon from metadata if available"""
|
||||
return self.metadata.get('icon')
|
||||
return self.metadata.get("icon")
|
||||
|
||||
@property
|
||||
def css_class(self) -> str | None:
|
||||
"""Get the CSS class from metadata if available"""
|
||||
return self.metadata.get('css_class')
|
||||
return self.metadata.get("css_class")
|
||||
|
||||
@property
|
||||
def sort_order(self) -> int:
|
||||
"""Get the sort order from metadata, defaulting to 0"""
|
||||
return self.metadata.get('sort_order', 0)
|
||||
|
||||
return self.metadata.get("sort_order", 0)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary representation for API serialization"""
|
||||
return {
|
||||
'value': self.value,
|
||||
'label': self.label,
|
||||
'description': self.description,
|
||||
'metadata': self.metadata,
|
||||
'deprecated': self.deprecated,
|
||||
'category': self.category.value,
|
||||
'color': self.color,
|
||||
'icon': self.icon,
|
||||
'css_class': self.css_class,
|
||||
'sort_order': self.sort_order,
|
||||
"value": self.value,
|
||||
"label": self.label,
|
||||
"description": self.description,
|
||||
"metadata": self.metadata,
|
||||
"deprecated": self.deprecated,
|
||||
"category": self.category.value,
|
||||
"color": self.color,
|
||||
"icon": self.icon,
|
||||
"css_class": self.css_class,
|
||||
"sort_order": self.sort_order,
|
||||
}
|
||||
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.label
|
||||
|
||||
@@ -108,6 +108,7 @@ class ChoiceGroup:
|
||||
This allows for organizing choices into logical groups with
|
||||
common properties and behaviors.
|
||||
"""
|
||||
|
||||
name: str
|
||||
choices: list[RichChoice]
|
||||
description: str = ""
|
||||
@@ -147,8 +148,8 @@ class ChoiceGroup:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary representation for API serialization"""
|
||||
return {
|
||||
'name': self.name,
|
||||
'description': self.description,
|
||||
'metadata': self.metadata,
|
||||
'choices': [choice.to_dict() for choice in self.choices]
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"metadata": self.metadata,
|
||||
"choices": [choice.to_dict() for choice in self.choices],
|
||||
}
|
||||
|
||||
@@ -15,26 +15,26 @@ HEALTH_STATUSES = [
|
||||
label="Healthy",
|
||||
description="System is operating normally with no issues detected",
|
||||
metadata={
|
||||
'color': 'green',
|
||||
'icon': 'check-circle',
|
||||
'css_class': 'bg-green-100 text-green-800',
|
||||
'sort_order': 1,
|
||||
'http_status': 200
|
||||
"color": "green",
|
||||
"icon": "check-circle",
|
||||
"css_class": "bg-green-100 text-green-800",
|
||||
"sort_order": 1,
|
||||
"http_status": 200,
|
||||
},
|
||||
category=ChoiceCategory.STATUS
|
||||
category=ChoiceCategory.STATUS,
|
||||
),
|
||||
RichChoice(
|
||||
value="unhealthy",
|
||||
label="Unhealthy",
|
||||
description="System has detected issues that may affect functionality",
|
||||
metadata={
|
||||
'color': 'red',
|
||||
'icon': 'x-circle',
|
||||
'css_class': 'bg-red-100 text-red-800',
|
||||
'sort_order': 2,
|
||||
'http_status': 503
|
||||
"color": "red",
|
||||
"icon": "x-circle",
|
||||
"css_class": "bg-red-100 text-red-800",
|
||||
"sort_order": 2,
|
||||
"http_status": 503,
|
||||
},
|
||||
category=ChoiceCategory.STATUS
|
||||
category=ChoiceCategory.STATUS,
|
||||
),
|
||||
]
|
||||
|
||||
@@ -45,26 +45,26 @@ SIMPLE_HEALTH_STATUSES = [
|
||||
label="OK",
|
||||
description="Basic health check passed",
|
||||
metadata={
|
||||
'color': 'green',
|
||||
'icon': 'check',
|
||||
'css_class': 'bg-green-100 text-green-800',
|
||||
'sort_order': 1,
|
||||
'http_status': 200
|
||||
"color": "green",
|
||||
"icon": "check",
|
||||
"css_class": "bg-green-100 text-green-800",
|
||||
"sort_order": 1,
|
||||
"http_status": 200,
|
||||
},
|
||||
category=ChoiceCategory.STATUS
|
||||
category=ChoiceCategory.STATUS,
|
||||
),
|
||||
RichChoice(
|
||||
value="error",
|
||||
label="Error",
|
||||
description="Basic health check failed",
|
||||
metadata={
|
||||
'color': 'red',
|
||||
'icon': 'x',
|
||||
'css_class': 'bg-red-100 text-red-800',
|
||||
'sort_order': 2,
|
||||
'http_status': 500
|
||||
"color": "red",
|
||||
"icon": "x",
|
||||
"css_class": "bg-red-100 text-red-800",
|
||||
"sort_order": 2,
|
||||
"http_status": 500,
|
||||
},
|
||||
category=ChoiceCategory.STATUS
|
||||
category=ChoiceCategory.STATUS,
|
||||
),
|
||||
]
|
||||
|
||||
@@ -75,52 +75,52 @@ ENTITY_TYPES = [
|
||||
label="Park",
|
||||
description="Theme parks and amusement parks",
|
||||
metadata={
|
||||
'color': 'green',
|
||||
'icon': 'map-pin',
|
||||
'css_class': 'bg-green-100 text-green-800',
|
||||
'sort_order': 1,
|
||||
'search_weight': 1.0
|
||||
"color": "green",
|
||||
"icon": "map-pin",
|
||||
"css_class": "bg-green-100 text-green-800",
|
||||
"sort_order": 1,
|
||||
"search_weight": 1.0,
|
||||
},
|
||||
category=ChoiceCategory.CLASSIFICATION
|
||||
category=ChoiceCategory.CLASSIFICATION,
|
||||
),
|
||||
RichChoice(
|
||||
value="ride",
|
||||
label="Ride",
|
||||
description="Individual rides and attractions",
|
||||
metadata={
|
||||
'color': 'blue',
|
||||
'icon': 'activity',
|
||||
'css_class': 'bg-blue-100 text-blue-800',
|
||||
'sort_order': 2,
|
||||
'search_weight': 1.0
|
||||
"color": "blue",
|
||||
"icon": "activity",
|
||||
"css_class": "bg-blue-100 text-blue-800",
|
||||
"sort_order": 2,
|
||||
"search_weight": 1.0,
|
||||
},
|
||||
category=ChoiceCategory.CLASSIFICATION
|
||||
category=ChoiceCategory.CLASSIFICATION,
|
||||
),
|
||||
RichChoice(
|
||||
value="company",
|
||||
label="Company",
|
||||
description="Manufacturers, operators, and designers",
|
||||
metadata={
|
||||
'color': 'purple',
|
||||
'icon': 'building',
|
||||
'css_class': 'bg-purple-100 text-purple-800',
|
||||
'sort_order': 3,
|
||||
'search_weight': 0.8
|
||||
"color": "purple",
|
||||
"icon": "building",
|
||||
"css_class": "bg-purple-100 text-purple-800",
|
||||
"sort_order": 3,
|
||||
"search_weight": 0.8,
|
||||
},
|
||||
category=ChoiceCategory.CLASSIFICATION
|
||||
category=ChoiceCategory.CLASSIFICATION,
|
||||
),
|
||||
RichChoice(
|
||||
value="user",
|
||||
label="User",
|
||||
description="User profiles and accounts",
|
||||
metadata={
|
||||
'color': 'orange',
|
||||
'icon': 'user',
|
||||
'css_class': 'bg-orange-100 text-orange-800',
|
||||
'sort_order': 4,
|
||||
'search_weight': 0.5
|
||||
"color": "orange",
|
||||
"icon": "user",
|
||||
"css_class": "bg-orange-100 text-orange-800",
|
||||
"sort_order": 4,
|
||||
"search_weight": 0.5,
|
||||
},
|
||||
category=ChoiceCategory.CLASSIFICATION
|
||||
category=ChoiceCategory.CLASSIFICATION,
|
||||
),
|
||||
]
|
||||
|
||||
@@ -133,7 +133,7 @@ def register_core_choices():
|
||||
choices=HEALTH_STATUSES,
|
||||
domain="core",
|
||||
description="Health check status options",
|
||||
metadata={'domain': 'core', 'type': 'health_status'}
|
||||
metadata={"domain": "core", "type": "health_status"},
|
||||
)
|
||||
|
||||
register_choices(
|
||||
@@ -141,7 +141,7 @@ def register_core_choices():
|
||||
choices=SIMPLE_HEALTH_STATUSES,
|
||||
domain="core",
|
||||
description="Simple health check status options",
|
||||
metadata={'domain': 'core', 'type': 'simple_health_status'}
|
||||
metadata={"domain": "core", "type": "simple_health_status"},
|
||||
)
|
||||
|
||||
register_choices(
|
||||
@@ -149,7 +149,7 @@ def register_core_choices():
|
||||
choices=ENTITY_TYPES,
|
||||
domain="core",
|
||||
description="Entity type classifications for search functionality",
|
||||
metadata={'domain': 'core', 'type': 'entity_type'}
|
||||
metadata={"domain": "core", "type": "entity_type"},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -23,12 +23,7 @@ class RichChoiceField(models.CharField):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
choice_group: str,
|
||||
domain: str = "core",
|
||||
max_length: int = 50,
|
||||
allow_deprecated: bool = False,
|
||||
**kwargs
|
||||
self, choice_group: str, domain: str = "core", max_length: int = 50, allow_deprecated: bool = False, **kwargs
|
||||
):
|
||||
"""
|
||||
Initialize the RichChoiceField.
|
||||
@@ -52,8 +47,8 @@ class RichChoiceField(models.CharField):
|
||||
|
||||
choices = [(choice.value, choice.label) for choice in choices_list]
|
||||
|
||||
kwargs['choices'] = choices
|
||||
kwargs['max_length'] = max_length
|
||||
kwargs["choices"] = choices
|
||||
kwargs["max_length"] = max_length
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -61,21 +56,17 @@ class RichChoiceField(models.CharField):
|
||||
"""Validate the choice value"""
|
||||
super().validate(value, model_instance)
|
||||
|
||||
if value is None or value == '':
|
||||
if value is None or value == "":
|
||||
return
|
||||
|
||||
# Check if choice exists in registry
|
||||
choice = registry.get_choice(self.choice_group, value, self.domain)
|
||||
if choice is None:
|
||||
raise ValidationError(
|
||||
f"'{value}' is not a valid choice for {self.choice_group}"
|
||||
)
|
||||
raise ValidationError(f"'{value}' is not a valid choice for {self.choice_group}")
|
||||
|
||||
# Check if deprecated choices are allowed
|
||||
if choice.deprecated and not self.allow_deprecated:
|
||||
raise ValidationError(
|
||||
f"'{value}' is deprecated and cannot be used for new entries"
|
||||
)
|
||||
raise ValidationError(f"'{value}' is deprecated and cannot be used for new entries")
|
||||
|
||||
def get_rich_choice(self, value: str) -> RichChoice | None:
|
||||
"""Get the RichChoice object for a value"""
|
||||
@@ -94,21 +85,21 @@ class RichChoiceField(models.CharField):
|
||||
value = getattr(instance, name)
|
||||
return self.get_rich_choice(value) if value else None
|
||||
|
||||
setattr(cls, f'get_{name}_rich_choice', get_rich_choice_method)
|
||||
setattr(cls, f"get_{name}_rich_choice", get_rich_choice_method)
|
||||
|
||||
# Add get_FOO_display method (Django provides this, but we enhance it)
|
||||
def get_display_method(instance):
|
||||
value = getattr(instance, name)
|
||||
return self.get_choice_display(value) if value else ''
|
||||
return self.get_choice_display(value) if value else ""
|
||||
|
||||
setattr(cls, f'get_{name}_display', get_display_method)
|
||||
setattr(cls, f"get_{name}_display", get_display_method)
|
||||
|
||||
def deconstruct(self):
|
||||
"""Support for Django migrations"""
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
kwargs['choice_group'] = self.choice_group
|
||||
kwargs['domain'] = self.domain
|
||||
kwargs['allow_deprecated'] = self.allow_deprecated
|
||||
kwargs["choice_group"] = self.choice_group
|
||||
kwargs["domain"] = self.domain
|
||||
kwargs["allow_deprecated"] = self.allow_deprecated
|
||||
return name, path, args, kwargs
|
||||
|
||||
|
||||
@@ -123,7 +114,7 @@ class RichChoiceFormField(ChoiceField):
|
||||
domain: str = "core",
|
||||
allow_deprecated: bool = False,
|
||||
show_descriptions: bool = False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the form field.
|
||||
@@ -154,36 +145,28 @@ class RichChoiceFormField(ChoiceField):
|
||||
label = f"{choice.label} - {choice.description}"
|
||||
choices.append((choice.value, label))
|
||||
|
||||
kwargs['choices'] = choices
|
||||
kwargs["choices"] = choices
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def validate(self, value: Any) -> None:
|
||||
"""Validate the choice value"""
|
||||
super().validate(value)
|
||||
|
||||
if value is None or value == '':
|
||||
if value is None or value == "":
|
||||
return
|
||||
|
||||
# Check if choice exists in registry
|
||||
choice = registry.get_choice(self.choice_group, value, self.domain)
|
||||
if choice is None:
|
||||
raise ValidationError(
|
||||
f"'{value}' is not a valid choice for {self.choice_group}"
|
||||
)
|
||||
raise ValidationError(f"'{value}' is not a valid choice for {self.choice_group}")
|
||||
|
||||
# Check if deprecated choices are allowed
|
||||
if choice.deprecated and not self.allow_deprecated:
|
||||
raise ValidationError(
|
||||
f"'{value}' is deprecated and cannot be used"
|
||||
)
|
||||
raise ValidationError(f"'{value}' is deprecated and cannot be used")
|
||||
|
||||
|
||||
def create_rich_choice_field(
|
||||
choice_group: str,
|
||||
domain: str = "core",
|
||||
max_length: int = 50,
|
||||
allow_deprecated: bool = False,
|
||||
**kwargs
|
||||
choice_group: str, domain: str = "core", max_length: int = 50, allow_deprecated: bool = False, **kwargs
|
||||
) -> RichChoiceField:
|
||||
"""
|
||||
Factory function to create a RichChoiceField.
|
||||
@@ -192,9 +175,5 @@ def create_rich_choice_field(
|
||||
across multiple models.
|
||||
"""
|
||||
return RichChoiceField(
|
||||
choice_group=choice_group,
|
||||
domain=domain,
|
||||
max_length=max_length,
|
||||
allow_deprecated=allow_deprecated,
|
||||
**kwargs
|
||||
choice_group=choice_group, domain=domain, max_length=max_length, allow_deprecated=allow_deprecated, **kwargs
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ class ChoiceRegistry:
|
||||
choices: list[RichChoice],
|
||||
domain: str = "core",
|
||||
description: str = "",
|
||||
metadata: dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> ChoiceGroup:
|
||||
"""
|
||||
Register a group of choices.
|
||||
@@ -65,12 +65,7 @@ class ChoiceRegistry:
|
||||
f"Existing: {existing_values}, New: {new_values}"
|
||||
)
|
||||
|
||||
choice_group = ChoiceGroup(
|
||||
name=full_name,
|
||||
choices=choices,
|
||||
description=description,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
choice_group = ChoiceGroup(name=full_name, choices=choices, description=description, metadata=metadata or {})
|
||||
|
||||
self._choices[full_name] = choice_group
|
||||
|
||||
@@ -103,7 +98,6 @@ class ChoiceRegistry:
|
||||
choice_group = self.get(name, domain)
|
||||
return choice_group.get_active_choices() if choice_group else []
|
||||
|
||||
|
||||
def get_domains(self) -> list[str]:
|
||||
"""Get all registered domains"""
|
||||
return list(self._domains.keys())
|
||||
@@ -113,10 +107,7 @@ class ChoiceRegistry:
|
||||
if domain not in self._domains:
|
||||
return {}
|
||||
|
||||
return {
|
||||
name: self._choices[f"{domain}.{name}"]
|
||||
for name in self._domains[domain]
|
||||
}
|
||||
return {name: self._choices[f"{domain}.{name}"] for name in self._domains[domain]}
|
||||
|
||||
def list_all(self) -> dict[str, ChoiceGroup]:
|
||||
"""Get all registered choice groups"""
|
||||
@@ -159,7 +150,7 @@ def register_choices(
|
||||
choices: list[RichChoice],
|
||||
domain: str = "core",
|
||||
description: str = "",
|
||||
metadata: dict[str, Any] | None = None
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> ChoiceGroup:
|
||||
"""
|
||||
Convenience function to register choices with the global registry.
|
||||
@@ -187,8 +178,6 @@ def get_choice(group_name: str, value: str, domain: str = "core") -> RichChoice
|
||||
return registry.get_choice(group_name, value, domain)
|
||||
|
||||
|
||||
|
||||
|
||||
def validate_choice(group_name: str, value: str, domain: str = "core") -> bool:
|
||||
"""Validate a choice value using the global registry"""
|
||||
return registry.validate_choice(group_name, value, domain)
|
||||
|
||||
@@ -20,6 +20,7 @@ class RichChoiceSerializer(serializers.Serializer):
|
||||
This provides a consistent API representation for choice objects
|
||||
with all their metadata.
|
||||
"""
|
||||
|
||||
value = serializers.CharField()
|
||||
label = serializers.CharField()
|
||||
description = serializers.CharField()
|
||||
@@ -42,6 +43,7 @@ class RichChoiceOptionSerializer(serializers.Serializer):
|
||||
|
||||
This replaces the legacy FilterOptionSerializer with rich choice support.
|
||||
"""
|
||||
|
||||
value = serializers.CharField()
|
||||
label = serializers.CharField()
|
||||
description = serializers.CharField(allow_blank=True)
|
||||
@@ -58,30 +60,30 @@ class RichChoiceOptionSerializer(serializers.Serializer):
|
||||
if isinstance(instance, RichChoice):
|
||||
# Convert RichChoice to option format
|
||||
return {
|
||||
'value': instance.value,
|
||||
'label': instance.label,
|
||||
'description': instance.description,
|
||||
'count': None,
|
||||
'selected': False,
|
||||
'deprecated': instance.deprecated,
|
||||
'color': instance.color,
|
||||
'icon': instance.icon,
|
||||
'css_class': instance.css_class,
|
||||
'metadata': instance.metadata,
|
||||
"value": instance.value,
|
||||
"label": instance.label,
|
||||
"description": instance.description,
|
||||
"count": None,
|
||||
"selected": False,
|
||||
"deprecated": instance.deprecated,
|
||||
"color": instance.color,
|
||||
"icon": instance.icon,
|
||||
"css_class": instance.css_class,
|
||||
"metadata": instance.metadata,
|
||||
}
|
||||
elif isinstance(instance, dict):
|
||||
# Handle dictionary input (for backwards compatibility)
|
||||
return {
|
||||
'value': instance.get('value', ''),
|
||||
'label': instance.get('label', ''),
|
||||
'description': instance.get('description', ''),
|
||||
'count': instance.get('count'),
|
||||
'selected': instance.get('selected', False),
|
||||
'deprecated': instance.get('deprecated', False),
|
||||
'color': instance.get('color'),
|
||||
'icon': instance.get('icon'),
|
||||
'css_class': instance.get('css_class'),
|
||||
'metadata': instance.get('metadata', {}),
|
||||
"value": instance.get("value", ""),
|
||||
"label": instance.get("label", ""),
|
||||
"description": instance.get("description", ""),
|
||||
"count": instance.get("count"),
|
||||
"selected": instance.get("selected", False),
|
||||
"deprecated": instance.get("deprecated", False),
|
||||
"color": instance.get("color"),
|
||||
"icon": instance.get("icon"),
|
||||
"css_class": instance.get("css_class"),
|
||||
"metadata": instance.get("metadata", {}),
|
||||
}
|
||||
else:
|
||||
return super().to_representation(instance)
|
||||
@@ -94,6 +96,7 @@ class ChoiceGroupSerializer(serializers.Serializer):
|
||||
This provides API representation for entire choice groups
|
||||
with all their choices and metadata.
|
||||
"""
|
||||
|
||||
name = serializers.CharField()
|
||||
description = serializers.CharField()
|
||||
metadata = serializers.DictField()
|
||||
@@ -112,13 +115,7 @@ class RichChoiceFieldSerializer(serializers.CharField):
|
||||
include rich choice metadata in the response.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
choice_group: str,
|
||||
domain: str = "core",
|
||||
include_metadata: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
def __init__(self, choice_group: str, domain: str = "core", include_metadata: bool = False, **kwargs):
|
||||
"""
|
||||
Initialize the serializer field.
|
||||
|
||||
@@ -146,16 +143,16 @@ class RichChoiceFieldSerializer(serializers.CharField):
|
||||
else:
|
||||
# Fallback for unknown values
|
||||
return {
|
||||
'value': value,
|
||||
'label': value,
|
||||
'description': '',
|
||||
'metadata': {},
|
||||
'deprecated': False,
|
||||
'category': 'other',
|
||||
'color': None,
|
||||
'icon': None,
|
||||
'css_class': None,
|
||||
'sort_order': 0,
|
||||
"value": value,
|
||||
"label": value,
|
||||
"description": "",
|
||||
"metadata": {},
|
||||
"deprecated": False,
|
||||
"category": "other",
|
||||
"color": None,
|
||||
"icon": None,
|
||||
"css_class": None,
|
||||
"sort_order": 0,
|
||||
}
|
||||
else:
|
||||
# Return just the value
|
||||
@@ -163,20 +160,16 @@ class RichChoiceFieldSerializer(serializers.CharField):
|
||||
|
||||
def to_internal_value(self, data: Any) -> str:
|
||||
"""Convert input data to choice value"""
|
||||
if isinstance(data, dict) and 'value' in data:
|
||||
if isinstance(data, dict) and "value" in data:
|
||||
# Handle rich choice object input
|
||||
return data['value']
|
||||
return data["value"]
|
||||
else:
|
||||
# Handle string input
|
||||
return super().to_internal_value(data)
|
||||
|
||||
|
||||
def create_choice_options_serializer(
|
||||
choice_group: str,
|
||||
domain: str = "core",
|
||||
include_counts: bool = False,
|
||||
queryset=None,
|
||||
count_field: str = 'id'
|
||||
choice_group: str, domain: str = "core", include_counts: bool = False, queryset=None, count_field: str = "id"
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Create choice options for filter endpoints.
|
||||
@@ -199,47 +192,44 @@ def create_choice_options_serializer(
|
||||
|
||||
for choice in choices:
|
||||
option_data = {
|
||||
'value': choice.value,
|
||||
'label': choice.label,
|
||||
'description': choice.description,
|
||||
'selected': False,
|
||||
'deprecated': choice.deprecated,
|
||||
'color': choice.color,
|
||||
'icon': choice.icon,
|
||||
'css_class': choice.css_class,
|
||||
'metadata': choice.metadata,
|
||||
"value": choice.value,
|
||||
"label": choice.label,
|
||||
"description": choice.description,
|
||||
"selected": False,
|
||||
"deprecated": choice.deprecated,
|
||||
"color": choice.color,
|
||||
"icon": choice.icon,
|
||||
"css_class": choice.css_class,
|
||||
"metadata": choice.metadata,
|
||||
}
|
||||
|
||||
if include_counts and queryset is not None:
|
||||
# Count items for this choice
|
||||
try:
|
||||
count = queryset.filter(**{count_field: choice.value}).count()
|
||||
option_data['count'] = count
|
||||
option_data["count"] = count
|
||||
except Exception:
|
||||
# If counting fails, set count to None
|
||||
option_data['count'] = None
|
||||
option_data["count"] = None
|
||||
else:
|
||||
option_data['count'] = None
|
||||
option_data["count"] = None
|
||||
|
||||
options.append(option_data)
|
||||
|
||||
# Sort by sort_order, then by label
|
||||
options.sort(key=lambda x: (
|
||||
(lambda c: c.sort_order if (c is not None and hasattr(c, 'sort_order')) else 0)(
|
||||
registry.get_choice(choice_group, x['value'], domain)
|
||||
),
|
||||
x['label']
|
||||
))
|
||||
options.sort(
|
||||
key=lambda x: (
|
||||
(lambda c: c.sort_order if (c is not None and hasattr(c, "sort_order")) else 0)(
|
||||
registry.get_choice(choice_group, x["value"], domain)
|
||||
),
|
||||
x["label"],
|
||||
)
|
||||
)
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def serialize_choice_value(
|
||||
value: str,
|
||||
choice_group: str,
|
||||
domain: str = "core",
|
||||
include_metadata: bool = False
|
||||
) -> Any:
|
||||
def serialize_choice_value(value: str, choice_group: str, domain: str = "core", include_metadata: bool = False) -> Any:
|
||||
"""
|
||||
Serialize a single choice value.
|
||||
|
||||
@@ -262,16 +252,16 @@ def serialize_choice_value(
|
||||
else:
|
||||
# Fallback for unknown values
|
||||
return {
|
||||
'value': value,
|
||||
'label': value,
|
||||
'description': '',
|
||||
'metadata': {},
|
||||
'deprecated': False,
|
||||
'category': 'other',
|
||||
'color': None,
|
||||
'icon': None,
|
||||
'css_class': None,
|
||||
'sort_order': 0,
|
||||
"value": value,
|
||||
"label": value,
|
||||
"description": "",
|
||||
"metadata": {},
|
||||
"deprecated": False,
|
||||
"category": "other",
|
||||
"color": None,
|
||||
"icon": None,
|
||||
"css_class": None,
|
||||
"sort_order": 0,
|
||||
}
|
||||
else:
|
||||
return value
|
||||
|
||||
@@ -10,12 +10,7 @@ from .base import ChoiceCategory, RichChoice
|
||||
from .registry import registry
|
||||
|
||||
|
||||
def validate_choice_value(
|
||||
value: str,
|
||||
choice_group: str,
|
||||
domain: str = "core",
|
||||
allow_deprecated: bool = False
|
||||
) -> bool:
|
||||
def validate_choice_value(value: str, choice_group: str, domain: str = "core", allow_deprecated: bool = False) -> bool:
|
||||
"""
|
||||
Validate that a choice value is valid for a given choice group.
|
||||
|
||||
@@ -38,11 +33,7 @@ def validate_choice_value(
|
||||
return not (choice.deprecated and not allow_deprecated)
|
||||
|
||||
|
||||
def get_choice_display(
|
||||
value: str,
|
||||
choice_group: str,
|
||||
domain: str = "core"
|
||||
) -> str:
|
||||
def get_choice_display(value: str, choice_group: str, domain: str = "core") -> str:
|
||||
"""
|
||||
Get the display label for a choice value.
|
||||
|
||||
@@ -67,11 +58,8 @@ def get_choice_display(
|
||||
raise ValueError(f"Choice value '{value}' not found in group '{choice_group}' for domain '{domain}'")
|
||||
|
||||
|
||||
|
||||
|
||||
def create_status_choices(
|
||||
statuses: dict[str, dict[str, Any]],
|
||||
category: ChoiceCategory = ChoiceCategory.STATUS
|
||||
statuses: dict[str, dict[str, Any]], category: ChoiceCategory = ChoiceCategory.STATUS
|
||||
) -> list[RichChoice]:
|
||||
"""
|
||||
Create status choices with consistent color coding.
|
||||
@@ -86,28 +74,28 @@ def create_status_choices(
|
||||
choices = []
|
||||
|
||||
for value, config in statuses.items():
|
||||
metadata = config.get('metadata', {})
|
||||
metadata = config.get("metadata", {})
|
||||
|
||||
# Add default status colors if not specified
|
||||
if 'color' not in metadata:
|
||||
if 'operating' in value.lower() or 'active' in value.lower():
|
||||
metadata['color'] = 'green'
|
||||
elif 'closed' in value.lower() or 'inactive' in value.lower():
|
||||
metadata['color'] = 'red'
|
||||
elif 'temp' in value.lower() or 'pending' in value.lower():
|
||||
metadata['color'] = 'yellow'
|
||||
elif 'construction' in value.lower():
|
||||
metadata['color'] = 'blue'
|
||||
if "color" not in metadata:
|
||||
if "operating" in value.lower() or "active" in value.lower():
|
||||
metadata["color"] = "green"
|
||||
elif "closed" in value.lower() or "inactive" in value.lower():
|
||||
metadata["color"] = "red"
|
||||
elif "temp" in value.lower() or "pending" in value.lower():
|
||||
metadata["color"] = "yellow"
|
||||
elif "construction" in value.lower():
|
||||
metadata["color"] = "blue"
|
||||
else:
|
||||
metadata['color'] = 'gray'
|
||||
metadata["color"] = "gray"
|
||||
|
||||
choice = RichChoice(
|
||||
value=value,
|
||||
label=config['label'],
|
||||
description=config.get('description', ''),
|
||||
label=config["label"],
|
||||
description=config.get("description", ""),
|
||||
metadata=metadata,
|
||||
deprecated=config.get('deprecated', False),
|
||||
category=category
|
||||
deprecated=config.get("deprecated", False),
|
||||
category=category,
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
@@ -115,8 +103,7 @@ def create_status_choices(
|
||||
|
||||
|
||||
def create_type_choices(
|
||||
types: dict[str, dict[str, Any]],
|
||||
category: ChoiceCategory = ChoiceCategory.TYPE
|
||||
types: dict[str, dict[str, Any]], category: ChoiceCategory = ChoiceCategory.TYPE
|
||||
) -> list[RichChoice]:
|
||||
"""
|
||||
Create type/classification choices.
|
||||
@@ -133,21 +120,18 @@ def create_type_choices(
|
||||
for value, config in types.items():
|
||||
choice = RichChoice(
|
||||
value=value,
|
||||
label=config['label'],
|
||||
description=config.get('description', ''),
|
||||
metadata=config.get('metadata', {}),
|
||||
deprecated=config.get('deprecated', False),
|
||||
category=category
|
||||
label=config["label"],
|
||||
description=config.get("description", ""),
|
||||
metadata=config.get("metadata", {}),
|
||||
deprecated=config.get("deprecated", False),
|
||||
category=category,
|
||||
)
|
||||
choices.append(choice)
|
||||
|
||||
return choices
|
||||
|
||||
|
||||
def merge_choice_metadata(
|
||||
base_metadata: dict[str, Any],
|
||||
override_metadata: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
def merge_choice_metadata(base_metadata: dict[str, Any], override_metadata: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Merge choice metadata dictionaries.
|
||||
|
||||
@@ -163,10 +147,7 @@ def merge_choice_metadata(
|
||||
return merged
|
||||
|
||||
|
||||
def filter_choices_by_category(
|
||||
choices: list[RichChoice],
|
||||
category: ChoiceCategory
|
||||
) -> list[RichChoice]:
|
||||
def filter_choices_by_category(choices: list[RichChoice], category: ChoiceCategory) -> list[RichChoice]:
|
||||
"""
|
||||
Filter choices by category.
|
||||
|
||||
@@ -180,10 +161,7 @@ def filter_choices_by_category(
|
||||
return [choice for choice in choices if choice.category == category]
|
||||
|
||||
|
||||
def sort_choices(
|
||||
choices: list[RichChoice],
|
||||
sort_by: str = "sort_order"
|
||||
) -> list[RichChoice]:
|
||||
def sort_choices(choices: list[RichChoice], sort_by: str = "sort_order") -> list[RichChoice]:
|
||||
"""
|
||||
Sort choices by specified criteria.
|
||||
|
||||
@@ -204,10 +182,7 @@ def sort_choices(
|
||||
return choices
|
||||
|
||||
|
||||
def get_choice_colors(
|
||||
choice_group: str,
|
||||
domain: str = "core"
|
||||
) -> dict[str, str]:
|
||||
def get_choice_colors(choice_group: str, domain: str = "core") -> dict[str, str]:
|
||||
"""
|
||||
Get a mapping of choice values to their colors.
|
||||
|
||||
@@ -219,18 +194,10 @@ def get_choice_colors(
|
||||
Dictionary mapping choice values to colors
|
||||
"""
|
||||
choices = registry.get_choices(choice_group, domain)
|
||||
return {
|
||||
choice.value: choice.color
|
||||
for choice in choices
|
||||
if choice.color
|
||||
}
|
||||
return {choice.value: choice.color for choice in choices if choice.color}
|
||||
|
||||
|
||||
def validate_choice_group_data(
|
||||
name: str,
|
||||
choices: list[RichChoice],
|
||||
domain: str = "core"
|
||||
) -> list[str]:
|
||||
def validate_choice_group_data(name: str, choices: list[RichChoice], domain: str = "core") -> list[str]:
|
||||
"""
|
||||
Validate choice group data and return list of errors.
|
||||
|
||||
@@ -267,7 +234,7 @@ def validate_choice_group_data(
|
||||
description=choice.description,
|
||||
metadata=choice.metadata,
|
||||
deprecated=choice.deprecated,
|
||||
category=choice.category
|
||||
category=choice.category,
|
||||
)
|
||||
except ValueError as e:
|
||||
errors.append(f"Choice {i}: {str(e)}")
|
||||
@@ -286,19 +253,16 @@ def create_choice_from_config(config: dict[str, Any]) -> RichChoice:
|
||||
RichChoice object
|
||||
"""
|
||||
return RichChoice(
|
||||
value=config['value'],
|
||||
label=config['label'],
|
||||
description=config.get('description', ''),
|
||||
metadata=config.get('metadata', {}),
|
||||
deprecated=config.get('deprecated', False),
|
||||
category=ChoiceCategory(config.get('category', 'other'))
|
||||
value=config["value"],
|
||||
label=config["label"],
|
||||
description=config.get("description", ""),
|
||||
metadata=config.get("metadata", {}),
|
||||
deprecated=config.get("deprecated", False),
|
||||
category=ChoiceCategory(config.get("category", "other")),
|
||||
)
|
||||
|
||||
|
||||
def export_choices_to_dict(
|
||||
choice_group: str,
|
||||
domain: str = "core"
|
||||
) -> dict[str, Any]:
|
||||
def export_choices_to_dict(choice_group: str, domain: str = "core") -> dict[str, Any]:
|
||||
"""
|
||||
Export a choice group to a dictionary format.
|
||||
|
||||
|
||||
@@ -48,11 +48,7 @@ def cache_api_response(
|
||||
cache_key_parts = [
|
||||
key_prefix,
|
||||
view_func.__name__,
|
||||
(
|
||||
str(getattr(request.user, "id", "anonymous"))
|
||||
if request.user.is_authenticated
|
||||
else "anonymous"
|
||||
),
|
||||
(str(getattr(request.user, "id", "anonymous")) if request.user.is_authenticated else "anonymous"),
|
||||
str(hash(frozenset(request.GET.items()))),
|
||||
]
|
||||
|
||||
@@ -72,9 +68,7 @@ def cache_api_response(
|
||||
|
||||
# Try to get from cache
|
||||
cache_service = EnhancedCacheService()
|
||||
cached_response = getattr(cache_service, cache_backend + "_cache").get(
|
||||
cache_key
|
||||
)
|
||||
cached_response = getattr(cache_service, cache_backend + "_cache").get(cache_key)
|
||||
|
||||
if cached_response:
|
||||
logger.debug(
|
||||
@@ -87,11 +81,8 @@ def cache_api_response(
|
||||
)
|
||||
|
||||
# If cached data is our dict format for DRF responses, reconstruct it
|
||||
if isinstance(cached_response, dict) and '__drf_data__' in cached_response:
|
||||
return DRFResponse(
|
||||
data=cached_response['__drf_data__'],
|
||||
status=cached_response.get('status', 200)
|
||||
)
|
||||
if isinstance(cached_response, dict) and "__drf_data__" in cached_response:
|
||||
return DRFResponse(data=cached_response["__drf_data__"], status=cached_response.get("status", 200))
|
||||
|
||||
return cached_response
|
||||
|
||||
@@ -104,17 +95,12 @@ def cache_api_response(
|
||||
if hasattr(response, "status_code") and response.status_code == 200:
|
||||
# For DRF responses, we must cache the data, not the response object
|
||||
# because the response object is not rendered yet and cannot be pickled
|
||||
if hasattr(response, 'data'):
|
||||
cache_payload = {
|
||||
'__drf_data__': response.data,
|
||||
'status': response.status_code
|
||||
}
|
||||
if hasattr(response, "data"):
|
||||
cache_payload = {"__drf_data__": response.data, "status": response.status_code}
|
||||
else:
|
||||
cache_payload = response
|
||||
|
||||
getattr(cache_service, cache_backend + "_cache").set(
|
||||
cache_key, cache_payload, timeout
|
||||
)
|
||||
getattr(cache_service, cache_backend + "_cache").set(cache_key, cache_payload, timeout)
|
||||
logger.debug(
|
||||
f"Cached API response for view {view_func.__name__}",
|
||||
extra={
|
||||
@@ -162,9 +148,7 @@ def cache_queryset_result(
|
||||
cache_key = f"{cache_key_template}:{hash(str(args) + str(kwargs))}"
|
||||
|
||||
cache_service = EnhancedCacheService()
|
||||
cached_result = getattr(cache_service, cache_backend + "_cache").get(
|
||||
cache_key
|
||||
)
|
||||
cached_result = getattr(cache_service, cache_backend + "_cache").get(cache_key)
|
||||
|
||||
if cached_result is not None:
|
||||
logger.debug(f"Cache hit for queryset operation: {func.__name__}")
|
||||
@@ -175,9 +159,7 @@ def cache_queryset_result(
|
||||
result = func(*args, **kwargs)
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
getattr(cache_service, cache_backend + "_cache").set(
|
||||
cache_key, result, timeout
|
||||
)
|
||||
getattr(cache_service, cache_backend + "_cache").set(cache_key, result, timeout)
|
||||
logger.debug(
|
||||
f"Cached queryset result for {func.__name__}",
|
||||
extra={
|
||||
@@ -250,24 +232,18 @@ class CachedAPIViewMixin(View):
|
||||
cache_backend = "api"
|
||||
|
||||
@method_decorator(vary_on_headers("User-Agent", "Accept-Language"))
|
||||
def dispatch(
|
||||
self, request: HttpRequest, *args: Any, **kwargs: Any
|
||||
) -> HttpResponseBase:
|
||||
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponseBase:
|
||||
"""Add caching to the dispatch method"""
|
||||
if request.method == "GET" and getattr(self, "enable_caching", True):
|
||||
return self._cached_dispatch(request, *args, **kwargs)
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
def _cached_dispatch(
|
||||
self, request: HttpRequest, *args: Any, **kwargs: Any
|
||||
) -> HttpResponseBase:
|
||||
def _cached_dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponseBase:
|
||||
"""Handle cached dispatch for GET requests"""
|
||||
cache_key = self._generate_cache_key(request, *args, **kwargs)
|
||||
|
||||
cache_service = EnhancedCacheService()
|
||||
cached_response = getattr(cache_service, self.cache_backend + "_cache").get(
|
||||
cache_key
|
||||
)
|
||||
cached_response = getattr(cache_service, self.cache_backend + "_cache").get(cache_key)
|
||||
|
||||
if cached_response:
|
||||
logger.debug(f"Cache hit for view {self.__class__.__name__}")
|
||||
@@ -278,26 +254,18 @@ class CachedAPIViewMixin(View):
|
||||
|
||||
# Cache successful responses
|
||||
if hasattr(response, "status_code") and response.status_code == 200:
|
||||
getattr(cache_service, self.cache_backend + "_cache").set(
|
||||
cache_key, response, self.cache_timeout
|
||||
)
|
||||
getattr(cache_service, self.cache_backend + "_cache").set(cache_key, response, self.cache_timeout)
|
||||
logger.debug(f"Cached response for view {self.__class__.__name__}")
|
||||
|
||||
return response
|
||||
|
||||
def _generate_cache_key(
|
||||
self, request: HttpRequest, *args: Any, **kwargs: Any
|
||||
) -> str:
|
||||
def _generate_cache_key(self, request: HttpRequest, *args: Any, **kwargs: Any) -> str:
|
||||
"""Generate cache key for the request"""
|
||||
key_parts = [
|
||||
self.cache_key_prefix,
|
||||
self.__class__.__name__,
|
||||
request.method,
|
||||
(
|
||||
str(getattr(request.user, "id", "anonymous"))
|
||||
if request.user.is_authenticated
|
||||
else "anonymous"
|
||||
),
|
||||
(str(getattr(request.user, "id", "anonymous")) if request.user.is_authenticated else "anonymous"),
|
||||
str(hash(frozenset(request.GET.items()))),
|
||||
]
|
||||
|
||||
@@ -344,15 +312,11 @@ def smart_cache(
|
||||
"kwargs": json.dumps(kwargs, sort_keys=True, default=str),
|
||||
}
|
||||
key_string = json.dumps(key_data, sort_keys=True)
|
||||
cache_key = (
|
||||
f"smart_cache:{hashlib.md5(key_string.encode()).hexdigest()}"
|
||||
)
|
||||
cache_key = f"smart_cache:{hashlib.md5(key_string.encode()).hexdigest()}"
|
||||
|
||||
# Try to get from cache
|
||||
cache_service = EnhancedCacheService()
|
||||
cached_result = getattr(cache_service, cache_backend + "_cache").get(
|
||||
cache_key
|
||||
)
|
||||
cached_result = getattr(cache_service, cache_backend + "_cache").get(cache_key)
|
||||
|
||||
if cached_result is not None:
|
||||
logger.debug(f"Smart cache hit for {func.__name__}")
|
||||
@@ -364,9 +328,7 @@ def smart_cache(
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# Cache result
|
||||
getattr(cache_service, cache_backend + "_cache").set(
|
||||
cache_key, result, timeout
|
||||
)
|
||||
getattr(cache_service, cache_backend + "_cache").set(cache_key, result, timeout)
|
||||
|
||||
logger.debug(
|
||||
f"Smart cached result for {func.__name__}",
|
||||
@@ -426,16 +388,10 @@ def generate_model_cache_key(model_instance: Any, suffix: str = "") -> str:
|
||||
"""Generate cache key based on model instance"""
|
||||
model_name = model_instance._meta.model_name
|
||||
instance_id = model_instance.id
|
||||
return (
|
||||
f"{model_name}:{instance_id}:{suffix}"
|
||||
if suffix
|
||||
else f"{model_name}:{instance_id}"
|
||||
)
|
||||
return f"{model_name}:{instance_id}:{suffix}" if suffix else f"{model_name}:{instance_id}"
|
||||
|
||||
|
||||
def generate_queryset_cache_key(
|
||||
queryset: Any, params: dict[str, Any] | None = None
|
||||
) -> str:
|
||||
def generate_queryset_cache_key(queryset: Any, params: dict[str, Any] | None = None) -> str:
|
||||
"""Generate cache key for queryset with parameters"""
|
||||
model_name = queryset.model._meta.model_name
|
||||
params_str = json.dumps(params or {}, sort_keys=True, default=str)
|
||||
|
||||
@@ -22,9 +22,7 @@ class BaseAutocomplete(Autocomplete):
|
||||
|
||||
# UI text configuration using gettext for i18n
|
||||
no_result_text = _("No matches found")
|
||||
narrow_search_text = _(
|
||||
"Showing %(page_size)s of %(total)s matches. Please refine your search."
|
||||
)
|
||||
narrow_search_text = _("Showing %(page_size)s of %(total)s matches. Please refine your search.")
|
||||
type_at_least_n_characters = _("Type at least %(n)s characters...")
|
||||
|
||||
# Project-wide component settings
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Base forms and views for HTMX integration.
|
||||
"""
|
||||
|
||||
from django.http import JsonResponse
|
||||
from django.views.generic.edit import FormView
|
||||
|
||||
@@ -20,9 +21,6 @@ class HTMXFormView(FormView):
|
||||
|
||||
def post(self, request, *args, **kwargs):
|
||||
# If HTMX field validation pattern: ?field=name
|
||||
if (
|
||||
request.headers.get("HX-Request") == "true"
|
||||
and request.GET.get("validate_field")
|
||||
):
|
||||
if request.headers.get("HX-Request") == "true" and request.GET.get("validate_field"):
|
||||
return self.validate_field(request.GET.get("validate_field"))
|
||||
return super().post(request, *args, **kwargs)
|
||||
|
||||
@@ -42,12 +42,8 @@ class LocationSearchForm(forms.Form):
|
||||
)
|
||||
|
||||
# Hidden fields for coordinates
|
||||
lat = forms.FloatField(
|
||||
required=False, widget=forms.HiddenInput(attrs={"id": "lat-input"})
|
||||
)
|
||||
lng = forms.FloatField(
|
||||
required=False, widget=forms.HiddenInput(attrs={"id": "lng-input"})
|
||||
)
|
||||
lat = forms.FloatField(required=False, widget=forms.HiddenInput(attrs={"id": "lat-input"}))
|
||||
lng = forms.FloatField(required=False, widget=forms.HiddenInput(attrs={"id": "lng-input"}))
|
||||
|
||||
# Search radius
|
||||
radius_km = forms.ChoiceField(
|
||||
@@ -81,8 +77,7 @@ class LocationSearchForm(forms.Form):
|
||||
widget=forms.CheckboxInput(
|
||||
attrs={
|
||||
"class": (
|
||||
"rounded border-gray-300 text-blue-600 focus:ring-blue-500 "
|
||||
"dark:border-gray-600 dark:bg-gray-700"
|
||||
"rounded border-gray-300 text-blue-600 focus:ring-blue-500 " "dark:border-gray-600 dark:bg-gray-700"
|
||||
)
|
||||
}
|
||||
),
|
||||
@@ -93,8 +88,7 @@ class LocationSearchForm(forms.Form):
|
||||
widget=forms.CheckboxInput(
|
||||
attrs={
|
||||
"class": (
|
||||
"rounded border-gray-300 text-blue-600 focus:ring-blue-500 "
|
||||
"dark:border-gray-600 dark:bg-gray-700"
|
||||
"rounded border-gray-300 text-blue-600 focus:ring-blue-500 " "dark:border-gray-600 dark:bg-gray-700"
|
||||
)
|
||||
}
|
||||
),
|
||||
@@ -105,8 +99,7 @@ class LocationSearchForm(forms.Form):
|
||||
widget=forms.CheckboxInput(
|
||||
attrs={
|
||||
"class": (
|
||||
"rounded border-gray-300 text-blue-600 focus:ring-blue-500 "
|
||||
"dark:border-gray-600 dark:bg-gray-700"
|
||||
"rounded border-gray-300 text-blue-600 focus:ring-blue-500 " "dark:border-gray-600 dark:bg-gray-700"
|
||||
)
|
||||
}
|
||||
),
|
||||
|
||||
@@ -58,13 +58,9 @@ class CacheHealthCheck(BaseHealthCheckBackend):
|
||||
if max_memory > 0:
|
||||
memory_usage_percent = (used_memory / max_memory) * 100
|
||||
if memory_usage_percent > 90:
|
||||
self.add_error(
|
||||
f"Redis memory usage critical: {memory_usage_percent:.1f}%"
|
||||
)
|
||||
self.add_error(f"Redis memory usage critical: {memory_usage_percent:.1f}%")
|
||||
elif memory_usage_percent > 80:
|
||||
logger.warning(
|
||||
f"Redis memory usage high: {memory_usage_percent:.1f}%"
|
||||
)
|
||||
logger.warning(f"Redis memory usage high: {memory_usage_percent:.1f}%")
|
||||
|
||||
except ImportError:
|
||||
# django-redis not available, skip additional checks
|
||||
@@ -160,9 +156,7 @@ class ApplicationHealthCheck(BaseHealthCheckBackend):
|
||||
try:
|
||||
__import__(module_name)
|
||||
except ImportError as e:
|
||||
self.add_error(
|
||||
f"Critical module import failed: {module_name} - {e}"
|
||||
)
|
||||
self.add_error(f"Critical module import failed: {module_name} - {e}")
|
||||
|
||||
# Check if we can access critical models
|
||||
try:
|
||||
@@ -179,9 +173,7 @@ class ApplicationHealthCheck(BaseHealthCheckBackend):
|
||||
ride_count = Ride.objects.count()
|
||||
user_count = User.objects.count()
|
||||
|
||||
logger.debug(
|
||||
f"Model counts - Parks: {park_count}, Rides: {ride_count}, Users: {user_count}"
|
||||
)
|
||||
logger.debug(f"Model counts - Parks: {park_count}, Rides: {ride_count}, Users: {user_count}")
|
||||
|
||||
except Exception as e:
|
||||
self.add_error(f"Model access check failed: {e}")
|
||||
@@ -195,9 +187,7 @@ class ApplicationHealthCheck(BaseHealthCheckBackend):
|
||||
self.add_error(f"Media directory does not exist: {settings.MEDIA_ROOT}")
|
||||
|
||||
if not os.path.exists(settings.STATIC_ROOT) and not settings.DEBUG:
|
||||
self.add_error(
|
||||
f"Static directory does not exist: {settings.STATIC_ROOT}"
|
||||
)
|
||||
self.add_error(f"Static directory does not exist: {settings.STATIC_ROOT}")
|
||||
|
||||
except Exception as e:
|
||||
self.add_error(f"Application health check failed: {e}")
|
||||
@@ -214,10 +204,7 @@ class ExternalServiceHealthCheck(BaseHealthCheckBackend):
|
||||
from django.conf import settings
|
||||
from django.core.mail import get_connection
|
||||
|
||||
if (
|
||||
hasattr(settings, "EMAIL_BACKEND")
|
||||
and "console" not in settings.EMAIL_BACKEND
|
||||
):
|
||||
if hasattr(settings, "EMAIL_BACKEND") and "console" not in settings.EMAIL_BACKEND:
|
||||
# Only check if not using console backend
|
||||
connection = get_connection()
|
||||
if hasattr(connection, "open"):
|
||||
@@ -304,9 +291,7 @@ class DiskSpaceHealthCheck(BaseHealthCheckBackend):
|
||||
media_free_percent:.1f}% free in media directory"
|
||||
)
|
||||
elif media_free_percent < 20:
|
||||
logger.warning(
|
||||
f"Low disk space: {media_free_percent:.1f}% free in media directory"
|
||||
)
|
||||
logger.warning(f"Low disk space: {media_free_percent:.1f}% free in media directory")
|
||||
|
||||
if logs_free_percent < 10:
|
||||
self.add_error(
|
||||
@@ -314,9 +299,7 @@ class DiskSpaceHealthCheck(BaseHealthCheckBackend):
|
||||
logs_free_percent:.1f}% free in logs directory"
|
||||
)
|
||||
elif logs_free_percent < 20:
|
||||
logger.warning(
|
||||
f"Low disk space: {logs_free_percent:.1f}% free in logs directory"
|
||||
)
|
||||
logger.warning(f"Low disk space: {logs_free_percent:.1f}% free in logs directory")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Disk space check failed: {e}")
|
||||
|
||||
@@ -94,9 +94,7 @@ class TrackedModel(models.Model):
|
||||
|
||||
event_model = getattr(events, "model", None)
|
||||
if event_model:
|
||||
return event_model.objects.filter(pgh_obj_id=self.pk).order_by(
|
||||
"-pgh_created_at"
|
||||
)
|
||||
return event_model.objects.filter(pgh_obj_id=self.pk).order_by("-pgh_created_at")
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
return self.__class__.objects.none()
|
||||
|
||||
@@ -23,9 +23,7 @@ class ThrillWikiFormatter(logging.Formatter):
|
||||
if hasattr(record, "request"):
|
||||
record.request_id = getattr(record.request, "id", "unknown")
|
||||
record.user_id = (
|
||||
getattr(record.request.user, "id", "anonymous")
|
||||
if hasattr(record.request, "user")
|
||||
else "unknown"
|
||||
getattr(record.request.user, "id", "anonymous") if hasattr(record.request, "user") else "unknown"
|
||||
)
|
||||
record.path = getattr(record.request, "path", "unknown")
|
||||
record.method = getattr(record.request, "method", "unknown")
|
||||
@@ -52,9 +50,7 @@ def get_logger(name: str) -> logging.Logger:
|
||||
# Only configure if not already configured
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
formatter = ThrillWikiFormatter(
|
||||
fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
formatter = ThrillWikiFormatter(fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.INFO if settings.DEBUG else logging.WARNING)
|
||||
@@ -91,11 +87,7 @@ def log_exception(
|
||||
{
|
||||
"request_path": getattr(request, "path", "unknown"),
|
||||
"request_method": getattr(request, "method", "unknown"),
|
||||
"user_id": (
|
||||
getattr(request.user, "id", "anonymous")
|
||||
if hasattr(request, "user")
|
||||
else "unknown"
|
||||
),
|
||||
"user_id": (getattr(request.user, "id", "anonymous") if hasattr(request, "user") else "unknown"),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -134,11 +126,7 @@ def log_business_event(
|
||||
{
|
||||
"request_path": getattr(request, "path", "unknown"),
|
||||
"request_method": getattr(request, "method", "unknown"),
|
||||
"user_id": (
|
||||
getattr(request.user, "id", "anonymous")
|
||||
if hasattr(request, "user")
|
||||
else "unknown"
|
||||
),
|
||||
"user_id": (getattr(request.user, "id", "anonymous") if hasattr(request, "user") else "unknown"),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -196,11 +184,7 @@ def log_api_request(
|
||||
"request_type": "api",
|
||||
"path": getattr(request, "path", "unknown"),
|
||||
"method": getattr(request, "method", "unknown"),
|
||||
"user_id": (
|
||||
getattr(request.user, "id", "anonymous")
|
||||
if hasattr(request, "user")
|
||||
else "unknown"
|
||||
),
|
||||
"user_id": (getattr(request.user, "id", "anonymous") if hasattr(request, "user") else "unknown"),
|
||||
"response_status": response_status,
|
||||
"duration_ms": duration_ms,
|
||||
}
|
||||
@@ -246,11 +230,7 @@ def log_security_event(
|
||||
{
|
||||
"request_path": getattr(request, "path", "unknown"),
|
||||
"request_method": getattr(request, "method", "unknown"),
|
||||
"user_id": (
|
||||
getattr(request.user, "id", "anonymous")
|
||||
if hasattr(request, "user")
|
||||
else "unknown"
|
||||
),
|
||||
"user_id": (getattr(request.user, "id", "anonymous") if hasattr(request, "user") else "unknown"),
|
||||
"remote_addr": request.META.get("REMOTE_ADDR", "unknown"),
|
||||
"user_agent": request.META.get("HTTP_USER_AGENT", "unknown"),
|
||||
}
|
||||
|
||||
@@ -43,9 +43,7 @@ class Command(BaseCommand):
|
||||
default=50,
|
||||
help="Maximum number of results to calculate (default: 50)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="Enable verbose output"
|
||||
)
|
||||
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
|
||||
|
||||
def handle(self, *args, **options):
|
||||
content_type = options["content_type"]
|
||||
@@ -61,17 +59,13 @@ class Command(BaseCommand):
|
||||
new_items = []
|
||||
|
||||
if content_type in ["all", "parks"]:
|
||||
parks = self._get_new_parks(
|
||||
cutoff_date, limit if content_type == "parks" else limit * 2
|
||||
)
|
||||
parks = self._get_new_parks(cutoff_date, limit if content_type == "parks" else limit * 2)
|
||||
new_items.extend(parks)
|
||||
if verbose:
|
||||
self.stdout.write(f"Found {len(parks)} new parks")
|
||||
|
||||
if content_type in ["all", "rides"]:
|
||||
rides = self._get_new_rides(
|
||||
cutoff_date, limit if content_type == "rides" else limit * 2
|
||||
)
|
||||
rides = self._get_new_rides(cutoff_date, limit if content_type == "rides" else limit * 2)
|
||||
new_items.extend(rides)
|
||||
if verbose:
|
||||
self.stdout.write(f"Found {len(rides)} new rides")
|
||||
@@ -88,27 +82,22 @@ class Command(BaseCommand):
|
||||
cache.set(cache_key, formatted_results, 1800) # Cache for 30 minutes
|
||||
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(
|
||||
f"Successfully calculated {len(formatted_results)} new items for {content_type}"
|
||||
)
|
||||
self.style.SUCCESS(f"Successfully calculated {len(formatted_results)} new items for {content_type}")
|
||||
)
|
||||
|
||||
if verbose:
|
||||
for item in formatted_results[:5]: # Show first 5 items
|
||||
self.stdout.write(
|
||||
f" {item['name']} ({item['park']}) - opened: {item['date_opened']}"
|
||||
)
|
||||
self.stdout.write(f" {item['name']} ({item['park']}) - opened: {item['date_opened']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating new content: {e}", exc_info=True)
|
||||
raise CommandError(f"Failed to calculate new content: {e}")
|
||||
raise CommandError(f"Failed to calculate new content: {e}") from None
|
||||
|
||||
def _get_new_parks(self, cutoff_date: datetime, limit: int) -> list[dict[str, Any]]:
|
||||
"""Get recently added parks using real data."""
|
||||
new_parks = (
|
||||
Park.objects.filter(
|
||||
Q(created_at__gte=cutoff_date)
|
||||
| Q(opening_date__gte=cutoff_date.date()),
|
||||
Q(created_at__gte=cutoff_date) | Q(opening_date__gte=cutoff_date.date()),
|
||||
status="OPERATING",
|
||||
)
|
||||
.select_related("location", "operator")
|
||||
@@ -146,8 +135,7 @@ class Command(BaseCommand):
|
||||
"""Get recently added rides using real data."""
|
||||
new_rides = (
|
||||
Ride.objects.filter(
|
||||
Q(created_at__gte=cutoff_date)
|
||||
| Q(opening_date__gte=cutoff_date.date()),
|
||||
Q(created_at__gte=cutoff_date) | Q(opening_date__gte=cutoff_date.date()),
|
||||
status="OPERATING",
|
||||
)
|
||||
.select_related("park", "park__location")
|
||||
@@ -156,9 +144,7 @@ class Command(BaseCommand):
|
||||
|
||||
results = []
|
||||
for ride in new_rides:
|
||||
date_added = getattr(ride, "opening_date", None) or getattr(
|
||||
ride, "created_at", None
|
||||
)
|
||||
date_added = getattr(ride, "opening_date", None) or getattr(ride, "created_at", None)
|
||||
if date_added and isinstance(date_added, datetime):
|
||||
date_added = date_added.date()
|
||||
|
||||
@@ -184,9 +170,7 @@ class Command(BaseCommand):
|
||||
|
||||
return results
|
||||
|
||||
def _format_new_content_results(
|
||||
self, new_items: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
def _format_new_content_results(self, new_items: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Format new content results for frontend consumption."""
|
||||
formatted_results = []
|
||||
|
||||
|
||||
@@ -37,9 +37,7 @@ class Command(BaseCommand):
|
||||
default=50,
|
||||
help="Maximum number of results to calculate (default: 50)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true", help="Enable verbose output"
|
||||
)
|
||||
parser.add_argument("--verbose", action="store_true", help="Enable verbose output")
|
||||
|
||||
def handle(self, *args, **options):
|
||||
content_type = options["content_type"]
|
||||
@@ -98,29 +96,23 @@ class Command(BaseCommand):
|
||||
|
||||
if verbose:
|
||||
for item in formatted_results[:5]: # Show first 5 items
|
||||
self.stdout.write(
|
||||
f" {item['name']} (score: {item.get('views_change', 'N/A')})"
|
||||
)
|
||||
self.stdout.write(f" {item['name']} (score: {item.get('views_change', 'N/A')})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating trending content: {e}", exc_info=True)
|
||||
raise CommandError(f"Failed to calculate trending content: {e}")
|
||||
raise CommandError(f"Failed to calculate trending content: {e}") from None
|
||||
|
||||
def _calculate_trending_parks(
|
||||
self, current_period_hours: int, previous_period_hours: int, limit: int
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Calculate trending scores for parks using real data."""
|
||||
parks = Park.objects.filter(status="OPERATING").select_related(
|
||||
"location", "operator"
|
||||
)
|
||||
parks = Park.objects.filter(status="OPERATING").select_related("location", "operator")
|
||||
|
||||
trending_parks = []
|
||||
|
||||
for park in parks:
|
||||
try:
|
||||
score = self._calculate_content_score(
|
||||
park, "park", current_period_hours, previous_period_hours
|
||||
)
|
||||
score = self._calculate_content_score(park, "park", current_period_hours, previous_period_hours)
|
||||
if score > 0: # Only include items with positive trending scores
|
||||
trending_parks.append(
|
||||
{
|
||||
@@ -132,16 +124,8 @@ class Command(BaseCommand):
|
||||
"slug": park.slug,
|
||||
"park": park.name, # For parks, park field is the park name itself
|
||||
"category": "park",
|
||||
"rating": (
|
||||
float(park.average_rating)
|
||||
if park.average_rating
|
||||
else 0.0
|
||||
),
|
||||
"date_opened": (
|
||||
park.opening_date.isoformat()
|
||||
if park.opening_date
|
||||
else ""
|
||||
),
|
||||
"rating": (float(park.average_rating) if park.average_rating else 0.0),
|
||||
"date_opened": (park.opening_date.isoformat() if park.opening_date else ""),
|
||||
"url": park.url,
|
||||
}
|
||||
)
|
||||
@@ -154,17 +138,13 @@ class Command(BaseCommand):
|
||||
self, current_period_hours: int, previous_period_hours: int, limit: int
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Calculate trending scores for rides using real data."""
|
||||
rides = Ride.objects.filter(status="OPERATING").select_related(
|
||||
"park", "park__location"
|
||||
)
|
||||
rides = Ride.objects.filter(status="OPERATING").select_related("park", "park__location")
|
||||
|
||||
trending_rides = []
|
||||
|
||||
for ride in rides:
|
||||
try:
|
||||
score = self._calculate_content_score(
|
||||
ride, "ride", current_period_hours, previous_period_hours
|
||||
)
|
||||
score = self._calculate_content_score(ride, "ride", current_period_hours, previous_period_hours)
|
||||
if score > 0: # Only include items with positive trending scores
|
||||
trending_rides.append(
|
||||
{
|
||||
@@ -176,16 +156,8 @@ class Command(BaseCommand):
|
||||
"slug": ride.slug,
|
||||
"park": ride.park.name if ride.park else "",
|
||||
"category": "ride",
|
||||
"rating": (
|
||||
float(ride.average_rating)
|
||||
if ride.average_rating
|
||||
else 0.0
|
||||
),
|
||||
"date_opened": (
|
||||
ride.opening_date.isoformat()
|
||||
if ride.opening_date
|
||||
else ""
|
||||
),
|
||||
"rating": (float(ride.average_rating) if ride.average_rating else 0.0),
|
||||
"date_opened": (ride.opening_date.isoformat() if ride.opening_date else ""),
|
||||
"url": ride.url,
|
||||
"park_url": ride.park.url if ride.park else "",
|
||||
}
|
||||
@@ -219,24 +191,15 @@ class Command(BaseCommand):
|
||||
recency_score = self._calculate_recency_score(content_obj)
|
||||
|
||||
# 4. Popularity Score (10% weight)
|
||||
popularity_score = self._calculate_popularity_score(
|
||||
ct, content_obj.id, current_period_hours
|
||||
)
|
||||
popularity_score = self._calculate_popularity_score(ct, content_obj.id, current_period_hours)
|
||||
|
||||
# Calculate weighted final score
|
||||
final_score = (
|
||||
view_growth_score * 0.4
|
||||
+ rating_score * 0.3
|
||||
+ recency_score * 0.2
|
||||
+ popularity_score * 0.1
|
||||
)
|
||||
final_score = view_growth_score * 0.4 + rating_score * 0.3 + recency_score * 0.2 + popularity_score * 0.1
|
||||
|
||||
return final_score
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error calculating score for {content_type} {content_obj.id}: {e}"
|
||||
)
|
||||
logger.error(f"Error calculating score for {content_type} {content_obj.id}: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_view_growth_score(
|
||||
@@ -248,13 +211,11 @@ class Command(BaseCommand):
|
||||
) -> float:
|
||||
"""Calculate normalized view growth score using real PageView data."""
|
||||
try:
|
||||
current_views, previous_views, growth_percentage = (
|
||||
PageView.get_views_growth(
|
||||
content_type,
|
||||
object_id,
|
||||
current_period_hours,
|
||||
previous_period_hours,
|
||||
)
|
||||
current_views, previous_views, growth_percentage = PageView.get_views_growth(
|
||||
content_type,
|
||||
object_id,
|
||||
current_period_hours,
|
||||
previous_period_hours,
|
||||
)
|
||||
|
||||
if previous_views == 0:
|
||||
@@ -262,9 +223,7 @@ class Command(BaseCommand):
|
||||
return min(current_views / 100.0, 1.0) if current_views > 0 else 0.0
|
||||
|
||||
# Normalize growth percentage to 0-1 scale
|
||||
normalized_growth = (
|
||||
min(growth_percentage / 500.0, 1.0) if growth_percentage > 0 else 0.0
|
||||
)
|
||||
normalized_growth = min(growth_percentage / 500.0, 1.0) if growth_percentage > 0 else 0.0
|
||||
return max(normalized_growth, 0.0)
|
||||
|
||||
except Exception as e:
|
||||
@@ -317,14 +276,10 @@ class Command(BaseCommand):
|
||||
logger.warning(f"Error calculating recency score: {e}")
|
||||
return 0.5
|
||||
|
||||
def _calculate_popularity_score(
|
||||
self, content_type: ContentType, object_id: int, hours: int
|
||||
) -> float:
|
||||
def _calculate_popularity_score(self, content_type: ContentType, object_id: int, hours: int) -> float:
|
||||
"""Calculate popularity score based on total view count."""
|
||||
try:
|
||||
total_views = PageView.get_total_views_count(
|
||||
content_type, object_id, hours=hours
|
||||
)
|
||||
total_views = PageView.get_total_views_count(content_type, object_id, hours=hours)
|
||||
|
||||
# Normalize views to 0-1 scale
|
||||
if total_views == 0:
|
||||
@@ -352,13 +307,11 @@ class Command(BaseCommand):
|
||||
# Get view change for display
|
||||
content_obj = item["content_object"]
|
||||
ct = ContentType.objects.get_for_model(content_obj)
|
||||
current_views, previous_views, growth_percentage = (
|
||||
PageView.get_views_growth(
|
||||
ct,
|
||||
content_obj.id,
|
||||
current_period_hours,
|
||||
previous_period_hours,
|
||||
)
|
||||
current_views, previous_views, growth_percentage = PageView.get_views_growth(
|
||||
ct,
|
||||
content_obj.id,
|
||||
current_period_hours,
|
||||
previous_period_hours,
|
||||
)
|
||||
|
||||
# Format exactly as frontend expects
|
||||
@@ -371,9 +324,7 @@ class Command(BaseCommand):
|
||||
"rank": rank,
|
||||
"views": current_views,
|
||||
"views_change": (
|
||||
f"+{growth_percentage:.1f}%"
|
||||
if growth_percentage > 0
|
||||
else f"{growth_percentage:.1f}%"
|
||||
f"+{growth_percentage:.1f}%" if growth_percentage > 0 else f"{growth_percentage:.1f}%"
|
||||
),
|
||||
"slug": item["slug"],
|
||||
"date_opened": item["date_opened"],
|
||||
|
||||
@@ -21,10 +21,7 @@ from django.core.management.base import BaseCommand
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = (
|
||||
"Clear all types of cache data including Django cache, "
|
||||
"__pycache__, and build caches"
|
||||
)
|
||||
help = "Clear all types of cache data including Django cache, " "__pycache__, and build caches"
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
@@ -92,9 +89,7 @@ class Command(BaseCommand):
|
||||
)
|
||||
|
||||
if self.dry_run:
|
||||
self.stdout.write(
|
||||
self.style.WARNING("🔍 DRY RUN MODE - No files will be deleted")
|
||||
)
|
||||
self.stdout.write(self.style.WARNING("🔍 DRY RUN MODE - No files will be deleted"))
|
||||
self.stdout.write("")
|
||||
|
||||
self.stdout.write(self.style.SUCCESS("🧹 ThrillWiki Cache Clearing Utility"))
|
||||
@@ -129,9 +124,7 @@ class Command(BaseCommand):
|
||||
self.clear_opcache()
|
||||
|
||||
self.stdout.write("")
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS("✅ Cache clearing completed successfully!")
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS("✅ Cache clearing completed successfully!"))
|
||||
|
||||
def clear_django_cache(self):
|
||||
"""Clear Django cache framework cache."""
|
||||
@@ -154,23 +147,13 @@ class Command(BaseCommand):
|
||||
if not self.dry_run:
|
||||
cache_backend.clear()
|
||||
|
||||
cache_info = (
|
||||
f"{alias} cache ({cache_backend.__class__.__name__})"
|
||||
)
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(f" ✅ Cleared {cache_info}")
|
||||
)
|
||||
cache_info = f"{alias} cache ({cache_backend.__class__.__name__})"
|
||||
self.stdout.write(self.style.SUCCESS(f" ✅ Cleared {cache_info}"))
|
||||
except Exception as e:
|
||||
self.stdout.write(
|
||||
self.style.WARNING(
|
||||
f" ⚠️ Could not clear {alias} cache: {e}"
|
||||
)
|
||||
)
|
||||
self.stdout.write(self.style.WARNING(f" ⚠️ Could not clear {alias} cache: {e}"))
|
||||
|
||||
except Exception as e:
|
||||
self.stdout.write(
|
||||
self.style.ERROR(f" ❌ Error clearing Django cache: {e}")
|
||||
)
|
||||
self.stdout.write(self.style.ERROR(f" ❌ Error clearing Django cache: {e}"))
|
||||
|
||||
def clear_pycache(self):
|
||||
"""Clear Python __pycache__ directories and .pyc files."""
|
||||
@@ -188,11 +171,7 @@ class Command(BaseCommand):
|
||||
if pycache_dir.is_dir():
|
||||
try:
|
||||
# Calculate size before removal
|
||||
dir_size = sum(
|
||||
f.stat().st_size
|
||||
for f in pycache_dir.rglob("*")
|
||||
if f.is_file()
|
||||
)
|
||||
dir_size = sum(f.stat().st_size for f in pycache_dir.rglob("*") if f.is_file())
|
||||
removed_size += dir_size
|
||||
|
||||
if self.verbose:
|
||||
@@ -203,11 +182,7 @@ class Command(BaseCommand):
|
||||
|
||||
removed_count += 1
|
||||
except Exception as e:
|
||||
self.stdout.write(
|
||||
self.style.WARNING(
|
||||
f" ⚠️ Could not remove {pycache_dir}: {e}"
|
||||
)
|
||||
)
|
||||
self.stdout.write(self.style.WARNING(f" ⚠️ Could not remove {pycache_dir}: {e}"))
|
||||
|
||||
# Find and remove .pyc files
|
||||
for pyc_file in project_root.rglob("*.pyc"):
|
||||
@@ -223,22 +198,14 @@ class Command(BaseCommand):
|
||||
|
||||
removed_count += 1
|
||||
except Exception as e:
|
||||
self.stdout.write(
|
||||
self.style.WARNING(f" ⚠️ Could not remove {pyc_file}: {e}")
|
||||
)
|
||||
self.stdout.write(self.style.WARNING(f" ⚠️ Could not remove {pyc_file}: {e}"))
|
||||
|
||||
# Format file size
|
||||
size_mb = removed_size / (1024 * 1024)
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(
|
||||
f" ✅ Removed {removed_count} Python cache items ({size_mb:.2f} MB)"
|
||||
)
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS(f" ✅ Removed {removed_count} Python cache items ({size_mb:.2f} MB)"))
|
||||
|
||||
except Exception as e:
|
||||
self.stdout.write(
|
||||
self.style.ERROR(f" ❌ Error clearing Python cache: {e}")
|
||||
)
|
||||
self.stdout.write(self.style.ERROR(f" ❌ Error clearing Python cache: {e}"))
|
||||
|
||||
def clear_static_cache(self):
|
||||
"""Clear static files cache."""
|
||||
@@ -251,9 +218,7 @@ class Command(BaseCommand):
|
||||
static_path = Path(static_root)
|
||||
|
||||
# Calculate size
|
||||
total_size = sum(
|
||||
f.stat().st_size for f in static_path.rglob("*") if f.is_file()
|
||||
)
|
||||
total_size = sum(f.stat().st_size for f in static_path.rglob("*") if f.is_file())
|
||||
size_mb = total_size / (1024 * 1024)
|
||||
|
||||
if self.verbose:
|
||||
@@ -263,22 +228,12 @@ class Command(BaseCommand):
|
||||
shutil.rmtree(static_path)
|
||||
static_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(
|
||||
f" ✅ Cleared static files cache ({size_mb:.2f} MB)"
|
||||
)
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS(f" ✅ Cleared static files cache ({size_mb:.2f} MB)"))
|
||||
else:
|
||||
self.stdout.write(
|
||||
self.style.WARNING(
|
||||
" ⚠️ No STATIC_ROOT configured or directory doesn't exist"
|
||||
)
|
||||
)
|
||||
self.stdout.write(self.style.WARNING(" ⚠️ No STATIC_ROOT configured or directory doesn't exist"))
|
||||
|
||||
except Exception as e:
|
||||
self.stdout.write(
|
||||
self.style.ERROR(f" ❌ Error clearing static cache: {e}")
|
||||
)
|
||||
self.stdout.write(self.style.ERROR(f" ❌ Error clearing static cache: {e}"))
|
||||
|
||||
def clear_sessions_cache(self):
|
||||
"""Clear session cache if using cache-based sessions."""
|
||||
@@ -289,9 +244,7 @@ class Command(BaseCommand):
|
||||
|
||||
if "cache" in session_engine:
|
||||
# Using cache-based sessions
|
||||
session_cache_alias = getattr(
|
||||
settings, "SESSION_CACHE_ALIAS", "default"
|
||||
)
|
||||
session_cache_alias = getattr(settings, "SESSION_CACHE_ALIAS", "default")
|
||||
session_cache = caches[session_cache_alias]
|
||||
|
||||
if not self.dry_run:
|
||||
@@ -299,20 +252,12 @@ class Command(BaseCommand):
|
||||
# In production, you might want more sophisticated session clearing
|
||||
session_cache.clear()
|
||||
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(
|
||||
f" ✅ Cleared cache-based sessions ({session_cache_alias})"
|
||||
)
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS(f" ✅ Cleared cache-based sessions ({session_cache_alias})"))
|
||||
else:
|
||||
self.stdout.write(
|
||||
self.style.WARNING(" ⚠️ Not using cache-based sessions")
|
||||
)
|
||||
self.stdout.write(self.style.WARNING(" ⚠️ Not using cache-based sessions"))
|
||||
|
||||
except Exception as e:
|
||||
self.stdout.write(
|
||||
self.style.ERROR(f" ❌ Error clearing session cache: {e}")
|
||||
)
|
||||
self.stdout.write(self.style.ERROR(f" ❌ Error clearing session cache: {e}"))
|
||||
|
||||
def clear_template_cache(self):
|
||||
"""Clear template cache."""
|
||||
@@ -332,18 +277,14 @@ class Command(BaseCommand):
|
||||
# Get engine instance safely
|
||||
engine_instance = getattr(engine, "engine", None)
|
||||
if engine_instance:
|
||||
template_loaders = getattr(
|
||||
engine_instance, "template_loaders", []
|
||||
)
|
||||
template_loaders = getattr(engine_instance, "template_loaders", [])
|
||||
for loader in template_loaders:
|
||||
if isinstance(loader, CachedLoader):
|
||||
if not self.dry_run:
|
||||
loader.reset()
|
||||
cleared_engines += 1
|
||||
if self.verbose:
|
||||
self.stdout.write(
|
||||
f" 🗑️ Cleared cached loader: {loader}"
|
||||
)
|
||||
self.stdout.write(f" 🗑️ Cleared cached loader: {loader}")
|
||||
|
||||
# Check for Jinja2 engines (if present)
|
||||
elif "Jinja2" in engine_backend and hasattr(engine, "env"):
|
||||
@@ -353,34 +294,21 @@ class Command(BaseCommand):
|
||||
env.cache.clear()
|
||||
cleared_engines += 1
|
||||
if self.verbose:
|
||||
self.stdout.write(
|
||||
f" 🗑️ Cleared Jinja2 cache: {engine}"
|
||||
)
|
||||
self.stdout.write(f" 🗑️ Cleared Jinja2 cache: {engine}")
|
||||
|
||||
except Exception as e:
|
||||
if self.verbose:
|
||||
self.stdout.write(
|
||||
self.style.WARNING(
|
||||
f" ⚠️ Could not clear cache for engine {engine}: {e}"
|
||||
)
|
||||
)
|
||||
self.stdout.write(self.style.WARNING(f" ⚠️ Could not clear cache for engine {engine}: {e}"))
|
||||
|
||||
if cleared_engines > 0:
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(
|
||||
f" ✅ Cleared template cache for "
|
||||
f"{cleared_engines} loaders/engines"
|
||||
)
|
||||
self.style.SUCCESS(f" ✅ Cleared template cache for " f"{cleared_engines} loaders/engines")
|
||||
)
|
||||
else:
|
||||
self.stdout.write(
|
||||
self.style.WARNING(" ⚠️ No cached template loaders found")
|
||||
)
|
||||
self.stdout.write(self.style.WARNING(" ⚠️ No cached template loaders found"))
|
||||
|
||||
except Exception as e:
|
||||
self.stdout.write(
|
||||
self.style.ERROR(f" ❌ Error clearing template cache: {e}")
|
||||
)
|
||||
self.stdout.write(self.style.ERROR(f" ❌ Error clearing template cache: {e}"))
|
||||
|
||||
def clear_tailwind_cache(self):
|
||||
"""Clear Tailwind CSS build cache."""
|
||||
@@ -410,27 +338,15 @@ class Command(BaseCommand):
|
||||
|
||||
cleared_count += 1
|
||||
except Exception as e:
|
||||
self.stdout.write(
|
||||
self.style.WARNING(
|
||||
f" ⚠️ Could not remove {cache_path}: {e}"
|
||||
)
|
||||
)
|
||||
self.stdout.write(self.style.WARNING(f" ⚠️ Could not remove {cache_path}: {e}"))
|
||||
|
||||
if cleared_count > 0:
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(
|
||||
f" ✅ Cleared {cleared_count} Tailwind cache directories"
|
||||
)
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS(f" ✅ Cleared {cleared_count} Tailwind cache directories"))
|
||||
else:
|
||||
self.stdout.write(
|
||||
self.style.WARNING(" ⚠️ No Tailwind cache directories found")
|
||||
)
|
||||
self.stdout.write(self.style.WARNING(" ⚠️ No Tailwind cache directories found"))
|
||||
|
||||
except Exception as e:
|
||||
self.stdout.write(
|
||||
self.style.ERROR(f" ❌ Error clearing Tailwind cache: {e}")
|
||||
)
|
||||
self.stdout.write(self.style.ERROR(f" ❌ Error clearing Tailwind cache: {e}"))
|
||||
|
||||
def clear_opcache(self):
|
||||
"""Clear PHP OPcache if available."""
|
||||
@@ -452,21 +368,13 @@ class Command(BaseCommand):
|
||||
|
||||
if result.returncode == 0:
|
||||
if "cleared" in result.stdout:
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(" ✅ OPcache cleared successfully")
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS(" ✅ OPcache cleared successfully"))
|
||||
else:
|
||||
self.stdout.write(self.style.WARNING(" ⚠️ OPcache not available"))
|
||||
else:
|
||||
self.stdout.write(
|
||||
self.style.WARNING(
|
||||
" ⚠️ PHP not available or OPcache not accessible"
|
||||
)
|
||||
)
|
||||
self.stdout.write(self.style.WARNING(" ⚠️ PHP not available or OPcache not accessible"))
|
||||
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
self.stdout.write(
|
||||
self.style.WARNING(" ⚠️ PHP not found or not accessible")
|
||||
)
|
||||
self.stdout.write(self.style.WARNING(" ⚠️ PHP not found or not accessible"))
|
||||
except Exception as e:
|
||||
self.stdout.write(self.style.ERROR(f" ❌ Error clearing OPcache: {e}"))
|
||||
|
||||
@@ -15,69 +15,69 @@ from apps.core.state_machine.config import callback_config
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = 'List all registered FSM transition callbacks'
|
||||
help = "List all registered FSM transition callbacks"
|
||||
|
||||
def add_arguments(self, parser: CommandParser) -> None:
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
"--model",
|
||||
type=str,
|
||||
help='Filter by model name (e.g., EditSubmission, Ride)',
|
||||
help="Filter by model name (e.g., EditSubmission, Ride)",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--stage',
|
||||
"--stage",
|
||||
type=str,
|
||||
choices=['pre', 'post', 'error', 'all'],
|
||||
default='all',
|
||||
help='Filter by callback stage',
|
||||
choices=["pre", "post", "error", "all"],
|
||||
default="all",
|
||||
help="Filter by callback stage",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose',
|
||||
'-v',
|
||||
action='store_true',
|
||||
help='Show detailed callback information',
|
||||
"--verbose",
|
||||
"-v",
|
||||
action="store_true",
|
||||
help="Show detailed callback information",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--format',
|
||||
"--format",
|
||||
type=str,
|
||||
choices=['text', 'table', 'json'],
|
||||
default='text',
|
||||
help='Output format',
|
||||
choices=["text", "table", "json"],
|
||||
default="text",
|
||||
help="Output format",
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
model_filter = options.get('model')
|
||||
stage_filter = options.get('stage')
|
||||
verbose = options.get('verbose', False)
|
||||
output_format = options.get('format', 'text')
|
||||
model_filter = options.get("model")
|
||||
stage_filter = options.get("stage")
|
||||
verbose = options.get("verbose", False)
|
||||
output_format = options.get("format", "text")
|
||||
|
||||
# Get all registrations
|
||||
all_registrations = callback_registry.get_all_registrations()
|
||||
|
||||
if output_format == 'json':
|
||||
if output_format == "json":
|
||||
self._output_json(all_registrations, model_filter, stage_filter)
|
||||
elif output_format == 'table':
|
||||
elif output_format == "table":
|
||||
self._output_table(all_registrations, model_filter, stage_filter, verbose)
|
||||
else:
|
||||
self._output_text(all_registrations, model_filter, stage_filter, verbose)
|
||||
|
||||
def _output_text(self, registrations, model_filter, stage_filter, verbose):
|
||||
"""Output in text format."""
|
||||
self.stdout.write(self.style.SUCCESS('\n=== FSM Transition Callbacks ===\n'))
|
||||
self.stdout.write(self.style.SUCCESS("\n=== FSM Transition Callbacks ===\n"))
|
||||
|
||||
# Group by model
|
||||
models_seen = set()
|
||||
total_callbacks = 0
|
||||
|
||||
for stage in CallbackStage:
|
||||
if stage_filter != 'all' and stage.value != stage_filter:
|
||||
if stage_filter != "all" and stage.value != stage_filter:
|
||||
continue
|
||||
|
||||
stage_regs = registrations.get(stage, [])
|
||||
if not stage_regs:
|
||||
continue
|
||||
|
||||
self.stdout.write(self.style.WARNING(f'\n{stage.value.upper()} Callbacks:'))
|
||||
self.stdout.write('-' * 50)
|
||||
self.stdout.write(self.style.WARNING(f"\n{stage.value.upper()} Callbacks:"))
|
||||
self.stdout.write("-" * 50)
|
||||
|
||||
# Group by model
|
||||
by_model = {}
|
||||
@@ -92,42 +92,34 @@ class Command(BaseCommand):
|
||||
total_callbacks += 1
|
||||
|
||||
for model_name, regs in sorted(by_model.items()):
|
||||
self.stdout.write(f'\n {model_name}:')
|
||||
self.stdout.write(f"\n {model_name}:")
|
||||
for reg in regs:
|
||||
transition = f'{reg.source} → {reg.target}'
|
||||
transition = f"{reg.source} → {reg.target}"
|
||||
callback_name = reg.callback.name
|
||||
priority = reg.callback.priority
|
||||
|
||||
self.stdout.write(
|
||||
f' [{transition}] {callback_name} (priority: {priority})'
|
||||
)
|
||||
self.stdout.write(f" [{transition}] {callback_name} (priority: {priority})")
|
||||
|
||||
if verbose:
|
||||
self.stdout.write(
|
||||
f' continue_on_error: {reg.callback.continue_on_error}'
|
||||
)
|
||||
if hasattr(reg.callback, 'patterns'):
|
||||
self.stdout.write(
|
||||
f' patterns: {reg.callback.patterns}'
|
||||
)
|
||||
self.stdout.write(f" continue_on_error: {reg.callback.continue_on_error}")
|
||||
if hasattr(reg.callback, "patterns"):
|
||||
self.stdout.write(f" patterns: {reg.callback.patterns}")
|
||||
|
||||
# Summary
|
||||
self.stdout.write('\n' + '=' * 50)
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f'Total: {total_callbacks} callbacks across {len(models_seen)} models'
|
||||
))
|
||||
self.stdout.write("\n" + "=" * 50)
|
||||
self.stdout.write(self.style.SUCCESS(f"Total: {total_callbacks} callbacks across {len(models_seen)} models"))
|
||||
|
||||
# Configuration status
|
||||
self.stdout.write(self.style.WARNING('\nConfiguration Status:'))
|
||||
self.stdout.write(f' Callbacks enabled: {callback_config.enabled}')
|
||||
self.stdout.write(f' Notifications enabled: {callback_config.notifications_enabled}')
|
||||
self.stdout.write(f' Cache invalidation enabled: {callback_config.cache_invalidation_enabled}')
|
||||
self.stdout.write(f' Related updates enabled: {callback_config.related_updates_enabled}')
|
||||
self.stdout.write(f' Debug mode: {callback_config.debug_mode}')
|
||||
self.stdout.write(self.style.WARNING("\nConfiguration Status:"))
|
||||
self.stdout.write(f" Callbacks enabled: {callback_config.enabled}")
|
||||
self.stdout.write(f" Notifications enabled: {callback_config.notifications_enabled}")
|
||||
self.stdout.write(f" Cache invalidation enabled: {callback_config.cache_invalidation_enabled}")
|
||||
self.stdout.write(f" Related updates enabled: {callback_config.related_updates_enabled}")
|
||||
self.stdout.write(f" Debug mode: {callback_config.debug_mode}")
|
||||
|
||||
def _output_table(self, registrations, model_filter, stage_filter, verbose):
|
||||
"""Output in table format."""
|
||||
self.stdout.write(self.style.SUCCESS('\n=== FSM Transition Callbacks ===\n'))
|
||||
self.stdout.write(self.style.SUCCESS("\n=== FSM Transition Callbacks ===\n"))
|
||||
|
||||
# Header
|
||||
if verbose:
|
||||
@@ -136,10 +128,10 @@ class Command(BaseCommand):
|
||||
header = f"{'Model':<20} {'Source':<15} {'Target':<15} {'Stage':<8} {'Callback':<30}"
|
||||
|
||||
self.stdout.write(self.style.WARNING(header))
|
||||
self.stdout.write('-' * len(header))
|
||||
self.stdout.write("-" * len(header))
|
||||
|
||||
for stage in CallbackStage:
|
||||
if stage_filter != 'all' and stage.value != stage_filter:
|
||||
if stage_filter != "all" and stage.value != stage_filter:
|
||||
continue
|
||||
|
||||
stage_regs = registrations.get(stage, [])
|
||||
@@ -160,18 +152,18 @@ class Command(BaseCommand):
|
||||
import json
|
||||
|
||||
output = {
|
||||
'callbacks': [],
|
||||
'configuration': {
|
||||
'enabled': callback_config.enabled,
|
||||
'notifications_enabled': callback_config.notifications_enabled,
|
||||
'cache_invalidation_enabled': callback_config.cache_invalidation_enabled,
|
||||
'related_updates_enabled': callback_config.related_updates_enabled,
|
||||
'debug_mode': callback_config.debug_mode,
|
||||
}
|
||||
"callbacks": [],
|
||||
"configuration": {
|
||||
"enabled": callback_config.enabled,
|
||||
"notifications_enabled": callback_config.notifications_enabled,
|
||||
"cache_invalidation_enabled": callback_config.cache_invalidation_enabled,
|
||||
"related_updates_enabled": callback_config.related_updates_enabled,
|
||||
"debug_mode": callback_config.debug_mode,
|
||||
},
|
||||
}
|
||||
|
||||
for stage in CallbackStage:
|
||||
if stage_filter != 'all' and stage.value != stage_filter:
|
||||
if stage_filter != "all" and stage.value != stage_filter:
|
||||
continue
|
||||
|
||||
stage_regs = registrations.get(stage, [])
|
||||
@@ -180,15 +172,17 @@ class Command(BaseCommand):
|
||||
if model_filter and model_name != model_filter:
|
||||
continue
|
||||
|
||||
output['callbacks'].append({
|
||||
'model': model_name,
|
||||
'field': reg.field_name,
|
||||
'source': reg.source,
|
||||
'target': reg.target,
|
||||
'stage': stage.value,
|
||||
'callback': reg.callback.name,
|
||||
'priority': reg.callback.priority,
|
||||
'continue_on_error': reg.callback.continue_on_error,
|
||||
})
|
||||
output["callbacks"].append(
|
||||
{
|
||||
"model": model_name,
|
||||
"field": reg.field_name,
|
||||
"source": reg.source,
|
||||
"target": reg.target,
|
||||
"stage": stage.value,
|
||||
"callback": reg.callback.name,
|
||||
"priority": reg.callback.priority,
|
||||
"continue_on_error": reg.callback.continue_on_error,
|
||||
}
|
||||
)
|
||||
|
||||
self.stdout.write(json.dumps(output, indent=2))
|
||||
|
||||
@@ -52,26 +52,17 @@ class Command(BaseCommand):
|
||||
import rjsmin
|
||||
except ImportError:
|
||||
rjsmin = None
|
||||
self.stdout.write(
|
||||
self.style.WARNING(
|
||||
"rjsmin not installed. Install with: pip install rjsmin"
|
||||
)
|
||||
)
|
||||
self.stdout.write(self.style.WARNING("rjsmin not installed. Install with: pip install rjsmin"))
|
||||
|
||||
try:
|
||||
import rcssmin
|
||||
except ImportError:
|
||||
rcssmin = None
|
||||
self.stdout.write(
|
||||
self.style.WARNING(
|
||||
"rcssmin not installed. Install with: pip install rcssmin"
|
||||
)
|
||||
)
|
||||
self.stdout.write(self.style.WARNING("rcssmin not installed. Install with: pip install rcssmin"))
|
||||
|
||||
if not rjsmin and not rcssmin:
|
||||
raise CommandError(
|
||||
"Neither rjsmin nor rcssmin is installed. "
|
||||
"Install at least one: pip install rjsmin rcssmin"
|
||||
"Neither rjsmin nor rcssmin is installed. " "Install at least one: pip install rjsmin rcssmin"
|
||||
)
|
||||
|
||||
# Get static file directories
|
||||
@@ -93,9 +84,7 @@ class Command(BaseCommand):
|
||||
if not css_only and rjsmin:
|
||||
js_dir = static_dir / "js"
|
||||
if js_dir.exists():
|
||||
saved, count = self._process_js_files(
|
||||
js_dir, rjsmin, dry_run, force
|
||||
)
|
||||
saved, count = self._process_js_files(js_dir, rjsmin, dry_run, force)
|
||||
total_js_saved += saved
|
||||
js_files_processed += count
|
||||
|
||||
@@ -103,9 +92,7 @@ class Command(BaseCommand):
|
||||
if not js_only and rcssmin:
|
||||
css_dir = static_dir / "css"
|
||||
if css_dir.exists():
|
||||
saved, count = self._process_css_files(
|
||||
css_dir, rcssmin, dry_run, force
|
||||
)
|
||||
saved, count = self._process_css_files(css_dir, rcssmin, dry_run, force)
|
||||
total_css_saved += saved
|
||||
css_files_processed += count
|
||||
|
||||
@@ -114,17 +101,11 @@ class Command(BaseCommand):
|
||||
self.stdout.write(self.style.SUCCESS("Static file optimization complete!"))
|
||||
self.stdout.write(f"JavaScript files processed: {js_files_processed}")
|
||||
self.stdout.write(f"CSS files processed: {css_files_processed}")
|
||||
self.stdout.write(
|
||||
f"Total JS savings: {self._format_size(total_js_saved)}"
|
||||
)
|
||||
self.stdout.write(
|
||||
f"Total CSS savings: {self._format_size(total_css_saved)}"
|
||||
)
|
||||
self.stdout.write(f"Total JS savings: {self._format_size(total_js_saved)}")
|
||||
self.stdout.write(f"Total CSS savings: {self._format_size(total_css_saved)}")
|
||||
|
||||
if dry_run:
|
||||
self.stdout.write(
|
||||
self.style.WARNING("\nDry run - no files were modified")
|
||||
)
|
||||
self.stdout.write(self.style.WARNING("\nDry run - no files were modified"))
|
||||
|
||||
def _process_js_files(self, js_dir, rjsmin, dry_run, force):
|
||||
"""Process JavaScript files for minification."""
|
||||
@@ -140,9 +121,7 @@ class Command(BaseCommand):
|
||||
|
||||
# Skip if minified version exists and not forcing
|
||||
if min_file.exists() and not force:
|
||||
self.stdout.write(
|
||||
f" Skipping {js_file.name} (min version exists)"
|
||||
)
|
||||
self.stdout.write(f" Skipping {js_file.name} (min version exists)")
|
||||
continue
|
||||
|
||||
try:
|
||||
@@ -169,9 +148,7 @@ class Command(BaseCommand):
|
||||
files_processed += 1
|
||||
|
||||
except Exception as e:
|
||||
self.stdout.write(
|
||||
self.style.ERROR(f" Error processing {js_file.name}: {e}")
|
||||
)
|
||||
self.stdout.write(self.style.ERROR(f" Error processing {js_file.name}: {e}"))
|
||||
|
||||
return total_saved, files_processed
|
||||
|
||||
@@ -189,9 +166,7 @@ class Command(BaseCommand):
|
||||
|
||||
# Skip if minified version exists and not forcing
|
||||
if min_file.exists() and not force:
|
||||
self.stdout.write(
|
||||
f" Skipping {css_file.name} (min version exists)"
|
||||
)
|
||||
self.stdout.write(f" Skipping {css_file.name} (min version exists)")
|
||||
continue
|
||||
|
||||
try:
|
||||
@@ -218,9 +193,7 @@ class Command(BaseCommand):
|
||||
files_processed += 1
|
||||
|
||||
except Exception as e:
|
||||
self.stdout.write(
|
||||
self.style.ERROR(f" Error processing {css_file.name}: {e}")
|
||||
)
|
||||
self.stdout.write(self.style.ERROR(f" Error processing {css_file.name}: {e}"))
|
||||
|
||||
return total_saved, files_processed
|
||||
|
||||
|
||||
@@ -39,19 +39,13 @@ class Command(BaseCommand):
|
||||
def handle(self, *args, **options):
|
||||
"""Run the development setup and start the server."""
|
||||
if not options["skip_setup"]:
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(
|
||||
"🚀 Setting up and starting ThrillWiki Development Server..."
|
||||
)
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS("🚀 Setting up and starting ThrillWiki Development Server..."))
|
||||
|
||||
# Run the setup_dev command first
|
||||
execute_from_command_line(["manage.py", "setup_dev"])
|
||||
|
||||
else:
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS("🚀 Starting ThrillWiki Development Server...")
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS("🚀 Starting ThrillWiki Development Server..."))
|
||||
|
||||
# Determine which server command to use
|
||||
self.get_server_command(options)
|
||||
@@ -59,9 +53,7 @@ class Command(BaseCommand):
|
||||
# Start the server
|
||||
self.stdout.write("")
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(
|
||||
f"🌟 Starting Django development server on http://{options['host']}:{options['port']}"
|
||||
)
|
||||
self.style.SUCCESS(f"🌟 Starting Django development server on http://{options['host']}:{options['port']}")
|
||||
)
|
||||
self.stdout.write("Press Ctrl+C to stop the server")
|
||||
self.stdout.write("")
|
||||
@@ -76,9 +68,7 @@ class Command(BaseCommand):
|
||||
]
|
||||
)
|
||||
else:
|
||||
execute_from_command_line(
|
||||
["manage.py", "runserver", f"{options['host']}:{options['port']}"]
|
||||
)
|
||||
execute_from_command_line(["manage.py", "runserver", f"{options['host']}:{options['port']}"])
|
||||
except KeyboardInterrupt:
|
||||
self.stdout.write("")
|
||||
self.stdout.write(self.style.SUCCESS("👋 Development server stopped"))
|
||||
|
||||
@@ -16,23 +16,23 @@ from django.core.management.base import BaseCommand
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = 'Run security audit and generate a report'
|
||||
help = "Run security audit and generate a report"
|
||||
|
||||
def add_arguments(self, parser):
|
||||
parser.add_argument(
|
||||
'--output',
|
||||
"--output",
|
||||
type=str,
|
||||
help='Output file for the security report',
|
||||
help="Output file for the security report",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose',
|
||||
action='store_true',
|
||||
help='Show detailed information for each check',
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Show detailed information for each check",
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
self.verbose = options.get('verbose', False)
|
||||
output_file = options.get('output')
|
||||
self.verbose = options.get("verbose", False)
|
||||
output_file = options.get("output")
|
||||
|
||||
report_lines = []
|
||||
|
||||
@@ -66,11 +66,9 @@ class Command(BaseCommand):
|
||||
|
||||
# Write to file if specified
|
||||
if output_file:
|
||||
with open(output_file, 'w') as f:
|
||||
f.write('\n'.join(report_lines))
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(f'\nReport saved to: {output_file}')
|
||||
)
|
||||
with open(output_file, "w") as f:
|
||||
f.write("\n".join(report_lines))
|
||||
self.stdout.write(self.style.SUCCESS(f"\nReport saved to: {output_file}"))
|
||||
|
||||
def log(self, message, report_lines):
|
||||
"""Log message to both stdout and report."""
|
||||
@@ -82,10 +80,7 @@ class Command(BaseCommand):
|
||||
errors = registry.run_checks(tags=[Tags.security])
|
||||
|
||||
if not errors:
|
||||
self.log(
|
||||
self.style.SUCCESS(" ✓ All Django security checks passed"),
|
||||
report_lines
|
||||
)
|
||||
self.log(self.style.SUCCESS(" ✓ All Django security checks passed"), report_lines)
|
||||
else:
|
||||
for error in errors:
|
||||
prefix = self.style.ERROR(" ✗ ERROR") if error.is_serious() else self.style.WARNING(" ! WARNING")
|
||||
@@ -97,73 +92,71 @@ class Command(BaseCommand):
|
||||
def check_configuration(self, report_lines):
|
||||
"""Check various configuration settings."""
|
||||
checks = [
|
||||
('DEBUG mode', not settings.DEBUG, 'DEBUG should be False'),
|
||||
("DEBUG mode", not settings.DEBUG, "DEBUG should be False"),
|
||||
("SECRET_KEY length", len(settings.SECRET_KEY) >= 50, f"Length: {len(settings.SECRET_KEY)}"),
|
||||
(
|
||||
'SECRET_KEY length',
|
||||
len(settings.SECRET_KEY) >= 50,
|
||||
f'Length: {len(settings.SECRET_KEY)}'
|
||||
"ALLOWED_HOSTS",
|
||||
bool(settings.ALLOWED_HOSTS) and "*" not in settings.ALLOWED_HOSTS,
|
||||
str(settings.ALLOWED_HOSTS),
|
||||
),
|
||||
(
|
||||
'ALLOWED_HOSTS',
|
||||
bool(settings.ALLOWED_HOSTS) and '*' not in settings.ALLOWED_HOSTS,
|
||||
str(settings.ALLOWED_HOSTS)
|
||||
"CSRF_TRUSTED_ORIGINS",
|
||||
bool(getattr(settings, "CSRF_TRUSTED_ORIGINS", [])),
|
||||
str(getattr(settings, "CSRF_TRUSTED_ORIGINS", [])),
|
||||
),
|
||||
(
|
||||
'CSRF_TRUSTED_ORIGINS',
|
||||
bool(getattr(settings, 'CSRF_TRUSTED_ORIGINS', [])),
|
||||
str(getattr(settings, 'CSRF_TRUSTED_ORIGINS', []))
|
||||
"X_FRAME_OPTIONS",
|
||||
getattr(settings, "X_FRAME_OPTIONS", "") in ("DENY", "SAMEORIGIN"),
|
||||
str(getattr(settings, "X_FRAME_OPTIONS", "Not set")),
|
||||
),
|
||||
(
|
||||
'X_FRAME_OPTIONS',
|
||||
getattr(settings, 'X_FRAME_OPTIONS', '') in ('DENY', 'SAMEORIGIN'),
|
||||
str(getattr(settings, 'X_FRAME_OPTIONS', 'Not set'))
|
||||
"SECURE_CONTENT_TYPE_NOSNIFF",
|
||||
getattr(settings, "SECURE_CONTENT_TYPE_NOSNIFF", False),
|
||||
str(getattr(settings, "SECURE_CONTENT_TYPE_NOSNIFF", False)),
|
||||
),
|
||||
(
|
||||
'SECURE_CONTENT_TYPE_NOSNIFF',
|
||||
getattr(settings, 'SECURE_CONTENT_TYPE_NOSNIFF', False),
|
||||
str(getattr(settings, 'SECURE_CONTENT_TYPE_NOSNIFF', False))
|
||||
"SECURE_BROWSER_XSS_FILTER",
|
||||
getattr(settings, "SECURE_BROWSER_XSS_FILTER", False),
|
||||
str(getattr(settings, "SECURE_BROWSER_XSS_FILTER", False)),
|
||||
),
|
||||
(
|
||||
'SECURE_BROWSER_XSS_FILTER',
|
||||
getattr(settings, 'SECURE_BROWSER_XSS_FILTER', False),
|
||||
str(getattr(settings, 'SECURE_BROWSER_XSS_FILTER', False))
|
||||
"SESSION_COOKIE_HTTPONLY",
|
||||
getattr(settings, "SESSION_COOKIE_HTTPONLY", True),
|
||||
str(getattr(settings, "SESSION_COOKIE_HTTPONLY", "Not set")),
|
||||
),
|
||||
(
|
||||
'SESSION_COOKIE_HTTPONLY',
|
||||
getattr(settings, 'SESSION_COOKIE_HTTPONLY', True),
|
||||
str(getattr(settings, 'SESSION_COOKIE_HTTPONLY', 'Not set'))
|
||||
),
|
||||
(
|
||||
'CSRF_COOKIE_HTTPONLY',
|
||||
getattr(settings, 'CSRF_COOKIE_HTTPONLY', True),
|
||||
str(getattr(settings, 'CSRF_COOKIE_HTTPONLY', 'Not set'))
|
||||
"CSRF_COOKIE_HTTPONLY",
|
||||
getattr(settings, "CSRF_COOKIE_HTTPONLY", True),
|
||||
str(getattr(settings, "CSRF_COOKIE_HTTPONLY", "Not set")),
|
||||
),
|
||||
]
|
||||
|
||||
# Production-only checks
|
||||
if not settings.DEBUG:
|
||||
checks.extend([
|
||||
(
|
||||
'SECURE_SSL_REDIRECT',
|
||||
getattr(settings, 'SECURE_SSL_REDIRECT', False),
|
||||
str(getattr(settings, 'SECURE_SSL_REDIRECT', False))
|
||||
),
|
||||
(
|
||||
'SESSION_COOKIE_SECURE',
|
||||
getattr(settings, 'SESSION_COOKIE_SECURE', False),
|
||||
str(getattr(settings, 'SESSION_COOKIE_SECURE', False))
|
||||
),
|
||||
(
|
||||
'CSRF_COOKIE_SECURE',
|
||||
getattr(settings, 'CSRF_COOKIE_SECURE', False),
|
||||
str(getattr(settings, 'CSRF_COOKIE_SECURE', False))
|
||||
),
|
||||
(
|
||||
'SECURE_HSTS_SECONDS',
|
||||
getattr(settings, 'SECURE_HSTS_SECONDS', 0) >= 31536000,
|
||||
str(getattr(settings, 'SECURE_HSTS_SECONDS', 0))
|
||||
),
|
||||
])
|
||||
checks.extend(
|
||||
[
|
||||
(
|
||||
"SECURE_SSL_REDIRECT",
|
||||
getattr(settings, "SECURE_SSL_REDIRECT", False),
|
||||
str(getattr(settings, "SECURE_SSL_REDIRECT", False)),
|
||||
),
|
||||
(
|
||||
"SESSION_COOKIE_SECURE",
|
||||
getattr(settings, "SESSION_COOKIE_SECURE", False),
|
||||
str(getattr(settings, "SESSION_COOKIE_SECURE", False)),
|
||||
),
|
||||
(
|
||||
"CSRF_COOKIE_SECURE",
|
||||
getattr(settings, "CSRF_COOKIE_SECURE", False),
|
||||
str(getattr(settings, "CSRF_COOKIE_SECURE", False)),
|
||||
),
|
||||
(
|
||||
"SECURE_HSTS_SECONDS",
|
||||
getattr(settings, "SECURE_HSTS_SECONDS", 0) >= 31536000,
|
||||
str(getattr(settings, "SECURE_HSTS_SECONDS", 0)),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
for name, is_secure, value in checks:
|
||||
status = self.style.SUCCESS("✓") if is_secure else self.style.WARNING("!")
|
||||
@@ -176,59 +169,43 @@ class Command(BaseCommand):
|
||||
|
||||
def check_middleware(self, report_lines):
|
||||
"""Check security-related middleware is properly configured."""
|
||||
middleware = getattr(settings, 'MIDDLEWARE', [])
|
||||
middleware = getattr(settings, "MIDDLEWARE", [])
|
||||
|
||||
required_middleware = [
|
||||
('django.middleware.security.SecurityMiddleware', 'SecurityMiddleware'),
|
||||
('django.middleware.csrf.CsrfViewMiddleware', 'CSRF Middleware'),
|
||||
('django.middleware.clickjacking.XFrameOptionsMiddleware', 'X-Frame-Options'),
|
||||
("django.middleware.security.SecurityMiddleware", "SecurityMiddleware"),
|
||||
("django.middleware.csrf.CsrfViewMiddleware", "CSRF Middleware"),
|
||||
("django.middleware.clickjacking.XFrameOptionsMiddleware", "X-Frame-Options"),
|
||||
]
|
||||
|
||||
custom_security_middleware = [
|
||||
('apps.core.middleware.security_headers.SecurityHeadersMiddleware', 'Security Headers'),
|
||||
('apps.core.middleware.rate_limiting.AuthRateLimitMiddleware', 'Rate Limiting'),
|
||||
("apps.core.middleware.security_headers.SecurityHeadersMiddleware", "Security Headers"),
|
||||
("apps.core.middleware.rate_limiting.AuthRateLimitMiddleware", "Rate Limiting"),
|
||||
]
|
||||
|
||||
# Check required middleware
|
||||
for mw_path, mw_name in required_middleware:
|
||||
if mw_path in middleware:
|
||||
self.log(
|
||||
f" {self.style.SUCCESS('✓')} {mw_name} is enabled",
|
||||
report_lines
|
||||
)
|
||||
self.log(f" {self.style.SUCCESS('✓')} {mw_name} is enabled", report_lines)
|
||||
else:
|
||||
self.log(
|
||||
f" {self.style.ERROR('✗')} {mw_name} is NOT enabled",
|
||||
report_lines
|
||||
)
|
||||
self.log(f" {self.style.ERROR('✗')} {mw_name} is NOT enabled", report_lines)
|
||||
|
||||
# Check custom security middleware
|
||||
for mw_path, mw_name in custom_security_middleware:
|
||||
if mw_path in middleware:
|
||||
self.log(
|
||||
f" {self.style.SUCCESS('✓')} {mw_name} is enabled",
|
||||
report_lines
|
||||
)
|
||||
self.log(f" {self.style.SUCCESS('✓')} {mw_name} is enabled", report_lines)
|
||||
else:
|
||||
self.log(
|
||||
f" {self.style.WARNING('!')} {mw_name} is not enabled (optional)",
|
||||
report_lines
|
||||
)
|
||||
self.log(f" {self.style.WARNING('!')} {mw_name} is not enabled (optional)", report_lines)
|
||||
|
||||
# Check middleware order
|
||||
try:
|
||||
security_idx = middleware.index('django.middleware.security.SecurityMiddleware')
|
||||
session_idx = middleware.index('django.contrib.sessions.middleware.SessionMiddleware')
|
||||
security_idx = middleware.index("django.middleware.security.SecurityMiddleware")
|
||||
session_idx = middleware.index("django.contrib.sessions.middleware.SessionMiddleware")
|
||||
|
||||
if security_idx < session_idx:
|
||||
self.log(
|
||||
f" {self.style.SUCCESS('✓')} Middleware ordering is correct",
|
||||
report_lines
|
||||
)
|
||||
self.log(f" {self.style.SUCCESS('✓')} Middleware ordering is correct", report_lines)
|
||||
else:
|
||||
self.log(
|
||||
f" {self.style.WARNING('!')} SecurityMiddleware should come before SessionMiddleware",
|
||||
report_lines
|
||||
f" {self.style.WARNING('!')} SecurityMiddleware should come before SessionMiddleware", report_lines
|
||||
)
|
||||
except ValueError:
|
||||
pass # Middleware not found, already reported above
|
||||
|
||||
@@ -39,9 +39,7 @@ class Command(BaseCommand):
|
||||
|
||||
def handle(self, *args, **options):
|
||||
"""Run the development setup process."""
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS("🚀 Setting up ThrillWiki Development Environment...")
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS("🚀 Setting up ThrillWiki Development Environment..."))
|
||||
|
||||
# Create necessary directories
|
||||
self.create_directories()
|
||||
@@ -71,9 +69,7 @@ class Command(BaseCommand):
|
||||
# Display environment info
|
||||
self.display_environment_info()
|
||||
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS("✅ Development environment setup complete!")
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS("✅ Development environment setup complete!"))
|
||||
|
||||
def create_directories(self):
|
||||
"""Create necessary directories."""
|
||||
@@ -99,36 +95,24 @@ class Command(BaseCommand):
|
||||
)
|
||||
|
||||
if result.returncode == 0:
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS("✅ Database migrations are up to date")
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS("✅ Database migrations are up to date"))
|
||||
else:
|
||||
self.stdout.write("🔄 Running database migrations...")
|
||||
subprocess.run(
|
||||
["uv", "run", "manage.py", "migrate", "--noinput"], check=True
|
||||
)
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS("✅ Database migrations completed")
|
||||
)
|
||||
subprocess.run(["uv", "run", "manage.py", "migrate", "--noinput"], check=True)
|
||||
self.stdout.write(self.style.SUCCESS("✅ Database migrations completed"))
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
self.stdout.write(
|
||||
self.style.WARNING(f"⚠️ Migration error (continuing): {e}")
|
||||
)
|
||||
self.stdout.write(self.style.WARNING(f"⚠️ Migration error (continuing): {e}"))
|
||||
|
||||
def seed_sample_data(self):
|
||||
"""Seed sample data to the database."""
|
||||
self.stdout.write("🌱 Seeding sample data...")
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
["uv", "run", "manage.py", "seed_sample_data"], check=True
|
||||
)
|
||||
subprocess.run(["uv", "run", "manage.py", "seed_sample_data"], check=True)
|
||||
self.stdout.write(self.style.SUCCESS("✅ Sample data seeded"))
|
||||
except subprocess.CalledProcessError:
|
||||
self.stdout.write(
|
||||
self.style.WARNING("⚠️ Could not seed sample data (continuing)")
|
||||
)
|
||||
self.stdout.write(self.style.WARNING("⚠️ Could not seed sample data (continuing)"))
|
||||
|
||||
def create_superuser(self):
|
||||
"""Create development superuser if it doesn't exist."""
|
||||
@@ -145,13 +129,9 @@ class Command(BaseCommand):
|
||||
self.stdout.write("👤 Creating development superuser (admin/admin)...")
|
||||
if not User.objects.filter(username="admin").exists():
|
||||
User.objects.create_superuser("admin", "admin@example.com", "admin")
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS("✅ Created superuser: admin/admin")
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS("✅ Created superuser: admin/admin"))
|
||||
else:
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS("✅ Admin user already exists")
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS("✅ Admin user already exists"))
|
||||
|
||||
except Exception as e:
|
||||
self.stdout.write(self.style.WARNING(f"⚠️ Could not create superuser: {e}"))
|
||||
@@ -167,9 +147,7 @@ class Command(BaseCommand):
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS("✅ Static files collected"))
|
||||
except subprocess.CalledProcessError as e:
|
||||
self.stdout.write(
|
||||
self.style.WARNING(f"⚠️ Could not collect static files: {e}")
|
||||
)
|
||||
self.stdout.write(self.style.WARNING(f"⚠️ Could not collect static files: {e}"))
|
||||
|
||||
def build_tailwind(self):
|
||||
"""Build Tailwind CSS if npm is available."""
|
||||
@@ -180,17 +158,11 @@ class Command(BaseCommand):
|
||||
subprocess.run(["npm", "--version"], capture_output=True, check=True)
|
||||
|
||||
# Build Tailwind CSS
|
||||
subprocess.run(
|
||||
["uv", "run", "manage.py", "tailwind", "build"], check=True
|
||||
)
|
||||
subprocess.run(["uv", "run", "manage.py", "tailwind", "build"], check=True)
|
||||
self.stdout.write(self.style.SUCCESS("✅ Tailwind CSS built"))
|
||||
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
self.stdout.write(
|
||||
self.style.WARNING(
|
||||
"⚠️ npm not found or Tailwind build failed, skipping"
|
||||
)
|
||||
)
|
||||
self.stdout.write(self.style.WARNING("⚠️ npm not found or Tailwind build failed, skipping"))
|
||||
|
||||
def run_system_checks(self):
|
||||
"""Run Django system checks."""
|
||||
@@ -200,9 +172,7 @@ class Command(BaseCommand):
|
||||
subprocess.run(["uv", "run", "manage.py", "check"], check=True)
|
||||
self.stdout.write(self.style.SUCCESS("✅ System checks passed"))
|
||||
except subprocess.CalledProcessError:
|
||||
self.stdout.write(
|
||||
self.style.WARNING("❌ System checks failed, but continuing...")
|
||||
)
|
||||
self.stdout.write(self.style.WARNING("❌ System checks failed, but continuing..."))
|
||||
|
||||
def display_environment_info(self):
|
||||
"""Display development environment information."""
|
||||
|
||||
@@ -18,62 +18,62 @@ from apps.core.state_machine.monitoring import callback_monitor
|
||||
|
||||
|
||||
class Command(BaseCommand):
|
||||
help = 'Test FSM transition callbacks for specific transitions'
|
||||
help = "Test FSM transition callbacks for specific transitions"
|
||||
|
||||
def add_arguments(self, parser: CommandParser) -> None:
|
||||
parser.add_argument(
|
||||
'model',
|
||||
"model",
|
||||
type=str,
|
||||
help='Model name (e.g., EditSubmission, Ride, Park)',
|
||||
help="Model name (e.g., EditSubmission, Ride, Park)",
|
||||
)
|
||||
parser.add_argument(
|
||||
'source',
|
||||
"source",
|
||||
type=str,
|
||||
help='Source state value',
|
||||
help="Source state value",
|
||||
)
|
||||
parser.add_argument(
|
||||
'target',
|
||||
"target",
|
||||
type=str,
|
||||
help='Target state value',
|
||||
help="Target state value",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--instance-id',
|
||||
"--instance-id",
|
||||
type=int,
|
||||
help='ID of an existing instance to use for testing',
|
||||
help="ID of an existing instance to use for testing",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--user-id',
|
||||
"--user-id",
|
||||
type=int,
|
||||
help='ID of user to use for testing',
|
||||
help="ID of user to use for testing",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dry-run',
|
||||
action='store_true',
|
||||
help='Show what would be executed without running callbacks',
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Show what would be executed without running callbacks",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--stage',
|
||||
"--stage",
|
||||
type=str,
|
||||
choices=['pre', 'post', 'error', 'all'],
|
||||
default='all',
|
||||
help='Which callback stage to test',
|
||||
choices=["pre", "post", "error", "all"],
|
||||
default="all",
|
||||
help="Which callback stage to test",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--field',
|
||||
"--field",
|
||||
type=str,
|
||||
default='status',
|
||||
help='FSM field name (default: status)',
|
||||
default="status",
|
||||
help="FSM field name (default: status)",
|
||||
)
|
||||
|
||||
def handle(self, *args, **options):
|
||||
model_name = options['model']
|
||||
source = options['source']
|
||||
target = options['target']
|
||||
instance_id = options.get('instance_id')
|
||||
user_id = options.get('user_id')
|
||||
dry_run = options.get('dry_run', False)
|
||||
stage_filter = options.get('stage', 'all')
|
||||
field_name = options.get('field', 'status')
|
||||
model_name = options["model"]
|
||||
source = options["source"]
|
||||
target = options["target"]
|
||||
instance_id = options.get("instance_id")
|
||||
user_id = options.get("user_id")
|
||||
dry_run = options.get("dry_run", False)
|
||||
stage_filter = options.get("stage", "all")
|
||||
field_name = options.get("field", "status")
|
||||
|
||||
# Find the model class
|
||||
model_class = self._find_model(model_name)
|
||||
@@ -90,7 +90,7 @@ class Command(BaseCommand):
|
||||
try:
|
||||
user = User.objects.get(pk=user_id)
|
||||
except User.DoesNotExist:
|
||||
raise CommandError(f"User with ID {user_id} not found")
|
||||
raise CommandError(f"User with ID {user_id} not found") from None
|
||||
|
||||
# Create transition context
|
||||
context = TransitionContext(
|
||||
@@ -101,19 +101,21 @@ class Command(BaseCommand):
|
||||
user=user,
|
||||
)
|
||||
|
||||
self.stdout.write(self.style.SUCCESS(
|
||||
f'\n=== Testing Transition Callbacks ===\n'
|
||||
f'Model: {model_name}\n'
|
||||
f'Transition: {source} → {target}\n'
|
||||
f'Field: {field_name}\n'
|
||||
f'Instance: {instance}\n'
|
||||
f'User: {user}\n'
|
||||
f'Dry Run: {dry_run}\n'
|
||||
))
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS(
|
||||
f"\n=== Testing Transition Callbacks ===\n"
|
||||
f"Model: {model_name}\n"
|
||||
f"Transition: {source} → {target}\n"
|
||||
f"Field: {field_name}\n"
|
||||
f"Instance: {instance}\n"
|
||||
f"User: {user}\n"
|
||||
f"Dry Run: {dry_run}\n"
|
||||
)
|
||||
)
|
||||
|
||||
# Get callbacks for each stage
|
||||
stages_to_test = []
|
||||
if stage_filter == 'all':
|
||||
if stage_filter == "all":
|
||||
stages_to_test = [CallbackStage.PRE, CallbackStage.POST, CallbackStage.ERROR]
|
||||
else:
|
||||
stages_to_test = [CallbackStage(stage_filter)]
|
||||
@@ -123,83 +125,69 @@ class Command(BaseCommand):
|
||||
total_failures = 0
|
||||
|
||||
for stage in stages_to_test:
|
||||
callbacks = callback_registry.get_callbacks(
|
||||
model_class, field_name, source, target, stage
|
||||
)
|
||||
callbacks = callback_registry.get_callbacks(model_class, field_name, source, target, stage)
|
||||
|
||||
if not callbacks:
|
||||
self.stdout.write(
|
||||
self.style.WARNING(f'\nNo {stage.value.upper()} callbacks registered')
|
||||
)
|
||||
self.stdout.write(self.style.WARNING(f"\nNo {stage.value.upper()} callbacks registered"))
|
||||
continue
|
||||
|
||||
self.stdout.write(
|
||||
self.style.WARNING(f'\n{stage.value.upper()} Callbacks ({len(callbacks)}):')
|
||||
)
|
||||
self.stdout.write('-' * 50)
|
||||
self.stdout.write(self.style.WARNING(f"\n{stage.value.upper()} Callbacks ({len(callbacks)}):"))
|
||||
self.stdout.write("-" * 50)
|
||||
|
||||
for callback in callbacks:
|
||||
total_callbacks += 1
|
||||
callback_info = (
|
||||
f' {callback.name} (priority: {callback.priority}, '
|
||||
f'continue_on_error: {callback.continue_on_error})'
|
||||
f" {callback.name} (priority: {callback.priority}, "
|
||||
f"continue_on_error: {callback.continue_on_error})"
|
||||
)
|
||||
|
||||
if dry_run:
|
||||
self.stdout.write(callback_info)
|
||||
self.stdout.write(self.style.NOTICE(' → Would execute (dry run)'))
|
||||
self.stdout.write(self.style.NOTICE(" → Would execute (dry run)"))
|
||||
else:
|
||||
self.stdout.write(callback_info)
|
||||
|
||||
# Check should_execute
|
||||
if not callback.should_execute(context):
|
||||
self.stdout.write(
|
||||
self.style.WARNING(' → Skipped (should_execute returned False)')
|
||||
)
|
||||
self.stdout.write(self.style.WARNING(" → Skipped (should_execute returned False)"))
|
||||
continue
|
||||
|
||||
# Execute callback
|
||||
try:
|
||||
if stage == CallbackStage.ERROR:
|
||||
result = callback.execute(
|
||||
context,
|
||||
exception=Exception("Test exception")
|
||||
)
|
||||
result = callback.execute(context, exception=Exception("Test exception"))
|
||||
else:
|
||||
result = callback.execute(context)
|
||||
|
||||
if result:
|
||||
self.stdout.write(self.style.SUCCESS(' → Success'))
|
||||
self.stdout.write(self.style.SUCCESS(" → Success"))
|
||||
total_success += 1
|
||||
else:
|
||||
self.stdout.write(self.style.ERROR(' → Failed (returned False)'))
|
||||
self.stdout.write(self.style.ERROR(" → Failed (returned False)"))
|
||||
total_failures += 1
|
||||
|
||||
except Exception as e:
|
||||
self.stdout.write(
|
||||
self.style.ERROR(f' → Exception: {type(e).__name__}: {e}')
|
||||
)
|
||||
self.stdout.write(self.style.ERROR(f" → Exception: {type(e).__name__}: {e}"))
|
||||
total_failures += 1
|
||||
|
||||
# Summary
|
||||
self.stdout.write('\n' + '=' * 50)
|
||||
self.stdout.write(self.style.SUCCESS(f'Total callbacks: {total_callbacks}'))
|
||||
self.stdout.write("\n" + "=" * 50)
|
||||
self.stdout.write(self.style.SUCCESS(f"Total callbacks: {total_callbacks}"))
|
||||
if not dry_run:
|
||||
self.stdout.write(self.style.SUCCESS(f'Successful: {total_success}'))
|
||||
self.stdout.write(self.style.SUCCESS(f"Successful: {total_success}"))
|
||||
self.stdout.write(
|
||||
self.style.ERROR(f'Failed: {total_failures}') if total_failures
|
||||
else self.style.SUCCESS(f'Failed: {total_failures}')
|
||||
self.style.ERROR(f"Failed: {total_failures}")
|
||||
if total_failures
|
||||
else self.style.SUCCESS(f"Failed: {total_failures}")
|
||||
)
|
||||
|
||||
# Show monitoring stats if available
|
||||
if not dry_run:
|
||||
self.stdout.write(self.style.WARNING('\nRecent Executions:'))
|
||||
self.stdout.write(self.style.WARNING("\nRecent Executions:"))
|
||||
recent = callback_monitor.get_recent_executions(limit=10)
|
||||
for record in recent:
|
||||
status = '✓' if record.success else '✗'
|
||||
self.stdout.write(
|
||||
f' {status} {record.callback_name} [{record.duration_ms:.2f}ms]'
|
||||
)
|
||||
status = "✓" if record.success else "✗"
|
||||
self.stdout.write(f" {status} {record.callback_name} [{record.duration_ms:.2f}ms]")
|
||||
|
||||
def _find_model(self, model_name):
|
||||
"""Find a model class by name."""
|
||||
@@ -217,9 +205,7 @@ class Command(BaseCommand):
|
||||
try:
|
||||
return model_class.objects.get(pk=instance_id)
|
||||
except model_class.DoesNotExist:
|
||||
raise CommandError(
|
||||
f"{model_class.__name__} with ID {instance_id} not found"
|
||||
)
|
||||
raise CommandError(f"{model_class.__name__} with ID {instance_id} not found") from None
|
||||
|
||||
# Create a mock instance for testing
|
||||
# This won't be saved to the database
|
||||
@@ -227,8 +213,6 @@ class Command(BaseCommand):
|
||||
instance.pk = 0 # Fake ID
|
||||
setattr(instance, field_name, source)
|
||||
|
||||
self.stdout.write(self.style.NOTICE(
|
||||
'Using mock instance (no --instance-id provided)'
|
||||
))
|
||||
self.stdout.write(self.style.NOTICE("Using mock instance (no --instance-id provided)"))
|
||||
|
||||
return instance
|
||||
|
||||
@@ -37,9 +37,7 @@ class Command(BaseCommand):
|
||||
self.test_trending_algorithm()
|
||||
self.test_api_format()
|
||||
|
||||
self.stdout.write(
|
||||
self.style.SUCCESS("✓ Trending system test completed successfully!")
|
||||
)
|
||||
self.stdout.write(self.style.SUCCESS("✓ Trending system test completed successfully!"))
|
||||
|
||||
def clean_test_data(self):
|
||||
"""Clean existing test data."""
|
||||
@@ -101,9 +99,7 @@ class Command(BaseCommand):
|
||||
# Create parks
|
||||
parks = []
|
||||
for park_data in parks_data:
|
||||
park, created = Park.objects.get_or_create(
|
||||
name=park_data["name"], defaults=park_data
|
||||
)
|
||||
park, created = Park.objects.get_or_create(name=park_data["name"], defaults=park_data)
|
||||
parks.append(park)
|
||||
if created and self.verbose:
|
||||
self.stdout.write(f" Created park: {park.name}")
|
||||
@@ -151,9 +147,7 @@ class Command(BaseCommand):
|
||||
# Create rides
|
||||
rides = []
|
||||
for ride_data in rides_data:
|
||||
ride, created = Ride.objects.get_or_create(
|
||||
name=ride_data["name"], defaults=ride_data
|
||||
)
|
||||
ride, created = Ride.objects.get_or_create(name=ride_data["name"], defaults=ride_data)
|
||||
rides.append(ride)
|
||||
if created and self.verbose:
|
||||
self.stdout.write(f" Created ride: {ride.name}")
|
||||
@@ -169,48 +163,34 @@ class Command(BaseCommand):
|
||||
|
||||
# Pattern 1: Recently trending item (Steel Vengeance)
|
||||
steel_vengeance = next(r for r in rides if r.name == "Steel Vengeance")
|
||||
self.create_views_for_content(
|
||||
steel_vengeance, recent_views=50, older_views=10, base_time=now
|
||||
)
|
||||
self.create_views_for_content(steel_vengeance, recent_views=50, older_views=10, base_time=now)
|
||||
|
||||
# Pattern 2: Consistently popular item (Space Mountain)
|
||||
space_mountain = next(r for r in rides if r.name == "Space Mountain")
|
||||
self.create_views_for_content(
|
||||
space_mountain, recent_views=30, older_views=25, base_time=now
|
||||
)
|
||||
self.create_views_for_content(space_mountain, recent_views=30, older_views=25, base_time=now)
|
||||
|
||||
# Pattern 3: Declining popularity (Kingda Ka)
|
||||
kingda_ka = next(r for r in rides if r.name == "Kingda Ka")
|
||||
self.create_views_for_content(
|
||||
kingda_ka, recent_views=5, older_views=40, base_time=now
|
||||
)
|
||||
self.create_views_for_content(kingda_ka, recent_views=5, older_views=40, base_time=now)
|
||||
|
||||
# Pattern 4: New but growing (Millennium Force)
|
||||
millennium_force = next(r for r in rides if r.name == "Millennium Force")
|
||||
self.create_views_for_content(
|
||||
millennium_force, recent_views=25, older_views=5, base_time=now
|
||||
)
|
||||
self.create_views_for_content(millennium_force, recent_views=25, older_views=5, base_time=now)
|
||||
|
||||
# Create some park views too
|
||||
cedar_point = next(p for p in parks if p.name == "Cedar Point")
|
||||
self.create_views_for_content(
|
||||
cedar_point, recent_views=35, older_views=20, base_time=now
|
||||
)
|
||||
self.create_views_for_content(cedar_point, recent_views=35, older_views=20, base_time=now)
|
||||
|
||||
if self.verbose:
|
||||
self.stdout.write(" Created PageView data for trending analysis")
|
||||
|
||||
def create_views_for_content(
|
||||
self, content_object, recent_views, older_views, base_time
|
||||
):
|
||||
def create_views_for_content(self, content_object, recent_views, older_views, base_time):
|
||||
"""Create PageViews for a content object with specified patterns."""
|
||||
content_type = ContentType.objects.get_for_model(type(content_object))
|
||||
|
||||
# Create recent views (last 2 hours)
|
||||
for _i in range(recent_views):
|
||||
view_time = base_time - timedelta(
|
||||
minutes=random.randint(0, 120) # Last 2 hours
|
||||
)
|
||||
view_time = base_time - timedelta(minutes=random.randint(0, 120)) # Last 2 hours
|
||||
PageView.objects.create(
|
||||
content_type=content_type,
|
||||
object_id=content_object.id,
|
||||
@@ -235,15 +215,9 @@ class Command(BaseCommand):
|
||||
self.stdout.write("Testing trending algorithm...")
|
||||
|
||||
# Test trending content for different content types
|
||||
trending_parks = trending_service.get_trending_content(
|
||||
content_type="parks", limit=3
|
||||
)
|
||||
trending_rides = trending_service.get_trending_content(
|
||||
content_type="rides", limit=3
|
||||
)
|
||||
trending_all = trending_service.get_trending_content(
|
||||
content_type="all", limit=5
|
||||
)
|
||||
trending_parks = trending_service.get_trending_content(content_type="parks", limit=3)
|
||||
trending_rides = trending_service.get_trending_content(content_type="rides", limit=3)
|
||||
trending_all = trending_service.get_trending_content(content_type="all", limit=5)
|
||||
|
||||
# Test new content
|
||||
new_parks = trending_service.get_new_content(content_type="parks", limit=3)
|
||||
@@ -265,12 +239,8 @@ class Command(BaseCommand):
|
||||
self.stdout.write("Testing API response format...")
|
||||
|
||||
# Test trending content format
|
||||
trending_parks = trending_service.get_trending_content(
|
||||
content_type="parks", limit=3
|
||||
)
|
||||
trending_service.get_trending_content(
|
||||
content_type="rides", limit=3
|
||||
)
|
||||
trending_parks = trending_service.get_trending_content(content_type="parks", limit=3)
|
||||
trending_service.get_trending_content(content_type="rides", limit=3)
|
||||
|
||||
# Test new content format
|
||||
new_parks = trending_service.get_new_content(content_type="parks", limit=3)
|
||||
|
||||
@@ -94,13 +94,21 @@ class Command(BaseCommand):
|
||||
try:
|
||||
parks_list = list(
|
||||
Park.objects.select_related("location", "operator")
|
||||
.only("id", "name", "slug", "status", "location__city", "location__state_province", "location__country")
|
||||
.only(
|
||||
"id",
|
||||
"name",
|
||||
"slug",
|
||||
"status",
|
||||
"location__city",
|
||||
"location__state_province",
|
||||
"location__country",
|
||||
)
|
||||
.order_by("name")[:500]
|
||||
)
|
||||
cache_service.default_cache.set(
|
||||
"warm:park_list",
|
||||
[{"id": p.id, "name": p.name, "slug": p.slug} for p in parks_list],
|
||||
timeout=3600
|
||||
timeout=3600,
|
||||
)
|
||||
warmed_count += 1
|
||||
if verbose:
|
||||
@@ -116,11 +124,7 @@ class Command(BaseCommand):
|
||||
if not dry_run:
|
||||
try:
|
||||
status_counts = Park.objects.values("status").annotate(count=Count("id"))
|
||||
cache_service.default_cache.set(
|
||||
"warm:park_status_counts",
|
||||
list(status_counts),
|
||||
timeout=3600
|
||||
)
|
||||
cache_service.default_cache.set("warm:park_status_counts", list(status_counts), timeout=3600)
|
||||
warmed_count += 1
|
||||
if verbose:
|
||||
self.stdout.write(" Cached park status counts")
|
||||
@@ -141,8 +145,11 @@ class Command(BaseCommand):
|
||||
)
|
||||
cache_service.default_cache.set(
|
||||
"warm:popular_parks",
|
||||
[{"id": p.id, "name": p.name, "slug": p.slug, "ride_count": p.ride_count} for p in popular_parks],
|
||||
timeout=3600
|
||||
[
|
||||
{"id": p.id, "name": p.name, "slug": p.slug, "ride_count": p.ride_count}
|
||||
for p in popular_parks
|
||||
],
|
||||
timeout=3600,
|
||||
)
|
||||
warmed_count += 1
|
||||
if verbose:
|
||||
@@ -168,8 +175,11 @@ class Command(BaseCommand):
|
||||
)
|
||||
cache_service.default_cache.set(
|
||||
"warm:ride_list",
|
||||
[{"id": r.id, "name": r.name, "slug": r.slug, "park": r.park.name if r.park else None} for r in rides_list],
|
||||
timeout=3600
|
||||
[
|
||||
{"id": r.id, "name": r.name, "slug": r.slug, "park": r.park.name if r.park else None}
|
||||
for r in rides_list
|
||||
],
|
||||
timeout=3600,
|
||||
)
|
||||
warmed_count += 1
|
||||
if verbose:
|
||||
@@ -185,11 +195,7 @@ class Command(BaseCommand):
|
||||
if not dry_run:
|
||||
try:
|
||||
category_counts = Ride.objects.values("category").annotate(count=Count("id"))
|
||||
cache_service.default_cache.set(
|
||||
"warm:ride_category_counts",
|
||||
list(category_counts),
|
||||
timeout=3600
|
||||
)
|
||||
cache_service.default_cache.set("warm:ride_category_counts", list(category_counts), timeout=3600)
|
||||
warmed_count += 1
|
||||
if verbose:
|
||||
self.stdout.write(" Cached ride category counts")
|
||||
@@ -210,8 +216,16 @@ class Command(BaseCommand):
|
||||
)
|
||||
cache_service.default_cache.set(
|
||||
"warm:top_rated_rides",
|
||||
[{"id": r.id, "name": r.name, "slug": r.slug, "rating": float(r.average_rating) if r.average_rating else None} for r in top_rides],
|
||||
timeout=3600
|
||||
[
|
||||
{
|
||||
"id": r.id,
|
||||
"name": r.name,
|
||||
"slug": r.slug,
|
||||
"rating": float(r.average_rating) if r.average_rating else None,
|
||||
}
|
||||
for r in top_rides
|
||||
],
|
||||
timeout=3600,
|
||||
)
|
||||
warmed_count += 1
|
||||
if verbose:
|
||||
@@ -231,12 +245,9 @@ class Command(BaseCommand):
|
||||
try:
|
||||
# Park filter metadata
|
||||
from apps.parks.services.hybrid_loader import smart_park_loader
|
||||
|
||||
metadata = smart_park_loader.get_filter_metadata()
|
||||
cache_service.default_cache.set(
|
||||
"warm:park_filter_metadata",
|
||||
metadata,
|
||||
timeout=1800
|
||||
)
|
||||
cache_service.default_cache.set("warm:park_filter_metadata", metadata, timeout=1800)
|
||||
warmed_count += 1
|
||||
if verbose:
|
||||
self.stdout.write(" Cached park filter metadata")
|
||||
@@ -251,13 +262,10 @@ class Command(BaseCommand):
|
||||
try:
|
||||
# Ride filter metadata
|
||||
from apps.rides.services.hybrid_loader import SmartRideLoader
|
||||
|
||||
ride_loader = SmartRideLoader()
|
||||
metadata = ride_loader.get_filter_metadata()
|
||||
cache_service.default_cache.set(
|
||||
"warm:ride_filter_metadata",
|
||||
metadata,
|
||||
timeout=1800
|
||||
)
|
||||
cache_service.default_cache.set("warm:ride_filter_metadata", metadata, timeout=1800)
|
||||
warmed_count += 1
|
||||
if verbose:
|
||||
self.stdout.write(" Cached ride filter metadata")
|
||||
|
||||
@@ -92,9 +92,7 @@ class LocationQuerySet(BaseQuerySet):
|
||||
"""Filter locations near a geographic point."""
|
||||
if hasattr(self.model, "point"):
|
||||
return (
|
||||
self.filter(point__distance_lte=(point, Distance(km=distance_km)))
|
||||
.distance(point)
|
||||
.order_by("distance")
|
||||
self.filter(point__distance_lte=(point, Distance(km=distance_km))).distance(point).order_by("distance")
|
||||
)
|
||||
return self
|
||||
|
||||
@@ -138,9 +136,7 @@ class LocationManager(BaseManager):
|
||||
return self.get_queryset().near_point(point=point, distance_km=distance_km)
|
||||
|
||||
def within_bounds(self, *, north: float, south: float, east: float, west: float):
|
||||
return self.get_queryset().within_bounds(
|
||||
north=north, south=south, east=east, west=west
|
||||
)
|
||||
return self.get_queryset().within_bounds(north=north, south=south, east=east, west=west)
|
||||
|
||||
|
||||
class ReviewableQuerySet(BaseQuerySet):
|
||||
@@ -151,9 +147,7 @@ class ReviewableQuerySet(BaseQuerySet):
|
||||
return self.annotate(
|
||||
review_count=Count("reviews", filter=Q(reviews__is_published=True)),
|
||||
average_rating=Avg("reviews__rating", filter=Q(reviews__is_published=True)),
|
||||
latest_review_date=Max(
|
||||
"reviews__created_at", filter=Q(reviews__is_published=True)
|
||||
),
|
||||
latest_review_date=Max("reviews__created_at", filter=Q(reviews__is_published=True)),
|
||||
)
|
||||
|
||||
def highly_rated(self, *, min_rating: float = 8.0):
|
||||
@@ -163,9 +157,7 @@ class ReviewableQuerySet(BaseQuerySet):
|
||||
def recently_reviewed(self, *, days: int = 30):
|
||||
"""Filter for items with recent reviews."""
|
||||
cutoff_date = timezone.now() - timedelta(days=days)
|
||||
return self.filter(
|
||||
reviews__created_at__gte=cutoff_date, reviews__is_published=True
|
||||
).distinct()
|
||||
return self.filter(reviews__created_at__gte=cutoff_date, reviews__is_published=True).distinct()
|
||||
|
||||
|
||||
class ReviewableManager(BaseManager):
|
||||
@@ -237,9 +229,7 @@ class TimestampedManager(BaseManager):
|
||||
return TimestampedQuerySet(self.model, using=self._db)
|
||||
|
||||
def created_between(self, *, start_date, end_date):
|
||||
return self.get_queryset().created_between(
|
||||
start_date=start_date, end_date=end_date
|
||||
)
|
||||
return self.get_queryset().created_between(start_date=start_date, end_date=end_date)
|
||||
|
||||
|
||||
class StatusQuerySet(BaseQuerySet):
|
||||
|
||||
@@ -16,16 +16,10 @@ class RequestContextProvider(pghistory.context):
|
||||
|
||||
def __call__(self, request: WSGIRequest) -> dict:
|
||||
return {
|
||||
"user": (
|
||||
str(request.user)
|
||||
if request.user and not isinstance(request.user, AnonymousUser)
|
||||
else None
|
||||
),
|
||||
"user": (str(request.user) if request.user and not isinstance(request.user, AnonymousUser) else None),
|
||||
"ip": request.META.get("REMOTE_ADDR"),
|
||||
"user_agent": request.META.get("HTTP_USER_AGENT"),
|
||||
"session_key": (
|
||||
request.session.session_key if hasattr(request, "session") else None
|
||||
),
|
||||
"session_key": (request.session.session_key if hasattr(request, "session") else None),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Middleware for handling errors in HTMX requests.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from django.http import HttpResponseServerError
|
||||
|
||||
@@ -38,12 +38,8 @@ class APIResponseMiddleware(MiddlewareMixin):
|
||||
response["Vary"] = "Origin"
|
||||
|
||||
# Helpful dev CORS headers (adjust for your frontend requests)
|
||||
response["Access-Control-Allow-Methods"] = (
|
||||
"GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
)
|
||||
response["Access-Control-Allow-Headers"] = (
|
||||
"Authorization, Content-Type, X-Requested-With"
|
||||
)
|
||||
response["Access-Control-Allow-Methods"] = "GET, POST, PUT, PATCH, DELETE, OPTIONS"
|
||||
response["Access-Control-Allow-Headers"] = "Authorization, Content-Type, X-Requested-With"
|
||||
# Uncomment if your dev frontend needs to send cookies/auth credentials
|
||||
# response['Access-Control-Allow-Credentials'] = 'true'
|
||||
logger.debug(f"Added CORS headers for origin: {origin}")
|
||||
|
||||
@@ -19,9 +19,7 @@ class PerformanceMiddleware(MiddlewareMixin):
|
||||
def process_request(self, request):
|
||||
"""Initialize performance tracking for the request"""
|
||||
request._performance_start_time = time.time()
|
||||
request._performance_initial_queries = (
|
||||
len(connection.queries) if hasattr(connection, "queries") else 0
|
||||
)
|
||||
request._performance_initial_queries = len(connection.queries) if hasattr(connection, "queries") else 0
|
||||
|
||||
def process_response(self, request, response):
|
||||
"""Log performance metrics after response is ready"""
|
||||
@@ -42,11 +40,7 @@ class PerformanceMiddleware(MiddlewareMixin):
|
||||
duration = end_time - start_time
|
||||
|
||||
initial_queries = getattr(request, "_performance_initial_queries", 0)
|
||||
total_queries = (
|
||||
len(connection.queries) - initial_queries
|
||||
if hasattr(connection, "queries")
|
||||
else 0
|
||||
)
|
||||
total_queries = len(connection.queries) - initial_queries if hasattr(connection, "queries") else 0
|
||||
|
||||
# Get content length
|
||||
content_length = 0
|
||||
@@ -70,9 +64,7 @@ class PerformanceMiddleware(MiddlewareMixin):
|
||||
if hasattr(request, "user") and request.user.is_authenticated
|
||||
else None
|
||||
),
|
||||
"user_agent": request.META.get("HTTP_USER_AGENT", "")[
|
||||
:100
|
||||
], # Truncate user agent
|
||||
"user_agent": request.META.get("HTTP_USER_AGENT", "")[:100], # Truncate user agent
|
||||
"remote_addr": self._get_client_ip(request),
|
||||
}
|
||||
|
||||
@@ -81,11 +73,7 @@ class PerformanceMiddleware(MiddlewareMixin):
|
||||
recent_queries = connection.queries[-total_queries:]
|
||||
performance_data["queries"] = [
|
||||
{
|
||||
"sql": (
|
||||
query["sql"][:200] + "..."
|
||||
if len(query["sql"]) > 200
|
||||
else query["sql"]
|
||||
),
|
||||
"sql": (query["sql"][:200] + "..." if len(query["sql"]) > 200 else query["sql"]),
|
||||
"time": float(query["time"]),
|
||||
}
|
||||
for query in recent_queries[-10:] # Last 10 queries only
|
||||
@@ -95,9 +83,7 @@ class PerformanceMiddleware(MiddlewareMixin):
|
||||
slow_queries = [q for q in recent_queries if float(q["time"]) > 0.1]
|
||||
if slow_queries:
|
||||
performance_data["slow_query_count"] = len(slow_queries)
|
||||
performance_data["slowest_query_time"] = max(
|
||||
float(q["time"]) for q in slow_queries
|
||||
)
|
||||
performance_data["slowest_query_time"] = max(float(q["time"]) for q in slow_queries)
|
||||
|
||||
# Determine log level based on performance
|
||||
log_level = self._get_log_level(duration, total_queries, response.status_code)
|
||||
@@ -115,9 +101,7 @@ class PerformanceMiddleware(MiddlewareMixin):
|
||||
response["X-Response-Time"] = f"{duration * 1000:.2f}ms"
|
||||
response["X-Query-Count"] = str(total_queries)
|
||||
if total_queries > 0 and hasattr(connection, "queries"):
|
||||
total_query_time = sum(
|
||||
float(q["time"]) for q in connection.queries[-total_queries:]
|
||||
)
|
||||
total_query_time = sum(float(q["time"]) for q in connection.queries[-total_queries:])
|
||||
response["X-Query-Time"] = f"{total_query_time * 1000:.2f}ms"
|
||||
|
||||
return response
|
||||
@@ -129,11 +113,7 @@ class PerformanceMiddleware(MiddlewareMixin):
|
||||
duration = end_time - start_time
|
||||
|
||||
initial_queries = getattr(request, "_performance_initial_queries", 0)
|
||||
total_queries = (
|
||||
len(connection.queries) - initial_queries
|
||||
if hasattr(connection, "queries")
|
||||
else 0
|
||||
)
|
||||
total_queries = len(connection.queries) - initial_queries if hasattr(connection, "queries") else 0
|
||||
|
||||
performance_data = {
|
||||
"path": request.path,
|
||||
@@ -195,9 +175,7 @@ class QueryCountMiddleware(MiddlewareMixin):
|
||||
|
||||
def process_request(self, request):
|
||||
"""Initialize query tracking"""
|
||||
request._query_count_start = (
|
||||
len(connection.queries) if hasattr(connection, "queries") else 0
|
||||
)
|
||||
request._query_count_start = len(connection.queries) if hasattr(connection, "queries") else 0
|
||||
|
||||
def process_response(self, request, response):
|
||||
"""Check query count and warn if excessive"""
|
||||
@@ -267,9 +245,7 @@ class CachePerformanceMiddleware(MiddlewareMixin):
|
||||
|
||||
def process_response(self, request, response):
|
||||
"""Log cache performance metrics"""
|
||||
cache_duration = time.time() - getattr(
|
||||
request, "_cache_start_time", time.time()
|
||||
)
|
||||
cache_duration = time.time() - getattr(request, "_cache_start_time", time.time())
|
||||
cache_hits = getattr(request, "_cache_hits", 0)
|
||||
cache_misses = getattr(request, "_cache_misses", 0)
|
||||
|
||||
|
||||
@@ -35,20 +35,17 @@ class AuthRateLimitMiddleware:
|
||||
# Endpoints to rate limit
|
||||
RATE_LIMITED_PATHS = {
|
||||
# Login endpoints
|
||||
'/api/v1/auth/login/': {'per_minute': 5, 'per_hour': 30, 'per_day': 100},
|
||||
'/accounts/login/': {'per_minute': 5, 'per_hour': 30, 'per_day': 100},
|
||||
|
||||
"/api/v1/auth/login/": {"per_minute": 5, "per_hour": 30, "per_day": 100},
|
||||
"/accounts/login/": {"per_minute": 5, "per_hour": 30, "per_day": 100},
|
||||
# Signup endpoints
|
||||
'/api/v1/auth/signup/': {'per_minute': 3, 'per_hour': 10, 'per_day': 20},
|
||||
'/accounts/signup/': {'per_minute': 3, 'per_hour': 10, 'per_day': 20},
|
||||
|
||||
"/api/v1/auth/signup/": {"per_minute": 3, "per_hour": 10, "per_day": 20},
|
||||
"/accounts/signup/": {"per_minute": 3, "per_hour": 10, "per_day": 20},
|
||||
# Password reset endpoints
|
||||
'/api/v1/auth/password-reset/': {'per_minute': 2, 'per_hour': 5, 'per_day': 10},
|
||||
'/accounts/password/reset/': {'per_minute': 2, 'per_hour': 5, 'per_day': 10},
|
||||
|
||||
"/api/v1/auth/password-reset/": {"per_minute": 2, "per_hour": 5, "per_day": 10},
|
||||
"/accounts/password/reset/": {"per_minute": 2, "per_hour": 5, "per_day": 10},
|
||||
# Token endpoints
|
||||
'/api/v1/auth/token/': {'per_minute': 10, 'per_hour': 60, 'per_day': 200},
|
||||
'/api/v1/auth/token/refresh/': {'per_minute': 20, 'per_hour': 120, 'per_day': 500},
|
||||
"/api/v1/auth/token/": {"per_minute": 10, "per_hour": 60, "per_day": 200},
|
||||
"/api/v1/auth/token/refresh/": {"per_minute": 20, "per_hour": 120, "per_day": 500},
|
||||
}
|
||||
|
||||
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
|
||||
@@ -56,7 +53,7 @@ class AuthRateLimitMiddleware:
|
||||
|
||||
def __call__(self, request: HttpRequest) -> HttpResponse:
|
||||
# Only rate limit POST requests to auth endpoints
|
||||
if request.method != 'POST':
|
||||
if request.method != "POST":
|
||||
return self.get_response(request)
|
||||
|
||||
# Check if this path should be rate limited
|
||||
@@ -68,14 +65,10 @@ class AuthRateLimitMiddleware:
|
||||
client_ip = self._get_client_ip(request)
|
||||
|
||||
# Check rate limits
|
||||
is_allowed, message = self._check_rate_limits(
|
||||
client_ip, request.path, limits
|
||||
)
|
||||
is_allowed, message = self._check_rate_limits(client_ip, request.path, limits)
|
||||
|
||||
if not is_allowed:
|
||||
logger.warning(
|
||||
f"Rate limit exceeded for {client_ip} on {request.path}"
|
||||
)
|
||||
logger.warning(f"Rate limit exceeded for {client_ip} on {request.path}")
|
||||
return self._rate_limit_response(message)
|
||||
|
||||
# Process request
|
||||
@@ -94,9 +87,9 @@ class AuthRateLimitMiddleware:
|
||||
return self.RATE_LIMITED_PATHS[path]
|
||||
|
||||
# Prefix match (for paths with trailing slashes)
|
||||
path_without_slash = path.rstrip('/')
|
||||
path_without_slash = path.rstrip("/")
|
||||
for limited_path, limits in self.RATE_LIMITED_PATHS.items():
|
||||
if path_without_slash == limited_path.rstrip('/'):
|
||||
if path_without_slash == limited_path.rstrip("/"):
|
||||
return limits
|
||||
|
||||
return None
|
||||
@@ -108,23 +101,18 @@ class AuthRateLimitMiddleware:
|
||||
Handles common proxy headers (X-Forwarded-For, X-Real-IP).
|
||||
"""
|
||||
# Check for forwarded headers (set by reverse proxies)
|
||||
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
|
||||
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
|
||||
if x_forwarded_for:
|
||||
# Take the first IP in the chain (client IP)
|
||||
return x_forwarded_for.split(',')[0].strip()
|
||||
return x_forwarded_for.split(",")[0].strip()
|
||||
|
||||
x_real_ip = request.META.get('HTTP_X_REAL_IP')
|
||||
x_real_ip = request.META.get("HTTP_X_REAL_IP")
|
||||
if x_real_ip:
|
||||
return x_real_ip
|
||||
|
||||
return request.META.get('REMOTE_ADDR', 'unknown')
|
||||
return request.META.get("REMOTE_ADDR", "unknown")
|
||||
|
||||
def _check_rate_limits(
|
||||
self,
|
||||
client_ip: str,
|
||||
path: str,
|
||||
limits: dict
|
||||
) -> tuple[bool, str]:
|
||||
def _check_rate_limits(self, client_ip: str, path: str, limits: dict) -> tuple[bool, str]:
|
||||
"""
|
||||
Check if the client has exceeded rate limits.
|
||||
|
||||
@@ -132,31 +120,31 @@ class AuthRateLimitMiddleware:
|
||||
Tuple of (is_allowed, reason_if_blocked)
|
||||
"""
|
||||
# Create a safe cache key from path
|
||||
path_key = path.replace('/', '_').strip('_')
|
||||
path_key = path.replace("/", "_").strip("_")
|
||||
|
||||
# Check per-minute limit
|
||||
minute_key = f"auth_rate:{client_ip}:{path_key}:minute"
|
||||
minute_count = cache.get(minute_key, 0)
|
||||
if minute_count >= limits.get('per_minute', 10):
|
||||
if minute_count >= limits.get("per_minute", 10):
|
||||
return False, "Too many requests. Please wait a minute before trying again."
|
||||
|
||||
# Check per-hour limit
|
||||
hour_key = f"auth_rate:{client_ip}:{path_key}:hour"
|
||||
hour_count = cache.get(hour_key, 0)
|
||||
if hour_count >= limits.get('per_hour', 60):
|
||||
if hour_count >= limits.get("per_hour", 60):
|
||||
return False, "Too many requests. Please try again later."
|
||||
|
||||
# Check per-day limit
|
||||
day_key = f"auth_rate:{client_ip}:{path_key}:day"
|
||||
day_count = cache.get(day_key, 0)
|
||||
if day_count >= limits.get('per_day', 200):
|
||||
if day_count >= limits.get("per_day", 200):
|
||||
return False, "Daily limit exceeded. Please try again tomorrow."
|
||||
|
||||
return True, ""
|
||||
|
||||
def _increment_counters(self, client_ip: str, path: str) -> None:
|
||||
"""Increment rate limit counters."""
|
||||
path_key = path.replace('/', '_').strip('_')
|
||||
path_key = path.replace("/", "_").strip("_")
|
||||
|
||||
# Increment per-minute counter
|
||||
minute_key = f"auth_rate:{client_ip}:{path_key}:minute"
|
||||
@@ -183,8 +171,8 @@ class AuthRateLimitMiddleware:
|
||||
"""Generate a rate limit exceeded response."""
|
||||
return JsonResponse(
|
||||
{
|
||||
'error': message,
|
||||
'code': 'RATE_LIMIT_EXCEEDED',
|
||||
"error": message,
|
||||
"code": "RATE_LIMIT_EXCEEDED",
|
||||
},
|
||||
status=429, # Too Many Requests
|
||||
)
|
||||
@@ -201,50 +189,31 @@ class SecurityEventLogger:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def log_failed_login(
|
||||
request: HttpRequest,
|
||||
username: str,
|
||||
reason: str = "Invalid credentials"
|
||||
) -> None:
|
||||
def log_failed_login(request: HttpRequest, username: str, reason: str = "Invalid credentials") -> None:
|
||||
"""Log a failed login attempt."""
|
||||
client_ip = AuthRateLimitMiddleware._get_client_ip(
|
||||
AuthRateLimitMiddleware, request
|
||||
)
|
||||
client_ip = AuthRateLimitMiddleware._get_client_ip(AuthRateLimitMiddleware, request)
|
||||
logger.warning(
|
||||
f"Failed login attempt - IP: {client_ip}, Username: {username}, "
|
||||
f"Reason: {reason}, User-Agent: {request.META.get('HTTP_USER_AGENT', 'unknown')}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def log_permission_denied(
|
||||
request: HttpRequest,
|
||||
resource: str,
|
||||
action: str = "access"
|
||||
) -> None:
|
||||
def log_permission_denied(request: HttpRequest, resource: str, action: str = "access") -> None:
|
||||
"""Log a permission denied event."""
|
||||
client_ip = AuthRateLimitMiddleware._get_client_ip(
|
||||
AuthRateLimitMiddleware, request
|
||||
)
|
||||
user = getattr(request, 'user', None)
|
||||
username = user.username if user and user.is_authenticated else 'anonymous'
|
||||
client_ip = AuthRateLimitMiddleware._get_client_ip(AuthRateLimitMiddleware, request)
|
||||
user = getattr(request, "user", None)
|
||||
username = user.username if user and user.is_authenticated else "anonymous"
|
||||
|
||||
logger.warning(
|
||||
f"Permission denied - IP: {client_ip}, User: {username}, "
|
||||
f"Resource: {resource}, Action: {action}"
|
||||
f"Permission denied - IP: {client_ip}, User: {username}, " f"Resource: {resource}, Action: {action}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def log_suspicious_activity(
|
||||
request: HttpRequest,
|
||||
activity_type: str,
|
||||
details: str = ""
|
||||
) -> None:
|
||||
def log_suspicious_activity(request: HttpRequest, activity_type: str, details: str = "") -> None:
|
||||
"""Log suspicious activity."""
|
||||
client_ip = AuthRateLimitMiddleware._get_client_ip(
|
||||
AuthRateLimitMiddleware, request
|
||||
)
|
||||
user = getattr(request, 'user', None)
|
||||
username = user.username if user and user.is_authenticated else 'anonymous'
|
||||
client_ip = AuthRateLimitMiddleware._get_client_ip(AuthRateLimitMiddleware, request)
|
||||
user = getattr(request, "user", None)
|
||||
username = user.username if user and user.is_authenticated else "anonymous"
|
||||
|
||||
logger.error(
|
||||
f"Suspicious activity detected - Type: {activity_type}, "
|
||||
|
||||
@@ -9,7 +9,7 @@ import time
|
||||
|
||||
from django.utils.deprecation import MiddlewareMixin
|
||||
|
||||
logger = logging.getLogger('request_logging')
|
||||
logger = logging.getLogger("request_logging")
|
||||
|
||||
|
||||
class RequestLoggingMiddleware(MiddlewareMixin):
|
||||
@@ -20,17 +20,16 @@ class RequestLoggingMiddleware(MiddlewareMixin):
|
||||
|
||||
# Paths to exclude from detailed logging (e.g., static files, health checks)
|
||||
EXCLUDE_DETAILED_LOGGING_PATHS = [
|
||||
'/static/',
|
||||
'/media/',
|
||||
'/favicon.ico',
|
||||
'/health/',
|
||||
'/admin/jsi18n/',
|
||||
"/static/",
|
||||
"/media/",
|
||||
"/favicon.ico",
|
||||
"/health/",
|
||||
"/admin/jsi18n/",
|
||||
]
|
||||
|
||||
def _should_log_detailed(self, request):
|
||||
"""Determine if detailed logging should be enabled for this request."""
|
||||
return not any(
|
||||
path in request.path for path in self.EXCLUDE_DETAILED_LOGGING_PATHS)
|
||||
return not any(path in request.path for path in self.EXCLUDE_DETAILED_LOGGING_PATHS)
|
||||
|
||||
def process_request(self, request):
|
||||
"""Store request start time and capture request data for detailed logging."""
|
||||
@@ -44,14 +43,17 @@ class RequestLoggingMiddleware(MiddlewareMixin):
|
||||
try:
|
||||
# Log request data
|
||||
request_data = {}
|
||||
if hasattr(request, 'data') and request.data:
|
||||
if hasattr(request, "data") and request.data:
|
||||
request_data = dict(request.data)
|
||||
elif request.body:
|
||||
try:
|
||||
request_data = json.loads(request.body.decode('utf-8'))
|
||||
request_data = json.loads(request.body.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
request_data = {'body': str(request.body)[
|
||||
:200] + '...' if len(str(request.body)) > 200 else str(request.body)}
|
||||
request_data = {
|
||||
"body": (
|
||||
str(request.body)[:200] + "..." if len(str(request.body)) > 200 else str(request.body)
|
||||
)
|
||||
}
|
||||
|
||||
# Log query parameters
|
||||
query_params = dict(request.GET) if request.GET else {}
|
||||
@@ -61,9 +63,8 @@ class RequestLoggingMiddleware(MiddlewareMixin):
|
||||
logger.info(f" Body: {self._safe_log_data(request_data)}")
|
||||
if query_params:
|
||||
logger.info(f" Query: {query_params}")
|
||||
if hasattr(request, 'user') and request.user.is_authenticated:
|
||||
logger.info(
|
||||
f" User: {request.user.username} (ID: {request.user.id})")
|
||||
if hasattr(request, "user") and request.user.is_authenticated:
|
||||
logger.info(f" User: {request.user.username} (ID: {request.user.id})")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to log request data: {e}")
|
||||
@@ -75,34 +76,28 @@ class RequestLoggingMiddleware(MiddlewareMixin):
|
||||
try:
|
||||
# Calculate request duration
|
||||
duration = 0
|
||||
if hasattr(request, '_start_time'):
|
||||
if hasattr(request, "_start_time"):
|
||||
duration = time.time() - request._start_time
|
||||
|
||||
# Basic request logging
|
||||
logger.info(
|
||||
f"{request.method} {request.get_full_path()} -> {response.status_code} "
|
||||
f"({duration:.3f}s)"
|
||||
)
|
||||
logger.info(f"{request.method} {request.get_full_path()} -> {response.status_code} " f"({duration:.3f}s)")
|
||||
|
||||
# Detailed response logging for specific endpoints
|
||||
if getattr(request, '_log_request_data', False):
|
||||
if getattr(request, "_log_request_data", False):
|
||||
try:
|
||||
# Log response data
|
||||
if hasattr(response, 'data'):
|
||||
logger.info(
|
||||
f"RESPONSE DATA for {request.method} {request.path}:")
|
||||
if hasattr(response, "data"):
|
||||
logger.info(f"RESPONSE DATA for {request.method} {request.path}:")
|
||||
logger.info(f" Status: {response.status_code}")
|
||||
logger.info(f" Data: {self._safe_log_data(response.data)}")
|
||||
elif hasattr(response, 'content'):
|
||||
elif hasattr(response, "content"):
|
||||
try:
|
||||
content = json.loads(response.content.decode('utf-8'))
|
||||
logger.info(
|
||||
f"RESPONSE DATA for {request.method} {request.path}:")
|
||||
content = json.loads(response.content.decode("utf-8"))
|
||||
logger.info(f"RESPONSE DATA for {request.method} {request.path}:")
|
||||
logger.info(f" Status: {response.status_code}")
|
||||
logger.info(f" Content: {self._safe_log_data(content)}")
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
logger.info(
|
||||
f"RESPONSE DATA for {request.method} {request.path}:")
|
||||
logger.info(f"RESPONSE DATA for {request.method} {request.path}:")
|
||||
logger.info(f" Status: {response.status_code}")
|
||||
logger.info(f" Content: {str(response.content)[:200]}...")
|
||||
|
||||
@@ -118,31 +113,31 @@ class RequestLoggingMiddleware(MiddlewareMixin):
|
||||
# Sensitive field patterns that should be masked in logs
|
||||
# Security: Comprehensive list of sensitive data patterns
|
||||
SENSITIVE_PATTERNS = [
|
||||
'password',
|
||||
'passwd',
|
||||
'pwd',
|
||||
'token',
|
||||
'secret',
|
||||
'key',
|
||||
'api_key',
|
||||
'apikey',
|
||||
'auth',
|
||||
'authorization',
|
||||
'credential',
|
||||
'ssn',
|
||||
'social_security',
|
||||
'credit_card',
|
||||
'creditcard',
|
||||
'card_number',
|
||||
'cvv',
|
||||
'cvc',
|
||||
'pin',
|
||||
'access_token',
|
||||
'refresh_token',
|
||||
'jwt',
|
||||
'session',
|
||||
'cookie',
|
||||
'private',
|
||||
"password",
|
||||
"passwd",
|
||||
"pwd",
|
||||
"token",
|
||||
"secret",
|
||||
"key",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"auth",
|
||||
"authorization",
|
||||
"credential",
|
||||
"ssn",
|
||||
"social_security",
|
||||
"credit_card",
|
||||
"creditcard",
|
||||
"card_number",
|
||||
"cvv",
|
||||
"cvc",
|
||||
"pin",
|
||||
"access_token",
|
||||
"refresh_token",
|
||||
"jwt",
|
||||
"session",
|
||||
"cookie",
|
||||
"private",
|
||||
]
|
||||
|
||||
def _safe_log_data(self, data):
|
||||
@@ -167,15 +162,15 @@ class RequestLoggingMiddleware(MiddlewareMixin):
|
||||
|
||||
# Truncate if too long
|
||||
if len(data_str) > 1000:
|
||||
return data_str[:1000] + '...[TRUNCATED]'
|
||||
return data_str[:1000] + "...[TRUNCATED]"
|
||||
return data_str
|
||||
except Exception:
|
||||
return str(data)[:500] + '...[ERROR_LOGGING]'
|
||||
return str(data)[:500] + "...[ERROR_LOGGING]"
|
||||
|
||||
def _mask_sensitive_dict(self, data, depth=0):
|
||||
"""Recursively mask sensitive fields in a dictionary."""
|
||||
if depth > 5: # Prevent infinite recursion
|
||||
return '***DEPTH_LIMIT***'
|
||||
return "***DEPTH_LIMIT***"
|
||||
|
||||
safe_data = {}
|
||||
for key, value in data.items():
|
||||
@@ -183,7 +178,7 @@ class RequestLoggingMiddleware(MiddlewareMixin):
|
||||
|
||||
# Check if key contains any sensitive pattern
|
||||
if any(pattern in key_lower for pattern in self.SENSITIVE_PATTERNS):
|
||||
safe_data[key] = '***MASKED***'
|
||||
safe_data[key] = "***MASKED***"
|
||||
else:
|
||||
safe_data[key] = self._mask_sensitive_value(value, depth)
|
||||
|
||||
@@ -197,11 +192,11 @@ class RequestLoggingMiddleware(MiddlewareMixin):
|
||||
return [self._mask_sensitive_value(item, depth + 1) for item in value[:10]] # Limit list items
|
||||
elif isinstance(value, str):
|
||||
# Mask email addresses (show only domain)
|
||||
if '@' in value and '.' in value.split('@')[-1]:
|
||||
parts = value.split('@')
|
||||
if "@" in value and "." in value.split("@")[-1]:
|
||||
parts = value.split("@")
|
||||
if len(parts) == 2:
|
||||
return f"***@{parts[1]}"
|
||||
# Truncate long strings
|
||||
if len(value) > 200:
|
||||
return value[:200] + '...[TRUNCATED]'
|
||||
return value[:200] + "...[TRUNCATED]"
|
||||
return value
|
||||
|
||||
@@ -49,9 +49,7 @@ class SecurityHeadersMiddleware:
|
||||
if not response.get("Content-Security-Policy"):
|
||||
response["Content-Security-Policy"] = self._csp_header
|
||||
else:
|
||||
logger.warning(
|
||||
f"CSP header already present for {request.path}, skipping"
|
||||
)
|
||||
logger.warning(f"CSP header already present for {request.path}, skipping")
|
||||
|
||||
# Permissions-Policy (successor to Feature-Policy)
|
||||
if not response.get("Permissions-Policy"):
|
||||
@@ -144,11 +142,13 @@ class SecurityHeadersMiddleware:
|
||||
# Add debug-specific relaxations
|
||||
if debug:
|
||||
# Allow webpack dev server connections in development
|
||||
directives["connect-src"].extend([
|
||||
"ws://localhost:*",
|
||||
"http://localhost:*",
|
||||
"http://127.0.0.1:*",
|
||||
])
|
||||
directives["connect-src"].extend(
|
||||
[
|
||||
"ws://localhost:*",
|
||||
"http://localhost:*",
|
||||
"http://127.0.0.1:*",
|
||||
]
|
||||
)
|
||||
|
||||
# Build header string
|
||||
parts = []
|
||||
@@ -168,30 +168,34 @@ class SecurityHeadersMiddleware:
|
||||
This header controls which browser features the page can use.
|
||||
"""
|
||||
# Get permissions policy from settings or use defaults
|
||||
policy = getattr(settings, "PERMISSIONS_POLICY", {
|
||||
"accelerometer": [],
|
||||
"ambient-light-sensor": [],
|
||||
"autoplay": [],
|
||||
"camera": [],
|
||||
"display-capture": [],
|
||||
"document-domain": [],
|
||||
"encrypted-media": [],
|
||||
"fullscreen": ["self"],
|
||||
"geolocation": ["self"],
|
||||
"gyroscope": [],
|
||||
"interest-cohort": [],
|
||||
"magnetometer": [],
|
||||
"microphone": [],
|
||||
"midi": [],
|
||||
"payment": [],
|
||||
"picture-in-picture": [],
|
||||
"publickey-credentials-get": [],
|
||||
"screen-wake-lock": [],
|
||||
"sync-xhr": [],
|
||||
"usb": [],
|
||||
"web-share": ["self"],
|
||||
"xr-spatial-tracking": [],
|
||||
})
|
||||
policy = getattr(
|
||||
settings,
|
||||
"PERMISSIONS_POLICY",
|
||||
{
|
||||
"accelerometer": [],
|
||||
"ambient-light-sensor": [],
|
||||
"autoplay": [],
|
||||
"camera": [],
|
||||
"display-capture": [],
|
||||
"document-domain": [],
|
||||
"encrypted-media": [],
|
||||
"fullscreen": ["self"],
|
||||
"geolocation": ["self"],
|
||||
"gyroscope": [],
|
||||
"interest-cohort": [],
|
||||
"magnetometer": [],
|
||||
"microphone": [],
|
||||
"midi": [],
|
||||
"payment": [],
|
||||
"picture-in-picture": [],
|
||||
"publickey-credentials-get": [],
|
||||
"screen-wake-lock": [],
|
||||
"sync-xhr": [],
|
||||
"usb": [],
|
||||
"web-share": ["self"],
|
||||
"xr-spatial-tracking": [],
|
||||
},
|
||||
)
|
||||
|
||||
parts = []
|
||||
for feature, allowlist in policy.items():
|
||||
|
||||
@@ -9,7 +9,6 @@ analytics for the trending algorithm.
|
||||
import logging
|
||||
import re
|
||||
from datetime import timedelta
|
||||
from typing import Union
|
||||
|
||||
from django.conf import settings
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
@@ -22,7 +21,7 @@ from apps.parks.models import Park
|
||||
from apps.rides.models import Ride
|
||||
|
||||
# Type alias for content objects
|
||||
ContentObject = Union[Park, Ride]
|
||||
ContentObject = Park | Ride
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -50,8 +49,7 @@ class ViewTrackingMiddleware:
|
||||
|
||||
# Compile patterns for performance
|
||||
self.compiled_patterns = [
|
||||
(re.compile(pattern), content_type)
|
||||
for pattern, content_type in self.tracked_patterns
|
||||
(re.compile(pattern), content_type) for pattern, content_type in self.tracked_patterns
|
||||
]
|
||||
|
||||
# Cache configuration
|
||||
@@ -63,11 +61,7 @@ class ViewTrackingMiddleware:
|
||||
response = self.get_response(request)
|
||||
|
||||
# Only track successful GET requests
|
||||
if (
|
||||
request.method == "GET"
|
||||
and 200 <= response.status_code < 300
|
||||
and not self._should_skip_tracking(request)
|
||||
):
|
||||
if request.method == "GET" and 200 <= response.status_code < 300 and not self._should_skip_tracking(request):
|
||||
try:
|
||||
self._track_view_if_applicable(request)
|
||||
except Exception as e:
|
||||
@@ -119,9 +113,7 @@ class ViewTrackingMiddleware:
|
||||
self._record_page_view(request, content_type, slug)
|
||||
break
|
||||
|
||||
def _record_page_view(
|
||||
self, request: HttpRequest, content_type: str, slug: str
|
||||
) -> None:
|
||||
def _record_page_view(self, request: HttpRequest, content_type: str, slug: str) -> None:
|
||||
"""Record a page view for the specified content."""
|
||||
client_ip = self._get_client_ip(request)
|
||||
if not client_ip:
|
||||
@@ -131,33 +123,23 @@ class ViewTrackingMiddleware:
|
||||
# Get the content object
|
||||
content_obj = self._get_content_object(content_type, slug)
|
||||
if not content_obj:
|
||||
self.logger.warning(
|
||||
f"Content not found: {content_type} with slug '{slug}'"
|
||||
)
|
||||
self.logger.warning(f"Content not found: {content_type} with slug '{slug}'")
|
||||
return
|
||||
|
||||
# Check deduplication
|
||||
if self._is_duplicate_view(content_obj, client_ip):
|
||||
self.logger.debug(
|
||||
f"Duplicate view skipped for {content_type} {slug} from {client_ip}"
|
||||
)
|
||||
self.logger.debug(f"Duplicate view skipped for {content_type} {slug} from {client_ip}")
|
||||
return
|
||||
|
||||
# Create PageView record
|
||||
self._create_page_view(content_obj, client_ip, request)
|
||||
|
||||
self.logger.debug(
|
||||
f"Recorded view for {content_type} {slug} from {client_ip}"
|
||||
)
|
||||
self.logger.debug(f"Recorded view for {content_type} {slug} from {client_ip}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Failed to record page view for {content_type} {slug}: {e}"
|
||||
)
|
||||
self.logger.error(f"Failed to record page view for {content_type} {slug}: {e}")
|
||||
|
||||
def _get_content_object(
|
||||
self, content_type: str, slug: str
|
||||
) -> ContentObject | None:
|
||||
def _get_content_object(self, content_type: str, slug: str) -> ContentObject | None:
|
||||
"""Get the content object by type and slug."""
|
||||
try:
|
||||
if content_type == "park":
|
||||
@@ -202,16 +184,12 @@ class ViewTrackingMiddleware:
|
||||
|
||||
return existing_view
|
||||
|
||||
def _create_page_view(
|
||||
self, content_obj: ContentObject, client_ip: str, request: HttpRequest
|
||||
) -> None:
|
||||
def _create_page_view(self, content_obj: ContentObject, client_ip: str, request: HttpRequest) -> None:
|
||||
"""Create a new PageView record."""
|
||||
content_type = ContentType.objects.get_for_model(content_obj)
|
||||
|
||||
# Extract additional metadata
|
||||
user_agent = request.META.get("HTTP_USER_AGENT", "")[
|
||||
:500
|
||||
] # Truncate long user agents
|
||||
user_agent = request.META.get("HTTP_USER_AGENT", "")[:500] # Truncate long user agents
|
||||
referer = request.META.get("HTTP_REFERER", "")[:500]
|
||||
|
||||
PageView.objects.create(
|
||||
@@ -267,11 +245,9 @@ class ViewTrackingMiddleware:
|
||||
return False
|
||||
|
||||
# Skip localhost and private IPs in production
|
||||
if getattr(settings, "SKIP_LOCAL_IPS", not settings.DEBUG):
|
||||
if getattr(settings, "SKIP_LOCAL_IPS", not settings.DEBUG): # noqa: SIM102
|
||||
if (ip.startswith(("127.", "192.168.", "10.")) or ip.startswith("172.")) and any(
|
||||
16 <= int(ip.split(".")[1]) <= 31
|
||||
for _ in [ip]
|
||||
if ip.startswith("172.")
|
||||
16 <= int(ip.split(".")[1]) <= 31 for _ in [ip] if ip.startswith("172.")
|
||||
):
|
||||
return False
|
||||
|
||||
|
||||
@@ -32,9 +32,7 @@ class Migration(migrations.Migration):
|
||||
migrations.AlterField(
|
||||
model_name="slughistory",
|
||||
name="object_id",
|
||||
field=models.CharField(
|
||||
help_text="ID of the object this slug belongs to", max_length=50
|
||||
),
|
||||
field=models.CharField(help_text="ID of the object this slug belongs to", max_length=50),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="slughistory",
|
||||
@@ -56,15 +54,11 @@ class Migration(migrations.Migration):
|
||||
migrations.AlterField(
|
||||
model_name="slughistoryevent",
|
||||
name="object_id",
|
||||
field=models.CharField(
|
||||
help_text="ID of the object this slug belongs to", max_length=50
|
||||
),
|
||||
field=models.CharField(help_text="ID of the object this slug belongs to", max_length=50),
|
||||
),
|
||||
migrations.AlterField(
|
||||
model_name="slughistoryevent",
|
||||
name="old_slug",
|
||||
field=models.SlugField(
|
||||
db_index=False, help_text="Previous slug value", max_length=200
|
||||
),
|
||||
field=models.SlugField(db_index=False, help_text="Previous slug value", max_length=200),
|
||||
),
|
||||
]
|
||||
|
||||
@@ -58,10 +58,7 @@ class HTMXFormMixin(FormMixin):
|
||||
def form_valid(self, form):
|
||||
"""Add HX-Trigger header on successful form submission via HTMX."""
|
||||
res = super().form_valid(form)
|
||||
if (
|
||||
self.request.headers.get("HX-Request") == "true"
|
||||
and self.htmx_success_trigger
|
||||
):
|
||||
if self.request.headers.get("HX-Request") == "true" and self.htmx_success_trigger:
|
||||
res["HX-Trigger"] = self.htmx_success_trigger
|
||||
return res
|
||||
|
||||
|
||||
@@ -78,9 +78,7 @@ class SluggedModel(TrackedModel):
|
||||
Returns the name of the read-only ID field for this model.
|
||||
Should be overridden by subclasses.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Subclasses of SluggedModel must implement get_id_field_name()"
|
||||
)
|
||||
raise NotImplementedError("Subclasses of SluggedModel must implement get_id_field_name()")
|
||||
|
||||
@classmethod
|
||||
def get_by_slug(cls, slug):
|
||||
@@ -123,4 +121,4 @@ class SluggedModel(TrackedModel):
|
||||
True,
|
||||
)
|
||||
|
||||
raise cls.DoesNotExist(f"{cls.__name__} with slug '{slug}' does not exist")
|
||||
raise cls.DoesNotExist(f"{cls.__name__} with slug '{slug}' does not exist") from None
|
||||
|
||||
@@ -14,10 +14,11 @@ class IsOwnerOrReadOnly(permissions.BasePermission):
|
||||
|
||||
# Write permissions are only allowed to the owner of the object.
|
||||
# Assumes the model instance has an `user` attribute.
|
||||
if hasattr(obj, 'user'):
|
||||
if hasattr(obj, "user"):
|
||||
return obj.user == request.user
|
||||
return False
|
||||
|
||||
|
||||
class IsStaffOrReadOnly(permissions.BasePermission):
|
||||
"""
|
||||
Custom permission to only allow staff to edit it.
|
||||
|
||||
@@ -61,23 +61,20 @@ def unified_locations_for_map(
|
||||
|
||||
# Rides
|
||||
if "ride" in location_types:
|
||||
ride_queryset = Ride.objects.select_related(
|
||||
"park", "manufacturer"
|
||||
).prefetch_related("park__location", "location")
|
||||
ride_queryset = Ride.objects.select_related("park", "manufacturer").prefetch_related(
|
||||
"park__location", "location"
|
||||
)
|
||||
|
||||
if bounds:
|
||||
ride_queryset = ride_queryset.filter(
|
||||
Q(location__coordinates__within=bounds)
|
||||
| Q(park__location__coordinates__within=bounds)
|
||||
Q(location__coordinates__within=bounds) | Q(park__location__coordinates__within=bounds)
|
||||
)
|
||||
|
||||
if filters:
|
||||
if "category" in filters:
|
||||
ride_queryset = ride_queryset.filter(category=filters["category"])
|
||||
if "manufacturer" in filters:
|
||||
ride_queryset = ride_queryset.filter(
|
||||
manufacturer=filters["manufacturer"]
|
||||
)
|
||||
ride_queryset = ride_queryset.filter(manufacturer=filters["manufacturer"])
|
||||
if "park" in filters:
|
||||
ride_queryset = ride_queryset.filter(park=filters["park"])
|
||||
|
||||
@@ -248,12 +245,7 @@ def popular_pages_summary(*, days: int = 30) -> dict[str, Any]:
|
||||
total_views = PageView.objects.filter(timestamp__gte=cutoff_date).count()
|
||||
|
||||
# Unique visitors (based on IP)
|
||||
unique_visitors = (
|
||||
PageView.objects.filter(timestamp__gte=cutoff_date)
|
||||
.values("ip_address")
|
||||
.distinct()
|
||||
.count()
|
||||
)
|
||||
unique_visitors = PageView.objects.filter(timestamp__gte=cutoff_date).values("ip_address").distinct().count()
|
||||
|
||||
return {
|
||||
"popular_pages": list(popular_pages),
|
||||
@@ -311,14 +303,10 @@ def system_health_metrics() -> dict[str, Any]:
|
||||
"page_views_7d": PageView.objects.filter(timestamp__gte=last_7d).count(),
|
||||
"data_freshness": {
|
||||
"latest_park_update": (
|
||||
Park.objects.order_by("-updated_at").first().updated_at
|
||||
if Park.objects.exists()
|
||||
else None
|
||||
Park.objects.order_by("-updated_at").first().updated_at if Park.objects.exists() else None
|
||||
),
|
||||
"latest_ride_update": (
|
||||
Ride.objects.order_by("-updated_at").first().updated_at
|
||||
if Ride.objects.exists()
|
||||
else None
|
||||
Ride.objects.order_by("-updated_at").first().updated_at if Ride.objects.exists() else None
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
@@ -63,9 +63,7 @@ class ClusteringService:
|
||||
if zoom_level < self.MIN_ZOOM_FOR_CLUSTERING:
|
||||
return True
|
||||
|
||||
config = self.ZOOM_CONFIGS.get(
|
||||
zoom_level, {"min_points": self.MIN_POINTS_TO_CLUSTER}
|
||||
)
|
||||
config = self.ZOOM_CONFIGS.get(zoom_level, {"min_points": self.MIN_POINTS_TO_CLUSTER})
|
||||
return point_count >= config["min_points"]
|
||||
|
||||
def cluster_locations(
|
||||
@@ -94,9 +92,7 @@ class ClusteringService:
|
||||
)
|
||||
|
||||
# Perform clustering
|
||||
clustered_groups = self._cluster_points(
|
||||
cluster_points, config["radius"], config["min_points"]
|
||||
)
|
||||
clustered_groups = self._cluster_points(cluster_points, config["radius"], config["min_points"])
|
||||
|
||||
# Separate individual locations from clusters
|
||||
unclustered_locations = []
|
||||
@@ -135,9 +131,7 @@ class ClusteringService:
|
||||
# Simple equirectangular projection (good enough for clustering)
|
||||
center_lat = (bounds.north + bounds.south) / 2
|
||||
lat_scale = 111320 # meters per degree latitude
|
||||
lng_scale = 111320 * math.cos(
|
||||
math.radians(center_lat)
|
||||
) # meters per degree longitude
|
||||
lng_scale = 111320 * math.cos(math.radians(center_lat)) # meters per degree longitude
|
||||
|
||||
for location in locations:
|
||||
# Convert to meters relative to bounds center
|
||||
@@ -200,9 +194,7 @@ class ClusteringService:
|
||||
# Calculate cluster bounds
|
||||
lats = [loc.latitude for loc in locations]
|
||||
lngs = [loc.longitude for loc in locations]
|
||||
cluster_bounds = GeoBounds(
|
||||
north=max(lats), south=min(lats), east=max(lngs), west=min(lngs)
|
||||
)
|
||||
cluster_bounds = GeoBounds(north=max(lats), south=min(lats), east=max(lngs), west=min(lngs))
|
||||
|
||||
# Collect location types in cluster
|
||||
types = {loc.type for loc in locations}
|
||||
@@ -223,9 +215,7 @@ class ClusteringService:
|
||||
representative_location=representative,
|
||||
)
|
||||
|
||||
def _select_representative_location(
|
||||
self, locations: list[UnifiedLocation]
|
||||
) -> UnifiedLocation | None:
|
||||
def _select_representative_location(self, locations: list[UnifiedLocation]) -> UnifiedLocation | None:
|
||||
"""Select the most representative location for a cluster."""
|
||||
if not locations:
|
||||
return None
|
||||
@@ -291,9 +281,7 @@ class ClusteringService:
|
||||
"category_distribution": dict(category_counts),
|
||||
}
|
||||
|
||||
def expand_cluster(
|
||||
self, cluster: ClusterData, zoom_level: int
|
||||
) -> list[UnifiedLocation]:
|
||||
def expand_cluster(self, cluster: ClusterData, zoom_level: int) -> list[UnifiedLocation]:
|
||||
"""
|
||||
Expand a cluster to show individual locations (for drill-down functionality).
|
||||
This would typically require re-querying the database with the cluster bounds.
|
||||
@@ -320,14 +308,11 @@ class SmartClusteringRules:
|
||||
return True
|
||||
|
||||
# Major parks should resist clustering unless very close
|
||||
if (
|
||||
loc1.cluster_category == "major_park"
|
||||
or loc2.cluster_category == "major_park"
|
||||
):
|
||||
if loc1.cluster_category == "major_park" or loc2.cluster_category == "major_park":
|
||||
return False
|
||||
|
||||
# Similar types cluster more readily
|
||||
if loc1.type == loc2.type:
|
||||
if loc1.type == loc2.type: # noqa: SIM103
|
||||
return True
|
||||
|
||||
# Different types can cluster but with higher threshold
|
||||
|
||||
@@ -89,9 +89,7 @@ class MapFilters:
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for caching and serialization."""
|
||||
return {
|
||||
"location_types": (
|
||||
[t.value for t in self.location_types] if self.location_types else None
|
||||
),
|
||||
"location_types": ([t.value for t in self.location_types] if self.location_types else None),
|
||||
"park_status": (list(self.park_status) if self.park_status else None),
|
||||
"ride_types": list(self.ride_types) if self.ride_types else None,
|
||||
"company_roles": (list(self.company_roles) if self.company_roles else None),
|
||||
@@ -183,11 +181,7 @@ class ClusterData:
|
||||
"count": self.count,
|
||||
"types": [t.value for t in self.types],
|
||||
"bounds": self.bounds.to_dict(),
|
||||
"representative": (
|
||||
self.representative_location.to_dict()
|
||||
if self.representative_location
|
||||
else None
|
||||
),
|
||||
"representative": (self.representative_location.to_dict() if self.representative_location else None),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -103,9 +103,7 @@ class EnhancedCacheService:
|
||||
self.default_cache.set(cache_key, data, timeout)
|
||||
logger.debug(f"Cached geographic data for bounds {bounds}")
|
||||
|
||||
def get_cached_geographic_data(
|
||||
self, bounds: "GeoBounds", zoom_level: int
|
||||
) -> Any | None:
|
||||
def get_cached_geographic_data(self, bounds: "GeoBounds", zoom_level: int) -> Any | None:
|
||||
"""Retrieve cached geographic data"""
|
||||
cache_key = f"geo:{bounds.min_lat}:{bounds.min_lng}:{bounds.max_lat}:{
|
||||
bounds.max_lng
|
||||
@@ -119,20 +117,14 @@ class EnhancedCacheService:
|
||||
# For Redis cache backends
|
||||
if hasattr(self.default_cache, "delete_pattern"):
|
||||
deleted_count = self.default_cache.delete_pattern(pattern)
|
||||
logger.info(
|
||||
f"Invalidated {deleted_count} cache keys matching pattern '{pattern}'"
|
||||
)
|
||||
logger.info(f"Invalidated {deleted_count} cache keys matching pattern '{pattern}'")
|
||||
return deleted_count
|
||||
else:
|
||||
logger.warning(
|
||||
f"Cache backend does not support pattern deletion for pattern '{pattern}'"
|
||||
)
|
||||
logger.warning(f"Cache backend does not support pattern deletion for pattern '{pattern}'")
|
||||
except Exception as e:
|
||||
logger.error(f"Error invalidating cache pattern '{pattern}': {e}")
|
||||
|
||||
def invalidate_model_cache(
|
||||
self, model_name: str, instance_id: int | None = None
|
||||
):
|
||||
def invalidate_model_cache(self, model_name: str, instance_id: int | None = None):
|
||||
"""Invalidate cache keys related to a specific model"""
|
||||
pattern = f"*{model_name}:{instance_id}*" if instance_id else f"*{model_name}*"
|
||||
|
||||
@@ -175,11 +167,7 @@ def cache_api_response(timeout=1800, vary_on=None, key_prefix=""):
|
||||
# Generate cache key based on view, user, and parameters
|
||||
cache_key_parts = [
|
||||
key_prefix or view_func.__name__,
|
||||
(
|
||||
str(request.user.id)
|
||||
if request.user.is_authenticated
|
||||
else "anonymous"
|
||||
),
|
||||
(str(request.user.id) if request.user.is_authenticated else "anonymous"),
|
||||
str(hash(frozenset(request.GET.items()))),
|
||||
]
|
||||
|
||||
@@ -219,9 +207,7 @@ def cache_queryset_result(cache_key_template: str, timeout: int = 3600):
|
||||
cache_key = cache_key_template.format(*args, **kwargs)
|
||||
|
||||
cache_service = EnhancedCacheService()
|
||||
return cache_service.cache_queryset(
|
||||
cache_key, func, timeout, *args, **kwargs
|
||||
)
|
||||
return cache_service.cache_queryset(cache_key, func, timeout, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -218,17 +218,14 @@ class EntityFuzzyMatcher:
|
||||
|
||||
return matches[: self.MAX_RESULTS], suggestion
|
||||
|
||||
def _get_candidates(
|
||||
self, query: str, entity_type: EntityType
|
||||
) -> list[dict[str, Any]]:
|
||||
def _get_candidates(self, query: str, entity_type: EntityType) -> list[dict[str, Any]]:
|
||||
"""Get potential matching candidates for an entity type."""
|
||||
candidates = []
|
||||
|
||||
if entity_type == EntityType.PARK:
|
||||
parks = Park.objects.filter(
|
||||
Q(name__icontains=query)
|
||||
| Q(slug__icontains=query.lower().replace(" ", "-"))
|
||||
)[: self.MAX_CANDIDATES]
|
||||
parks = Park.objects.filter(Q(name__icontains=query) | Q(slug__icontains=query.lower().replace(" ", "-")))[
|
||||
: self.MAX_CANDIDATES
|
||||
]
|
||||
|
||||
for park in parks:
|
||||
candidates.append(
|
||||
@@ -265,8 +262,7 @@ class EntityFuzzyMatcher:
|
||||
|
||||
elif entity_type == EntityType.COMPANY:
|
||||
companies = Company.objects.filter(
|
||||
Q(name__icontains=query)
|
||||
| Q(slug__icontains=query.lower().replace(" ", "-"))
|
||||
Q(name__icontains=query) | Q(slug__icontains=query.lower().replace(" ", "-"))
|
||||
)[: self.MAX_CANDIDATES]
|
||||
|
||||
for company in companies:
|
||||
@@ -284,9 +280,7 @@ class EntityFuzzyMatcher:
|
||||
|
||||
return candidates
|
||||
|
||||
def _score_and_rank_candidates(
|
||||
self, query: str, candidates: list[dict[str, Any]]
|
||||
) -> list[FuzzyMatchResult]:
|
||||
def _score_and_rank_candidates(self, query: str, candidates: list[dict[str, Any]]) -> list[FuzzyMatchResult]:
|
||||
"""Score and rank all candidates using multiple algorithms."""
|
||||
scored_matches = []
|
||||
|
||||
@@ -354,9 +348,7 @@ class EntityFuzzyMatcher:
|
||||
# Sort by score (highest first) and return
|
||||
return sorted(scored_matches, key=lambda x: x.score, reverse=True)
|
||||
|
||||
def _generate_entity_suggestion(
|
||||
self, query: str, entity_types: list[EntityType], user
|
||||
) -> EntitySuggestion:
|
||||
def _generate_entity_suggestion(self, query: str, entity_types: list[EntityType], user) -> EntitySuggestion:
|
||||
"""Generate suggestion for creating new entity when no matches found."""
|
||||
|
||||
# Determine most likely entity type based on query characteristics
|
||||
@@ -364,14 +356,9 @@ class EntityFuzzyMatcher:
|
||||
|
||||
# Simple heuristics for entity type detection
|
||||
query_lower = query.lower()
|
||||
if any(
|
||||
word in query_lower
|
||||
for word in ["roller coaster", "ride", "coaster", "attraction"]
|
||||
):
|
||||
if any(word in query_lower for word in ["roller coaster", "ride", "coaster", "attraction"]):
|
||||
suggested_type = EntityType.RIDE
|
||||
elif any(
|
||||
word in query_lower for word in ["inc", "corp", "company", "manufacturer"]
|
||||
):
|
||||
elif any(word in query_lower for word in ["inc", "corp", "company", "manufacturer"]):
|
||||
suggested_type = EntityType.COMPANY
|
||||
elif EntityType.PARK in entity_types:
|
||||
suggested_type = EntityType.PARK
|
||||
@@ -382,21 +369,13 @@ class EntityFuzzyMatcher:
|
||||
suggested_name = " ".join(word.capitalize() for word in query.split())
|
||||
|
||||
# Check if user is authenticated
|
||||
is_authenticated = (
|
||||
user and hasattr(user, "is_authenticated") and user.is_authenticated
|
||||
)
|
||||
is_authenticated = user and hasattr(user, "is_authenticated") and user.is_authenticated
|
||||
|
||||
# Generate appropriate prompts
|
||||
entity_name = suggested_type.value
|
||||
login_prompt = (
|
||||
f"Log in to suggest adding '{suggested_name}' as a new {entity_name}"
|
||||
)
|
||||
signup_prompt = (
|
||||
f"Sign up to contribute and add '{suggested_name}' to ThrillWiki"
|
||||
)
|
||||
creation_hint = (
|
||||
f"Help expand ThrillWiki by adding information about '{suggested_name}'"
|
||||
)
|
||||
login_prompt = f"Log in to suggest adding '{suggested_name}' as a new {entity_name}"
|
||||
signup_prompt = f"Sign up to contribute and add '{suggested_name}' to ThrillWiki"
|
||||
creation_hint = f"Help expand ThrillWiki by adding information about '{suggested_name}'"
|
||||
|
||||
return EntitySuggestion(
|
||||
suggested_name=suggested_name,
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Location adapters for converting between domain-specific models and UnifiedLocation.
|
||||
"""
|
||||
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.urls import reverse
|
||||
|
||||
@@ -45,15 +44,9 @@ class BaseLocationAdapter:
|
||||
class ParkLocationAdapter(BaseLocationAdapter):
|
||||
"""Converts Park/ParkLocation to UnifiedLocation."""
|
||||
|
||||
def to_unified_location(
|
||||
self, location_obj: ParkLocation
|
||||
) -> UnifiedLocation | None:
|
||||
def to_unified_location(self, location_obj: ParkLocation) -> UnifiedLocation | None:
|
||||
"""Convert ParkLocation to UnifiedLocation."""
|
||||
if (
|
||||
not location_obj.point
|
||||
or location_obj.latitude is None
|
||||
or location_obj.longitude is None
|
||||
):
|
||||
if not location_obj.point or location_obj.latitude is None or location_obj.longitude is None:
|
||||
return None
|
||||
|
||||
park = location_obj.park
|
||||
@@ -67,17 +60,11 @@ class ParkLocationAdapter(BaseLocationAdapter):
|
||||
metadata={
|
||||
"status": getattr(park, "status", "UNKNOWN"),
|
||||
"rating": (
|
||||
float(park.average_rating)
|
||||
if hasattr(park, "average_rating") and park.average_rating
|
||||
else None
|
||||
float(park.average_rating) if hasattr(park, "average_rating") and park.average_rating else None
|
||||
),
|
||||
"ride_count": getattr(park, "ride_count", 0),
|
||||
"coaster_count": getattr(park, "coaster_count", 0),
|
||||
"operator": (
|
||||
park.operator.name
|
||||
if hasattr(park, "operator") and park.operator
|
||||
else None
|
||||
),
|
||||
"operator": (park.operator.name if hasattr(park, "operator") and park.operator else None),
|
||||
"city": location_obj.city,
|
||||
"state": location_obj.state,
|
||||
"country": location_obj.country,
|
||||
@@ -85,18 +72,14 @@ class ParkLocationAdapter(BaseLocationAdapter):
|
||||
type_data={
|
||||
"slug": park.slug,
|
||||
"opening_date": (
|
||||
park.opening_date.isoformat()
|
||||
if hasattr(park, "opening_date") and park.opening_date
|
||||
else None
|
||||
park.opening_date.isoformat() if hasattr(park, "opening_date") and park.opening_date else None
|
||||
),
|
||||
"website": getattr(park, "website", ""),
|
||||
"operating_season": getattr(park, "operating_season", ""),
|
||||
"highway_exit": location_obj.highway_exit,
|
||||
"parking_notes": location_obj.parking_notes,
|
||||
"best_arrival_time": (
|
||||
location_obj.best_arrival_time.strftime("%H:%M")
|
||||
if location_obj.best_arrival_time
|
||||
else None
|
||||
location_obj.best_arrival_time.strftime("%H:%M") if location_obj.best_arrival_time else None
|
||||
),
|
||||
"seasonal_notes": location_obj.seasonal_notes,
|
||||
"url": self._get_park_url(park),
|
||||
@@ -111,9 +94,7 @@ class ParkLocationAdapter(BaseLocationAdapter):
|
||||
filters: MapFilters | None = None,
|
||||
) -> QuerySet:
|
||||
"""Get optimized queryset for park locations."""
|
||||
queryset = ParkLocation.objects.select_related("park", "park__operator").filter(
|
||||
point__isnull=False
|
||||
)
|
||||
queryset = ParkLocation.objects.select_related("park", "park__operator").filter(point__isnull=False)
|
||||
|
||||
# Spatial filtering
|
||||
if bounds:
|
||||
@@ -139,17 +120,9 @@ class ParkLocationAdapter(BaseLocationAdapter):
|
||||
weight = 1
|
||||
if hasattr(park, "ride_count") and park.ride_count and park.ride_count > 20:
|
||||
weight += 2
|
||||
if (
|
||||
hasattr(park, "coaster_count")
|
||||
and park.coaster_count
|
||||
and park.coaster_count > 5
|
||||
):
|
||||
if hasattr(park, "coaster_count") and park.coaster_count and park.coaster_count > 5:
|
||||
weight += 1
|
||||
if (
|
||||
hasattr(park, "average_rating")
|
||||
and park.average_rating
|
||||
and park.average_rating > 4.0
|
||||
):
|
||||
if hasattr(park, "average_rating") and park.average_rating and park.average_rating > 4.0:
|
||||
weight += 1
|
||||
return min(weight, 5) # Cap at 5
|
||||
|
||||
@@ -176,15 +149,9 @@ class ParkLocationAdapter(BaseLocationAdapter):
|
||||
class RideLocationAdapter(BaseLocationAdapter):
|
||||
"""Converts Ride/RideLocation to UnifiedLocation."""
|
||||
|
||||
def to_unified_location(
|
||||
self, location_obj: RideLocation
|
||||
) -> UnifiedLocation | None:
|
||||
def to_unified_location(self, location_obj: RideLocation) -> UnifiedLocation | None:
|
||||
"""Convert RideLocation to UnifiedLocation."""
|
||||
if (
|
||||
not location_obj.point
|
||||
or location_obj.latitude is None
|
||||
or location_obj.longitude is None
|
||||
):
|
||||
if not location_obj.point or location_obj.latitude is None or location_obj.longitude is None:
|
||||
return None
|
||||
|
||||
ride = location_obj.ride
|
||||
@@ -194,11 +161,7 @@ class RideLocationAdapter(BaseLocationAdapter):
|
||||
type=LocationType.RIDE,
|
||||
name=ride.name,
|
||||
coordinates=[float(location_obj.latitude), float(location_obj.longitude)],
|
||||
address=(
|
||||
f"{location_obj.park_area}, {ride.park.name}"
|
||||
if location_obj.park_area
|
||||
else ride.park.name
|
||||
),
|
||||
address=(f"{location_obj.park_area}, {ride.park.name}" if location_obj.park_area else ride.park.name),
|
||||
metadata={
|
||||
"park_id": ride.park.id,
|
||||
"park_name": ride.park.name,
|
||||
@@ -206,22 +169,16 @@ class RideLocationAdapter(BaseLocationAdapter):
|
||||
"ride_type": getattr(ride, "ride_type", "Unknown"),
|
||||
"status": getattr(ride, "status", "UNKNOWN"),
|
||||
"rating": (
|
||||
float(ride.average_rating)
|
||||
if hasattr(ride, "average_rating") and ride.average_rating
|
||||
else None
|
||||
float(ride.average_rating) if hasattr(ride, "average_rating") and ride.average_rating else None
|
||||
),
|
||||
"manufacturer": (
|
||||
getattr(ride, "manufacturer", {}).get("name")
|
||||
if hasattr(ride, "manufacturer")
|
||||
else None
|
||||
getattr(ride, "manufacturer", {}).get("name") if hasattr(ride, "manufacturer") else None
|
||||
),
|
||||
},
|
||||
type_data={
|
||||
"slug": ride.slug,
|
||||
"opening_date": (
|
||||
ride.opening_date.isoformat()
|
||||
if hasattr(ride, "opening_date") and ride.opening_date
|
||||
else None
|
||||
ride.opening_date.isoformat() if hasattr(ride, "opening_date") and ride.opening_date else None
|
||||
),
|
||||
"height_requirement": getattr(ride, "height_requirement", ""),
|
||||
"duration_minutes": getattr(ride, "duration_minutes", None),
|
||||
@@ -240,9 +197,9 @@ class RideLocationAdapter(BaseLocationAdapter):
|
||||
filters: MapFilters | None = None,
|
||||
) -> QuerySet:
|
||||
"""Get optimized queryset for ride locations."""
|
||||
queryset = RideLocation.objects.select_related(
|
||||
"ride", "ride__park", "ride__park__operator"
|
||||
).filter(point__isnull=False)
|
||||
queryset = RideLocation.objects.select_related("ride", "ride__park", "ride__park__operator").filter(
|
||||
point__isnull=False
|
||||
)
|
||||
|
||||
# Spatial filtering
|
||||
if bounds:
|
||||
@@ -263,11 +220,7 @@ class RideLocationAdapter(BaseLocationAdapter):
|
||||
ride_type = getattr(ride, "ride_type", "").lower()
|
||||
if "coaster" in ride_type or "roller" in ride_type:
|
||||
weight += 1
|
||||
if (
|
||||
hasattr(ride, "average_rating")
|
||||
and ride.average_rating
|
||||
and ride.average_rating > 4.0
|
||||
):
|
||||
if hasattr(ride, "average_rating") and ride.average_rating and ride.average_rating > 4.0:
|
||||
weight += 1
|
||||
return min(weight, 3) # Cap at 3 for rides
|
||||
|
||||
@@ -292,9 +245,7 @@ class RideLocationAdapter(BaseLocationAdapter):
|
||||
class CompanyLocationAdapter(BaseLocationAdapter):
|
||||
"""Converts Company/CompanyHeadquarters to UnifiedLocation."""
|
||||
|
||||
def to_unified_location(
|
||||
self, location_obj: CompanyHeadquarters
|
||||
) -> UnifiedLocation | None:
|
||||
def to_unified_location(self, location_obj: CompanyHeadquarters) -> UnifiedLocation | None:
|
||||
"""Convert CompanyHeadquarters to UnifiedLocation."""
|
||||
# Note: CompanyHeadquarters doesn't have coordinates, so we need to geocode
|
||||
# For now, we'll skip companies without coordinates
|
||||
@@ -312,13 +263,9 @@ class CompanyLocationAdapter(BaseLocationAdapter):
|
||||
# Company-specific filters
|
||||
if filters:
|
||||
if filters.company_roles:
|
||||
queryset = queryset.filter(
|
||||
company__roles__overlap=filters.company_roles
|
||||
)
|
||||
queryset = queryset.filter(company__roles__overlap=filters.company_roles)
|
||||
if filters.search_query:
|
||||
queryset = queryset.filter(
|
||||
company__name__icontains=filters.search_query
|
||||
)
|
||||
queryset = queryset.filter(company__name__icontains=filters.search_query)
|
||||
if filters.country:
|
||||
queryset = queryset.filter(country=filters.country)
|
||||
if filters.city:
|
||||
@@ -354,11 +301,7 @@ class LocationAbstractionLayer:
|
||||
all_locations = []
|
||||
|
||||
# Determine which location types to include
|
||||
location_types = (
|
||||
filters.location_types
|
||||
if filters and filters.location_types
|
||||
else set(LocationType)
|
||||
)
|
||||
location_types = filters.location_types if filters and filters.location_types else set(LocationType)
|
||||
|
||||
for location_type in location_types:
|
||||
adapter = self.adapters[location_type]
|
||||
@@ -379,25 +322,17 @@ class LocationAbstractionLayer:
|
||||
queryset = adapter.get_queryset(bounds, filters)
|
||||
return adapter.bulk_convert(queryset)
|
||||
|
||||
def get_location_by_id(
|
||||
self, location_type: LocationType, location_id: int
|
||||
) -> UnifiedLocation | None:
|
||||
def get_location_by_id(self, location_type: LocationType, location_id: int) -> UnifiedLocation | None:
|
||||
"""Get single location with full details."""
|
||||
adapter = self.adapters[location_type]
|
||||
|
||||
try:
|
||||
if location_type == LocationType.PARK:
|
||||
obj = ParkLocation.objects.select_related("park", "park__operator").get(
|
||||
park_id=location_id
|
||||
)
|
||||
obj = ParkLocation.objects.select_related("park", "park__operator").get(park_id=location_id)
|
||||
elif location_type == LocationType.RIDE:
|
||||
obj = RideLocation.objects.select_related("ride", "ride__park").get(
|
||||
ride_id=location_id
|
||||
)
|
||||
obj = RideLocation.objects.select_related("ride", "ride__park").get(ride_id=location_id)
|
||||
elif location_type == LocationType.COMPANY:
|
||||
obj = CompanyHeadquarters.objects.select_related("company").get(
|
||||
company_id=location_id
|
||||
)
|
||||
obj = CompanyHeadquarters.objects.select_related("company").get(company_id=location_id)
|
||||
# LocationType.GENERIC removed - generic location app deprecated
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -128,9 +128,7 @@ class LocationSearchService:
|
||||
# Apply max results limit
|
||||
return results[: filters.max_results]
|
||||
|
||||
def _search_parks(
|
||||
self, filters: LocationSearchFilters
|
||||
) -> list[LocationSearchResult]:
|
||||
def _search_parks(self, filters: LocationSearchFilters) -> list[LocationSearchResult]:
|
||||
"""Search parks with location data."""
|
||||
queryset = Park.objects.select_related("location", "operator").all()
|
||||
|
||||
@@ -154,9 +152,9 @@ class LocationSearchService:
|
||||
|
||||
# Add distance annotation if proximity search
|
||||
if filters.location_point and filters.include_distance:
|
||||
queryset = queryset.annotate(
|
||||
distance=Distance("location__point", filters.location_point)
|
||||
).order_by("distance")
|
||||
queryset = queryset.annotate(distance=Distance("location__point", filters.location_point)).order_by(
|
||||
"distance"
|
||||
)
|
||||
|
||||
# Convert to search results
|
||||
results = []
|
||||
@@ -166,11 +164,7 @@ class LocationSearchService:
|
||||
object_id=park.id,
|
||||
name=park.name,
|
||||
description=park.description,
|
||||
url=(
|
||||
park.get_absolute_url()
|
||||
if hasattr(park, "get_absolute_url")
|
||||
else None
|
||||
),
|
||||
url=(park.get_absolute_url() if hasattr(park, "get_absolute_url") else None),
|
||||
status=park.get_status_display(),
|
||||
rating=(float(park.average_rating) if park.average_rating else None),
|
||||
tags=["park", park.status.lower()],
|
||||
@@ -187,20 +181,14 @@ class LocationSearchService:
|
||||
result.country = location.country
|
||||
|
||||
# Add distance if proximity search
|
||||
if (
|
||||
filters.location_point
|
||||
and filters.include_distance
|
||||
and hasattr(park, "distance")
|
||||
):
|
||||
if filters.location_point and filters.include_distance and hasattr(park, "distance"):
|
||||
result.distance_km = float(park.distance.km)
|
||||
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def _search_rides(
|
||||
self, filters: LocationSearchFilters
|
||||
) -> list[LocationSearchResult]:
|
||||
def _search_rides(self, filters: LocationSearchFilters) -> list[LocationSearchResult]:
|
||||
"""Search rides with location data."""
|
||||
queryset = Ride.objects.select_related("park", "location").all()
|
||||
|
||||
@@ -223,9 +211,9 @@ class LocationSearchService:
|
||||
|
||||
# Add distance annotation if proximity search
|
||||
if filters.location_point and filters.include_distance:
|
||||
queryset = queryset.annotate(
|
||||
distance=Distance("location__point", filters.location_point)
|
||||
).order_by("distance")
|
||||
queryset = queryset.annotate(distance=Distance("location__point", filters.location_point)).order_by(
|
||||
"distance"
|
||||
)
|
||||
|
||||
# Convert to search results
|
||||
results = []
|
||||
@@ -235,11 +223,7 @@ class LocationSearchService:
|
||||
object_id=ride.id,
|
||||
name=ride.name,
|
||||
description=ride.description,
|
||||
url=(
|
||||
ride.get_absolute_url()
|
||||
if hasattr(ride, "get_absolute_url")
|
||||
else None
|
||||
),
|
||||
url=(ride.get_absolute_url() if hasattr(ride, "get_absolute_url") else None),
|
||||
status=ride.status,
|
||||
tags=[
|
||||
"ride",
|
||||
@@ -253,18 +237,10 @@ class LocationSearchService:
|
||||
location = ride.location
|
||||
result.latitude = location.latitude
|
||||
result.longitude = location.longitude
|
||||
result.address = (
|
||||
f"{ride.park.name} - {location.park_area}"
|
||||
if location.park_area
|
||||
else ride.park.name
|
||||
)
|
||||
result.address = f"{ride.park.name} - {location.park_area}" if location.park_area else ride.park.name
|
||||
|
||||
# Add distance if proximity search
|
||||
if (
|
||||
filters.location_point
|
||||
and filters.include_distance
|
||||
and hasattr(ride, "distance")
|
||||
):
|
||||
if filters.location_point and filters.include_distance and hasattr(ride, "distance"):
|
||||
result.distance_km = float(ride.distance.km)
|
||||
|
||||
# Fall back to park location if no specific ride location
|
||||
@@ -281,16 +257,12 @@ class LocationSearchService:
|
||||
|
||||
return results
|
||||
|
||||
def _search_companies(
|
||||
self, filters: LocationSearchFilters
|
||||
) -> list[LocationSearchResult]:
|
||||
def _search_companies(self, filters: LocationSearchFilters) -> list[LocationSearchResult]:
|
||||
"""Search companies with headquarters location data."""
|
||||
queryset = Company.objects.select_related("headquarters").all()
|
||||
|
||||
# Apply location filters
|
||||
queryset = self._apply_location_filters(
|
||||
queryset, filters, "headquarters__point"
|
||||
)
|
||||
queryset = self._apply_location_filters(queryset, filters, "headquarters__point")
|
||||
|
||||
# Apply text search
|
||||
if filters.search_query:
|
||||
@@ -309,9 +281,9 @@ class LocationSearchService:
|
||||
|
||||
# Add distance annotation if proximity search
|
||||
if filters.location_point and filters.include_distance:
|
||||
queryset = queryset.annotate(
|
||||
distance=Distance("headquarters__point", filters.location_point)
|
||||
).order_by("distance")
|
||||
queryset = queryset.annotate(distance=Distance("headquarters__point", filters.location_point)).order_by(
|
||||
"distance"
|
||||
)
|
||||
|
||||
# Convert to search results
|
||||
results = []
|
||||
@@ -321,11 +293,7 @@ class LocationSearchService:
|
||||
object_id=company.id,
|
||||
name=company.name,
|
||||
description=company.description,
|
||||
url=(
|
||||
company.get_absolute_url()
|
||||
if hasattr(company, "get_absolute_url")
|
||||
else None
|
||||
),
|
||||
url=(company.get_absolute_url() if hasattr(company, "get_absolute_url") else None),
|
||||
tags=["company"] + (company.roles or []),
|
||||
)
|
||||
|
||||
@@ -340,20 +308,14 @@ class LocationSearchService:
|
||||
result.country = hq.country
|
||||
|
||||
# Add distance if proximity search
|
||||
if (
|
||||
filters.location_point
|
||||
and filters.include_distance
|
||||
and hasattr(company, "distance")
|
||||
):
|
||||
if filters.location_point and filters.include_distance and hasattr(company, "distance"):
|
||||
result.distance_km = float(company.distance.km)
|
||||
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
def _apply_location_filters(
|
||||
self, queryset, filters: LocationSearchFilters, point_field: str
|
||||
):
|
||||
def _apply_location_filters(self, queryset, filters: LocationSearchFilters, point_field: str):
|
||||
"""Apply common location filters to a queryset."""
|
||||
|
||||
# Proximity filter
|
||||
@@ -371,31 +333,21 @@ class LocationSearchService:
|
||||
# Geographic filters - adjust field names based on model
|
||||
if filters.country:
|
||||
if "headquarters" in point_field:
|
||||
queryset = queryset.filter(
|
||||
headquarters__country__icontains=filters.country
|
||||
)
|
||||
queryset = queryset.filter(headquarters__country__icontains=filters.country)
|
||||
else:
|
||||
location_field = point_field.split("__")[0]
|
||||
queryset = queryset.filter(
|
||||
**{f"{location_field}__country__icontains": filters.country}
|
||||
)
|
||||
queryset = queryset.filter(**{f"{location_field}__country__icontains": filters.country})
|
||||
|
||||
if filters.state:
|
||||
if "headquarters" in point_field:
|
||||
queryset = queryset.filter(
|
||||
headquarters__state_province__icontains=filters.state
|
||||
)
|
||||
queryset = queryset.filter(headquarters__state_province__icontains=filters.state)
|
||||
else:
|
||||
location_field = point_field.split("__")[0]
|
||||
queryset = queryset.filter(
|
||||
**{f"{location_field}__state__icontains": filters.state}
|
||||
)
|
||||
queryset = queryset.filter(**{f"{location_field}__state__icontains": filters.state})
|
||||
|
||||
if filters.city:
|
||||
location_field = point_field.split("__")[0]
|
||||
queryset = queryset.filter(
|
||||
**{f"{location_field}__city__icontains": filters.city}
|
||||
)
|
||||
queryset = queryset.filter(**{f"{location_field}__city__icontains": filters.city})
|
||||
|
||||
return queryset
|
||||
|
||||
@@ -417,9 +369,7 @@ class LocationSearchService:
|
||||
|
||||
# Get park location suggestions
|
||||
park_locations = ParkLocation.objects.filter(
|
||||
Q(park__name__icontains=query)
|
||||
| Q(city__icontains=query)
|
||||
| Q(state__icontains=query)
|
||||
Q(park__name__icontains=query) | Q(city__icontains=query) | Q(state__icontains=query)
|
||||
).select_related("park")[: limit // 3]
|
||||
|
||||
for location in park_locations:
|
||||
@@ -429,11 +379,7 @@ class LocationSearchService:
|
||||
"name": location.park.name,
|
||||
"address": location.formatted_address,
|
||||
"coordinates": location.coordinates,
|
||||
"url": (
|
||||
location.park.get_absolute_url()
|
||||
if hasattr(location.park, "get_absolute_url")
|
||||
else None
|
||||
),
|
||||
"url": (location.park.get_absolute_url() if hasattr(location.park, "get_absolute_url") else None),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -93,9 +93,7 @@ class MapCacheService:
|
||||
|
||||
return ":".join(key_parts)
|
||||
|
||||
def get_location_detail_cache_key(
|
||||
self, location_type: str, location_id: int
|
||||
) -> str:
|
||||
def get_location_detail_cache_key(self, location_type: str, location_id: int) -> str:
|
||||
"""Generate cache key for individual location details."""
|
||||
return f"{self.DETAIL_PREFIX}:{location_type}:{location_id}"
|
||||
|
||||
@@ -137,9 +135,7 @@ class MapCacheService:
|
||||
except Exception as e:
|
||||
print(f"Cache write error for clusters {cache_key}: {e}")
|
||||
|
||||
def cache_map_response(
|
||||
self, cache_key: str, response: MapResponse, ttl: int | None = None
|
||||
) -> None:
|
||||
def cache_map_response(self, cache_key: str, response: MapResponse, ttl: int | None = None) -> None:
|
||||
"""Cache complete map response."""
|
||||
try:
|
||||
cache_data = response.to_dict()
|
||||
@@ -212,24 +208,18 @@ class MapCacheService:
|
||||
self.cache_stats["misses"] += 1
|
||||
return None
|
||||
|
||||
def invalidate_location_cache(
|
||||
self, location_type: str, location_id: int | None = None
|
||||
) -> None:
|
||||
def invalidate_location_cache(self, location_type: str, location_id: int | None = None) -> None:
|
||||
"""Invalidate cache for specific location or all locations of a type."""
|
||||
try:
|
||||
if location_id:
|
||||
# Invalidate specific location detail
|
||||
detail_key = self.get_location_detail_cache_key(
|
||||
location_type, location_id
|
||||
)
|
||||
detail_key = self.get_location_detail_cache_key(location_type, location_id)
|
||||
cache.delete(detail_key)
|
||||
|
||||
# Invalidate related location and cluster caches
|
||||
# In a production system, you'd want more sophisticated cache
|
||||
# tagging
|
||||
cache.delete_many(
|
||||
[f"{self.LOCATIONS_PREFIX}:*", f"{self.CLUSTERS_PREFIX}:*"]
|
||||
)
|
||||
cache.delete_many([f"{self.LOCATIONS_PREFIX}:*", f"{self.CLUSTERS_PREFIX}:*"])
|
||||
|
||||
self.cache_stats["invalidations"] += 1
|
||||
|
||||
@@ -271,11 +261,7 @@ class MapCacheService:
|
||||
def get_cache_stats(self) -> dict[str, Any]:
|
||||
"""Get cache performance statistics."""
|
||||
total_requests = self.cache_stats["hits"] + self.cache_stats["misses"]
|
||||
hit_rate = (
|
||||
(self.cache_stats["hits"] / total_requests * 100)
|
||||
if total_requests > 0
|
||||
else 0
|
||||
)
|
||||
hit_rate = (self.cache_stats["hits"] / total_requests * 100) if total_requests > 0 else 0
|
||||
|
||||
return {
|
||||
"hits": self.cache_stats["hits"],
|
||||
@@ -408,12 +394,8 @@ class MapCacheService:
|
||||
|
||||
def _dict_to_map_response(self, data: dict[str, Any]) -> MapResponse:
|
||||
"""Convert dictionary back to MapResponse object."""
|
||||
locations = [
|
||||
self._dict_to_unified_location(loc) for loc in data.get("locations", [])
|
||||
]
|
||||
clusters = [
|
||||
self._dict_to_cluster_data(cluster) for cluster in data.get("clusters", [])
|
||||
]
|
||||
locations = [self._dict_to_unified_location(loc) for loc in data.get("locations", [])]
|
||||
clusters = [self._dict_to_cluster_data(cluster) for cluster in data.get("clusters", [])]
|
||||
|
||||
bounds = None
|
||||
if data.get("bounds"):
|
||||
|
||||
@@ -67,17 +67,13 @@ class UnifiedMapService:
|
||||
# Generate cache key
|
||||
cache_key = None
|
||||
if use_cache:
|
||||
cache_key = self._generate_cache_key(
|
||||
bounds, filters, zoom_level, cluster
|
||||
)
|
||||
cache_key = self._generate_cache_key(bounds, filters, zoom_level, cluster)
|
||||
|
||||
# Try to get from cache first
|
||||
cached_response = self.cache_service.get_cached_map_response(cache_key)
|
||||
if cached_response:
|
||||
cached_response.cache_hit = True
|
||||
cached_response.query_time_ms = int(
|
||||
(time.time() - start_time) * 1000
|
||||
)
|
||||
cached_response.query_time_ms = int((time.time() - start_time) * 1000)
|
||||
return cached_response
|
||||
|
||||
# Get locations from database
|
||||
@@ -87,21 +83,15 @@ class UnifiedMapService:
|
||||
locations = self._apply_smart_limiting(locations, bounds, zoom_level)
|
||||
|
||||
# Determine if clustering should be applied
|
||||
should_cluster = cluster and self.clustering_service.should_cluster(
|
||||
zoom_level, len(locations)
|
||||
)
|
||||
should_cluster = cluster and self.clustering_service.should_cluster(zoom_level, len(locations))
|
||||
|
||||
# Apply clustering if needed
|
||||
clusters = []
|
||||
if should_cluster:
|
||||
locations, clusters = self.clustering_service.cluster_locations(
|
||||
locations, zoom_level, bounds
|
||||
)
|
||||
locations, clusters = self.clustering_service.cluster_locations(locations, zoom_level, bounds)
|
||||
|
||||
# Calculate response bounds
|
||||
response_bounds = self._calculate_response_bounds(
|
||||
locations, clusters, bounds
|
||||
)
|
||||
response_bounds = self._calculate_response_bounds(locations, clusters, bounds)
|
||||
|
||||
# Create response
|
||||
response = MapResponse(
|
||||
@@ -144,9 +134,7 @@ class UnifiedMapService:
|
||||
cache_hit=False,
|
||||
)
|
||||
|
||||
def get_location_details(
|
||||
self, location_type: str, location_id: int
|
||||
) -> UnifiedLocation | None:
|
||||
def get_location_details(self, location_type: str, location_id: int) -> UnifiedLocation | None:
|
||||
"""
|
||||
Get detailed information for a specific location.
|
||||
|
||||
@@ -159,18 +147,14 @@ class UnifiedMapService:
|
||||
"""
|
||||
try:
|
||||
# Check cache first
|
||||
cache_key = self.cache_service.get_location_detail_cache_key(
|
||||
location_type, location_id
|
||||
)
|
||||
cache_key = self.cache_service.get_location_detail_cache_key(location_type, location_id)
|
||||
cached_locations = self.cache_service.get_cached_locations(cache_key)
|
||||
if cached_locations:
|
||||
return cached_locations[0] if cached_locations else None
|
||||
|
||||
# Get from database
|
||||
location_type_enum = LocationType(location_type.lower())
|
||||
location = self.location_layer.get_location_by_id(
|
||||
location_type_enum, location_id
|
||||
)
|
||||
location = self.location_layer.get_location_by_id(location_type_enum, location_id)
|
||||
|
||||
# Cache the result
|
||||
if location:
|
||||
@@ -245,19 +229,13 @@ class UnifiedMapService:
|
||||
"""
|
||||
try:
|
||||
bounds = GeoBounds(north=north, south=south, east=east, west=west)
|
||||
filters = (
|
||||
MapFilters(location_types=location_types) if location_types else None
|
||||
)
|
||||
filters = MapFilters(location_types=location_types) if location_types else None
|
||||
|
||||
return self.get_map_data(
|
||||
bounds=bounds, filters=filters, zoom_level=zoom_level
|
||||
)
|
||||
return self.get_map_data(bounds=bounds, filters=filters, zoom_level=zoom_level)
|
||||
|
||||
except ValueError:
|
||||
# Invalid bounds
|
||||
return MapResponse(
|
||||
locations=[], clusters=[], total_count=0, filtered_count=0
|
||||
)
|
||||
return MapResponse(locations=[], clusters=[], total_count=0, filtered_count=0)
|
||||
|
||||
def get_clustered_locations(
|
||||
self,
|
||||
@@ -276,9 +254,7 @@ class UnifiedMapService:
|
||||
Returns:
|
||||
MapResponse with clustered data
|
||||
"""
|
||||
return self.get_map_data(
|
||||
bounds=bounds, filters=filters, zoom_level=zoom_level, cluster=True
|
||||
)
|
||||
return self.get_map_data(bounds=bounds, filters=filters, zoom_level=zoom_level, cluster=True)
|
||||
|
||||
def get_locations_by_type(
|
||||
self,
|
||||
@@ -299,9 +275,7 @@ class UnifiedMapService:
|
||||
"""
|
||||
try:
|
||||
filters = MapFilters(location_types={location_type})
|
||||
locations = self.location_layer.get_locations_by_type(
|
||||
location_type, bounds, filters
|
||||
)
|
||||
locations = self.location_layer.get_locations_by_type(location_type, bounds, filters)
|
||||
|
||||
if limit:
|
||||
locations = locations[:limit]
|
||||
@@ -346,9 +320,7 @@ class UnifiedMapService:
|
||||
"service_version": "1.0.0",
|
||||
}
|
||||
|
||||
def _get_locations_from_db(
|
||||
self, bounds: GeoBounds | None, filters: MapFilters | None
|
||||
) -> list[UnifiedLocation]:
|
||||
def _get_locations_from_db(self, bounds: GeoBounds | None, filters: MapFilters | None) -> list[UnifiedLocation]:
|
||||
"""Get locations from database using the abstraction layer."""
|
||||
return self.location_layer.get_all_locations(bounds, filters)
|
||||
|
||||
@@ -363,10 +335,7 @@ class UnifiedMapService:
|
||||
major_parks = [
|
||||
loc
|
||||
for loc in locations
|
||||
if (
|
||||
loc.type == LocationType.PARK
|
||||
and loc.cluster_category in ["major_park", "theme_park"]
|
||||
)
|
||||
if (loc.type == LocationType.PARK and loc.cluster_category in ["major_park", "theme_park"])
|
||||
]
|
||||
return major_parks[:200]
|
||||
elif zoom_level < 10: # Regional level
|
||||
@@ -398,9 +367,7 @@ class UnifiedMapService:
|
||||
return None
|
||||
|
||||
lats, lngs = zip(*all_coords, strict=False)
|
||||
return GeoBounds(
|
||||
north=max(lats), south=min(lats), east=max(lngs), west=min(lngs)
|
||||
)
|
||||
return GeoBounds(north=max(lats), south=min(lats), east=max(lngs), west=min(lngs))
|
||||
|
||||
def _get_applied_filters_list(self, filters: MapFilters | None) -> list[str]:
|
||||
"""Get list of applied filter types for metadata."""
|
||||
@@ -438,13 +405,9 @@ class UnifiedMapService:
|
||||
) -> str:
|
||||
"""Generate cache key for the request."""
|
||||
if cluster:
|
||||
return self.cache_service.get_clusters_cache_key(
|
||||
bounds, filters, zoom_level
|
||||
)
|
||||
return self.cache_service.get_clusters_cache_key(bounds, filters, zoom_level)
|
||||
else:
|
||||
return self.cache_service.get_locations_cache_key(
|
||||
bounds, filters, zoom_level
|
||||
)
|
||||
return self.cache_service.get_locations_cache_key(bounds, filters, zoom_level)
|
||||
|
||||
def _record_performance_metrics(
|
||||
self,
|
||||
|
||||
@@ -21,9 +21,7 @@ class MediaService:
|
||||
"""Shared service for media upload and processing operations."""
|
||||
|
||||
@staticmethod
|
||||
def generate_upload_path(
|
||||
domain: str, identifier: str, filename: str, subdirectory: str | None = None
|
||||
) -> str:
|
||||
def generate_upload_path(domain: str, identifier: str, filename: str, subdirectory: str | None = None) -> str:
|
||||
"""
|
||||
Generate standardized upload path for media files.
|
||||
|
||||
@@ -83,9 +81,7 @@ class MediaService:
|
||||
"""
|
||||
try:
|
||||
# Check file size
|
||||
max_size = getattr(
|
||||
settings, "MAX_PHOTO_SIZE", 10 * 1024 * 1024
|
||||
) # 10MB default
|
||||
max_size = getattr(settings, "MAX_PHOTO_SIZE", 10 * 1024 * 1024) # 10MB default
|
||||
if image_file.size > max_size:
|
||||
return (
|
||||
False,
|
||||
|
||||
@@ -32,7 +32,7 @@ class MediaURLService:
|
||||
slug = slugify(caption)
|
||||
# Limit length to avoid overly long URLs
|
||||
if len(slug) > 50:
|
||||
slug = slug[:50].rsplit('-', 1)[0] # Cut at word boundary
|
||||
slug = slug[:50].rsplit("-", 1)[0] # Cut at word boundary
|
||||
return f"{slug}-{photo_id}.{extension}"
|
||||
else:
|
||||
return f"photo-{photo_id}.{extension}"
|
||||
@@ -55,13 +55,15 @@ class MediaURLService:
|
||||
|
||||
# Add variant to filename if not public
|
||||
if variant != "public":
|
||||
name, ext = filename.rsplit('.', 1)
|
||||
name, ext = filename.rsplit(".", 1)
|
||||
filename = f"{name}-{variant}.{ext}"
|
||||
|
||||
return f"/parks/{park_slug}/photos/{filename}"
|
||||
|
||||
@staticmethod
|
||||
def generate_ride_photo_url(park_slug: str, ride_slug: str, caption: str, photo_id: int, variant: str = "public") -> str:
|
||||
def generate_ride_photo_url(
|
||||
park_slug: str, ride_slug: str, caption: str, photo_id: int, variant: str = "public"
|
||||
) -> str:
|
||||
"""
|
||||
Generate a friendly URL for a ride photo.
|
||||
|
||||
@@ -78,7 +80,7 @@ class MediaURLService:
|
||||
filename = MediaURLService.generate_friendly_filename(caption, photo_id)
|
||||
|
||||
if variant != "public":
|
||||
name, ext = filename.rsplit('.', 1)
|
||||
name, ext = filename.rsplit(".", 1)
|
||||
filename = f"{name}-{variant}.{ext}"
|
||||
|
||||
return f"/parks/{park_slug}/rides/{ride_slug}/photos/{filename}"
|
||||
@@ -95,7 +97,7 @@ class MediaURLService:
|
||||
Dict with photo_id and variant, or None if parsing fails
|
||||
"""
|
||||
# Remove extension
|
||||
name = filename.rsplit('.', 1)[0]
|
||||
name = filename.rsplit(".", 1)[0]
|
||||
|
||||
# Check for variant suffix
|
||||
variant = "public"
|
||||
@@ -104,17 +106,14 @@ class MediaURLService:
|
||||
for v in variant_patterns:
|
||||
if name.endswith(f"-{v}"):
|
||||
variant = v
|
||||
name = name[:-len(f"-{v}")]
|
||||
name = name[: -len(f"-{v}")]
|
||||
break
|
||||
|
||||
# Extract photo ID (should be the last number)
|
||||
match = re.search(r'-(\d+)$', name)
|
||||
match = re.search(r"-(\d+)$", name)
|
||||
if match:
|
||||
photo_id = int(match.group(1))
|
||||
return {
|
||||
"photo_id": photo_id,
|
||||
"variant": variant
|
||||
}
|
||||
return {"photo_id": photo_id, "variant": variant}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -53,9 +53,7 @@ def monitor_performance(operation_name: str, **tags):
|
||||
)
|
||||
|
||||
# Log performance data
|
||||
log_level = (
|
||||
logging.WARNING if duration > 2.0 or total_queries > 10 else logging.INFO
|
||||
)
|
||||
log_level = logging.WARNING if duration > 2.0 or total_queries > 10 else logging.INFO
|
||||
logger.log(
|
||||
log_level,
|
||||
f"Performance: {operation_name} completed in {duration:.3f}s with {
|
||||
@@ -108,11 +106,7 @@ def track_queries(operation_name: str, warn_threshold: int = 10):
|
||||
recent_queries = connection.queries[-total_queries:]
|
||||
query_details = [
|
||||
{
|
||||
"sql": (
|
||||
query["sql"][:200] + "..."
|
||||
if len(query["sql"]) > 200
|
||||
else query["sql"]
|
||||
),
|
||||
"sql": (query["sql"][:200] + "..." if len(query["sql"]) > 200 else query["sql"]),
|
||||
"time": float(query["time"]),
|
||||
}
|
||||
for query in recent_queries
|
||||
@@ -127,14 +121,12 @@ def track_queries(operation_name: str, warn_threshold: int = 10):
|
||||
|
||||
if total_queries > warn_threshold or execution_time > 1.0:
|
||||
logger.warning(
|
||||
f"Performance concern in {operation_name}: "
|
||||
f"{total_queries} queries, {execution_time:.2f}s",
|
||||
f"Performance concern in {operation_name}: " f"{total_queries} queries, {execution_time:.2f}s",
|
||||
extra=performance_data,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Query tracking for {operation_name}: "
|
||||
f"{total_queries} queries, {execution_time:.2f}s",
|
||||
f"Query tracking for {operation_name}: " f"{total_queries} queries, {execution_time:.2f}s",
|
||||
extra=performance_data,
|
||||
)
|
||||
|
||||
@@ -221,9 +213,7 @@ class PerformanceProfiler:
|
||||
"total_queries": total_queries,
|
||||
"checkpoints": self.checkpoints,
|
||||
"memory_usage": self.memory_usage,
|
||||
"queries_per_second": (
|
||||
total_queries / total_duration if total_duration > 0 else 0
|
||||
),
|
||||
"queries_per_second": (total_queries / total_duration if total_duration > 0 else 0),
|
||||
}
|
||||
|
||||
# Calculate checkpoint intervals
|
||||
@@ -237,8 +227,7 @@ class PerformanceProfiler:
|
||||
"from": prev["name"],
|
||||
"to": curr["name"],
|
||||
"duration": curr["elapsed_seconds"] - prev["elapsed_seconds"],
|
||||
"queries": curr["queries_since_start"]
|
||||
- prev["queries_since_start"],
|
||||
"queries": curr["queries_since_start"] - prev["queries_since_start"],
|
||||
}
|
||||
)
|
||||
report["checkpoint_intervals"] = intervals
|
||||
@@ -288,9 +277,7 @@ class DatabaseQueryAnalyzer:
|
||||
query_types[query_type] = query_types.get(query_type, 0) + 1
|
||||
|
||||
# Find slow queries (top 10% by time)
|
||||
sorted_queries = sorted(
|
||||
queries, key=lambda q: float(q.get("time", 0)), reverse=True
|
||||
)
|
||||
sorted_queries = sorted(queries, key=lambda q: float(q.get("time", 0)), reverse=True)
|
||||
slow_query_count = max(1, query_count // 10)
|
||||
slow_queries = sorted_queries[:slow_query_count]
|
||||
|
||||
@@ -302,9 +289,7 @@ class DatabaseQueryAnalyzer:
|
||||
signature = " ".join(sql.split()) # Normalize whitespace
|
||||
query_signatures[signature] = query_signatures.get(signature, 0) + 1
|
||||
|
||||
duplicates = {
|
||||
sig: count for sig, count in query_signatures.items() if count > 1
|
||||
}
|
||||
duplicates = {sig: count for sig, count in query_signatures.items() if count > 1}
|
||||
|
||||
analysis = {
|
||||
"total_queries": query_count,
|
||||
@@ -313,21 +298,13 @@ class DatabaseQueryAnalyzer:
|
||||
"query_types": query_types,
|
||||
"slow_queries": [
|
||||
{
|
||||
"sql": (
|
||||
q.get("sql", "")[:200] + "..."
|
||||
if len(q.get("sql", "")) > 200
|
||||
else q.get("sql", "")
|
||||
),
|
||||
"sql": (q.get("sql", "")[:200] + "..." if len(q.get("sql", "")) > 200 else q.get("sql", "")),
|
||||
"time": float(q.get("time", 0)),
|
||||
}
|
||||
for q in slow_queries
|
||||
],
|
||||
"duplicate_query_count": len(duplicates),
|
||||
"duplicate_queries": (
|
||||
duplicates
|
||||
if len(duplicates) <= 10
|
||||
else dict(list(duplicates.items())[:10])
|
||||
),
|
||||
"duplicate_queries": (duplicates if len(duplicates) <= 10 else dict(list(duplicates.items())[:10])),
|
||||
}
|
||||
|
||||
return analysis
|
||||
@@ -348,9 +325,7 @@ def monitor_function_performance(operation_name: str | None = None):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
name = operation_name or f"{func.__module__}.{func.__name__}"
|
||||
with monitor_performance(
|
||||
name, function=func.__name__, module=func.__module__
|
||||
):
|
||||
with monitor_performance(name, function=func.__name__, module=func.__module__):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -74,9 +74,7 @@ class TrendingService:
|
||||
if not force_refresh:
|
||||
cached_result = cache.get(cache_key)
|
||||
if cached_result is not None:
|
||||
self.logger.debug(
|
||||
f"Returning cached trending results for {content_type}"
|
||||
)
|
||||
self.logger.debug(f"Returning cached trending results for {content_type}")
|
||||
return cached_result
|
||||
|
||||
self.logger.info(f"Getting trending content for {content_type}")
|
||||
@@ -86,15 +84,11 @@ class TrendingService:
|
||||
trending_items = []
|
||||
|
||||
if content_type in ["all", "parks"]:
|
||||
park_items = self._calculate_trending_parks(
|
||||
limit * 2 if content_type == "all" else limit
|
||||
)
|
||||
park_items = self._calculate_trending_parks(limit * 2 if content_type == "all" else limit)
|
||||
trending_items.extend(park_items)
|
||||
|
||||
if content_type in ["all", "rides"]:
|
||||
ride_items = self._calculate_trending_rides(
|
||||
limit * 2 if content_type == "all" else limit
|
||||
)
|
||||
ride_items = self._calculate_trending_rides(limit * 2 if content_type == "all" else limit)
|
||||
trending_items.extend(ride_items)
|
||||
|
||||
# Sort by trending score and apply limit
|
||||
@@ -107,9 +101,7 @@ class TrendingService:
|
||||
# Cache results
|
||||
cache.set(cache_key, formatted_results, self.CACHE_TTL)
|
||||
|
||||
self.logger.info(
|
||||
f"Calculated {len(formatted_results)} trending items for {content_type}"
|
||||
)
|
||||
self.logger.info(f"Calculated {len(formatted_results)} trending items for {content_type}")
|
||||
return formatted_results
|
||||
|
||||
except Exception as e:
|
||||
@@ -140,9 +132,7 @@ class TrendingService:
|
||||
if not force_refresh:
|
||||
cached_result = cache.get(cache_key)
|
||||
if cached_result is not None:
|
||||
self.logger.debug(
|
||||
f"Returning cached new content results for {content_type}"
|
||||
)
|
||||
self.logger.debug(f"Returning cached new content results for {content_type}")
|
||||
return cached_result
|
||||
|
||||
self.logger.info(f"Getting new content for {content_type}")
|
||||
@@ -153,15 +143,11 @@ class TrendingService:
|
||||
new_items = []
|
||||
|
||||
if content_type in ["all", "parks"]:
|
||||
parks = self._get_new_parks(
|
||||
cutoff_date, limit * 2 if content_type == "all" else limit
|
||||
)
|
||||
parks = self._get_new_parks(cutoff_date, limit * 2 if content_type == "all" else limit)
|
||||
new_items.extend(parks)
|
||||
|
||||
if content_type in ["all", "rides"]:
|
||||
rides = self._get_new_rides(
|
||||
cutoff_date, limit * 2 if content_type == "all" else limit
|
||||
)
|
||||
rides = self._get_new_rides(cutoff_date, limit * 2 if content_type == "all" else limit)
|
||||
new_items.extend(rides)
|
||||
|
||||
# Sort by date added (most recent first) and apply limit
|
||||
@@ -174,9 +160,7 @@ class TrendingService:
|
||||
# Cache results
|
||||
cache.set(cache_key, formatted_results, 1800) # Cache for 30 minutes
|
||||
|
||||
self.logger.info(
|
||||
f"Calculated {len(formatted_results)} new items for {content_type}"
|
||||
)
|
||||
self.logger.info(f"Calculated {len(formatted_results)} new items for {content_type}")
|
||||
return formatted_results
|
||||
|
||||
except Exception as e:
|
||||
@@ -185,9 +169,7 @@ class TrendingService:
|
||||
|
||||
def _calculate_trending_parks(self, limit: int) -> list[dict[str, Any]]:
|
||||
"""Calculate trending scores for parks."""
|
||||
parks = Park.objects.filter(status="OPERATING").select_related(
|
||||
"location", "operator", "card_image"
|
||||
)
|
||||
parks = Park.objects.filter(status="OPERATING").select_related("location", "operator", "card_image")
|
||||
|
||||
trending_parks = []
|
||||
|
||||
@@ -216,9 +198,7 @@ class TrendingService:
|
||||
# Get card image URL
|
||||
card_image_url = ""
|
||||
if park.card_image and hasattr(park.card_image, "image"):
|
||||
card_image_url = (
|
||||
park.card_image.image.url if park.card_image.image else ""
|
||||
)
|
||||
card_image_url = park.card_image.image.url if park.card_image.image else ""
|
||||
|
||||
# Get primary company (operator)
|
||||
primary_company = park.operator.name if park.operator else ""
|
||||
@@ -233,14 +213,8 @@ class TrendingService:
|
||||
"slug": park.slug,
|
||||
"park": park.name, # For parks, park field is the park name itself
|
||||
"category": "park",
|
||||
"rating": (
|
||||
float(park.average_rating)
|
||||
if park.average_rating
|
||||
else 0.0
|
||||
),
|
||||
"date_opened": (
|
||||
opening_date.isoformat() if opening_date else ""
|
||||
),
|
||||
"rating": (float(park.average_rating) if park.average_rating else 0.0),
|
||||
"date_opened": (opening_date.isoformat() if opening_date else ""),
|
||||
"url": park.url,
|
||||
"card_image": card_image_url,
|
||||
"city": city,
|
||||
@@ -256,9 +230,7 @@ class TrendingService:
|
||||
|
||||
def _calculate_trending_rides(self, limit: int) -> list[dict[str, Any]]:
|
||||
"""Calculate trending scores for rides."""
|
||||
rides = Ride.objects.filter(status="OPERATING").select_related(
|
||||
"park", "park__location", "card_image"
|
||||
)
|
||||
rides = Ride.objects.filter(status="OPERATING").select_related("park", "park__location", "card_image")
|
||||
|
||||
trending_rides = []
|
||||
|
||||
@@ -274,9 +246,7 @@ class TrendingService:
|
||||
# Get card image URL
|
||||
card_image_url = ""
|
||||
if ride.card_image and hasattr(ride.card_image, "image"):
|
||||
card_image_url = (
|
||||
ride.card_image.image.url if ride.card_image.image else ""
|
||||
)
|
||||
card_image_url = ride.card_image.image.url if ride.card_image.image else ""
|
||||
|
||||
trending_rides.append(
|
||||
{
|
||||
@@ -288,14 +258,8 @@ class TrendingService:
|
||||
"slug": ride.slug,
|
||||
"park": ride.park.name if ride.park else "",
|
||||
"category": "ride",
|
||||
"rating": (
|
||||
float(ride.average_rating)
|
||||
if ride.average_rating
|
||||
else 0.0
|
||||
),
|
||||
"date_opened": (
|
||||
opening_date.isoformat() if opening_date else ""
|
||||
),
|
||||
"rating": (float(ride.average_rating) if ride.average_rating else 0.0),
|
||||
"date_opened": (opening_date.isoformat() if opening_date else ""),
|
||||
"url": ride.url,
|
||||
"park_url": ride.park.url if ride.park else "",
|
||||
"card_image": card_image_url,
|
||||
@@ -347,23 +311,17 @@ class TrendingService:
|
||||
return final_score
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Error calculating score for {content_type} {content_obj.id}: {e}"
|
||||
)
|
||||
self.logger.error(f"Error calculating score for {content_type} {content_obj.id}: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_view_growth_score(
|
||||
self, content_type: ContentType, object_id: int
|
||||
) -> float:
|
||||
def _calculate_view_growth_score(self, content_type: ContentType, object_id: int) -> float:
|
||||
"""Calculate normalized view growth score."""
|
||||
try:
|
||||
current_views, previous_views, growth_percentage = (
|
||||
PageView.get_views_growth(
|
||||
content_type,
|
||||
object_id,
|
||||
self.CURRENT_PERIOD_HOURS,
|
||||
self.PREVIOUS_PERIOD_HOURS,
|
||||
)
|
||||
current_views, previous_views, growth_percentage = PageView.get_views_growth(
|
||||
content_type,
|
||||
object_id,
|
||||
self.CURRENT_PERIOD_HOURS,
|
||||
self.PREVIOUS_PERIOD_HOURS,
|
||||
)
|
||||
|
||||
if previous_views == 0:
|
||||
@@ -372,9 +330,7 @@ class TrendingService:
|
||||
|
||||
# Normalize growth percentage to 0-1 scale
|
||||
# 100% growth = 0.5, 500% growth = 1.0
|
||||
normalized_growth = (
|
||||
min(growth_percentage / 500.0, 1.0) if growth_percentage > 0 else 0.0
|
||||
)
|
||||
normalized_growth = min(growth_percentage / 500.0, 1.0) if growth_percentage > 0 else 0.0
|
||||
return max(normalized_growth, 0.0)
|
||||
|
||||
except Exception as e:
|
||||
@@ -421,11 +377,7 @@ class TrendingService:
|
||||
elif days_since_added <= 30:
|
||||
return 1.0 - (days_since_added / 30.0) * 0.2 # 1.0 to 0.8
|
||||
elif days_since_added <= self.RECENCY_BASELINE_DAYS:
|
||||
return (
|
||||
0.8
|
||||
- ((days_since_added - 30) / (self.RECENCY_BASELINE_DAYS - 30))
|
||||
* 0.7
|
||||
) # 0.8 to 0.1
|
||||
return 0.8 - ((days_since_added - 30) / (self.RECENCY_BASELINE_DAYS - 30)) * 0.7 # 0.8 to 0.1
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
@@ -433,9 +385,7 @@ class TrendingService:
|
||||
self.logger.warning(f"Error calculating recency score: {e}")
|
||||
return 0.5
|
||||
|
||||
def _calculate_popularity_score(
|
||||
self, content_type: ContentType, object_id: int
|
||||
) -> float:
|
||||
def _calculate_popularity_score(self, content_type: ContentType, object_id: int) -> float:
|
||||
"""Calculate popularity score based on total view count."""
|
||||
try:
|
||||
total_views = PageView.get_total_views_count(
|
||||
@@ -461,8 +411,7 @@ class TrendingService:
|
||||
"""Get recently added parks."""
|
||||
new_parks = (
|
||||
Park.objects.filter(
|
||||
Q(created_at__gte=cutoff_date)
|
||||
| Q(opening_date__gte=cutoff_date.date()),
|
||||
Q(created_at__gte=cutoff_date) | Q(opening_date__gte=cutoff_date.date()),
|
||||
status="OPERATING",
|
||||
)
|
||||
.select_related("location", "operator", "card_image")
|
||||
@@ -473,7 +422,7 @@ class TrendingService:
|
||||
for park in new_parks:
|
||||
date_added = park.opening_date or park.created_at
|
||||
# Handle datetime to date conversion
|
||||
if date_added:
|
||||
if date_added: # noqa: SIM102
|
||||
# If it's a datetime, convert to date
|
||||
if isinstance(date_added, datetime):
|
||||
date_added = date_added.date()
|
||||
@@ -500,9 +449,7 @@ class TrendingService:
|
||||
# Get card image URL
|
||||
card_image_url = ""
|
||||
if park.card_image and hasattr(park.card_image, "image"):
|
||||
card_image_url = (
|
||||
park.card_image.image.url if park.card_image.image else ""
|
||||
)
|
||||
card_image_url = park.card_image.image.url if park.card_image.image else ""
|
||||
|
||||
# Get primary company (operator)
|
||||
primary_company = park.operator.name if park.operator else ""
|
||||
@@ -533,8 +480,7 @@ class TrendingService:
|
||||
"""Get recently added rides."""
|
||||
new_rides = (
|
||||
Ride.objects.filter(
|
||||
Q(created_at__gte=cutoff_date)
|
||||
| Q(opening_date__gte=cutoff_date.date()),
|
||||
Q(created_at__gte=cutoff_date) | Q(opening_date__gte=cutoff_date.date()),
|
||||
status="OPERATING",
|
||||
)
|
||||
.select_related("park", "park__location", "card_image")
|
||||
@@ -543,11 +489,9 @@ class TrendingService:
|
||||
|
||||
results = []
|
||||
for ride in new_rides:
|
||||
date_added = getattr(ride, "opening_date", None) or getattr(
|
||||
ride, "created_at", None
|
||||
)
|
||||
date_added = getattr(ride, "opening_date", None) or getattr(ride, "created_at", None)
|
||||
# Handle datetime to date conversion
|
||||
if date_added:
|
||||
if date_added: # noqa: SIM102
|
||||
# If it's a datetime, convert to date
|
||||
if isinstance(date_added, datetime):
|
||||
date_added = date_added.date()
|
||||
@@ -561,9 +505,7 @@ class TrendingService:
|
||||
# Get card image URL
|
||||
card_image_url = ""
|
||||
if ride.card_image and hasattr(ride.card_image, "image"):
|
||||
card_image_url = (
|
||||
ride.card_image.image.url if ride.card_image.image else ""
|
||||
)
|
||||
card_image_url = ride.card_image.image.url if ride.card_image.image else ""
|
||||
|
||||
results.append(
|
||||
{
|
||||
@@ -584,9 +526,7 @@ class TrendingService:
|
||||
|
||||
return results
|
||||
|
||||
def _format_trending_results(
|
||||
self, trending_items: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
def _format_trending_results(self, trending_items: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Format trending results for frontend consumption."""
|
||||
formatted_results = []
|
||||
|
||||
@@ -595,13 +535,11 @@ class TrendingService:
|
||||
# Get view change for display
|
||||
content_obj = item["content_object"]
|
||||
ct = ContentType.objects.get_for_model(content_obj)
|
||||
current_views, previous_views, growth_percentage = (
|
||||
PageView.get_views_growth(
|
||||
ct,
|
||||
content_obj.id,
|
||||
self.CURRENT_PERIOD_HOURS,
|
||||
self.PREVIOUS_PERIOD_HOURS,
|
||||
)
|
||||
current_views, previous_views, growth_percentage = PageView.get_views_growth(
|
||||
ct,
|
||||
content_obj.id,
|
||||
self.CURRENT_PERIOD_HOURS,
|
||||
self.PREVIOUS_PERIOD_HOURS,
|
||||
)
|
||||
|
||||
# Format exactly as frontend expects
|
||||
@@ -614,9 +552,7 @@ class TrendingService:
|
||||
"rank": rank,
|
||||
"views": current_views,
|
||||
"views_change": (
|
||||
f"+{growth_percentage:.1f}%"
|
||||
if growth_percentage > 0
|
||||
else f"{growth_percentage:.1f}%"
|
||||
f"+{growth_percentage:.1f}%" if growth_percentage > 0 else f"{growth_percentage:.1f}%"
|
||||
),
|
||||
"slug": item["slug"],
|
||||
"date_opened": item["date_opened"],
|
||||
@@ -649,9 +585,7 @@ class TrendingService:
|
||||
|
||||
return formatted_results
|
||||
|
||||
def _format_new_content_results(
|
||||
self, new_items: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
def _format_new_content_results(self, new_items: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Format new content results for frontend consumption."""
|
||||
formatted_results = []
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""State machine utilities for core app."""
|
||||
|
||||
from .builder import (
|
||||
StateTransitionBuilder,
|
||||
determine_method_name_for_transition,
|
||||
|
||||
@@ -60,6 +60,7 @@ See Also:
|
||||
- apps.core.choices.registry: Central choice registry
|
||||
- apps.core.state_machine.guards: Guard extraction from metadata
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
@@ -129,9 +130,7 @@ class StateTransitionBuilder:
|
||||
# Validate choice group exists
|
||||
group = registry.get(choice_group, domain)
|
||||
if group is None:
|
||||
raise ImproperlyConfigured(
|
||||
f"Choice group '{choice_group}' not found in domain '{domain}'"
|
||||
)
|
||||
raise ImproperlyConfigured(f"Choice group '{choice_group}' not found in domain '{domain}'")
|
||||
|
||||
self.choices = registry.get_choices(choice_group, domain)
|
||||
|
||||
@@ -172,20 +171,15 @@ class StateTransitionBuilder:
|
||||
|
||||
# Validate all target states exist
|
||||
for target in transitions:
|
||||
target_choice = registry.get_choice(
|
||||
self.choice_group, target, self.domain
|
||||
)
|
||||
target_choice = registry.get_choice(self.choice_group, target, self.domain)
|
||||
if target_choice is None:
|
||||
raise ImproperlyConfigured(
|
||||
f"State '{state_value}' references non-existent "
|
||||
f"transition target '{target}'"
|
||||
f"State '{state_value}' references non-existent " f"transition target '{target}'"
|
||||
)
|
||||
|
||||
return transitions
|
||||
|
||||
def extract_permission_requirements(
|
||||
self, state_value: str
|
||||
) -> dict[str, bool]:
|
||||
def extract_permission_requirements(self, state_value: str) -> dict[str, bool]:
|
||||
"""
|
||||
Extract permission requirements from metadata.
|
||||
|
||||
@@ -198,9 +192,7 @@ class StateTransitionBuilder:
|
||||
metadata = self.get_choice_metadata(state_value)
|
||||
return {
|
||||
"requires_moderator": metadata.get("requires_moderator", False),
|
||||
"requires_admin_approval": metadata.get(
|
||||
"requires_admin_approval", False
|
||||
),
|
||||
"requires_admin_approval": metadata.get("requires_admin_approval", False),
|
||||
}
|
||||
|
||||
def is_terminal_state(self, state_value: str) -> bool:
|
||||
|
||||
@@ -181,10 +181,7 @@ class TransitionContext:
|
||||
return self.model_class.__name__
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"TransitionContext({self.model_name}.{self.field_name}: "
|
||||
f"{self.source_state} → {self.target_state})"
|
||||
)
|
||||
return f"TransitionContext({self.model_name}.{self.field_name}: " f"{self.source_state} → {self.target_state})"
|
||||
|
||||
|
||||
class BaseTransitionCallback(ABC):
|
||||
@@ -324,9 +321,9 @@ class CallbackRegistration:
|
||||
return False
|
||||
if self.field_name != field_name:
|
||||
return False
|
||||
if self.source != '*' and self.source != source:
|
||||
if self.source != "*" and self.source != source:
|
||||
return False
|
||||
return not (self.target != '*' and self.target != target)
|
||||
return not (self.target != "*" and self.target != target)
|
||||
|
||||
|
||||
class TransitionCallbackRegistry:
|
||||
@@ -337,10 +334,10 @@ class TransitionCallbackRegistry:
|
||||
for specific transitions.
|
||||
"""
|
||||
|
||||
_instance: Optional['TransitionCallbackRegistry'] = None
|
||||
_instance: Optional["TransitionCallbackRegistry"] = None
|
||||
_initialized: bool = False
|
||||
|
||||
def __new__(cls) -> 'TransitionCallbackRegistry':
|
||||
def __new__(cls) -> "TransitionCallbackRegistry":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
@@ -483,10 +480,7 @@ class TransitionCallbackRegistry:
|
||||
try:
|
||||
# Check if callback should execute
|
||||
if not callback.should_execute(context):
|
||||
logger.debug(
|
||||
f"Skipping callback {callback.name} - "
|
||||
f"should_execute returned False"
|
||||
)
|
||||
logger.debug(f"Skipping callback {callback.name} - " f"should_execute returned False")
|
||||
continue
|
||||
|
||||
# Execute callback
|
||||
@@ -498,30 +492,24 @@ class TransitionCallbackRegistry:
|
||||
result = callback.execute(context)
|
||||
|
||||
if not result:
|
||||
logger.warning(
|
||||
f"Callback {callback.name} returned False for {context}"
|
||||
)
|
||||
logger.warning(f"Callback {callback.name} returned False for {context}")
|
||||
failures.append((callback, None))
|
||||
overall_success = False
|
||||
|
||||
if not callback.continue_on_error:
|
||||
logger.error(
|
||||
f"Aborting callback chain - {callback.name} failed "
|
||||
f"and continue_on_error=False"
|
||||
f"Aborting callback chain - {callback.name} failed " f"and continue_on_error=False"
|
||||
)
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Callback {callback.name} raised exception for {context}: {e}"
|
||||
)
|
||||
logger.exception(f"Callback {callback.name} raised exception for {context}: {e}")
|
||||
failures.append((callback, e))
|
||||
overall_success = False
|
||||
|
||||
if not callback.continue_on_error:
|
||||
logger.error(
|
||||
f"Aborting callback chain - {callback.name} raised exception "
|
||||
f"and continue_on_error=False"
|
||||
f"Aborting callback chain - {callback.name} raised exception " f"and continue_on_error=False"
|
||||
)
|
||||
break
|
||||
|
||||
@@ -540,10 +528,7 @@ class TransitionCallbackRegistry:
|
||||
self._callbacks[stage] = []
|
||||
else:
|
||||
for stage in CallbackStage:
|
||||
self._callbacks[stage] = [
|
||||
r for r in self._callbacks[stage]
|
||||
if r.model_class != model_class
|
||||
]
|
||||
self._callbacks[stage] = [r for r in self._callbacks[stage] if r.model_class != model_class]
|
||||
|
||||
def get_all_registrations(
|
||||
self,
|
||||
@@ -563,10 +548,7 @@ class TransitionCallbackRegistry:
|
||||
|
||||
filtered = {}
|
||||
for stage, registrations in self._callbacks.items():
|
||||
filtered[stage] = [
|
||||
r for r in registrations
|
||||
if r.model_class == model_class
|
||||
]
|
||||
filtered[stage] = [r for r in registrations if r.model_class == model_class]
|
||||
return filtered
|
||||
|
||||
@classmethod
|
||||
@@ -601,9 +583,7 @@ def register_pre_callback(
|
||||
callback: PreTransitionCallback,
|
||||
) -> None:
|
||||
"""Convenience function to register a pre-transition callback."""
|
||||
callback_registry.register(
|
||||
model_class, field_name, source, target, callback, CallbackStage.PRE
|
||||
)
|
||||
callback_registry.register(model_class, field_name, source, target, callback, CallbackStage.PRE)
|
||||
|
||||
|
||||
def register_post_callback(
|
||||
@@ -614,9 +594,7 @@ def register_post_callback(
|
||||
callback: PostTransitionCallback,
|
||||
) -> None:
|
||||
"""Convenience function to register a post-transition callback."""
|
||||
callback_registry.register(
|
||||
model_class, field_name, source, target, callback, CallbackStage.POST
|
||||
)
|
||||
callback_registry.register(model_class, field_name, source, target, callback, CallbackStage.POST)
|
||||
|
||||
|
||||
def register_error_callback(
|
||||
@@ -627,6 +605,4 @@ def register_error_callback(
|
||||
callback: ErrorTransitionCallback,
|
||||
) -> None:
|
||||
"""Convenience function to register an error callback."""
|
||||
callback_registry.register(
|
||||
model_class, field_name, source, target, callback, CallbackStage.ERROR
|
||||
)
|
||||
callback_registry.register(model_class, field_name, source, target, callback, CallbackStage.ERROR)
|
||||
|
||||
@@ -44,8 +44,8 @@ class CacheInvalidationCallback(PostTransitionCallback):
|
||||
|
||||
def should_execute(self, context: TransitionContext) -> bool:
|
||||
"""Check if cache invalidation is enabled."""
|
||||
callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
|
||||
if not callback_settings.get('cache_invalidation_enabled', True):
|
||||
callback_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {})
|
||||
if not callback_settings.get("cache_invalidation_enabled", True):
|
||||
logger.debug("Cache invalidation disabled via settings")
|
||||
return False
|
||||
return True
|
||||
@@ -54,6 +54,7 @@ class CacheInvalidationCallback(PostTransitionCallback):
|
||||
"""Get the EnhancedCacheService instance."""
|
||||
try:
|
||||
from apps.core.services.enhanced_cache_service import EnhancedCacheService
|
||||
|
||||
return EnhancedCacheService()
|
||||
except ImportError:
|
||||
logger.warning("EnhancedCacheService not available")
|
||||
@@ -85,11 +86,7 @@ class CacheInvalidationCallback(PostTransitionCallback):
|
||||
|
||||
substituted = set()
|
||||
for pattern in all_patterns:
|
||||
substituted.add(
|
||||
pattern
|
||||
.replace('{id}', instance_id)
|
||||
.replace('{model}', model_name)
|
||||
)
|
||||
substituted.add(pattern.replace("{id}", instance_id).replace("{model}", model_name))
|
||||
|
||||
return substituted
|
||||
|
||||
@@ -108,20 +105,13 @@ class CacheInvalidationCallback(PostTransitionCallback):
|
||||
cache_service.invalidate_pattern(pattern)
|
||||
logger.debug(f"Invalidated cache pattern: {pattern}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to invalidate cache pattern {pattern}: {e}"
|
||||
)
|
||||
logger.warning(f"Failed to invalidate cache pattern {pattern}: {e}")
|
||||
|
||||
logger.info(
|
||||
f"Cache invalidation completed for {context}: "
|
||||
f"{len(patterns)} patterns"
|
||||
)
|
||||
logger.info(f"Cache invalidation completed for {context}: " f"{len(patterns)} patterns")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to invalidate cache for {context}: {e}"
|
||||
)
|
||||
logger.exception(f"Failed to invalidate cache for {context}: {e}")
|
||||
return False
|
||||
|
||||
def _fallback_invalidation(self, context: TransitionContext) -> bool:
|
||||
@@ -133,8 +123,7 @@ class CacheInvalidationCallback(PostTransitionCallback):
|
||||
# Django's default cache doesn't support pattern deletion
|
||||
# Log a warning and return True (don't fail the transition)
|
||||
logger.warning(
|
||||
f"EnhancedCacheService not available, skipping pattern "
|
||||
f"invalidation for {len(patterns)} patterns"
|
||||
f"EnhancedCacheService not available, skipping pattern " f"invalidation for {len(patterns)} patterns"
|
||||
)
|
||||
|
||||
return True
|
||||
@@ -155,13 +144,13 @@ class ModelCacheInvalidation(CacheInvalidationCallback):
|
||||
|
||||
# Default patterns by model type
|
||||
MODEL_PATTERNS = {
|
||||
'Park': ['*park:{id}*', '*parks*', 'geo:*'],
|
||||
'Ride': ['*ride:{id}*', '*rides*', '*park:*', 'geo:*'],
|
||||
'EditSubmission': ['*submission:{id}*', '*moderation*'],
|
||||
'PhotoSubmission': ['*photo:{id}*', '*moderation*'],
|
||||
'ModerationReport': ['*report:{id}*', '*moderation*'],
|
||||
'ModerationQueue': ['*queue*', '*moderation*'],
|
||||
'BulkOperation': ['*operation:{id}*', '*moderation*'],
|
||||
"Park": ["*park:{id}*", "*parks*", "geo:*"],
|
||||
"Ride": ["*ride:{id}*", "*rides*", "*park:*", "geo:*"],
|
||||
"EditSubmission": ["*submission:{id}*", "*moderation*"],
|
||||
"PhotoSubmission": ["*photo:{id}*", "*moderation*"],
|
||||
"ModerationReport": ["*report:{id}*", "*moderation*"],
|
||||
"ModerationQueue": ["*queue*", "*moderation*"],
|
||||
"BulkOperation": ["*operation:{id}*", "*moderation*"],
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -178,7 +167,7 @@ class ModelCacheInvalidation(CacheInvalidationCallback):
|
||||
# Substitute {id} placeholder
|
||||
instance_id = str(context.instance.pk)
|
||||
for pattern in model_patterns:
|
||||
base_patterns.append(pattern.replace('{id}', instance_id))
|
||||
base_patterns.append(pattern.replace("{id}", instance_id))
|
||||
|
||||
return base_patterns
|
||||
|
||||
@@ -217,14 +206,14 @@ class RelatedModelCacheInvalidation(CacheInvalidationCallback):
|
||||
continue
|
||||
|
||||
# Handle foreign key relationships
|
||||
if hasattr(related_obj, 'pk'):
|
||||
if hasattr(related_obj, "pk"):
|
||||
related_model = type(related_obj).__name__.lower()
|
||||
related_id = related_obj.pk
|
||||
patterns.append(f"*{related_model}:{related_id}*")
|
||||
patterns.append(f"*{related_model}_{related_id}*")
|
||||
|
||||
# Handle many-to-many relationships
|
||||
elif hasattr(related_obj, 'all'):
|
||||
elif hasattr(related_obj, "all"):
|
||||
for obj in related_obj.all():
|
||||
related_model = type(obj).__name__.lower()
|
||||
related_id = obj.pk
|
||||
@@ -293,7 +282,7 @@ class APICacheInvalidation(CacheInvalidationCallback):
|
||||
**kwargs: Additional arguments.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.api_prefixes = api_prefixes or ['api:*']
|
||||
self.api_prefixes = api_prefixes or ["api:*"]
|
||||
self.include_geo_cache = include_geo_cache
|
||||
|
||||
def _get_all_patterns(self, context: TransitionContext) -> set[str]:
|
||||
@@ -306,8 +295,8 @@ class APICacheInvalidation(CacheInvalidationCallback):
|
||||
|
||||
# Add geo cache if requested
|
||||
if self.include_geo_cache:
|
||||
patterns.add('geo:*')
|
||||
patterns.add('map:*')
|
||||
patterns.add("geo:*")
|
||||
patterns.add("map:*")
|
||||
|
||||
# Add model-specific API patterns
|
||||
model_name = context.model_name.lower()
|
||||
@@ -329,10 +318,10 @@ class ParkCacheInvalidation(CacheInvalidationCallback):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
patterns=[
|
||||
'*park:{id}*',
|
||||
'*parks*',
|
||||
'api:*',
|
||||
'geo:*',
|
||||
"*park:{id}*",
|
||||
"*parks*",
|
||||
"api:*",
|
||||
"geo:*",
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
@@ -346,10 +335,10 @@ class RideCacheInvalidation(CacheInvalidationCallback):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
patterns=[
|
||||
'*ride:{id}*',
|
||||
'*rides*',
|
||||
'api:*',
|
||||
'geo:*',
|
||||
"*ride:{id}*",
|
||||
"*rides*",
|
||||
"api:*",
|
||||
"geo:*",
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
@@ -359,9 +348,9 @@ class RideCacheInvalidation(CacheInvalidationCallback):
|
||||
patterns = super()._get_instance_patterns(context)
|
||||
|
||||
# Invalidate parent park's cache
|
||||
park = getattr(context.instance, 'park', None)
|
||||
park = getattr(context.instance, "park", None)
|
||||
if park:
|
||||
park_id = park.pk if hasattr(park, 'pk') else park
|
||||
park_id = park.pk if hasattr(park, "pk") else park
|
||||
patterns.append(f"*park:{park_id}*")
|
||||
patterns.append(f"*park_{park_id}*")
|
||||
|
||||
@@ -376,9 +365,9 @@ class ModerationCacheInvalidation(CacheInvalidationCallback):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(
|
||||
patterns=[
|
||||
'*submission*',
|
||||
'*moderation*',
|
||||
'api:moderation*',
|
||||
"*submission*",
|
||||
"*moderation*",
|
||||
"api:moderation*",
|
||||
],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -53,17 +53,15 @@ class NotificationCallback(PostTransitionCallback):
|
||||
def should_execute(self, context: TransitionContext) -> bool:
|
||||
"""Check if notifications are enabled and recipient exists."""
|
||||
# Check if notifications are disabled in settings
|
||||
callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
|
||||
if not callback_settings.get('notifications_enabled', True):
|
||||
callback_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {})
|
||||
if not callback_settings.get("notifications_enabled", True):
|
||||
logger.debug("Notifications disabled via settings")
|
||||
return False
|
||||
|
||||
# Check if recipient exists
|
||||
recipient = self._get_recipient(context.instance)
|
||||
if not recipient:
|
||||
logger.debug(
|
||||
f"No recipient found at {self.recipient_field} for {context}"
|
||||
)
|
||||
logger.debug(f"No recipient found at {self.recipient_field} for {context}")
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -76,6 +74,7 @@ class NotificationCallback(PostTransitionCallback):
|
||||
"""Get the NotificationService instance."""
|
||||
try:
|
||||
from apps.accounts.services.notification_service import NotificationService
|
||||
|
||||
return NotificationService()
|
||||
except ImportError:
|
||||
logger.warning("NotificationService not available")
|
||||
@@ -86,18 +85,16 @@ class NotificationCallback(PostTransitionCallback):
|
||||
extra_data = {}
|
||||
|
||||
if self.include_transition_data:
|
||||
extra_data['transition'] = {
|
||||
'source_state': context.source_state,
|
||||
'target_state': context.target_state,
|
||||
'field_name': context.field_name,
|
||||
'timestamp': context.timestamp.isoformat(),
|
||||
extra_data["transition"] = {
|
||||
"source_state": context.source_state,
|
||||
"target_state": context.target_state,
|
||||
"field_name": context.field_name,
|
||||
"timestamp": context.timestamp.isoformat(),
|
||||
}
|
||||
|
||||
if context.user:
|
||||
extra_data['transition']['by_user_id'] = context.user.id
|
||||
extra_data['transition']['by_username'] = getattr(
|
||||
context.user, 'username', str(context.user)
|
||||
)
|
||||
extra_data["transition"]["by_user_id"] = context.user.id
|
||||
extra_data["transition"]["by_username"] = getattr(context.user, "username", str(context.user))
|
||||
|
||||
# Include any extra data from the context
|
||||
extra_data.update(context.extra_data)
|
||||
@@ -112,10 +109,7 @@ class NotificationCallback(PostTransitionCallback):
|
||||
def _get_notification_message(self, context: TransitionContext) -> str:
|
||||
"""Get the notification message based on context."""
|
||||
model_name = context.model_name
|
||||
return (
|
||||
f"The {model_name} has transitioned from {context.source_state} "
|
||||
f"to {context.target_state}."
|
||||
)
|
||||
return f"The {model_name} has transitioned from {context.source_state} " f"to {context.target_state}."
|
||||
|
||||
def execute(self, context: TransitionContext) -> bool:
|
||||
"""Execute the notification callback."""
|
||||
@@ -140,16 +134,11 @@ class NotificationCallback(PostTransitionCallback):
|
||||
extra_data=extra_data,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created {self.notification_type} notification for "
|
||||
f"{recipient} on {context}"
|
||||
)
|
||||
logger.info(f"Created {self.notification_type} notification for " f"{recipient} on {context}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create notification for {context}: {e}"
|
||||
)
|
||||
logger.exception(f"Failed to create notification for {context}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
@@ -176,8 +165,8 @@ class SubmissionApprovedNotification(NotificationCallback):
|
||||
def _get_submission_type(self, context: TransitionContext) -> str:
|
||||
"""Get the submission type from context or instance."""
|
||||
# Try to get from extra_data first
|
||||
if 'submission_type' in context.extra_data:
|
||||
return context.extra_data['submission_type']
|
||||
if "submission_type" in context.extra_data:
|
||||
return context.extra_data["submission_type"]
|
||||
# Fall back to model name
|
||||
return self.submission_type or context.model_name.lower()
|
||||
|
||||
@@ -193,10 +182,10 @@ class SubmissionApprovedNotification(NotificationCallback):
|
||||
|
||||
try:
|
||||
submission_type = self._get_submission_type(context)
|
||||
additional_message = context.extra_data.get('comment', '')
|
||||
additional_message = context.extra_data.get("comment", "")
|
||||
|
||||
# Use the specific method if available
|
||||
if hasattr(notification_service, 'create_submission_approved_notification'):
|
||||
if hasattr(notification_service, "create_submission_approved_notification"):
|
||||
notification_service.create_submission_approved_notification(
|
||||
user=recipient,
|
||||
submission_object=context.instance,
|
||||
@@ -215,15 +204,11 @@ class SubmissionApprovedNotification(NotificationCallback):
|
||||
extra_data=extra_data,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created approval notification for {recipient} on {context}"
|
||||
)
|
||||
logger.info(f"Created approval notification for {recipient} on {context}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create approval notification for {context}: {e}"
|
||||
)
|
||||
logger.exception(f"Failed to create approval notification for {context}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
@@ -250,8 +235,8 @@ class SubmissionRejectedNotification(NotificationCallback):
|
||||
def _get_submission_type(self, context: TransitionContext) -> str:
|
||||
"""Get the submission type from context or instance."""
|
||||
# Try to get from extra_data first
|
||||
if 'submission_type' in context.extra_data:
|
||||
return context.extra_data['submission_type']
|
||||
if "submission_type" in context.extra_data:
|
||||
return context.extra_data["submission_type"]
|
||||
# Fall back to model name
|
||||
return self.submission_type or context.model_name.lower()
|
||||
|
||||
@@ -268,11 +253,11 @@ class SubmissionRejectedNotification(NotificationCallback):
|
||||
try:
|
||||
submission_type = self._get_submission_type(context)
|
||||
# Extract rejection reason from extra_data
|
||||
rejection_reason = context.extra_data.get('reason', 'No reason provided')
|
||||
additional_message = context.extra_data.get('comment', '')
|
||||
rejection_reason = context.extra_data.get("reason", "No reason provided")
|
||||
additional_message = context.extra_data.get("comment", "")
|
||||
|
||||
# Use the specific method if available
|
||||
if hasattr(notification_service, 'create_submission_rejected_notification'):
|
||||
if hasattr(notification_service, "create_submission_rejected_notification"):
|
||||
notification_service.create_submission_rejected_notification(
|
||||
user=recipient,
|
||||
submission_object=context.instance,
|
||||
@@ -291,15 +276,11 @@ class SubmissionRejectedNotification(NotificationCallback):
|
||||
extra_data=extra_data,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created rejection notification for {recipient} on {context}"
|
||||
)
|
||||
logger.info(f"Created rejection notification for {recipient} on {context}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create rejection notification for {context}: {e}"
|
||||
)
|
||||
logger.exception(f"Failed to create rejection notification for {context}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
@@ -326,6 +307,7 @@ class SubmissionEscalatedNotification(NotificationCallback):
|
||||
"""Get admin users to notify."""
|
||||
try:
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
user_model = get_user_model()
|
||||
return user_model.objects.filter(is_staff=True, is_active=True)
|
||||
except Exception as e:
|
||||
@@ -340,9 +322,9 @@ class SubmissionEscalatedNotification(NotificationCallback):
|
||||
|
||||
try:
|
||||
extra_data = self._build_extra_data(context)
|
||||
escalation_reason = context.extra_data.get('reason', '')
|
||||
escalation_reason = context.extra_data.get("reason", "")
|
||||
if escalation_reason:
|
||||
extra_data['escalation_reason'] = escalation_reason
|
||||
extra_data["escalation_reason"] = escalation_reason
|
||||
|
||||
title = f"{context.model_name} escalated for review"
|
||||
message = f"A {context.model_name} has been escalated and requires attention."
|
||||
@@ -361,9 +343,7 @@ class SubmissionEscalatedNotification(NotificationCallback):
|
||||
related_object=context.instance,
|
||||
extra_data=extra_data,
|
||||
)
|
||||
logger.info(
|
||||
f"Created escalation notifications for {admins.count()} admins"
|
||||
)
|
||||
logger.info(f"Created escalation notifications for {admins.count()} admins")
|
||||
else:
|
||||
# Notify the submitter
|
||||
recipient = self._get_recipient(context.instance)
|
||||
@@ -376,16 +356,12 @@ class SubmissionEscalatedNotification(NotificationCallback):
|
||||
related_object=context.instance,
|
||||
extra_data=extra_data,
|
||||
)
|
||||
logger.info(
|
||||
f"Created escalation notification for {recipient}"
|
||||
)
|
||||
logger.info(f"Created escalation notification for {recipient}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create escalation notification for {context}: {e}"
|
||||
)
|
||||
logger.exception(f"Failed to create escalation notification for {context}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
@@ -415,16 +391,14 @@ class StatusChangeNotification(NotificationCallback):
|
||||
notification_type="status_change",
|
||||
**kwargs,
|
||||
)
|
||||
self.significant_states = significant_states or [
|
||||
'CLOSED_PERM', 'DEMOLISHED', 'RELOCATED'
|
||||
]
|
||||
self.significant_states = significant_states or ["CLOSED_PERM", "DEMOLISHED", "RELOCATED"]
|
||||
self.notify_admins = notify_admins
|
||||
|
||||
def should_execute(self, context: TransitionContext) -> bool:
|
||||
"""Only execute for significant state changes."""
|
||||
# Check if notifications are disabled
|
||||
callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
|
||||
if not callback_settings.get('notifications_enabled', True):
|
||||
callback_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {})
|
||||
if not callback_settings.get("notifications_enabled", True):
|
||||
return False
|
||||
|
||||
# Only notify for significant status changes
|
||||
@@ -441,16 +415,13 @@ class StatusChangeNotification(NotificationCallback):
|
||||
|
||||
try:
|
||||
extra_data = self._build_extra_data(context)
|
||||
extra_data['entity_type'] = context.model_name
|
||||
extra_data['entity_id'] = context.instance.pk
|
||||
extra_data["entity_type"] = context.model_name
|
||||
extra_data["entity_id"] = context.instance.pk
|
||||
|
||||
# Build title and message
|
||||
entity_name = getattr(context.instance, 'name', str(context.instance))
|
||||
entity_name = getattr(context.instance, "name", str(context.instance))
|
||||
title = f"{context.model_name} status changed to {context.target_state}"
|
||||
message = (
|
||||
f"{entity_name} has changed status from {context.source_state} "
|
||||
f"to {context.target_state}."
|
||||
)
|
||||
message = f"{entity_name} has changed status from {context.source_state} " f"to {context.target_state}."
|
||||
|
||||
# Notify admin users
|
||||
admins = self._get_admin_users()
|
||||
@@ -471,15 +442,14 @@ class StatusChangeNotification(NotificationCallback):
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create status change notification for {context}: {e}"
|
||||
)
|
||||
logger.exception(f"Failed to create status change notification for {context}: {e}")
|
||||
return False
|
||||
|
||||
def _get_admin_users(self):
|
||||
"""Get admin users to notify."""
|
||||
try:
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
user_model = get_user_model()
|
||||
return user_model.objects.filter(is_staff=True, is_active=True)
|
||||
except Exception as e:
|
||||
@@ -499,13 +469,13 @@ class ModerationNotificationCallback(NotificationCallback):
|
||||
|
||||
# Mapping of (model_name, target_state) to notification type
|
||||
NOTIFICATION_MAPPING = {
|
||||
('ModerationReport', 'UNDER_REVIEW'): 'report_under_review',
|
||||
('ModerationReport', 'RESOLVED'): 'report_resolved',
|
||||
('ModerationQueue', 'IN_PROGRESS'): 'queue_in_progress',
|
||||
('ModerationQueue', 'COMPLETED'): 'queue_completed',
|
||||
('BulkOperation', 'RUNNING'): 'bulk_operation_started',
|
||||
('BulkOperation', 'COMPLETED'): 'bulk_operation_completed',
|
||||
('BulkOperation', 'FAILED'): 'bulk_operation_failed',
|
||||
("ModerationReport", "UNDER_REVIEW"): "report_under_review",
|
||||
("ModerationReport", "RESOLVED"): "report_resolved",
|
||||
("ModerationQueue", "IN_PROGRESS"): "queue_in_progress",
|
||||
("ModerationQueue", "COMPLETED"): "queue_completed",
|
||||
("BulkOperation", "RUNNING"): "bulk_operation_started",
|
||||
("BulkOperation", "COMPLETED"): "bulk_operation_completed",
|
||||
("BulkOperation", "FAILED"): "bulk_operation_failed",
|
||||
}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@@ -522,7 +492,7 @@ class ModerationNotificationCallback(NotificationCallback):
|
||||
def _get_recipient(self, instance: models.Model) -> Any | None:
|
||||
"""Get the appropriate recipient based on model type."""
|
||||
# Try common recipient fields
|
||||
for field in ['reporter', 'assigned_to', 'created_by', 'submitted_by']:
|
||||
for field in ["reporter", "assigned_to", "created_by", "submitted_by"]:
|
||||
recipient = getattr(instance, field, None)
|
||||
if recipient:
|
||||
return recipient
|
||||
@@ -531,31 +501,28 @@ class ModerationNotificationCallback(NotificationCallback):
|
||||
def _get_notification_title(self, context: TransitionContext, notification_type: str) -> str:
|
||||
"""Get the notification title based on notification type."""
|
||||
titles = {
|
||||
'report_under_review': 'Your report is under review',
|
||||
'report_resolved': 'Your report has been resolved',
|
||||
'queue_in_progress': 'Moderation queue item in progress',
|
||||
'queue_completed': 'Moderation queue item completed',
|
||||
'bulk_operation_started': 'Bulk operation started',
|
||||
'bulk_operation_completed': 'Bulk operation completed',
|
||||
'bulk_operation_failed': 'Bulk operation failed',
|
||||
"report_under_review": "Your report is under review",
|
||||
"report_resolved": "Your report has been resolved",
|
||||
"queue_in_progress": "Moderation queue item in progress",
|
||||
"queue_completed": "Moderation queue item completed",
|
||||
"bulk_operation_started": "Bulk operation started",
|
||||
"bulk_operation_completed": "Bulk operation completed",
|
||||
"bulk_operation_failed": "Bulk operation failed",
|
||||
}
|
||||
return titles.get(notification_type, f"{context.model_name} status updated")
|
||||
|
||||
def _get_notification_message(self, context: TransitionContext, notification_type: str) -> str:
|
||||
"""Get the notification message based on notification type."""
|
||||
messages = {
|
||||
'report_under_review': 'Your moderation report is now being reviewed by our team.',
|
||||
'report_resolved': 'Your moderation report has been reviewed and resolved.',
|
||||
'queue_in_progress': 'A moderation queue item is now being processed.',
|
||||
'queue_completed': 'A moderation queue item has been completed.',
|
||||
'bulk_operation_started': 'Your bulk operation has started processing.',
|
||||
'bulk_operation_completed': 'Your bulk operation has completed successfully.',
|
||||
'bulk_operation_failed': 'Your bulk operation encountered an error and could not complete.',
|
||||
"report_under_review": "Your moderation report is now being reviewed by our team.",
|
||||
"report_resolved": "Your moderation report has been reviewed and resolved.",
|
||||
"queue_in_progress": "A moderation queue item is now being processed.",
|
||||
"queue_completed": "A moderation queue item has been completed.",
|
||||
"bulk_operation_started": "Your bulk operation has started processing.",
|
||||
"bulk_operation_completed": "Your bulk operation has completed successfully.",
|
||||
"bulk_operation_failed": "Your bulk operation encountered an error and could not complete.",
|
||||
}
|
||||
return messages.get(
|
||||
notification_type,
|
||||
f"The {context.model_name} has been updated to {context.target_state}."
|
||||
)
|
||||
return messages.get(notification_type, f"The {context.model_name} has been updated to {context.target_state}.")
|
||||
|
||||
def execute(self, context: TransitionContext) -> bool:
|
||||
"""Execute the moderation notification."""
|
||||
@@ -565,10 +532,7 @@ class ModerationNotificationCallback(NotificationCallback):
|
||||
|
||||
notification_type = self._get_notification_type(context)
|
||||
if not notification_type:
|
||||
logger.debug(
|
||||
f"No notification type defined for {context.model_name} "
|
||||
f"→ {context.target_state}"
|
||||
)
|
||||
logger.debug(f"No notification type defined for {context.model_name} " f"→ {context.target_state}")
|
||||
return True # Not an error, just no notification needed
|
||||
|
||||
recipient = self._get_recipient(context.instance)
|
||||
@@ -587,13 +551,9 @@ class ModerationNotificationCallback(NotificationCallback):
|
||||
extra_data=extra_data,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created {notification_type} notification for {recipient}"
|
||||
)
|
||||
logger.info(f"Created {notification_type} notification for {recipient}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to create moderation notification for {context}: {e}"
|
||||
)
|
||||
logger.exception(f"Failed to create moderation notification for {context}: {e}")
|
||||
return False
|
||||
|
||||
@@ -45,8 +45,8 @@ class RelatedModelUpdateCallback(PostTransitionCallback):
|
||||
|
||||
def should_execute(self, context: TransitionContext) -> bool:
|
||||
"""Check if related updates are enabled."""
|
||||
callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
|
||||
if not callback_settings.get('related_updates_enabled', True):
|
||||
callback_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {})
|
||||
if not callback_settings.get("related_updates_enabled", True):
|
||||
logger.debug("Related model updates disabled via settings")
|
||||
return False
|
||||
return True
|
||||
@@ -77,9 +77,7 @@ class RelatedModelUpdateCallback(PostTransitionCallback):
|
||||
return self.perform_update(context)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to update related models for {context}: {e}"
|
||||
)
|
||||
logger.exception(f"Failed to update related models for {context}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
@@ -94,10 +92,10 @@ class ParkCountUpdateCallback(RelatedModelUpdateCallback):
|
||||
name: str = "ParkCountUpdateCallback"
|
||||
|
||||
# Status values that count as "active" rides
|
||||
ACTIVE_STATUSES = {'OPERATING', 'SEASONAL', 'UNDER_CONSTRUCTION'}
|
||||
ACTIVE_STATUSES = {"OPERATING", "SEASONAL", "UNDER_CONSTRUCTION"}
|
||||
|
||||
# Status values that indicate a ride is no longer countable
|
||||
INACTIVE_STATUSES = {'CLOSED_PERM', 'DEMOLISHED', 'RELOCATED', 'REMOVED'}
|
||||
INACTIVE_STATUSES = {"CLOSED_PERM", "DEMOLISHED", "RELOCATED", "REMOVED"}
|
||||
|
||||
def should_execute(self, context: TransitionContext) -> bool:
|
||||
"""Only execute when status affects ride counts."""
|
||||
@@ -115,14 +113,14 @@ class ParkCountUpdateCallback(RelatedModelUpdateCallback):
|
||||
return source_affects or target_affects
|
||||
|
||||
# Category value for roller coasters (from rides domain choices)
|
||||
COASTER_CATEGORY = 'RC'
|
||||
COASTER_CATEGORY = "RC"
|
||||
|
||||
def perform_update(self, context: TransitionContext) -> bool:
|
||||
"""Update park ride counts."""
|
||||
instance = context.instance
|
||||
|
||||
# Get the parent park
|
||||
park = getattr(instance, 'park', None)
|
||||
park = getattr(instance, "park", None)
|
||||
if not park:
|
||||
logger.debug(f"No park found for ride {instance.pk}")
|
||||
return True
|
||||
@@ -133,22 +131,17 @@ class ParkCountUpdateCallback(RelatedModelUpdateCallback):
|
||||
from apps.rides.models.rides import Ride
|
||||
|
||||
# Get the park ID (handle both object and ID)
|
||||
park_id = park.pk if hasattr(park, 'pk') else park
|
||||
park_id = park.pk if hasattr(park, "pk") else park
|
||||
|
||||
# Calculate new counts efficiently
|
||||
ride_queryset = Ride.objects.filter(park_id=park_id)
|
||||
|
||||
# Count active rides
|
||||
active_statuses = list(self.ACTIVE_STATUSES)
|
||||
ride_count = ride_queryset.filter(
|
||||
status__in=active_statuses
|
||||
).count()
|
||||
ride_count = ride_queryset.filter(status__in=active_statuses).count()
|
||||
|
||||
# Count active coasters (category='RC' for Roller Coaster)
|
||||
coaster_count = ride_queryset.filter(
|
||||
status__in=active_statuses,
|
||||
category=self.COASTER_CATEGORY
|
||||
).count()
|
||||
coaster_count = ride_queryset.filter(status__in=active_statuses, category=self.COASTER_CATEGORY).count()
|
||||
|
||||
# Update park counts
|
||||
Park.objects.filter(id=park_id).update(
|
||||
@@ -156,16 +149,11 @@ class ParkCountUpdateCallback(RelatedModelUpdateCallback):
|
||||
coaster_count=coaster_count,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Updated park {park_id} counts: "
|
||||
f"ride_count={ride_count}, coaster_count={coaster_count}"
|
||||
)
|
||||
logger.info(f"Updated park {park_id} counts: " f"ride_count={ride_count}, coaster_count={coaster_count}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to update park counts for {instance.pk}: {e}"
|
||||
)
|
||||
logger.exception(f"Failed to update park counts for {instance.pk}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
@@ -184,20 +172,16 @@ class SearchTextUpdateCallback(RelatedModelUpdateCallback):
|
||||
instance = context.instance
|
||||
|
||||
# Check if instance has search_text field
|
||||
if not hasattr(instance, 'search_text'):
|
||||
logger.debug(
|
||||
f"{context.model_name} has no search_text field"
|
||||
)
|
||||
if not hasattr(instance, "search_text"):
|
||||
logger.debug(f"{context.model_name} has no search_text field")
|
||||
return True
|
||||
|
||||
try:
|
||||
# Call the model's update_search_text method if available
|
||||
if hasattr(instance, 'update_search_text'):
|
||||
if hasattr(instance, "update_search_text"):
|
||||
instance.update_search_text()
|
||||
instance.save(update_fields=['search_text'])
|
||||
logger.info(
|
||||
f"Updated search_text for {context.model_name} {instance.pk}"
|
||||
)
|
||||
instance.save(update_fields=["search_text"])
|
||||
logger.info(f"Updated search_text for {context.model_name} {instance.pk}")
|
||||
else:
|
||||
# Build search text manually
|
||||
self._build_search_text(instance, context)
|
||||
@@ -205,9 +189,7 @@ class SearchTextUpdateCallback(RelatedModelUpdateCallback):
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to update search_text for {instance.pk}: {e}"
|
||||
)
|
||||
logger.exception(f"Failed to update search_text for {instance.pk}: {e}")
|
||||
return False
|
||||
|
||||
def _build_search_text(
|
||||
@@ -219,7 +201,7 @@ class SearchTextUpdateCallback(RelatedModelUpdateCallback):
|
||||
parts = []
|
||||
|
||||
# Common searchable fields
|
||||
for field in ['name', 'title', 'description', 'location']:
|
||||
for field in ["name", "title", "description", "location"]:
|
||||
value = getattr(instance, field, None)
|
||||
if value:
|
||||
parts.append(str(value))
|
||||
@@ -228,15 +210,15 @@ class SearchTextUpdateCallback(RelatedModelUpdateCallback):
|
||||
status_field = getattr(instance, context.field_name, None)
|
||||
if status_field:
|
||||
# Try to get the display label
|
||||
display_method = f'get_{context.field_name}_display'
|
||||
display_method = f"get_{context.field_name}_display"
|
||||
if hasattr(instance, display_method):
|
||||
parts.append(getattr(instance, display_method)())
|
||||
else:
|
||||
parts.append(str(status_field))
|
||||
|
||||
# Update search_text
|
||||
instance.search_text = ' '.join(parts)
|
||||
instance.save(update_fields=['search_text'])
|
||||
instance.search_text = " ".join(parts)
|
||||
instance.save(update_fields=["search_text"])
|
||||
|
||||
|
||||
class ComputedFieldUpdateCallback(RelatedModelUpdateCallback):
|
||||
@@ -280,7 +262,7 @@ class ComputedFieldUpdateCallback(RelatedModelUpdateCallback):
|
||||
# Update specific fields
|
||||
updated_fields = []
|
||||
for field_name in self.computed_fields:
|
||||
update_method_name = f'compute_{field_name}'
|
||||
update_method_name = f"compute_{field_name}"
|
||||
if hasattr(instance, update_method_name):
|
||||
method = getattr(instance, update_method_name)
|
||||
if callable(method):
|
||||
@@ -291,17 +273,12 @@ class ComputedFieldUpdateCallback(RelatedModelUpdateCallback):
|
||||
# Save updated fields
|
||||
if updated_fields:
|
||||
instance.save(update_fields=updated_fields)
|
||||
logger.info(
|
||||
f"Updated computed fields {updated_fields} for "
|
||||
f"{context.model_name} {instance.pk}"
|
||||
)
|
||||
logger.info(f"Updated computed fields {updated_fields} for " f"{context.model_name} {instance.pk}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to update computed fields for {instance.pk}: {e}"
|
||||
)
|
||||
logger.exception(f"Failed to update computed fields for {instance.pk}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
@@ -320,7 +297,7 @@ class RideStatusUpdateCallback(RelatedModelUpdateCallback):
|
||||
return False
|
||||
|
||||
# Only execute for Ride model
|
||||
return context.model_name == 'Ride'
|
||||
return context.model_name == "Ride"
|
||||
|
||||
def perform_update(self, context: TransitionContext) -> bool:
|
||||
"""Perform ride-specific status updates."""
|
||||
@@ -329,22 +306,18 @@ class RideStatusUpdateCallback(RelatedModelUpdateCallback):
|
||||
|
||||
try:
|
||||
# Handle CLOSING → post_closing_status transition
|
||||
if context.source_state == 'CLOSING' and target != 'CLOSING':
|
||||
post_closing_status = getattr(instance, 'post_closing_status', None)
|
||||
if context.source_state == "CLOSING" and target != "CLOSING":
|
||||
post_closing_status = getattr(instance, "post_closing_status", None)
|
||||
if post_closing_status and target == post_closing_status:
|
||||
# Clear post_closing_status after applying it
|
||||
instance.post_closing_status = None
|
||||
instance.save(update_fields=['post_closing_status'])
|
||||
logger.info(
|
||||
f"Cleared post_closing_status for ride {instance.pk}"
|
||||
)
|
||||
instance.save(update_fields=["post_closing_status"])
|
||||
logger.info(f"Cleared post_closing_status for ride {instance.pk}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to update ride status fields for {instance.pk}: {e}"
|
||||
)
|
||||
logger.exception(f"Failed to update ride status fields for {instance.pk}: {e}")
|
||||
return False
|
||||
|
||||
|
||||
@@ -362,9 +335,7 @@ class ModerationQueueUpdateCallback(RelatedModelUpdateCallback):
|
||||
|
||||
# Only for submission and report models
|
||||
model_name = context.model_name
|
||||
return model_name in (
|
||||
'EditSubmission', 'PhotoSubmission', 'ModerationReport'
|
||||
)
|
||||
return model_name in ("EditSubmission", "PhotoSubmission", "ModerationReport")
|
||||
|
||||
def perform_update(self, context: TransitionContext) -> bool:
|
||||
"""Update moderation queue entries."""
|
||||
@@ -373,15 +344,13 @@ class ModerationQueueUpdateCallback(RelatedModelUpdateCallback):
|
||||
|
||||
try:
|
||||
# Mark related queue items as completed when submission is resolved
|
||||
if target in ('APPROVED', 'REJECTED', 'RESOLVED'):
|
||||
if target in ("APPROVED", "REJECTED", "RESOLVED"):
|
||||
self._update_queue_items(instance, context)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to update moderation queue for {instance.pk}: {e}"
|
||||
)
|
||||
logger.exception(f"Failed to update moderation queue for {instance.pk}: {e}")
|
||||
return False
|
||||
|
||||
def _update_queue_items(
|
||||
@@ -401,20 +370,18 @@ class ModerationQueueUpdateCallback(RelatedModelUpdateCallback):
|
||||
queue_items = ModerationQueue.objects.filter(
|
||||
content_type_id=content_type_id,
|
||||
object_id=instance.pk,
|
||||
status='IN_PROGRESS',
|
||||
status="IN_PROGRESS",
|
||||
)
|
||||
|
||||
for item in queue_items:
|
||||
if hasattr(item, 'complete'):
|
||||
if hasattr(item, "complete"):
|
||||
item.complete(user=context.user)
|
||||
else:
|
||||
item.status = 'COMPLETED'
|
||||
item.save(update_fields=['status'])
|
||||
item.status = "COMPLETED"
|
||||
item.save(update_fields=["status"])
|
||||
|
||||
if queue_items.exists():
|
||||
logger.info(
|
||||
f"Marked {queue_items.count()} queue items as completed"
|
||||
)
|
||||
logger.info(f"Marked {queue_items.count()} queue items as completed")
|
||||
|
||||
except ImportError:
|
||||
logger.debug("ModerationQueue model not available")
|
||||
@@ -425,6 +392,7 @@ class ModerationQueueUpdateCallback(RelatedModelUpdateCallback):
|
||||
"""Get content type ID for the instance."""
|
||||
try:
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
|
||||
content_type = ContentType.objects.get_for_model(type(instance))
|
||||
return content_type.pk
|
||||
except Exception:
|
||||
|
||||
@@ -33,7 +33,7 @@ class ModelCallbackConfig:
|
||||
"""Configuration for all callbacks on a model."""
|
||||
|
||||
model_name: str
|
||||
field_name: str = 'status'
|
||||
field_name: str = "status"
|
||||
transitions: dict[tuple, TransitionCallbackConfig] = field(default_factory=dict)
|
||||
default_config: TransitionCallbackConfig = field(default_factory=TransitionCallbackConfig)
|
||||
|
||||
@@ -53,12 +53,12 @@ class CallbackConfig:
|
||||
|
||||
# Default settings
|
||||
DEFAULT_SETTINGS = {
|
||||
'enabled': True,
|
||||
'notifications_enabled': True,
|
||||
'cache_invalidation_enabled': True,
|
||||
'related_updates_enabled': True,
|
||||
'debug_mode': False,
|
||||
'log_callbacks': False,
|
||||
"enabled": True,
|
||||
"notifications_enabled": True,
|
||||
"cache_invalidation_enabled": True,
|
||||
"related_updates_enabled": True,
|
||||
"debug_mode": False,
|
||||
"log_callbacks": False,
|
||||
}
|
||||
|
||||
# Model-specific configurations
|
||||
@@ -70,7 +70,7 @@ class CallbackConfig:
|
||||
|
||||
def _load_settings(self) -> dict[str, Any]:
|
||||
"""Load settings from Django configuration."""
|
||||
django_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
|
||||
django_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {})
|
||||
merged = dict(self.DEFAULT_SETTINGS)
|
||||
merged.update(django_settings)
|
||||
return merged
|
||||
@@ -78,123 +78,123 @@ class CallbackConfig:
|
||||
def _build_model_configs(self) -> dict[str, ModelCallbackConfig]:
|
||||
"""Build model-specific configurations."""
|
||||
return {
|
||||
'EditSubmission': ModelCallbackConfig(
|
||||
model_name='EditSubmission',
|
||||
field_name='status',
|
||||
"EditSubmission": ModelCallbackConfig(
|
||||
model_name="EditSubmission",
|
||||
field_name="status",
|
||||
transitions={
|
||||
('PENDING', 'APPROVED'): TransitionCallbackConfig(
|
||||
notification_template='submission_approved',
|
||||
cache_patterns=['*submission*', '*moderation*'],
|
||||
("PENDING", "APPROVED"): TransitionCallbackConfig(
|
||||
notification_template="submission_approved",
|
||||
cache_patterns=["*submission*", "*moderation*"],
|
||||
),
|
||||
('PENDING', 'REJECTED'): TransitionCallbackConfig(
|
||||
notification_template='submission_rejected',
|
||||
cache_patterns=['*submission*', '*moderation*'],
|
||||
("PENDING", "REJECTED"): TransitionCallbackConfig(
|
||||
notification_template="submission_rejected",
|
||||
cache_patterns=["*submission*", "*moderation*"],
|
||||
),
|
||||
('PENDING', 'ESCALATED'): TransitionCallbackConfig(
|
||||
notification_template='submission_escalated',
|
||||
cache_patterns=['*submission*', '*moderation*'],
|
||||
("PENDING", "ESCALATED"): TransitionCallbackConfig(
|
||||
notification_template="submission_escalated",
|
||||
cache_patterns=["*submission*", "*moderation*"],
|
||||
),
|
||||
},
|
||||
),
|
||||
'PhotoSubmission': ModelCallbackConfig(
|
||||
model_name='PhotoSubmission',
|
||||
field_name='status',
|
||||
"PhotoSubmission": ModelCallbackConfig(
|
||||
model_name="PhotoSubmission",
|
||||
field_name="status",
|
||||
transitions={
|
||||
('PENDING', 'APPROVED'): TransitionCallbackConfig(
|
||||
notification_template='photo_approved',
|
||||
cache_patterns=['*photo*', '*moderation*'],
|
||||
("PENDING", "APPROVED"): TransitionCallbackConfig(
|
||||
notification_template="photo_approved",
|
||||
cache_patterns=["*photo*", "*moderation*"],
|
||||
),
|
||||
('PENDING', 'REJECTED'): TransitionCallbackConfig(
|
||||
notification_template='photo_rejected',
|
||||
cache_patterns=['*photo*', '*moderation*'],
|
||||
("PENDING", "REJECTED"): TransitionCallbackConfig(
|
||||
notification_template="photo_rejected",
|
||||
cache_patterns=["*photo*", "*moderation*"],
|
||||
),
|
||||
},
|
||||
),
|
||||
'ModerationReport': ModelCallbackConfig(
|
||||
model_name='ModerationReport',
|
||||
field_name='status',
|
||||
"ModerationReport": ModelCallbackConfig(
|
||||
model_name="ModerationReport",
|
||||
field_name="status",
|
||||
transitions={
|
||||
('PENDING', 'UNDER_REVIEW'): TransitionCallbackConfig(
|
||||
notification_template='report_under_review',
|
||||
cache_patterns=['*report*', '*moderation*'],
|
||||
("PENDING", "UNDER_REVIEW"): TransitionCallbackConfig(
|
||||
notification_template="report_under_review",
|
||||
cache_patterns=["*report*", "*moderation*"],
|
||||
),
|
||||
('UNDER_REVIEW', 'RESOLVED'): TransitionCallbackConfig(
|
||||
notification_template='report_resolved',
|
||||
cache_patterns=['*report*', '*moderation*'],
|
||||
("UNDER_REVIEW", "RESOLVED"): TransitionCallbackConfig(
|
||||
notification_template="report_resolved",
|
||||
cache_patterns=["*report*", "*moderation*"],
|
||||
),
|
||||
},
|
||||
),
|
||||
'ModerationQueue': ModelCallbackConfig(
|
||||
model_name='ModerationQueue',
|
||||
field_name='status',
|
||||
"ModerationQueue": ModelCallbackConfig(
|
||||
model_name="ModerationQueue",
|
||||
field_name="status",
|
||||
transitions={
|
||||
('PENDING', 'IN_PROGRESS'): TransitionCallbackConfig(
|
||||
notification_template='queue_in_progress',
|
||||
cache_patterns=['*queue*', '*moderation*'],
|
||||
("PENDING", "IN_PROGRESS"): TransitionCallbackConfig(
|
||||
notification_template="queue_in_progress",
|
||||
cache_patterns=["*queue*", "*moderation*"],
|
||||
),
|
||||
('IN_PROGRESS', 'COMPLETED'): TransitionCallbackConfig(
|
||||
notification_template='queue_completed',
|
||||
cache_patterns=['*queue*', '*moderation*'],
|
||||
("IN_PROGRESS", "COMPLETED"): TransitionCallbackConfig(
|
||||
notification_template="queue_completed",
|
||||
cache_patterns=["*queue*", "*moderation*"],
|
||||
),
|
||||
},
|
||||
),
|
||||
'BulkOperation': ModelCallbackConfig(
|
||||
model_name='BulkOperation',
|
||||
field_name='status',
|
||||
"BulkOperation": ModelCallbackConfig(
|
||||
model_name="BulkOperation",
|
||||
field_name="status",
|
||||
transitions={
|
||||
('PENDING', 'RUNNING'): TransitionCallbackConfig(
|
||||
notification_template='bulk_operation_started',
|
||||
cache_patterns=['*operation*', '*moderation*'],
|
||||
("PENDING", "RUNNING"): TransitionCallbackConfig(
|
||||
notification_template="bulk_operation_started",
|
||||
cache_patterns=["*operation*", "*moderation*"],
|
||||
),
|
||||
('RUNNING', 'COMPLETED'): TransitionCallbackConfig(
|
||||
notification_template='bulk_operation_completed',
|
||||
cache_patterns=['*operation*', '*moderation*'],
|
||||
("RUNNING", "COMPLETED"): TransitionCallbackConfig(
|
||||
notification_template="bulk_operation_completed",
|
||||
cache_patterns=["*operation*", "*moderation*"],
|
||||
),
|
||||
('RUNNING', 'FAILED'): TransitionCallbackConfig(
|
||||
notification_template='bulk_operation_failed',
|
||||
cache_patterns=['*operation*', '*moderation*'],
|
||||
("RUNNING", "FAILED"): TransitionCallbackConfig(
|
||||
notification_template="bulk_operation_failed",
|
||||
cache_patterns=["*operation*", "*moderation*"],
|
||||
),
|
||||
},
|
||||
),
|
||||
'Park': ModelCallbackConfig(
|
||||
model_name='Park',
|
||||
field_name='status',
|
||||
"Park": ModelCallbackConfig(
|
||||
model_name="Park",
|
||||
field_name="status",
|
||||
default_config=TransitionCallbackConfig(
|
||||
cache_patterns=['*park*', 'api:*', 'geo:*'],
|
||||
cache_patterns=["*park*", "api:*", "geo:*"],
|
||||
),
|
||||
transitions={
|
||||
('*', 'CLOSED_PERM'): TransitionCallbackConfig(
|
||||
("*", "CLOSED_PERM"): TransitionCallbackConfig(
|
||||
notifications_enabled=True,
|
||||
notification_template='park_closed_permanently',
|
||||
cache_patterns=['*park*', 'api:*', 'geo:*'],
|
||||
notification_template="park_closed_permanently",
|
||||
cache_patterns=["*park*", "api:*", "geo:*"],
|
||||
),
|
||||
('*', 'OPERATING'): TransitionCallbackConfig(
|
||||
("*", "OPERATING"): TransitionCallbackConfig(
|
||||
notifications_enabled=False,
|
||||
cache_patterns=['*park*', 'api:*', 'geo:*'],
|
||||
cache_patterns=["*park*", "api:*", "geo:*"],
|
||||
),
|
||||
},
|
||||
),
|
||||
'Ride': ModelCallbackConfig(
|
||||
model_name='Ride',
|
||||
field_name='status',
|
||||
"Ride": ModelCallbackConfig(
|
||||
model_name="Ride",
|
||||
field_name="status",
|
||||
default_config=TransitionCallbackConfig(
|
||||
cache_patterns=['*ride*', '*park*', 'api:*', 'geo:*'],
|
||||
cache_patterns=["*ride*", "*park*", "api:*", "geo:*"],
|
||||
),
|
||||
transitions={
|
||||
('*', 'OPERATING'): TransitionCallbackConfig(
|
||||
cache_patterns=['*ride*', '*park*', 'api:*', 'geo:*'],
|
||||
("*", "OPERATING"): TransitionCallbackConfig(
|
||||
cache_patterns=["*ride*", "*park*", "api:*", "geo:*"],
|
||||
related_updates_enabled=True,
|
||||
),
|
||||
('*', 'CLOSED_PERM'): TransitionCallbackConfig(
|
||||
cache_patterns=['*ride*', '*park*', 'api:*', 'geo:*'],
|
||||
("*", "CLOSED_PERM"): TransitionCallbackConfig(
|
||||
cache_patterns=["*ride*", "*park*", "api:*", "geo:*"],
|
||||
related_updates_enabled=True,
|
||||
),
|
||||
('*', 'DEMOLISHED'): TransitionCallbackConfig(
|
||||
cache_patterns=['*ride*', '*park*', 'api:*', 'geo:*'],
|
||||
("*", "DEMOLISHED"): TransitionCallbackConfig(
|
||||
cache_patterns=["*ride*", "*park*", "api:*", "geo:*"],
|
||||
related_updates_enabled=True,
|
||||
),
|
||||
('*', 'RELOCATED'): TransitionCallbackConfig(
|
||||
cache_patterns=['*ride*', '*park*', 'api:*', 'geo:*'],
|
||||
("*", "RELOCATED"): TransitionCallbackConfig(
|
||||
cache_patterns=["*ride*", "*park*", "api:*", "geo:*"],
|
||||
related_updates_enabled=True,
|
||||
),
|
||||
},
|
||||
@@ -204,32 +204,32 @@ class CallbackConfig:
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
"""Check if callbacks are globally enabled."""
|
||||
return self._settings.get('enabled', True)
|
||||
return self._settings.get("enabled", True)
|
||||
|
||||
@property
|
||||
def notifications_enabled(self) -> bool:
|
||||
"""Check if notification callbacks are enabled."""
|
||||
return self._settings.get('notifications_enabled', True)
|
||||
return self._settings.get("notifications_enabled", True)
|
||||
|
||||
@property
|
||||
def cache_invalidation_enabled(self) -> bool:
|
||||
"""Check if cache invalidation is enabled."""
|
||||
return self._settings.get('cache_invalidation_enabled', True)
|
||||
return self._settings.get("cache_invalidation_enabled", True)
|
||||
|
||||
@property
|
||||
def related_updates_enabled(self) -> bool:
|
||||
"""Check if related model updates are enabled."""
|
||||
return self._settings.get('related_updates_enabled', True)
|
||||
return self._settings.get("related_updates_enabled", True)
|
||||
|
||||
@property
|
||||
def debug_mode(self) -> bool:
|
||||
"""Check if debug mode is enabled."""
|
||||
return self._settings.get('debug_mode', False)
|
||||
return self._settings.get("debug_mode", False)
|
||||
|
||||
@property
|
||||
def log_callbacks(self) -> bool:
|
||||
"""Check if callback logging is enabled."""
|
||||
return self._settings.get('log_callbacks', False)
|
||||
return self._settings.get("log_callbacks", False)
|
||||
|
||||
def get_config(
|
||||
self,
|
||||
@@ -258,12 +258,12 @@ class CallbackConfig:
|
||||
return config
|
||||
|
||||
# Try wildcard source
|
||||
config = model_config.transitions.get(('*', target))
|
||||
config = model_config.transitions.get(("*", target))
|
||||
if config:
|
||||
return config
|
||||
|
||||
# Try wildcard target
|
||||
config = model_config.transitions.get((source, '*'))
|
||||
config = model_config.transitions.get((source, "*"))
|
||||
if config:
|
||||
return config
|
||||
|
||||
@@ -362,9 +362,7 @@ class CallbackConfig:
|
||||
**kwargs: Configuration values to update.
|
||||
"""
|
||||
if model_name not in self._model_configs:
|
||||
self._model_configs[model_name] = ModelCallbackConfig(
|
||||
model_name=model_name
|
||||
)
|
||||
self._model_configs[model_name] = ModelCallbackConfig(model_name=model_name)
|
||||
|
||||
model_config = self._model_configs[model_name]
|
||||
transition_key = (source, target)
|
||||
@@ -394,9 +392,9 @@ def get_callback_config() -> CallbackConfig:
|
||||
|
||||
|
||||
__all__ = [
|
||||
'TransitionCallbackConfig',
|
||||
'ModelCallbackConfig',
|
||||
'CallbackConfig',
|
||||
'callback_config',
|
||||
'get_callback_config',
|
||||
"TransitionCallbackConfig",
|
||||
"ModelCallbackConfig",
|
||||
"CallbackConfig",
|
||||
"callback_config",
|
||||
"get_callback_config",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Transition decorator generation for django-fsm integration."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
@@ -51,51 +52,42 @@ def with_callbacks(
|
||||
@wraps(func)
|
||||
def wrapper(instance, *args, **kwargs):
|
||||
# Extract user from kwargs
|
||||
user = kwargs.get('user')
|
||||
user = kwargs.get("user")
|
||||
|
||||
# Get source state before transition
|
||||
source_state = getattr(instance, field_name, None)
|
||||
|
||||
# Get target state from the transition decorator
|
||||
# The @transition decorator sets _django_fsm_target
|
||||
target_state = getattr(func, '_django_fsm', {}).get('target', None)
|
||||
target_state = getattr(func, "_django_fsm", {}).get("target", None)
|
||||
|
||||
# If we can't determine the target from decorator metadata,
|
||||
# we'll capture it after the transition
|
||||
if target_state is None:
|
||||
# This happens when decorators are applied in wrong order
|
||||
logger.debug(
|
||||
f"Could not determine target state from decorator for {func.__name__}"
|
||||
)
|
||||
logger.debug(f"Could not determine target state from decorator for {func.__name__}")
|
||||
|
||||
# Create transition context
|
||||
context = TransitionContext(
|
||||
instance=instance,
|
||||
field_name=field_name,
|
||||
source_state=str(source_state) if source_state else '',
|
||||
target_state=str(target_state) if target_state else '',
|
||||
source_state=str(source_state) if source_state else "",
|
||||
target_state=str(target_state) if target_state else "",
|
||||
user=user,
|
||||
extra_data=dict(kwargs),
|
||||
)
|
||||
|
||||
# Execute pre-transition callbacks
|
||||
pre_success, pre_failures = callback_registry.execute_callbacks(
|
||||
context, CallbackStage.PRE
|
||||
)
|
||||
pre_success, pre_failures = callback_registry.execute_callbacks(context, CallbackStage.PRE)
|
||||
|
||||
# If pre-callbacks fail with continue_on_error=False, abort
|
||||
if not pre_success and pre_failures:
|
||||
for callback, exc in pre_failures:
|
||||
if not callback.continue_on_error:
|
||||
logger.error(
|
||||
f"Pre-transition callback {callback.name} failed, "
|
||||
f"aborting transition"
|
||||
)
|
||||
logger.error(f"Pre-transition callback {callback.name} failed, " f"aborting transition")
|
||||
if exc:
|
||||
raise exc
|
||||
raise RuntimeError(
|
||||
f"Pre-transition callback {callback.name} failed"
|
||||
)
|
||||
raise RuntimeError(f"Pre-transition callback {callback.name} failed")
|
||||
|
||||
# Emit pre-transition signal
|
||||
if emit_signals:
|
||||
@@ -114,19 +106,14 @@ def with_callbacks(
|
||||
|
||||
# Update context with actual target state after transition
|
||||
actual_target = getattr(instance, field_name, None)
|
||||
context.target_state = str(actual_target) if actual_target else ''
|
||||
context.target_state = str(actual_target) if actual_target else ""
|
||||
|
||||
# Execute post-transition callbacks
|
||||
post_success, post_failures = callback_registry.execute_callbacks(
|
||||
context, CallbackStage.POST
|
||||
)
|
||||
post_success, post_failures = callback_registry.execute_callbacks(context, CallbackStage.POST)
|
||||
|
||||
if not post_success:
|
||||
for callback, exc in post_failures:
|
||||
logger.warning(
|
||||
f"Post-transition callback {callback.name} failed "
|
||||
f"for {context}"
|
||||
)
|
||||
for callback, _exc in post_failures:
|
||||
logger.warning(f"Post-transition callback {callback.name} failed " f"for {context}")
|
||||
|
||||
# Emit post-transition signal
|
||||
if emit_signals:
|
||||
@@ -236,9 +223,7 @@ def create_transition_method(
|
||||
on_success(instance, user=user, **kwargs)
|
||||
|
||||
transition_method.__name__ = method_name
|
||||
transition_method.__doc__ = (
|
||||
f"Transition from {source} to {target} on field {field_name}"
|
||||
)
|
||||
transition_method.__doc__ = f"Transition from {source} to {target} on field {field_name}"
|
||||
|
||||
# Apply callback wrapper if enabled
|
||||
if enable_callbacks:
|
||||
@@ -249,10 +234,10 @@ def create_transition_method(
|
||||
|
||||
# Store metadata for callback registration
|
||||
transition_method._fsm_metadata = {
|
||||
'source': source,
|
||||
'target': target,
|
||||
'field_name': field_name,
|
||||
'callbacks': callbacks or [],
|
||||
"source": source,
|
||||
"target": target,
|
||||
"field_name": field_name,
|
||||
"callbacks": callbacks or [],
|
||||
}
|
||||
|
||||
return transition_method
|
||||
@@ -271,21 +256,21 @@ def register_method_callbacks(
|
||||
model_class: The model class containing the method.
|
||||
method: The transition method with _fsm_metadata.
|
||||
"""
|
||||
metadata = getattr(method, '_fsm_metadata', None)
|
||||
if not metadata or not metadata.get('callbacks'):
|
||||
metadata = getattr(method, "_fsm_metadata", None)
|
||||
if not metadata or not metadata.get("callbacks"):
|
||||
return
|
||||
|
||||
from .callback_base import CallbackStage, PreTransitionCallback
|
||||
|
||||
for callback in metadata['callbacks']:
|
||||
for callback in metadata["callbacks"]:
|
||||
# Determine stage from callback type
|
||||
stage = CallbackStage.PRE if isinstance(callback, PreTransitionCallback) else CallbackStage.POST
|
||||
|
||||
callback_registry.register(
|
||||
model_class=model_class,
|
||||
field_name=metadata['field_name'],
|
||||
source=metadata['source'],
|
||||
target=metadata['target'],
|
||||
field_name=metadata["field_name"],
|
||||
source=metadata["source"],
|
||||
target=metadata["target"],
|
||||
callback=callback,
|
||||
stage=stage,
|
||||
)
|
||||
@@ -490,9 +475,7 @@ class TransitionMethodFactory:
|
||||
if docstring:
|
||||
generic_transition.__doc__ = docstring
|
||||
else:
|
||||
generic_transition.__doc__ = (
|
||||
f"Transition from {source} to {target}"
|
||||
)
|
||||
generic_transition.__doc__ = f"Transition from {source} to {target}"
|
||||
|
||||
# Apply callback wrapper if enabled
|
||||
if enable_callbacks:
|
||||
|
||||
@@ -12,6 +12,7 @@ Example usage:
|
||||
'code': e.error_code
|
||||
}, status=403)
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from django_fsm import TransitionNotAllowed
|
||||
@@ -214,29 +215,18 @@ ERROR_MESSAGES = {
|
||||
"You need {required_role} permissions to {action}. "
|
||||
"Please contact an administrator if you believe this is an error."
|
||||
),
|
||||
"PERMISSION_DENIED_OWNERSHIP": (
|
||||
"You must be the owner of this item to perform this action."
|
||||
),
|
||||
"PERMISSION_DENIED_OWNERSHIP": ("You must be the owner of this item to perform this action."),
|
||||
"PERMISSION_DENIED_ASSIGNMENT": (
|
||||
"This item must be assigned to you before you can {action}. "
|
||||
"Please assign it to yourself first."
|
||||
),
|
||||
"NO_ASSIGNMENT": (
|
||||
"This item must be assigned before this action can be performed."
|
||||
"This item must be assigned to you before you can {action}. " "Please assign it to yourself first."
|
||||
),
|
||||
"NO_ASSIGNMENT": ("This item must be assigned before this action can be performed."),
|
||||
"INVALID_STATE_TRANSITION": (
|
||||
"This action cannot be performed from the current state. "
|
||||
"The item is currently '{current_state}' and cannot be modified."
|
||||
),
|
||||
"TRANSITION_NOT_AVAILABLE": (
|
||||
"This {item_type} has already been {state} and cannot be modified."
|
||||
),
|
||||
"MISSING_REQUIRED_FIELD": (
|
||||
"{field_name} is required to complete this action."
|
||||
),
|
||||
"EMPTY_REQUIRED_FIELD": (
|
||||
"{field_name} must not be empty."
|
||||
),
|
||||
"TRANSITION_NOT_AVAILABLE": ("This {item_type} has already been {state} and cannot be modified."),
|
||||
"MISSING_REQUIRED_FIELD": ("{field_name} is required to complete this action."),
|
||||
"EMPTY_REQUIRED_FIELD": ("{field_name} must not be empty."),
|
||||
"ESCALATED_REQUIRES_ADMIN": (
|
||||
"This submission has been escalated and requires admin review. "
|
||||
"Only administrators can approve or reject escalated items."
|
||||
|
||||
@@ -47,6 +47,7 @@ See Also:
|
||||
- apps.core.choices.registry: The central choice registry
|
||||
- apps.core.state_machine.mixins.StateMachineMixin: Convenience helpers
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from django.core.exceptions import ValidationError
|
||||
@@ -138,14 +139,10 @@ class RichFSMField(DjangoFSMField):
|
||||
|
||||
choice = registry.get_choice(self.choice_group, value, self.domain)
|
||||
if choice is None:
|
||||
raise ValidationError(
|
||||
f"'{value}' is not a valid state for {self.choice_group}"
|
||||
)
|
||||
raise ValidationError(f"'{value}' is not a valid state for {self.choice_group}")
|
||||
|
||||
if choice.deprecated and not self.allow_deprecated:
|
||||
raise ValidationError(
|
||||
f"'{value}' is deprecated and cannot be used for new entries"
|
||||
)
|
||||
raise ValidationError(f"'{value}' is deprecated and cannot be used for new entries")
|
||||
|
||||
def get_rich_choice(self, value: str) -> RichChoice | None:
|
||||
"""Return the RichChoice object for a given state value."""
|
||||
@@ -155,9 +152,7 @@ class RichFSMField(DjangoFSMField):
|
||||
"""Return the label for the given state value."""
|
||||
return registry.get_choice_display(self.choice_group, value, self.domain)
|
||||
|
||||
def contribute_to_class(
|
||||
self, cls: Any, name: str, private_only: bool = False, **kwargs: Any
|
||||
) -> None:
|
||||
def contribute_to_class(self, cls: Any, name: str, private_only: bool = False, **kwargs: Any) -> None:
|
||||
"""Attach helpers to the model for convenience."""
|
||||
super().contribute_to_class(cls, name, private_only=private_only, **kwargs)
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ Example usage:
|
||||
OwnershipGuard()
|
||||
], operator='OR')
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -534,7 +535,7 @@ class MetadataGuard:
|
||||
self._last_error_code = self.ERROR_CODE_EMPTY_FIELD
|
||||
self._failed_field = field_name
|
||||
return False
|
||||
if isinstance(value, (list, dict)) and not value:
|
||||
if isinstance(value, list | dict) and not value:
|
||||
self._last_error_code = self.ERROR_CODE_EMPTY_FIELD
|
||||
self._failed_field = field_name
|
||||
return False
|
||||
@@ -787,8 +788,7 @@ def extract_guards_from_metadata(metadata: dict[str, Any]) -> list[Callable]:
|
||||
# Zero tolerance requires superuser
|
||||
if zero_tolerance:
|
||||
guard = PermissionGuard(
|
||||
requires_superuser=True,
|
||||
error_message="Zero tolerance violations require superuser permissions"
|
||||
requires_superuser=True, error_message="Zero tolerance violations require superuser permissions"
|
||||
)
|
||||
guards.append(guard)
|
||||
elif requires_moderator or requires_admin or escalation_level:
|
||||
@@ -801,7 +801,7 @@ def extract_guards_from_metadata(metadata: dict[str, Any]) -> list[Callable]:
|
||||
assignment_guard = AssignmentGuard(
|
||||
require_assignment=True,
|
||||
allow_admin_override=True,
|
||||
error_message="This item must be assigned to you before this action can be performed"
|
||||
error_message="This item must be assigned to you before this action can be performed",
|
||||
)
|
||||
guards.append(assignment_guard)
|
||||
|
||||
@@ -814,7 +814,7 @@ def extract_guards_from_metadata(metadata: dict[str, Any]) -> list[Callable]:
|
||||
|
||||
perm_guard = PermissionGuard(
|
||||
custom_check=check_permissions,
|
||||
error_message=f"Missing required permissions: {', '.join(required_permissions)}"
|
||||
error_message=f"Missing required permissions: {', '.join(required_permissions)}",
|
||||
)
|
||||
guards.append(perm_guard)
|
||||
|
||||
@@ -1072,7 +1072,7 @@ def has_role(user: Any, required_roles: list[str]) -> bool:
|
||||
# Only apply if role field is not set
|
||||
if user_role is None:
|
||||
# Check for superuser (Django's is_superuser flag)
|
||||
if hasattr(user, "is_superuser") and user.is_superuser:
|
||||
if hasattr(user, "is_superuser") and user.is_superuser: # noqa: SIM102
|
||||
if "SUPERUSER" in required_roles or "ADMIN" in required_roles:
|
||||
return True
|
||||
|
||||
@@ -1248,7 +1248,7 @@ def create_guard_from_drf_permission(
|
||||
self._last_error_code = "PERMISSION_DENIED"
|
||||
return False
|
||||
|
||||
if hasattr(permission, "has_object_permission"):
|
||||
if hasattr(permission, "has_object_permission"): # noqa: SIM102
|
||||
if not permission.has_object_permission(mock_request, None, instance):
|
||||
self._last_error_code = "OBJECT_PERMISSION_DENIED"
|
||||
return False
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Model integration utilities for applying state machines to Django models."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
@@ -46,18 +47,13 @@ def apply_state_machine(
|
||||
|
||||
if not result.is_valid:
|
||||
error_messages = [str(e) for e in result.errors]
|
||||
raise ValueError(
|
||||
"Cannot apply state machine - validation failed:\n"
|
||||
+ "\n".join(error_messages)
|
||||
)
|
||||
raise ValueError("Cannot apply state machine - validation failed:\n" + "\n".join(error_messages))
|
||||
|
||||
# Build transition registry
|
||||
registry_instance.build_registry_from_choices(choice_group, domain)
|
||||
|
||||
# Generate and attach transition methods
|
||||
generate_transition_methods_for_model(
|
||||
model_class, field_name, choice_group, domain
|
||||
)
|
||||
generate_transition_methods_for_model(model_class, field_name, choice_group, domain)
|
||||
|
||||
|
||||
def generate_transition_methods_for_model(
|
||||
@@ -140,15 +136,10 @@ def generate_transition_methods_for_model(
|
||||
setattr(model_class, method_name, method)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class StateMachineModelMixin:
|
||||
"""Mixin providing state machine helper methods for models."""
|
||||
|
||||
def get_available_state_transitions(
|
||||
self, field_name: str = "status"
|
||||
) -> list[TransitionInfo]:
|
||||
def get_available_state_transitions(self, field_name: str = "status") -> list[TransitionInfo]:
|
||||
"""
|
||||
Get available transitions from current state.
|
||||
|
||||
@@ -167,9 +158,7 @@ class StateMachineModelMixin:
|
||||
domain = field.domain
|
||||
current_state = getattr(self, field_name)
|
||||
|
||||
return registry_instance.get_available_transitions(
|
||||
choice_group, domain, current_state
|
||||
)
|
||||
return registry_instance.get_available_transitions(choice_group, domain, current_state)
|
||||
|
||||
def can_transition_to(
|
||||
self,
|
||||
@@ -199,9 +188,7 @@ class StateMachineModelMixin:
|
||||
domain = field.domain
|
||||
|
||||
# Check if transition exists in registry
|
||||
transition = registry_instance.get_transition(
|
||||
choice_group, domain, current_state, target_state
|
||||
)
|
||||
transition = registry_instance.get_transition(choice_group, domain, current_state, target_state)
|
||||
|
||||
if not transition:
|
||||
return False
|
||||
@@ -216,9 +203,7 @@ class StateMachineModelMixin:
|
||||
# Use django-fsm's can_proceed
|
||||
return can_proceed(method)
|
||||
|
||||
def get_transition_method(
|
||||
self, target_state: str, field_name: str = "status"
|
||||
) -> Callable | None:
|
||||
def get_transition_method(self, target_state: str, field_name: str = "status") -> Callable | None:
|
||||
"""
|
||||
Get the transition method for moving to target state.
|
||||
|
||||
@@ -238,9 +223,7 @@ class StateMachineModelMixin:
|
||||
choice_group = field.choice_group
|
||||
domain = field.domain
|
||||
|
||||
transition = registry_instance.get_transition(
|
||||
choice_group, domain, current_state, target_state
|
||||
)
|
||||
transition = registry_instance.get_transition(choice_group, domain, current_state, target_state)
|
||||
|
||||
if not transition:
|
||||
return None
|
||||
@@ -270,9 +253,7 @@ class StateMachineModelMixin:
|
||||
ValueError: If transition is not allowed
|
||||
"""
|
||||
if not self.can_transition_to(target_state, field_name, user):
|
||||
raise ValueError(
|
||||
f"Cannot transition to {target_state} from current state"
|
||||
)
|
||||
raise ValueError(f"Cannot transition to {target_state} from current state")
|
||||
|
||||
method = self.get_transition_method(target_state, field_name)
|
||||
if method is None:
|
||||
@@ -283,9 +264,7 @@ class StateMachineModelMixin:
|
||||
return True
|
||||
|
||||
|
||||
def state_machine_model(
|
||||
field_name: str, choice_group: str, domain: str = "core"
|
||||
):
|
||||
def state_machine_model(field_name: str, choice_group: str, domain: str = "core"):
|
||||
"""
|
||||
Class decorator to automatically apply state machine to models.
|
||||
|
||||
@@ -306,9 +285,7 @@ def state_machine_model(
|
||||
return decorator
|
||||
|
||||
|
||||
def validate_model_state_machine(
|
||||
model_class: type[models.Model], field_name: str
|
||||
) -> bool:
|
||||
def validate_model_state_machine(model_class: type[models.Model], field_name: str) -> bool:
|
||||
"""
|
||||
Ensure model is properly configured with state machine.
|
||||
|
||||
@@ -326,13 +303,11 @@ def validate_model_state_machine(
|
||||
try:
|
||||
field = model_class._meta.get_field(field_name)
|
||||
except Exception:
|
||||
raise ValueError(f"Field {field_name} not found on {model_class}")
|
||||
raise ValueError(f"Field {field_name} not found on {model_class}") from None
|
||||
|
||||
# Check if field has choice_group attribute
|
||||
if not hasattr(field, "choice_group"):
|
||||
raise ValueError(
|
||||
f"Field {field_name} is not a RichFSMField or RichChoiceField"
|
||||
)
|
||||
raise ValueError(f"Field {field_name} is not a RichFSMField or RichChoiceField")
|
||||
|
||||
# Validate metadata
|
||||
choice_group = field.choice_group
|
||||
@@ -343,9 +318,7 @@ def validate_model_state_machine(
|
||||
|
||||
if not result.is_valid:
|
||||
error_messages = [str(e) for e in result.errors]
|
||||
raise ValueError(
|
||||
"State machine validation failed:\n" + "\n".join(error_messages)
|
||||
)
|
||||
raise ValueError("State machine validation failed:\n" + "\n".join(error_messages))
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ See Also:
|
||||
- apps.core.state_machine.fields.RichFSMField: The FSM field implementation
|
||||
- django_fsm.can_proceed: FSM transition checking utility
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
@@ -47,25 +48,75 @@ from django_fsm import can_proceed
|
||||
# Default transition metadata for styling
|
||||
TRANSITION_METADATA = {
|
||||
# Approval transitions
|
||||
"approve": {"style": "green", "icon": "check", "requires_confirm": True, "confirm_message": "Are you sure you want to approve this?"},
|
||||
"transition_to_approved": {"style": "green", "icon": "check", "requires_confirm": True, "confirm_message": "Are you sure you want to approve this?"},
|
||||
"approve": {
|
||||
"style": "green",
|
||||
"icon": "check",
|
||||
"requires_confirm": True,
|
||||
"confirm_message": "Are you sure you want to approve this?",
|
||||
},
|
||||
"transition_to_approved": {
|
||||
"style": "green",
|
||||
"icon": "check",
|
||||
"requires_confirm": True,
|
||||
"confirm_message": "Are you sure you want to approve this?",
|
||||
},
|
||||
# Rejection transitions
|
||||
"reject": {"style": "red", "icon": "times", "requires_confirm": True, "confirm_message": "Are you sure you want to reject this?"},
|
||||
"transition_to_rejected": {"style": "red", "icon": "times", "requires_confirm": True, "confirm_message": "Are you sure you want to reject this?"},
|
||||
"reject": {
|
||||
"style": "red",
|
||||
"icon": "times",
|
||||
"requires_confirm": True,
|
||||
"confirm_message": "Are you sure you want to reject this?",
|
||||
},
|
||||
"transition_to_rejected": {
|
||||
"style": "red",
|
||||
"icon": "times",
|
||||
"requires_confirm": True,
|
||||
"confirm_message": "Are you sure you want to reject this?",
|
||||
},
|
||||
# Escalation transitions
|
||||
"escalate": {"style": "yellow", "icon": "arrow-up", "requires_confirm": True, "confirm_message": "Are you sure you want to escalate this?"},
|
||||
"transition_to_escalated": {"style": "yellow", "icon": "arrow-up", "requires_confirm": True, "confirm_message": "Are you sure you want to escalate this?"},
|
||||
"escalate": {
|
||||
"style": "yellow",
|
||||
"icon": "arrow-up",
|
||||
"requires_confirm": True,
|
||||
"confirm_message": "Are you sure you want to escalate this?",
|
||||
},
|
||||
"transition_to_escalated": {
|
||||
"style": "yellow",
|
||||
"icon": "arrow-up",
|
||||
"requires_confirm": True,
|
||||
"confirm_message": "Are you sure you want to escalate this?",
|
||||
},
|
||||
# Assignment transitions
|
||||
"assign": {"style": "blue", "icon": "user-plus", "requires_confirm": False},
|
||||
"unassign": {"style": "gray", "icon": "user-minus", "requires_confirm": False},
|
||||
# Status transitions
|
||||
"start": {"style": "blue", "icon": "play", "requires_confirm": False},
|
||||
"complete": {"style": "green", "icon": "check-circle", "requires_confirm": True, "confirm_message": "Are you sure you want to complete this?"},
|
||||
"cancel": {"style": "red", "icon": "ban", "requires_confirm": True, "confirm_message": "Are you sure you want to cancel this?"},
|
||||
"complete": {
|
||||
"style": "green",
|
||||
"icon": "check-circle",
|
||||
"requires_confirm": True,
|
||||
"confirm_message": "Are you sure you want to complete this?",
|
||||
},
|
||||
"cancel": {
|
||||
"style": "red",
|
||||
"icon": "ban",
|
||||
"requires_confirm": True,
|
||||
"confirm_message": "Are you sure you want to cancel this?",
|
||||
},
|
||||
"reopen": {"style": "blue", "icon": "redo", "requires_confirm": False},
|
||||
# Resolution transitions
|
||||
"resolve": {"style": "green", "icon": "check-double", "requires_confirm": True, "confirm_message": "Are you sure you want to resolve this?"},
|
||||
"dismiss": {"style": "gray", "icon": "times-circle", "requires_confirm": True, "confirm_message": "Are you sure you want to dismiss this?"},
|
||||
"resolve": {
|
||||
"style": "green",
|
||||
"icon": "check-double",
|
||||
"requires_confirm": True,
|
||||
"confirm_message": "Are you sure you want to resolve this?",
|
||||
},
|
||||
"dismiss": {
|
||||
"style": "gray",
|
||||
"icon": "times-circle",
|
||||
"requires_confirm": True,
|
||||
"confirm_message": "Are you sure you want to dismiss this?",
|
||||
},
|
||||
# Default
|
||||
"default": {"style": "gray", "icon": "arrow-right", "requires_confirm": False},
|
||||
}
|
||||
@@ -86,22 +137,22 @@ def _get_transition_metadata(transition_name: str) -> dict[str, Any]:
|
||||
def _format_transition_label(transition_name: str) -> str:
|
||||
"""Format a transition method name into a human-readable label."""
|
||||
label = transition_name
|
||||
for prefix in ['transition_to_', 'transition_', 'do_']:
|
||||
for prefix in ["transition_to_", "transition_", "do_"]:
|
||||
if label.startswith(prefix):
|
||||
label = label[len(prefix):]
|
||||
label = label[len(prefix) :]
|
||||
break
|
||||
|
||||
if label.endswith('ed') and len(label) > 3:
|
||||
if label.endswith('ied'):
|
||||
label = label[:-3] + 'y'
|
||||
if label.endswith("ed") and len(label) > 3:
|
||||
if label.endswith("ied"):
|
||||
label = label[:-3] + "y"
|
||||
elif label[-3] == label[-4]:
|
||||
label = label[:-3]
|
||||
else:
|
||||
label = label[:-1]
|
||||
if not label.endswith('e'):
|
||||
if not label.endswith("e"):
|
||||
label = label[:-1]
|
||||
|
||||
return label.replace('_', ' ').title()
|
||||
return label.replace("_", " ").title()
|
||||
|
||||
|
||||
class StateMachineMixin(models.Model):
|
||||
@@ -187,14 +238,10 @@ class StateMachineMixin(models.Model):
|
||||
"""Check if a transition method can proceed for the current instance."""
|
||||
method = getattr(self, transition_method_name, None)
|
||||
if method is None or not callable(method):
|
||||
raise AttributeError(
|
||||
f"Transition method '{transition_method_name}' not found"
|
||||
)
|
||||
raise AttributeError(f"Transition method '{transition_method_name}' not found")
|
||||
return can_proceed(method)
|
||||
|
||||
def get_available_transitions(
|
||||
self, field_name: str | None = None
|
||||
) -> Iterable[Any]:
|
||||
def get_available_transitions(self, field_name: str | None = None) -> Iterable[Any]:
|
||||
"""Return available transitions when helpers are present."""
|
||||
name = field_name or self.state_field_name
|
||||
helper_name = f"get_available_{name}_transitions"
|
||||
@@ -246,14 +293,16 @@ class StateMachineMixin(models.Model):
|
||||
try:
|
||||
if can_proceed(method, user):
|
||||
metadata = _get_transition_metadata(transition_name)
|
||||
transitions.append({
|
||||
'name': transition_name,
|
||||
'label': _format_transition_label(transition_name),
|
||||
'icon': metadata.get('icon', 'arrow-right'),
|
||||
'style': metadata.get('style', 'gray'),
|
||||
'requires_confirm': metadata.get('requires_confirm', False),
|
||||
'confirm_message': metadata.get('confirm_message', 'Are you sure?'),
|
||||
})
|
||||
transitions.append(
|
||||
{
|
||||
"name": transition_name,
|
||||
"label": _format_transition_label(transition_name),
|
||||
"icon": metadata.get("icon", "arrow-right"),
|
||||
"style": metadata.get("style", "gray"),
|
||||
"requires_confirm": metadata.get("requires_confirm", False),
|
||||
"confirm_message": metadata.get("confirm_message", "Are you sure?"),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Skip transitions that raise errors
|
||||
pass
|
||||
|
||||
@@ -47,7 +47,7 @@ class CallbackStats:
|
||||
successful_executions: int = 0
|
||||
failed_executions: int = 0
|
||||
total_duration_ms: float = 0.0
|
||||
min_duration_ms: float = float('inf')
|
||||
min_duration_ms: float = float("inf")
|
||||
max_duration_ms: float = 0.0
|
||||
last_execution: datetime | None = None
|
||||
last_error: str | None = None
|
||||
@@ -97,10 +97,10 @@ class CallbackMonitor:
|
||||
- Performance statistics
|
||||
"""
|
||||
|
||||
_instance: Optional['CallbackMonitor'] = None
|
||||
_instance: Optional["CallbackMonitor"] = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls) -> 'CallbackMonitor':
|
||||
def __new__(cls) -> "CallbackMonitor":
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
@@ -112,9 +112,7 @@ class CallbackMonitor:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._stats: dict[str, CallbackStats] = defaultdict(
|
||||
lambda: CallbackStats(callback_name="")
|
||||
)
|
||||
self._stats: dict[str, CallbackStats] = defaultdict(lambda: CallbackStats(callback_name=""))
|
||||
self._recent_executions: list[CallbackExecutionRecord] = []
|
||||
self._max_recent_records = 1000
|
||||
self._enabled = self._check_enabled()
|
||||
@@ -123,13 +121,13 @@ class CallbackMonitor:
|
||||
|
||||
def _check_enabled(self) -> bool:
|
||||
"""Check if monitoring is enabled."""
|
||||
callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
|
||||
return callback_settings.get('monitoring_enabled', True)
|
||||
callback_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {})
|
||||
return callback_settings.get("monitoring_enabled", True)
|
||||
|
||||
def _check_debug_mode(self) -> bool:
|
||||
"""Check if debug mode is enabled."""
|
||||
callback_settings = getattr(settings, 'STATE_MACHINE_CALLBACKS', {})
|
||||
return callback_settings.get('debug_mode', settings.DEBUG)
|
||||
callback_settings = getattr(settings, "STATE_MACHINE_CALLBACKS", {})
|
||||
return callback_settings.get("debug_mode", settings.DEBUG)
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Check if monitoring is currently enabled."""
|
||||
@@ -197,7 +195,7 @@ class CallbackMonitor:
|
||||
# Store recent executions (with size limit)
|
||||
self._recent_executions.append(record)
|
||||
if len(self._recent_executions) > self._max_recent_records:
|
||||
self._recent_executions = self._recent_executions[-self._max_recent_records:]
|
||||
self._recent_executions = self._recent_executions[-self._max_recent_records :]
|
||||
|
||||
# Log in debug mode
|
||||
if self._debug_mode:
|
||||
@@ -277,12 +275,12 @@ class CallbackMonitor:
|
||||
|
||||
# Build summary
|
||||
summary = {
|
||||
'total_failures': len(failures),
|
||||
'by_callback': {
|
||||
"total_failures": len(failures),
|
||||
"by_callback": {
|
||||
name: {
|
||||
'count': len(records),
|
||||
'last_error': records[-1].error_message if records else None,
|
||||
'last_occurrence': records[-1].timestamp if records else None,
|
||||
"count": len(records),
|
||||
"last_error": records[-1].error_message if records else None,
|
||||
"last_occurrence": records[-1].timestamp if records else None,
|
||||
}
|
||||
for name, records in by_callback.items()
|
||||
},
|
||||
@@ -293,12 +291,12 @@ class CallbackMonitor:
|
||||
def get_performance_report(self) -> dict[str, Any]:
|
||||
"""Get a performance report for all callbacks."""
|
||||
report = {
|
||||
'callbacks': {},
|
||||
'summary': {
|
||||
'total_callbacks': len(self._stats),
|
||||
'total_executions': sum(s.total_executions for s in self._stats.values()),
|
||||
'total_failures': sum(s.failed_executions for s in self._stats.values()),
|
||||
'avg_duration_ms': 0.0,
|
||||
"callbacks": {},
|
||||
"summary": {
|
||||
"total_callbacks": len(self._stats),
|
||||
"total_executions": sum(s.total_executions for s in self._stats.values()),
|
||||
"total_failures": sum(s.failed_executions for s in self._stats.values()),
|
||||
"avg_duration_ms": 0.0,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -306,19 +304,19 @@ class CallbackMonitor:
|
||||
total_count = 0
|
||||
|
||||
for name, stats in self._stats.items():
|
||||
report['callbacks'][name] = {
|
||||
'executions': stats.total_executions,
|
||||
'success_rate': f"{stats.success_rate:.1f}%",
|
||||
'avg_duration_ms': f"{stats.avg_duration_ms:.2f}",
|
||||
'min_duration_ms': f"{stats.min_duration_ms:.2f}" if stats.min_duration_ms != float('inf') else "N/A",
|
||||
'max_duration_ms': f"{stats.max_duration_ms:.2f}",
|
||||
'last_execution': stats.last_execution.isoformat() if stats.last_execution else None,
|
||||
report["callbacks"][name] = {
|
||||
"executions": stats.total_executions,
|
||||
"success_rate": f"{stats.success_rate:.1f}%",
|
||||
"avg_duration_ms": f"{stats.avg_duration_ms:.2f}",
|
||||
"min_duration_ms": f"{stats.min_duration_ms:.2f}" if stats.min_duration_ms != float("inf") else "N/A",
|
||||
"max_duration_ms": f"{stats.max_duration_ms:.2f}",
|
||||
"last_execution": stats.last_execution.isoformat() if stats.last_execution else None,
|
||||
}
|
||||
total_duration += stats.total_duration_ms
|
||||
total_count += stats.total_executions
|
||||
|
||||
if total_count > 0:
|
||||
report['summary']['avg_duration_ms'] = total_duration / total_count
|
||||
report["summary"]["avg_duration_ms"] = total_duration / total_count
|
||||
|
||||
return report
|
||||
|
||||
@@ -361,7 +359,7 @@ class TimedCallbackExecution:
|
||||
self.success = True
|
||||
self.error_message: str | None = None
|
||||
|
||||
def __enter__(self) -> 'TimedCallbackExecution':
|
||||
def __enter__(self) -> "TimedCallbackExecution":
|
||||
self.start_time = time.perf_counter()
|
||||
return self
|
||||
|
||||
@@ -442,12 +440,12 @@ def get_callback_execution_order(
|
||||
|
||||
|
||||
__all__ = [
|
||||
'CallbackExecutionRecord',
|
||||
'CallbackStats',
|
||||
'CallbackMonitor',
|
||||
'callback_monitor',
|
||||
'TimedCallbackExecution',
|
||||
'log_transition_start',
|
||||
'log_transition_end',
|
||||
'get_callback_execution_order',
|
||||
"CallbackExecutionRecord",
|
||||
"CallbackStats",
|
||||
"CallbackMonitor",
|
||||
"callback_monitor",
|
||||
"TimedCallbackExecution",
|
||||
"log_transition_start",
|
||||
"log_transition_end",
|
||||
"get_callback_execution_order",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""TransitionRegistry - Centralized registry for managing FSM transitions."""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
@@ -86,9 +87,7 @@ class TransitionRegistry:
|
||||
self._transitions[key][transition_key] = transition_info
|
||||
return transition_info
|
||||
|
||||
def get_transition(
|
||||
self, choice_group: str, domain: str, source: str, target: str
|
||||
) -> TransitionInfo | None:
|
||||
def get_transition(self, choice_group: str, domain: str, source: str, target: str) -> TransitionInfo | None:
|
||||
"""
|
||||
Retrieve transition info.
|
||||
|
||||
@@ -109,9 +108,7 @@ class TransitionRegistry:
|
||||
|
||||
return self._transitions[key].get(transition_key)
|
||||
|
||||
def get_available_transitions(
|
||||
self, choice_group: str, domain: str, current_state: str
|
||||
) -> list[TransitionInfo]:
|
||||
def get_available_transitions(self, choice_group: str, domain: str, current_state: str) -> list[TransitionInfo]:
|
||||
"""
|
||||
Get all valid transitions from a state.
|
||||
|
||||
@@ -135,9 +132,7 @@ class TransitionRegistry:
|
||||
|
||||
return available
|
||||
|
||||
def get_transition_method_name(
|
||||
self, choice_group: str, domain: str, source: str, target: str
|
||||
) -> str | None:
|
||||
def get_transition_method_name(self, choice_group: str, domain: str, source: str, target: str) -> str | None:
|
||||
"""
|
||||
Get the method name for a transition.
|
||||
|
||||
@@ -153,9 +148,7 @@ class TransitionRegistry:
|
||||
transition = self.get_transition(choice_group, domain, source, target)
|
||||
return transition.method_name if transition else None
|
||||
|
||||
def validate_transition(
|
||||
self, choice_group: str, domain: str, source: str, target: str
|
||||
) -> bool:
|
||||
def validate_transition(self, choice_group: str, domain: str, source: str, target: str) -> bool:
|
||||
"""
|
||||
Check if a transition is valid.
|
||||
|
||||
@@ -168,13 +161,9 @@ class TransitionRegistry:
|
||||
Returns:
|
||||
True if transition is valid
|
||||
"""
|
||||
return (
|
||||
self.get_transition(choice_group, domain, source, target) is not None
|
||||
)
|
||||
return self.get_transition(choice_group, domain, source, target) is not None
|
||||
|
||||
def build_registry_from_choices(
|
||||
self, choice_group: str, domain: str = "core"
|
||||
) -> None:
|
||||
def build_registry_from_choices(self, choice_group: str, domain: str = "core") -> None:
|
||||
"""
|
||||
Automatically populate registry from RichChoice metadata.
|
||||
|
||||
@@ -194,9 +183,7 @@ class TransitionRegistry:
|
||||
|
||||
for target in targets:
|
||||
# Use shared method name determination
|
||||
method_name = determine_method_name_for_transition(
|
||||
source, target
|
||||
)
|
||||
method_name = determine_method_name_for_transition(source, target)
|
||||
|
||||
self.register_transition(
|
||||
choice_group=choice_group,
|
||||
@@ -226,9 +213,7 @@ class TransitionRegistry:
|
||||
else:
|
||||
self._transitions.clear()
|
||||
|
||||
def export_transition_graph(
|
||||
self, choice_group: str, domain: str, format: str = "dict"
|
||||
) -> Any:
|
||||
def export_transition_graph(self, choice_group: str, domain: str, format: str = "dict") -> Any:
|
||||
"""
|
||||
Export state machine graph for visualization.
|
||||
|
||||
@@ -247,7 +232,7 @@ class TransitionRegistry:
|
||||
|
||||
if format == "dict":
|
||||
graph: dict[str, list[str]] = {}
|
||||
for (source, target), info in self._transitions[key].items():
|
||||
for (source, target), _info in self._transitions[key].items():
|
||||
if source not in graph:
|
||||
graph[source] = []
|
||||
graph[source].append(target)
|
||||
@@ -262,10 +247,7 @@ class TransitionRegistry:
|
||||
elif format == "dot":
|
||||
lines = ["digraph {"]
|
||||
for (source, target), info in self._transitions[key].items():
|
||||
lines.append(
|
||||
f' "{source}" -> "{target}" '
|
||||
f'[label="{info.method_name}"];'
|
||||
)
|
||||
lines.append(f' "{source}" -> "{target}" ' f'[label="{info.method_name}"];')
|
||||
lines.append("}")
|
||||
return "\n".join(lines)
|
||||
|
||||
@@ -288,13 +270,14 @@ registry_instance = TransitionRegistry()
|
||||
|
||||
# Callback registration helpers
|
||||
|
||||
|
||||
def register_callback(
|
||||
model_class: type[models.Model],
|
||||
field_name: str,
|
||||
source: str,
|
||||
target: str,
|
||||
callback: Any,
|
||||
stage: str = 'post',
|
||||
stage: str = "post",
|
||||
) -> None:
|
||||
"""
|
||||
Register a callback for a specific state transition.
|
||||
@@ -325,7 +308,7 @@ def register_notification_callback(
|
||||
source: str,
|
||||
target: str,
|
||||
notification_type: str,
|
||||
recipient_field: str = 'submitted_by',
|
||||
recipient_field: str = "submitted_by",
|
||||
) -> None:
|
||||
"""
|
||||
Register a notification callback for a state transition.
|
||||
@@ -344,15 +327,15 @@ def register_notification_callback(
|
||||
notification_type=notification_type,
|
||||
recipient_field=recipient_field,
|
||||
)
|
||||
register_callback(model_class, field_name, source, target, callback, 'post')
|
||||
register_callback(model_class, field_name, source, target, callback, "post")
|
||||
|
||||
|
||||
def register_cache_invalidation(
|
||||
model_class: type[models.Model],
|
||||
field_name: str,
|
||||
cache_patterns: list[str] | None = None,
|
||||
source: str = '*',
|
||||
target: str = '*',
|
||||
source: str = "*",
|
||||
target: str = "*",
|
||||
) -> None:
|
||||
"""
|
||||
Register cache invalidation for state transitions.
|
||||
@@ -367,15 +350,15 @@ def register_cache_invalidation(
|
||||
from .callbacks.cache import CacheInvalidationCallback
|
||||
|
||||
callback = CacheInvalidationCallback(patterns=cache_patterns or [])
|
||||
register_callback(model_class, field_name, source, target, callback, 'post')
|
||||
register_callback(model_class, field_name, source, target, callback, "post")
|
||||
|
||||
|
||||
def register_related_update(
|
||||
model_class: type[models.Model],
|
||||
field_name: str,
|
||||
update_func: Callable,
|
||||
source: str = '*',
|
||||
target: str = '*',
|
||||
source: str = "*",
|
||||
target: str = "*",
|
||||
) -> None:
|
||||
"""
|
||||
Register a related model update callback.
|
||||
@@ -390,7 +373,7 @@ def register_related_update(
|
||||
from .callbacks.related_updates import RelatedModelUpdateCallback
|
||||
|
||||
callback = RelatedModelUpdateCallback(update_function=update_func)
|
||||
register_callback(model_class, field_name, source, target, callback, 'post')
|
||||
register_callback(model_class, field_name, source, target, callback, "post")
|
||||
|
||||
|
||||
def register_transition_callbacks(cls: type[models.Model]) -> type[models.Model]:
|
||||
@@ -414,20 +397,20 @@ def register_transition_callbacks(cls: type[models.Model]) -> type[models.Model]
|
||||
Returns:
|
||||
The decorated model class.
|
||||
"""
|
||||
meta = getattr(cls, 'Meta', None)
|
||||
meta = getattr(cls, "Meta", None)
|
||||
if not meta:
|
||||
return cls
|
||||
|
||||
transition_callbacks = getattr(meta, 'transition_callbacks', None)
|
||||
transition_callbacks = getattr(meta, "transition_callbacks", None)
|
||||
if not transition_callbacks:
|
||||
return cls
|
||||
|
||||
# Get the FSM field name
|
||||
field_name = getattr(meta, 'fsm_field', 'status')
|
||||
field_name = getattr(meta, "fsm_field", "status")
|
||||
|
||||
# Register each callback
|
||||
for (source, target), callbacks in transition_callbacks.items():
|
||||
if not isinstance(callbacks, (list, tuple)):
|
||||
if not isinstance(callbacks, list | tuple):
|
||||
callbacks = [callbacks]
|
||||
|
||||
for callback in callbacks:
|
||||
@@ -455,23 +438,23 @@ def discover_and_register_callbacks() -> None:
|
||||
|
||||
for model in apps.get_models():
|
||||
# Check if model has StateMachineMixin
|
||||
if not hasattr(model, '_fsm_metadata') and not hasattr(model, 'Meta'):
|
||||
if not hasattr(model, "_fsm_metadata") and not hasattr(model, "Meta"):
|
||||
continue
|
||||
|
||||
meta = getattr(model, 'Meta', None)
|
||||
meta = getattr(model, "Meta", None)
|
||||
if not meta:
|
||||
continue
|
||||
|
||||
transition_callbacks = getattr(meta, 'transition_callbacks', None)
|
||||
transition_callbacks = getattr(meta, "transition_callbacks", None)
|
||||
if not transition_callbacks:
|
||||
continue
|
||||
|
||||
# Get the FSM field name
|
||||
field_name = getattr(meta, 'fsm_field', 'status')
|
||||
field_name = getattr(meta, "fsm_field", "status")
|
||||
|
||||
# Register callbacks
|
||||
for (source, target), callbacks in transition_callbacks.items():
|
||||
if not isinstance(callbacks, (list, tuple)):
|
||||
if not isinstance(callbacks, list | tuple):
|
||||
callbacks = [callbacks]
|
||||
|
||||
for callback in callbacks:
|
||||
|
||||
@@ -74,7 +74,7 @@ class TransitionSignalHandler:
|
||||
source: str,
|
||||
target: str,
|
||||
handler: Callable,
|
||||
stage: str = 'post',
|
||||
stage: str = "post",
|
||||
) -> None:
|
||||
"""
|
||||
Register a handler for a specific transition.
|
||||
@@ -95,10 +95,7 @@ class TransitionSignalHandler:
|
||||
signal = self._get_signal(stage)
|
||||
self._connect_signal(signal, model_class, source, target, handler)
|
||||
|
||||
logger.debug(
|
||||
f"Registered {stage} transition handler for "
|
||||
f"{model_class.__name__}: {source} → {target}"
|
||||
)
|
||||
logger.debug(f"Registered {stage} transition handler for " f"{model_class.__name__}: {source} → {target}")
|
||||
|
||||
def unregister(
|
||||
self,
|
||||
@@ -106,7 +103,7 @@ class TransitionSignalHandler:
|
||||
source: str,
|
||||
target: str,
|
||||
handler: Callable,
|
||||
stage: str = 'post',
|
||||
stage: str = "post",
|
||||
) -> None:
|
||||
"""Unregister a previously registered handler."""
|
||||
key = self._make_key(model_class, source, target, stage)
|
||||
@@ -128,9 +125,9 @@ class TransitionSignalHandler:
|
||||
|
||||
def _get_signal(self, stage: str) -> Signal:
|
||||
"""Get the signal for a given stage."""
|
||||
if stage == 'pre':
|
||||
if stage == "pre":
|
||||
return pre_state_transition
|
||||
elif stage == 'error':
|
||||
elif stage == "error":
|
||||
return state_transition_failed
|
||||
return post_state_transition
|
||||
|
||||
@@ -150,13 +147,13 @@ class TransitionSignalHandler:
|
||||
return
|
||||
|
||||
# Check source state
|
||||
signal_source = kwargs.get('source', '')
|
||||
if source != '*' and str(signal_source) != source:
|
||||
signal_source = kwargs.get("source", "")
|
||||
if source != "*" and str(signal_source) != source:
|
||||
return
|
||||
|
||||
# Check target state
|
||||
signal_target = kwargs.get('target', '')
|
||||
if target != '*' and str(signal_target) != target:
|
||||
signal_target = kwargs.get("target", "")
|
||||
if target != "*" and str(signal_target) != target:
|
||||
return
|
||||
|
||||
# Call the handler
|
||||
@@ -174,7 +171,7 @@ def register_transition_handler(
|
||||
source: str,
|
||||
target: str,
|
||||
handler: Callable,
|
||||
stage: str = 'post',
|
||||
stage: str = "post",
|
||||
) -> None:
|
||||
"""
|
||||
Convenience function to register a transition signal handler.
|
||||
@@ -186,9 +183,7 @@ def register_transition_handler(
|
||||
handler: The handler function to call.
|
||||
stage: 'pre', 'post', or 'error'.
|
||||
"""
|
||||
transition_signal_handler.register(
|
||||
model_class, source, target, handler, stage
|
||||
)
|
||||
transition_signal_handler.register(model_class, source, target, handler, stage)
|
||||
|
||||
|
||||
def connect_fsm_log_signals() -> None:
|
||||
@@ -231,9 +226,9 @@ class TransitionHandlerDecorator:
|
||||
def __init__(
|
||||
self,
|
||||
model_class: type[models.Model],
|
||||
source: str = '*',
|
||||
target: str = '*',
|
||||
stage: str = 'post',
|
||||
source: str = "*",
|
||||
target: str = "*",
|
||||
stage: str = "post",
|
||||
):
|
||||
"""
|
||||
Initialize the decorator.
|
||||
@@ -263,9 +258,9 @@ class TransitionHandlerDecorator:
|
||||
|
||||
def on_transition(
|
||||
model_class: type[models.Model],
|
||||
source: str = '*',
|
||||
target: str = '*',
|
||||
stage: str = 'post',
|
||||
source: str = "*",
|
||||
target: str = "*",
|
||||
stage: str = "post",
|
||||
) -> TransitionHandlerDecorator:
|
||||
"""
|
||||
Decorator factory for registering transition handlers.
|
||||
@@ -289,44 +284,44 @@ def on_transition(
|
||||
|
||||
def on_pre_transition(
|
||||
model_class: type[models.Model],
|
||||
source: str = '*',
|
||||
target: str = '*',
|
||||
source: str = "*",
|
||||
target: str = "*",
|
||||
) -> TransitionHandlerDecorator:
|
||||
"""Decorator for pre-transition handlers."""
|
||||
return on_transition(model_class, source, target, stage='pre')
|
||||
return on_transition(model_class, source, target, stage="pre")
|
||||
|
||||
|
||||
def on_post_transition(
|
||||
model_class: type[models.Model],
|
||||
source: str = '*',
|
||||
target: str = '*',
|
||||
source: str = "*",
|
||||
target: str = "*",
|
||||
) -> TransitionHandlerDecorator:
|
||||
"""Decorator for post-transition handlers."""
|
||||
return on_transition(model_class, source, target, stage='post')
|
||||
return on_transition(model_class, source, target, stage="post")
|
||||
|
||||
|
||||
def on_transition_error(
|
||||
model_class: type[models.Model],
|
||||
source: str = '*',
|
||||
target: str = '*',
|
||||
source: str = "*",
|
||||
target: str = "*",
|
||||
) -> TransitionHandlerDecorator:
|
||||
"""Decorator for transition error handlers."""
|
||||
return on_transition(model_class, source, target, stage='error')
|
||||
return on_transition(model_class, source, target, stage="error")
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Signals
|
||||
'pre_state_transition',
|
||||
'post_state_transition',
|
||||
'state_transition_failed',
|
||||
"pre_state_transition",
|
||||
"post_state_transition",
|
||||
"state_transition_failed",
|
||||
# Handler registration
|
||||
'TransitionSignalHandler',
|
||||
'transition_signal_handler',
|
||||
'register_transition_handler',
|
||||
'connect_fsm_log_signals',
|
||||
"TransitionSignalHandler",
|
||||
"transition_signal_handler",
|
||||
"register_transition_handler",
|
||||
"connect_fsm_log_signals",
|
||||
# Decorators
|
||||
'on_transition',
|
||||
'on_pre_transition',
|
||||
'on_post_transition',
|
||||
'on_transition_error',
|
||||
"on_transition",
|
||||
"on_pre_transition",
|
||||
"on_post_transition",
|
||||
"on_transition_error",
|
||||
]
|
||||
|
||||
@@ -29,11 +29,11 @@ class UserFactory:
|
||||
@classmethod
|
||||
def create_user(
|
||||
cls,
|
||||
role: str = 'USER',
|
||||
role: str = "USER",
|
||||
username: str | None = None,
|
||||
email: str | None = None,
|
||||
password: str = 'testpass123',
|
||||
**kwargs
|
||||
password: str = "testpass123",
|
||||
**kwargs,
|
||||
) -> User:
|
||||
"""
|
||||
Create a user with specified role.
|
||||
@@ -54,33 +54,27 @@ class UserFactory:
|
||||
if email is None:
|
||||
email = f"{role.lower()}_{uid}@example.com"
|
||||
|
||||
return User.objects.create_user(
|
||||
username=username,
|
||||
email=email,
|
||||
password=password,
|
||||
role=role,
|
||||
**kwargs
|
||||
)
|
||||
return User.objects.create_user(username=username, email=email, password=password, role=role, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_regular_user(cls, **kwargs) -> User:
|
||||
"""Create a regular user."""
|
||||
return cls.create_user(role='USER', **kwargs)
|
||||
return cls.create_user(role="USER", **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_moderator(cls, **kwargs) -> User:
|
||||
"""Create a moderator user."""
|
||||
return cls.create_user(role='MODERATOR', **kwargs)
|
||||
return cls.create_user(role="MODERATOR", **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_admin(cls, **kwargs) -> User:
|
||||
"""Create an admin user."""
|
||||
return cls.create_user(role='ADMIN', **kwargs)
|
||||
return cls.create_user(role="ADMIN", **kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_superuser(cls, **kwargs) -> User:
|
||||
"""Create a superuser."""
|
||||
return cls.create_user(role='SUPERUSER', **kwargs)
|
||||
return cls.create_user(role="SUPERUSER", **kwargs)
|
||||
|
||||
|
||||
class CompanyFactory:
|
||||
@@ -102,11 +96,7 @@ class CompanyFactory:
|
||||
if name is None:
|
||||
name = f"Test Operator {uid}"
|
||||
|
||||
defaults = {
|
||||
'name': name,
|
||||
'description': f'Test operator company {uid}',
|
||||
'roles': ['OPERATOR']
|
||||
}
|
||||
defaults = {"name": name, "description": f"Test operator company {uid}", "roles": ["OPERATOR"]}
|
||||
defaults.update(kwargs)
|
||||
return Company.objects.create(**defaults)
|
||||
|
||||
@@ -119,11 +109,7 @@ class CompanyFactory:
|
||||
if name is None:
|
||||
name = f"Test Manufacturer {uid}"
|
||||
|
||||
defaults = {
|
||||
'name': name,
|
||||
'description': f'Test manufacturer company {uid}',
|
||||
'roles': ['MANUFACTURER']
|
||||
}
|
||||
defaults = {"name": name, "description": f"Test manufacturer company {uid}", "roles": ["MANUFACTURER"]}
|
||||
defaults.update(kwargs)
|
||||
return Company.objects.create(**defaults)
|
||||
|
||||
@@ -140,11 +126,7 @@ class ParkFactory:
|
||||
|
||||
@classmethod
|
||||
def create_park(
|
||||
cls,
|
||||
name: str | None = None,
|
||||
operator: Any | None = None,
|
||||
status: str = 'OPERATING',
|
||||
**kwargs
|
||||
cls, name: str | None = None, operator: Any | None = None, status: str = "OPERATING", **kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Create a park with specified status.
|
||||
@@ -167,12 +149,12 @@ class ParkFactory:
|
||||
operator = CompanyFactory.create_operator()
|
||||
|
||||
defaults = {
|
||||
'name': name,
|
||||
'slug': f'test-park-{uid}',
|
||||
'description': f'A test park {uid}',
|
||||
'operator': operator,
|
||||
'status': status,
|
||||
'timezone': 'America/New_York'
|
||||
"name": name,
|
||||
"slug": f"test-park-{uid}",
|
||||
"description": f"A test park {uid}",
|
||||
"operator": operator,
|
||||
"status": status,
|
||||
"timezone": "America/New_York",
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return Park.objects.create(**defaults)
|
||||
@@ -194,8 +176,8 @@ class RideFactory:
|
||||
name: str | None = None,
|
||||
park: Any | None = None,
|
||||
manufacturer: Any | None = None,
|
||||
status: str = 'OPERATING',
|
||||
**kwargs
|
||||
status: str = "OPERATING",
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
Create a ride with specified status.
|
||||
@@ -221,12 +203,12 @@ class RideFactory:
|
||||
manufacturer = CompanyFactory.create_manufacturer()
|
||||
|
||||
defaults = {
|
||||
'name': name,
|
||||
'slug': f'test-ride-{uid}',
|
||||
'description': f'A test ride {uid}',
|
||||
'park': park,
|
||||
'manufacturer': manufacturer,
|
||||
'status': status
|
||||
"name": name,
|
||||
"slug": f"test-ride-{uid}",
|
||||
"description": f"A test ride {uid}",
|
||||
"park": park,
|
||||
"manufacturer": manufacturer,
|
||||
"status": status,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return Ride.objects.create(**defaults)
|
||||
@@ -247,9 +229,9 @@ class EditSubmissionFactory:
|
||||
cls,
|
||||
user: Any | None = None,
|
||||
target_object: Any | None = None,
|
||||
status: str = 'PENDING',
|
||||
status: str = "PENDING",
|
||||
changes: dict[str, Any] | None = None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""
|
||||
Create an edit submission.
|
||||
@@ -271,23 +253,20 @@ class EditSubmissionFactory:
|
||||
if user is None:
|
||||
user = UserFactory.create_regular_user()
|
||||
if target_object is None:
|
||||
target_object = Company.objects.create(
|
||||
name=f'Target Company {uid}',
|
||||
description='Test company'
|
||||
)
|
||||
target_object = Company.objects.create(name=f"Target Company {uid}", description="Test company")
|
||||
if changes is None:
|
||||
changes = {'name': f'Updated Name {uid}'}
|
||||
changes = {"name": f"Updated Name {uid}"}
|
||||
|
||||
content_type = ContentType.objects.get_for_model(target_object)
|
||||
|
||||
defaults = {
|
||||
'user': user,
|
||||
'content_type': content_type,
|
||||
'object_id': target_object.id,
|
||||
'submission_type': 'EDIT',
|
||||
'changes': changes,
|
||||
'status': status,
|
||||
'reason': f'Test reason {uid}'
|
||||
"user": user,
|
||||
"content_type": content_type,
|
||||
"object_id": target_object.id,
|
||||
"submission_type": "EDIT",
|
||||
"changes": changes,
|
||||
"status": status,
|
||||
"reason": f"Test reason {uid}",
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return EditSubmission.objects.create(**defaults)
|
||||
@@ -305,11 +284,7 @@ class ModerationReportFactory:
|
||||
|
||||
@classmethod
|
||||
def create_report(
|
||||
cls,
|
||||
reporter: Any | None = None,
|
||||
target_object: Any | None = None,
|
||||
status: str = 'PENDING',
|
||||
**kwargs
|
||||
cls, reporter: Any | None = None, target_object: Any | None = None, status: str = "PENDING", **kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
Create a moderation report.
|
||||
@@ -330,23 +305,20 @@ class ModerationReportFactory:
|
||||
if reporter is None:
|
||||
reporter = UserFactory.create_regular_user()
|
||||
if target_object is None:
|
||||
target_object = Company.objects.create(
|
||||
name=f'Reported Company {uid}',
|
||||
description='Test company'
|
||||
)
|
||||
target_object = Company.objects.create(name=f"Reported Company {uid}", description="Test company")
|
||||
|
||||
content_type = ContentType.objects.get_for_model(target_object)
|
||||
|
||||
defaults = {
|
||||
'report_type': 'CONTENT',
|
||||
'status': status,
|
||||
'priority': 'MEDIUM',
|
||||
'reported_entity_type': target_object._meta.model_name,
|
||||
'reported_entity_id': target_object.id,
|
||||
'content_type': content_type,
|
||||
'reason': f'Test reason {uid}',
|
||||
'description': f'Test report description {uid}',
|
||||
'reported_by': reporter
|
||||
"report_type": "CONTENT",
|
||||
"status": status,
|
||||
"priority": "MEDIUM",
|
||||
"reported_entity_type": target_object._meta.model_name,
|
||||
"reported_entity_id": target_object.id,
|
||||
"content_type": content_type,
|
||||
"reason": f"Test reason {uid}",
|
||||
"description": f"Test report description {uid}",
|
||||
"reported_by": reporter,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return ModerationReport.objects.create(**defaults)
|
||||
@@ -369,5 +341,5 @@ class MockInstance:
|
||||
setattr(self, key, value)
|
||||
|
||||
def __repr__(self):
|
||||
attrs = ', '.join(f'{k}={v!r}' for k, v in self.__dict__.items())
|
||||
return f'MockInstance({attrs})'
|
||||
attrs = ", ".join(f"{k}={v!r}" for k, v in self.__dict__.items())
|
||||
return f"MockInstance({attrs})"
|
||||
|
||||
@@ -13,11 +13,7 @@ from typing import Any
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
|
||||
|
||||
def assert_transition_allowed(
|
||||
instance: Any,
|
||||
method_name: str,
|
||||
user: Any | None = None
|
||||
) -> bool:
|
||||
def assert_transition_allowed(instance: Any, method_name: str, user: Any | None = None) -> bool:
|
||||
"""
|
||||
Assert that a transition is allowed.
|
||||
|
||||
@@ -43,11 +39,7 @@ def assert_transition_allowed(
|
||||
return True
|
||||
|
||||
|
||||
def assert_transition_denied(
|
||||
instance: Any,
|
||||
method_name: str,
|
||||
user: Any | None = None
|
||||
) -> bool:
|
||||
def assert_transition_denied(instance: Any, method_name: str, user: Any | None = None) -> bool:
|
||||
"""
|
||||
Assert that a transition is denied.
|
||||
|
||||
@@ -73,11 +65,7 @@ def assert_transition_denied(
|
||||
return True
|
||||
|
||||
|
||||
def assert_state_log_created(
|
||||
instance: Any,
|
||||
expected_state: str,
|
||||
user: Any | None = None
|
||||
) -> Any:
|
||||
def assert_state_log_created(instance: Any, expected_state: str, user: Any | None = None) -> Any:
|
||||
"""
|
||||
Assert that a StateLog entry was created for a transition.
|
||||
|
||||
@@ -98,11 +86,7 @@ def assert_state_log_created(
|
||||
from django_fsm_log.models import StateLog
|
||||
|
||||
ct = ContentType.objects.get_for_model(instance)
|
||||
log = StateLog.objects.filter(
|
||||
content_type=ct,
|
||||
object_id=instance.id,
|
||||
state=expected_state
|
||||
).first()
|
||||
log = StateLog.objects.filter(content_type=ct, object_id=instance.id, state=expected_state).first()
|
||||
|
||||
assert log is not None, f"StateLog for state '{expected_state}' not found"
|
||||
|
||||
@@ -132,22 +116,15 @@ def assert_state_log_count(instance: Any, expected_count: int) -> list[Any]:
|
||||
from django_fsm_log.models import StateLog
|
||||
|
||||
ct = ContentType.objects.get_for_model(instance)
|
||||
logs = list(StateLog.objects.filter(
|
||||
content_type=ct,
|
||||
object_id=instance.id
|
||||
).order_by('timestamp'))
|
||||
logs = list(StateLog.objects.filter(content_type=ct, object_id=instance.id).order_by("timestamp"))
|
||||
|
||||
actual_count = len(logs)
|
||||
assert actual_count == expected_count, \
|
||||
f"Expected {expected_count} StateLog entries, got {actual_count}"
|
||||
assert actual_count == expected_count, f"Expected {expected_count} StateLog entries, got {actual_count}"
|
||||
|
||||
return logs
|
||||
|
||||
|
||||
def assert_state_transition_sequence(
|
||||
instance: Any,
|
||||
expected_states: list[str]
|
||||
) -> list[Any]:
|
||||
def assert_state_transition_sequence(instance: Any, expected_states: list[str]) -> list[Any]:
|
||||
"""
|
||||
Assert that state transitions occurred in a specific sequence.
|
||||
|
||||
@@ -167,24 +144,15 @@ def assert_state_transition_sequence(
|
||||
from django_fsm_log.models import StateLog
|
||||
|
||||
ct = ContentType.objects.get_for_model(instance)
|
||||
logs = list(StateLog.objects.filter(
|
||||
content_type=ct,
|
||||
object_id=instance.id
|
||||
).order_by('timestamp'))
|
||||
logs = list(StateLog.objects.filter(content_type=ct, object_id=instance.id).order_by("timestamp"))
|
||||
|
||||
actual_states = [log.state for log in logs]
|
||||
assert actual_states == expected_states, \
|
||||
f"Expected state sequence {expected_states}, got {actual_states}"
|
||||
assert actual_states == expected_states, f"Expected state sequence {expected_states}, got {actual_states}"
|
||||
|
||||
return logs
|
||||
|
||||
|
||||
def assert_guard_passes(
|
||||
guard: Callable,
|
||||
instance: Any,
|
||||
user: Any | None = None,
|
||||
message: str = ""
|
||||
) -> bool:
|
||||
def assert_guard_passes(guard: Callable, instance: Any, user: Any | None = None, message: str = "") -> bool:
|
||||
"""
|
||||
Assert that a guard function passes.
|
||||
|
||||
@@ -210,11 +178,7 @@ def assert_guard_passes(
|
||||
|
||||
|
||||
def assert_guard_fails(
|
||||
guard: Callable,
|
||||
instance: Any,
|
||||
user: Any | None = None,
|
||||
expected_error_code: str | None = None,
|
||||
message: str = ""
|
||||
guard: Callable, instance: Any, user: Any | None = None, expected_error_code: str | None = None, message: str = ""
|
||||
) -> bool:
|
||||
"""
|
||||
Assert that a guard function fails.
|
||||
@@ -239,19 +203,15 @@ def assert_guard_fails(
|
||||
fail_message = message or f"Guard should fail but returned {result}"
|
||||
assert result is False, fail_message
|
||||
|
||||
if expected_error_code and hasattr(guard, 'error_code'):
|
||||
assert guard.error_code == expected_error_code, \
|
||||
f"Expected error code {expected_error_code}, got {guard.error_code}"
|
||||
if expected_error_code and hasattr(guard, "error_code"):
|
||||
assert (
|
||||
guard.error_code == expected_error_code
|
||||
), f"Expected error code {expected_error_code}, got {guard.error_code}"
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def transition_and_save(
|
||||
instance: Any,
|
||||
transition_method: str,
|
||||
user: Any | None = None,
|
||||
**kwargs
|
||||
) -> Any:
|
||||
def transition_and_save(instance: Any, transition_method: str, user: Any | None = None, **kwargs) -> Any:
|
||||
"""
|
||||
Execute a transition and save the instance.
|
||||
|
||||
@@ -290,10 +250,10 @@ def get_available_transitions(instance: Any) -> list[str]:
|
||||
"""
|
||||
|
||||
# Get the state field name from the instance
|
||||
state_field = getattr(instance, 'state_field_name', 'status')
|
||||
state_field = getattr(instance, "state_field_name", "status")
|
||||
|
||||
# Build the function name dynamically
|
||||
func_name = f'get_available_{state_field}_transitions'
|
||||
func_name = f"get_available_{state_field}_transitions"
|
||||
if hasattr(instance, func_name):
|
||||
get_transitions = getattr(instance, func_name)
|
||||
return [t.name for t in get_transitions()]
|
||||
@@ -301,19 +261,13 @@ def get_available_transitions(instance: Any) -> list[str]:
|
||||
# Fallback: look for transition methods
|
||||
transitions = []
|
||||
for attr_name in dir(instance):
|
||||
if attr_name.startswith('transition_to_'):
|
||||
if attr_name.startswith("transition_to_"):
|
||||
transitions.append(attr_name)
|
||||
|
||||
return transitions
|
||||
|
||||
|
||||
def create_transition_context(
|
||||
instance: Any,
|
||||
from_state: str,
|
||||
to_state: str,
|
||||
user: Any | None = None,
|
||||
**extra
|
||||
) -> dict:
|
||||
def create_transition_context(instance: Any, from_state: str, to_state: str, user: Any | None = None, **extra) -> dict:
|
||||
"""
|
||||
Create a mock transition context dictionary.
|
||||
|
||||
@@ -331,11 +285,11 @@ def create_transition_context(
|
||||
context = create_transition_context(submission, 'PENDING', 'APPROVED', moderator)
|
||||
"""
|
||||
return {
|
||||
'instance': instance,
|
||||
'from_state': from_state,
|
||||
'to_state': to_state,
|
||||
'user': user,
|
||||
'model_class': type(instance),
|
||||
'transition_name': f'transition_to_{to_state.lower()}',
|
||||
**extra
|
||||
"instance": instance,
|
||||
"from_state": from_state,
|
||||
"to_state": to_state,
|
||||
"user": user,
|
||||
"model_class": type(instance),
|
||||
"transition_name": f"transition_to_{to_state.lower()}",
|
||||
**extra,
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Tests for StateTransitionBuilder."""
|
||||
|
||||
import pytest
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
|
||||
|
||||
@@ -19,12 +19,7 @@ class CallbackContext:
|
||||
"""Mock context for testing callbacks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instance: Any = None,
|
||||
from_state: str = 'PENDING',
|
||||
to_state: str = 'APPROVED',
|
||||
user: Any = None,
|
||||
**extra
|
||||
self, instance: Any = None, from_state: str = "PENDING", to_state: str = "APPROVED", user: Any = None, **extra
|
||||
):
|
||||
self.instance = instance or Mock()
|
||||
self.from_state = from_state
|
||||
@@ -34,18 +29,18 @@ class CallbackContext:
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
'instance': self.instance,
|
||||
'from_state': self.from_state,
|
||||
'to_state': self.to_state,
|
||||
'user': self.user,
|
||||
**self.extra
|
||||
"instance": self.instance,
|
||||
"from_state": self.from_state,
|
||||
"to_state": self.to_state,
|
||||
"user": self.user,
|
||||
**self.extra,
|
||||
}
|
||||
|
||||
|
||||
class MockCallback:
|
||||
"""Mock callback for testing."""
|
||||
|
||||
def __init__(self, name: str = 'callback', should_raise: bool = False):
|
||||
def __init__(self, name: str = "callback", should_raise: bool = False):
|
||||
self.name = name
|
||||
self.calls: list[dict] = []
|
||||
self.should_raise = should_raise
|
||||
@@ -71,40 +66,40 @@ class PreTransitionCallbackTests(TestCase):
|
||||
|
||||
def test_pre_callback_executes_before_state_change(self):
|
||||
"""Test that pre-transition callback executes before state changes."""
|
||||
callback = MockCallback('pre_callback')
|
||||
context = CallbackContext(from_state='PENDING', to_state='APPROVED')
|
||||
callback = MockCallback("pre_callback")
|
||||
context = CallbackContext(from_state="PENDING", to_state="APPROVED")
|
||||
|
||||
# Simulate pre-transition execution
|
||||
callback(context.to_dict())
|
||||
|
||||
self.assertTrue(callback.was_called())
|
||||
self.assertEqual(callback.calls[0]['from_state'], 'PENDING')
|
||||
self.assertEqual(callback.calls[0]['to_state'], 'APPROVED')
|
||||
self.assertEqual(callback.calls[0]["from_state"], "PENDING")
|
||||
self.assertEqual(callback.calls[0]["to_state"], "APPROVED")
|
||||
|
||||
def test_pre_callback_receives_instance(self):
|
||||
"""Test that pre-callback receives the model instance."""
|
||||
mock_instance = Mock()
|
||||
mock_instance.id = 123
|
||||
mock_instance.status = 'PENDING'
|
||||
mock_instance.status = "PENDING"
|
||||
|
||||
callback = MockCallback()
|
||||
context = CallbackContext(instance=mock_instance)
|
||||
|
||||
callback(context.to_dict())
|
||||
|
||||
self.assertEqual(callback.calls[0]['instance'], mock_instance)
|
||||
self.assertEqual(callback.calls[0]["instance"], mock_instance)
|
||||
|
||||
def test_pre_callback_receives_user(self):
|
||||
"""Test that pre-callback receives the user performing transition."""
|
||||
mock_user = Mock()
|
||||
mock_user.username = 'moderator'
|
||||
mock_user.username = "moderator"
|
||||
|
||||
callback = MockCallback()
|
||||
context = CallbackContext(user=mock_user)
|
||||
|
||||
callback(context.to_dict())
|
||||
|
||||
self.assertEqual(callback.calls[0]['user'], mock_user)
|
||||
self.assertEqual(callback.calls[0]["user"], mock_user)
|
||||
|
||||
def test_pre_callback_can_prevent_transition(self):
|
||||
"""Test that pre-callback can prevent transition by raising exception."""
|
||||
@@ -119,13 +114,13 @@ class PreTransitionCallbackTests(TestCase):
|
||||
execution_order = []
|
||||
|
||||
def callback_1(ctx):
|
||||
execution_order.append('first')
|
||||
execution_order.append("first")
|
||||
|
||||
def callback_2(ctx):
|
||||
execution_order.append('second')
|
||||
execution_order.append("second")
|
||||
|
||||
def callback_3(ctx):
|
||||
execution_order.append('third')
|
||||
execution_order.append("third")
|
||||
|
||||
context = CallbackContext().to_dict()
|
||||
|
||||
@@ -134,7 +129,7 @@ class PreTransitionCallbackTests(TestCase):
|
||||
callback_2(context)
|
||||
callback_3(context)
|
||||
|
||||
self.assertEqual(execution_order, ['first', 'second', 'third'])
|
||||
self.assertEqual(execution_order, ["first", "second", "third"])
|
||||
|
||||
|
||||
class PostTransitionCallbackTests(TestCase):
|
||||
@@ -142,28 +137,24 @@ class PostTransitionCallbackTests(TestCase):
|
||||
|
||||
def test_post_callback_executes_after_state_change(self):
|
||||
"""Test that post-transition callback executes after state changes."""
|
||||
callback = MockCallback('post_callback')
|
||||
callback = MockCallback("post_callback")
|
||||
|
||||
# Simulate instance after transition
|
||||
mock_instance = Mock()
|
||||
mock_instance.status = 'APPROVED' # Already changed
|
||||
mock_instance.status = "APPROVED" # Already changed
|
||||
|
||||
context = CallbackContext(
|
||||
instance=mock_instance,
|
||||
from_state='PENDING',
|
||||
to_state='APPROVED'
|
||||
)
|
||||
context = CallbackContext(instance=mock_instance, from_state="PENDING", to_state="APPROVED")
|
||||
|
||||
callback(context.to_dict())
|
||||
|
||||
self.assertTrue(callback.was_called())
|
||||
self.assertEqual(callback.calls[0]['instance'].status, 'APPROVED')
|
||||
self.assertEqual(callback.calls[0]["instance"].status, "APPROVED")
|
||||
|
||||
def test_post_callback_receives_updated_instance(self):
|
||||
"""Test that post-callback receives instance with new state."""
|
||||
mock_instance = Mock()
|
||||
mock_instance.status = 'APPROVED'
|
||||
mock_instance.approved_at = '2025-01-15'
|
||||
mock_instance.status = "APPROVED"
|
||||
mock_instance.approved_at = "2025-01-15"
|
||||
mock_instance.handled_by_id = 456
|
||||
|
||||
callback = MockCallback()
|
||||
@@ -171,9 +162,9 @@ class PostTransitionCallbackTests(TestCase):
|
||||
|
||||
callback(context.to_dict())
|
||||
|
||||
instance = callback.calls[0]['instance']
|
||||
self.assertEqual(instance.status, 'APPROVED')
|
||||
self.assertEqual(instance.approved_at, '2025-01-15')
|
||||
instance = callback.calls[0]["instance"]
|
||||
self.assertEqual(instance.status, "APPROVED")
|
||||
self.assertEqual(instance.approved_at, "2025-01-15")
|
||||
|
||||
def test_post_callback_failure_does_not_rollback(self):
|
||||
"""Test that post-callback failures don't rollback the transition."""
|
||||
@@ -193,13 +184,13 @@ class PostTransitionCallbackTests(TestCase):
|
||||
execution_order = []
|
||||
|
||||
def notification_callback(ctx):
|
||||
execution_order.append('notification')
|
||||
execution_order.append("notification")
|
||||
|
||||
def cache_callback(ctx):
|
||||
execution_order.append('cache')
|
||||
execution_order.append("cache")
|
||||
|
||||
def analytics_callback(ctx):
|
||||
execution_order.append('analytics')
|
||||
execution_order.append("analytics")
|
||||
|
||||
context = CallbackContext().to_dict()
|
||||
|
||||
@@ -207,7 +198,7 @@ class PostTransitionCallbackTests(TestCase):
|
||||
cache_callback(context)
|
||||
analytics_callback(context)
|
||||
|
||||
self.assertEqual(execution_order, ['notification', 'cache', 'analytics'])
|
||||
self.assertEqual(execution_order, ["notification", "cache", "analytics"])
|
||||
|
||||
|
||||
class ErrorCallbackTests(TestCase):
|
||||
@@ -221,17 +212,17 @@ class ErrorCallbackTests(TestCase):
|
||||
raise ValueError("Transition failed")
|
||||
except ValueError as e:
|
||||
error_context = {
|
||||
'instance': Mock(),
|
||||
'from_state': 'PENDING',
|
||||
'to_state': 'APPROVED',
|
||||
'exception': e,
|
||||
'exception_type': type(e).__name__
|
||||
"instance": Mock(),
|
||||
"from_state": "PENDING",
|
||||
"to_state": "APPROVED",
|
||||
"exception": e,
|
||||
"exception_type": type(e).__name__,
|
||||
}
|
||||
error_callback(error_context)
|
||||
|
||||
self.assertTrue(error_callback.was_called())
|
||||
self.assertIn('exception', error_callback.calls[0])
|
||||
self.assertEqual(error_callback.calls[0]['exception_type'], 'ValueError')
|
||||
self.assertIn("exception", error_callback.calls[0])
|
||||
self.assertEqual(error_callback.calls[0]["exception_type"], "ValueError")
|
||||
|
||||
def test_error_callback_for_cleanup(self):
|
||||
"""Test that error callbacks can perform cleanup."""
|
||||
@@ -244,7 +235,7 @@ class ErrorCallbackTests(TestCase):
|
||||
try:
|
||||
raise ValueError("Transition failed")
|
||||
except ValueError:
|
||||
cleanup_callback({'exception': 'test'})
|
||||
cleanup_callback({"exception": "test"})
|
||||
|
||||
self.assertTrue(cleanup_performed)
|
||||
|
||||
@@ -256,17 +247,17 @@ class ErrorCallbackTests(TestCase):
|
||||
error_callback = MockCallback()
|
||||
|
||||
error_context = {
|
||||
'instance': mock_instance,
|
||||
'from_state': 'PENDING',
|
||||
'to_state': 'APPROVED',
|
||||
'user': mock_user,
|
||||
'exception': ValueError("Test error")
|
||||
"instance": mock_instance,
|
||||
"from_state": "PENDING",
|
||||
"to_state": "APPROVED",
|
||||
"user": mock_user,
|
||||
"exception": ValueError("Test error"),
|
||||
}
|
||||
|
||||
error_callback(error_context)
|
||||
|
||||
self.assertEqual(error_callback.calls[0]['instance'], mock_instance)
|
||||
self.assertEqual(error_callback.calls[0]['user'], mock_user)
|
||||
self.assertEqual(error_callback.calls[0]["instance"], mock_instance)
|
||||
self.assertEqual(error_callback.calls[0]["user"], mock_user)
|
||||
|
||||
|
||||
class ConditionalCallbackTests(TestCase):
|
||||
@@ -277,15 +268,15 @@ class ConditionalCallbackTests(TestCase):
|
||||
execution_log = []
|
||||
|
||||
def approval_only_callback(ctx):
|
||||
if ctx.get('to_state') == 'APPROVED':
|
||||
execution_log.append('approved')
|
||||
if ctx.get("to_state") == "APPROVED":
|
||||
execution_log.append("approved")
|
||||
|
||||
# Transition to APPROVED - should execute
|
||||
approval_only_callback({'to_state': 'APPROVED'})
|
||||
approval_only_callback({"to_state": "APPROVED"})
|
||||
self.assertEqual(len(execution_log), 1)
|
||||
|
||||
# Transition to REJECTED - should not execute
|
||||
approval_only_callback({'to_state': 'REJECTED'})
|
||||
approval_only_callback({"to_state": "REJECTED"})
|
||||
self.assertEqual(len(execution_log), 1) # Still 1
|
||||
|
||||
def test_callback_with_transition_filter(self):
|
||||
@@ -293,15 +284,15 @@ class ConditionalCallbackTests(TestCase):
|
||||
execution_log = []
|
||||
|
||||
def escalation_callback(ctx):
|
||||
if ctx.get('to_state') == 'ESCALATED':
|
||||
execution_log.append('escalated')
|
||||
if ctx.get("to_state") == "ESCALATED":
|
||||
execution_log.append("escalated")
|
||||
|
||||
# Escalation - should execute
|
||||
escalation_callback({'to_state': 'ESCALATED'})
|
||||
escalation_callback({"to_state": "ESCALATED"})
|
||||
self.assertEqual(len(execution_log), 1)
|
||||
|
||||
# Other transitions - should not execute
|
||||
escalation_callback({'to_state': 'APPROVED'})
|
||||
escalation_callback({"to_state": "APPROVED"})
|
||||
self.assertEqual(len(execution_log), 1)
|
||||
|
||||
def test_callback_with_user_role_filter(self):
|
||||
@@ -309,17 +300,17 @@ class ConditionalCallbackTests(TestCase):
|
||||
admin_notifications = []
|
||||
|
||||
def admin_only_notification(ctx):
|
||||
user = ctx.get('user')
|
||||
if user and getattr(user, 'role', None) == 'ADMIN':
|
||||
user = ctx.get("user")
|
||||
if user and getattr(user, "role", None) == "ADMIN":
|
||||
admin_notifications.append(ctx)
|
||||
|
||||
admin_user = Mock(role='ADMIN')
|
||||
moderator_user = Mock(role='MODERATOR')
|
||||
admin_user = Mock(role="ADMIN")
|
||||
moderator_user = Mock(role="MODERATOR")
|
||||
|
||||
admin_only_notification({'user': admin_user})
|
||||
admin_only_notification({"user": admin_user})
|
||||
self.assertEqual(len(admin_notifications), 1)
|
||||
|
||||
admin_only_notification({'user': moderator_user})
|
||||
admin_only_notification({"user": moderator_user})
|
||||
self.assertEqual(len(admin_notifications), 1) # Still 1
|
||||
|
||||
|
||||
@@ -331,29 +322,29 @@ class CallbackChainTests(TestCase):
|
||||
results = []
|
||||
|
||||
callbacks = [
|
||||
lambda ctx: results.append('a'),
|
||||
lambda ctx: results.append('b'),
|
||||
lambda ctx: results.append('c'),
|
||||
lambda ctx: results.append("a"),
|
||||
lambda ctx: results.append("b"),
|
||||
lambda ctx: results.append("c"),
|
||||
]
|
||||
|
||||
context = {}
|
||||
for cb in callbacks:
|
||||
cb(context)
|
||||
|
||||
self.assertEqual(results, ['a', 'b', 'c'])
|
||||
self.assertEqual(results, ["a", "b", "c"])
|
||||
|
||||
def test_callback_chain_stops_on_failure(self):
|
||||
"""Test that callback chain stops when a callback fails."""
|
||||
results = []
|
||||
|
||||
def callback_a(ctx):
|
||||
results.append('a')
|
||||
results.append("a")
|
||||
|
||||
def callback_b(ctx):
|
||||
raise ValueError("B failed")
|
||||
|
||||
def callback_c(ctx):
|
||||
results.append('c')
|
||||
results.append("c")
|
||||
|
||||
callbacks = [callback_a, callback_b, callback_c]
|
||||
|
||||
@@ -364,7 +355,7 @@ class CallbackChainTests(TestCase):
|
||||
except ValueError:
|
||||
break
|
||||
|
||||
self.assertEqual(results, ['a']) # c never executed
|
||||
self.assertEqual(results, ["a"]) # c never executed
|
||||
|
||||
def test_callback_chain_with_continue_on_error(self):
|
||||
"""Test callback chain that continues despite errors."""
|
||||
@@ -372,13 +363,13 @@ class CallbackChainTests(TestCase):
|
||||
errors = []
|
||||
|
||||
def callback_a(ctx):
|
||||
results.append('a')
|
||||
results.append("a")
|
||||
|
||||
def callback_b(ctx):
|
||||
raise ValueError("B failed")
|
||||
|
||||
def callback_c(ctx):
|
||||
results.append('c')
|
||||
results.append("c")
|
||||
|
||||
callbacks = [callback_a, callback_b, callback_c]
|
||||
|
||||
@@ -389,7 +380,7 @@ class CallbackChainTests(TestCase):
|
||||
except Exception as e:
|
||||
errors.append(str(e))
|
||||
|
||||
self.assertEqual(results, ['a', 'c'])
|
||||
self.assertEqual(results, ["a", "c"])
|
||||
self.assertEqual(len(errors), 1)
|
||||
|
||||
|
||||
@@ -399,36 +390,30 @@ class CallbackContextEnrichmentTests(TestCase):
|
||||
def test_context_includes_model_class(self):
|
||||
"""Test that context includes the model class."""
|
||||
mock_instance = Mock()
|
||||
mock_instance.__class__.__name__ = 'EditSubmission'
|
||||
mock_instance.__class__.__name__ = "EditSubmission"
|
||||
|
||||
context = {
|
||||
'instance': mock_instance,
|
||||
'model_class': type(mock_instance)
|
||||
}
|
||||
context = {"instance": mock_instance, "model_class": type(mock_instance)}
|
||||
|
||||
self.assertIn('model_class', context)
|
||||
self.assertIn("model_class", context)
|
||||
|
||||
def test_context_includes_transition_name(self):
|
||||
"""Test that context includes the transition method name."""
|
||||
context = {
|
||||
'instance': Mock(),
|
||||
'from_state': 'PENDING',
|
||||
'to_state': 'APPROVED',
|
||||
'transition_name': 'transition_to_approved'
|
||||
"instance": Mock(),
|
||||
"from_state": "PENDING",
|
||||
"to_state": "APPROVED",
|
||||
"transition_name": "transition_to_approved",
|
||||
}
|
||||
|
||||
self.assertEqual(context['transition_name'], 'transition_to_approved')
|
||||
self.assertEqual(context["transition_name"], "transition_to_approved")
|
||||
|
||||
def test_context_includes_timestamp(self):
|
||||
"""Test that context includes transition timestamp."""
|
||||
from django.utils import timezone
|
||||
|
||||
context = {
|
||||
'instance': Mock(),
|
||||
'timestamp': timezone.now()
|
||||
}
|
||||
context = {"instance": Mock(), "timestamp": timezone.now()}
|
||||
|
||||
self.assertIn('timestamp', context)
|
||||
self.assertIn("timestamp", context)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -446,9 +431,9 @@ class NotificationCallbackTests(TestCase):
|
||||
|
||||
def _create_transition_context(
|
||||
self,
|
||||
model_name: str = 'EditSubmission',
|
||||
source_state: str = 'PENDING',
|
||||
target_state: str = 'APPROVED',
|
||||
model_name: str = "EditSubmission",
|
||||
source_state: str = "PENDING",
|
||||
target_state: str = "APPROVED",
|
||||
user=None,
|
||||
instance=None,
|
||||
):
|
||||
@@ -465,18 +450,18 @@ class NotificationCallbackTests(TestCase):
|
||||
if user is None:
|
||||
user = Mock()
|
||||
user.pk = 1
|
||||
user.username = 'moderator'
|
||||
user.username = "moderator"
|
||||
|
||||
return TransitionContext(
|
||||
instance=instance,
|
||||
field_name='status',
|
||||
field_name="status",
|
||||
source_state=source_state,
|
||||
target_state=target_state,
|
||||
user=user,
|
||||
timestamp=timezone.now(),
|
||||
)
|
||||
|
||||
@patch('apps.core.state_machine.callbacks.notifications.NotificationService')
|
||||
@patch("apps.core.state_machine.callbacks.notifications.NotificationService")
|
||||
def test_notification_callback_approval_title(self, mock_service_class):
|
||||
"""Test NotificationCallback generates correct title for approvals."""
|
||||
from ..callbacks.notifications import NotificationCallback
|
||||
@@ -487,8 +472,8 @@ class NotificationCallbackTests(TestCase):
|
||||
|
||||
callback = NotificationCallback()
|
||||
context = self._create_transition_context(
|
||||
source_state='PENDING',
|
||||
target_state='APPROVED',
|
||||
source_state="PENDING",
|
||||
target_state="APPROVED",
|
||||
)
|
||||
|
||||
callback.execute(context)
|
||||
@@ -496,9 +481,9 @@ class NotificationCallbackTests(TestCase):
|
||||
# Check that notification was sent with correct title
|
||||
if mock_service.send_notification.called:
|
||||
call_args = mock_service.send_notification.call_args
|
||||
self.assertIn('approved', call_args[1].get('title', '').lower())
|
||||
self.assertIn("approved", call_args[1].get("title", "").lower())
|
||||
|
||||
@patch('apps.core.state_machine.callbacks.notifications.NotificationService')
|
||||
@patch("apps.core.state_machine.callbacks.notifications.NotificationService")
|
||||
def test_notification_callback_rejection_title(self, mock_service_class):
|
||||
"""Test NotificationCallback generates correct title for rejections."""
|
||||
from ..callbacks.notifications import NotificationCallback
|
||||
@@ -509,17 +494,17 @@ class NotificationCallbackTests(TestCase):
|
||||
|
||||
callback = NotificationCallback()
|
||||
context = self._create_transition_context(
|
||||
source_state='PENDING',
|
||||
target_state='REJECTED',
|
||||
source_state="PENDING",
|
||||
target_state="REJECTED",
|
||||
)
|
||||
|
||||
callback.execute(context)
|
||||
|
||||
if mock_service.send_notification.called:
|
||||
call_args = mock_service.send_notification.call_args
|
||||
self.assertIn('rejected', call_args[1].get('title', '').lower())
|
||||
self.assertIn("rejected", call_args[1].get("title", "").lower())
|
||||
|
||||
@patch('apps.core.state_machine.callbacks.notifications.NotificationService')
|
||||
@patch("apps.core.state_machine.callbacks.notifications.NotificationService")
|
||||
def test_moderation_notification_recipient_selection(self, mock_service_class):
|
||||
"""Test ModerationNotificationCallback sends to correct recipient."""
|
||||
from ..callbacks.notifications import ModerationNotificationCallback
|
||||
@@ -530,16 +515,16 @@ class NotificationCallbackTests(TestCase):
|
||||
|
||||
submitter = Mock()
|
||||
submitter.pk = 999
|
||||
submitter.username = 'submitter'
|
||||
submitter.username = "submitter"
|
||||
|
||||
instance = Mock()
|
||||
instance.pk = 123
|
||||
instance.__class__.__name__ = 'EditSubmission'
|
||||
instance.__class__.__name__ = "EditSubmission"
|
||||
instance.user = submitter # The submitter who should receive notification
|
||||
|
||||
callback = ModerationNotificationCallback()
|
||||
context = self._create_transition_context(
|
||||
target_state='APPROVED',
|
||||
target_state="APPROVED",
|
||||
instance=instance,
|
||||
)
|
||||
|
||||
@@ -548,10 +533,10 @@ class NotificationCallbackTests(TestCase):
|
||||
if mock_service.send_notification.called:
|
||||
call_args = mock_service.send_notification.call_args
|
||||
# Should notify the submitter about their submission
|
||||
recipient = call_args[1].get('user') or call_args[0][0] if call_args[0] else None
|
||||
recipient = call_args[1].get("user") or call_args[0][0] if call_args[0] else None
|
||||
self.assertIsNotNone(recipient)
|
||||
|
||||
@patch('apps.core.state_machine.callbacks.notifications.NotificationService')
|
||||
@patch("apps.core.state_machine.callbacks.notifications.NotificationService")
|
||||
def test_notification_callback_handles_service_error(self, mock_service_class):
|
||||
"""Test NotificationCallback handles service errors gracefully."""
|
||||
from ..callbacks.notifications import NotificationCallback
|
||||
@@ -568,7 +553,7 @@ class NotificationCallbackTests(TestCase):
|
||||
# Callback may return False on error but should not raise
|
||||
self.assertIsNotNone(result)
|
||||
|
||||
@patch('apps.core.state_machine.callbacks.notifications.NotificationService')
|
||||
@patch("apps.core.state_machine.callbacks.notifications.NotificationService")
|
||||
def test_notification_callback_message_includes_model_info(self, mock_service_class):
|
||||
"""Test notification message includes model information."""
|
||||
from ..callbacks.notifications import NotificationCallback
|
||||
@@ -578,13 +563,13 @@ class NotificationCallbackTests(TestCase):
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
callback = NotificationCallback()
|
||||
context = self._create_transition_context(model_name='PhotoSubmission')
|
||||
context = self._create_transition_context(model_name="PhotoSubmission")
|
||||
|
||||
callback.execute(context)
|
||||
|
||||
if mock_service.send_notification.called:
|
||||
call_args = mock_service.send_notification.call_args
|
||||
message = call_args[1].get('message', '')
|
||||
message = call_args[1].get("message", "")
|
||||
# Should reference the submission type or model
|
||||
self.assertIsInstance(message, str)
|
||||
|
||||
@@ -599,10 +584,10 @@ class CacheCallbackTests(TestCase):
|
||||
|
||||
def _create_transition_context(
|
||||
self,
|
||||
model_name: str = 'Park',
|
||||
model_name: str = "Park",
|
||||
instance_id: int = 123,
|
||||
source_state: str = 'OPERATING',
|
||||
target_state: str = 'CLOSED_TEMP',
|
||||
source_state: str = "OPERATING",
|
||||
target_state: str = "CLOSED_TEMP",
|
||||
):
|
||||
"""Helper to create a TransitionContext."""
|
||||
from django.utils import timezone
|
||||
@@ -615,14 +600,14 @@ class CacheCallbackTests(TestCase):
|
||||
|
||||
return TransitionContext(
|
||||
instance=instance,
|
||||
field_name='status',
|
||||
field_name="status",
|
||||
source_state=source_state,
|
||||
target_state=target_state,
|
||||
user=Mock(),
|
||||
timestamp=timezone.now(),
|
||||
)
|
||||
|
||||
@patch('apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service')
|
||||
@patch("apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service")
|
||||
def test_cache_callback_invalidates_model_patterns(self, mock_get_service):
|
||||
"""Test CacheInvalidationCallback invalidates correct patterns."""
|
||||
from ..callbacks.cache import CacheInvalidationCallback
|
||||
@@ -631,9 +616,7 @@ class CacheCallbackTests(TestCase):
|
||||
mock_cache.invalidate_pattern = Mock()
|
||||
mock_get_service.return_value = mock_cache
|
||||
|
||||
callback = CacheInvalidationCallback(
|
||||
patterns=['*park:123*', '*parks*']
|
||||
)
|
||||
callback = CacheInvalidationCallback(patterns=["*park:123*", "*parks*"])
|
||||
context = self._create_transition_context()
|
||||
|
||||
callback.execute(context)
|
||||
@@ -641,7 +624,7 @@ class CacheCallbackTests(TestCase):
|
||||
# Should have called invalidate_pattern for each pattern
|
||||
self.assertTrue(mock_cache.invalidate_pattern.called)
|
||||
|
||||
@patch('apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service')
|
||||
@patch("apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service")
|
||||
def test_cache_callback_generates_instance_patterns(self, mock_get_service):
|
||||
"""Test CacheInvalidationCallback generates instance-specific patterns."""
|
||||
from ..callbacks.cache import CacheInvalidationCallback
|
||||
@@ -651,30 +634,25 @@ class CacheCallbackTests(TestCase):
|
||||
mock_get_service.return_value = mock_cache
|
||||
|
||||
callback = CacheInvalidationCallback(include_instance_patterns=True)
|
||||
context = self._create_transition_context(
|
||||
model_name='Park',
|
||||
instance_id=456
|
||||
)
|
||||
context = self._create_transition_context(model_name="Park", instance_id=456)
|
||||
|
||||
callback.execute(context)
|
||||
|
||||
# Should have called invalidate_pattern with instance-specific patterns
|
||||
self.assertTrue(mock_cache.invalidate_pattern.called)
|
||||
patterns_called = [
|
||||
call[0][0] for call in mock_cache.invalidate_pattern.call_args_list
|
||||
]
|
||||
patterns_called = [call[0][0] for call in mock_cache.invalidate_pattern.call_args_list]
|
||||
# Should include patterns containing the instance ID
|
||||
has_instance_pattern = any('456' in p for p in patterns_called)
|
||||
has_instance_pattern = any("456" in p for p in patterns_called)
|
||||
self.assertTrue(has_instance_pattern, f"No pattern with instance ID in {patterns_called}")
|
||||
|
||||
@patch('apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service')
|
||||
@patch("apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service")
|
||||
def test_cache_callback_handles_service_unavailable(self, mock_get_service):
|
||||
"""Test CacheInvalidationCallback handles unavailable cache service."""
|
||||
from ..callbacks.cache import CacheInvalidationCallback
|
||||
|
||||
mock_get_service.return_value = None
|
||||
|
||||
callback = CacheInvalidationCallback(patterns=['*test*'])
|
||||
callback = CacheInvalidationCallback(patterns=["*test*"])
|
||||
context = self._create_transition_context()
|
||||
|
||||
# Should not raise, uses fallback
|
||||
@@ -682,7 +660,7 @@ class CacheCallbackTests(TestCase):
|
||||
# Should return True (fallback succeeds)
|
||||
self.assertTrue(result)
|
||||
|
||||
@patch('apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service')
|
||||
@patch("apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service")
|
||||
def test_cache_callback_continues_on_pattern_error(self, mock_get_service):
|
||||
"""Test CacheInvalidationCallback continues if individual pattern fails."""
|
||||
from ..callbacks.cache import CacheInvalidationCallback
|
||||
@@ -693,16 +671,13 @@ class CacheCallbackTests(TestCase):
|
||||
def invalidate_side_effect(pattern):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if 'bad' in pattern:
|
||||
if "bad" in pattern:
|
||||
raise Exception("Pattern invalid")
|
||||
|
||||
mock_cache.invalidate_pattern = Mock(side_effect=invalidate_side_effect)
|
||||
mock_get_service.return_value = mock_cache
|
||||
|
||||
callback = CacheInvalidationCallback(
|
||||
patterns=['good:*', 'bad:*', 'another:*'],
|
||||
include_instance_patterns=False
|
||||
)
|
||||
callback = CacheInvalidationCallback(patterns=["good:*", "bad:*", "another:*"], include_instance_patterns=False)
|
||||
context = self._create_transition_context()
|
||||
|
||||
# Should not raise overall
|
||||
@@ -716,7 +691,7 @@ class ModelCacheInvalidationTests(TestCase):
|
||||
|
||||
def _create_transition_context(
|
||||
self,
|
||||
model_name: str = 'Ride',
|
||||
model_name: str = "Ride",
|
||||
instance_id: int = 789,
|
||||
):
|
||||
from django.utils import timezone
|
||||
@@ -728,20 +703,20 @@ class ModelCacheInvalidationTests(TestCase):
|
||||
instance.__class__.__name__ = model_name
|
||||
|
||||
# Add park reference for rides
|
||||
if model_name == 'Ride':
|
||||
if model_name == "Ride":
|
||||
instance.park = Mock()
|
||||
instance.park.pk = 111
|
||||
|
||||
return TransitionContext(
|
||||
instance=instance,
|
||||
field_name='status',
|
||||
source_state='OPERATING',
|
||||
target_state='CLOSED_TEMP',
|
||||
field_name="status",
|
||||
source_state="OPERATING",
|
||||
target_state="CLOSED_TEMP",
|
||||
user=Mock(),
|
||||
timestamp=timezone.now(),
|
||||
)
|
||||
|
||||
@patch('apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service')
|
||||
@patch("apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service")
|
||||
def test_ride_cache_includes_park_patterns(self, mock_get_service):
|
||||
"""Test RideCacheInvalidation includes parent park patterns."""
|
||||
from ..callbacks.cache import RideCacheInvalidation
|
||||
@@ -755,12 +730,10 @@ class ModelCacheInvalidationTests(TestCase):
|
||||
|
||||
callback.execute(context)
|
||||
|
||||
patterns_called = [
|
||||
call[0][0] for call in mock_cache.invalidate_pattern.call_args_list
|
||||
]
|
||||
patterns_called = [call[0][0] for call in mock_cache.invalidate_pattern.call_args_list]
|
||||
|
||||
# Should include park patterns (parent park ID is 111)
|
||||
has_park_pattern = any('park' in p.lower() for p in patterns_called)
|
||||
has_park_pattern = any("park" in p.lower() for p in patterns_called)
|
||||
self.assertTrue(has_park_pattern, f"No park pattern in {patterns_called}")
|
||||
|
||||
|
||||
@@ -775,17 +748,18 @@ class RelatedUpdateCallbackTests(TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
get_user_model()
|
||||
|
||||
self.user = Mock()
|
||||
self.user.pk = 1
|
||||
self.user.username = 'testuser'
|
||||
self.user.username = "testuser"
|
||||
|
||||
def _create_transition_context(
|
||||
self,
|
||||
model_name: str = 'Ride',
|
||||
model_name: str = "Ride",
|
||||
instance=None,
|
||||
target_state: str = 'OPERATING',
|
||||
target_state: str = "OPERATING",
|
||||
):
|
||||
from django.utils import timezone
|
||||
|
||||
@@ -798,8 +772,8 @@ class RelatedUpdateCallbackTests(TestCase):
|
||||
|
||||
return TransitionContext(
|
||||
instance=instance,
|
||||
field_name='status',
|
||||
source_state='UNDER_CONSTRUCTION',
|
||||
field_name="status",
|
||||
source_state="UNDER_CONSTRUCTION",
|
||||
target_state=target_state,
|
||||
user=self.user,
|
||||
timestamp=timezone.now(),
|
||||
@@ -819,7 +793,7 @@ class RelatedUpdateCallbackTests(TestCase):
|
||||
# Create mock ride that belongs to park
|
||||
mock_ride = Mock()
|
||||
mock_ride.pk = 200
|
||||
mock_ride.__class__.__name__ = 'Ride'
|
||||
mock_ride.__class__.__name__ = "Ride"
|
||||
mock_ride.park = mock_park
|
||||
mock_ride.is_coaster = True
|
||||
|
||||
@@ -830,9 +804,9 @@ class RelatedUpdateCallbackTests(TestCase):
|
||||
|
||||
callback = ParkCountUpdateCallback()
|
||||
context = self._create_transition_context(
|
||||
model_name='Ride',
|
||||
model_name="Ride",
|
||||
instance=mock_ride,
|
||||
target_state='OPERATING',
|
||||
target_state="OPERATING",
|
||||
)
|
||||
|
||||
# Execute callback
|
||||
@@ -847,12 +821,12 @@ class RelatedUpdateCallbackTests(TestCase):
|
||||
|
||||
mock_ride = Mock()
|
||||
mock_ride.pk = 200
|
||||
mock_ride.__class__.__name__ = 'Ride'
|
||||
mock_ride.__class__.__name__ = "Ride"
|
||||
mock_ride.park = None # No park
|
||||
|
||||
callback = ParkCountUpdateCallback()
|
||||
context = self._create_transition_context(
|
||||
model_name='Ride',
|
||||
model_name="Ride",
|
||||
instance=mock_ride,
|
||||
)
|
||||
|
||||
@@ -872,15 +846,15 @@ class RelatedUpdateCallbackTests(TestCase):
|
||||
|
||||
mock_ride = Mock()
|
||||
mock_ride.pk = 200
|
||||
mock_ride.__class__.__name__ = 'Ride'
|
||||
mock_ride.__class__.__name__ = "Ride"
|
||||
mock_ride.park = mock_park
|
||||
mock_ride.is_coaster = False
|
||||
|
||||
callback = ParkCountUpdateCallback()
|
||||
context = self._create_transition_context(
|
||||
model_name='Ride',
|
||||
model_name="Ride",
|
||||
instance=mock_ride,
|
||||
target_state='OPERATING',
|
||||
target_state="OPERATING",
|
||||
)
|
||||
|
||||
callback.execute(context)
|
||||
@@ -901,15 +875,15 @@ class RelatedUpdateCallbackTests(TestCase):
|
||||
|
||||
mock_ride = Mock()
|
||||
mock_ride.pk = 200
|
||||
mock_ride.__class__.__name__ = 'Ride'
|
||||
mock_ride.__class__.__name__ = "Ride"
|
||||
mock_ride.park = mock_park
|
||||
mock_ride.is_coaster = True
|
||||
|
||||
callback = ParkCountUpdateCallback()
|
||||
context = self._create_transition_context(
|
||||
model_name='Ride',
|
||||
model_name="Ride",
|
||||
instance=mock_ride,
|
||||
target_state='CLOSED_PERM',
|
||||
target_state="CLOSED_PERM",
|
||||
)
|
||||
|
||||
result = callback.execute(context)
|
||||
@@ -931,18 +905,18 @@ class CallbackErrorHandlingTests(TestCase):
|
||||
|
||||
instance = Mock()
|
||||
instance.pk = 1
|
||||
instance.__class__.__name__ = 'EditSubmission'
|
||||
instance.__class__.__name__ = "EditSubmission"
|
||||
|
||||
return TransitionContext(
|
||||
instance=instance,
|
||||
field_name='status',
|
||||
source_state='PENDING',
|
||||
target_state='APPROVED',
|
||||
field_name="status",
|
||||
source_state="PENDING",
|
||||
target_state="APPROVED",
|
||||
user=Mock(),
|
||||
timestamp=timezone.now(),
|
||||
)
|
||||
|
||||
@patch('apps.core.state_machine.callbacks.notifications.NotificationService')
|
||||
@patch("apps.core.state_machine.callbacks.notifications.NotificationService")
|
||||
def test_notification_callback_logs_error_on_failure(self, mock_service_class):
|
||||
"""Test NotificationCallback logs errors when service fails."""
|
||||
import logging
|
||||
@@ -957,7 +931,7 @@ class CallbackErrorHandlingTests(TestCase):
|
||||
context = self._create_transition_context()
|
||||
|
||||
with self.assertLogs(level=logging.WARNING):
|
||||
try:
|
||||
try: # noqa: SIM105
|
||||
callback.execute(context)
|
||||
except Exception:
|
||||
pass # May or may not raise depending on implementation
|
||||
@@ -965,7 +939,7 @@ class CallbackErrorHandlingTests(TestCase):
|
||||
# Should have logged something about the error
|
||||
# (Logging behavior depends on implementation)
|
||||
|
||||
@patch('apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service')
|
||||
@patch("apps.core.state_machine.callbacks.cache.CacheInvalidationCallback._get_cache_service")
|
||||
def test_cache_callback_returns_false_on_total_failure(self, mock_get_service):
|
||||
"""Test CacheInvalidationCallback returns False on complete failure."""
|
||||
from ..callbacks.cache import CacheInvalidationCallback
|
||||
@@ -974,10 +948,7 @@ class CallbackErrorHandlingTests(TestCase):
|
||||
mock_cache.invalidate_pattern = Mock(side_effect=Exception("Cache error"))
|
||||
mock_get_service.return_value = mock_cache
|
||||
|
||||
callback = CacheInvalidationCallback(
|
||||
patterns=['*test*'],
|
||||
include_instance_patterns=False
|
||||
)
|
||||
callback = CacheInvalidationCallback(patterns=["*test*"], include_instance_patterns=False)
|
||||
context = self._create_transition_context()
|
||||
|
||||
result = callback.execute(context)
|
||||
@@ -993,18 +964,18 @@ class CallbackErrorHandlingTests(TestCase):
|
||||
|
||||
instance = Mock()
|
||||
instance.pk = 1
|
||||
instance.__class__.__name__ = 'EditSubmission'
|
||||
instance.__class__.__name__ = "EditSubmission"
|
||||
|
||||
context = TransitionContext(
|
||||
instance=instance,
|
||||
field_name='status',
|
||||
source_state='PENDING',
|
||||
target_state='APPROVED',
|
||||
field_name="status",
|
||||
source_state="PENDING",
|
||||
target_state="APPROVED",
|
||||
user=None, # No user
|
||||
timestamp=timezone.now(),
|
||||
)
|
||||
|
||||
with patch('apps.core.state_machine.callbacks.notifications.NotificationService'):
|
||||
with patch("apps.core.state_machine.callbacks.notifications.NotificationService"):
|
||||
callback = NotificationCallback()
|
||||
# Should not raise with None user
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Tests for transition decorator generation."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
from apps.core.state_machine.decorators import (
|
||||
@@ -11,9 +12,7 @@ from apps.core.state_machine.decorators import (
|
||||
|
||||
def test_generate_transition_decorator():
|
||||
"""Test basic transition decorator generation."""
|
||||
decorator = generate_transition_decorator(
|
||||
source="pending", target="approved", field_name="status"
|
||||
)
|
||||
decorator = generate_transition_decorator(source="pending", target="approved", field_name="status")
|
||||
assert callable(decorator)
|
||||
|
||||
|
||||
@@ -72,9 +71,7 @@ def test_create_transition_method_with_callbacks():
|
||||
def test_factory_create_approve_method():
|
||||
"""Test approval method creation."""
|
||||
factory = TransitionMethodFactory()
|
||||
method = factory.create_approve_method(
|
||||
source="pending", target="approved", field_name="status"
|
||||
)
|
||||
method = factory.create_approve_method(source="pending", target="approved", field_name="status")
|
||||
assert callable(method)
|
||||
assert method.__name__ == "approve"
|
||||
|
||||
@@ -82,9 +79,7 @@ def test_factory_create_approve_method():
|
||||
def test_factory_create_reject_method():
|
||||
"""Test rejection method creation."""
|
||||
factory = TransitionMethodFactory()
|
||||
method = factory.create_reject_method(
|
||||
source="pending", target="rejected", field_name="status"
|
||||
)
|
||||
method = factory.create_reject_method(source="pending", target="rejected", field_name="status")
|
||||
assert callable(method)
|
||||
assert method.__name__ == "reject"
|
||||
|
||||
@@ -92,9 +87,7 @@ def test_factory_create_reject_method():
|
||||
def test_factory_create_escalate_method():
|
||||
"""Test escalation method creation."""
|
||||
factory = TransitionMethodFactory()
|
||||
method = factory.create_escalate_method(
|
||||
source="pending", target="escalated", field_name="status"
|
||||
)
|
||||
method = factory.create_escalate_method(source="pending", target="escalated", field_name="status")
|
||||
assert callable(method)
|
||||
assert method.__name__ == "escalate"
|
||||
|
||||
@@ -145,16 +138,14 @@ def test_with_transition_logging():
|
||||
def test_method_signature_generation():
|
||||
"""Test that generated methods have proper signatures."""
|
||||
factory = TransitionMethodFactory()
|
||||
method = factory.create_approve_method(
|
||||
source="pending", target="approved"
|
||||
)
|
||||
method = factory.create_approve_method(source="pending", target="approved")
|
||||
|
||||
# Check method accepts expected parameters
|
||||
mock_instance = Mock()
|
||||
mock_user = Mock()
|
||||
|
||||
# Should not raise
|
||||
try:
|
||||
try: # noqa: SIM105
|
||||
method(mock_instance, user=mock_user, comment="test")
|
||||
except Exception:
|
||||
# May fail due to django-fsm not being fully configured
|
||||
|
||||
@@ -58,28 +58,16 @@ class PermissionGuardTests(TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.regular_user = User.objects.create_user(
|
||||
username='user',
|
||||
email='user@example.com',
|
||||
password='testpass123',
|
||||
role='USER'
|
||||
username="user", email="user@example.com", password="testpass123", role="USER"
|
||||
)
|
||||
self.moderator = User.objects.create_user(
|
||||
username='moderator',
|
||||
email='moderator@example.com',
|
||||
password='testpass123',
|
||||
role='MODERATOR'
|
||||
username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR"
|
||||
)
|
||||
self.admin = User.objects.create_user(
|
||||
username='admin',
|
||||
email='admin@example.com',
|
||||
password='testpass123',
|
||||
role='ADMIN'
|
||||
username="admin", email="admin@example.com", password="testpass123", role="ADMIN"
|
||||
)
|
||||
self.superuser = User.objects.create_user(
|
||||
username='superuser',
|
||||
email='superuser@example.com',
|
||||
password='testpass123',
|
||||
role='SUPERUSER'
|
||||
username="superuser", email="superuser@example.com", password="testpass123", role="SUPERUSER"
|
||||
)
|
||||
self.instance = MockInstance()
|
||||
|
||||
@@ -168,7 +156,7 @@ class PermissionGuardTests(TestCase):
|
||||
|
||||
def test_required_roles_explicit_list(self):
|
||||
"""Test using explicit required_roles list."""
|
||||
guard = PermissionGuard(required_roles=['ADMIN', 'SUPERUSER'])
|
||||
guard = PermissionGuard(required_roles=["ADMIN", "SUPERUSER"])
|
||||
|
||||
self.assertTrue(guard(self.instance, user=self.admin))
|
||||
self.assertTrue(guard(self.instance, user=self.superuser))
|
||||
@@ -177,8 +165,9 @@ class PermissionGuardTests(TestCase):
|
||||
|
||||
def test_custom_check_passes(self):
|
||||
"""Test custom check function that passes."""
|
||||
|
||||
def custom_check(instance, user):
|
||||
return hasattr(instance, 'allow_access') and instance.allow_access
|
||||
return hasattr(instance, "allow_access") and instance.allow_access
|
||||
|
||||
guard = PermissionGuard(custom_check=custom_check)
|
||||
instance = MockInstance(allow_access=True)
|
||||
@@ -189,8 +178,9 @@ class PermissionGuardTests(TestCase):
|
||||
|
||||
def test_custom_check_fails(self):
|
||||
"""Test custom check function that fails."""
|
||||
|
||||
def custom_check(instance, user):
|
||||
return hasattr(instance, 'allow_access') and instance.allow_access
|
||||
return hasattr(instance, "allow_access") and instance.allow_access
|
||||
|
||||
guard = PermissionGuard(custom_check=custom_check)
|
||||
instance = MockInstance(allow_access=False)
|
||||
@@ -237,28 +227,16 @@ class OwnershipGuardTests(TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.owner = User.objects.create_user(
|
||||
username='owner',
|
||||
email='owner@example.com',
|
||||
password='testpass123',
|
||||
role='USER'
|
||||
username="owner", email="owner@example.com", password="testpass123", role="USER"
|
||||
)
|
||||
self.other_user = User.objects.create_user(
|
||||
username='other',
|
||||
email='other@example.com',
|
||||
password='testpass123',
|
||||
role='USER'
|
||||
username="other", email="other@example.com", password="testpass123", role="USER"
|
||||
)
|
||||
self.moderator = User.objects.create_user(
|
||||
username='moderator',
|
||||
email='moderator@example.com',
|
||||
password='testpass123',
|
||||
role='MODERATOR'
|
||||
username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR"
|
||||
)
|
||||
self.admin = User.objects.create_user(
|
||||
username='admin',
|
||||
email='admin@example.com',
|
||||
password='testpass123',
|
||||
role='ADMIN'
|
||||
username="admin", email="admin@example.com", password="testpass123", role="ADMIN"
|
||||
)
|
||||
|
||||
def test_no_user_fails(self):
|
||||
@@ -329,7 +307,7 @@ class OwnershipGuardTests(TestCase):
|
||||
def test_custom_owner_fields(self):
|
||||
"""Test custom owner field names."""
|
||||
instance = MockInstance(author=self.owner)
|
||||
guard = OwnershipGuard(owner_fields=['author'])
|
||||
guard = OwnershipGuard(owner_fields=["author"])
|
||||
|
||||
result = guard(instance, user=self.owner)
|
||||
|
||||
@@ -357,22 +335,13 @@ class AssignmentGuardTests(TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.assigned_user = User.objects.create_user(
|
||||
username='assigned',
|
||||
email='assigned@example.com',
|
||||
password='testpass123',
|
||||
role='MODERATOR'
|
||||
username="assigned", email="assigned@example.com", password="testpass123", role="MODERATOR"
|
||||
)
|
||||
self.other_user = User.objects.create_user(
|
||||
username='other',
|
||||
email='other@example.com',
|
||||
password='testpass123',
|
||||
role='MODERATOR'
|
||||
username="other", email="other@example.com", password="testpass123", role="MODERATOR"
|
||||
)
|
||||
self.admin = User.objects.create_user(
|
||||
username='admin',
|
||||
email='admin@example.com',
|
||||
password='testpass123',
|
||||
role='ADMIN'
|
||||
username="admin", email="admin@example.com", password="testpass123", role="ADMIN"
|
||||
)
|
||||
|
||||
def test_no_user_fails(self):
|
||||
@@ -426,7 +395,7 @@ class AssignmentGuardTests(TestCase):
|
||||
def test_custom_assignment_fields(self):
|
||||
"""Test custom assignment field names."""
|
||||
instance = MockInstance(reviewer=self.assigned_user)
|
||||
guard = AssignmentGuard(assignment_fields=['reviewer'])
|
||||
guard = AssignmentGuard(assignment_fields=["reviewer"])
|
||||
|
||||
result = guard(instance, user=self.assigned_user)
|
||||
|
||||
@@ -439,7 +408,7 @@ class AssignmentGuardTests(TestCase):
|
||||
|
||||
guard(instance, user=self.assigned_user)
|
||||
|
||||
self.assertIn('assigned', guard.get_error_message().lower())
|
||||
self.assertIn("assigned", guard.get_error_message().lower())
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -453,16 +422,13 @@ class StateGuardTests(TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.user = User.objects.create_user(
|
||||
username='user',
|
||||
email='user@example.com',
|
||||
password='testpass123',
|
||||
role='USER'
|
||||
username="user", email="user@example.com", password="testpass123", role="USER"
|
||||
)
|
||||
|
||||
def test_allowed_states_passes(self):
|
||||
"""Test that guard passes when in allowed state."""
|
||||
instance = MockInstance(status='PENDING')
|
||||
guard = StateGuard(allowed_states=['PENDING', 'UNDER_REVIEW'])
|
||||
instance = MockInstance(status="PENDING")
|
||||
guard = StateGuard(allowed_states=["PENDING", "UNDER_REVIEW"])
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
@@ -470,8 +436,8 @@ class StateGuardTests(TestCase):
|
||||
|
||||
def test_allowed_states_fails(self):
|
||||
"""Test that guard fails when not in allowed state."""
|
||||
instance = MockInstance(status='COMPLETED')
|
||||
guard = StateGuard(allowed_states=['PENDING', 'UNDER_REVIEW'])
|
||||
instance = MockInstance(status="COMPLETED")
|
||||
guard = StateGuard(allowed_states=["PENDING", "UNDER_REVIEW"])
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
@@ -480,8 +446,8 @@ class StateGuardTests(TestCase):
|
||||
|
||||
def test_blocked_states_passes(self):
|
||||
"""Test that guard passes when not in blocked state."""
|
||||
instance = MockInstance(status='PENDING')
|
||||
guard = StateGuard(blocked_states=['COMPLETED', 'CANCELLED'])
|
||||
instance = MockInstance(status="PENDING")
|
||||
guard = StateGuard(blocked_states=["COMPLETED", "CANCELLED"])
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
@@ -489,8 +455,8 @@ class StateGuardTests(TestCase):
|
||||
|
||||
def test_blocked_states_fails(self):
|
||||
"""Test that guard fails when in blocked state."""
|
||||
instance = MockInstance(status='COMPLETED')
|
||||
guard = StateGuard(blocked_states=['COMPLETED', 'CANCELLED'])
|
||||
instance = MockInstance(status="COMPLETED")
|
||||
guard = StateGuard(blocked_states=["COMPLETED", "CANCELLED"])
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
@@ -499,8 +465,8 @@ class StateGuardTests(TestCase):
|
||||
|
||||
def test_custom_state_field(self):
|
||||
"""Test using custom state field name."""
|
||||
instance = MockInstance(workflow_status='ACTIVE')
|
||||
guard = StateGuard(allowed_states=['ACTIVE'], state_field='workflow_status')
|
||||
instance = MockInstance(workflow_status="ACTIVE")
|
||||
guard = StateGuard(allowed_states=["ACTIVE"], state_field="workflow_status")
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
@@ -508,14 +474,14 @@ class StateGuardTests(TestCase):
|
||||
|
||||
def test_error_message_includes_states(self):
|
||||
"""Test that error message includes allowed states."""
|
||||
instance = MockInstance(status='COMPLETED')
|
||||
guard = StateGuard(allowed_states=['PENDING', 'UNDER_REVIEW'])
|
||||
instance = MockInstance(status="COMPLETED")
|
||||
guard = StateGuard(allowed_states=["PENDING", "UNDER_REVIEW"])
|
||||
|
||||
guard(instance, user=self.user)
|
||||
|
||||
message = guard.get_error_message()
|
||||
self.assertIn('PENDING', message)
|
||||
self.assertIn('UNDER_REVIEW', message)
|
||||
self.assertIn("PENDING", message)
|
||||
self.assertIn("UNDER_REVIEW", message)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -529,16 +495,13 @@ class MetadataGuardTests(TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.user = User.objects.create_user(
|
||||
username='user',
|
||||
email='user@example.com',
|
||||
password='testpass123',
|
||||
role='USER'
|
||||
username="user", email="user@example.com", password="testpass123", role="USER"
|
||||
)
|
||||
|
||||
def test_required_fields_present(self):
|
||||
"""Test that guard passes when required fields are present."""
|
||||
instance = MockInstance(resolution_notes='Fixed', assigned_to='user')
|
||||
guard = MetadataGuard(required_fields=['resolution_notes', 'assigned_to'])
|
||||
instance = MockInstance(resolution_notes="Fixed", assigned_to="user")
|
||||
guard = MetadataGuard(required_fields=["resolution_notes", "assigned_to"])
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
@@ -546,8 +509,8 @@ class MetadataGuardTests(TestCase):
|
||||
|
||||
def test_required_field_missing(self):
|
||||
"""Test that guard fails when required field is missing."""
|
||||
instance = MockInstance(resolution_notes='Fixed')
|
||||
guard = MetadataGuard(required_fields=['resolution_notes', 'assigned_to'])
|
||||
instance = MockInstance(resolution_notes="Fixed")
|
||||
guard = MetadataGuard(required_fields=["resolution_notes", "assigned_to"])
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
@@ -557,7 +520,7 @@ class MetadataGuardTests(TestCase):
|
||||
def test_required_field_none(self):
|
||||
"""Test that guard fails when required field is None."""
|
||||
instance = MockInstance(resolution_notes=None)
|
||||
guard = MetadataGuard(required_fields=['resolution_notes'])
|
||||
guard = MetadataGuard(required_fields=["resolution_notes"])
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
@@ -566,8 +529,8 @@ class MetadataGuardTests(TestCase):
|
||||
|
||||
def test_empty_string_fails_check_not_empty(self):
|
||||
"""Test that empty string fails when check_not_empty is True."""
|
||||
instance = MockInstance(resolution_notes=' ')
|
||||
guard = MetadataGuard(required_fields=['resolution_notes'], check_not_empty=True)
|
||||
instance = MockInstance(resolution_notes=" ")
|
||||
guard = MetadataGuard(required_fields=["resolution_notes"], check_not_empty=True)
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
@@ -577,7 +540,7 @@ class MetadataGuardTests(TestCase):
|
||||
def test_empty_list_fails_check_not_empty(self):
|
||||
"""Test that empty list fails when check_not_empty is True."""
|
||||
instance = MockInstance(tags=[])
|
||||
guard = MetadataGuard(required_fields=['tags'], check_not_empty=True)
|
||||
guard = MetadataGuard(required_fields=["tags"], check_not_empty=True)
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
@@ -587,7 +550,7 @@ class MetadataGuardTests(TestCase):
|
||||
def test_empty_dict_fails_check_not_empty(self):
|
||||
"""Test that empty dict fails when check_not_empty is True."""
|
||||
instance = MockInstance(metadata={})
|
||||
guard = MetadataGuard(required_fields=['metadata'], check_not_empty=True)
|
||||
guard = MetadataGuard(required_fields=["metadata"], check_not_empty=True)
|
||||
|
||||
result = guard(instance, user=self.user)
|
||||
|
||||
@@ -597,12 +560,12 @@ class MetadataGuardTests(TestCase):
|
||||
def test_error_message_includes_field_name(self):
|
||||
"""Test that error message includes the field name."""
|
||||
instance = MockInstance(resolution_notes=None)
|
||||
guard = MetadataGuard(required_fields=['resolution_notes'])
|
||||
guard = MetadataGuard(required_fields=["resolution_notes"])
|
||||
|
||||
guard(instance, user=self.user)
|
||||
|
||||
message = guard.get_error_message()
|
||||
self.assertIn('Resolution Notes', message)
|
||||
self.assertIn("Resolution Notes", message)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -616,32 +579,23 @@ class CompositeGuardTests(TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.owner = User.objects.create_user(
|
||||
username='owner',
|
||||
email='owner@example.com',
|
||||
password='testpass123',
|
||||
role='USER'
|
||||
username="owner", email="owner@example.com", password="testpass123", role="USER"
|
||||
)
|
||||
self.moderator = User.objects.create_user(
|
||||
username='moderator',
|
||||
email='moderator@example.com',
|
||||
password='testpass123',
|
||||
role='MODERATOR'
|
||||
username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR"
|
||||
)
|
||||
self.non_owner_moderator = User.objects.create_user(
|
||||
username='non_owner_moderator',
|
||||
email='non_owner_moderator@example.com',
|
||||
password='testpass123',
|
||||
role='MODERATOR'
|
||||
username="non_owner_moderator",
|
||||
email="non_owner_moderator@example.com",
|
||||
password="testpass123",
|
||||
role="MODERATOR",
|
||||
)
|
||||
|
||||
def test_and_operator_all_pass(self):
|
||||
"""Test AND operator when all guards pass."""
|
||||
instance = MockInstance(created_by=self.moderator)
|
||||
guards = [
|
||||
PermissionGuard(requires_moderator=True),
|
||||
OwnershipGuard()
|
||||
]
|
||||
composite = CompositeGuard(guards, operator='AND')
|
||||
guards = [PermissionGuard(requires_moderator=True), OwnershipGuard()]
|
||||
composite = CompositeGuard(guards, operator="AND")
|
||||
|
||||
result = composite(instance, user=self.moderator)
|
||||
|
||||
@@ -652,9 +606,9 @@ class CompositeGuardTests(TestCase):
|
||||
instance = MockInstance(created_by=self.owner)
|
||||
guards = [
|
||||
PermissionGuard(requires_moderator=True), # Will pass for moderator
|
||||
OwnershipGuard() # Will fail - moderator is not owner
|
||||
OwnershipGuard(), # Will fail - moderator is not owner
|
||||
]
|
||||
composite = CompositeGuard(guards, operator='AND')
|
||||
composite = CompositeGuard(guards, operator="AND")
|
||||
|
||||
result = composite(instance, user=self.non_owner_moderator)
|
||||
|
||||
@@ -666,9 +620,9 @@ class CompositeGuardTests(TestCase):
|
||||
instance = MockInstance(created_by=self.owner)
|
||||
guards = [
|
||||
PermissionGuard(requires_moderator=True), # Will fail for owner
|
||||
OwnershipGuard() # Will pass - user is owner
|
||||
OwnershipGuard(), # Will pass - user is owner
|
||||
]
|
||||
composite = CompositeGuard(guards, operator='OR')
|
||||
composite = CompositeGuard(guards, operator="OR")
|
||||
|
||||
result = composite(instance, user=self.owner)
|
||||
|
||||
@@ -677,11 +631,8 @@ class CompositeGuardTests(TestCase):
|
||||
def test_or_operator_all_fail(self):
|
||||
"""Test OR operator when all guards fail."""
|
||||
instance = MockInstance(created_by=self.moderator)
|
||||
guards = [
|
||||
PermissionGuard(requires_admin=True), # Regular user fails
|
||||
OwnershipGuard() # Not the owner fails
|
||||
]
|
||||
composite = CompositeGuard(guards, operator='OR')
|
||||
guards = [PermissionGuard(requires_admin=True), OwnershipGuard()] # Regular user fails # Not the owner fails
|
||||
composite = CompositeGuard(guards, operator="OR")
|
||||
|
||||
result = composite(instance, user=self.owner)
|
||||
|
||||
@@ -690,19 +641,13 @@ class CompositeGuardTests(TestCase):
|
||||
|
||||
def test_nested_composite_guards(self):
|
||||
"""Test nested composite guards."""
|
||||
instance = MockInstance(created_by=self.moderator, status='PENDING')
|
||||
instance = MockInstance(created_by=self.moderator, status="PENDING")
|
||||
|
||||
# Inner composite: moderator OR owner
|
||||
inner = CompositeGuard([
|
||||
PermissionGuard(requires_moderator=True),
|
||||
OwnershipGuard()
|
||||
], operator='OR')
|
||||
inner = CompositeGuard([PermissionGuard(requires_moderator=True), OwnershipGuard()], operator="OR")
|
||||
|
||||
# Outer composite: (moderator OR owner) AND valid state
|
||||
outer = CompositeGuard([
|
||||
inner,
|
||||
StateGuard(allowed_states=['PENDING'])
|
||||
], operator='AND')
|
||||
outer = CompositeGuard([inner, StateGuard(allowed_states=["PENDING"])], operator="AND")
|
||||
|
||||
result = outer(instance, user=self.moderator)
|
||||
|
||||
@@ -713,12 +658,12 @@ class CompositeGuardTests(TestCase):
|
||||
instance = MockInstance(created_by=self.owner)
|
||||
perm_guard = PermissionGuard(requires_admin=True)
|
||||
guards = [perm_guard]
|
||||
composite = CompositeGuard(guards, operator='AND')
|
||||
composite = CompositeGuard(guards, operator="AND")
|
||||
|
||||
composite(instance, user=self.owner)
|
||||
|
||||
message = composite.get_error_message()
|
||||
self.assertIn('admin', message.lower())
|
||||
self.assertIn("admin", message.lower())
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -732,15 +677,12 @@ class GuardFactoryTests(TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.moderator = User.objects.create_user(
|
||||
username='moderator',
|
||||
email='moderator@example.com',
|
||||
password='testpass123',
|
||||
role='MODERATOR'
|
||||
username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR"
|
||||
)
|
||||
|
||||
def test_create_permission_guard_moderator(self):
|
||||
"""Test create_permission_guard with moderator requirement."""
|
||||
metadata = {'requires_moderator': True}
|
||||
metadata = {"requires_moderator": True}
|
||||
guard = create_permission_guard(metadata)
|
||||
instance = MockInstance()
|
||||
|
||||
@@ -750,14 +692,14 @@ class GuardFactoryTests(TestCase):
|
||||
|
||||
def test_create_permission_guard_admin(self):
|
||||
"""Test create_permission_guard with admin requirement."""
|
||||
metadata = {'requires_admin_approval': True}
|
||||
metadata = {"requires_admin_approval": True}
|
||||
guard = create_permission_guard(metadata)
|
||||
|
||||
self.assertTrue(guard.requires_admin)
|
||||
|
||||
def test_create_permission_guard_escalation_level(self):
|
||||
"""Test create_permission_guard with escalation level."""
|
||||
metadata = {'escalation_level': 'admin'}
|
||||
metadata = {"escalation_level": "admin"}
|
||||
guard = create_permission_guard(metadata)
|
||||
|
||||
self.assertTrue(guard.requires_admin)
|
||||
@@ -777,9 +719,9 @@ class GuardFactoryTests(TestCase):
|
||||
def test_create_composite_guard(self):
|
||||
"""Test create_composite_guard factory."""
|
||||
guards = [PermissionGuard(), OwnershipGuard()]
|
||||
composite = create_composite_guard(guards, operator='OR')
|
||||
composite = create_composite_guard(guards, operator="OR")
|
||||
|
||||
self.assertEqual(composite.operator, 'OR')
|
||||
self.assertEqual(composite.operator, "OR")
|
||||
self.assertEqual(len(composite.guards), 2)
|
||||
|
||||
|
||||
@@ -793,7 +735,7 @@ class MetadataExtractionTests(TestCase):
|
||||
|
||||
def test_extract_moderator_guard(self):
|
||||
"""Test extracting guard for moderator requirement."""
|
||||
metadata = {'requires_moderator': True}
|
||||
metadata = {"requires_moderator": True}
|
||||
guards = extract_guards_from_metadata(metadata)
|
||||
|
||||
self.assertEqual(len(guards), 1)
|
||||
@@ -801,7 +743,7 @@ class MetadataExtractionTests(TestCase):
|
||||
|
||||
def test_extract_admin_guard(self):
|
||||
"""Test extracting guard for admin requirement."""
|
||||
metadata = {'requires_admin_approval': True}
|
||||
metadata = {"requires_admin_approval": True}
|
||||
guards = extract_guards_from_metadata(metadata)
|
||||
|
||||
self.assertEqual(len(guards), 1)
|
||||
@@ -809,7 +751,7 @@ class MetadataExtractionTests(TestCase):
|
||||
|
||||
def test_extract_assignment_guard(self):
|
||||
"""Test extracting assignment guard."""
|
||||
metadata = {'requires_assignment': True}
|
||||
metadata = {"requires_assignment": True}
|
||||
guards = extract_guards_from_metadata(metadata)
|
||||
|
||||
self.assertEqual(len(guards), 1)
|
||||
@@ -817,17 +759,14 @@ class MetadataExtractionTests(TestCase):
|
||||
|
||||
def test_extract_multiple_guards(self):
|
||||
"""Test extracting multiple guards."""
|
||||
metadata = {
|
||||
'requires_moderator': True,
|
||||
'requires_assignment': True
|
||||
}
|
||||
metadata = {"requires_moderator": True, "requires_assignment": True}
|
||||
guards = extract_guards_from_metadata(metadata)
|
||||
|
||||
self.assertEqual(len(guards), 2)
|
||||
|
||||
def test_extract_zero_tolerance_guard(self):
|
||||
"""Test extracting guard for zero tolerance (superuser required)."""
|
||||
metadata = {'zero_tolerance': True}
|
||||
metadata = {"zero_tolerance": True}
|
||||
guards = extract_guards_from_metadata(metadata)
|
||||
|
||||
self.assertEqual(len(guards), 1)
|
||||
@@ -835,7 +774,7 @@ class MetadataExtractionTests(TestCase):
|
||||
|
||||
def test_invalid_escalation_level_raises(self):
|
||||
"""Test that invalid escalation level raises ValueError."""
|
||||
metadata = {'escalation_level': 'invalid'}
|
||||
metadata = {"escalation_level": "invalid"}
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
extract_guards_from_metadata(metadata)
|
||||
@@ -851,11 +790,7 @@ class MetadataValidationTests(TestCase):
|
||||
|
||||
def test_valid_metadata(self):
|
||||
"""Test that valid metadata passes validation."""
|
||||
metadata = {
|
||||
'requires_moderator': True,
|
||||
'escalation_level': 'admin',
|
||||
'requires_assignment': False
|
||||
}
|
||||
metadata = {"requires_moderator": True, "escalation_level": "admin", "requires_assignment": False}
|
||||
|
||||
is_valid, errors = validate_guard_metadata(metadata)
|
||||
|
||||
@@ -864,30 +799,30 @@ class MetadataValidationTests(TestCase):
|
||||
|
||||
def test_invalid_escalation_level(self):
|
||||
"""Test that invalid escalation level fails validation."""
|
||||
metadata = {'escalation_level': 'invalid_level'}
|
||||
metadata = {"escalation_level": "invalid_level"}
|
||||
|
||||
is_valid, errors = validate_guard_metadata(metadata)
|
||||
|
||||
self.assertFalse(is_valid)
|
||||
self.assertTrue(any('escalation_level' in e for e in errors))
|
||||
self.assertTrue(any("escalation_level" in e for e in errors))
|
||||
|
||||
def test_invalid_boolean_field(self):
|
||||
"""Test that non-boolean value for boolean field fails validation."""
|
||||
metadata = {'requires_moderator': 'yes'}
|
||||
metadata = {"requires_moderator": "yes"}
|
||||
|
||||
is_valid, errors = validate_guard_metadata(metadata)
|
||||
|
||||
self.assertFalse(is_valid)
|
||||
self.assertTrue(any('requires_moderator' in e for e in errors))
|
||||
self.assertTrue(any("requires_moderator" in e for e in errors))
|
||||
|
||||
def test_required_permissions_not_list(self):
|
||||
"""Test that non-list required_permissions fails validation."""
|
||||
metadata = {'required_permissions': 'app.permission'}
|
||||
metadata = {"required_permissions": "app.permission"}
|
||||
|
||||
is_valid, errors = validate_guard_metadata(metadata)
|
||||
|
||||
self.assertFalse(is_valid)
|
||||
self.assertTrue(any('required_permissions' in e for e in errors))
|
||||
self.assertTrue(any("required_permissions" in e for e in errors))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
@@ -901,42 +836,30 @@ class RoleHelperTests(TestCase):
|
||||
def setUp(self):
|
||||
"""Set up test fixtures."""
|
||||
self.regular_user = User.objects.create_user(
|
||||
username='user',
|
||||
email='user@example.com',
|
||||
password='testpass123',
|
||||
role='USER'
|
||||
username="user", email="user@example.com", password="testpass123", role="USER"
|
||||
)
|
||||
self.moderator = User.objects.create_user(
|
||||
username='moderator',
|
||||
email='moderator@example.com',
|
||||
password='testpass123',
|
||||
role='MODERATOR'
|
||||
username="moderator", email="moderator@example.com", password="testpass123", role="MODERATOR"
|
||||
)
|
||||
self.admin = User.objects.create_user(
|
||||
username='admin',
|
||||
email='admin@example.com',
|
||||
password='testpass123',
|
||||
role='ADMIN'
|
||||
username="admin", email="admin@example.com", password="testpass123", role="ADMIN"
|
||||
)
|
||||
self.superuser = User.objects.create_user(
|
||||
username='superuser',
|
||||
email='superuser@example.com',
|
||||
password='testpass123',
|
||||
role='SUPERUSER'
|
||||
username="superuser", email="superuser@example.com", password="testpass123", role="SUPERUSER"
|
||||
)
|
||||
|
||||
def test_get_user_role(self):
|
||||
"""Test get_user_role returns correct role."""
|
||||
self.assertEqual(get_user_role(self.regular_user), 'USER')
|
||||
self.assertEqual(get_user_role(self.moderator), 'MODERATOR')
|
||||
self.assertEqual(get_user_role(self.admin), 'ADMIN')
|
||||
self.assertEqual(get_user_role(self.superuser), 'SUPERUSER')
|
||||
self.assertEqual(get_user_role(self.regular_user), "USER")
|
||||
self.assertEqual(get_user_role(self.moderator), "MODERATOR")
|
||||
self.assertEqual(get_user_role(self.admin), "ADMIN")
|
||||
self.assertEqual(get_user_role(self.superuser), "SUPERUSER")
|
||||
self.assertIsNone(get_user_role(None))
|
||||
|
||||
def test_has_role(self):
|
||||
"""Test has_role function."""
|
||||
self.assertTrue(has_role(self.moderator, ['MODERATOR', 'ADMIN']))
|
||||
self.assertFalse(has_role(self.regular_user, ['MODERATOR', 'ADMIN']))
|
||||
self.assertTrue(has_role(self.moderator, ["MODERATOR", "ADMIN"]))
|
||||
self.assertFalse(has_role(self.regular_user, ["MODERATOR", "ADMIN"]))
|
||||
|
||||
def test_is_moderator_or_above(self):
|
||||
"""Test is_moderator_or_above function."""
|
||||
@@ -963,7 +886,7 @@ class RoleHelperTests(TestCase):
|
||||
"""Test that anonymous user has no role."""
|
||||
anonymous = AnonymousUser()
|
||||
|
||||
self.assertFalse(has_role(anonymous, ['USER']))
|
||||
self.assertFalse(has_role(anonymous, ["USER"]))
|
||||
self.assertFalse(is_moderator_or_above(anonymous))
|
||||
self.assertFalse(is_admin_or_above(anonymous))
|
||||
self.assertFalse(is_superuser_role(anonymous))
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Integration tests for state machine model integration."""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
@@ -74,31 +75,23 @@ def test_generate_transition_methods(sample_choices):
|
||||
"""Test generating transition methods on model."""
|
||||
mock_model = type("MockModel", (), {})
|
||||
|
||||
generate_transition_methods_for_model(
|
||||
mock_model, "status", "test_states", "test"
|
||||
)
|
||||
generate_transition_methods_for_model(mock_model, "status", "test_states", "test")
|
||||
|
||||
# Check that transition methods were added
|
||||
# Method names may vary based on implementation
|
||||
assert hasattr(mock_model, "approve") or hasattr(
|
||||
mock_model, "transition_to_approved"
|
||||
)
|
||||
assert hasattr(mock_model, "approve") or hasattr(mock_model, "transition_to_approved")
|
||||
|
||||
|
||||
def test_state_machine_model_decorator(sample_choices):
|
||||
"""Test state_machine_model decorator."""
|
||||
|
||||
@state_machine_model(
|
||||
field_name="status", choice_group="test_states", domain="test"
|
||||
)
|
||||
@state_machine_model(field_name="status", choice_group="test_states", domain="test")
|
||||
class TestModel:
|
||||
pass
|
||||
|
||||
# Decorator should apply state machine
|
||||
# Check for transition methods
|
||||
assert hasattr(TestModel, "approve") or hasattr(
|
||||
TestModel, "transition_to_approved"
|
||||
)
|
||||
assert hasattr(TestModel, "approve") or hasattr(TestModel, "transition_to_approved")
|
||||
|
||||
|
||||
def test_state_machine_mixin_get_available_transitions():
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Tests for TransitionRegistry."""
|
||||
|
||||
import pytest
|
||||
|
||||
from apps.core.choices.base import RichChoice
|
||||
@@ -55,12 +56,8 @@ def test_transition_info_creation():
|
||||
|
||||
def test_transition_info_hashable():
|
||||
"""Test TransitionInfo is hashable."""
|
||||
info1 = TransitionInfo(
|
||||
source="pending", target="approved", method_name="approve"
|
||||
)
|
||||
info2 = TransitionInfo(
|
||||
source="pending", target="approved", method_name="approve"
|
||||
)
|
||||
info1 = TransitionInfo(source="pending", target="approved", method_name="approve")
|
||||
info2 = TransitionInfo(source="pending", target="approved", method_name="approve")
|
||||
assert hash(info1) == hash(info2)
|
||||
|
||||
|
||||
@@ -82,9 +79,7 @@ def test_register_transition():
|
||||
metadata={"requires_moderator": True},
|
||||
)
|
||||
|
||||
transition = registry_instance.get_transition(
|
||||
"test_states", "test", "pending", "approved"
|
||||
)
|
||||
transition = registry_instance.get_transition("test_states", "test", "pending", "approved")
|
||||
assert transition is not None
|
||||
assert transition.method_name == "approve"
|
||||
assert transition.requires_moderator is True
|
||||
@@ -92,9 +87,7 @@ def test_register_transition():
|
||||
|
||||
def test_get_transition_not_found():
|
||||
"""Test getting non-existent transition."""
|
||||
transition = registry_instance.get_transition(
|
||||
"nonexistent", "test", "pending", "approved"
|
||||
)
|
||||
transition = registry_instance.get_transition("nonexistent", "test", "pending", "approved")
|
||||
assert transition is None
|
||||
|
||||
|
||||
@@ -102,9 +95,7 @@ def test_get_available_transitions(sample_choices):
|
||||
"""Test getting available transitions from a state."""
|
||||
registry_instance.build_registry_from_choices("test_states", "test")
|
||||
|
||||
available = registry_instance.get_available_transitions(
|
||||
"test_states", "test", "pending"
|
||||
)
|
||||
available = registry_instance.get_available_transitions("test_states", "test", "pending")
|
||||
assert len(available) == 2
|
||||
targets = [t.target for t in available]
|
||||
assert "approved" in targets
|
||||
@@ -121,9 +112,7 @@ def test_get_transition_method_name():
|
||||
method_name="approve",
|
||||
)
|
||||
|
||||
method_name = registry_instance.get_transition_method_name(
|
||||
"test_states", "test", "pending", "approved"
|
||||
)
|
||||
method_name = registry_instance.get_transition_method_name("test_states", "test", "pending", "approved")
|
||||
assert method_name == "approve"
|
||||
|
||||
|
||||
@@ -137,12 +126,8 @@ def test_validate_transition():
|
||||
method_name="approve",
|
||||
)
|
||||
|
||||
assert registry_instance.validate_transition(
|
||||
"test_states", "test", "pending", "approved"
|
||||
)
|
||||
assert not registry_instance.validate_transition(
|
||||
"test_states", "test", "pending", "nonexistent"
|
||||
)
|
||||
assert registry_instance.validate_transition("test_states", "test", "pending", "approved")
|
||||
assert not registry_instance.validate_transition("test_states", "test", "pending", "nonexistent")
|
||||
|
||||
|
||||
def test_build_registry_from_choices(sample_choices):
|
||||
@@ -150,9 +135,7 @@ def test_build_registry_from_choices(sample_choices):
|
||||
registry_instance.build_registry_from_choices("test_states", "test")
|
||||
|
||||
# Check transitions were registered
|
||||
transition = registry_instance.get_transition(
|
||||
"test_states", "test", "pending", "approved"
|
||||
)
|
||||
transition = registry_instance.get_transition("test_states", "test", "pending", "approved")
|
||||
assert transition is not None
|
||||
|
||||
|
||||
@@ -168,9 +151,7 @@ def test_clear_registry_specific():
|
||||
|
||||
registry_instance.clear_registry(choice_group="test_states", domain="test")
|
||||
|
||||
transition = registry_instance.get_transition(
|
||||
"test_states", "test", "pending", "approved"
|
||||
)
|
||||
transition = registry_instance.get_transition("test_states", "test", "pending", "approved")
|
||||
assert transition is None
|
||||
|
||||
|
||||
@@ -186,9 +167,7 @@ def test_clear_registry_all():
|
||||
|
||||
registry_instance.clear_registry()
|
||||
|
||||
transition = registry_instance.get_transition(
|
||||
"test_states", "test", "pending", "approved"
|
||||
)
|
||||
transition = registry_instance.get_transition("test_states", "test", "pending", "approved")
|
||||
assert transition is None
|
||||
|
||||
|
||||
@@ -196,9 +175,7 @@ def test_export_transition_graph_dict(sample_choices):
|
||||
"""Test exporting transition graph as dict."""
|
||||
registry_instance.build_registry_from_choices("test_states", "test")
|
||||
|
||||
graph = registry_instance.export_transition_graph(
|
||||
"test_states", "test", format="dict"
|
||||
)
|
||||
graph = registry_instance.export_transition_graph("test_states", "test", format="dict")
|
||||
assert isinstance(graph, dict)
|
||||
assert "pending" in graph
|
||||
assert set(graph["pending"]) == {"approved", "rejected"}
|
||||
@@ -208,9 +185,7 @@ def test_export_transition_graph_mermaid(sample_choices):
|
||||
"""Test exporting transition graph as mermaid."""
|
||||
registry_instance.build_registry_from_choices("test_states", "test")
|
||||
|
||||
graph = registry_instance.export_transition_graph(
|
||||
"test_states", "test", format="mermaid"
|
||||
)
|
||||
graph = registry_instance.export_transition_graph("test_states", "test", format="mermaid")
|
||||
assert isinstance(graph, str)
|
||||
assert "stateDiagram-v2" in graph
|
||||
assert "pending" in graph
|
||||
@@ -220,9 +195,7 @@ def test_export_transition_graph_dot(sample_choices):
|
||||
"""Test exporting transition graph as DOT."""
|
||||
registry_instance.build_registry_from_choices("test_states", "test")
|
||||
|
||||
graph = registry_instance.export_transition_graph(
|
||||
"test_states", "test", format="dot"
|
||||
)
|
||||
graph = registry_instance.export_transition_graph("test_states", "test", format="dot")
|
||||
assert isinstance(graph, str)
|
||||
assert "digraph" in graph
|
||||
assert "pending" in graph
|
||||
@@ -233,9 +206,7 @@ def test_export_invalid_format(sample_choices):
|
||||
registry_instance.build_registry_from_choices("test_states", "test")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
registry_instance.export_transition_graph(
|
||||
"test_states", "test", format="invalid"
|
||||
)
|
||||
registry_instance.export_transition_graph("test_states", "test", format="invalid")
|
||||
|
||||
|
||||
def test_get_all_registered_groups():
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Tests for metadata validators."""
|
||||
|
||||
import pytest
|
||||
|
||||
from apps.core.choices.base import RichChoice
|
||||
@@ -70,9 +71,7 @@ def terminal_with_transitions():
|
||||
|
||||
def test_validation_error_creation():
|
||||
"""Test ValidationError creation."""
|
||||
error = ValidationError(
|
||||
code="TEST_ERROR", message="Test message", state="pending"
|
||||
)
|
||||
error = ValidationError(code="TEST_ERROR", message="Test message", state="pending")
|
||||
assert error.code == "TEST_ERROR"
|
||||
assert error.message == "Test message"
|
||||
assert error.state == "pending"
|
||||
@@ -81,9 +80,7 @@ def test_validation_error_creation():
|
||||
|
||||
def test_validation_warning_creation():
|
||||
"""Test ValidationWarning creation."""
|
||||
warning = ValidationWarning(
|
||||
code="TEST_WARNING", message="Test warning", state="pending"
|
||||
)
|
||||
warning = ValidationWarning(code="TEST_WARNING", message="Test warning", state="pending")
|
||||
assert warning.code == "TEST_WARNING"
|
||||
assert warning.message == "Test warning"
|
||||
|
||||
@@ -166,15 +163,9 @@ def test_validate_no_cycles(valid_choices):
|
||||
def test_validate_no_cycles_with_cycle():
|
||||
"""Test cycle detection finds cycles."""
|
||||
choices = [
|
||||
RichChoice(
|
||||
value="a", label="A", metadata={"can_transition_to": ["b"]}
|
||||
),
|
||||
RichChoice(
|
||||
value="b", label="B", metadata={"can_transition_to": ["c"]}
|
||||
),
|
||||
RichChoice(
|
||||
value="c", label="C", metadata={"can_transition_to": ["a"]}
|
||||
),
|
||||
RichChoice(value="a", label="A", metadata={"can_transition_to": ["b"]}),
|
||||
RichChoice(value="b", label="B", metadata={"can_transition_to": ["c"]}),
|
||||
RichChoice(value="c", label="C", metadata={"can_transition_to": ["a"]}),
|
||||
]
|
||||
registry.register("cycle_states", choices, domain="test")
|
||||
|
||||
@@ -202,9 +193,7 @@ def test_validate_reachability_unreachable():
|
||||
label="Pending",
|
||||
metadata={"can_transition_to": ["approved"]},
|
||||
),
|
||||
RichChoice(
|
||||
value="approved", label="Approved", metadata={"is_final": True}
|
||||
),
|
||||
RichChoice(value="approved", label="Approved", metadata={"is_final": True}),
|
||||
RichChoice(
|
||||
value="orphan",
|
||||
label="Orphan",
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Metadata validators for ensuring RichChoice metadata meets FSM requirements."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
@@ -110,8 +111,7 @@ class MetadataValidator:
|
||||
ValidationError(
|
||||
code="MISSING_CAN_TRANSITION_TO",
|
||||
message=(
|
||||
"State metadata must explicitly define "
|
||||
"'can_transition_to' (use [] for terminal states)"
|
||||
"State metadata must explicitly define " "'can_transition_to' (use [] for terminal states)"
|
||||
),
|
||||
state=state,
|
||||
)
|
||||
@@ -138,9 +138,7 @@ class MetadataValidator:
|
||||
errors.append(
|
||||
ValidationError(
|
||||
code="INVALID_TRANSITION_TARGET",
|
||||
message=(
|
||||
f"Transition target '{target}' does not exist"
|
||||
),
|
||||
message=(f"Transition target '{target}' does not exist"),
|
||||
state=state,
|
||||
)
|
||||
)
|
||||
@@ -188,17 +186,11 @@ class MetadataValidator:
|
||||
perms = self.builder.extract_permission_requirements(state)
|
||||
|
||||
# Check for contradictory permissions
|
||||
if (
|
||||
perms.get("requires_admin_approval")
|
||||
and not perms.get("requires_moderator")
|
||||
):
|
||||
if perms.get("requires_admin_approval") and not perms.get("requires_moderator"):
|
||||
errors.append(
|
||||
ValidationError(
|
||||
code="PERMISSION_INCONSISTENCY",
|
||||
message=(
|
||||
"State requires admin approval but not moderator "
|
||||
"(admin should imply moderator)"
|
||||
),
|
||||
message=("State requires admin approval but not moderator " "(admin should imply moderator)"),
|
||||
state=state,
|
||||
)
|
||||
)
|
||||
@@ -251,9 +243,7 @@ class MetadataValidator:
|
||||
errors.append(
|
||||
ValidationError(
|
||||
code="STATE_CYCLE_DETECTED",
|
||||
message=(
|
||||
f"Cycle detected: {' -> '.join(cycle)}"
|
||||
),
|
||||
message=(f"Cycle detected: {' -> '.join(cycle)}"),
|
||||
state=cycle[0],
|
||||
)
|
||||
)
|
||||
@@ -278,9 +268,7 @@ class MetadataValidator:
|
||||
for target in targets:
|
||||
incoming[target].append(source)
|
||||
|
||||
initial_states = [
|
||||
state for state in all_states if not incoming[state]
|
||||
]
|
||||
initial_states = [state for state in all_states if not incoming[state]]
|
||||
|
||||
if not initial_states:
|
||||
errors.append(
|
||||
@@ -327,9 +315,7 @@ class MetadataValidator:
|
||||
result = self.validate_choice_group()
|
||||
|
||||
lines = []
|
||||
lines.append(
|
||||
f"Validation Report for {self.domain}.{self.choice_group}"
|
||||
)
|
||||
lines.append(f"Validation Report for {self.domain}.{self.choice_group}")
|
||||
lines.append("=" * 60)
|
||||
lines.append(f"Status: {'VALID' if result.is_valid else 'INVALID'}")
|
||||
lines.append(f"Errors: {len(result.errors)}")
|
||||
@@ -372,10 +358,7 @@ def validate_on_registration(choice_group: str, domain: str = "core") -> bool:
|
||||
|
||||
if not result.is_valid:
|
||||
error_messages = [str(e) for e in result.errors]
|
||||
raise ValueError(
|
||||
f"Validation failed for {domain}.{choice_group}:\n"
|
||||
+ "\n".join(error_messages)
|
||||
)
|
||||
raise ValueError(f"Validation failed for {domain}.{choice_group}:\n" + "\n".join(error_messages))
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -23,9 +23,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(bind=True, max_retries=3, default_retry_delay=60)
|
||||
def calculate_trending_content(
|
||||
self, content_type: str = "all", limit: int = 50
|
||||
) -> dict[str, Any]:
|
||||
def calculate_trending_content(self, content_type: str = "all", limit: int = 50) -> dict[str, Any]:
|
||||
"""
|
||||
Calculate trending content using real analytics data.
|
||||
|
||||
@@ -72,17 +70,13 @@ def calculate_trending_content(
|
||||
trending_items = trending_items[:limit]
|
||||
|
||||
# Format results for API consumption
|
||||
formatted_results = _format_trending_results(
|
||||
trending_items, current_period_hours, previous_period_hours
|
||||
)
|
||||
formatted_results = _format_trending_results(trending_items, current_period_hours, previous_period_hours)
|
||||
|
||||
# Cache results
|
||||
cache_key = f"trending:calculated:{content_type}:{limit}"
|
||||
cache.set(cache_key, formatted_results, 3600) # Cache for 1 hour
|
||||
|
||||
logger.info(
|
||||
f"Calculated {len(formatted_results)} trending items for {content_type}"
|
||||
)
|
||||
logger.info(f"Calculated {len(formatted_results)} trending items for {content_type}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
@@ -95,13 +89,11 @@ def calculate_trending_content(
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating trending content: {e}", exc_info=True)
|
||||
# Retry the task
|
||||
raise self.retry(exc=e)
|
||||
raise self.retry(exc=e) from None
|
||||
|
||||
|
||||
@shared_task(bind=True, max_retries=3, default_retry_delay=30)
|
||||
def calculate_new_content(
|
||||
self, content_type: str = "all", days_back: int = 30, limit: int = 50
|
||||
) -> dict[str, Any]:
|
||||
def calculate_new_content(self, content_type: str = "all", days_back: int = 30, limit: int = 50) -> dict[str, Any]:
|
||||
"""
|
||||
Calculate new content based on opening dates and creation dates.
|
||||
|
||||
@@ -120,15 +112,11 @@ def calculate_new_content(
|
||||
new_items = []
|
||||
|
||||
if content_type in ["all", "parks"]:
|
||||
parks = _get_new_parks(
|
||||
cutoff_date, limit if content_type == "parks" else limit * 2
|
||||
)
|
||||
parks = _get_new_parks(cutoff_date, limit if content_type == "parks" else limit * 2)
|
||||
new_items.extend(parks)
|
||||
|
||||
if content_type in ["all", "rides"]:
|
||||
rides = _get_new_rides(
|
||||
cutoff_date, limit if content_type == "rides" else limit * 2
|
||||
)
|
||||
rides = _get_new_rides(cutoff_date, limit if content_type == "rides" else limit * 2)
|
||||
new_items.extend(rides)
|
||||
|
||||
# Sort by date added (most recent first) and apply limit
|
||||
@@ -154,7 +142,7 @@ def calculate_new_content(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating new content: {e}", exc_info=True)
|
||||
raise self.retry(exc=e)
|
||||
raise self.retry(exc=e) from None
|
||||
|
||||
|
||||
@shared_task(bind=True)
|
||||
@@ -185,9 +173,7 @@ def warm_trending_cache(self) -> dict[str, Any]:
|
||||
calculate_new_content.delay(**query)
|
||||
|
||||
results[f"trending_{query['content_type']}_{query['limit']}"] = "scheduled"
|
||||
results[f"new_content_{query['content_type']}_{query['limit']}"] = (
|
||||
"scheduled"
|
||||
)
|
||||
results[f"new_content_{query['content_type']}_{query['limit']}"] = "scheduled"
|
||||
|
||||
logger.info("Trending cache warming completed")
|
||||
|
||||
@@ -211,17 +197,13 @@ def _calculate_trending_parks(
|
||||
current_period_hours: int, previous_period_hours: int, limit: int
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Calculate trending scores for parks using real data."""
|
||||
parks = Park.objects.filter(status="OPERATING").select_related(
|
||||
"location", "operator"
|
||||
)
|
||||
parks = Park.objects.filter(status="OPERATING").select_related("location", "operator")
|
||||
|
||||
trending_parks = []
|
||||
|
||||
for park in parks:
|
||||
try:
|
||||
score = _calculate_content_score(
|
||||
park, "park", current_period_hours, previous_period_hours
|
||||
)
|
||||
score = _calculate_content_score(park, "park", current_period_hours, previous_period_hours)
|
||||
if score > 0: # Only include items with positive trending scores
|
||||
trending_parks.append(
|
||||
{
|
||||
@@ -231,13 +213,9 @@ def _calculate_trending_parks(
|
||||
"id": park.id,
|
||||
"name": park.name,
|
||||
"slug": park.slug,
|
||||
"location": (
|
||||
park.formatted_location if hasattr(park, "location") else ""
|
||||
),
|
||||
"location": (park.formatted_location if hasattr(park, "location") else ""),
|
||||
"category": "park",
|
||||
"rating": (
|
||||
float(park.average_rating) if park.average_rating else 0.0
|
||||
),
|
||||
"rating": (float(park.average_rating) if park.average_rating else 0.0),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -250,17 +228,13 @@ def _calculate_trending_rides(
|
||||
current_period_hours: int, previous_period_hours: int, limit: int
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Calculate trending scores for rides using real data."""
|
||||
rides = Ride.objects.filter(status="OPERATING").select_related(
|
||||
"park", "park__location"
|
||||
)
|
||||
rides = Ride.objects.filter(status="OPERATING").select_related("park", "park__location")
|
||||
|
||||
trending_rides = []
|
||||
|
||||
for ride in rides:
|
||||
try:
|
||||
score = _calculate_content_score(
|
||||
ride, "ride", current_period_hours, previous_period_hours
|
||||
)
|
||||
score = _calculate_content_score(ride, "ride", current_period_hours, previous_period_hours)
|
||||
if score > 0: # Only include items with positive trending scores
|
||||
# Get location from park
|
||||
location = ""
|
||||
@@ -277,9 +251,7 @@ def _calculate_trending_rides(
|
||||
"slug": ride.slug,
|
||||
"location": location,
|
||||
"category": "ride",
|
||||
"rating": (
|
||||
float(ride.average_rating) if ride.average_rating else 0.0
|
||||
),
|
||||
"rating": (float(ride.average_rating) if ride.average_rating else 0.0),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -322,17 +294,10 @@ def _calculate_content_score(
|
||||
recency_score = _calculate_recency_score(content_obj)
|
||||
|
||||
# 4. Popularity Score (10% weight)
|
||||
popularity_score = _calculate_popularity_score(
|
||||
ct, content_obj.id, current_period_hours
|
||||
)
|
||||
popularity_score = _calculate_popularity_score(ct, content_obj.id, current_period_hours)
|
||||
|
||||
# Calculate weighted final score
|
||||
final_score = (
|
||||
view_growth_score * 0.4
|
||||
+ rating_score * 0.3
|
||||
+ recency_score * 0.2
|
||||
+ popularity_score * 0.1
|
||||
)
|
||||
final_score = view_growth_score * 0.4 + rating_score * 0.3 + recency_score * 0.2 + popularity_score * 0.1
|
||||
|
||||
logger.debug(
|
||||
f"{content_type} {content_obj.id}: "
|
||||
@@ -344,9 +309,7 @@ def _calculate_content_score(
|
||||
return final_score
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error calculating score for {content_type} {content_obj.id}: {e}"
|
||||
)
|
||||
logger.error(f"Error calculating score for {content_type} {content_obj.id}: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
@@ -371,9 +334,7 @@ def _calculate_view_growth_score(
|
||||
|
||||
# Normalize growth percentage to 0-1 scale
|
||||
# 100% growth = 0.5, 500% growth = 1.0
|
||||
normalized_growth = (
|
||||
min(growth_percentage / 500.0, 1.0) if growth_percentage > 0 else 0.0
|
||||
)
|
||||
normalized_growth = min(growth_percentage / 500.0, 1.0) if growth_percentage > 0 else 0.0
|
||||
return max(normalized_growth, 0.0)
|
||||
|
||||
except Exception as e:
|
||||
@@ -431,14 +392,10 @@ def _calculate_recency_score(content_obj: Any) -> float:
|
||||
return 0.5
|
||||
|
||||
|
||||
def _calculate_popularity_score(
|
||||
content_type: ContentType, object_id: int, hours: int
|
||||
) -> float:
|
||||
def _calculate_popularity_score(content_type: ContentType, object_id: int, hours: int) -> float:
|
||||
"""Calculate popularity score based on total view count."""
|
||||
try:
|
||||
total_views = PageView.get_total_views_count(
|
||||
content_type, object_id, hours=hours
|
||||
)
|
||||
total_views = PageView.get_total_views_count(content_type, object_id, hours=hours)
|
||||
|
||||
# Normalize views to 0-1 scale
|
||||
# 0 views = 0.0, 100 views = 0.5, 1000+ views = 1.0
|
||||
@@ -505,9 +462,7 @@ def _get_new_rides(cutoff_date: datetime, limit: int) -> list[dict[str, Any]]:
|
||||
|
||||
results = []
|
||||
for ride in new_rides:
|
||||
date_added = getattr(ride, "opening_date", None) or getattr(
|
||||
ride, "created_at", None
|
||||
)
|
||||
date_added = getattr(ride, "opening_date", None) or getattr(ride, "created_at", None)
|
||||
if date_added and isinstance(date_added, datetime):
|
||||
date_added = date_added.date()
|
||||
|
||||
@@ -545,13 +500,11 @@ def _format_trending_results(
|
||||
# Get view change for display
|
||||
content_obj = item["content_object"]
|
||||
ct = ContentType.objects.get_for_model(content_obj)
|
||||
current_views, previous_views, growth_percentage = (
|
||||
PageView.get_views_growth(
|
||||
ct,
|
||||
content_obj.id,
|
||||
current_period_hours,
|
||||
previous_period_hours,
|
||||
)
|
||||
current_views, previous_views, growth_percentage = PageView.get_views_growth(
|
||||
ct,
|
||||
content_obj.id,
|
||||
current_period_hours,
|
||||
previous_period_hours,
|
||||
)
|
||||
|
||||
# Format exactly as frontend expects
|
||||
@@ -564,9 +517,7 @@ def _format_trending_results(
|
||||
"rank": rank,
|
||||
"views": current_views,
|
||||
"views_change": (
|
||||
f"+{growth_percentage:.1f}%"
|
||||
if growth_percentage > 0
|
||||
else f"{growth_percentage:.1f}%"
|
||||
f"+{growth_percentage:.1f}%" if growth_percentage > 0 else f"{growth_percentage:.1f}%"
|
||||
),
|
||||
"slug": item["slug"],
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ register = template.Library()
|
||||
# Time and Date Filters
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@register.filter
|
||||
def humanize_timedelta(value):
|
||||
"""
|
||||
@@ -42,26 +43,26 @@ def humanize_timedelta(value):
|
||||
Human-readable string like "2 hours ago"
|
||||
"""
|
||||
if value is None:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
# Convert datetime to timedelta from now
|
||||
if hasattr(value, 'tzinfo'): # It's a datetime
|
||||
if hasattr(value, "tzinfo"): # It's a datetime
|
||||
now = timezone.now()
|
||||
if value > now:
|
||||
return 'in the future'
|
||||
return "in the future"
|
||||
value = now - value
|
||||
|
||||
# Convert seconds to timedelta
|
||||
if isinstance(value, (int, float)):
|
||||
if isinstance(value, int | float):
|
||||
value = timedelta(seconds=value)
|
||||
|
||||
if not isinstance(value, timedelta):
|
||||
return ''
|
||||
return ""
|
||||
|
||||
seconds = int(value.total_seconds())
|
||||
|
||||
if seconds < 60:
|
||||
return 'just now'
|
||||
return "just now"
|
||||
elif seconds < 3600:
|
||||
minutes = seconds // 60
|
||||
return f'{minutes} minute{"s" if minutes != 1 else ""} ago'
|
||||
@@ -92,22 +93,23 @@ def time_until(value):
|
||||
Output: "in 2 days", "in 3 hours"
|
||||
"""
|
||||
if value is None:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
if hasattr(value, 'tzinfo'):
|
||||
if hasattr(value, "tzinfo"):
|
||||
now = timezone.now()
|
||||
if value <= now:
|
||||
return 'now'
|
||||
return "now"
|
||||
diff = value - now
|
||||
return humanize_timedelta(diff).replace(' ago', '')
|
||||
return humanize_timedelta(diff).replace(" ago", "")
|
||||
|
||||
return ''
|
||||
return ""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Text Manipulation Filters
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@register.filter
|
||||
@stringfilter
|
||||
def truncate_smart(value, max_length=50):
|
||||
@@ -130,12 +132,12 @@ def truncate_smart(value, max_length=50):
|
||||
|
||||
# Find the last space before max_length
|
||||
truncated = value[:max_length]
|
||||
last_space = truncated.rfind(' ')
|
||||
last_space = truncated.rfind(" ")
|
||||
|
||||
if last_space > max_length * 0.5: # Only use word boundary if reasonable
|
||||
truncated = truncated[:last_space]
|
||||
|
||||
return truncated.rstrip('.,!?;:') + '...'
|
||||
return truncated.rstrip(".,!?;:") + "..."
|
||||
|
||||
|
||||
@register.filter
|
||||
@@ -153,7 +155,7 @@ def truncate_middle(value, max_length=50):
|
||||
return value
|
||||
|
||||
keep_chars = (max_length - 3) // 2
|
||||
return f'{value[:keep_chars]}...{value[-keep_chars:]}'
|
||||
return f"{value[:keep_chars]}...{value[-keep_chars:]}"
|
||||
|
||||
|
||||
@register.filter
|
||||
@@ -167,13 +169,14 @@ def initials(value, max_initials=2):
|
||||
Output: "JD" for "John Doe"
|
||||
"""
|
||||
words = value.split()
|
||||
return ''.join(word[0].upper() for word in words[:max_initials] if word)
|
||||
return "".join(word[0].upper() for word in words[:max_initials] if word)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Number Formatting Filters
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@register.filter
|
||||
def format_number(value, decimals=0):
|
||||
"""
|
||||
@@ -187,14 +190,14 @@ def format_number(value, decimals=0):
|
||||
Output: "1,234.56"
|
||||
"""
|
||||
if value is None:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
try:
|
||||
value = float(value)
|
||||
decimals = int(decimals)
|
||||
if decimals > 0:
|
||||
return f'{value:,.{decimals}f}'
|
||||
return f'{int(value):,}'
|
||||
return f"{value:,.{decimals}f}"
|
||||
return f"{int(value):,}"
|
||||
except (ValueError, TypeError):
|
||||
return value
|
||||
|
||||
@@ -209,16 +212,16 @@ def format_compact(value):
|
||||
Output: "1.2K", "3.4M", "2.1B"
|
||||
"""
|
||||
if value is None:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
try:
|
||||
value = float(value)
|
||||
if value >= 1_000_000_000:
|
||||
return f'{value / 1_000_000_000:.1f}B'
|
||||
return f"{value / 1_000_000_000:.1f}B"
|
||||
elif value >= 1_000_000:
|
||||
return f'{value / 1_000_000:.1f}M'
|
||||
return f"{value / 1_000_000:.1f}M"
|
||||
elif value >= 1_000:
|
||||
return f'{value / 1_000:.1f}K'
|
||||
return f"{value / 1_000:.1f}K"
|
||||
return str(int(value))
|
||||
except (ValueError, TypeError):
|
||||
return value
|
||||
@@ -237,16 +240,17 @@ def percentage(value, total):
|
||||
value = float(value)
|
||||
total = float(total)
|
||||
if total == 0:
|
||||
return '0%'
|
||||
return f'{(value / total * 100):.0f}%'
|
||||
return "0%"
|
||||
return f"{(value / total * 100):.0f}%"
|
||||
except (ValueError, TypeError, ZeroDivisionError):
|
||||
return '0%'
|
||||
return "0%"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Dictionary/List Filters
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@register.filter
|
||||
def get_item(dictionary, key):
|
||||
"""
|
||||
@@ -278,7 +282,7 @@ def getlist(querydict, key):
|
||||
"""
|
||||
if querydict is None:
|
||||
return []
|
||||
if hasattr(querydict, 'getlist'):
|
||||
if hasattr(querydict, "getlist"):
|
||||
return querydict.getlist(key)
|
||||
return []
|
||||
|
||||
@@ -314,6 +318,7 @@ def index(sequence, i):
|
||||
# Pluralization Filters
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@register.filter
|
||||
def pluralize_custom(count, forms):
|
||||
"""
|
||||
@@ -330,7 +335,7 @@ def pluralize_custom(count, forms):
|
||||
"""
|
||||
try:
|
||||
count = int(count)
|
||||
singular, plural = forms.split(',')
|
||||
singular, plural = forms.split(",")
|
||||
return singular if count == 1 else plural
|
||||
except (ValueError, AttributeError):
|
||||
return forms
|
||||
@@ -347,9 +352,9 @@ def count_with_label(count, forms):
|
||||
"""
|
||||
try:
|
||||
count = int(count)
|
||||
singular, plural = forms.split(',')
|
||||
singular, plural = forms.split(",")
|
||||
label = singular if count == 1 else plural
|
||||
return f'{count} {label}'
|
||||
return f"{count} {label}"
|
||||
except (ValueError, AttributeError):
|
||||
return str(count)
|
||||
|
||||
@@ -358,6 +363,7 @@ def count_with_label(count, forms):
|
||||
# CSS Class Manipulation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@register.filter
|
||||
def add_class(field, css_class):
|
||||
"""
|
||||
@@ -366,10 +372,10 @@ def add_class(field, css_class):
|
||||
Usage:
|
||||
{{ form.email|add_class:"form-control" }}
|
||||
"""
|
||||
if hasattr(field, 'as_widget'):
|
||||
existing = field.field.widget.attrs.get('class', '')
|
||||
new_classes = f'{existing} {css_class}'.strip()
|
||||
return field.as_widget(attrs={'class': new_classes})
|
||||
if hasattr(field, "as_widget"):
|
||||
existing = field.field.widget.attrs.get("class", "")
|
||||
new_classes = f"{existing} {css_class}".strip()
|
||||
return field.as_widget(attrs={"class": new_classes})
|
||||
return field
|
||||
|
||||
|
||||
@@ -381,8 +387,8 @@ def set_attr(field, attr_value):
|
||||
Usage:
|
||||
{{ form.email|set_attr:"placeholder:Enter email" }}
|
||||
"""
|
||||
if hasattr(field, 'as_widget'):
|
||||
attr, value = attr_value.split(':')
|
||||
if hasattr(field, "as_widget"):
|
||||
attr, value = attr_value.split(":")
|
||||
return field.as_widget(attrs={attr: value})
|
||||
return field
|
||||
|
||||
@@ -391,6 +397,7 @@ def set_attr(field, attr_value):
|
||||
# Conditional Filters
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@register.filter
|
||||
def default_if_none(value, default):
|
||||
"""
|
||||
@@ -413,5 +420,5 @@ def yesno_icon(value, icons="fa-check,fa-times"):
|
||||
|
||||
{{ has_feature|yesno_icon:"fa-star,fa-star-o" }}
|
||||
"""
|
||||
true_icon, false_icon = icons.split(',')
|
||||
true_icon, false_icon = icons.split(",")
|
||||
return true_icon if value else false_icon
|
||||
|
||||
@@ -23,6 +23,7 @@ Usage:
|
||||
{# Render a transition button #}
|
||||
{% transition_button submission 'approve' request.user %}
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from django import template
|
||||
@@ -53,12 +54,12 @@ def get_state_value(obj) -> str | None:
|
||||
Returns:
|
||||
The current state value or None
|
||||
"""
|
||||
if hasattr(obj, 'get_state_value'):
|
||||
if hasattr(obj, "get_state_value"):
|
||||
return obj.get_state_value()
|
||||
if hasattr(obj, 'state_field_name'):
|
||||
if hasattr(obj, "state_field_name"):
|
||||
return getattr(obj, obj.state_field_name, None)
|
||||
# Try common field names
|
||||
for field in ['status', 'state']:
|
||||
for field in ["status", "state"]:
|
||||
if hasattr(obj, field):
|
||||
return getattr(obj, field, None)
|
||||
return None
|
||||
@@ -78,19 +79,19 @@ def get_state_display(obj) -> str:
|
||||
Returns:
|
||||
The human-readable state display value
|
||||
"""
|
||||
if hasattr(obj, 'get_state_display_value'):
|
||||
if hasattr(obj, "get_state_display_value"):
|
||||
return obj.get_state_display_value()
|
||||
if hasattr(obj, 'state_field_name'):
|
||||
if hasattr(obj, "state_field_name"):
|
||||
field_name = obj.state_field_name
|
||||
getter = getattr(obj, f'get_{field_name}_display', None)
|
||||
getter = getattr(obj, f"get_{field_name}_display", None)
|
||||
if callable(getter):
|
||||
return getter()
|
||||
# Try common field names
|
||||
for field in ['status', 'state']:
|
||||
getter = getattr(obj, f'get_{field}_display', None)
|
||||
for field in ["status", "state"]:
|
||||
getter = getattr(obj, f"get_{field}_display", None)
|
||||
if callable(getter):
|
||||
return getter()
|
||||
return str(get_state_value(obj) or '')
|
||||
return str(get_state_value(obj) or "")
|
||||
|
||||
|
||||
@register.filter
|
||||
@@ -109,7 +110,7 @@ def get_state_choice(obj):
|
||||
Returns:
|
||||
The RichChoice object or None
|
||||
"""
|
||||
if hasattr(obj, 'get_state_choice'):
|
||||
if hasattr(obj, "get_state_choice"):
|
||||
return obj.get_state_choice()
|
||||
return None
|
||||
|
||||
@@ -205,18 +206,24 @@ def get_available_transitions(obj, user) -> list[dict[str, Any]]:
|
||||
# Get list of available transitions
|
||||
available_transition_names = []
|
||||
|
||||
if hasattr(obj, 'get_available_user_transitions'):
|
||||
if hasattr(obj, "get_available_user_transitions"):
|
||||
# Use the helper method if available
|
||||
return obj.get_available_user_transitions(user)
|
||||
|
||||
if hasattr(obj, 'get_available_transitions'):
|
||||
if hasattr(obj, "get_available_transitions"):
|
||||
available_transition_names = list(obj.get_available_transitions())
|
||||
else:
|
||||
# Fallback: look for transition methods by convention
|
||||
for attr_name in dir(obj):
|
||||
if attr_name.startswith('transition_to_') or attr_name in ['approve', 'reject', 'escalate', 'complete', 'cancel']:
|
||||
if attr_name.startswith("transition_to_") or attr_name in [
|
||||
"approve",
|
||||
"reject",
|
||||
"escalate",
|
||||
"complete",
|
||||
"cancel",
|
||||
]:
|
||||
method = getattr(obj, attr_name, None)
|
||||
if callable(method) and hasattr(method, '_django_fsm'):
|
||||
if callable(method) and hasattr(method, "_django_fsm"):
|
||||
available_transition_names.append(attr_name)
|
||||
|
||||
# Filter transitions by user permission
|
||||
@@ -226,14 +233,16 @@ def get_available_transitions(obj, user) -> list[dict[str, Any]]:
|
||||
try:
|
||||
if can_proceed(method, user):
|
||||
metadata = get_transition_metadata(transition_name)
|
||||
transitions.append({
|
||||
'name': transition_name,
|
||||
'label': _format_transition_label(transition_name),
|
||||
'icon': metadata.get('icon', 'arrow-right'),
|
||||
'style': metadata.get('style', 'gray'),
|
||||
'requires_confirm': metadata.get('requires_confirm', False),
|
||||
'confirm_message': metadata.get('confirm_message', 'Are you sure?'),
|
||||
})
|
||||
transitions.append(
|
||||
{
|
||||
"name": transition_name,
|
||||
"label": _format_transition_label(transition_name),
|
||||
"icon": metadata.get("icon", "arrow-right"),
|
||||
"style": metadata.get("style", "gray"),
|
||||
"requires_confirm": metadata.get("requires_confirm", False),
|
||||
"confirm_message": metadata.get("confirm_message", "Are you sure?"),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
# Skip transitions that raise errors during can_proceed check
|
||||
pass
|
||||
@@ -289,14 +298,17 @@ def get_transition_url(obj, transition_name: str) -> str:
|
||||
The URL string for the transition endpoint
|
||||
"""
|
||||
try:
|
||||
return reverse('core:fsm_transition', kwargs={
|
||||
'app_label': obj._meta.app_label,
|
||||
'model_name': obj._meta.model_name,
|
||||
'pk': obj.pk,
|
||||
'transition_name': transition_name,
|
||||
})
|
||||
return reverse(
|
||||
"core:fsm_transition",
|
||||
kwargs={
|
||||
"app_label": obj._meta.app_label,
|
||||
"model_name": obj._meta.model_name,
|
||||
"pk": obj.pk,
|
||||
"transition_name": transition_name,
|
||||
},
|
||||
)
|
||||
except NoReverseMatch:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
@@ -304,7 +316,7 @@ def get_transition_url(obj, transition_name: str) -> str:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@register.inclusion_tag('htmx/state_actions.html', takes_context=True)
|
||||
@register.inclusion_tag("htmx/state_actions.html", takes_context=True)
|
||||
def render_state_actions(context, obj, user=None, **kwargs):
|
||||
"""
|
||||
Render the state action buttons for an FSM-enabled object.
|
||||
@@ -323,17 +335,17 @@ def render_state_actions(context, obj, user=None, **kwargs):
|
||||
Context for the state_actions.html template
|
||||
"""
|
||||
if user is None:
|
||||
user = context.get('request', {}).user if 'request' in context else None
|
||||
user = context.get("request", {}).user if "request" in context else None
|
||||
|
||||
return {
|
||||
'object': obj,
|
||||
'user': user,
|
||||
'request': context.get('request'),
|
||||
"object": obj,
|
||||
"user": user,
|
||||
"request": context.get("request"),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
@register.inclusion_tag('htmx/status_with_actions.html', takes_context=True)
|
||||
@register.inclusion_tag("htmx/status_with_actions.html", takes_context=True)
|
||||
def render_status_with_actions(context, obj, user=None, **kwargs):
|
||||
"""
|
||||
Render the status badge with action buttons for an FSM-enabled object.
|
||||
@@ -352,12 +364,12 @@ def render_status_with_actions(context, obj, user=None, **kwargs):
|
||||
Context for the status_with_actions.html template
|
||||
"""
|
||||
if user is None:
|
||||
user = context.get('request', {}).user if 'request' in context else None
|
||||
user = context.get("request", {}).user if "request" in context else None
|
||||
|
||||
return {
|
||||
'object': obj,
|
||||
'user': user,
|
||||
'request': context.get('request'),
|
||||
"object": obj,
|
||||
"user": user,
|
||||
"request": context.get("request"),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
@@ -384,28 +396,28 @@ def _format_transition_label(transition_name: str) -> str:
|
||||
"""
|
||||
# Remove common prefixes
|
||||
label = transition_name
|
||||
for prefix in ['transition_to_', 'transition_', 'do_']:
|
||||
for prefix in ["transition_to_", "transition_", "do_"]:
|
||||
if label.startswith(prefix):
|
||||
label = label[len(prefix):]
|
||||
label = label[len(prefix) :]
|
||||
break
|
||||
|
||||
# Remove past tense suffix and capitalize
|
||||
# e.g., 'approved' -> 'Approve'
|
||||
if label.endswith('ed') and len(label) > 3:
|
||||
if label.endswith("ed") and len(label) > 3:
|
||||
# Handle special cases
|
||||
if label.endswith('ied'):
|
||||
label = label[:-3] + 'y'
|
||||
if label.endswith("ied"):
|
||||
label = label[:-3] + "y"
|
||||
elif label[-3] == label[-4]: # doubled consonant (e.g., 'submitted')
|
||||
label = label[:-3]
|
||||
else:
|
||||
label = label[:-1] # Remove 'd'
|
||||
if label.endswith('e'):
|
||||
if label.endswith("e"):
|
||||
pass # Keep the 'e' for words like 'approve'
|
||||
else:
|
||||
label = label[:-1] # Remove 'e' for words like 'rejected' -> 'reject'
|
||||
|
||||
# Replace underscores with spaces and title case
|
||||
label = label.replace('_', ' ').title()
|
||||
label = label.replace("_", " ").title()
|
||||
|
||||
return label
|
||||
|
||||
@@ -418,17 +430,17 @@ def _format_transition_label(transition_name: str) -> str:
|
||||
# Ensure all tags and filters are registered
|
||||
__all__ = [
|
||||
# Filters
|
||||
'get_state_value',
|
||||
'get_state_display',
|
||||
'get_state_choice',
|
||||
'app_label',
|
||||
'model_name',
|
||||
'default_target_id',
|
||||
"get_state_value",
|
||||
"get_state_display",
|
||||
"get_state_choice",
|
||||
"app_label",
|
||||
"model_name",
|
||||
"default_target_id",
|
||||
# Tags
|
||||
'get_available_transitions',
|
||||
'can_transition',
|
||||
'get_transition_url',
|
||||
"get_available_transitions",
|
||||
"can_transition",
|
||||
"get_transition_url",
|
||||
# Inclusion tags
|
||||
'render_state_actions',
|
||||
'render_status_with_actions',
|
||||
"render_state_actions",
|
||||
"render_status_with_actions",
|
||||
]
|
||||
|
||||
@@ -57,7 +57,8 @@ register = template.Library()
|
||||
# HTML Sanitization Filters
|
||||
# =============================================================================
|
||||
|
||||
@register.filter(name='sanitize', is_safe=True)
|
||||
|
||||
@register.filter(name="sanitize", is_safe=True)
|
||||
def sanitize_filter(value):
|
||||
"""
|
||||
Sanitize HTML content to prevent XSS attacks.
|
||||
@@ -68,11 +69,11 @@ def sanitize_filter(value):
|
||||
{{ user_content|sanitize }}
|
||||
"""
|
||||
if not value:
|
||||
return ''
|
||||
return ""
|
||||
return mark_safe(sanitize_html(str(value)))
|
||||
|
||||
|
||||
@register.filter(name='sanitize_minimal', is_safe=True)
|
||||
@register.filter(name="sanitize_minimal", is_safe=True)
|
||||
def sanitize_minimal_filter(value):
|
||||
"""
|
||||
Sanitize HTML with minimal allowed tags.
|
||||
@@ -83,11 +84,11 @@ def sanitize_minimal_filter(value):
|
||||
{{ comment|sanitize_minimal }}
|
||||
"""
|
||||
if not value:
|
||||
return ''
|
||||
return ""
|
||||
return mark_safe(_sanitize_minimal(str(value)))
|
||||
|
||||
|
||||
@register.filter(name='sanitize_svg', is_safe=True)
|
||||
@register.filter(name="sanitize_svg", is_safe=True)
|
||||
def sanitize_svg_filter(value):
|
||||
"""
|
||||
Sanitize SVG content for safe inline rendering.
|
||||
@@ -96,11 +97,11 @@ def sanitize_svg_filter(value):
|
||||
{{ icon_svg|sanitize_svg }}
|
||||
"""
|
||||
if not value:
|
||||
return ''
|
||||
return ""
|
||||
return mark_safe(sanitize_svg(str(value)))
|
||||
|
||||
|
||||
@register.filter(name='strip_html')
|
||||
@register.filter(name="strip_html")
|
||||
def strip_html_filter(value):
|
||||
"""
|
||||
Remove all HTML tags from content.
|
||||
@@ -109,7 +110,7 @@ def strip_html_filter(value):
|
||||
{{ html_content|strip_html }}
|
||||
"""
|
||||
if not value:
|
||||
return ''
|
||||
return ""
|
||||
return _strip_html(str(value))
|
||||
|
||||
|
||||
@@ -117,7 +118,8 @@ def strip_html_filter(value):
|
||||
# JavaScript/JSON Context Filters
|
||||
# =============================================================================
|
||||
|
||||
@register.filter(name='json_safe', is_safe=True)
|
||||
|
||||
@register.filter(name="json_safe", is_safe=True)
|
||||
def json_safe_filter(value):
|
||||
"""
|
||||
Safely serialize data for embedding in JavaScript.
|
||||
@@ -131,11 +133,11 @@ def json_safe_filter(value):
|
||||
</script>
|
||||
"""
|
||||
if value is None:
|
||||
return 'null'
|
||||
return "null"
|
||||
return mark_safe(sanitize_for_json(value))
|
||||
|
||||
|
||||
@register.filter(name='escapejs_safe')
|
||||
@register.filter(name="escapejs_safe")
|
||||
def escapejs_safe_filter(value):
|
||||
"""
|
||||
Escape a string for safe use in JavaScript string literals.
|
||||
@@ -146,7 +148,7 @@ def escapejs_safe_filter(value):
|
||||
</script>
|
||||
"""
|
||||
if not value:
|
||||
return ''
|
||||
return ""
|
||||
return _escape_js_string(str(value))
|
||||
|
||||
|
||||
@@ -154,7 +156,8 @@ def escapejs_safe_filter(value):
|
||||
# URL and Attribute Filters
|
||||
# =============================================================================
|
||||
|
||||
@register.filter(name='sanitize_url')
|
||||
|
||||
@register.filter(name="sanitize_url")
|
||||
def sanitize_url_filter(value):
|
||||
"""
|
||||
Sanitize a URL to prevent javascript: and other dangerous protocols.
|
||||
@@ -163,11 +166,11 @@ def sanitize_url_filter(value):
|
||||
<a href="{{ user_url|sanitize_url }}">Link</a>
|
||||
"""
|
||||
if not value:
|
||||
return ''
|
||||
return ""
|
||||
return _sanitize_url(str(value))
|
||||
|
||||
|
||||
@register.filter(name='attr_safe')
|
||||
@register.filter(name="attr_safe")
|
||||
def attr_safe_filter(value):
|
||||
"""
|
||||
Escape a value for safe use in HTML attributes.
|
||||
@@ -176,7 +179,7 @@ def attr_safe_filter(value):
|
||||
<div data-value="{{ user_value|attr_safe }}">
|
||||
"""
|
||||
if not value:
|
||||
return ''
|
||||
return ""
|
||||
return sanitize_attribute_value(str(value))
|
||||
|
||||
|
||||
@@ -187,36 +190,36 @@ def attr_safe_filter(value):
|
||||
# Predefined safe SVG icons
|
||||
# These are trusted and can be rendered without sanitization
|
||||
BUILTIN_ICONS = {
|
||||
'check': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M4.5 12.75l6 6 9-13.5" /></svg>''',
|
||||
'x': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M6 18L18 6M6 6l12 12" /></svg>''',
|
||||
'plus': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M12 4.5v15m7.5-7.5h-15" /></svg>''',
|
||||
'minus': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M19.5 12h-15" /></svg>''',
|
||||
'chevron-down': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M19.5 8.25l-7.5 7.5-7.5-7.5" /></svg>''',
|
||||
'chevron-up': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M4.5 15.75l7.5-7.5 7.5 7.5" /></svg>''',
|
||||
'chevron-left': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M15.75 19.5L8.25 12l7.5-7.5" /></svg>''',
|
||||
'chevron-right': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M8.25 4.5l7.5 7.5-7.5 7.5" /></svg>''',
|
||||
'search': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M21 21l-5.197-5.197m0 0A7.5 7.5 0 105.196 5.196a7.5 7.5 0 0010.607 10.607z" /></svg>''',
|
||||
'menu': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3.75 6.75h16.5M3.75 12h16.5m-16.5 5.25h16.5" /></svg>''',
|
||||
'user': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M15.75 6a3.75 3.75 0 11-7.5 0 3.75 3.75 0 017.5 0zM4.501 20.118a7.5 7.5 0 0114.998 0A17.933 17.933 0 0112 21.75c-2.676 0-5.216-.584-7.499-1.632z" /></svg>''',
|
||||
'cog': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M9.594 3.94c.09-.542.56-.94 1.11-.94h2.593c.55 0 1.02.398 1.11.94l.213 1.281c.063.374.313.686.645.87.074.04.147.083.22.127.324.196.72.257 1.075.124l1.217-.456a1.125 1.125 0 011.37.49l1.296 2.247a1.125 1.125 0 01-.26 1.431l-1.003.827c-.293.24-.438.613-.431.992a6.759 6.759 0 010 .255c-.007.378.138.75.43.99l1.005.828c.424.35.534.954.26 1.43l-1.298 2.247a1.125 1.125 0 01-1.369.491l-1.217-.456c-.355-.133-.75-.072-1.076.124a6.57 6.57 0 01-.22.128c-.331.183-.581.495-.644.869l-.213 1.28c-.09.543-.56.941-1.11.941h-2.594c-.55 0-1.02-.398-1.11-.94l-.213-1.281c-.062-.374-.312-.686-.644-.87a6.52 6.52 0 01-.22-.127c-.325-.196-.72-.257-1.076-.124l-1.217.456a1.125 1.125 0 01-1.369-.49l-1.297-2.247a1.125 1.125 0 01.26-1.431l1.004-.827c.292-.24.437-.613.43-.992a6.932 6.932 0 010-.255c.007-.378-.138-.75-.43-.99l-1.004-.828a1.125 1.125 0 01-.26-1.43l1.297-2.247a1.125 1.125 0 011.37-.491l1.216.456c.356.133.751.072 1.076-.124.072-.044.146-.087.22-.128.332-.183.582-.495.644-.869l.214-1.281z" /><path stroke-linecap="round" stroke-linejoin="round" d="M15 12a3 3 0 11-6 0 3 3 0 016 0z" /></svg>''',
|
||||
'trash': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M14.74 9l-.346 9m-4.788 0L9.26 9m9.968-3.21c.342.052.682.107 1.022.166m-1.022-.165L18.16 19.673a2.25 2.25 0 01-2.244 2.077H8.084a2.25 2.25 0 01-2.244-2.077L4.772 5.79m14.456 0a48.108 48.108 0 00-3.478-.397m-12 .562c.34-.059.68-.114 1.022-.165m0 0a48.11 48.11 0 013.478-.397m7.5 0v-.916c0-1.18-.91-2.164-2.09-2.201a51.964 51.964 0 00-3.32 0c-1.18.037-2.09 1.022-2.09 2.201v.916m7.5 0a48.667 48.667 0 00-7.5 0" /></svg>''',
|
||||
'pencil': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M16.862 4.487l1.687-1.688a1.875 1.875 0 112.652 2.652L10.582 16.07a4.5 4.5 0 01-1.897 1.13L6 18l.8-2.685a4.5 4.5 0 011.13-1.897l8.932-8.931zm0 0L19.5 7.125M18 14v4.75A2.25 2.25 0 0115.75 21H5.25A2.25 2.25 0 013 18.75V8.25A2.25 2.25 0 015.25 6H10" /></svg>''',
|
||||
'eye': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M2.036 12.322a1.012 1.012 0 010-.639C3.423 7.51 7.36 4.5 12 4.5c4.638 0 8.573 3.007 9.963 7.178.07.207.07.431 0 .639C20.577 16.49 16.64 19.5 12 19.5c-4.638 0-8.573-3.007-9.963-7.178z" /><path stroke-linecap="round" stroke-linejoin="round" d="M15 12a3 3 0 11-6 0 3 3 0 016 0z" /></svg>''',
|
||||
'eye-slash': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3.98 8.223A10.477 10.477 0 001.934 12C3.226 16.338 7.244 19.5 12 19.5c.993 0 1.953-.138 2.863-.395M6.228 6.228A10.45 10.45 0 0112 4.5c4.756 0 8.773 3.162 10.065 7.498a10.523 10.523 0 01-4.293 5.774M6.228 6.228L3 3m3.228 3.228l3.65 3.65m7.894 7.894L21 21m-3.228-3.228l-3.65-3.65m0 0a3 3 0 10-4.243-4.243m4.242 4.242L9.88 9.88" /></svg>''',
|
||||
'arrow-left': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M10.5 19.5L3 12m0 0l7.5-7.5M3 12h18" /></svg>''',
|
||||
'arrow-right': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M13.5 4.5L21 12m0 0l-7.5 7.5M21 12H3" /></svg>''',
|
||||
'info': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M11.25 11.25l.041-.02a.75.75 0 011.063.852l-.708 2.836a.75.75 0 001.063.853l.041-.021M21 12a9 9 0 11-18 0 9 9 0 0118 0zm-9-3.75h.008v.008H12V8.25z" /></svg>''',
|
||||
'warning': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M12 9v3.75m-9.303 3.376c-.866 1.5.217 3.374 1.948 3.374h14.71c1.73 0 2.813-1.874 1.948-3.374L13.949 3.378c-.866-1.5-3.032-1.5-3.898 0L2.697 16.126zM12 15.75h.007v.008H12v-.008z" /></svg>''',
|
||||
'error': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M12 9v3.75m9-.75a9 9 0 11-18 0 9 9 0 0118 0zm-9 3.75h.008v.008H12v-.008z" /></svg>''',
|
||||
'success': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M9 12.75L11.25 15 15 9.75M21 12a9 9 0 11-18 0 9 9 0 0118 0z" /></svg>''',
|
||||
'loading': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" {attrs}><circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle><path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>''',
|
||||
'external-link': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M13.5 6H5.25A2.25 2.25 0 003 8.25v10.5A2.25 2.25 0 005.25 21h10.5A2.25 2.25 0 0018 18.75V10.5m-10.5 6L21 3m0 0h-5.25M21 3v5.25" /></svg>''',
|
||||
'download': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3 16.5v2.25A2.25 2.25 0 005.25 21h13.5A2.25 2.25 0 0021 18.75V16.5M16.5 12L12 16.5m0 0L7.5 12m4.5 4.5V3" /></svg>''',
|
||||
'upload': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3 16.5v2.25A2.25 2.25 0 005.25 21h13.5A2.25 2.25 0 0021 18.75V16.5m-13.5-9L12 3m0 0l4.5 4.5M12 3v13.5" /></svg>''',
|
||||
'star': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M11.48 3.499a.562.562 0 011.04 0l2.125 5.111a.563.563 0 00.475.345l5.518.442c.499.04.701.663.321.988l-4.204 3.602a.563.563 0 00-.182.557l1.285 5.385a.562.562 0 01-.84.61l-4.725-2.885a.563.563 0 00-.586 0L6.982 20.54a.562.562 0 01-.84-.61l1.285-5.386a.562.562 0 00-.182-.557l-4.204-3.602a.563.563 0 01.321-.988l5.518-.442a.563.563 0 00.475-.345L11.48 3.5z" /></svg>''',
|
||||
'star-filled': '''<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" {attrs}><path fill-rule="evenodd" d="M10.788 3.21c.448-1.077 1.976-1.077 2.424 0l2.082 5.007 5.404.433c1.164.093 1.636 1.545.749 2.305l-4.117 3.527 1.257 5.273c.271 1.136-.964 2.033-1.96 1.425L12 18.354 7.373 21.18c-.996.608-2.231-.29-1.96-1.425l1.257-5.273-4.117-3.527c-.887-.76-.415-2.212.749-2.305l5.404-.433 2.082-5.006z" clip-rule="evenodd" /></svg>''',
|
||||
'heart': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M21 8.25c0-2.485-2.099-4.5-4.688-4.5-1.935 0-3.597 1.126-4.312 2.733-.715-1.607-2.377-2.733-4.313-2.733C5.1 3.75 3 5.765 3 8.25c0 7.22 9 12 9 12s9-4.78 9-12z" /></svg>''',
|
||||
'heart-filled': '''<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" {attrs}><path d="M11.645 20.91l-.007-.003-.022-.012a15.247 15.247 0 01-.383-.218 25.18 25.18 0 01-4.244-3.17C4.688 15.36 2.25 12.174 2.25 8.25 2.25 5.322 4.714 3 7.688 3A5.5 5.5 0 0112 5.052 5.5 5.5 0 0116.313 3c2.973 0 5.437 2.322 5.437 5.25 0 3.925-2.438 7.111-4.739 9.256a25.175 25.175 0 01-4.244 3.17 15.247 15.247 0 01-.383.219l-.022.012-.007.004-.003.001a.752.752 0 01-.704 0l-.003-.001z" /></svg>''',
|
||||
"check": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M4.5 12.75l6 6 9-13.5" /></svg>""",
|
||||
"x": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M6 18L18 6M6 6l12 12" /></svg>""",
|
||||
"plus": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M12 4.5v15m7.5-7.5h-15" /></svg>""",
|
||||
"minus": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M19.5 12h-15" /></svg>""",
|
||||
"chevron-down": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M19.5 8.25l-7.5 7.5-7.5-7.5" /></svg>""",
|
||||
"chevron-up": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M4.5 15.75l7.5-7.5 7.5 7.5" /></svg>""",
|
||||
"chevron-left": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M15.75 19.5L8.25 12l7.5-7.5" /></svg>""",
|
||||
"chevron-right": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M8.25 4.5l7.5 7.5-7.5 7.5" /></svg>""",
|
||||
"search": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M21 21l-5.197-5.197m0 0A7.5 7.5 0 105.196 5.196a7.5 7.5 0 0010.607 10.607z" /></svg>""",
|
||||
"menu": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3.75 6.75h16.5M3.75 12h16.5m-16.5 5.25h16.5" /></svg>""",
|
||||
"user": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M15.75 6a3.75 3.75 0 11-7.5 0 3.75 3.75 0 017.5 0zM4.501 20.118a7.5 7.5 0 0114.998 0A17.933 17.933 0 0112 21.75c-2.676 0-5.216-.584-7.499-1.632z" /></svg>""",
|
||||
"cog": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M9.594 3.94c.09-.542.56-.94 1.11-.94h2.593c.55 0 1.02.398 1.11.94l.213 1.281c.063.374.313.686.645.87.074.04.147.083.22.127.324.196.72.257 1.075.124l1.217-.456a1.125 1.125 0 011.37.49l1.296 2.247a1.125 1.125 0 01-.26 1.431l-1.003.827c-.293.24-.438.613-.431.992a6.759 6.759 0 010 .255c-.007.378.138.75.43.99l1.005.828c.424.35.534.954.26 1.43l-1.298 2.247a1.125 1.125 0 01-1.369.491l-1.217-.456c-.355-.133-.75-.072-1.076.124a6.57 6.57 0 01-.22.128c-.331.183-.581.495-.644.869l-.213 1.28c-.09.543-.56.941-1.11.941h-2.594c-.55 0-1.02-.398-1.11-.94l-.213-1.281c-.062-.374-.312-.686-.644-.87a6.52 6.52 0 01-.22-.127c-.325-.196-.72-.257-1.076-.124l-1.217.456a1.125 1.125 0 01-1.369-.49l-1.297-2.247a1.125 1.125 0 01.26-1.431l1.004-.827c.292-.24.437-.613.43-.992a6.932 6.932 0 010-.255c.007-.378-.138-.75-.43-.99l-1.004-.828a1.125 1.125 0 01-.26-1.43l1.297-2.247a1.125 1.125 0 011.37-.491l1.216.456c.356.133.751.072 1.076-.124.072-.044.146-.087.22-.128.332-.183.582-.495.644-.869l.214-1.281z" /><path stroke-linecap="round" stroke-linejoin="round" d="M15 12a3 3 0 11-6 0 3 3 0 016 0z" /></svg>""",
|
||||
"trash": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M14.74 9l-.346 9m-4.788 0L9.26 9m9.968-3.21c.342.052.682.107 1.022.166m-1.022-.165L18.16 19.673a2.25 2.25 0 01-2.244 2.077H8.084a2.25 2.25 0 01-2.244-2.077L4.772 5.79m14.456 0a48.108 48.108 0 00-3.478-.397m-12 .562c.34-.059.68-.114 1.022-.165m0 0a48.11 48.11 0 013.478-.397m7.5 0v-.916c0-1.18-.91-2.164-2.09-2.201a51.964 51.964 0 00-3.32 0c-1.18.037-2.09 1.022-2.09 2.201v.916m7.5 0a48.667 48.667 0 00-7.5 0" /></svg>""",
|
||||
"pencil": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M16.862 4.487l1.687-1.688a1.875 1.875 0 112.652 2.652L10.582 16.07a4.5 4.5 0 01-1.897 1.13L6 18l.8-2.685a4.5 4.5 0 011.13-1.897l8.932-8.931zm0 0L19.5 7.125M18 14v4.75A2.25 2.25 0 0115.75 21H5.25A2.25 2.25 0 013 18.75V8.25A2.25 2.25 0 015.25 6H10" /></svg>""",
|
||||
"eye": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M2.036 12.322a1.012 1.012 0 010-.639C3.423 7.51 7.36 4.5 12 4.5c4.638 0 8.573 3.007 9.963 7.178.07.207.07.431 0 .639C20.577 16.49 16.64 19.5 12 19.5c-4.638 0-8.573-3.007-9.963-7.178z" /><path stroke-linecap="round" stroke-linejoin="round" d="M15 12a3 3 0 11-6 0 3 3 0 016 0z" /></svg>""",
|
||||
"eye-slash": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3.98 8.223A10.477 10.477 0 001.934 12C3.226 16.338 7.244 19.5 12 19.5c.993 0 1.953-.138 2.863-.395M6.228 6.228A10.45 10.45 0 0112 4.5c4.756 0 8.773 3.162 10.065 7.498a10.523 10.523 0 01-4.293 5.774M6.228 6.228L3 3m3.228 3.228l3.65 3.65m7.894 7.894L21 21m-3.228-3.228l-3.65-3.65m0 0a3 3 0 10-4.243-4.243m4.242 4.242L9.88 9.88" /></svg>""",
|
||||
"arrow-left": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M10.5 19.5L3 12m0 0l7.5-7.5M3 12h18" /></svg>""",
|
||||
"arrow-right": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M13.5 4.5L21 12m0 0l-7.5 7.5M21 12H3" /></svg>""",
|
||||
"info": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M11.25 11.25l.041-.02a.75.75 0 011.063.852l-.708 2.836a.75.75 0 001.063.853l.041-.021M21 12a9 9 0 11-18 0 9 9 0 0118 0zm-9-3.75h.008v.008H12V8.25z" /></svg>""",
|
||||
"warning": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M12 9v3.75m-9.303 3.376c-.866 1.5.217 3.374 1.948 3.374h14.71c1.73 0 2.813-1.874 1.948-3.374L13.949 3.378c-.866-1.5-3.032-1.5-3.898 0L2.697 16.126zM12 15.75h.007v.008H12v-.008z" /></svg>""",
|
||||
"error": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M12 9v3.75m9-.75a9 9 0 11-18 0 9 9 0 0118 0zm-9 3.75h.008v.008H12v-.008z" /></svg>""",
|
||||
"success": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M9 12.75L11.25 15 15 9.75M21 12a9 9 0 11-18 0 9 9 0 0118 0z" /></svg>""",
|
||||
"loading": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" {attrs}><circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle><path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>""",
|
||||
"external-link": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M13.5 6H5.25A2.25 2.25 0 003 8.25v10.5A2.25 2.25 0 005.25 21h10.5A2.25 2.25 0 0018 18.75V10.5m-10.5 6L21 3m0 0h-5.25M21 3v5.25" /></svg>""",
|
||||
"download": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3 16.5v2.25A2.25 2.25 0 005.25 21h13.5A2.25 2.25 0 0021 18.75V16.5M16.5 12L12 16.5m0 0L7.5 12m4.5 4.5V3" /></svg>""",
|
||||
"upload": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3 16.5v2.25A2.25 2.25 0 005.25 21h13.5A2.25 2.25 0 0021 18.75V16.5m-13.5-9L12 3m0 0l4.5 4.5M12 3v13.5" /></svg>""",
|
||||
"star": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M11.48 3.499a.562.562 0 011.04 0l2.125 5.111a.563.563 0 00.475.345l5.518.442c.499.04.701.663.321.988l-4.204 3.602a.563.563 0 00-.182.557l1.285 5.385a.562.562 0 01-.84.61l-4.725-2.885a.563.563 0 00-.586 0L6.982 20.54a.562.562 0 01-.84-.61l1.285-5.386a.562.562 0 00-.182-.557l-4.204-3.602a.563.563 0 01.321-.988l5.518-.442a.563.563 0 00.475-.345L11.48 3.5z" /></svg>""",
|
||||
"star-filled": """<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" {attrs}><path fill-rule="evenodd" d="M10.788 3.21c.448-1.077 1.976-1.077 2.424 0l2.082 5.007 5.404.433c1.164.093 1.636 1.545.749 2.305l-4.117 3.527 1.257 5.273c.271 1.136-.964 2.033-1.96 1.425L12 18.354 7.373 21.18c-.996.608-2.231-.29-1.96-1.425l1.257-5.273-4.117-3.527c-.887-.76-.415-2.212.749-2.305l5.404-.433 2.082-5.006z" clip-rule="evenodd" /></svg>""",
|
||||
"heart": """<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M21 8.25c0-2.485-2.099-4.5-4.688-4.5-1.935 0-3.597 1.126-4.312 2.733-.715-1.607-2.377-2.733-4.313-2.733C5.1 3.75 3 5.765 3 8.25c0 7.22 9 12 9 12s9-4.78 9-12z" /></svg>""",
|
||||
"heart-filled": """<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" {attrs}><path d="M11.645 20.91l-.007-.003-.022-.012a15.247 15.247 0 01-.383-.218 25.18 25.18 0 01-4.244-3.17C4.688 15.36 2.25 12.174 2.25 8.25 2.25 5.322 4.714 3 7.688 3A5.5 5.5 0 0112 5.052 5.5 5.5 0 0116.313 3c2.973 0 5.437 2.322 5.437 5.25 0 3.925-2.438 7.111-4.739 9.256a25.175 25.175 0 01-4.244 3.17 15.247 15.247 0 01-.383.219l-.022.012-.007.004-.003.001a.752.752 0 01-.704 0l-.003-.001z" /></svg>""",
|
||||
}
|
||||
|
||||
|
||||
@@ -243,18 +246,18 @@ def icon(name, **kwargs):
|
||||
|
||||
if not svg_template:
|
||||
# Return empty string for unknown icons (fail silently)
|
||||
return ''
|
||||
return ""
|
||||
|
||||
# Build attributes string
|
||||
attrs_list = []
|
||||
for key, value in kwargs.items():
|
||||
# Convert underscore to hyphen for HTML attributes (e.g., aria_hidden -> aria-hidden)
|
||||
attr_name = key.replace('_', '-')
|
||||
attr_name = key.replace("_", "-")
|
||||
# Escape attribute values to prevent XSS
|
||||
safe_value = sanitize_attribute_value(str(value))
|
||||
attrs_list.append(f'{attr_name}="{safe_value}"')
|
||||
|
||||
attrs_str = ' '.join(attrs_list)
|
||||
attrs_str = " ".join(attrs_list)
|
||||
|
||||
# Substitute attributes into template
|
||||
svg = svg_template.format(attrs=attrs_str)
|
||||
@@ -263,7 +266,7 @@ def icon(name, **kwargs):
|
||||
|
||||
|
||||
@register.simple_tag
|
||||
def icon_class(name, size='w-5 h-5', extra_class=''):
|
||||
def icon_class(name, size="w-5 h-5", extra_class=""):
|
||||
"""
|
||||
Render a trusted SVG icon with common class presets.
|
||||
|
||||
@@ -278,5 +281,5 @@ def icon_class(name, size='w-5 h-5', extra_class=''):
|
||||
Returns:
|
||||
Safe HTML for the icon SVG
|
||||
"""
|
||||
classes = f'{size} {extra_class}'.strip()
|
||||
return icon(name, **{'class': classes})
|
||||
classes = f"{size} {extra_class}".strip()
|
||||
return icon(name, **{"class": classes})
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import pghistory
|
||||
import pytest
|
||||
from django.contrib.auth import get_user_model
|
||||
@@ -7,6 +6,7 @@ from apps.parks.models import Company, Park
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
@pytest.mark.django_db
|
||||
class TestTrackedModel:
|
||||
"""
|
||||
@@ -20,10 +20,7 @@ class TestTrackedModel:
|
||||
|
||||
with pghistory.context(user=user.id):
|
||||
park = Park.objects.create(
|
||||
name="History Test Park",
|
||||
description="Testing history",
|
||||
operating_season="Summer",
|
||||
operator=company
|
||||
name="History Test Park", description="Testing history", operating_season="Summer", operator=company
|
||||
)
|
||||
|
||||
# Verify history using the helper method from TrackedModel
|
||||
@@ -50,6 +47,5 @@ class TestTrackedModel:
|
||||
park.save()
|
||||
|
||||
assert park.get_history().count() == 2
|
||||
latest = park.get_history().first() # Ordered by -pgh_created_at
|
||||
latest = park.get_history().first() # Ordered by -pgh_created_at
|
||||
assert latest.name == "Updated"
|
||||
|
||||
|
||||
@@ -17,9 +17,7 @@ app_name = "core"
|
||||
entity_patterns = [
|
||||
path("search/", EntityFuzzySearchView.as_view(), name="entity_fuzzy_search"),
|
||||
path("not-found/", EntityNotFoundView.as_view(), name="entity_not_found"),
|
||||
path(
|
||||
"suggestions/", QuickEntitySuggestionView.as_view(), name="entity_suggestions"
|
||||
),
|
||||
path("suggestions/", QuickEntitySuggestionView.as_view(), name="entity_suggestions"),
|
||||
]
|
||||
|
||||
# FSM transition endpoints
|
||||
|
||||
@@ -6,6 +6,7 @@ from django.core.exceptions import ImproperlyConfigured
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_direct_upload_url(user_id=None):
|
||||
"""
|
||||
Generates a direct upload URL for Cloudflare Images.
|
||||
@@ -20,13 +21,11 @@ def get_direct_upload_url(user_id=None):
|
||||
ImproperlyConfigured: If Cloudflare settings are missing.
|
||||
requests.RequestException: If the Cloudflare API request fails.
|
||||
"""
|
||||
account_id = getattr(settings, 'CLOUDFLARE_IMAGES_ACCOUNT_ID', None)
|
||||
api_token = getattr(settings, 'CLOUDFLARE_IMAGES_API_TOKEN', None)
|
||||
account_id = getattr(settings, "CLOUDFLARE_IMAGES_ACCOUNT_ID", None)
|
||||
api_token = getattr(settings, "CLOUDFLARE_IMAGES_API_TOKEN", None)
|
||||
|
||||
if not account_id or not api_token:
|
||||
raise ImproperlyConfigured(
|
||||
"CLOUDFLARE_IMAGES_ACCOUNT_ID and CLOUDFLARE_IMAGES_API_TOKEN must be set."
|
||||
)
|
||||
raise ImproperlyConfigured("CLOUDFLARE_IMAGES_ACCOUNT_ID and CLOUDFLARE_IMAGES_API_TOKEN must be set.")
|
||||
|
||||
url = f"https://api.cloudflare.com/client/v4/accounts/{account_id}/images/v2/direct_upload"
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ from django.core.files.uploadedfile import UploadedFile
|
||||
|
||||
class FileValidationError(ValidationError):
|
||||
"""Custom exception for file validation errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -47,41 +48,49 @@ class FileValidationError(ValidationError):
|
||||
# Magic number signatures for common image formats
|
||||
# Format: (magic_bytes, offset, description)
|
||||
IMAGE_SIGNATURES = {
|
||||
'jpeg': [
|
||||
(b'\xFF\xD8\xFF\xE0', 0, 'JPEG (JFIF)'),
|
||||
(b'\xFF\xD8\xFF\xE1', 0, 'JPEG (EXIF)'),
|
||||
(b'\xFF\xD8\xFF\xE2', 0, 'JPEG (ICC)'),
|
||||
(b'\xFF\xD8\xFF\xE3', 0, 'JPEG (Samsung)'),
|
||||
(b'\xFF\xD8\xFF\xE8', 0, 'JPEG (SPIFF)'),
|
||||
(b'\xFF\xD8\xFF\xDB', 0, 'JPEG (Raw)'),
|
||||
"jpeg": [
|
||||
(b"\xff\xd8\xff\xe0", 0, "JPEG (JFIF)"),
|
||||
(b"\xff\xd8\xff\xe1", 0, "JPEG (EXIF)"),
|
||||
(b"\xff\xd8\xff\xe2", 0, "JPEG (ICC)"),
|
||||
(b"\xff\xd8\xff\xe3", 0, "JPEG (Samsung)"),
|
||||
(b"\xff\xd8\xff\xe8", 0, "JPEG (SPIFF)"),
|
||||
(b"\xff\xd8\xff\xdb", 0, "JPEG (Raw)"),
|
||||
],
|
||||
'png': [
|
||||
(b'\x89PNG\r\n\x1a\n', 0, 'PNG'),
|
||||
"png": [
|
||||
(b"\x89PNG\r\n\x1a\n", 0, "PNG"),
|
||||
],
|
||||
'gif': [
|
||||
(b'GIF87a', 0, 'GIF87a'),
|
||||
(b'GIF89a', 0, 'GIF89a'),
|
||||
"gif": [
|
||||
(b"GIF87a", 0, "GIF87a"),
|
||||
(b"GIF89a", 0, "GIF89a"),
|
||||
],
|
||||
'webp': [
|
||||
(b'RIFF', 0, 'RIFF'), # WebP starts with RIFF header
|
||||
"webp": [
|
||||
(b"RIFF", 0, "RIFF"), # WebP starts with RIFF header
|
||||
],
|
||||
'bmp': [
|
||||
(b'BM', 0, 'BMP'),
|
||||
"bmp": [
|
||||
(b"BM", 0, "BMP"),
|
||||
],
|
||||
}
|
||||
|
||||
# All allowed MIME types
|
||||
ALLOWED_IMAGE_MIME_TYPES: set[str] = frozenset({
|
||||
'image/jpeg',
|
||||
'image/png',
|
||||
'image/gif',
|
||||
'image/webp',
|
||||
})
|
||||
ALLOWED_IMAGE_MIME_TYPES: set[str] = frozenset(
|
||||
{
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
}
|
||||
)
|
||||
|
||||
# Allowed file extensions
|
||||
ALLOWED_IMAGE_EXTENSIONS: set[str] = frozenset({
|
||||
'.jpg', '.jpeg', '.png', '.gif', '.webp',
|
||||
})
|
||||
ALLOWED_IMAGE_EXTENSIONS: set[str] = frozenset(
|
||||
{
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".gif",
|
||||
".webp",
|
||||
}
|
||||
)
|
||||
|
||||
# Maximum file size (10MB)
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024
|
||||
@@ -94,6 +103,7 @@ MIN_FILE_SIZE = 100 # 100 bytes
|
||||
# File Validation Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def validate_image_upload(
|
||||
file: UploadedFile,
|
||||
max_size: int = MAX_FILE_SIZE,
|
||||
@@ -133,39 +143,29 @@ def validate_image_upload(
|
||||
|
||||
# 2. Check file size
|
||||
if file.size > max_size:
|
||||
raise FileValidationError(
|
||||
f"File too large. Maximum size is {max_size // (1024 * 1024)}MB"
|
||||
)
|
||||
raise FileValidationError(f"File too large. Maximum size is {max_size // (1024 * 1024)}MB")
|
||||
|
||||
if file.size < MIN_FILE_SIZE:
|
||||
raise FileValidationError("File too small or empty")
|
||||
|
||||
# 3. Check file extension
|
||||
filename = file.name or ''
|
||||
filename = file.name or ""
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext not in allowed_extensions:
|
||||
raise FileValidationError(
|
||||
f"Invalid file extension '{ext}'. Allowed: {', '.join(allowed_extensions)}"
|
||||
)
|
||||
raise FileValidationError(f"Invalid file extension '{ext}'. Allowed: {', '.join(allowed_extensions)}")
|
||||
|
||||
# 4. Check Content-Type header
|
||||
content_type = getattr(file, 'content_type', '')
|
||||
content_type = getattr(file, "content_type", "")
|
||||
if content_type and content_type not in allowed_types:
|
||||
raise FileValidationError(
|
||||
f"Invalid file type '{content_type}'. Allowed: {', '.join(allowed_types)}"
|
||||
)
|
||||
raise FileValidationError(f"Invalid file type '{content_type}'. Allowed: {', '.join(allowed_types)}")
|
||||
|
||||
# 5. Validate magic numbers (actual file content)
|
||||
if not _validate_magic_number(file):
|
||||
raise FileValidationError(
|
||||
"File content doesn't match file extension. File may be corrupted or malicious."
|
||||
)
|
||||
raise FileValidationError("File content doesn't match file extension. File may be corrupted or malicious.")
|
||||
|
||||
# 6. Validate image integrity using PIL
|
||||
if not _validate_image_integrity(file):
|
||||
raise FileValidationError(
|
||||
"Invalid or corrupted image file"
|
||||
)
|
||||
raise FileValidationError("Invalid or corrupted image file")
|
||||
|
||||
return True
|
||||
|
||||
@@ -191,10 +191,10 @@ def _validate_magic_number(file: UploadedFile) -> bool:
|
||||
# Check against known signatures
|
||||
for format_name, signatures in IMAGE_SIGNATURES.items():
|
||||
for magic, offset, _description in signatures:
|
||||
if len(header) >= offset + len(magic) and header[offset:offset + len(magic)] == magic:
|
||||
if len(header) >= offset + len(magic) and header[offset : offset + len(magic)] == magic:
|
||||
# Special handling for WebP (must also have WEBP marker)
|
||||
if format_name == 'webp':
|
||||
if len(header) >= 12 and header[8:12] == b'WEBP':
|
||||
if format_name == "webp":
|
||||
if len(header) >= 12 and header[8:12] == b"WEBP":
|
||||
return True
|
||||
else:
|
||||
return True
|
||||
@@ -233,9 +233,7 @@ def _validate_image_integrity(file: UploadedFile) -> bool:
|
||||
# Prevent decompression bombs
|
||||
max_dimension = 10000
|
||||
if img2.width > max_dimension or img2.height > max_dimension:
|
||||
raise FileValidationError(
|
||||
f"Image dimensions too large. Maximum is {max_dimension}x{max_dimension}"
|
||||
)
|
||||
raise FileValidationError(f"Image dimensions too large. Maximum is {max_dimension}x{max_dimension}")
|
||||
|
||||
# Check for very small dimensions (might be suspicious)
|
||||
if img2.width < 1 or img2.height < 1:
|
||||
@@ -253,6 +251,7 @@ def _validate_image_integrity(file: UploadedFile) -> bool:
|
||||
# Filename Sanitization
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def sanitize_filename(filename: str, max_length: int = 100) -> str:
|
||||
"""
|
||||
Sanitize a filename to prevent directory traversal and other attacks.
|
||||
@@ -281,13 +280,13 @@ def sanitize_filename(filename: str, max_length: int = 100) -> str:
|
||||
|
||||
# Remove or replace dangerous characters from name
|
||||
# Allow alphanumeric, hyphens, underscores, dots
|
||||
name = re.sub(r'[^\w\-.]', '_', name)
|
||||
name = re.sub(r"[^\w\-.]", "_", name)
|
||||
|
||||
# Remove leading dots and underscores (hidden file prevention)
|
||||
name = name.lstrip('._')
|
||||
name = name.lstrip("._")
|
||||
|
||||
# Collapse multiple underscores
|
||||
name = re.sub(r'_+', '_', name)
|
||||
name = re.sub(r"_+", "_", name)
|
||||
|
||||
# Ensure name is not empty
|
||||
if not name:
|
||||
@@ -295,7 +294,7 @@ def sanitize_filename(filename: str, max_length: int = 100) -> str:
|
||||
|
||||
# Sanitize extension
|
||||
ext = ext.lower()
|
||||
ext = re.sub(r'[^\w.]', '', ext)
|
||||
ext = re.sub(r"[^\w.]", "", ext)
|
||||
|
||||
# Combine and truncate
|
||||
result = f"{name[:max_length - len(ext)]}{ext}"
|
||||
@@ -303,7 +302,7 @@ def sanitize_filename(filename: str, max_length: int = 100) -> str:
|
||||
return result
|
||||
|
||||
|
||||
def generate_unique_filename(original_filename: str, prefix: str = '') -> str:
|
||||
def generate_unique_filename(original_filename: str, prefix: str = "") -> str:
|
||||
"""
|
||||
Generate a unique filename using UUID while preserving extension.
|
||||
|
||||
@@ -317,7 +316,7 @@ def generate_unique_filename(original_filename: str, prefix: str = '') -> str:
|
||||
ext = os.path.splitext(original_filename)[1].lower()
|
||||
|
||||
# Sanitize extension
|
||||
ext = re.sub(r'[^\w.]', '', ext)
|
||||
ext = re.sub(r"[^\w.]", "", ext)
|
||||
|
||||
# Generate unique filename
|
||||
unique_id = uuid.uuid4().hex[:12]
|
||||
@@ -332,9 +331,9 @@ def generate_unique_filename(original_filename: str, prefix: str = '') -> str:
|
||||
|
||||
# Rate limiting configuration
|
||||
UPLOAD_RATE_LIMITS = {
|
||||
'per_minute': 10,
|
||||
'per_hour': 100,
|
||||
'per_day': 500,
|
||||
"per_minute": 10,
|
||||
"per_hour": 100,
|
||||
"per_day": 500,
|
||||
}
|
||||
|
||||
|
||||
@@ -351,24 +350,25 @@ def check_upload_rate_limit(user_id: int, cache_backend=None) -> tuple[bool, str
|
||||
"""
|
||||
if cache_backend is None:
|
||||
from django.core.cache import cache
|
||||
|
||||
cache_backend = cache
|
||||
|
||||
# Check per-minute limit
|
||||
minute_key = f"upload_rate:{user_id}:minute"
|
||||
minute_count = cache_backend.get(minute_key, 0)
|
||||
if minute_count >= UPLOAD_RATE_LIMITS['per_minute']:
|
||||
if minute_count >= UPLOAD_RATE_LIMITS["per_minute"]:
|
||||
return False, "Upload rate limit exceeded. Please wait a minute."
|
||||
|
||||
# Check per-hour limit
|
||||
hour_key = f"upload_rate:{user_id}:hour"
|
||||
hour_count = cache_backend.get(hour_key, 0)
|
||||
if hour_count >= UPLOAD_RATE_LIMITS['per_hour']:
|
||||
if hour_count >= UPLOAD_RATE_LIMITS["per_hour"]:
|
||||
return False, "Hourly upload limit exceeded. Please try again later."
|
||||
|
||||
# Check per-day limit
|
||||
day_key = f"upload_rate:{user_id}:day"
|
||||
day_count = cache_backend.get(day_key, 0)
|
||||
if day_count >= UPLOAD_RATE_LIMITS['per_day']:
|
||||
if day_count >= UPLOAD_RATE_LIMITS["per_day"]:
|
||||
return False, "Daily upload limit exceeded. Please try again tomorrow."
|
||||
|
||||
return True, ""
|
||||
@@ -384,6 +384,7 @@ def increment_upload_count(user_id: int, cache_backend=None) -> None:
|
||||
"""
|
||||
if cache_backend is None:
|
||||
from django.core.cache import cache
|
||||
|
||||
cache_backend = cache
|
||||
|
||||
# Increment per-minute counter (expires in 60 seconds)
|
||||
@@ -412,6 +413,7 @@ def increment_upload_count(user_id: int, cache_backend=None) -> None:
|
||||
# Antivirus Integration Point
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def scan_file_for_malware(file: UploadedFile) -> tuple[bool, str]:
|
||||
"""
|
||||
Placeholder for antivirus/malware scanning integration.
|
||||
|
||||
@@ -26,6 +26,7 @@ from typing import Any
|
||||
|
||||
try:
|
||||
import bleach
|
||||
|
||||
BLEACH_AVAILABLE = True
|
||||
except ImportError:
|
||||
BLEACH_AVAILABLE = False
|
||||
@@ -36,71 +37,135 @@ except ImportError:
|
||||
# =============================================================================
|
||||
|
||||
# Default allowed HTML tags for user-generated content
|
||||
ALLOWED_TAGS = frozenset([
|
||||
# Text formatting
|
||||
'p', 'br', 'hr',
|
||||
'strong', 'b', 'em', 'i', 'u', 's', 'strike',
|
||||
'sub', 'sup', 'small', 'mark',
|
||||
|
||||
# Headers
|
||||
'h1', 'h2', 'h3', 'h4', 'h5', 'h6',
|
||||
|
||||
# Lists
|
||||
'ul', 'ol', 'li',
|
||||
|
||||
# Links (with restrictions on attributes)
|
||||
'a',
|
||||
|
||||
# Block elements
|
||||
'blockquote', 'pre', 'code',
|
||||
'div', 'span',
|
||||
|
||||
# Tables
|
||||
'table', 'thead', 'tbody', 'tfoot', 'tr', 'th', 'td',
|
||||
])
|
||||
ALLOWED_TAGS = frozenset(
|
||||
[
|
||||
# Text formatting
|
||||
"p",
|
||||
"br",
|
||||
"hr",
|
||||
"strong",
|
||||
"b",
|
||||
"em",
|
||||
"i",
|
||||
"u",
|
||||
"s",
|
||||
"strike",
|
||||
"sub",
|
||||
"sup",
|
||||
"small",
|
||||
"mark",
|
||||
# Headers
|
||||
"h1",
|
||||
"h2",
|
||||
"h3",
|
||||
"h4",
|
||||
"h5",
|
||||
"h6",
|
||||
# Lists
|
||||
"ul",
|
||||
"ol",
|
||||
"li",
|
||||
# Links (with restrictions on attributes)
|
||||
"a",
|
||||
# Block elements
|
||||
"blockquote",
|
||||
"pre",
|
||||
"code",
|
||||
"div",
|
||||
"span",
|
||||
# Tables
|
||||
"table",
|
||||
"thead",
|
||||
"tbody",
|
||||
"tfoot",
|
||||
"tr",
|
||||
"th",
|
||||
"td",
|
||||
]
|
||||
)
|
||||
|
||||
# Allowed attributes for each tag
|
||||
ALLOWED_ATTRIBUTES = {
|
||||
'a': ['href', 'title', 'rel', 'target'],
|
||||
'img': ['src', 'alt', 'title', 'width', 'height'],
|
||||
'div': ['class'],
|
||||
'span': ['class'],
|
||||
'p': ['class'],
|
||||
'table': ['class'],
|
||||
'th': ['class', 'colspan', 'rowspan'],
|
||||
'td': ['class', 'colspan', 'rowspan'],
|
||||
'*': ['class'], # Allow class on all elements
|
||||
"a": ["href", "title", "rel", "target"],
|
||||
"img": ["src", "alt", "title", "width", "height"],
|
||||
"div": ["class"],
|
||||
"span": ["class"],
|
||||
"p": ["class"],
|
||||
"table": ["class"],
|
||||
"th": ["class", "colspan", "rowspan"],
|
||||
"td": ["class", "colspan", "rowspan"],
|
||||
"*": ["class"], # Allow class on all elements
|
||||
}
|
||||
|
||||
# Allowed URL protocols
|
||||
ALLOWED_PROTOCOLS = frozenset([
|
||||
'http', 'https', 'mailto', 'tel',
|
||||
])
|
||||
ALLOWED_PROTOCOLS = frozenset(
|
||||
[
|
||||
"http",
|
||||
"https",
|
||||
"mailto",
|
||||
"tel",
|
||||
]
|
||||
)
|
||||
|
||||
# Minimal tags for comments and short text
|
||||
MINIMAL_TAGS = frozenset([
|
||||
'p', 'br', 'strong', 'b', 'em', 'i', 'a',
|
||||
])
|
||||
MINIMAL_TAGS = frozenset(
|
||||
[
|
||||
"p",
|
||||
"br",
|
||||
"strong",
|
||||
"b",
|
||||
"em",
|
||||
"i",
|
||||
"a",
|
||||
]
|
||||
)
|
||||
|
||||
# Tags allowed in icon SVGs (for icon template rendering)
|
||||
SVG_TAGS = frozenset([
|
||||
'svg', 'path', 'g', 'circle', 'rect', 'line', 'polyline', 'polygon',
|
||||
'ellipse', 'text', 'tspan', 'defs', 'use', 'symbol', 'clipPath',
|
||||
'mask', 'linearGradient', 'radialGradient', 'stop', 'title',
|
||||
])
|
||||
SVG_TAGS = frozenset(
|
||||
[
|
||||
"svg",
|
||||
"path",
|
||||
"g",
|
||||
"circle",
|
||||
"rect",
|
||||
"line",
|
||||
"polyline",
|
||||
"polygon",
|
||||
"ellipse",
|
||||
"text",
|
||||
"tspan",
|
||||
"defs",
|
||||
"use",
|
||||
"symbol",
|
||||
"clipPath",
|
||||
"mask",
|
||||
"linearGradient",
|
||||
"radialGradient",
|
||||
"stop",
|
||||
"title",
|
||||
]
|
||||
)
|
||||
|
||||
SVG_ATTRIBUTES = {
|
||||
'svg': ['viewBox', 'width', 'height', 'fill', 'stroke', 'class',
|
||||
'xmlns', 'aria-hidden', 'role'],
|
||||
'path': ['d', 'fill', 'stroke', 'stroke-width', 'stroke-linecap',
|
||||
'stroke-linejoin', 'class', 'fill-rule', 'clip-rule'],
|
||||
'g': ['fill', 'stroke', 'transform', 'class'],
|
||||
'circle': ['cx', 'cy', 'r', 'fill', 'stroke', 'class'],
|
||||
'rect': ['x', 'y', 'width', 'height', 'rx', 'ry', 'fill', 'stroke', 'class'],
|
||||
'line': ['x1', 'y1', 'x2', 'y2', 'stroke', 'stroke-width', 'class'],
|
||||
'polyline': ['points', 'fill', 'stroke', 'class'],
|
||||
'polygon': ['points', 'fill', 'stroke', 'class'],
|
||||
'*': ['class', 'fill', 'stroke'],
|
||||
"svg": ["viewBox", "width", "height", "fill", "stroke", "class", "xmlns", "aria-hidden", "role"],
|
||||
"path": [
|
||||
"d",
|
||||
"fill",
|
||||
"stroke",
|
||||
"stroke-width",
|
||||
"stroke-linecap",
|
||||
"stroke-linejoin",
|
||||
"class",
|
||||
"fill-rule",
|
||||
"clip-rule",
|
||||
],
|
||||
"g": ["fill", "stroke", "transform", "class"],
|
||||
"circle": ["cx", "cy", "r", "fill", "stroke", "class"],
|
||||
"rect": ["x", "y", "width", "height", "rx", "ry", "fill", "stroke", "class"],
|
||||
"line": ["x1", "y1", "x2", "y2", "stroke", "stroke-width", "class"],
|
||||
"polyline": ["points", "fill", "stroke", "class"],
|
||||
"polygon": ["points", "fill", "stroke", "class"],
|
||||
"*": ["class", "fill", "stroke"],
|
||||
}
|
||||
|
||||
|
||||
@@ -108,6 +173,7 @@ SVG_ATTRIBUTES = {
|
||||
# Sanitization Functions
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def sanitize_html(
|
||||
html: str | None,
|
||||
allowed_tags: frozenset | None = None,
|
||||
@@ -133,7 +199,7 @@ def sanitize_html(
|
||||
'<p>Hello</p>'
|
||||
"""
|
||||
if not html:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
if not isinstance(html, str):
|
||||
html = str(html)
|
||||
@@ -170,7 +236,7 @@ def sanitize_minimal(html: str | None) -> str:
|
||||
return sanitize_html(
|
||||
html,
|
||||
allowed_tags=MINIMAL_TAGS,
|
||||
allowed_attributes={'a': ['href', 'title']},
|
||||
allowed_attributes={"a": ["href", "title"]},
|
||||
)
|
||||
|
||||
|
||||
@@ -188,7 +254,7 @@ def sanitize_svg(svg: str | None) -> str:
|
||||
Sanitized SVG string safe for inline rendering
|
||||
"""
|
||||
if not svg:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
if not isinstance(svg, str):
|
||||
svg = str(svg)
|
||||
@@ -218,7 +284,7 @@ def strip_html(html: str | None) -> str:
|
||||
Plain text with all HTML tags removed
|
||||
"""
|
||||
if not html:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
if not isinstance(html, str):
|
||||
html = str(html)
|
||||
@@ -227,13 +293,14 @@ def strip_html(html: str | None) -> str:
|
||||
return bleach.clean(html, tags=[], strip=True)
|
||||
else:
|
||||
# Fallback: use regex to strip tags
|
||||
return re.sub(r'<[^>]+>', '', html)
|
||||
return re.sub(r"<[^>]+>", "", html)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# JSON/JavaScript Context Sanitization
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def sanitize_for_json(data: Any) -> str:
|
||||
"""
|
||||
Safely serialize data for embedding in JavaScript/JSON contexts.
|
||||
@@ -251,14 +318,12 @@ def sanitize_for_json(data: Any) -> str:
|
||||
'{"name": "\\u003c/script\\u003e\\u003cscript\\u003ealert(\\"xss\\")"}'
|
||||
"""
|
||||
# JSON encode with safe characters escaped
|
||||
return json.dumps(data, ensure_ascii=False).replace(
|
||||
'<', '\\u003c'
|
||||
).replace(
|
||||
'>', '\\u003e'
|
||||
).replace(
|
||||
'&', '\\u0026'
|
||||
).replace(
|
||||
"'", '\\u0027'
|
||||
return (
|
||||
json.dumps(data, ensure_ascii=False)
|
||||
.replace("<", "\\u003c")
|
||||
.replace(">", "\\u003e")
|
||||
.replace("&", "\\u0026")
|
||||
.replace("'", "\\u0027")
|
||||
)
|
||||
|
||||
|
||||
@@ -273,26 +338,21 @@ def escape_js_string(s: str | None) -> str:
|
||||
Escaped string safe for JavaScript contexts
|
||||
"""
|
||||
if not s:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
if not isinstance(s, str):
|
||||
s = str(s)
|
||||
|
||||
# Escape backslashes first, then other special characters
|
||||
return s.replace('\\', '\\\\').replace(
|
||||
"'", "\\'"
|
||||
).replace(
|
||||
'"', '\\"'
|
||||
).replace(
|
||||
'\n', '\\n'
|
||||
).replace(
|
||||
'\r', '\\r'
|
||||
).replace(
|
||||
'<', '\\u003c'
|
||||
).replace(
|
||||
'>', '\\u003e'
|
||||
).replace(
|
||||
'&', '\\u0026'
|
||||
return (
|
||||
s.replace("\\", "\\\\")
|
||||
.replace("'", "\\'")
|
||||
.replace('"', '\\"')
|
||||
.replace("\n", "\\n")
|
||||
.replace("\r", "\\r")
|
||||
.replace("<", "\\u003c")
|
||||
.replace(">", "\\u003e")
|
||||
.replace("&", "\\u0026")
|
||||
)
|
||||
|
||||
|
||||
@@ -300,6 +360,7 @@ def escape_js_string(s: str | None) -> str:
|
||||
# URL Sanitization
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def sanitize_url(url: str | None, allowed_protocols: frozenset | None = None) -> str:
|
||||
"""
|
||||
Sanitize a URL to prevent javascript: and other dangerous protocols.
|
||||
@@ -312,7 +373,7 @@ def sanitize_url(url: str | None, allowed_protocols: frozenset | None = None) ->
|
||||
Sanitized URL or empty string if unsafe
|
||||
"""
|
||||
if not url:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
if not isinstance(url, str):
|
||||
url = str(url)
|
||||
@@ -320,7 +381,7 @@ def sanitize_url(url: str | None, allowed_protocols: frozenset | None = None) ->
|
||||
url = url.strip()
|
||||
|
||||
if not url:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
protocols = allowed_protocols if allowed_protocols is not None else ALLOWED_PROTOCOLS
|
||||
|
||||
@@ -328,12 +389,12 @@ def sanitize_url(url: str | None, allowed_protocols: frozenset | None = None) ->
|
||||
url_lower = url.lower()
|
||||
|
||||
# Check for javascript:, data:, vbscript:, etc.
|
||||
if ':' in url_lower:
|
||||
protocol = url_lower.split(':')[0]
|
||||
if protocol not in protocols:
|
||||
if ":" in url_lower:
|
||||
protocol = url_lower.split(":")[0]
|
||||
if protocol not in protocols: # noqa: SIM102
|
||||
# Allow relative URLs and anchor links
|
||||
if not (url.startswith('/') or url.startswith('#') or url.startswith('?')):
|
||||
return ''
|
||||
if not (url.startswith("/") or url.startswith("#") or url.startswith("?")):
|
||||
return ""
|
||||
|
||||
return url
|
||||
|
||||
@@ -342,6 +403,7 @@ def sanitize_url(url: str | None, allowed_protocols: frozenset | None = None) ->
|
||||
# Attribute Sanitization
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def sanitize_attribute_value(value: str | None) -> str:
|
||||
"""
|
||||
Sanitize a value for use in HTML attributes.
|
||||
@@ -353,7 +415,7 @@ def sanitize_attribute_value(value: str | None) -> str:
|
||||
Sanitized value safe for HTML attribute contexts
|
||||
"""
|
||||
if not value:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
if not isinstance(value, str):
|
||||
value = str(value)
|
||||
@@ -373,10 +435,10 @@ def sanitize_class_name(name: str | None) -> str:
|
||||
Sanitized class name containing only safe characters
|
||||
"""
|
||||
if not name:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
if not isinstance(name, str):
|
||||
name = str(name)
|
||||
|
||||
# Only allow alphanumeric, hyphens, and underscores
|
||||
return re.sub(r'[^a-zA-Z0-9_-]', '', name)
|
||||
return re.sub(r"[^a-zA-Z0-9_-]", "", name)
|
||||
|
||||
@@ -16,9 +16,7 @@ logger = logging.getLogger("query_optimization")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def track_queries(
|
||||
operation_name: str, warn_threshold: int = 10, time_threshold: float = 1.0
|
||||
):
|
||||
def track_queries(operation_name: str, warn_threshold: int = 10, time_threshold: float = 1.0):
|
||||
"""
|
||||
Context manager to track database queries for specific operations
|
||||
|
||||
@@ -47,15 +45,9 @@ def track_queries(
|
||||
recent_queries = connection.queries[-total_queries:]
|
||||
query_details = [
|
||||
{
|
||||
"sql": (
|
||||
query["sql"][:500] + "..."
|
||||
if len(query["sql"]) > 500
|
||||
else query["sql"]
|
||||
),
|
||||
"sql": (query["sql"][:500] + "..." if len(query["sql"]) > 500 else query["sql"]),
|
||||
"time": float(query["time"]),
|
||||
"duplicate_count": sum(
|
||||
1 for q in recent_queries if q["sql"] == query["sql"]
|
||||
),
|
||||
"duplicate_count": sum(1 for q in recent_queries if q["sql"] == query["sql"]),
|
||||
}
|
||||
for query in recent_queries
|
||||
]
|
||||
@@ -65,22 +57,18 @@ def track_queries(
|
||||
"query_count": total_queries,
|
||||
"execution_time": execution_time,
|
||||
"queries": query_details if settings.DEBUG else [],
|
||||
"slow_queries": [
|
||||
q for q in query_details if q["time"] > 0.1
|
||||
], # Queries slower than 100ms
|
||||
"slow_queries": [q for q in query_details if q["time"] > 0.1], # Queries slower than 100ms
|
||||
}
|
||||
|
||||
# Log warnings for performance issues
|
||||
if total_queries > warn_threshold or execution_time > time_threshold:
|
||||
logger.warning(
|
||||
f"Performance concern in {operation_name}: "
|
||||
f"{total_queries} queries, {execution_time:.2f}s",
|
||||
f"Performance concern in {operation_name}: " f"{total_queries} queries, {execution_time:.2f}s",
|
||||
extra=performance_data,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Query tracking for {operation_name}: "
|
||||
f"{total_queries} queries, {execution_time:.2f}s",
|
||||
f"Query tracking for {operation_name}: " f"{total_queries} queries, {execution_time:.2f}s",
|
||||
extra=performance_data,
|
||||
)
|
||||
|
||||
@@ -109,9 +97,7 @@ class QueryOptimizer:
|
||||
Optimize Ride queryset with proper relationships
|
||||
"""
|
||||
return (
|
||||
queryset.select_related(
|
||||
"park", "park__location", "manufacturer", "created_by"
|
||||
)
|
||||
queryset.select_related("park", "park__location", "manufacturer", "created_by")
|
||||
.prefetch_related("reviews__user", "media_items")
|
||||
.annotate(
|
||||
review_count=Count("reviews"),
|
||||
@@ -158,9 +144,7 @@ class QueryCache:
|
||||
"""Caching utilities for expensive queries"""
|
||||
|
||||
@staticmethod
|
||||
def cache_queryset_result(
|
||||
cache_key: str, queryset_func, timeout: int = 3600, **kwargs
|
||||
):
|
||||
def cache_queryset_result(cache_key: str, queryset_func, timeout: int = 3600, **kwargs):
|
||||
"""
|
||||
Cache the result of an expensive queryset operation
|
||||
|
||||
@@ -202,13 +186,9 @@ class QueryCache:
|
||||
# For Redis cache backends that support pattern deletion
|
||||
if hasattr(cache, "delete_pattern"):
|
||||
deleted_count = cache.delete_pattern(pattern)
|
||||
logger.info(
|
||||
f"Invalidated {deleted_count} cache keys for pattern: {pattern}"
|
||||
)
|
||||
logger.info(f"Invalidated {deleted_count} cache keys for pattern: {pattern}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Cache backend does not support pattern deletion: {pattern}"
|
||||
)
|
||||
logger.warning(f"Cache backend does not support pattern deletion: {pattern}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error invalidating cache pattern {pattern}: {e}")
|
||||
|
||||
@@ -249,10 +229,7 @@ class IndexAnalyzer:
|
||||
sql_upper = sql.upper()
|
||||
analysis = {
|
||||
"has_where_clause": "WHERE" in sql_upper,
|
||||
"has_join": any(
|
||||
join in sql_upper
|
||||
for join in ["JOIN", "INNER JOIN", "LEFT JOIN", "RIGHT JOIN"]
|
||||
),
|
||||
"has_join": any(join in sql_upper for join in ["JOIN", "INNER JOIN", "LEFT JOIN", "RIGHT JOIN"]),
|
||||
"has_order_by": "ORDER BY" in sql_upper,
|
||||
"has_group_by": "GROUP BY" in sql_upper,
|
||||
"has_like": "LIKE" in sql_upper,
|
||||
@@ -266,19 +243,13 @@ class IndexAnalyzer:
|
||||
|
||||
# Suggest indexes based on patterns
|
||||
if analysis["has_where_clause"] and not analysis["has_join"]:
|
||||
analysis["suggestions"].append(
|
||||
"Consider adding indexes on WHERE clause columns"
|
||||
)
|
||||
analysis["suggestions"].append("Consider adding indexes on WHERE clause columns")
|
||||
|
||||
if analysis["has_order_by"]:
|
||||
analysis["suggestions"].append(
|
||||
"Consider adding indexes on ORDER BY columns"
|
||||
)
|
||||
analysis["suggestions"].append("Consider adding indexes on ORDER BY columns")
|
||||
|
||||
if analysis["has_like"] and "%" not in sql[: sql.find("LIKE") + 10]:
|
||||
analysis["suggestions"].append(
|
||||
"LIKE queries with leading wildcards cannot use indexes efficiently"
|
||||
)
|
||||
analysis["suggestions"].append("LIKE queries with leading wildcards cannot use indexes efficiently")
|
||||
|
||||
return analysis
|
||||
|
||||
@@ -294,28 +265,16 @@ class IndexAnalyzer:
|
||||
# automatically)
|
||||
for field in opts.fields:
|
||||
if isinstance(field, models.ForeignKey):
|
||||
suggestions.append(
|
||||
f"Index on {field.name} (automatically created by Django)"
|
||||
)
|
||||
suggestions.append(f"Index on {field.name} (automatically created by Django)")
|
||||
|
||||
# Suggest composite indexes for common query patterns
|
||||
date_fields = [
|
||||
f.name
|
||||
for f in opts.fields
|
||||
if isinstance(f, (models.DateField, models.DateTimeField))
|
||||
]
|
||||
status_fields = [
|
||||
f.name
|
||||
for f in opts.fields
|
||||
if f.name in ["status", "is_active", "is_published"]
|
||||
]
|
||||
date_fields = [f.name for f in opts.fields if isinstance(f, models.DateField | models.DateTimeField)]
|
||||
status_fields = [f.name for f in opts.fields if f.name in ["status", "is_active", "is_published"]]
|
||||
|
||||
if date_fields and status_fields:
|
||||
for date_field in date_fields:
|
||||
for status_field in status_fields:
|
||||
suggestions.append(
|
||||
f"Composite index on ({status_field}, {date_field}) for filtered date queries"
|
||||
)
|
||||
suggestions.append(f"Composite index on ({status_field}, {date_field}) for filtered date queries")
|
||||
|
||||
# Suggest indexes for fields commonly used in WHERE clauses
|
||||
common_filter_fields = ["slug", "name", "created_at", "updated_at"]
|
||||
@@ -340,9 +299,7 @@ def log_query_performance():
|
||||
return decorator
|
||||
|
||||
|
||||
def optimize_queryset_for_serialization(
|
||||
queryset: QuerySet, fields: list[str]
|
||||
) -> QuerySet:
|
||||
def optimize_queryset_for_serialization(queryset: QuerySet, fields: list[str]) -> QuerySet:
|
||||
"""
|
||||
Optimize a queryset for API serialization by only selecting needed fields
|
||||
|
||||
@@ -362,9 +319,7 @@ def optimize_queryset_for_serialization(
|
||||
field = opts.get_field(field_name)
|
||||
if isinstance(field, models.ForeignKey):
|
||||
select_related_fields.append(field_name)
|
||||
elif isinstance(
|
||||
field, (models.ManyToManyField, models.reverse.ManyToManyRel)
|
||||
):
|
||||
elif isinstance(field, models.ManyToManyField | models.reverse.ManyToManyRel):
|
||||
prefetch_related_fields.append(field_name)
|
||||
except models.FieldDoesNotExist:
|
||||
# Field might be a property or method, skip optimization
|
||||
@@ -421,7 +376,6 @@ def monitor_db_performance(operation_name: str):
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"DB performance for {operation_name}: "
|
||||
f"{duration:.3f}s, {total_queries} queries",
|
||||
f"DB performance for {operation_name}: " f"{duration:.3f}s, {total_queries} queries",
|
||||
extra=performance_data,
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ Cloudflare Turnstile validation utilities.
|
||||
This module provides a function to validate Turnstile tokens
|
||||
on the server side before processing form submissions.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from django.conf import settings
|
||||
|
||||
@@ -20,45 +21,41 @@ def validate_turnstile_token(token: str, ip: str = None) -> dict:
|
||||
dict with 'success' boolean and optional 'error' message
|
||||
"""
|
||||
# Skip validation if configured (dev mode)
|
||||
if getattr(settings, 'TURNSTILE_SKIP_VALIDATION', False):
|
||||
return {'success': True}
|
||||
if getattr(settings, "TURNSTILE_SKIP_VALIDATION", False):
|
||||
return {"success": True}
|
||||
|
||||
secret = getattr(settings, 'TURNSTILE_SECRET', '')
|
||||
secret = getattr(settings, "TURNSTILE_SECRET", "")
|
||||
if not secret:
|
||||
return {'success': True} # Skip if no secret configured
|
||||
return {"success": True} # Skip if no secret configured
|
||||
|
||||
if not token:
|
||||
return {'success': False, 'error': 'Captcha verification required'}
|
||||
return {"success": False, "error": "Captcha verification required"}
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
'https://challenges.cloudflare.com/turnstile/v0/siteverify',
|
||||
"https://challenges.cloudflare.com/turnstile/v0/siteverify",
|
||||
data={
|
||||
'secret': secret,
|
||||
'response': token,
|
||||
'remoteip': ip,
|
||||
"secret": secret,
|
||||
"response": token,
|
||||
"remoteip": ip,
|
||||
},
|
||||
timeout=10
|
||||
timeout=10,
|
||||
)
|
||||
result = response.json()
|
||||
|
||||
if result.get('success'):
|
||||
return {'success': True}
|
||||
if result.get("success"):
|
||||
return {"success": True}
|
||||
else:
|
||||
error_codes = result.get('error-codes', [])
|
||||
return {
|
||||
'success': False,
|
||||
'error': 'Captcha verification failed',
|
||||
'error_codes': error_codes
|
||||
}
|
||||
error_codes = result.get("error-codes", [])
|
||||
return {"success": False, "error": "Captcha verification failed", "error_codes": error_codes}
|
||||
except requests.RequestException:
|
||||
# Log error but don't block user on network issues
|
||||
return {'success': True} # Fail open to avoid blocking legitimate users
|
||||
return {"success": True} # Fail open to avoid blocking legitimate users
|
||||
|
||||
|
||||
def get_client_ip(request):
|
||||
"""Extract client IP from request, handling proxies."""
|
||||
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
|
||||
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
|
||||
if x_forwarded_for:
|
||||
return x_forwarded_for.split(',')[0].strip()
|
||||
return request.META.get('REMOTE_ADDR')
|
||||
return x_forwarded_for.split(",")[0].strip()
|
||||
return request.META.get("REMOTE_ADDR")
|
||||
|
||||
@@ -5,7 +5,6 @@ This module provides base view classes that implement common patterns
|
||||
such as automatic query optimization with select_related and prefetch_related.
|
||||
"""
|
||||
|
||||
|
||||
from django.db.models import QuerySet
|
||||
from django.views.generic import DetailView, ListView
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
Entity search views with fuzzy matching and authentication prompts.
|
||||
"""
|
||||
|
||||
|
||||
import contextlib
|
||||
|
||||
from rest_framework import status
|
||||
@@ -67,9 +66,7 @@ class EntityFuzzySearchView(APIView):
|
||||
try:
|
||||
# Parse request data
|
||||
query = request.data.get("query", "").strip()
|
||||
entity_types_raw = request.data.get(
|
||||
"entity_types", ["park", "ride", "company"]
|
||||
)
|
||||
entity_types_raw = request.data.get("entity_types", ["park", "ride", "company"])
|
||||
include_suggestions = request.data.get("include_suggestions", True)
|
||||
|
||||
# Validate query
|
||||
@@ -105,9 +102,7 @@ class EntityFuzzySearchView(APIView):
|
||||
"query": query,
|
||||
"matches": [match.to_dict() for match in matches],
|
||||
"user_authenticated": (
|
||||
request.user.is_authenticated
|
||||
if hasattr(request.user, "is_authenticated")
|
||||
else False
|
||||
request.user.is_authenticated if hasattr(request.user, "is_authenticated") else False
|
||||
),
|
||||
}
|
||||
|
||||
@@ -211,9 +206,7 @@ class EntityNotFoundView(APIView):
|
||||
"context": context,
|
||||
"matches": [match.to_dict() for match in matches],
|
||||
"user_authenticated": (
|
||||
request.user.is_authenticated
|
||||
if hasattr(request.user, "is_authenticated")
|
||||
else False
|
||||
request.user.is_authenticated if hasattr(request.user, "is_authenticated") else False
|
||||
),
|
||||
"has_matches": len(matches) > 0,
|
||||
}
|
||||
@@ -267,9 +260,7 @@ class QuickEntitySuggestionView(APIView):
|
||||
limit = min(int(request.GET.get("limit", 5)), 10) # Cap at 10
|
||||
|
||||
if not query or len(query) < 2:
|
||||
return Response(
|
||||
{"suggestions": [], "query": query}, status=status.HTTP_200_OK
|
||||
)
|
||||
return Response({"suggestions": [], "query": query}, status=status.HTTP_200_OK)
|
||||
|
||||
# Parse entity types
|
||||
entity_types = []
|
||||
@@ -282,9 +273,7 @@ class QuickEntitySuggestionView(APIView):
|
||||
entity_types = [EntityType.PARK, EntityType.RIDE, EntityType.COMPANY]
|
||||
|
||||
# Get fuzzy matches
|
||||
matches, _ = entity_fuzzy_matcher.find_entity(
|
||||
query=query, entity_types=entity_types, user=request.user
|
||||
)
|
||||
matches, _ = entity_fuzzy_matcher.find_entity(query=query, entity_types=entity_types, user=request.user)
|
||||
|
||||
# Format as simple suggestions
|
||||
suggestions = []
|
||||
@@ -313,9 +302,7 @@ class QuickEntitySuggestionView(APIView):
|
||||
|
||||
|
||||
# Utility function for other views to use
|
||||
def get_entity_suggestions(
|
||||
query: str, entity_types: list[str] | None = None, user=None
|
||||
):
|
||||
def get_entity_suggestions(query: str, entity_types: list[str] | None = None, user=None):
|
||||
"""
|
||||
Utility function for other Django views to get entity suggestions.
|
||||
|
||||
@@ -340,8 +327,6 @@ def get_entity_suggestions(
|
||||
if not parsed_types:
|
||||
parsed_types = [EntityType.PARK, EntityType.RIDE, EntityType.COMPANY]
|
||||
|
||||
return entity_fuzzy_matcher.find_entity(
|
||||
query=query, entity_types=parsed_types, user=user
|
||||
)
|
||||
return entity_fuzzy_matcher.find_entity(query=query, entity_types=parsed_types, user=user)
|
||||
except Exception:
|
||||
return [], None
|
||||
|
||||
@@ -86,7 +86,7 @@ class MapAPIView(View):
|
||||
return bounds
|
||||
return None
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValidationError(f"Invalid bounds parameters: {e}")
|
||||
raise ValidationError(f"Invalid bounds parameters: {e}") from None
|
||||
|
||||
def _parse_pagination(self, request: HttpRequest) -> dict[str, int]:
|
||||
"""Parse pagination parameters from request."""
|
||||
@@ -125,11 +125,7 @@ class MapAPIView(View):
|
||||
if location_types_param:
|
||||
type_strings = location_types_param.split(",")
|
||||
valid_types = {lt.value for lt in LocationType}
|
||||
filters.location_types = {
|
||||
LocationType(t.strip())
|
||||
for t in type_strings
|
||||
if t.strip() in valid_types
|
||||
}
|
||||
filters.location_types = {LocationType(t.strip()) for t in type_strings if t.strip() in valid_types}
|
||||
|
||||
# Park status
|
||||
park_status_param = request.GET.get("park_status")
|
||||
@@ -199,7 +195,7 @@ class MapAPIView(View):
|
||||
)
|
||||
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValidationError(f"Invalid filter parameters: {e}")
|
||||
raise ValidationError(f"Invalid filter parameters: {e}") from None
|
||||
|
||||
def _parse_zoom_level(self, request: HttpRequest) -> int:
|
||||
"""Parse zoom level from request with default."""
|
||||
@@ -218,9 +214,7 @@ class MapAPIView(View):
|
||||
request: HttpRequest,
|
||||
) -> dict[str, Any]:
|
||||
"""Create paginated response with metadata."""
|
||||
total_pages = (total_count + pagination["page_size"] - 1) // pagination[
|
||||
"page_size"
|
||||
]
|
||||
total_pages = (total_count + pagination["page_size"] - 1) // pagination["page_size"]
|
||||
|
||||
# Build pagination URLs
|
||||
base_url = request.build_absolute_uri(request.path)
|
||||
@@ -278,9 +272,7 @@ class MapAPIView(View):
|
||||
|
||||
return JsonResponse(response_data, status=status)
|
||||
|
||||
def _success_response(
|
||||
self, data: Any, message: str = None, metadata: dict[str, Any] = None
|
||||
) -> JsonResponse:
|
||||
def _success_response(self, data: Any, message: str = None, metadata: dict[str, Any] = None) -> JsonResponse:
|
||||
"""Return standardized success response."""
|
||||
response_data = {
|
||||
"status": "success",
|
||||
@@ -397,9 +389,7 @@ class MapLocationDetailView(MapAPIView):
|
||||
"""
|
||||
|
||||
@method_decorator(cache_page(600)) # Cache for 10 minutes
|
||||
def get(
|
||||
self, request: HttpRequest, location_type: str, location_id: int
|
||||
) -> JsonResponse:
|
||||
def get(self, request: HttpRequest, location_type: str, location_id: int) -> JsonResponse:
|
||||
"""Get detailed information for a specific location."""
|
||||
try:
|
||||
# Validate location type
|
||||
@@ -422,9 +412,7 @@ class MapLocationDetailView(MapAPIView):
|
||||
)
|
||||
|
||||
# Get location details
|
||||
location = unified_map_service.get_location_details(
|
||||
location_type, location_id
|
||||
)
|
||||
location = unified_map_service.get_location_details(location_type, location_id)
|
||||
|
||||
if not location:
|
||||
return self._error_response(
|
||||
@@ -499,9 +487,7 @@ class MapSearchView(MapAPIView):
|
||||
try:
|
||||
valid_types = {lt.value for lt in LocationType}
|
||||
location_types = {
|
||||
LocationType(t.strip())
|
||||
for t in types_param.split(",")
|
||||
if t.strip() in valid_types
|
||||
LocationType(t.strip()) for t in types_param.split(",") if t.strip() in valid_types
|
||||
}
|
||||
except ValueError:
|
||||
return self._error_response(
|
||||
@@ -569,9 +555,7 @@ class MapBoundsView(MapAPIView):
|
||||
# Parse required bounds
|
||||
bounds = self._parse_bounds(request)
|
||||
if not bounds:
|
||||
return self._error_response(
|
||||
"Bounds parameters required: north, south, east, west", 400
|
||||
)
|
||||
return self._error_response("Bounds parameters required: north, south, east, west", 400)
|
||||
|
||||
# Parse optional filters
|
||||
location_types = None
|
||||
|
||||
@@ -74,9 +74,7 @@ class UniversalMapView(MapViewMixin, TemplateView):
|
||||
)
|
||||
|
||||
# Handle initial bounds from query parameters
|
||||
if all(
|
||||
param in self.request.GET for param in ["north", "south", "east", "west"]
|
||||
):
|
||||
if all(param in self.request.GET for param in ["north", "south", "east", "west"]):
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
context["initial_bounds"] = {
|
||||
"north": float(self.request.GET["north"]),
|
||||
@@ -243,9 +241,7 @@ class LocationSearchView(MapViewMixin, View):
|
||||
limit = min(20, max(5, int(request.GET.get("limit", "10"))))
|
||||
|
||||
# Perform search
|
||||
results = unified_map_service.search_locations(
|
||||
query=query, location_types=location_types, limit=limit
|
||||
)
|
||||
results = unified_map_service.search_locations(query=query, location_types=location_types, limit=limit)
|
||||
|
||||
return render(
|
||||
request,
|
||||
@@ -285,11 +281,7 @@ class MapBoundsUpdateView(MapViewMixin, View):
|
||||
zoom_level = int(data.get("zoom", 10))
|
||||
location_types = None
|
||||
if "types" in data:
|
||||
location_types = {
|
||||
LocationType(t)
|
||||
for t in data["types"]
|
||||
if t in [lt.value for lt in LocationType]
|
||||
}
|
||||
location_types = {LocationType(t) for t in data["types"] if t in [lt.value for lt in LocationType]}
|
||||
|
||||
# Location types are used directly in the service call
|
||||
|
||||
@@ -324,9 +316,7 @@ class LocationDetailModalView(MapViewMixin, View):
|
||||
URL: /maps/htmx/location/<type>/<id>/
|
||||
"""
|
||||
|
||||
def get(
|
||||
self, request: HttpRequest, location_type: str, location_id: int
|
||||
) -> HttpResponse:
|
||||
def get(self, request: HttpRequest, location_type: str, location_id: int) -> HttpResponse:
|
||||
"""Return location detail modal content."""
|
||||
try:
|
||||
# Validate location type
|
||||
@@ -338,9 +328,7 @@ class LocationDetailModalView(MapViewMixin, View):
|
||||
)
|
||||
|
||||
# Get location details
|
||||
location = unified_map_service.get_location_details(
|
||||
location_type, location_id
|
||||
)
|
||||
location = unified_map_service.get_location_details(location_type, location_id)
|
||||
|
||||
if not location:
|
||||
return render(
|
||||
@@ -356,9 +344,7 @@ class LocationDetailModalView(MapViewMixin, View):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return render(
|
||||
request, "maps/partials/location_modal.html", {"error": str(e)}
|
||||
)
|
||||
return render(request, "maps/partials/location_modal.html", {"error": str(e)})
|
||||
|
||||
|
||||
class LocationListView(MapViewMixin, TemplateView):
|
||||
@@ -392,9 +378,7 @@ class LocationListView(MapViewMixin, TemplateView):
|
||||
)
|
||||
|
||||
# Get locations without clustering
|
||||
map_response = unified_map_service.get_map_data(
|
||||
filters=filters, cluster=False, use_cache=True
|
||||
)
|
||||
map_response = unified_map_service.get_map_data(filters=filters, cluster=False, use_cache=True)
|
||||
|
||||
# Paginate results
|
||||
paginator = Paginator(map_response.locations, self.paginate_by)
|
||||
|
||||
@@ -83,13 +83,15 @@ class PerformanceDashboardView(TemplateView):
|
||||
try:
|
||||
client = cache._cache.get_client()
|
||||
info = client.info()
|
||||
cache_stats.update({
|
||||
"connected_clients": info.get("connected_clients"),
|
||||
"used_memory_human": info.get("used_memory_human"),
|
||||
"keyspace_hits": info.get("keyspace_hits", 0),
|
||||
"keyspace_misses": info.get("keyspace_misses", 0),
|
||||
"total_commands": info.get("total_commands_processed"),
|
||||
})
|
||||
cache_stats.update(
|
||||
{
|
||||
"connected_clients": info.get("connected_clients"),
|
||||
"used_memory_human": info.get("used_memory_human"),
|
||||
"keyspace_hits": info.get("keyspace_hits", 0),
|
||||
"keyspace_misses": info.get("keyspace_misses", 0),
|
||||
"total_commands": info.get("total_commands_processed"),
|
||||
}
|
||||
)
|
||||
|
||||
# Calculate hit rate
|
||||
hits = info.get("keyspace_hits", 0)
|
||||
@@ -127,8 +129,7 @@ class PerformanceDashboardView(TemplateView):
|
||||
# Get connection count (PostgreSQL specific)
|
||||
try:
|
||||
cursor.execute(
|
||||
"SELECT count(*) FROM pg_stat_activity WHERE datname = %s;",
|
||||
[db_settings.get("NAME")]
|
||||
"SELECT count(*) FROM pg_stat_activity WHERE datname = %s;", [db_settings.get("NAME")]
|
||||
)
|
||||
stats["active_connections"] = cursor.fetchone()[0]
|
||||
except Exception:
|
||||
@@ -244,16 +245,18 @@ class CacheStatsAPIView(View):
|
||||
client = cache._cache.get_client()
|
||||
info = client.info()
|
||||
|
||||
cache_info.update({
|
||||
"used_memory": info.get("used_memory_human"),
|
||||
"connected_clients": info.get("connected_clients"),
|
||||
"keyspace_hits": info.get("keyspace_hits", 0),
|
||||
"keyspace_misses": info.get("keyspace_misses", 0),
|
||||
"expired_keys": info.get("expired_keys", 0),
|
||||
"evicted_keys": info.get("evicted_keys", 0),
|
||||
"total_connections_received": info.get("total_connections_received"),
|
||||
"total_commands_processed": info.get("total_commands_processed"),
|
||||
})
|
||||
cache_info.update(
|
||||
{
|
||||
"used_memory": info.get("used_memory_human"),
|
||||
"connected_clients": info.get("connected_clients"),
|
||||
"keyspace_hits": info.get("keyspace_hits", 0),
|
||||
"keyspace_misses": info.get("keyspace_misses", 0),
|
||||
"expired_keys": info.get("expired_keys", 0),
|
||||
"evicted_keys": info.get("evicted_keys", 0),
|
||||
"total_connections_received": info.get("total_connections_received"),
|
||||
"total_commands_processed": info.get("total_commands_processed"),
|
||||
}
|
||||
)
|
||||
|
||||
# Calculate metrics
|
||||
hits = info.get("keyspace_hits", 0)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user