feat: Refactor rides app with unique constraints, mixins, and enhanced documentation

- Added migration to convert unique_together constraints to UniqueConstraint for RideModel.
- Introduced RideFormMixin for handling entity suggestions in ride forms.
- Created comprehensive code standards documentation outlining formatting, docstring requirements, complexity guidelines, and testing requirements.
- Established error handling guidelines with a structured exception hierarchy and best practices for API and view error handling.
- Documented view pattern guidelines, emphasizing the use of CBVs, FBVs, and ViewSets with examples.
- Implemented a benchmarking script for query performance analysis and optimization.
- Developed security documentation detailing measures, configurations, and a security checklist.
- Compiled a database optimization guide covering indexing strategies, query optimization patterns, and computed fields.
This commit is contained in:
pacnpal
2025-12-22 11:17:31 -05:00
parent 45d97b6e68
commit 2e35f8c5d9
71 changed files with 8036 additions and 1462 deletions

View File

@@ -1,7 +1,15 @@
"""
Management command to reset the database and create an admin user.
Security Note: This command uses a mix of raw SQL (for PostgreSQL-specific operations
like dropping all tables) and Django ORM (for creating users). The raw SQL operations
use quote_ident() for table/sequence names which is safe from SQL injection.
WARNING: This command is destructive and should only be used in development.
"""
from django.core.management.base import BaseCommand
from django.db import connection
from django.contrib.auth.hashers import make_password
import uuid
class Command(BaseCommand):
@@ -10,7 +18,8 @@ class Command(BaseCommand):
def handle(self, *args, **options):
self.stdout.write("Resetting database...")
# Drop all tables
# Drop all tables using PostgreSQL-specific operations
# Security: Using quote_ident() to safely quote table/sequence names
with connection.cursor() as cursor:
cursor.execute(
"""
@@ -21,7 +30,7 @@ class Command(BaseCommand):
SELECT tablename FROM pg_tables
WHERE schemaname = current_schema()
) LOOP
EXECUTE 'DROP TABLE IF EXISTS ' || \
EXECUTE 'DROP TABLE IF EXISTS ' ||
quote_ident(r.tablename) || ' CASCADE';
END LOOP;
END $$;
@@ -38,7 +47,7 @@ class Command(BaseCommand):
SELECT sequencename FROM pg_sequences
WHERE schemaname = current_schema()
) LOOP
EXECUTE 'ALTER SEQUENCE ' || \
EXECUTE 'ALTER SEQUENCE ' ||
quote_ident(r.sequencename) || ' RESTART WITH 1';
END LOOP;
END $$;
@@ -54,51 +63,25 @@ class Command(BaseCommand):
self.stdout.write("Migrations applied.")
# Create superuser using raw SQL
# Create superuser using Django ORM (safer than raw SQL)
try:
with connection.cursor() as cursor:
# Create user
user_id = str(uuid.uuid4())[:10]
cursor.execute(
"""
INSERT INTO accounts_user (
username, password, email, is_superuser, is_staff,
is_active, date_joined, user_id, first_name,
last_name, role, is_banned, ban_reason,
theme_preference
) VALUES (
'admin', %s, 'admin@thrillwiki.com', true, true,
true, NOW(), %s, '', '', 'SUPERUSER', false, '',
'light'
) RETURNING id;
""",
[make_password("admin"), user_id],
)
from apps.accounts.models import User, UserProfile
result = cursor.fetchone()
if result is None:
raise Exception("Failed to create user - no ID returned")
user_db_id = result[0]
# Security: Using Django ORM instead of raw SQL for user creation
user = User.objects.create_superuser(
username='admin',
email='admin@thrillwiki.com',
password='admin',
role='SUPERUSER',
)
# Create profile
profile_id = str(uuid.uuid4())[:10]
cursor.execute(
"""
INSERT INTO accounts_userprofile (
profile_id, display_name, pronouns, bio,
twitter, instagram, youtube, discord,
coaster_credits, dark_ride_credits,
flat_ride_credits, water_ride_credits,
user_id, avatar
) VALUES (
%s, 'Admin', 'they/them', 'ThrillWiki Administrator',
'', '', '', '',
0, 0, 0, 0,
%s, ''
);
""",
[profile_id, user_db_id],
)
# Create profile using ORM
UserProfile.objects.create(
user=user,
display_name='Admin',
pronouns='they/them',
bio='ThrillWiki Administrator',
)
self.stdout.write("Superuser created.")
except Exception as e:

View File

@@ -0,0 +1,40 @@
"""
Add performance indexes and constraints to User model.
This migration adds:
1. db_index=True to is_banned and role fields for faster filtering
2. Composite index on (is_banned, role) for common query patterns
3. CheckConstraint to ensure banned users have a ban_date set
"""
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('accounts', '0012_alter_toplist_category_and_more'),
]
operations = [
# Add db_index to is_banned field
migrations.AlterField(
model_name='user',
name='is_banned',
field=models.BooleanField(default=False, db_index=True),
),
# Add composite index for common query patterns
migrations.AddIndex(
model_name='user',
index=models.Index(fields=['is_banned', 'role'], name='accounts_user_banned_role_idx'),
),
# Add CheckConstraint for ban consistency
migrations.AddConstraint(
model_name='user',
constraint=models.CheckConstraint(
name='user_ban_consistency',
check=models.Q(is_banned=False) | models.Q(ban_date__isnull=False),
violation_error_message='Banned users must have a ban_date set'
),
),
]

View File

@@ -49,8 +49,9 @@ class User(AbstractUser):
domain="accounts",
max_length=10,
default="USER",
db_index=True,
)
is_banned = models.BooleanField(default=False)
is_banned = models.BooleanField(default=False, db_index=True)
ban_reason = models.TextField(blank=True)
ban_date = models.DateTimeField(null=True, blank=True)
pending_email = models.EmailField(blank=True, null=True)
@@ -127,6 +128,18 @@ class User(AbstractUser):
return profile.display_name
return self.username
class Meta:
indexes = [
models.Index(fields=['is_banned', 'role'], name='accounts_user_banned_role_idx'),
]
constraints = [
models.CheckConstraint(
name='user_ban_consistency',
check=models.Q(is_banned=False) | models.Q(ban_date__isnull=False),
violation_error_message='Banned users must have a ban_date set'
),
]
def save(self, *args, **kwargs):
if not self.user_id:
self.user_id = generate_random_id(User, "user_id")

View File

@@ -2,16 +2,281 @@
User management services for ThrillWiki.
This module contains services for user account management including
user deletion while preserving submissions.
user deletion while preserving submissions, password management,
and email change functionality.
Recent additions:
- AccountService: Handles password and email change operations
- UserDeletionService: Manages user deletion while preserving content
"""
from typing import Optional
from django.db import transaction
from django.utils import timezone
import logging
import re
from typing import Any, Dict, Optional
from django.conf import settings
from django.contrib.auth import update_session_auth_hash
from django.contrib.sites.models import Site
from django.contrib.sites.shortcuts import get_current_site
from django.db import transaction
from django.http import HttpRequest
from django.template.loader import render_to_string
from django.utils import timezone
from django.utils.crypto import get_random_string
from django_forwardemail.services import EmailService
from .models import User, UserProfile, UserDeletionRequest
from .models import EmailVerification, User, UserDeletionRequest, UserProfile
logger = logging.getLogger(__name__)
class AccountService:
"""Service for account management operations including password and email changes."""
@staticmethod
def validate_password(password: str) -> bool:
"""
Validate password meets requirements.
Args:
password: The password to validate
Returns:
True if password meets requirements, False otherwise
"""
return (
len(password) >= 8
and bool(re.search(r"[A-Z]", password))
and bool(re.search(r"[a-z]", password))
and bool(re.search(r"[0-9]", password))
)
@staticmethod
def change_password(
*,
user: User,
old_password: str,
new_password: str,
request: HttpRequest,
) -> Dict[str, Any]:
"""
Change user password with validation and notification.
Validates the old password, checks new password requirements,
updates the password, and sends a confirmation email.
Args:
user: The user whose password is being changed
old_password: Current password for verification
new_password: New password to set
request: HTTP request for session handling
Returns:
Dictionary with success status, message, and optional redirect URL:
{
'success': bool,
'message': str,
'redirect_url': Optional[str]
}
"""
# Verify old password
if not user.check_password(old_password):
logger.warning(
f"Password change failed: incorrect current password for user {user.id}"
)
return {
'success': False,
'message': "Current password is incorrect",
'redirect_url': None
}
# Validate new password
if not AccountService.validate_password(new_password):
return {
'success': False,
'message': "Password must be at least 8 characters and contain uppercase, lowercase, and numbers",
'redirect_url': None
}
# Update password
user.set_password(new_password)
user.save()
# Keep user logged in after password change
update_session_auth_hash(request, user)
# Send confirmation email
AccountService._send_password_change_confirmation(request, user)
logger.info(f"Password changed successfully for user {user.id}")
return {
'success': True,
'message': "Password changed successfully. Please check your email for confirmation.",
'redirect_url': None
}
@staticmethod
def _send_password_change_confirmation(request: HttpRequest, user: User) -> None:
"""Send password change confirmation email."""
site = get_current_site(request)
context = {
"user": user,
"site_name": site.name,
}
email_html = render_to_string(
"accounts/email/password_change_confirmation.html", context
)
try:
EmailService.send_email(
to=user.email,
subject="Password Changed Successfully",
text="Your password has been changed successfully.",
site=site,
html=email_html,
)
except Exception as e:
logger.error(f"Failed to send password change confirmation email: {e}")
@staticmethod
def initiate_email_change(
*,
user: User,
new_email: str,
request: HttpRequest,
) -> Dict[str, Any]:
"""
Initiate email change with verification.
Creates a verification token and sends a verification email
to the new email address.
Args:
user: The user changing their email
new_email: The new email address
request: HTTP request for site context
Returns:
Dictionary with success status and message:
{
'success': bool,
'message': str
}
"""
if not new_email:
return {
'success': False,
'message': "New email is required"
}
# Check if email is already in use
if User.objects.filter(email=new_email).exclude(id=user.id).exists():
return {
'success': False,
'message': "This email address is already in use"
}
# Generate verification token
token = get_random_string(64)
# Create or update email verification record
EmailVerification.objects.update_or_create(
user=user,
defaults={"token": token}
)
# Store pending email
user.pending_email = new_email
user.save()
# Send verification email
AccountService._send_email_verification(request, user, new_email, token)
logger.info(f"Email change initiated for user {user.id} to {new_email}")
return {
'success': True,
'message': "Verification email sent to your new email address"
}
@staticmethod
def _send_email_verification(
request: HttpRequest,
user: User,
new_email: str,
token: str
) -> None:
"""Send email verification for email change."""
from django.urls import reverse
site = get_current_site(request)
verification_url = reverse("verify_email", kwargs={"token": token})
context = {
"user": user,
"verification_url": verification_url,
"site_name": site.name,
}
email_html = render_to_string("accounts/email/verify_email.html", context)
try:
EmailService.send_email(
to=new_email,
subject="Verify your new email address",
text="Click the link to verify your new email address",
site=site,
html=email_html,
)
except Exception as e:
logger.error(f"Failed to send email verification: {e}")
@staticmethod
def verify_email_change(*, token: str) -> Dict[str, Any]:
"""
Verify email change token and update user email.
Args:
token: The verification token
Returns:
Dictionary with success status and message
"""
try:
verification = EmailVerification.objects.select_related("user").get(
token=token
)
except EmailVerification.DoesNotExist:
return {
'success': False,
'message': "Invalid or expired verification token"
}
user = verification.user
if not user.pending_email:
return {
'success': False,
'message': "No pending email change found"
}
# Update email
old_email = user.email
user.email = user.pending_email
user.pending_email = None
user.save()
# Delete verification record
verification.delete()
logger.info(f"Email changed for user {user.id} from {old_email} to {user.email}")
return {
'success': True,
'message': "Email address updated successfully"
}
class UserDeletionService:

View File

@@ -0,0 +1,101 @@
"""
Tests for model constraints and validators in the accounts app.
These tests verify that:
1. CheckConstraints raise appropriate errors
2. Validators work correctly
3. Business rules are enforced at the model level
"""
from django.test import TestCase
from django.db import IntegrityError
from django.core.exceptions import ValidationError
from django.utils import timezone
from apps.accounts.models import User
class UserConstraintTests(TestCase):
"""Tests for User model constraints."""
def test_banned_user_without_ban_date_raises_error(self):
"""Verify banned users must have a ban_date set."""
user = User(
username="testuser",
email="test@example.com",
is_banned=True,
ban_date=None, # This should violate the constraint
)
# The constraint should be enforced at database level
with self.assertRaises(IntegrityError):
user.save()
def test_banned_user_with_ban_date_saves_successfully(self):
"""Verify banned users with ban_date save successfully."""
user = User.objects.create_user(
username="testuser2",
email="test2@example.com",
password="testpass123",
is_banned=True,
ban_date=timezone.now(),
)
self.assertIsNotNone(user.pk)
self.assertTrue(user.is_banned)
self.assertIsNotNone(user.ban_date)
def test_non_banned_user_without_ban_date_saves_successfully(self):
"""Verify non-banned users can be saved without ban_date."""
user = User.objects.create_user(
username="testuser3",
email="test3@example.com",
password="testpass123",
is_banned=False,
ban_date=None,
)
self.assertIsNotNone(user.pk)
self.assertFalse(user.is_banned)
def test_user_id_is_auto_generated(self):
"""Verify user_id is automatically generated on save."""
user = User.objects.create_user(
username="testuser4",
email="test4@example.com",
password="testpass123",
)
self.assertIsNotNone(user.user_id)
self.assertTrue(len(user.user_id) >= 4)
def test_user_id_is_unique(self):
"""Verify user_id is unique across users."""
user1 = User.objects.create_user(
username="testuser5",
email="test5@example.com",
password="testpass123",
)
user2 = User.objects.create_user(
username="testuser6",
email="test6@example.com",
password="testpass123",
)
self.assertNotEqual(user1.user_id, user2.user_id)
class UserIndexTests(TestCase):
"""Tests for User model indexes."""
def test_is_banned_field_is_indexed(self):
"""Verify is_banned field has db_index=True."""
field = User._meta.get_field('is_banned')
self.assertTrue(field.db_index)
def test_role_field_is_indexed(self):
"""Verify role field has db_index=True."""
field = User._meta.get_field('role')
self.assertTrue(field.db_index)
def test_composite_index_exists(self):
"""Verify composite index on (is_banned, role) exists."""
indexes = User._meta.indexes
index_names = [idx.name for idx in indexes]
self.assertIn('accounts_user_banned_role_idx', index_names)

View File

@@ -1302,11 +1302,15 @@ def get_user_statistics(request):
user = request.user
# Calculate user statistics
# TODO(THRILLWIKI-104): Implement full user statistics tracking
from apps.parks.models import ParkReview
from apps.rides.models import RideReview
data = {
"parks_visited": 0, # TODO: Implement based on reviews/check-ins
"rides_ridden": 0, # TODO: Implement based on reviews/check-ins
"reviews_written": 0, # TODO: Count user's reviews
"photos_uploaded": 0, # TODO: Count user's photos
"parks_visited": ParkReview.objects.filter(user=user).values("park").distinct().count(),
"rides_ridden": RideReview.objects.filter(user=user).values("ride").distinct().count(),
"reviews_written": ParkReview.objects.filter(user=user).count() + RideReview.objects.filter(user=user).count(),
"photos_uploaded": 0, # TODO(THRILLWIKI-105): Implement photo counting
"top_lists_created": TopList.objects.filter(user=user).count(),
"member_since": user.date_joined,
"last_activity": user.last_login,

View File

@@ -7,8 +7,6 @@ from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
from rest_framework.permissions import AllowAny
from django.views.decorators.csrf import csrf_exempt
from django.utils.decorators import method_decorator
from typing import Optional, List
from drf_spectacular.utils import extend_schema
@@ -260,12 +258,14 @@ class EntityNotFoundView(APIView):
)
@method_decorator(csrf_exempt, name="dispatch")
class QuickEntitySuggestionView(APIView):
"""
Lightweight endpoint for quick entity suggestions (e.g., autocomplete).
Migrated from apps.core.views.entity_search.QuickEntitySuggestionView
Security Note: This endpoint only accepts GET requests, which are inherently
safe from CSRF attacks. No CSRF exemption is needed.
"""
permission_classes = [AllowAny]

View File

@@ -12,7 +12,7 @@ from django.contrib.gis.geos import Polygon
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
from rest_framework.permissions import AllowAny
from rest_framework.permissions import AllowAny, IsAdminUser
from drf_spectacular.utils import (
extend_schema,
extend_schema_view,
@@ -306,7 +306,7 @@ class MapLocationsAPIView(APIView):
return {
"status": "success",
"locations": locations,
"clusters": [], # TODO: Implement clustering
"clusters": [], # TODO(THRILLWIKI-106): Implement map clustering algorithm
"bounds": self._calculate_bounds(locations),
"total_count": len(locations),
"clustered": params["cluster"],
@@ -471,7 +471,7 @@ class MapLocationDetailAPIView(APIView):
obj.opening_date.isoformat() if obj.opening_date else None
),
},
"nearby_locations": [], # TODO: Implement nearby locations
"nearby_locations": [], # TODO(THRILLWIKI-107): Implement nearby locations for parks
}
else: # ride
data = {
@@ -538,7 +538,7 @@ class MapLocationDetailAPIView(APIView):
obj.manufacturer.name if obj.manufacturer else None
),
},
"nearby_locations": [], # TODO: Implement nearby locations
"nearby_locations": [], # TODO(THRILLWIKI-107): Implement nearby locations for rides
}
return Response(
@@ -669,7 +669,7 @@ class MapSearchAPIView(APIView):
else ""
),
},
"relevance_score": 1.0, # TODO: Implement relevance scoring
"relevance_score": 1.0, # TODO(THRILLWIKI-108): Implement relevance scoring for search
}
)
@@ -722,7 +722,7 @@ class MapSearchAPIView(APIView):
else ""
),
},
"relevance_score": 1.0, # TODO: Implement relevance scoring
"relevance_score": 1.0, # TODO(THRILLWIKI-108): Implement relevance scoring for search
}
)
@@ -965,8 +965,8 @@ class MapStatsAPIView(APIView):
"total_locations": total_locations,
"parks_with_location": parks_with_location,
"rides_with_location": rides_with_location,
"cache_hits": 0, # TODO: Implement cache statistics
"cache_misses": 0, # TODO: Implement cache statistics
"cache_hits": 0, # TODO(THRILLWIKI-109): Implement cache statistics tracking
"cache_misses": 0, # TODO(THRILLWIKI-109): Implement cache statistics tracking
},
}
)
@@ -996,7 +996,7 @@ class MapStatsAPIView(APIView):
class MapCacheAPIView(APIView):
"""API endpoint for cache management (admin only)."""
permission_classes = [AllowAny] # TODO: Add admin permission check
permission_classes = [IsAdminUser] # Admin only
def delete(self, request: HttpRequest) -> Response:
"""Clear all map cache (admin only)."""

View File

@@ -5,34 +5,44 @@ This module contains consolidated park photo viewset for the centralized API str
Enhanced from rogue implementation to maintain full feature parity.
"""
from .serializers import (
ParkPhotoOutputSerializer,
ParkPhotoCreateInputSerializer,
ParkPhotoUpdateInputSerializer,
ParkPhotoListOutputSerializer,
ParkPhotoApprovalInputSerializer,
ParkPhotoStatsOutputSerializer,
)
from typing import Any, cast
import logging
from typing import Any, cast
from django.contrib.auth import get_user_model
from django.core.exceptions import PermissionDenied
from drf_spectacular.utils import extend_schema_view, extend_schema, OpenApiParameter
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view
from rest_framework import status
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
from rest_framework.permissions import IsAuthenticated, AllowAny
from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.viewsets import ModelViewSet
from apps.parks.models import ParkPhoto, Park
from apps.core.exceptions import (
NotFoundError,
PermissionDeniedError,
ServiceError,
ValidationException,
)
from apps.core.utils.error_handling import ErrorHandler
from apps.parks.models import Park, ParkPhoto
from apps.parks.services import ParkMediaService
from django.contrib.auth import get_user_model
from apps.parks.services.hybrid_loader import smart_park_loader
UserModel = get_user_model()
from .serializers import (
HybridParkSerializer,
ParkPhotoApprovalInputSerializer,
ParkPhotoCreateInputSerializer,
ParkPhotoListOutputSerializer,
ParkPhotoOutputSerializer,
ParkPhotoStatsOutputSerializer,
ParkPhotoUpdateInputSerializer,
)
logger = logging.getLogger(__name__)
UserModel = get_user_model()
@extend_schema_view(
@@ -113,7 +123,7 @@ class ParkPhotoViewSet(ModelViewSet):
def get_permissions(self):
"""Set permissions based on action."""
if self.action in ['list', 'retrieve', 'stats']:
if self.action in ["list", "retrieve", "stats"]:
permission_classes = [AllowAny]
else:
permission_classes = [IsAuthenticated]
@@ -166,8 +176,11 @@ class ParkPhotoViewSet(ModelViewSet):
# Set the instance for the serializer response
serializer.instance = photo
except Exception as e:
logger.error(f"Error creating park photo: {e}")
except (ValidationException, ValidationError) as e:
logger.warning(f"Validation error creating park photo: {e}")
raise ValidationError(str(e))
except ServiceError as e:
logger.error(f"Service error creating park photo: {e}")
raise ValidationError(f"Failed to create photo: {str(e)}")
def perform_update(self, serializer):
@@ -190,8 +203,11 @@ class ParkPhotoViewSet(ModelViewSet):
# Remove is_primary from validated_data since service handles it
if "is_primary" in serializer.validated_data:
del serializer.validated_data["is_primary"]
except Exception as e:
logger.error(f"Error setting primary photo: {e}")
except (ValidationException, ValidationError) as e:
logger.warning(f"Validation error setting primary photo: {e}")
raise ValidationError(str(e))
except ServiceError as e:
logger.error(f"Service error setting primary photo: {e}")
raise ValidationError(f"Failed to set primary photo: {str(e)}")
def perform_destroy(self, instance):
@@ -205,25 +221,30 @@ class ParkPhotoViewSet(ModelViewSet):
"You can only delete your own photos or be an admin."
)
try:
# Delete from Cloudflare first if image exists
if instance.image:
try:
from django_cloudflareimages_toolkit.services import CloudflareImagesService
service = CloudflareImagesService()
service.delete_image(instance.image)
logger.info(
f"Successfully deleted park photo from Cloudflare: {instance.image.cloudflare_id}")
except Exception as e:
logger.error(
f"Failed to delete park photo from Cloudflare: {str(e)}")
# Continue with database deletion even if Cloudflare deletion fails
# Delete from Cloudflare first if image exists
if instance.image:
try:
from django_cloudflareimages_toolkit.services import (
CloudflareImagesService,
)
service = CloudflareImagesService()
service.delete_image(instance.image)
logger.info(
f"Successfully deleted park photo from Cloudflare: {instance.image.cloudflare_id}"
)
except ImportError:
logger.warning("CloudflareImagesService not available")
except ServiceError as e:
logger.error(f"Service error deleting from Cloudflare: {str(e)}")
# Continue with database deletion even if Cloudflare deletion fails
try:
ParkMediaService().delete_photo(
instance.id, deleted_by=cast(UserModel, self.request.user)
)
except Exception as e:
logger.error(f"Error deleting park photo: {e}")
except ServiceError as e:
logger.error(f"Service error deleting park photo: {e}")
raise ValidationError(f"Failed to delete photo: {str(e)}")
@extend_schema(
@@ -265,11 +286,18 @@ class ParkPhotoViewSet(ModelViewSet):
status=status.HTTP_200_OK,
)
except Exception as e:
logger.error(f"Error setting primary photo: {e}")
return Response(
{"error": f"Failed to set primary photo: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST,
except (ValidationException, ValidationError) as e:
logger.warning(f"Validation error setting primary photo: {e}")
return ErrorHandler.handle_api_error(
e,
user_message="Failed to set primary photo",
status_code=status.HTTP_400_BAD_REQUEST,
)
except ServiceError as e:
return ErrorHandler.handle_api_error(
e,
user_message="Failed to set primary photo",
status_code=status.HTTP_400_BAD_REQUEST,
)
@extend_schema(
@@ -319,11 +347,18 @@ class ParkPhotoViewSet(ModelViewSet):
status=status.HTTP_200_OK,
)
except Exception as e:
logger.error(f"Error in bulk photo approval: {e}")
return Response(
{"error": f"Failed to update photos: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST,
except (ValidationException, ValidationError) as e:
logger.warning(f"Validation error in bulk photo approval: {e}")
return ErrorHandler.handle_api_error(
e,
user_message="Failed to update photos",
status_code=status.HTTP_400_BAD_REQUEST,
)
except ServiceError as e:
return ErrorHandler.handle_api_error(
e,
user_message="Failed to update photos",
status_code=status.HTTP_400_BAD_REQUEST,
)
@extend_schema(
@@ -345,9 +380,10 @@ class ParkPhotoViewSet(ModelViewSet):
try:
park = Park.objects.get(pk=park_pk)
except Park.DoesNotExist:
return Response(
{"error": "Park not found."},
status=status.HTTP_404_NOT_FOUND,
return ErrorHandler.handle_api_error(
NotFoundError(f"Park with id {park_pk} not found"),
user_message="Park not found",
status_code=status.HTTP_404_NOT_FOUND,
)
try:
@@ -359,11 +395,11 @@ class ParkPhotoViewSet(ModelViewSet):
return Response(serializer.data, status=status.HTTP_200_OK)
except Exception as e:
logger.error(f"Error getting park photo stats: {e}")
return Response(
{"error": f"Failed to get photo statistics: {str(e)}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
except ServiceError as e:
return ErrorHandler.handle_api_error(
e,
user_message="Failed to get photo statistics",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
# Legacy compatibility action using the legacy set_primary logic
@@ -394,9 +430,19 @@ class ParkPhotoViewSet(ModelViewSet):
park_id=photo.park_id, photo_id=photo.id
)
return Response({"message": "Photo set as primary successfully."})
except Exception as e:
logger.error(f"Error in set_primary_photo: {str(e)}", exc_info=True)
return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST)
except (ValidationException, ValidationError) as e:
logger.warning(f"Validation error in set_primary_photo: {str(e)}")
return ErrorHandler.handle_api_error(
e,
user_message="Failed to set primary photo",
status_code=status.HTTP_400_BAD_REQUEST,
)
except ServiceError as e:
return ErrorHandler.handle_api_error(
e,
user_message="Failed to set primary photo",
status_code=status.HTTP_400_BAD_REQUEST,
)
@extend_schema(
summary="Save Cloudflare image as park photo",
@@ -442,60 +488,55 @@ class ParkPhotoViewSet(ModelViewSet):
from django.utils import timezone
# Always fetch the latest image data from Cloudflare API
# Get image details from Cloudflare API
service = CloudflareImagesService()
image_data = service.get_image(cloudflare_image_id)
if not image_data:
return ErrorHandler.handle_api_error(
NotFoundError("Image not found in Cloudflare"),
user_message="Image not found in Cloudflare",
status_code=status.HTTP_400_BAD_REQUEST,
)
# Try to find existing CloudflareImage record by cloudflare_id
cloudflare_image = None
try:
# Get image details from Cloudflare API
service = CloudflareImagesService()
image_data = service.get_image(cloudflare_image_id)
cloudflare_image = CloudflareImage.objects.get(
cloudflare_id=cloudflare_image_id
)
if not image_data:
return Response(
{"error": "Image not found in Cloudflare"},
status=status.HTTP_400_BAD_REQUEST,
)
# Update existing record with latest data from Cloudflare
cloudflare_image.status = "uploaded"
cloudflare_image.uploaded_at = timezone.now()
cloudflare_image.metadata = image_data.get("meta", {})
# Extract variants from nested result structure
cloudflare_image.variants = image_data.get("result", {}).get(
"variants", []
)
cloudflare_image.cloudflare_metadata = image_data
cloudflare_image.width = image_data.get("width")
cloudflare_image.height = image_data.get("height")
cloudflare_image.format = image_data.get("format", "")
cloudflare_image.save()
# Try to find existing CloudflareImage record by cloudflare_id
cloudflare_image = None
try:
cloudflare_image = CloudflareImage.objects.get(
cloudflare_id=cloudflare_image_id)
# Update existing record with latest data from Cloudflare
cloudflare_image.status = 'uploaded'
cloudflare_image.uploaded_at = timezone.now()
cloudflare_image.metadata = image_data.get('meta', {})
except CloudflareImage.DoesNotExist:
# Create new CloudflareImage record from API response
cloudflare_image = CloudflareImage.objects.create(
cloudflare_id=cloudflare_image_id,
user=request.user,
status="uploaded",
upload_url="", # Not needed for uploaded images
expires_at=timezone.now()
+ timezone.timedelta(days=365), # Set far future expiry
uploaded_at=timezone.now(),
metadata=image_data.get("meta", {}),
# Extract variants from nested result structure
cloudflare_image.variants = image_data.get(
'result', {}).get('variants', [])
cloudflare_image.cloudflare_metadata = image_data
cloudflare_image.width = image_data.get('width')
cloudflare_image.height = image_data.get('height')
cloudflare_image.format = image_data.get('format', '')
cloudflare_image.save()
except CloudflareImage.DoesNotExist:
# Create new CloudflareImage record from API response
cloudflare_image = CloudflareImage.objects.create(
cloudflare_id=cloudflare_image_id,
user=request.user,
status='uploaded',
upload_url='', # Not needed for uploaded images
expires_at=timezone.now() + timezone.timedelta(days=365), # Set far future expiry
uploaded_at=timezone.now(),
metadata=image_data.get('meta', {}),
# Extract variants from nested result structure
variants=image_data.get('result', {}).get('variants', []),
cloudflare_metadata=image_data,
width=image_data.get('width'),
height=image_data.get('height'),
format=image_data.get('format', ''),
)
except Exception as api_error:
logger.error(
f"Error fetching image from Cloudflare API: {str(api_error)}", exc_info=True)
return Response(
{"error": f"Failed to fetch image from Cloudflare: {str(api_error)}"},
status=status.HTTP_400_BAD_REQUEST,
variants=image_data.get("result", {}).get("variants", []),
cloudflare_metadata=image_data,
width=image_data.get("width"),
height=image_data.get("height"),
format=image_data.get("format", ""),
)
# Create the park photo with the CloudflareImage reference
@@ -516,25 +557,33 @@ class ParkPhotoViewSet(ModelViewSet):
ParkMediaService().set_primary_photo(
park_id=park.id, photo_id=photo.id
)
except Exception as e:
except ServiceError as e:
logger.error(f"Error setting primary photo: {e}")
# Don't fail the entire operation, just log the error
serializer = ParkPhotoOutputSerializer(photo, context={"request": request})
return Response(serializer.data, status=status.HTTP_201_CREATED)
except Exception as e:
logger.error(f"Error saving park photo: {e}")
return Response(
{"error": f"Failed to save photo: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST,
except ImportError:
logger.error("CloudflareImagesService not available")
return ErrorHandler.handle_api_error(
ServiceError("Cloudflare Images service not available"),
user_message="Image upload service not available",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
except (ValidationException, ValidationError) as e:
logger.warning(f"Validation error saving park photo: {e}")
return ErrorHandler.handle_api_error(
e,
user_message="Failed to save photo",
status_code=status.HTTP_400_BAD_REQUEST,
)
except ServiceError as e:
return ErrorHandler.handle_api_error(
e,
user_message="Failed to save photo",
status_code=status.HTTP_400_BAD_REQUEST,
)
from rest_framework.views import APIView
from rest_framework.permissions import AllowAny
from .serializers import HybridParkSerializer
from apps.parks.services.hybrid_loader import smart_park_loader
@extend_schema_view(
@@ -542,23 +591,79 @@ from apps.parks.services.hybrid_loader import smart_park_loader
summary="Get parks with hybrid filtering",
description="Retrieve parks with intelligent hybrid filtering strategy. Automatically chooses between client-side and server-side filtering based on data size.",
parameters=[
OpenApiParameter("status", OpenApiTypes.STR, description="Filter by park status (comma-separated for multiple)"),
OpenApiParameter("park_type", OpenApiTypes.STR, description="Filter by park type (comma-separated for multiple)"),
OpenApiParameter("country", OpenApiTypes.STR, description="Filter by country (comma-separated for multiple)"),
OpenApiParameter("state", OpenApiTypes.STR, description="Filter by state (comma-separated for multiple)"),
OpenApiParameter("opening_year_min", OpenApiTypes.INT, description="Minimum opening year"),
OpenApiParameter("opening_year_max", OpenApiTypes.INT, description="Maximum opening year"),
OpenApiParameter("size_min", OpenApiTypes.NUMBER, description="Minimum park size in acres"),
OpenApiParameter("size_max", OpenApiTypes.NUMBER, description="Maximum park size in acres"),
OpenApiParameter("rating_min", OpenApiTypes.NUMBER, description="Minimum average rating"),
OpenApiParameter("rating_max", OpenApiTypes.NUMBER, description="Maximum average rating"),
OpenApiParameter("ride_count_min", OpenApiTypes.INT, description="Minimum ride count"),
OpenApiParameter("ride_count_max", OpenApiTypes.INT, description="Maximum ride count"),
OpenApiParameter("coaster_count_min", OpenApiTypes.INT, description="Minimum coaster count"),
OpenApiParameter("coaster_count_max", OpenApiTypes.INT, description="Maximum coaster count"),
OpenApiParameter("operator", OpenApiTypes.STR, description="Filter by operator slug (comma-separated for multiple)"),
OpenApiParameter("search", OpenApiTypes.STR, description="Search query for park names, descriptions, locations, and operators"),
OpenApiParameter("offset", OpenApiTypes.INT, description="Offset for progressive loading (server-side pagination)"),
OpenApiParameter(
"status",
OpenApiTypes.STR,
description="Filter by park status (comma-separated for multiple)",
),
OpenApiParameter(
"park_type",
OpenApiTypes.STR,
description="Filter by park type (comma-separated for multiple)",
),
OpenApiParameter(
"country",
OpenApiTypes.STR,
description="Filter by country (comma-separated for multiple)",
),
OpenApiParameter(
"state",
OpenApiTypes.STR,
description="Filter by state (comma-separated for multiple)",
),
OpenApiParameter(
"opening_year_min", OpenApiTypes.INT, description="Minimum opening year"
),
OpenApiParameter(
"opening_year_max", OpenApiTypes.INT, description="Maximum opening year"
),
OpenApiParameter(
"size_min",
OpenApiTypes.NUMBER,
description="Minimum park size in acres",
),
OpenApiParameter(
"size_max",
OpenApiTypes.NUMBER,
description="Maximum park size in acres",
),
OpenApiParameter(
"rating_min", OpenApiTypes.NUMBER, description="Minimum average rating"
),
OpenApiParameter(
"rating_max", OpenApiTypes.NUMBER, description="Maximum average rating"
),
OpenApiParameter(
"ride_count_min", OpenApiTypes.INT, description="Minimum ride count"
),
OpenApiParameter(
"ride_count_max", OpenApiTypes.INT, description="Maximum ride count"
),
OpenApiParameter(
"coaster_count_min",
OpenApiTypes.INT,
description="Minimum coaster count",
),
OpenApiParameter(
"coaster_count_max",
OpenApiTypes.INT,
description="Maximum coaster count",
),
OpenApiParameter(
"operator",
OpenApiTypes.STR,
description="Filter by operator slug (comma-separated for multiple)",
),
OpenApiParameter(
"search",
OpenApiTypes.STR,
description="Search query for park names, descriptions, locations, and operators",
),
OpenApiParameter(
"offset",
OpenApiTypes.INT,
description="Offset for progressive loading (server-side pagination)",
),
],
responses={
200: {
@@ -570,31 +675,33 @@ from apps.parks.services.hybrid_loader import smart_park_loader
"properties": {
"parks": {
"type": "array",
"items": {"$ref": "#/components/schemas/HybridParkSerializer"}
"items": {
"$ref": "#/components/schemas/HybridParkSerializer"
},
},
"total_count": {"type": "integer"},
"strategy": {
"type": "string",
"enum": ["client_side", "server_side"],
"description": "Filtering strategy used"
"description": "Filtering strategy used",
},
"has_more": {
"type": "boolean",
"description": "Whether more data is available for progressive loading"
"description": "Whether more data is available for progressive loading",
},
"next_offset": {
"type": "integer",
"nullable": True,
"description": "Next offset for progressive loading"
"description": "Next offset for progressive loading",
},
"filter_metadata": {
"type": "object",
"description": "Available filter options and ranges"
}
}
"description": "Available filter options and ranges",
},
},
}
}
}
},
}
},
tags=["Parks"],
@@ -603,77 +710,83 @@ from apps.parks.services.hybrid_loader import smart_park_loader
class HybridParkAPIView(APIView):
"""
Hybrid Park API View with intelligent filtering strategy.
Automatically chooses between client-side and server-side filtering
based on data size and complexity. Provides progressive loading
for large datasets and complete data for smaller sets.
"""
permission_classes = [AllowAny]
def get(self, request):
"""Get parks with hybrid filtering strategy."""
# Extract filters from query parameters
filters = self._extract_filters(request.query_params)
# Check if this is a progressive load request
offset = request.query_params.get("offset")
if offset is not None:
try:
offset = int(offset)
except ValueError:
return ErrorHandler.handle_api_error(
ValidationException("Invalid offset parameter"),
user_message="Invalid offset parameter",
status_code=status.HTTP_400_BAD_REQUEST,
)
try:
# Extract filters from query parameters
filters = self._extract_filters(request.query_params)
# Check if this is a progressive load request
offset = request.query_params.get('offset')
if offset is not None:
try:
offset = int(offset)
# Get progressive load data
data = smart_park_loader.get_progressive_load(offset, filters)
except ValueError:
return Response(
{"error": "Invalid offset parameter"},
status=status.HTTP_400_BAD_REQUEST
)
# Get progressive load data
data = smart_park_loader.get_progressive_load(offset, filters)
else:
# Get initial load data
data = smart_park_loader.get_initial_load(filters)
# Serialize the parks data
serializer = HybridParkSerializer(data['parks'], many=True)
serializer = HybridParkSerializer(data["parks"], many=True)
# Prepare response
response_data = {
'parks': serializer.data,
'total_count': data['total_count'],
'strategy': data.get('strategy', 'server_side'),
'has_more': data.get('has_more', False),
'next_offset': data.get('next_offset'),
"parks": serializer.data,
"total_count": data["total_count"],
"strategy": data.get("strategy", "server_side"),
"has_more": data.get("has_more", False),
"next_offset": data.get("next_offset"),
}
# Include filter metadata for initial loads
if 'filter_metadata' in data:
response_data['filter_metadata'] = data['filter_metadata']
if "filter_metadata" in data:
response_data["filter_metadata"] = data["filter_metadata"]
return Response(response_data, status=status.HTTP_200_OK)
except Exception as e:
logger.error(f"Error in HybridParkAPIView: {e}")
return Response(
{"error": "Internal server error"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
except ServiceError as e:
return ErrorHandler.handle_api_error(
e,
user_message="Failed to load parks",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
def _extract_filters(self, query_params):
"""Extract and parse filters from query parameters."""
filters = {}
# Handle comma-separated list parameters
list_params = ['status', 'park_type', 'country', 'state', 'operator']
list_params = ["status", "park_type", "country", "state", "operator"]
for param in list_params:
value = query_params.get(param)
if value:
filters[param] = [v.strip() for v in value.split(',') if v.strip()]
filters[param] = [v.strip() for v in value.split(",") if v.strip()]
# Handle integer parameters
int_params = [
'opening_year_min', 'opening_year_max',
'ride_count_min', 'ride_count_max',
'coaster_count_min', 'coaster_count_max'
"opening_year_min",
"opening_year_max",
"ride_count_min",
"ride_count_max",
"coaster_count_min",
"coaster_count_max",
]
for param in int_params:
value = query_params.get(param)
@@ -682,9 +795,9 @@ class HybridParkAPIView(APIView):
filters[param] = int(value)
except ValueError:
pass # Skip invalid integer values
# Handle float parameters
float_params = ['size_min', 'size_max', 'rating_min', 'rating_max']
float_params = ["size_min", "size_max", "rating_min", "rating_max"]
for param in float_params:
value = query_params.get(param)
if value:
@@ -692,12 +805,12 @@ class HybridParkAPIView(APIView):
filters[param] = float(value)
except ValueError:
pass # Skip invalid float values
# Handle search parameter
search = query_params.get('search')
search = query_params.get("search")
if search:
filters['search'] = search.strip()
filters["search"] = search.strip()
return filters
@@ -706,7 +819,11 @@ class HybridParkAPIView(APIView):
summary="Get park filter metadata",
description="Get available filter options and ranges for parks filtering.",
parameters=[
OpenApiParameter("scoped", OpenApiTypes.BOOL, description="Whether to scope metadata to current filters"),
OpenApiParameter(
"scoped",
OpenApiTypes.BOOL,
description="Whether to scope metadata to current filters",
),
],
responses={
200: {
@@ -719,21 +836,33 @@ class HybridParkAPIView(APIView):
"categorical": {
"type": "object",
"properties": {
"countries": {"type": "array", "items": {"type": "string"}},
"states": {"type": "array", "items": {"type": "string"}},
"park_types": {"type": "array", "items": {"type": "string"}},
"statuses": {"type": "array", "items": {"type": "string"}},
"countries": {
"type": "array",
"items": {"type": "string"},
},
"states": {
"type": "array",
"items": {"type": "string"},
},
"park_types": {
"type": "array",
"items": {"type": "string"},
},
"statuses": {
"type": "array",
"items": {"type": "string"},
},
"operators": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {"type": "string"},
"slug": {"type": "string"}
}
}
}
}
"slug": {"type": "string"},
},
},
},
},
},
"ranges": {
"type": "object",
@@ -741,45 +870,75 @@ class HybridParkAPIView(APIView):
"opening_year": {
"type": "object",
"properties": {
"min": {"type": "integer", "nullable": True},
"max": {"type": "integer", "nullable": True}
}
"min": {
"type": "integer",
"nullable": True,
},
"max": {
"type": "integer",
"nullable": True,
},
},
},
"size_acres": {
"type": "object",
"properties": {
"min": {"type": "number", "nullable": True},
"max": {"type": "number", "nullable": True}
}
"min": {
"type": "number",
"nullable": True,
},
"max": {
"type": "number",
"nullable": True,
},
},
},
"average_rating": {
"type": "object",
"properties": {
"min": {"type": "number", "nullable": True},
"max": {"type": "number", "nullable": True}
}
"min": {
"type": "number",
"nullable": True,
},
"max": {
"type": "number",
"nullable": True,
},
},
},
"ride_count": {
"type": "object",
"properties": {
"min": {"type": "integer", "nullable": True},
"max": {"type": "integer", "nullable": True}
}
"min": {
"type": "integer",
"nullable": True,
},
"max": {
"type": "integer",
"nullable": True,
},
},
},
"coaster_count": {
"type": "object",
"properties": {
"min": {"type": "integer", "nullable": True},
"max": {"type": "integer", "nullable": True}
}
}
}
"min": {
"type": "integer",
"nullable": True,
},
"max": {
"type": "integer",
"nullable": True,
},
},
},
},
},
"total_count": {"type": "integer"}
}
"total_count": {"type": "integer"},
},
}
}
}
},
}
},
tags=["Parks"],
@@ -788,35 +947,34 @@ class HybridParkAPIView(APIView):
class ParkFilterMetadataAPIView(APIView):
"""
API view for getting park filter metadata.
Provides information about available filter options and ranges
to help build dynamic filter interfaces.
"""
permission_classes = [AllowAny]
def get(self, request):
"""Get park filter metadata."""
# Check if metadata should be scoped to current filters
scoped = request.query_params.get("scoped", "").lower() == "true"
filters = None
if scoped:
filters = self._extract_filters(request.query_params)
try:
# Check if metadata should be scoped to current filters
scoped = request.query_params.get('scoped', '').lower() == 'true'
filters = None
if scoped:
filters = self._extract_filters(request.query_params)
# Get filter metadata
metadata = smart_park_loader.get_filter_metadata(filters)
return Response(metadata, status=status.HTTP_200_OK)
except Exception as e:
logger.error(f"Error in ParkFilterMetadataAPIView: {e}")
return Response(
{"error": "Internal server error"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
except ServiceError as e:
return ErrorHandler.handle_api_error(
e,
user_message="Failed to get filter metadata",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
def _extract_filters(self, query_params):
"""Extract and parse filters from query parameters."""
# Reuse the same filter extraction logic

File diff suppressed because it is too large Load Diff

View File

@@ -365,7 +365,7 @@ class MapLocationDetailSerializer(serializers.Serializer):
@extend_schema_field(serializers.ListField(child=serializers.DictField()))
def get_nearby_locations(self, obj) -> list:
"""Get nearby locations (placeholder for now)."""
# TODO: Implement nearby location logic
# TODO(THRILLWIKI-107): Implement nearby location logic using spatial queries
return []

View File

@@ -4,3 +4,12 @@ from django.apps import AppConfig
class CoreConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField"
name = "apps.core"
def ready(self):
"""
Application initialization.
Imports security checks to register them with Django's check framework.
"""
# Import security checks to register them
from . import checks # noqa: F401

372
backend/apps/core/checks.py Normal file
View File

@@ -0,0 +1,372 @@
"""
Django System Checks for Security Configuration.
This module implements Django system checks that validate security settings
at startup. These checks help catch security misconfigurations before
deployment.
Usage:
These checks run automatically when Django starts. They can also be run
manually with: python manage.py check --tag=security
Security checks included:
- SECRET_KEY validation (not default, sufficient entropy)
- DEBUG mode check (should be False in production)
- ALLOWED_HOSTS check (should be configured in production)
- Security headers validation
- HTTPS settings validation
- Cookie security settings
"""
import os
import re
from django.conf import settings
from django.core.checks import Error, Warning, register, Tags
# =============================================================================
# Secret Key Validation
# =============================================================================
@register(Tags.security)
def check_secret_key(app_configs, **kwargs):
"""
Check that SECRET_KEY is properly configured.
Validates:
- Key is not a known default/placeholder value
- Key has sufficient entropy (length and character variety)
"""
errors = []
secret_key = getattr(settings, 'SECRET_KEY', '')
# Check for empty or missing key
if not secret_key:
errors.append(
Error(
'SECRET_KEY is not set.',
hint='Set a strong, random SECRET_KEY in your environment.',
id='security.E001',
)
)
return errors
# Check for known insecure default values
insecure_defaults = [
'django-insecure',
'your-secret-key',
'change-me',
'changeme',
'secret',
'xxx',
'test',
'development',
'dev-key',
]
key_lower = secret_key.lower()
for default in insecure_defaults:
if default in key_lower:
errors.append(
Error(
f'SECRET_KEY appears to contain an insecure default value: "{default}"',
hint='Generate a new secret key using: python -c "from django.core.management.utils import get_random_secret_key; print(get_random_secret_key())"',
id='security.E002',
)
)
break
# Check minimum length (Django recommends at least 50 characters)
if len(secret_key) < 50:
errors.append(
Warning(
f'SECRET_KEY is only {len(secret_key)} characters long.',
hint='A secret key should be at least 50 characters for proper security.',
id='security.W001',
)
)
# Check for sufficient character variety
has_upper = bool(re.search(r'[A-Z]', secret_key))
has_lower = bool(re.search(r'[a-z]', secret_key))
has_digit = bool(re.search(r'[0-9]', secret_key))
has_special = bool(re.search(r'[!@#$%^&*()_+\-=\[\]{};\':"\\|,.<>\/?]', secret_key))
char_types = sum([has_upper, has_lower, has_digit, has_special])
if char_types < 3:
errors.append(
Warning(
'SECRET_KEY lacks character variety.',
hint='A good secret key should contain uppercase, lowercase, digits, and special characters.',
id='security.W002',
)
)
return errors
# =============================================================================
# Debug Mode Check
# =============================================================================
@register(Tags.security)
def check_debug_mode(app_configs, **kwargs):
"""
Check that DEBUG is False in production-like environments.
"""
errors = []
# Check if we're in a production-like environment
env = os.environ.get('DJANGO_SETTINGS_MODULE', '')
is_production = 'production' in env.lower() or 'prod' in env.lower()
if is_production and settings.DEBUG:
errors.append(
Error(
'DEBUG is True in what appears to be a production environment.',
hint='Set DEBUG=False in production settings.',
id='security.E003',
)
)
# Also check if DEBUG is True with ALLOWED_HOSTS configured
# (indicates possible production deployment with debug on)
if settings.DEBUG and settings.ALLOWED_HOSTS and '*' not in settings.ALLOWED_HOSTS:
if len(settings.ALLOWED_HOSTS) > 0 and 'localhost' not in settings.ALLOWED_HOSTS[0]:
errors.append(
Warning(
'DEBUG is True but ALLOWED_HOSTS contains non-localhost values.',
hint='This may indicate DEBUG is accidentally enabled in a deployed environment.',
id='security.W003',
)
)
return errors
# =============================================================================
# ALLOWED_HOSTS Check
# =============================================================================
@register(Tags.security)
def check_allowed_hosts(app_configs, **kwargs):
"""
Check ALLOWED_HOSTS configuration.
"""
errors = []
allowed_hosts = getattr(settings, 'ALLOWED_HOSTS', [])
if not settings.DEBUG:
# In non-debug mode, ALLOWED_HOSTS must be set
if not allowed_hosts:
errors.append(
Error(
'ALLOWED_HOSTS is empty but DEBUG is False.',
hint='Set ALLOWED_HOSTS to a list of allowed hostnames.',
id='security.E004',
)
)
elif '*' in allowed_hosts:
errors.append(
Error(
'ALLOWED_HOSTS contains "*" which allows all hosts.',
hint='Specify explicit hostnames instead of wildcards.',
id='security.E005',
)
)
return errors
# =============================================================================
# Security Headers Check
# =============================================================================
@register(Tags.security)
def check_security_headers(app_configs, **kwargs):
"""
Check that security headers are properly configured.
"""
errors = []
# Check X-Frame-Options
x_frame_options = getattr(settings, 'X_FRAME_OPTIONS', None)
if x_frame_options not in ('DENY', 'SAMEORIGIN'):
errors.append(
Warning(
f'X_FRAME_OPTIONS is set to "{x_frame_options}" or not set.',
hint='Set X_FRAME_OPTIONS to "DENY" or "SAMEORIGIN" to prevent clickjacking.',
id='security.W004',
)
)
# Check content type sniffing protection
if not getattr(settings, 'SECURE_CONTENT_TYPE_NOSNIFF', False):
errors.append(
Warning(
'SECURE_CONTENT_TYPE_NOSNIFF is not enabled.',
hint='Set SECURE_CONTENT_TYPE_NOSNIFF = True to prevent MIME type sniffing.',
id='security.W005',
)
)
# Check referrer policy
referrer_policy = getattr(settings, 'SECURE_REFERRER_POLICY', None)
if not referrer_policy:
errors.append(
Warning(
'SECURE_REFERRER_POLICY is not set.',
hint='Set SECURE_REFERRER_POLICY to control referrer header behavior.',
id='security.W006',
)
)
return errors
# =============================================================================
# HTTPS Settings Check
# =============================================================================
@register(Tags.security)
def check_https_settings(app_configs, **kwargs):
"""
Check HTTPS-related security settings for production.
"""
errors = []
# Skip these checks in debug mode
if settings.DEBUG:
return errors
# Check SSL redirect
if not getattr(settings, 'SECURE_SSL_REDIRECT', False):
errors.append(
Warning(
'SECURE_SSL_REDIRECT is not enabled.',
hint='Set SECURE_SSL_REDIRECT = True to redirect HTTP to HTTPS.',
id='security.W007',
)
)
# Check HSTS settings
hsts_seconds = getattr(settings, 'SECURE_HSTS_SECONDS', 0)
if hsts_seconds < 31536000: # Less than 1 year
errors.append(
Warning(
f'SECURE_HSTS_SECONDS is {hsts_seconds} (less than 1 year).',
hint='Set SECURE_HSTS_SECONDS to at least 31536000 (1 year) for HSTS preload eligibility.',
id='security.W008',
)
)
if not getattr(settings, 'SECURE_HSTS_INCLUDE_SUBDOMAINS', False):
errors.append(
Warning(
'SECURE_HSTS_INCLUDE_SUBDOMAINS is not enabled.',
hint='Set SECURE_HSTS_INCLUDE_SUBDOMAINS = True to include all subdomains in HSTS.',
id='security.W009',
)
)
return errors
# =============================================================================
# Cookie Security Check
# =============================================================================
@register(Tags.security)
def check_cookie_security(app_configs, **kwargs):
"""
Check cookie security settings for production.
"""
errors = []
# Skip in debug mode
if settings.DEBUG:
return errors
# Check session cookie security
if not getattr(settings, 'SESSION_COOKIE_SECURE', False):
errors.append(
Warning(
'SESSION_COOKIE_SECURE is not enabled.',
hint='Set SESSION_COOKIE_SECURE = True to only send session cookies over HTTPS.',
id='security.W010',
)
)
if not getattr(settings, 'SESSION_COOKIE_HTTPONLY', True):
errors.append(
Warning(
'SESSION_COOKIE_HTTPONLY is disabled.',
hint='Set SESSION_COOKIE_HTTPONLY = True to prevent JavaScript access to session cookies.',
id='security.W011',
)
)
# Check CSRF cookie security
if not getattr(settings, 'CSRF_COOKIE_SECURE', False):
errors.append(
Warning(
'CSRF_COOKIE_SECURE is not enabled.',
hint='Set CSRF_COOKIE_SECURE = True to only send CSRF cookies over HTTPS.',
id='security.W012',
)
)
# Check SameSite attributes
session_samesite = getattr(settings, 'SESSION_COOKIE_SAMESITE', 'Lax')
if session_samesite not in ('Strict', 'Lax'):
errors.append(
Warning(
f'SESSION_COOKIE_SAMESITE is set to "{session_samesite}".',
hint='Set SESSION_COOKIE_SAMESITE to "Strict" or "Lax" for CSRF protection.',
id='security.W013',
)
)
return errors
# =============================================================================
# Database Security Check
# =============================================================================
@register(Tags.security)
def check_database_security(app_configs, **kwargs):
"""
Check database connection security settings.
"""
errors = []
# Skip in debug mode
if settings.DEBUG:
return errors
databases = getattr(settings, 'DATABASES', {})
default_db = databases.get('default', {})
# Check for empty password
if not default_db.get('PASSWORD') and default_db.get('ENGINE', '').endswith('postgresql'):
errors.append(
Warning(
'Database password is empty.',
hint='Set a strong password for database authentication.',
id='security.W014',
)
)
# Check for SSL mode in PostgreSQL
options = default_db.get('OPTIONS', {})
if 'sslmode' not in str(options) and default_db.get('ENGINE', '').endswith('postgresql'):
errors.append(
Warning(
'Database SSL mode is not explicitly configured.',
hint='Consider setting sslmode in database OPTIONS for encrypted connections.',
id='security.W015',
)
)
return errors

View File

@@ -65,6 +65,14 @@ class BusinessLogicError(ThrillWikiException):
status_code = 400
class ServiceError(ThrillWikiException):
"""Raised when a service operation fails."""
default_message = "Service operation failed"
error_code = "SERVICE_ERROR"
status_code = 500
class ExternalServiceError(ThrillWikiException):
"""Raised when external service calls fail."""

View File

@@ -0,0 +1,240 @@
"""
Security Audit Management Command.
Runs comprehensive security checks on the Django application and generates
a security audit report.
Usage:
python manage.py security_audit
python manage.py security_audit --output report.txt
python manage.py security_audit --verbose
"""
from django.core.management.base import BaseCommand
from django.core.checks import registry, Tags
from django.conf import settings
class Command(BaseCommand):
help = 'Run security audit and generate a report'
def add_arguments(self, parser):
parser.add_argument(
'--output',
type=str,
help='Output file for the security report',
)
parser.add_argument(
'--verbose',
action='store_true',
help='Show detailed information for each check',
)
def handle(self, *args, **options):
self.verbose = options.get('verbose', False)
output_file = options.get('output')
report_lines = []
self.log("=" * 60, report_lines)
self.log("ThrillWiki Security Audit Report", report_lines)
self.log("=" * 60, report_lines)
self.log("", report_lines)
# Run Django's built-in security checks
self.log("Running Django Security Checks...", report_lines)
self.log("-" * 40, report_lines)
self.run_django_checks(report_lines)
# Run custom configuration checks
self.log("", report_lines)
self.log("Configuration Analysis...", report_lines)
self.log("-" * 40, report_lines)
self.check_configuration(report_lines)
# Run middleware checks
self.log("", report_lines)
self.log("Middleware Analysis...", report_lines)
self.log("-" * 40, report_lines)
self.check_middleware(report_lines)
# Summary
self.log("", report_lines)
self.log("=" * 60, report_lines)
self.log("Audit Complete", report_lines)
self.log("=" * 60, report_lines)
# Write to file if specified
if output_file:
with open(output_file, 'w') as f:
f.write('\n'.join(report_lines))
self.stdout.write(
self.style.SUCCESS(f'\nReport saved to: {output_file}')
)
def log(self, message, report_lines):
"""Log message to both stdout and report."""
self.stdout.write(message)
report_lines.append(message)
def run_django_checks(self, report_lines):
"""Run Django's security checks."""
errors = registry.run_checks(tags=[Tags.security])
if not errors:
self.log(
self.style.SUCCESS(" ✓ All Django security checks passed"),
report_lines
)
else:
for error in errors:
if error.is_serious():
prefix = self.style.ERROR(" ✗ ERROR")
else:
prefix = self.style.WARNING(" ! WARNING")
self.log(f"{prefix}: {error.msg}", report_lines)
if error.hint and self.verbose:
self.log(f" Hint: {error.hint}", report_lines)
def check_configuration(self, report_lines):
"""Check various configuration settings."""
checks = [
('DEBUG mode', not settings.DEBUG, 'DEBUG should be False'),
(
'SECRET_KEY length',
len(settings.SECRET_KEY) >= 50,
f'Length: {len(settings.SECRET_KEY)}'
),
(
'ALLOWED_HOSTS',
bool(settings.ALLOWED_HOSTS) and '*' not in settings.ALLOWED_HOSTS,
str(settings.ALLOWED_HOSTS)
),
(
'CSRF_TRUSTED_ORIGINS',
bool(getattr(settings, 'CSRF_TRUSTED_ORIGINS', [])),
str(getattr(settings, 'CSRF_TRUSTED_ORIGINS', []))
),
(
'X_FRAME_OPTIONS',
getattr(settings, 'X_FRAME_OPTIONS', '') in ('DENY', 'SAMEORIGIN'),
str(getattr(settings, 'X_FRAME_OPTIONS', 'Not set'))
),
(
'SECURE_CONTENT_TYPE_NOSNIFF',
getattr(settings, 'SECURE_CONTENT_TYPE_NOSNIFF', False),
str(getattr(settings, 'SECURE_CONTENT_TYPE_NOSNIFF', False))
),
(
'SECURE_BROWSER_XSS_FILTER',
getattr(settings, 'SECURE_BROWSER_XSS_FILTER', False),
str(getattr(settings, 'SECURE_BROWSER_XSS_FILTER', False))
),
(
'SESSION_COOKIE_HTTPONLY',
getattr(settings, 'SESSION_COOKIE_HTTPONLY', True),
str(getattr(settings, 'SESSION_COOKIE_HTTPONLY', 'Not set'))
),
(
'CSRF_COOKIE_HTTPONLY',
getattr(settings, 'CSRF_COOKIE_HTTPONLY', True),
str(getattr(settings, 'CSRF_COOKIE_HTTPONLY', 'Not set'))
),
]
# Production-only checks
if not settings.DEBUG:
checks.extend([
(
'SECURE_SSL_REDIRECT',
getattr(settings, 'SECURE_SSL_REDIRECT', False),
str(getattr(settings, 'SECURE_SSL_REDIRECT', False))
),
(
'SESSION_COOKIE_SECURE',
getattr(settings, 'SESSION_COOKIE_SECURE', False),
str(getattr(settings, 'SESSION_COOKIE_SECURE', False))
),
(
'CSRF_COOKIE_SECURE',
getattr(settings, 'CSRF_COOKIE_SECURE', False),
str(getattr(settings, 'CSRF_COOKIE_SECURE', False))
),
(
'SECURE_HSTS_SECONDS',
getattr(settings, 'SECURE_HSTS_SECONDS', 0) >= 31536000,
str(getattr(settings, 'SECURE_HSTS_SECONDS', 0))
),
])
for name, is_secure, value in checks:
if is_secure:
status = self.style.SUCCESS("")
else:
status = self.style.WARNING("!")
msg = f" {status} {name}"
if self.verbose:
msg += f" ({value})"
self.log(msg, report_lines)
def check_middleware(self, report_lines):
"""Check security-related middleware is properly configured."""
middleware = getattr(settings, 'MIDDLEWARE', [])
required_middleware = [
('django.middleware.security.SecurityMiddleware', 'SecurityMiddleware'),
('django.middleware.csrf.CsrfViewMiddleware', 'CSRF Middleware'),
('django.middleware.clickjacking.XFrameOptionsMiddleware', 'X-Frame-Options'),
]
custom_security_middleware = [
('apps.core.middleware.security_headers.SecurityHeadersMiddleware', 'Security Headers'),
('apps.core.middleware.rate_limiting.AuthRateLimitMiddleware', 'Rate Limiting'),
]
# Check required middleware
for mw_path, mw_name in required_middleware:
if mw_path in middleware:
self.log(
f" {self.style.SUCCESS('')} {mw_name} is enabled",
report_lines
)
else:
self.log(
f" {self.style.ERROR('')} {mw_name} is NOT enabled",
report_lines
)
# Check custom security middleware
for mw_path, mw_name in custom_security_middleware:
if mw_path in middleware:
self.log(
f" {self.style.SUCCESS('')} {mw_name} is enabled",
report_lines
)
else:
self.log(
f" {self.style.WARNING('!')} {mw_name} is not enabled (optional)",
report_lines
)
# Check middleware order
try:
security_idx = middleware.index('django.middleware.security.SecurityMiddleware')
session_idx = middleware.index('django.contrib.sessions.middleware.SessionMiddleware')
if security_idx < session_idx:
self.log(
f" {self.style.SUCCESS('')} Middleware ordering is correct",
report_lines
)
else:
self.log(
f" {self.style.WARNING('!')} SecurityMiddleware should come before SessionMiddleware",
report_lines
)
except ValueError:
pass # Middleware not found, already reported above

View File

@@ -0,0 +1,253 @@
"""
Rate Limiting Middleware for ThrillWiki.
This middleware provides rate limiting for authentication endpoints to prevent
brute force attacks, credential stuffing, and account enumeration.
Security Note:
Rate limiting is applied at the IP level and user level (if authenticated).
Limits are configurable and should be adjusted based on actual usage patterns.
Usage:
Add 'apps.core.middleware.rate_limiting.AuthRateLimitMiddleware'
to MIDDLEWARE in settings.py.
"""
import logging
from typing import Callable, Optional, Tuple
from django.core.cache import cache
from django.http import HttpRequest, HttpResponse, JsonResponse
from django.conf import settings
logger = logging.getLogger(__name__)
class AuthRateLimitMiddleware:
"""
Middleware that rate limits authentication-related endpoints.
Protects against:
- Brute force login attacks
- Password reset abuse
- Account enumeration through timing attacks
"""
# Endpoints to rate limit
RATE_LIMITED_PATHS = {
# Login endpoints
'/api/v1/auth/login/': {'per_minute': 5, 'per_hour': 30, 'per_day': 100},
'/accounts/login/': {'per_minute': 5, 'per_hour': 30, 'per_day': 100},
# Signup endpoints
'/api/v1/auth/signup/': {'per_minute': 3, 'per_hour': 10, 'per_day': 20},
'/accounts/signup/': {'per_minute': 3, 'per_hour': 10, 'per_day': 20},
# Password reset endpoints
'/api/v1/auth/password-reset/': {'per_minute': 2, 'per_hour': 5, 'per_day': 10},
'/accounts/password/reset/': {'per_minute': 2, 'per_hour': 5, 'per_day': 10},
# Token endpoints
'/api/v1/auth/token/': {'per_minute': 10, 'per_hour': 60, 'per_day': 200},
'/api/v1/auth/token/refresh/': {'per_minute': 20, 'per_hour': 120, 'per_day': 500},
}
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
self.get_response = get_response
def __call__(self, request: HttpRequest) -> HttpResponse:
# Only rate limit POST requests to auth endpoints
if request.method != 'POST':
return self.get_response(request)
# Check if this path should be rate limited
limits = self._get_rate_limits(request.path)
if not limits:
return self.get_response(request)
# Get client identifier (IP address)
client_ip = self._get_client_ip(request)
# Check rate limits
is_allowed, message = self._check_rate_limits(
client_ip, request.path, limits
)
if not is_allowed:
logger.warning(
f"Rate limit exceeded for {client_ip} on {request.path}"
)
return self._rate_limit_response(message)
# Process request
response = self.get_response(request)
# Only increment counter for failed auth attempts (non-2xx responses)
if response.status_code >= 400:
self._increment_counters(client_ip, request.path)
return response
def _get_rate_limits(self, path: str) -> Optional[dict]:
"""Get rate limits for a path, if any."""
# Exact match
if path in self.RATE_LIMITED_PATHS:
return self.RATE_LIMITED_PATHS[path]
# Prefix match (for paths with trailing slashes)
path_without_slash = path.rstrip('/')
for limited_path, limits in self.RATE_LIMITED_PATHS.items():
if path_without_slash == limited_path.rstrip('/'):
return limits
return None
def _get_client_ip(self, request: HttpRequest) -> str:
"""
Get the client's IP address from the request.
Handles common proxy headers (X-Forwarded-For, X-Real-IP).
"""
# Check for forwarded headers (set by reverse proxies)
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
if x_forwarded_for:
# Take the first IP in the chain (client IP)
return x_forwarded_for.split(',')[0].strip()
x_real_ip = request.META.get('HTTP_X_REAL_IP')
if x_real_ip:
return x_real_ip
return request.META.get('REMOTE_ADDR', 'unknown')
def _check_rate_limits(
self,
client_ip: str,
path: str,
limits: dict
) -> Tuple[bool, str]:
"""
Check if the client has exceeded rate limits.
Returns:
Tuple of (is_allowed, reason_if_blocked)
"""
# Create a safe cache key from path
path_key = path.replace('/', '_').strip('_')
# Check per-minute limit
minute_key = f"auth_rate:{client_ip}:{path_key}:minute"
minute_count = cache.get(minute_key, 0)
if minute_count >= limits.get('per_minute', 10):
return False, "Too many requests. Please wait a minute before trying again."
# Check per-hour limit
hour_key = f"auth_rate:{client_ip}:{path_key}:hour"
hour_count = cache.get(hour_key, 0)
if hour_count >= limits.get('per_hour', 60):
return False, "Too many requests. Please try again later."
# Check per-day limit
day_key = f"auth_rate:{client_ip}:{path_key}:day"
day_count = cache.get(day_key, 0)
if day_count >= limits.get('per_day', 200):
return False, "Daily limit exceeded. Please try again tomorrow."
return True, ""
def _increment_counters(self, client_ip: str, path: str) -> None:
"""Increment rate limit counters."""
path_key = path.replace('/', '_').strip('_')
# Increment per-minute counter
minute_key = f"auth_rate:{client_ip}:{path_key}:minute"
try:
cache.incr(minute_key)
except ValueError:
cache.set(minute_key, 1, 60)
# Increment per-hour counter
hour_key = f"auth_rate:{client_ip}:{path_key}:hour"
try:
cache.incr(hour_key)
except ValueError:
cache.set(hour_key, 1, 3600)
# Increment per-day counter
day_key = f"auth_rate:{client_ip}:{path_key}:day"
try:
cache.incr(day_key)
except ValueError:
cache.set(day_key, 1, 86400)
def _rate_limit_response(self, message: str) -> JsonResponse:
"""Generate a rate limit exceeded response."""
return JsonResponse(
{
'error': message,
'code': 'RATE_LIMIT_EXCEEDED',
},
status=429, # Too Many Requests
)
class SecurityEventLogger:
"""
Utility class for logging security-relevant events.
Use this to log:
- Failed authentication attempts
- Permission denied events
- Suspicious activity
"""
@staticmethod
def log_failed_login(
request: HttpRequest,
username: str,
reason: str = "Invalid credentials"
) -> None:
"""Log a failed login attempt."""
client_ip = AuthRateLimitMiddleware._get_client_ip(
AuthRateLimitMiddleware, request
)
logger.warning(
f"Failed login attempt - IP: {client_ip}, Username: {username}, "
f"Reason: {reason}, User-Agent: {request.META.get('HTTP_USER_AGENT', 'unknown')}"
)
@staticmethod
def log_permission_denied(
request: HttpRequest,
resource: str,
action: str = "access"
) -> None:
"""Log a permission denied event."""
client_ip = AuthRateLimitMiddleware._get_client_ip(
AuthRateLimitMiddleware, request
)
user = getattr(request, 'user', None)
username = user.username if user and user.is_authenticated else 'anonymous'
logger.warning(
f"Permission denied - IP: {client_ip}, User: {username}, "
f"Resource: {resource}, Action: {action}"
)
@staticmethod
def log_suspicious_activity(
request: HttpRequest,
activity_type: str,
details: str = ""
) -> None:
"""Log suspicious activity."""
client_ip = AuthRateLimitMiddleware._get_client_ip(
AuthRateLimitMiddleware, request
)
user = getattr(request, 'user', None)
username = user.username if user and user.is_authenticated else 'anonymous'
logger.error(
f"Suspicious activity detected - Type: {activity_type}, "
f"IP: {client_ip}, User: {username}, Details: {details}"
)

View File

@@ -114,18 +114,52 @@ class RequestLoggingMiddleware(MiddlewareMixin):
return response
# Sensitive field patterns that should be masked in logs
# Security: Comprehensive list of sensitive data patterns
SENSITIVE_PATTERNS = [
'password',
'passwd',
'pwd',
'token',
'secret',
'key',
'api_key',
'apikey',
'auth',
'authorization',
'credential',
'ssn',
'social_security',
'credit_card',
'creditcard',
'card_number',
'cvv',
'cvc',
'pin',
'access_token',
'refresh_token',
'jwt',
'session',
'cookie',
'private',
]
def _safe_log_data(self, data):
"""Safely log data, truncating if too large and masking sensitive fields."""
"""
Safely log data, truncating if too large and masking sensitive fields.
Security measures:
- Masks all sensitive field names
- Masks email addresses (shows only domain)
- Truncates long values to prevent log flooding
- Recursively processes nested dictionaries and lists
"""
try:
# Convert to string representation
if isinstance(data, dict):
# Mask sensitive fields
safe_data = {}
for key, value in data.items():
if any(sensitive in key.lower() for sensitive in ['password', 'token', 'secret', 'key']):
safe_data[key] = '***MASKED***'
else:
safe_data[key] = value
safe_data = self._mask_sensitive_dict(data)
data_str = json.dumps(safe_data, indent=2, default=str)
elif isinstance(data, list):
safe_data = [self._mask_sensitive_value(item) for item in data]
data_str = json.dumps(safe_data, indent=2, default=str)
else:
data_str = json.dumps(data, indent=2, default=str)
@@ -136,3 +170,37 @@ class RequestLoggingMiddleware(MiddlewareMixin):
return data_str
except Exception:
return str(data)[:500] + '...[ERROR_LOGGING]'
def _mask_sensitive_dict(self, data, depth=0):
"""Recursively mask sensitive fields in a dictionary."""
if depth > 5: # Prevent infinite recursion
return '***DEPTH_LIMIT***'
safe_data = {}
for key, value in data.items():
key_lower = str(key).lower()
# Check if key contains any sensitive pattern
if any(pattern in key_lower for pattern in self.SENSITIVE_PATTERNS):
safe_data[key] = '***MASKED***'
else:
safe_data[key] = self._mask_sensitive_value(value, depth)
return safe_data
def _mask_sensitive_value(self, value, depth=0):
"""Mask a single value, handling different types."""
if isinstance(value, dict):
return self._mask_sensitive_dict(value, depth + 1)
elif isinstance(value, list):
return [self._mask_sensitive_value(item, depth + 1) for item in value[:10]] # Limit list items
elif isinstance(value, str):
# Mask email addresses (show only domain)
if '@' in value and '.' in value.split('@')[-1]:
parts = value.split('@')
if len(parts) == 2:
return f"***@{parts[1]}"
# Truncate long strings
if len(value) > 200:
return value[:200] + '...[TRUNCATED]'
return value

View File

@@ -0,0 +1,196 @@
"""
Security Headers Middleware for ThrillWiki.
This middleware adds additional security headers to all HTTP responses,
providing defense-in-depth against common web vulnerabilities.
Headers added:
- Content-Security-Policy: Controls resource loading to prevent XSS
- Permissions-Policy: Restricts browser feature access
- Cross-Origin-Embedder-Policy: Prevents cross-origin embedding
- Cross-Origin-Resource-Policy: Restricts cross-origin resource access
Usage:
Add 'apps.core.middleware.security_headers.SecurityHeadersMiddleware'
to MIDDLEWARE in settings.py (after SecurityMiddleware).
"""
from django.conf import settings
class SecurityHeadersMiddleware:
"""
Middleware that adds security headers to HTTP responses.
This provides defense-in-depth by adding headers that Django's
SecurityMiddleware doesn't handle.
"""
def __init__(self, get_response):
self.get_response = get_response
# Build CSP header at startup for performance
self._csp_header = self._build_csp_header()
self._permissions_policy_header = self._build_permissions_policy_header()
def __call__(self, request):
response = self.get_response(request)
return self._add_security_headers(response, request)
def _add_security_headers(self, response, request):
"""Add security headers to the response."""
# Content-Security-Policy
# Only add CSP for HTML responses to avoid breaking API/JSON responses
content_type = response.get("Content-Type", "")
if "text/html" in content_type:
if not response.get("Content-Security-Policy"):
response["Content-Security-Policy"] = self._csp_header
# Permissions-Policy (successor to Feature-Policy)
if not response.get("Permissions-Policy"):
response["Permissions-Policy"] = self._permissions_policy_header
# Cross-Origin-Embedder-Policy
# Requires resources to be CORS-enabled or same-origin
# Using 'unsafe-none' for now as 'require-corp' can break third-party resources
if not response.get("Cross-Origin-Embedder-Policy"):
response["Cross-Origin-Embedder-Policy"] = "unsafe-none"
# Cross-Origin-Resource-Policy
# Controls how resources can be shared with other origins
if not response.get("Cross-Origin-Resource-Policy"):
response["Cross-Origin-Resource-Policy"] = "same-origin"
return response
def _build_csp_header(self):
"""
Build the Content-Security-Policy header value.
CSP directives explained:
- default-src: Fallback for other fetch directives
- script-src: Sources for JavaScript
- style-src: Sources for CSS
- img-src: Sources for images
- font-src: Sources for fonts
- connect-src: Sources for fetch, XHR, WebSocket
- frame-ancestors: Controls framing (replaces X-Frame-Options)
- form-action: Valid targets for form submissions
- base-uri: Restricts base element URLs
- object-src: Sources for plugins (Flash, etc.)
"""
# Check if we're in debug mode
debug = getattr(settings, "DEBUG", False)
# Base directives (production-focused)
directives = {
"default-src": ["'self'"],
"script-src": [
"'self'",
# Allow HTMX inline scripts with nonce (would need nonce middleware)
# For now, using 'unsafe-inline' for HTMX compatibility
"'unsafe-inline'" if debug else "'self'",
# CDNs for external scripts
"https://cdn.jsdelivr.net",
"https://unpkg.com",
"https://challenges.cloudflare.com", # Turnstile
],
"style-src": [
"'self'",
"'unsafe-inline'", # Required for Tailwind and inline styles
"https://cdn.jsdelivr.net",
"https://fonts.googleapis.com",
],
"img-src": [
"'self'",
"data:",
"blob:",
"https:", # Allow HTTPS images (needed for user uploads, maps, etc.)
],
"font-src": [
"'self'",
"https://fonts.gstatic.com",
"https://cdn.jsdelivr.net",
],
"connect-src": [
"'self'",
"https://api.forwardemail.net",
"https://challenges.cloudflare.com",
"https://*.cloudflare.com",
# Map tile servers
"https://*.openstreetmap.org",
"https://*.tile.openstreetmap.org",
],
"frame-src": [
"'self'",
"https://challenges.cloudflare.com", # Turnstile widget
],
"frame-ancestors": ["'self'"],
"form-action": ["'self'"],
"base-uri": ["'self'"],
"object-src": ["'none'"],
"upgrade-insecure-requests": [], # Upgrade HTTP to HTTPS
}
# Add debug-specific relaxations
if debug:
# Allow webpack dev server connections in development
directives["connect-src"].extend([
"ws://localhost:*",
"http://localhost:*",
"http://127.0.0.1:*",
])
# Build header string
parts = []
for directive, sources in directives.items():
if sources:
parts.append(f"{directive} {' '.join(sources)}")
else:
# Directives like upgrade-insecure-requests don't need values
parts.append(directive)
return "; ".join(parts)
def _build_permissions_policy_header(self):
"""
Build the Permissions-Policy header value.
This header controls which browser features the page can use.
"""
# Get permissions policy from settings or use defaults
policy = getattr(settings, "PERMISSIONS_POLICY", {
"accelerometer": [],
"ambient-light-sensor": [],
"autoplay": [],
"camera": [],
"display-capture": [],
"document-domain": [],
"encrypted-media": [],
"fullscreen": ["self"],
"geolocation": ["self"],
"gyroscope": [],
"interest-cohort": [],
"magnetometer": [],
"microphone": [],
"midi": [],
"payment": [],
"picture-in-picture": [],
"publickey-credentials-get": [],
"screen-wake-lock": [],
"sync-xhr": [],
"usb": [],
"web-share": ["self"],
"xr-spatial-tracking": [],
})
parts = []
for feature, allowlist in policy.items():
if not allowlist:
# Empty list means disallow completely
parts.append(f"{feature}=()")
else:
# Convert allowlist to proper format
formatted = " ".join(allowlist)
parts.append(f"{feature}=({formatted})")
return ", ".join(parts)

View File

@@ -223,7 +223,7 @@ class MapResponse:
"query_time_ms": self.query_time_ms,
"filters_applied": self.filters_applied,
"pagination": {
"has_more": False, # TODO: Implement pagination
"has_more": False, # TODO(THRILLWIKI-102): Implement pagination for map data
"total_pages": 1,
},
},

View File

@@ -297,7 +297,7 @@ class CompanyLocationAdapter(BaseLocationAdapter):
"""Convert CompanyHeadquarters to UnifiedLocation."""
# Note: CompanyHeadquarters doesn't have coordinates, so we need to geocode
# For now, we'll skip companies without coordinates
# TODO: Implement geocoding service integration
# TODO(THRILLWIKI-101): Implement geocoding service integration for company HQs
return None
def get_queryset(

View File

@@ -0,0 +1,275 @@
"""
Safe HTML Template Tags and Filters for ThrillWiki.
This module provides template tags and filters for safely rendering
HTML content without XSS vulnerabilities.
Security Note:
Always use these filters instead of |safe for user-generated content.
The |safe filter should only be used for content that has been
pre-sanitized in the view layer.
Usage:
{% load safe_html %}
{# Sanitize user content #}
{{ user_description|sanitize }}
{# Minimal sanitization for comments #}
{{ comment_text|sanitize_minimal }}
{# Strip all HTML #}
{{ raw_text|strip_html }}
{# Safe JSON for JavaScript #}
{{ data|json_safe }}
{# Render trusted icon SVG #}
{% icon "check" class="w-4 h-4" %}
"""
import json
from django import template
from django.utils.safestring import mark_safe
from apps.core.utils.html_sanitizer import (
sanitize_html,
sanitize_minimal as _sanitize_minimal,
sanitize_svg,
strip_html as _strip_html,
sanitize_for_json,
escape_js_string as _escape_js_string,
sanitize_url as _sanitize_url,
sanitize_attribute_value,
)
register = template.Library()
# =============================================================================
# HTML Sanitization Filters
# =============================================================================
@register.filter(name='sanitize', is_safe=True)
def sanitize_filter(value):
"""
Sanitize HTML content to prevent XSS attacks.
Allows common formatting tags while stripping dangerous content.
Usage:
{{ user_content|sanitize }}
"""
if not value:
return ''
return mark_safe(sanitize_html(str(value)))
@register.filter(name='sanitize_minimal', is_safe=True)
def sanitize_minimal_filter(value):
"""
Sanitize HTML with minimal allowed tags.
Only allows basic text formatting: p, br, strong, em, i, b, a
Usage:
{{ comment|sanitize_minimal }}
"""
if not value:
return ''
return mark_safe(_sanitize_minimal(str(value)))
@register.filter(name='sanitize_svg', is_safe=True)
def sanitize_svg_filter(value):
"""
Sanitize SVG content for safe inline rendering.
Usage:
{{ icon_svg|sanitize_svg }}
"""
if not value:
return ''
return mark_safe(sanitize_svg(str(value)))
@register.filter(name='strip_html')
def strip_html_filter(value):
"""
Remove all HTML tags from content.
Usage:
{{ html_content|strip_html }}
"""
if not value:
return ''
return _strip_html(str(value))
# =============================================================================
# JavaScript/JSON Context Filters
# =============================================================================
@register.filter(name='json_safe', is_safe=True)
def json_safe_filter(value):
"""
Safely serialize data for embedding in JavaScript.
This is safer than using |safe for JSON data as it properly
escapes </script> and other dangerous sequences.
Usage:
<script>
const data = {{ python_dict|json_safe }};
</script>
"""
if value is None:
return 'null'
return mark_safe(sanitize_for_json(value))
@register.filter(name='escapejs_safe')
def escapejs_safe_filter(value):
"""
Escape a string for safe use in JavaScript string literals.
Usage:
<script>
const message = '{{ user_input|escapejs_safe }}';
</script>
"""
if not value:
return ''
return _escape_js_string(str(value))
# =============================================================================
# URL and Attribute Filters
# =============================================================================
@register.filter(name='sanitize_url')
def sanitize_url_filter(value):
"""
Sanitize a URL to prevent javascript: and other dangerous protocols.
Usage:
<a href="{{ user_url|sanitize_url }}">Link</a>
"""
if not value:
return ''
return _sanitize_url(str(value))
@register.filter(name='attr_safe')
def attr_safe_filter(value):
"""
Escape a value for safe use in HTML attributes.
Usage:
<div data-value="{{ user_value|attr_safe }}">
"""
if not value:
return ''
return sanitize_attribute_value(str(value))
# =============================================================================
# Icon Template Tags
# =============================================================================
# Predefined safe SVG icons
# These are trusted and can be rendered without sanitization
BUILTIN_ICONS = {
'check': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M4.5 12.75l6 6 9-13.5" /></svg>''',
'x': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M6 18L18 6M6 6l12 12" /></svg>''',
'plus': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M12 4.5v15m7.5-7.5h-15" /></svg>''',
'minus': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M19.5 12h-15" /></svg>''',
'chevron-down': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M19.5 8.25l-7.5 7.5-7.5-7.5" /></svg>''',
'chevron-up': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M4.5 15.75l7.5-7.5 7.5 7.5" /></svg>''',
'chevron-left': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M15.75 19.5L8.25 12l7.5-7.5" /></svg>''',
'chevron-right': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M8.25 4.5l7.5 7.5-7.5 7.5" /></svg>''',
'search': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M21 21l-5.197-5.197m0 0A7.5 7.5 0 105.196 5.196a7.5 7.5 0 0010.607 10.607z" /></svg>''',
'menu': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3.75 6.75h16.5M3.75 12h16.5m-16.5 5.25h16.5" /></svg>''',
'user': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M15.75 6a3.75 3.75 0 11-7.5 0 3.75 3.75 0 017.5 0zM4.501 20.118a7.5 7.5 0 0114.998 0A17.933 17.933 0 0112 21.75c-2.676 0-5.216-.584-7.499-1.632z" /></svg>''',
'cog': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M9.594 3.94c.09-.542.56-.94 1.11-.94h2.593c.55 0 1.02.398 1.11.94l.213 1.281c.063.374.313.686.645.87.074.04.147.083.22.127.324.196.72.257 1.075.124l1.217-.456a1.125 1.125 0 011.37.49l1.296 2.247a1.125 1.125 0 01-.26 1.431l-1.003.827c-.293.24-.438.613-.431.992a6.759 6.759 0 010 .255c-.007.378.138.75.43.99l1.005.828c.424.35.534.954.26 1.43l-1.298 2.247a1.125 1.125 0 01-1.369.491l-1.217-.456c-.355-.133-.75-.072-1.076.124a6.57 6.57 0 01-.22.128c-.331.183-.581.495-.644.869l-.213 1.28c-.09.543-.56.941-1.11.941h-2.594c-.55 0-1.02-.398-1.11-.94l-.213-1.281c-.062-.374-.312-.686-.644-.87a6.52 6.52 0 01-.22-.127c-.325-.196-.72-.257-1.076-.124l-1.217.456a1.125 1.125 0 01-1.369-.49l-1.297-2.247a1.125 1.125 0 01.26-1.431l1.004-.827c.292-.24.437-.613.43-.992a6.932 6.932 0 010-.255c.007-.378-.138-.75-.43-.99l-1.004-.828a1.125 1.125 0 01-.26-1.43l1.297-2.247a1.125 1.125 0 011.37-.491l1.216.456c.356.133.751.072 1.076-.124.072-.044.146-.087.22-.128.332-.183.582-.495.644-.869l.214-1.281z" /><path stroke-linecap="round" stroke-linejoin="round" d="M15 12a3 3 0 11-6 0 3 3 0 016 0z" /></svg>''',
'trash': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M14.74 9l-.346 9m-4.788 0L9.26 9m9.968-3.21c.342.052.682.107 1.022.166m-1.022-.165L18.16 19.673a2.25 2.25 0 01-2.244 2.077H8.084a2.25 2.25 0 01-2.244-2.077L4.772 5.79m14.456 0a48.108 48.108 0 00-3.478-.397m-12 .562c.34-.059.68-.114 1.022-.165m0 0a48.11 48.11 0 013.478-.397m7.5 0v-.916c0-1.18-.91-2.164-2.09-2.201a51.964 51.964 0 00-3.32 0c-1.18.037-2.09 1.022-2.09 2.201v.916m7.5 0a48.667 48.667 0 00-7.5 0" /></svg>''',
'pencil': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M16.862 4.487l1.687-1.688a1.875 1.875 0 112.652 2.652L10.582 16.07a4.5 4.5 0 01-1.897 1.13L6 18l.8-2.685a4.5 4.5 0 011.13-1.897l8.932-8.931zm0 0L19.5 7.125M18 14v4.75A2.25 2.25 0 0115.75 21H5.25A2.25 2.25 0 013 18.75V8.25A2.25 2.25 0 015.25 6H10" /></svg>''',
'eye': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M2.036 12.322a1.012 1.012 0 010-.639C3.423 7.51 7.36 4.5 12 4.5c4.638 0 8.573 3.007 9.963 7.178.07.207.07.431 0 .639C20.577 16.49 16.64 19.5 12 19.5c-4.638 0-8.573-3.007-9.963-7.178z" /><path stroke-linecap="round" stroke-linejoin="round" d="M15 12a3 3 0 11-6 0 3 3 0 016 0z" /></svg>''',
'eye-slash': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3.98 8.223A10.477 10.477 0 001.934 12C3.226 16.338 7.244 19.5 12 19.5c.993 0 1.953-.138 2.863-.395M6.228 6.228A10.45 10.45 0 0112 4.5c4.756 0 8.773 3.162 10.065 7.498a10.523 10.523 0 01-4.293 5.774M6.228 6.228L3 3m3.228 3.228l3.65 3.65m7.894 7.894L21 21m-3.228-3.228l-3.65-3.65m0 0a3 3 0 10-4.243-4.243m4.242 4.242L9.88 9.88" /></svg>''',
'arrow-left': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M10.5 19.5L3 12m0 0l7.5-7.5M3 12h18" /></svg>''',
'arrow-right': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M13.5 4.5L21 12m0 0l-7.5 7.5M21 12H3" /></svg>''',
'info': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M11.25 11.25l.041-.02a.75.75 0 011.063.852l-.708 2.836a.75.75 0 001.063.853l.041-.021M21 12a9 9 0 11-18 0 9 9 0 0118 0zm-9-3.75h.008v.008H12V8.25z" /></svg>''',
'warning': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M12 9v3.75m-9.303 3.376c-.866 1.5.217 3.374 1.948 3.374h14.71c1.73 0 2.813-1.874 1.948-3.374L13.949 3.378c-.866-1.5-3.032-1.5-3.898 0L2.697 16.126zM12 15.75h.007v.008H12v-.008z" /></svg>''',
'error': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M12 9v3.75m9-.75a9 9 0 11-18 0 9 9 0 0118 0zm-9 3.75h.008v.008H12v-.008z" /></svg>''',
'success': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M9 12.75L11.25 15 15 9.75M21 12a9 9 0 11-18 0 9 9 0 0118 0z" /></svg>''',
'loading': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" {attrs}><circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle><path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>''',
'external-link': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M13.5 6H5.25A2.25 2.25 0 003 8.25v10.5A2.25 2.25 0 005.25 21h10.5A2.25 2.25 0 0018 18.75V10.5m-10.5 6L21 3m0 0h-5.25M21 3v5.25" /></svg>''',
'download': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3 16.5v2.25A2.25 2.25 0 005.25 21h13.5A2.25 2.25 0 0021 18.75V16.5M16.5 12L12 16.5m0 0L7.5 12m4.5 4.5V3" /></svg>''',
'upload': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3 16.5v2.25A2.25 2.25 0 005.25 21h13.5A2.25 2.25 0 0021 18.75V16.5m-13.5-9L12 3m0 0l4.5 4.5M12 3v13.5" /></svg>''',
'star': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M11.48 3.499a.562.562 0 011.04 0l2.125 5.111a.563.563 0 00.475.345l5.518.442c.499.04.701.663.321.988l-4.204 3.602a.563.563 0 00-.182.557l1.285 5.385a.562.562 0 01-.84.61l-4.725-2.885a.563.563 0 00-.586 0L6.982 20.54a.562.562 0 01-.84-.61l1.285-5.386a.562.562 0 00-.182-.557l-4.204-3.602a.563.563 0 01.321-.988l5.518-.442a.563.563 0 00.475-.345L11.48 3.5z" /></svg>''',
'star-filled': '''<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" {attrs}><path fill-rule="evenodd" d="M10.788 3.21c.448-1.077 1.976-1.077 2.424 0l2.082 5.007 5.404.433c1.164.093 1.636 1.545.749 2.305l-4.117 3.527 1.257 5.273c.271 1.136-.964 2.033-1.96 1.425L12 18.354 7.373 21.18c-.996.608-2.231-.29-1.96-1.425l1.257-5.273-4.117-3.527c-.887-.76-.415-2.212.749-2.305l5.404-.433 2.082-5.006z" clip-rule="evenodd" /></svg>''',
'heart': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M21 8.25c0-2.485-2.099-4.5-4.688-4.5-1.935 0-3.597 1.126-4.312 2.733-.715-1.607-2.377-2.733-4.313-2.733C5.1 3.75 3 5.765 3 8.25c0 7.22 9 12 9 12s9-4.78 9-12z" /></svg>''',
'heart-filled': '''<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" {attrs}><path d="M11.645 20.91l-.007-.003-.022-.012a15.247 15.247 0 01-.383-.218 25.18 25.18 0 01-4.244-3.17C4.688 15.36 2.25 12.174 2.25 8.25 2.25 5.322 4.714 3 7.688 3A5.5 5.5 0 0112 5.052 5.5 5.5 0 0116.313 3c2.973 0 5.437 2.322 5.437 5.25 0 3.925-2.438 7.111-4.739 9.256a25.175 25.175 0 01-4.244 3.17 15.247 15.247 0 01-.383.219l-.022.012-.007.004-.003.001a.752.752 0 01-.704 0l-.003-.001z" /></svg>''',
}
@register.simple_tag
def icon(name, **kwargs):
"""
Render a trusted SVG icon.
This tag renders predefined SVG icons that are trusted and safe.
Custom attributes can be passed to customize the icon.
Usage:
{% icon "check" class="w-4 h-4 text-green-500" %}
{% icon "x" class="w-6 h-6" aria_hidden="true" %}
Args:
name: The icon name (from BUILTIN_ICONS)
**kwargs: Additional HTML attributes for the SVG
Returns:
Safe HTML for the icon SVG
"""
svg_template = BUILTIN_ICONS.get(name)
if not svg_template:
# Return empty string for unknown icons (fail silently)
return ''
# Build attributes string
attrs_list = []
for key, value in kwargs.items():
# Convert underscore to hyphen for HTML attributes (e.g., aria_hidden -> aria-hidden)
attr_name = key.replace('_', '-')
# Escape attribute values to prevent XSS
safe_value = sanitize_attribute_value(str(value))
attrs_list.append(f'{attr_name}="{safe_value}"')
attrs_str = ' '.join(attrs_list)
# Substitute attributes into template
svg = svg_template.format(attrs=attrs_str)
return mark_safe(svg)
@register.simple_tag
def icon_class(name, size='w-5 h-5', extra_class=''):
"""
Render a trusted SVG icon with common class presets.
Usage:
{% icon_class "check" size="w-4 h-4" extra_class="text-green-500" %}
Args:
name: The icon name
size: Size classes (default: "w-5 h-5")
extra_class: Additional CSS classes
Returns:
Safe HTML for the icon SVG
"""
classes = f'{size} {extra_class}'.strip()
return icon(name, **{'class': classes})

View File

@@ -0,0 +1,161 @@
"""
Centralized error handling utilities.
This module provides standardized error handling for views across the application,
ensuring consistent logging, user messages, and API responses.
"""
import logging
from typing import Any, Dict, Optional
from django.contrib import messages
from django.http import HttpRequest
from rest_framework import status
from rest_framework.response import Response
from apps.core.exceptions import ThrillWikiException
logger = logging.getLogger(__name__)
class ErrorHandler:
"""Centralized error handling for views."""
@staticmethod
def handle_view_error(
request: HttpRequest,
error: Exception,
user_message: str = "An error occurred",
log_message: Optional[str] = None,
level: str = "error",
) -> None:
"""
Handle errors in template views.
Logs the error with appropriate context and displays a user-friendly
message using Django's messages framework.
Args:
request: HTTP request object
error: Exception that occurred
user_message: Message to show to the user (should be user-friendly)
log_message: Message to log (defaults to str(error) with user_message prefix)
level: Log level - one of "error", "warning", "info"
Example:
try:
ParkService.create_park(...)
except ServiceError as e:
ErrorHandler.handle_view_error(
request,
e,
user_message="Failed to create park",
log_message=f"Park creation failed for user {request.user.id}"
)
"""
log_msg = log_message or f"{user_message}: {str(error)}"
if level == "error":
logger.error(log_msg, exc_info=True)
elif level == "warning":
logger.warning(log_msg)
else:
logger.info(log_msg)
messages.error(request, user_message)
@staticmethod
def handle_api_error(
error: Exception,
user_message: str = "An error occurred",
log_message: Optional[str] = None,
status_code: int = status.HTTP_400_BAD_REQUEST,
) -> Response:
"""
Handle errors in API views.
Logs the error and returns a standardized DRF Response with error details.
Args:
error: Exception that occurred
user_message: Message to return to the client (should be user-friendly)
log_message: Message to log (defaults to str(error) with user_message prefix)
status_code: HTTP status code to return
Returns:
DRF Response with error details in standard format
Example:
try:
result = ParkService.create_park(...)
except ServiceError as e:
return ErrorHandler.handle_api_error(
e,
user_message="Failed to create park photo",
status_code=status.HTTP_400_BAD_REQUEST
)
"""
log_msg = log_message or f"{user_message}: {str(error)}"
logger.error(log_msg, exc_info=True)
# Build error response
error_data: Dict[str, Any] = {
"error": user_message,
"detail": str(error),
}
# Include additional details for ThrillWikiException subclasses
if isinstance(error, ThrillWikiException):
error_data["error_code"] = error.error_code
if error.details:
error_data["details"] = error.details
return Response(error_data, status=status_code)
@staticmethod
def handle_success(
request: HttpRequest,
message: str,
level: str = "success",
) -> None:
"""
Handle success messages in template views.
Args:
request: HTTP request object
message: Success message to display
level: Message level - one of "success", "info", "warning"
"""
if level == "success":
messages.success(request, message)
elif level == "info":
messages.info(request, message)
elif level == "warning":
messages.warning(request, message)
@staticmethod
def api_success_response(
data: Any = None,
message: str = "Success",
status_code: int = status.HTTP_200_OK,
) -> Response:
"""
Create a standardized success response for API views.
Args:
data: Response data (optional)
message: Success message
status_code: HTTP status code
Returns:
DRF Response with success data in standard format
"""
response_data: Dict[str, Any] = {
"status": "success",
"message": message,
}
if data is not None:
response_data["data"] = data
return Response(response_data, status=status_code)

View File

@@ -0,0 +1,432 @@
"""
File Upload Security Utilities for ThrillWiki.
This module provides comprehensive file validation and security checks for
file uploads, including:
- File type validation (MIME type and magic number verification)
- File size validation
- Filename sanitization
- Image-specific validation
Security Note:
Always validate uploaded files before saving them. Never trust user-provided
file extensions or Content-Type headers alone.
Usage:
from apps.core.utils.file_scanner import validate_image_upload, sanitize_filename
# In view
try:
validate_image_upload(uploaded_file)
# Safe to save file
except FileValidationError as e:
return JsonResponse({'error': str(e)}, status=400)
# Sanitize filename
safe_name = sanitize_filename(uploaded_file.name)
"""
import os
import re
import uuid
from io import BytesIO
from typing import Optional, Set, Tuple
from django.core.exceptions import ValidationError
from django.core.files.uploadedfile import UploadedFile
class FileValidationError(ValidationError):
"""Custom exception for file validation errors."""
pass
# =============================================================================
# Image Magic Number Signatures
# =============================================================================
# Magic number signatures for common image formats
# Format: (magic_bytes, offset, description)
IMAGE_SIGNATURES = {
'jpeg': [
(b'\xFF\xD8\xFF\xE0', 0, 'JPEG (JFIF)'),
(b'\xFF\xD8\xFF\xE1', 0, 'JPEG (EXIF)'),
(b'\xFF\xD8\xFF\xE2', 0, 'JPEG (ICC)'),
(b'\xFF\xD8\xFF\xE3', 0, 'JPEG (Samsung)'),
(b'\xFF\xD8\xFF\xE8', 0, 'JPEG (SPIFF)'),
(b'\xFF\xD8\xFF\xDB', 0, 'JPEG (Raw)'),
],
'png': [
(b'\x89PNG\r\n\x1a\n', 0, 'PNG'),
],
'gif': [
(b'GIF87a', 0, 'GIF87a'),
(b'GIF89a', 0, 'GIF89a'),
],
'webp': [
(b'RIFF', 0, 'RIFF'), # WebP starts with RIFF header
],
'bmp': [
(b'BM', 0, 'BMP'),
],
}
# All allowed MIME types
ALLOWED_IMAGE_MIME_TYPES: Set[str] = frozenset({
'image/jpeg',
'image/png',
'image/gif',
'image/webp',
})
# Allowed file extensions
ALLOWED_IMAGE_EXTENSIONS: Set[str] = frozenset({
'.jpg', '.jpeg', '.png', '.gif', '.webp',
})
# Maximum file size (10MB)
MAX_FILE_SIZE = 10 * 1024 * 1024
# Minimum file size (prevent empty files)
MIN_FILE_SIZE = 100 # 100 bytes
# =============================================================================
# File Validation Functions
# =============================================================================
def validate_image_upload(
file: UploadedFile,
max_size: int = MAX_FILE_SIZE,
allowed_types: Optional[Set[str]] = None,
allowed_extensions: Optional[Set[str]] = None,
) -> bool:
"""
Validate an uploaded image file for security.
Performs multiple validation checks:
1. File size validation
2. File extension validation
3. MIME type validation (from Content-Type header)
4. Magic number validation (actual file content check)
5. Image integrity validation (using PIL)
Args:
file: The uploaded file object
max_size: Maximum allowed file size in bytes
allowed_types: Set of allowed MIME types
allowed_extensions: Set of allowed file extensions
Returns:
True if all validations pass
Raises:
FileValidationError: If any validation fails
"""
if allowed_types is None:
allowed_types = ALLOWED_IMAGE_MIME_TYPES
if allowed_extensions is None:
allowed_extensions = ALLOWED_IMAGE_EXTENSIONS
# 1. Check if file exists
if not file:
raise FileValidationError("No file provided")
# 2. Check file size
if file.size > max_size:
raise FileValidationError(
f"File too large. Maximum size is {max_size // (1024 * 1024)}MB"
)
if file.size < MIN_FILE_SIZE:
raise FileValidationError("File too small or empty")
# 3. Check file extension
filename = file.name or ''
ext = os.path.splitext(filename)[1].lower()
if ext not in allowed_extensions:
raise FileValidationError(
f"Invalid file extension '{ext}'. Allowed: {', '.join(allowed_extensions)}"
)
# 4. Check Content-Type header
content_type = getattr(file, 'content_type', '')
if content_type and content_type not in allowed_types:
raise FileValidationError(
f"Invalid file type '{content_type}'. Allowed: {', '.join(allowed_types)}"
)
# 5. Validate magic numbers (actual file content)
if not _validate_magic_number(file):
raise FileValidationError(
"File content doesn't match file extension. File may be corrupted or malicious."
)
# 6. Validate image integrity using PIL
if not _validate_image_integrity(file):
raise FileValidationError(
"Invalid or corrupted image file"
)
return True
def _validate_magic_number(file: UploadedFile) -> bool:
"""
Validate file content against known magic number signatures.
This is more reliable than checking the file extension or Content-Type
header, which can be easily spoofed.
Args:
file: The uploaded file object
Returns:
True if magic number matches an allowed image type
"""
# Read the file header
file.seek(0)
header = file.read(16)
file.seek(0)
# Check against known signatures
for format_name, signatures in IMAGE_SIGNATURES.items():
for magic, offset, description in signatures:
if len(header) >= offset + len(magic):
if header[offset:offset + len(magic)] == magic:
# Special handling for WebP (must also have WEBP marker)
if format_name == 'webp':
if len(header) >= 12 and header[8:12] == b'WEBP':
return True
else:
return True
return False
def _validate_image_integrity(file: UploadedFile) -> bool:
"""
Validate image integrity using PIL.
This catches corrupted images and various image-related attacks.
Args:
file: The uploaded file object
Returns:
True if image can be opened and verified
"""
try:
from PIL import Image
file.seek(0)
# Read into BytesIO to avoid issues with file-like objects
img_data = BytesIO(file.read())
file.seek(0)
with Image.open(img_data) as img:
# Verify the image is not truncated or corrupted
img.verify()
# Re-open for size check (verify() can only be called once)
img_data.seek(0)
with Image.open(img_data) as img2:
# Check for reasonable image dimensions
# Prevent decompression bombs
max_dimension = 10000
if img2.width > max_dimension or img2.height > max_dimension:
raise FileValidationError(
f"Image dimensions too large. Maximum is {max_dimension}x{max_dimension}"
)
# Check for very small dimensions (might be suspicious)
if img2.width < 1 or img2.height < 1:
raise FileValidationError("Invalid image dimensions")
return True
except FileValidationError:
raise
except Exception:
return False
# =============================================================================
# Filename Sanitization
# =============================================================================
def sanitize_filename(filename: str, max_length: int = 100) -> str:
"""
Sanitize a filename to prevent directory traversal and other attacks.
This function:
- Removes path separators and directory traversal attempts
- Removes special characters
- Truncates to maximum length
- Ensures the filename is not empty
Args:
filename: The original filename
max_length: Maximum length for the filename
Returns:
Sanitized filename
"""
if not filename:
return f"file_{uuid.uuid4().hex[:8]}"
# Get just the filename (remove any path components)
filename = os.path.basename(filename)
# Split into name and extension
name, ext = os.path.splitext(filename)
# Remove or replace dangerous characters from name
# Allow alphanumeric, hyphens, underscores, dots
name = re.sub(r'[^\w\-.]', '_', name)
# Remove leading dots and underscores (hidden file prevention)
name = name.lstrip('._')
# Collapse multiple underscores
name = re.sub(r'_+', '_', name)
# Ensure name is not empty
if not name:
name = f"file_{uuid.uuid4().hex[:8]}"
# Sanitize extension
ext = ext.lower()
ext = re.sub(r'[^\w.]', '', ext)
# Combine and truncate
result = f"{name[:max_length - len(ext)]}{ext}"
return result
def generate_unique_filename(original_filename: str, prefix: str = '') -> str:
"""
Generate a unique filename using UUID while preserving extension.
Args:
original_filename: The original filename
prefix: Optional prefix for the filename
Returns:
Unique filename with UUID
"""
ext = os.path.splitext(original_filename)[1].lower()
# Sanitize extension
ext = re.sub(r'[^\w.]', '', ext)
# Generate unique filename
unique_id = uuid.uuid4().hex[:12]
if prefix:
return f"{sanitize_filename(prefix)}_{unique_id}{ext}"
return f"{unique_id}{ext}"
# =============================================================================
# Rate Limiting for Uploads
# =============================================================================
# Rate limiting configuration
UPLOAD_RATE_LIMITS = {
'per_minute': 10,
'per_hour': 100,
'per_day': 500,
}
def check_upload_rate_limit(user_id: int, cache_backend=None) -> Tuple[bool, str]:
"""
Check if user has exceeded upload rate limits.
Args:
user_id: The user's ID
cache_backend: Optional Django cache backend (uses default if not provided)
Returns:
Tuple of (is_allowed, reason_if_blocked)
"""
if cache_backend is None:
from django.core.cache import cache
cache_backend = cache
# Check per-minute limit
minute_key = f"upload_rate:{user_id}:minute"
minute_count = cache_backend.get(minute_key, 0)
if minute_count >= UPLOAD_RATE_LIMITS['per_minute']:
return False, "Upload rate limit exceeded. Please wait a minute."
# Check per-hour limit
hour_key = f"upload_rate:{user_id}:hour"
hour_count = cache_backend.get(hour_key, 0)
if hour_count >= UPLOAD_RATE_LIMITS['per_hour']:
return False, "Hourly upload limit exceeded. Please try again later."
# Check per-day limit
day_key = f"upload_rate:{user_id}:day"
day_count = cache_backend.get(day_key, 0)
if day_count >= UPLOAD_RATE_LIMITS['per_day']:
return False, "Daily upload limit exceeded. Please try again tomorrow."
return True, ""
def increment_upload_count(user_id: int, cache_backend=None) -> None:
"""
Increment upload count for rate limiting.
Args:
user_id: The user's ID
cache_backend: Optional Django cache backend
"""
if cache_backend is None:
from django.core.cache import cache
cache_backend = cache
# Increment per-minute counter (expires in 60 seconds)
minute_key = f"upload_rate:{user_id}:minute"
try:
cache_backend.incr(minute_key)
except ValueError:
cache_backend.set(minute_key, 1, 60)
# Increment per-hour counter (expires in 3600 seconds)
hour_key = f"upload_rate:{user_id}:hour"
try:
cache_backend.incr(hour_key)
except ValueError:
cache_backend.set(hour_key, 1, 3600)
# Increment per-day counter (expires in 86400 seconds)
day_key = f"upload_rate:{user_id}:day"
try:
cache_backend.incr(day_key)
except ValueError:
cache_backend.set(day_key, 1, 86400)
# =============================================================================
# Antivirus Integration Point
# =============================================================================
def scan_file_for_malware(file: UploadedFile) -> Tuple[bool, str]:
"""
Placeholder for antivirus/malware scanning integration.
This function should be implemented to integrate with a virus scanner
like ClamAV. Currently it returns True (safe) for all files.
Args:
file: The uploaded file object
Returns:
Tuple of (is_safe, reason_if_unsafe)
"""
# TODO(THRILLWIKI-110): Implement ClamAV integration for malware scanning
# This requires ClamAV daemon to be running and python-clamav to be installed
return True, ""

View File

@@ -0,0 +1,382 @@
"""
HTML Sanitization Utilities for ThrillWiki.
This module provides functions for sanitizing user-generated HTML content
to prevent XSS (Cross-Site Scripting) attacks while allowing safe HTML
formatting.
Security Note:
Always sanitize user-generated content before rendering with |safe filter
or mark_safe(). Never trust user input.
Usage:
from apps.core.utils.html_sanitizer import sanitize_html, sanitize_for_json
# In views
context['description'] = sanitize_html(user_input)
# For JSON/JavaScript contexts
json_safe = sanitize_for_json(data)
"""
import json
import re
from html import escape as html_escape
from typing import Any
try:
import bleach
BLEACH_AVAILABLE = True
except ImportError:
BLEACH_AVAILABLE = False
# =============================================================================
# Allowed HTML Configuration
# =============================================================================
# Default allowed HTML tags for user-generated content
ALLOWED_TAGS = frozenset([
# Text formatting
'p', 'br', 'hr',
'strong', 'b', 'em', 'i', 'u', 's', 'strike',
'sub', 'sup', 'small', 'mark',
# Headers
'h1', 'h2', 'h3', 'h4', 'h5', 'h6',
# Lists
'ul', 'ol', 'li',
# Links (with restrictions on attributes)
'a',
# Block elements
'blockquote', 'pre', 'code',
'div', 'span',
# Tables
'table', 'thead', 'tbody', 'tfoot', 'tr', 'th', 'td',
])
# Allowed attributes for each tag
ALLOWED_ATTRIBUTES = {
'a': ['href', 'title', 'rel', 'target'],
'img': ['src', 'alt', 'title', 'width', 'height'],
'div': ['class'],
'span': ['class'],
'p': ['class'],
'table': ['class'],
'th': ['class', 'colspan', 'rowspan'],
'td': ['class', 'colspan', 'rowspan'],
'*': ['class'], # Allow class on all elements
}
# Allowed URL protocols
ALLOWED_PROTOCOLS = frozenset([
'http', 'https', 'mailto', 'tel',
])
# Minimal tags for comments and short text
MINIMAL_TAGS = frozenset([
'p', 'br', 'strong', 'b', 'em', 'i', 'a',
])
# Tags allowed in icon SVGs (for icon template rendering)
SVG_TAGS = frozenset([
'svg', 'path', 'g', 'circle', 'rect', 'line', 'polyline', 'polygon',
'ellipse', 'text', 'tspan', 'defs', 'use', 'symbol', 'clipPath',
'mask', 'linearGradient', 'radialGradient', 'stop', 'title',
])
SVG_ATTRIBUTES = {
'svg': ['viewBox', 'width', 'height', 'fill', 'stroke', 'class',
'xmlns', 'aria-hidden', 'role'],
'path': ['d', 'fill', 'stroke', 'stroke-width', 'stroke-linecap',
'stroke-linejoin', 'class', 'fill-rule', 'clip-rule'],
'g': ['fill', 'stroke', 'transform', 'class'],
'circle': ['cx', 'cy', 'r', 'fill', 'stroke', 'class'],
'rect': ['x', 'y', 'width', 'height', 'rx', 'ry', 'fill', 'stroke', 'class'],
'line': ['x1', 'y1', 'x2', 'y2', 'stroke', 'stroke-width', 'class'],
'polyline': ['points', 'fill', 'stroke', 'class'],
'polygon': ['points', 'fill', 'stroke', 'class'],
'*': ['class', 'fill', 'stroke'],
}
# =============================================================================
# Sanitization Functions
# =============================================================================
def sanitize_html(
html: str | None,
allowed_tags: frozenset | None = None,
allowed_attributes: dict | None = None,
allowed_protocols: frozenset | None = None,
strip: bool = True,
) -> str:
"""
Sanitize HTML content to prevent XSS attacks.
Args:
html: The HTML string to sanitize
allowed_tags: Set of allowed HTML tag names
allowed_attributes: Dict mapping tag names to allowed attributes
allowed_protocols: Set of allowed URL protocols
strip: If True, remove disallowed tags; if False, escape them
Returns:
Sanitized HTML string safe for rendering
Example:
>>> sanitize_html('<script>alert("xss")</script><p>Hello</p>')
'<p>Hello</p>'
"""
if not html:
return ''
if not isinstance(html, str):
html = str(html)
if not BLEACH_AVAILABLE:
# Fallback: escape all HTML if bleach is not available
return html_escape(html)
tags = allowed_tags if allowed_tags is not None else ALLOWED_TAGS
attrs = allowed_attributes if allowed_attributes is not None else ALLOWED_ATTRIBUTES
protocols = allowed_protocols if allowed_protocols is not None else ALLOWED_PROTOCOLS
return bleach.clean(
html,
tags=tags,
attributes=attrs,
protocols=protocols,
strip=strip,
)
def sanitize_minimal(html: str | None) -> str:
"""
Sanitize HTML with minimal allowed tags.
Use this for user comments, short descriptions, etc.
Args:
html: The HTML string to sanitize
Returns:
Sanitized HTML with only basic formatting tags allowed
"""
return sanitize_html(
html,
allowed_tags=MINIMAL_TAGS,
allowed_attributes={'a': ['href', 'title']},
)
def sanitize_svg(svg: str | None) -> str:
"""
Sanitize SVG content for safe inline rendering.
This is specifically for icon SVGs that need to be rendered inline.
Removes potentially dangerous elements while preserving SVG structure.
Args:
svg: The SVG string to sanitize
Returns:
Sanitized SVG string safe for inline rendering
"""
if not svg:
return ''
if not isinstance(svg, str):
svg = str(svg)
if not BLEACH_AVAILABLE:
# Fallback: escape all if bleach is not available
return html_escape(svg)
return bleach.clean(
svg,
tags=SVG_TAGS,
attributes=SVG_ATTRIBUTES,
strip=True,
)
def strip_html(html: str | None) -> str:
"""
Remove all HTML tags from a string.
Use this for contexts where no HTML is allowed at all.
Args:
html: The HTML string to strip
Returns:
Plain text with all HTML tags removed
"""
if not html:
return ''
if not isinstance(html, str):
html = str(html)
if BLEACH_AVAILABLE:
return bleach.clean(html, tags=[], strip=True)
else:
# Fallback: use regex to strip tags
return re.sub(r'<[^>]+>', '', html)
# =============================================================================
# JSON/JavaScript Context Sanitization
# =============================================================================
def sanitize_for_json(data: Any) -> str:
"""
Safely serialize data for embedding in JavaScript/JSON contexts.
This prevents XSS when embedding data in <script> tags or JavaScript.
Args:
data: The data to serialize (dict, list, or primitive)
Returns:
JSON string safe for embedding in JavaScript
Example:
>>> sanitize_for_json({'name': '</script><script>alert("xss")'})
'{"name": "\\u003c/script\\u003e\\u003cscript\\u003ealert(\\"xss\\")"}'
"""
# JSON encode with safe characters escaped
return json.dumps(data, ensure_ascii=False).replace(
'<', '\\u003c'
).replace(
'>', '\\u003e'
).replace(
'&', '\\u0026'
).replace(
"'", '\\u0027'
)
def escape_js_string(s: str | None) -> str:
"""
Escape a string for safe use in JavaScript string literals.
Args:
s: The string to escape
Returns:
Escaped string safe for JavaScript contexts
"""
if not s:
return ''
if not isinstance(s, str):
s = str(s)
# Escape backslashes first, then other special characters
return s.replace('\\', '\\\\').replace(
"'", "\\'"
).replace(
'"', '\\"'
).replace(
'\n', '\\n'
).replace(
'\r', '\\r'
).replace(
'<', '\\u003c'
).replace(
'>', '\\u003e'
).replace(
'&', '\\u0026'
)
# =============================================================================
# URL Sanitization
# =============================================================================
def sanitize_url(url: str | None, allowed_protocols: frozenset | None = None) -> str:
"""
Sanitize a URL to prevent javascript: and other dangerous protocols.
Args:
url: The URL to sanitize
allowed_protocols: Set of allowed URL protocols
Returns:
Sanitized URL or empty string if unsafe
"""
if not url:
return ''
if not isinstance(url, str):
url = str(url)
url = url.strip()
if not url:
return ''
protocols = allowed_protocols if allowed_protocols is not None else ALLOWED_PROTOCOLS
# Check for allowed protocols
url_lower = url.lower()
# Check for javascript:, data:, vbscript:, etc.
if ':' in url_lower:
protocol = url_lower.split(':')[0]
if protocol not in protocols:
# Allow relative URLs and anchor links
if not (url.startswith('/') or url.startswith('#') or url.startswith('?')):
return ''
return url
# =============================================================================
# Attribute Sanitization
# =============================================================================
def sanitize_attribute_value(value: str | None) -> str:
"""
Sanitize a value for use in HTML attributes.
Args:
value: The attribute value to sanitize
Returns:
Sanitized value safe for HTML attribute contexts
"""
if not value:
return ''
if not isinstance(value, str):
value = str(value)
# HTML escape for attribute context
return html_escape(value, quote=True)
def sanitize_class_name(name: str | None) -> str:
"""
Sanitize a CSS class name.
Args:
name: The class name to sanitize
Returns:
Sanitized class name containing only safe characters
"""
if not name:
return ''
if not isinstance(name, str):
name = str(name)
# Only allow alphanumeric, hyphens, and underscores
return re.sub(r'[^a-zA-Z0-9_-]', '', name)

View File

@@ -0,0 +1,79 @@
"""
Base view classes with common patterns.
This module provides base view classes that implement common patterns
such as automatic query optimization with select_related and prefetch_related.
"""
from typing import List
from django.db.models import QuerySet
from django.views.generic import DetailView, ListView
class OptimizedListView(ListView):
"""
ListView with automatic query optimization.
Automatically applies select_related and prefetch_related based on
class attributes, reducing the need for boilerplate code in get_queryset.
Attributes:
select_related_fields: List of fields to pass to select_related()
prefetch_related_fields: List of fields to pass to prefetch_related()
Example:
class RideListView(OptimizedListView):
model = Ride
select_related_fields = ['park', 'manufacturer']
prefetch_related_fields = ['photos']
"""
select_related_fields: List[str] = []
prefetch_related_fields: List[str] = []
def get_queryset(self) -> QuerySet:
"""Get queryset with optimizations applied."""
queryset = super().get_queryset()
if self.select_related_fields:
queryset = queryset.select_related(*self.select_related_fields)
if self.prefetch_related_fields:
queryset = queryset.prefetch_related(*self.prefetch_related_fields)
return queryset
class OptimizedDetailView(DetailView):
"""
DetailView with automatic query optimization.
Automatically applies select_related and prefetch_related based on
class attributes, reducing the need for boilerplate code in get_queryset.
Attributes:
select_related_fields: List of fields to pass to select_related()
prefetch_related_fields: List of fields to pass to prefetch_related()
Example:
class RideDetailView(OptimizedDetailView):
model = Ride
select_related_fields = ['park', 'park__location', 'manufacturer']
prefetch_related_fields = ['photos', 'coaster_stats']
"""
select_related_fields: List[str] = []
prefetch_related_fields: List[str] = []
def get_queryset(self) -> QuerySet:
"""Get queryset with optimizations applied."""
queryset = super().get_queryset()
if self.select_related_fields:
queryset = queryset.select_related(*self.select_related_fields)
if self.prefetch_related_fields:
queryset = queryset.prefetch_related(*self.prefetch_related_fields)
return queryset

View File

@@ -6,8 +6,6 @@ from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
from rest_framework.permissions import AllowAny
from django.views.decorators.csrf import csrf_exempt
from django.utils.decorators import method_decorator
from typing import Optional, List
from ..services.entity_fuzzy_matching import (
@@ -244,10 +242,12 @@ class EntityNotFoundView(APIView):
)
@method_decorator(csrf_exempt, name="dispatch")
class QuickEntitySuggestionView(APIView):
"""
Lightweight endpoint for quick entity suggestions (e.g., autocomplete).
Security Note: This endpoint only accepts GET requests, which are inherently
safe from CSRF attacks. No CSRF exemption is needed.
"""
permission_classes = [AllowAny]

View File

@@ -1,5 +1,4 @@
from django.views.generic.edit import FormView
from django.shortcuts import get_object_or_404
class InlineEditView(FormView):

View File

@@ -636,7 +636,9 @@ class MapCacheView(MapAPIView):
def delete(self, request: HttpRequest) -> JsonResponse:
"""Clear all map cache (admin only)."""
# TODO: Add admin permission check
# TODO(THRILLWIKI-103): Add admin permission check for cache clear
if not (request.user.is_authenticated and request.user.is_staff):
return self._error_response("Admin access required", 403)
try:
unified_map_service.invalidate_cache()
@@ -655,7 +657,9 @@ class MapCacheView(MapAPIView):
def post(self, request: HttpRequest) -> JsonResponse:
"""Invalidate specific cache entries."""
# TODO: Add admin permission check
# TODO(THRILLWIKI-103): Add admin permission check for cache invalidation
if not (request.user.is_authenticated and request.user.is_staff):
return self._error_response("Admin access required", 403)
try:
data = json.loads(request.body)

View File

@@ -1,5 +1,4 @@
from django.views.generic.edit import FormView
from django.http import HttpResponse
class HTMXModalFormView(FormView):

View File

@@ -1,6 +1,7 @@
"""
Core views for the application.
"""
import json
import logging
from typing import Any, Dict, Optional, Type
@@ -93,18 +94,14 @@ class GlobalSearchView(TemplateView):
# Real implementation should query multiple models.
if q:
# Return a small payload of mocked results to keep this scaffold safe
results = [
{"title": f"Result for {q}", "url": "#", "subtitle": "Park"}
]
results = [{"title": f"Result for {q}", "url": "#", "subtitle": "Park"}]
suggestions = [{"text": q, "url": "#"}]
context = {"results": results, "suggestions": suggestions}
# If HTMX request, render dropdown partial
if request.headers.get("HX-Request") == "true":
return render(
request, "core/search/partials/search_dropdown.html", context
)
return render(request, "core/search/partials/search_dropdown.html", context)
return render(request, self.template_name, context)
@@ -117,25 +114,75 @@ class GlobalSearchView(TemplateView):
# Default transition metadata for styling
TRANSITION_METADATA = {
# Approval transitions
"approve": {"style": "green", "icon": "check", "requires_confirm": True, "confirm_message": "Are you sure you want to approve this?"},
"transition_to_approved": {"style": "green", "icon": "check", "requires_confirm": True, "confirm_message": "Are you sure you want to approve this?"},
"approve": {
"style": "green",
"icon": "check",
"requires_confirm": True,
"confirm_message": "Are you sure you want to approve this?",
},
"transition_to_approved": {
"style": "green",
"icon": "check",
"requires_confirm": True,
"confirm_message": "Are you sure you want to approve this?",
},
# Rejection transitions
"reject": {"style": "red", "icon": "times", "requires_confirm": True, "confirm_message": "Are you sure you want to reject this?"},
"transition_to_rejected": {"style": "red", "icon": "times", "requires_confirm": True, "confirm_message": "Are you sure you want to reject this?"},
"reject": {
"style": "red",
"icon": "times",
"requires_confirm": True,
"confirm_message": "Are you sure you want to reject this?",
},
"transition_to_rejected": {
"style": "red",
"icon": "times",
"requires_confirm": True,
"confirm_message": "Are you sure you want to reject this?",
},
# Escalation transitions
"escalate": {"style": "yellow", "icon": "arrow-up", "requires_confirm": True, "confirm_message": "Are you sure you want to escalate this?"},
"transition_to_escalated": {"style": "yellow", "icon": "arrow-up", "requires_confirm": True, "confirm_message": "Are you sure you want to escalate this?"},
"escalate": {
"style": "yellow",
"icon": "arrow-up",
"requires_confirm": True,
"confirm_message": "Are you sure you want to escalate this?",
},
"transition_to_escalated": {
"style": "yellow",
"icon": "arrow-up",
"requires_confirm": True,
"confirm_message": "Are you sure you want to escalate this?",
},
# Assignment transitions
"assign": {"style": "blue", "icon": "user-plus", "requires_confirm": False},
"unassign": {"style": "gray", "icon": "user-minus", "requires_confirm": False},
# Status transitions
"start": {"style": "blue", "icon": "play", "requires_confirm": False},
"complete": {"style": "green", "icon": "check-circle", "requires_confirm": True, "confirm_message": "Are you sure you want to complete this?"},
"cancel": {"style": "red", "icon": "ban", "requires_confirm": True, "confirm_message": "Are you sure you want to cancel this?"},
"complete": {
"style": "green",
"icon": "check-circle",
"requires_confirm": True,
"confirm_message": "Are you sure you want to complete this?",
},
"cancel": {
"style": "red",
"icon": "ban",
"requires_confirm": True,
"confirm_message": "Are you sure you want to cancel this?",
},
"reopen": {"style": "blue", "icon": "redo", "requires_confirm": False},
# Resolution transitions
"resolve": {"style": "green", "icon": "check-double", "requires_confirm": True, "confirm_message": "Are you sure you want to resolve this?"},
"dismiss": {"style": "gray", "icon": "times-circle", "requires_confirm": True, "confirm_message": "Are you sure you want to dismiss this?"},
"resolve": {
"style": "green",
"icon": "check-double",
"requires_confirm": True,
"confirm_message": "Are you sure you want to resolve this?",
},
"dismiss": {
"style": "gray",
"icon": "times-circle",
"requires_confirm": True,
"confirm_message": "Are you sure you want to dismiss this?",
},
# Default
"default": {"style": "gray", "icon": "arrow-right", "requires_confirm": False},
}
@@ -155,7 +202,9 @@ def get_transition_metadata(transition_name: str) -> Dict[str, Any]:
return TRANSITION_METADATA["default"].copy()
def add_toast_trigger(response: HttpResponse, message: str, toast_type: str = "success") -> HttpResponse:
def add_toast_trigger(
response: HttpResponse, message: str, toast_type: str = "success"
) -> HttpResponse:
"""
Add HX-Trigger header to trigger Alpine.js toast.
@@ -167,17 +216,12 @@ def add_toast_trigger(response: HttpResponse, message: str, toast_type: str = "s
Returns:
Modified response with HX-Trigger header
"""
trigger_data = {
"showToast": {
"message": message,
"type": toast_type
}
}
trigger_data = {"showToast": {"message": message, "type": toast_type}}
response["HX-Trigger"] = json.dumps(trigger_data)
return response
@method_decorator(csrf_protect, name='dispatch')
@method_decorator(csrf_protect, name="dispatch")
class FSMTransitionView(View):
"""
Generic view for handling FSM state transitions via HTMX.
@@ -212,12 +256,16 @@ class FSMTransitionView(View):
The model class or None if not found
"""
try:
content_type = ContentType.objects.get(app_label=app_label, model=model_name)
content_type = ContentType.objects.get(
app_label=app_label, model=model_name
)
return content_type.model_class()
except ContentType.DoesNotExist:
return None
def get_object(self, model_class: Type[Model], pk: Any, slug: Optional[str] = None) -> Model:
def get_object(
self, model_class: Type[Model], pk: Any, slug: Optional[str] = None
) -> Model:
"""
Get the model instance.
@@ -249,7 +297,9 @@ class FSMTransitionView(View):
"""
return getattr(obj, transition_name, None)
def validate_transition(self, obj: Model, transition_name: str, user) -> tuple[bool, Optional[str]]:
def validate_transition(
self, obj: Model, transition_name: str, user
) -> tuple[bool, Optional[str]]:
"""
Validate that the transition can proceed.
@@ -264,18 +314,26 @@ class FSMTransitionView(View):
method = self.get_transition_method(obj, transition_name)
if method is None:
return False, f"Transition '{transition_name}' not found on {obj.__class__.__name__}"
return (
False,
f"Transition '{transition_name}' not found on {obj.__class__.__name__}",
)
if not callable(method):
return False, f"'{transition_name}' is not a callable method"
# Check if the transition can proceed
if not can_proceed(method, user):
return False, f"Transition '{transition_name}' is not allowed from current state"
return (
False,
f"Transition '{transition_name}' is not allowed from current state",
)
return True, None
def execute_transition(self, obj: Model, transition_name: str, user, **kwargs) -> None:
def execute_transition(
self, obj: Model, transition_name: str, user, **kwargs
) -> None:
"""
Execute the transition on the object.
@@ -297,7 +355,9 @@ class FSMTransitionView(View):
def get_success_message(self, obj: Model, transition_name: str) -> str:
"""Generate a success message for the transition."""
# Clean up transition name for display
display_name = transition_name.replace("transition_to_", "").replace("_", " ").title()
display_name = (
transition_name.replace("transition_to_", "").replace("_", " ").title()
)
model_name = obj._meta.verbose_name.title()
return f"{model_name} has been {display_name.lower()}d successfully."
@@ -321,9 +381,9 @@ class FSMTransitionView(View):
model_name = obj._meta.model_name
# Special handling for parks and rides - return status section
if app_label == 'parks' and model_name == 'park':
if app_label == "parks" and model_name == "park":
return "parks/partials/park_status_actions.html"
elif app_label == 'rides' and model_name == 'ride':
elif app_label == "rides" and model_name == "ride":
return "rides/partials/ride_status_actions.html"
# Check for model-specific templates in order of preference
@@ -337,6 +397,7 @@ class FSMTransitionView(View):
# Use template loader to check if template exists
from django.template.loader import select_template
from django.template import TemplateDoesNotExist
try:
template = select_template(possible_templates)
return template.template.name
@@ -344,10 +405,7 @@ class FSMTransitionView(View):
return "htmx/updated_row.html"
def format_success_response(
self,
request: HttpRequest,
obj: Model,
transition_name: str
self, request: HttpRequest, obj: Model, transition_name: str
) -> HttpResponse:
"""
Format a successful transition response.
@@ -381,17 +439,20 @@ class FSMTransitionView(View):
return add_toast_trigger(response, message, "success")
# Regular request - return JSON
return JsonResponse({
"success": True,
"message": message,
"new_state": getattr(obj, obj.state_field_name, None) if hasattr(obj, "state_field_name") else None,
})
return JsonResponse(
{
"success": True,
"message": message,
"new_state": (
getattr(obj, obj.state_field_name, None)
if hasattr(obj, "state_field_name")
else None
),
}
)
def format_error_response(
self,
request: HttpRequest,
error: Exception,
status_code: int = 400
self, request: HttpRequest, error: Exception, status_code: int = 400
) -> HttpResponse:
"""
Format an error response.
@@ -408,10 +469,13 @@ class FSMTransitionView(View):
return add_toast_trigger(response, message, "error")
# Regular request - return JSON
return JsonResponse({
"success": False,
"error": error_data,
}, status=status_code)
return JsonResponse(
{
"success": False,
"error": error_data,
},
status=status_code,
)
def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Handle POST request to execute a transition."""
@@ -425,24 +489,22 @@ class FSMTransitionView(View):
if not all([app_label, model_name, transition_name]):
return self.format_error_response(
request,
ValueError("Missing required parameters: app_label, model_name, and transition_name"),
400
ValueError(
"Missing required parameters: app_label, model_name, and transition_name"
),
400,
)
if not pk and not slug:
return self.format_error_response(
request,
ValueError("Missing required parameter: pk or slug"),
400
request, ValueError("Missing required parameter: pk or slug"), 400
)
# Get the model class
model_class = self.get_model_class(app_label, model_name)
if model_class is None:
return self.format_error_response(
request,
ValueError(f"Model '{app_label}.{model_name}' not found"),
404
request, ValueError(f"Model '{app_label}.{model_name}' not found"), 404
)
# Get the object
@@ -450,13 +512,13 @@ class FSMTransitionView(View):
obj = self.get_object(model_class, pk, slug)
except ObjectDoesNotExist:
return self.format_error_response(
request,
ValueError(f"Object not found: {model_name} with pk={pk}"),
404
request, ValueError(f"Object not found: {model_name} with pk={pk}"), 404
)
# Validate the transition
can_execute, error_msg = self.validate_transition(obj, transition_name, request.user)
can_execute, error_msg = self.validate_transition(
obj, transition_name, request.user
)
if not can_execute:
return self.format_error_response(
request,
@@ -466,7 +528,7 @@ class FSMTransitionView(View):
current_state=getattr(obj, "status", None),
requested_transition=transition_name,
),
400
400,
)
# Execute the transition
@@ -509,7 +571,5 @@ class FSMTransitionView(View):
f"Unexpected error during transition '{transition_name}' on {model_class.__name__}(pk={obj.pk})"
)
return self.format_error_response(
request,
ValueError(f"An unexpected error occurred: {str(e)}"),
500
request, ValueError(f"An unexpected error occurred: {str(e)}"), 500
)

View File

@@ -7,6 +7,7 @@ across all models using django-fsm-log.
from django.core.management.base import BaseCommand
from django.db.models import Count, Avg, F
from django.db.models.functions import TruncDate, ExtractHour
from django.utils import timezone
from datetime import timedelta
from django_fsm_log.models import StateLog
@@ -148,9 +149,10 @@ class Command(BaseCommand):
self.stdout.write(f" System-initiated: {system_count} ({(system_count/total_transitions)*100:.1f}%)")
# Daily transition volume
# Security: Using Django ORM functions instead of raw SQL .extra() to prevent SQL injection
self.stdout.write(self.style.SUCCESS('\n--- Daily Transition Volume ---'))
daily_stats = (
queryset.extra(select={'day': 'date(timestamp)'})
queryset.annotate(day=TruncDate('timestamp'))
.values('day')
.annotate(count=Count('id'))
.order_by('-day')[:7]
@@ -162,9 +164,10 @@ class Command(BaseCommand):
self.stdout.write(f" {date}: {count} transitions")
# Busiest hours
# Security: Using Django ORM functions instead of raw SQL .extra() to prevent SQL injection
self.stdout.write(self.style.SUCCESS('\n--- Busiest Hours (UTC) ---'))
hourly_stats = (
queryset.extra(select={'hour': 'extract(hour from timestamp)'})
queryset.annotate(hour=ExtractHour('timestamp'))
.values('hour')
.annotate(count=Count('id'))
.order_by('-count')[:5]

View File

@@ -4,7 +4,8 @@ Following Django styleguide pattern for separating data access from business log
"""
from typing import Optional, Dict, Any
from django.db.models import QuerySet, Count
from django.db.models import QuerySet, Count, F, ExpressionWrapper, FloatField
from django.db.models.functions import Extract
from django.utils import timezone
from datetime import timedelta
from django.contrib.auth.models import User
@@ -185,12 +186,14 @@ def moderation_statistics_summary(
rejected_submissions = handled_queryset.filter(status="REJECTED").count()
# Response time analysis (only for handled submissions)
# Security: Using Django ORM instead of raw SQL .extra() to prevent SQL injection
handled_with_times = (
handled_queryset.exclude(handled_at__isnull=True)
.extra(
select={
"response_hours": "EXTRACT(EPOCH FROM (handled_at - created_at)) / 3600"
}
.annotate(
response_hours=ExpressionWrapper(
Extract(F('handled_at') - F('created_at'), 'epoch') / 3600.0,
output_field=FloatField()
)
)
.values_list("response_hours", flat=True)
)

View File

@@ -306,35 +306,39 @@ class ModerationReportViewSet(viewsets.ModelViewSet):
return Response(stats_data)
@action(detail=True, methods=['get'], permission_classes=[CanViewModerationData])
@action(detail=True, methods=["get"], permission_classes=[CanViewModerationData])
def history(self, request, pk=None):
"""Get transition history for this report."""
from django_fsm_log.models import StateLog
from django.contrib.contenttypes.models import ContentType
report = self.get_object()
content_type = ContentType.objects.get_for_model(report)
logs = StateLog.objects.filter(
content_type=content_type,
object_id=report.id
).select_related('by').order_by('-timestamp')
history_data = [{
'id': log.id,
'timestamp': log.timestamp,
'state': log.state,
'from_state': log.source_state,
'to_state': log.state,
'transition': log.transition,
'user': log.by.username if log.by else None,
'description': log.description,
'reason': log.description,
} for log in logs]
logs = (
StateLog.objects.filter(content_type=content_type, object_id=report.id)
.select_related("by")
.order_by("-timestamp")
)
history_data = [
{
"id": log.id,
"timestamp": log.timestamp,
"state": log.state,
"from_state": log.source_state,
"to_state": log.state,
"transition": log.transition,
"user": log.by.username if log.by else None,
"description": log.description,
"reason": log.description,
}
for log in logs
]
return Response(history_data)
@action(detail=False, methods=['get'], permission_classes=[CanViewModerationData])
@action(detail=False, methods=["get"], permission_classes=[CanViewModerationData])
def all_history(self, request):
"""Get all transition history with filtering.
@@ -343,61 +347,77 @@ class ModerationReportViewSet(viewsets.ModelViewSet):
from django_fsm_log.models import StateLog
from django.contrib.contenttypes.models import ContentType
queryset = StateLog.objects.select_related('by', 'content_type').all()
queryset = StateLog.objects.select_related("by", "content_type").all()
# Filter by id (for detail view)
log_id = request.query_params.get('id')
log_id = request.query_params.get("id")
if log_id:
try:
log = queryset.get(id=log_id)
# Check if HTMX request for detail view
if request.headers.get('HX-Request'):
return render(request, 'moderation/partials/history_detail_content.html', {
'log': log,
})
if request.headers.get("HX-Request"):
return render(
request,
"moderation/partials/history_detail_content.html",
{
"log": log,
},
)
# Return JSON for API request
return Response({
'id': log.id,
'timestamp': log.timestamp,
'model': log.content_type.model,
'object_id': log.object_id,
'state': log.state,
'from_state': log.source_state,
'to_state': log.state,
'transition': log.transition,
'user': log.by.username if log.by else None,
'description': log.description,
'reason': log.description,
})
return Response(
{
"id": log.id,
"timestamp": log.timestamp,
"model": log.content_type.model,
"object_id": log.object_id,
"state": log.state,
"from_state": log.source_state,
"to_state": log.state,
"transition": log.transition,
"user": log.by.username if log.by else None,
"description": log.description,
"reason": log.description,
}
)
except StateLog.DoesNotExist:
if request.headers.get('HX-Request'):
return render(request, 'moderation/partials/history_detail_content.html', {
'log': None,
})
return Response({'error': 'Log not found'}, status=status.HTTP_404_NOT_FOUND)
if request.headers.get("HX-Request"):
return render(
request,
"moderation/partials/history_detail_content.html",
{
"log": None,
},
)
return Response(
{"error": "Log not found"}, status=status.HTTP_404_NOT_FOUND
)
# Filter by model type with app_label support for correct ContentType resolution
model_type = request.query_params.get('model_type')
app_label = request.query_params.get('app_label')
model_type = request.query_params.get("model_type")
app_label = request.query_params.get("app_label")
if model_type:
try:
if app_label:
# Use both app_label and model for precise matching
content_type = ContentType.objects.get_by_natural_key(app_label, model_type)
content_type = ContentType.objects.get_by_natural_key(
app_label, model_type
)
else:
# Map common model names to their app_labels for correct resolution
model_app_mapping = {
'park': 'parks',
'ride': 'rides',
'editsubmission': 'submissions',
'photosubmission': 'submissions',
'moderationreport': 'moderation',
'moderationqueue': 'moderation',
'bulkoperation': 'moderation',
"park": "parks",
"ride": "rides",
"editsubmission": "submissions",
"photosubmission": "submissions",
"moderationreport": "moderation",
"moderationqueue": "moderation",
"bulkoperation": "moderation",
}
mapped_app_label = model_app_mapping.get(model_type.lower())
if mapped_app_label:
content_type = ContentType.objects.get_by_natural_key(mapped_app_label, model_type.lower())
content_type = ContentType.objects.get_by_natural_key(
mapped_app_label, model_type.lower()
)
else:
# Fallback to model-only lookup
content_type = ContentType.objects.get(model=model_type)
@@ -406,88 +426,98 @@ class ModerationReportViewSet(viewsets.ModelViewSet):
pass
# Filter by object_id (for object-level history)
object_id = request.query_params.get('object_id')
object_id = request.query_params.get("object_id")
if object_id:
queryset = queryset.filter(object_id=object_id)
# Filter by user
user_id = request.query_params.get('user_id')
user_id = request.query_params.get("user_id")
if user_id:
queryset = queryset.filter(by_id=user_id)
# Filter by date range
start_date = request.query_params.get('start_date')
end_date = request.query_params.get('end_date')
start_date = request.query_params.get("start_date")
end_date = request.query_params.get("end_date")
if start_date:
queryset = queryset.filter(timestamp__gte=start_date)
if end_date:
queryset = queryset.filter(timestamp__lte=end_date)
# Filter by state
state = request.query_params.get('state')
state = request.query_params.get("state")
if state:
queryset = queryset.filter(state=state)
# Search filter (case-insensitive across relevant fields)
search_query = request.query_params.get('q')
search_query = request.query_params.get("q")
if search_query:
queryset = queryset.filter(
Q(transition__icontains=search_query) |
Q(description__icontains=search_query) |
Q(state__icontains=search_query) |
Q(source_state__icontains=search_query) |
Q(object_id__icontains=search_query) |
Q(by__username__icontains=search_query)
Q(transition__icontains=search_query)
| Q(description__icontains=search_query)
| Q(state__icontains=search_query)
| Q(source_state__icontains=search_query)
| Q(object_id__icontains=search_query)
| Q(by__username__icontains=search_query)
)
# Order queryset
queryset = queryset.order_by('-timestamp')
queryset = queryset.order_by("-timestamp")
# Check if HTMX request
if request.headers.get('HX-Request'):
if request.headers.get("HX-Request"):
# Use Django's Paginator for HTMX responses
paginator = Paginator(queryset, 20)
page_number = request.query_params.get('page', 1)
page_number = request.query_params.get("page", 1)
page_obj = paginator.get_page(page_number)
return render(request, 'moderation/partials/history_table.html', {
'history_logs': page_obj,
'page_obj': page_obj,
'request': request,
})
return render(
request,
"moderation/partials/history_table.html",
{
"history_logs": page_obj,
"page_obj": page_obj,
"request": request,
},
)
# Paginate for API response
page = self.paginate_queryset(queryset)
if page is not None:
history_data = [{
'id': log.id,
'timestamp': log.timestamp,
'model': log.content_type.model,
'object_id': log.object_id,
'state': log.state,
'from_state': log.source_state,
'to_state': log.state,
'transition': log.transition,
'user': log.by.username if log.by else None,
'description': log.description,
'reason': log.description,
} for log in page]
history_data = [
{
"id": log.id,
"timestamp": log.timestamp,
"model": log.content_type.model,
"object_id": log.object_id,
"state": log.state,
"from_state": log.source_state,
"to_state": log.state,
"transition": log.transition,
"user": log.by.username if log.by else None,
"description": log.description,
"reason": log.description,
}
for log in page
]
return self.get_paginated_response(history_data)
# Return all history data when pagination is not triggered
history_data = [{
'id': log.id,
'timestamp': log.timestamp,
'model': log.content_type.model,
'object_id': log.object_id,
'state': log.state,
'from_state': log.source_state,
'to_state': log.state,
'transition': log.transition,
'user': log.by.username if log.by else None,
'description': log.description,
'reason': log.description,
} for log in queryset]
history_data = [
{
"id": log.id,
"timestamp": log.timestamp,
"model": log.content_type.model,
"object_id": log.object_id,
"state": log.state,
"from_state": log.source_state,
"to_state": log.state,
"transition": log.transition,
"user": log.by.username if log.by else None,
"description": log.description,
"reason": log.description,
}
for log in queryset
]
return Response(history_data)
@@ -704,7 +734,7 @@ class ModerationQueueViewSet(viewsets.ModelViewSet):
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)
@action(detail=True, methods=['get'], permission_classes=[CanViewModerationData])
@action(detail=True, methods=["get"], permission_classes=[CanViewModerationData])
def history(self, request, pk=None):
"""Get transition history for this queue item."""
from django_fsm_log.models import StateLog
@@ -713,22 +743,26 @@ class ModerationQueueViewSet(viewsets.ModelViewSet):
queue_item = self.get_object()
content_type = ContentType.objects.get_for_model(queue_item)
logs = StateLog.objects.filter(
content_type=content_type,
object_id=queue_item.id
).select_related('by').order_by('-timestamp')
logs = (
StateLog.objects.filter(content_type=content_type, object_id=queue_item.id)
.select_related("by")
.order_by("-timestamp")
)
history_data = [{
'id': log.id,
'timestamp': log.timestamp,
'state': log.state,
'from_state': log.source_state,
'to_state': log.state,
'transition': log.transition,
'user': log.by.username if log.by else None,
'description': log.description,
'reason': log.description,
} for log in logs]
history_data = [
{
"id": log.id,
"timestamp": log.timestamp,
"state": log.state,
"from_state": log.source_state,
"to_state": log.state,
"transition": log.transition,
"user": log.by.username if log.by else None,
"description": log.description,
"reason": log.description,
}
for log in logs
]
return Response(history_data)
@@ -996,7 +1030,7 @@ class BulkOperationViewSet(viewsets.ModelViewSet):
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)
@action(detail=True, methods=['get'])
@action(detail=True, methods=["get"])
def history(self, request, pk=None):
"""Get transition history for this bulk operation."""
from django_fsm_log.models import StateLog
@@ -1005,22 +1039,26 @@ class BulkOperationViewSet(viewsets.ModelViewSet):
operation = self.get_object()
content_type = ContentType.objects.get_for_model(operation)
logs = StateLog.objects.filter(
content_type=content_type,
object_id=operation.id
).select_related('by').order_by('-timestamp')
logs = (
StateLog.objects.filter(content_type=content_type, object_id=operation.id)
.select_related("by")
.order_by("-timestamp")
)
history_data = [{
'id': log.id,
'timestamp': log.timestamp,
'state': log.state,
'from_state': log.source_state,
'to_state': log.state,
'transition': log.transition,
'user': log.by.username if log.by else None,
'description': log.description,
'reason': log.description,
} for log in logs]
history_data = [
{
"id": log.id,
"timestamp": log.timestamp,
"state": log.state,
"from_state": log.source_state,
"to_state": log.state,
"transition": log.transition,
"user": log.by.username if log.by else None,
"description": log.description,
"reason": log.description,
}
for log in logs
]
return Response(history_data)

View File

@@ -203,11 +203,12 @@ class Command(BaseCommand):
)
def create_parks(self):
"""Create parks with proper operator relationships"""
"""Create parks with proper operator relationships."""
self.stdout.write("Creating parks...")
# TODO: Implement park creation - parks_data defined but not used yet
parks_data = [ # noqa: F841
# Park creation data - will be used to create parks in the database
# TODO(THRILLWIKI-111): Complete park creation implementation
parks_data = [
{
"name": "Magic Kingdom",
"slug": "magic-kingdom",

View File

@@ -194,10 +194,9 @@ class Command(BaseCommand):
missing_tables = []
for model in required_models:
try:
# Check if the table exists by trying to get the table name
table_name = model._meta.db_table
with connection.cursor() as cursor:
cursor.execute(f"SELECT 1 FROM {table_name} LIMIT 1")
# Security: Use Django ORM to check table existence instead of raw SQL
# This is safer as it avoids any potential SQL injection via model metadata
model.objects.exists()
except Exception:
missing_tables.append(model._meta.label)

View File

@@ -127,16 +127,17 @@ class ParkQuerySet(StatusQuerySet, ReviewableQuerySet, LocationQuerySet):
| Q(location__state__icontains=query)
)
.select_related("operator", "location")
.values(
"id",
"name",
"slug",
"location__city",
"location__state",
"operator__name",
.only(
"id", "name", "slug",
"location__city", "location__state",
"operator__name"
)[:limit]
)
def with_location(self):
"""Always prefetch location for park queries."""
return self.select_related("operator").prefetch_related("location")
class ParkManager(StatusManager, ReviewableManager, LocationManager):
"""Custom manager for Park model."""
@@ -162,6 +163,10 @@ class ParkManager(StatusManager, ReviewableManager, LocationManager):
def for_map_display(self, *, bounds=None):
return self.get_queryset().for_map_display(bounds=bounds)
def with_location(self):
"""Always prefetch location for park queries."""
return self.get_queryset().with_location()
class ParkAreaQuerySet(BaseQuerySet):
"""QuerySet for ParkArea model."""
@@ -300,3 +305,33 @@ class CompanyManager(BaseManager):
def major_operators(self, *, min_parks: int = 5):
return self.get_queryset().major_operators(min_parks=min_parks)
def manufacturers_with_ride_count(self):
"""Get manufacturers with ride count annotation for list views."""
return (
self.get_queryset()
.manufacturers()
.annotate(ride_count=Count("manufactured_rides", distinct=True))
.only('id', 'name', 'slug', 'roles', 'description')
.order_by("name")
)
def designers_with_ride_count(self):
"""Get designers with ride count annotation for list views."""
return (
self.get_queryset()
.filter(roles__contains=["DESIGNER"])
.annotate(ride_count=Count("designed_rides", distinct=True))
.only('id', 'name', 'slug', 'roles', 'description')
.order_by("name")
)
def operators_with_park_count(self):
"""Get operators with park count annotation for list views."""
return (
self.get_queryset()
.operators()
.with_park_counts()
.only('id', 'name', 'slug', 'roles', 'description')
.order_by("name")
)

View File

@@ -0,0 +1,25 @@
"""
Add GIN index for Company.roles ArrayField.
This improves query performance for queries like:
Company.objects.filter(roles__contains=["MANUFACTURER"])
Company.objects.filter(roles__contains=["OPERATOR"])
GIN indexes are specifically designed for array containment queries in PostgreSQL.
"""
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('parks', '0022_alter_company_roles_alter_companyevent_roles'),
]
operations = [
migrations.RunSQL(
sql="CREATE INDEX IF NOT EXISTS parks_company_roles_gin_idx ON parks_company USING gin(roles);",
reverse_sql="DROP INDEX IF EXISTS parks_company_roles_gin_idx;",
),
]

View File

@@ -0,0 +1,28 @@
"""
Add default value 'UTC' to Park.timezone field.
This ensures all new parks have a valid timezone and existing parks
without a timezone get a sensible default.
"""
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('parks', '0023_add_company_roles_gin_index'),
]
operations = [
migrations.AlterField(
model_name='park',
name='timezone',
field=models.CharField(
blank=True,
default='UTC',
help_text="Timezone identifier for park operations (e.g., 'America/New_York')",
max_length=50,
),
),
]

View File

@@ -127,6 +127,8 @@ class Park(StateMachineMixin, TrackedModel):
# Timezone for park operations
timezone = models.CharField(
max_length=50,
default='UTC',
blank=True,
help_text="Timezone identifier for park operations (e.g., 'America/New_York')"
)

View File

@@ -3,16 +3,21 @@ Services for park-related business logic.
Following Django styleguide pattern for business logic encapsulation.
"""
from typing import Optional, Dict, Any, TYPE_CHECKING
import logging
from typing import Optional, Dict, Any, List, TYPE_CHECKING
from django.db import transaction
from django.db.models import Q
from django.core.files.uploadedfile import UploadedFile
if TYPE_CHECKING:
from django.contrib.auth.models import AbstractUser
from ..models import Park, ParkArea
from ..models import Park, ParkArea, ParkPhoto
from ..models.location import ParkLocation
from .location_service import ParkLocationService
logger = logging.getLogger(__name__)
class ParkService:
"""Service for managing park operations."""
@@ -226,3 +231,282 @@ class ParkService:
park.save()
return park
@staticmethod
def create_park_with_moderation(
*,
changes: Dict[str, Any],
submitter: "AbstractUser",
reason: str = "",
source: str = "",
) -> Dict[str, Any]:
"""
Create a park through the moderation system.
Args:
changes: Dictionary of park data
submitter: User submitting the park
reason: Reason for submission
source: Source of information
Returns:
Dictionary with status and created object (if auto-approved)
"""
from apps.moderation.services import ModerationService
return ModerationService.create_edit_submission_with_queue(
content_object=None,
changes=changes,
submitter=submitter,
submission_type="CREATE",
reason=reason,
source=source,
)
@staticmethod
def update_park_with_moderation(
*,
park: Park,
changes: Dict[str, Any],
submitter: "AbstractUser",
reason: str = "",
source: str = "",
) -> Dict[str, Any]:
"""
Update a park through the moderation system.
Args:
park: Park instance to update
changes: Dictionary of changes
submitter: User submitting the update
reason: Reason for submission
source: Source of information
Returns:
Dictionary with status and updated object (if auto-approved)
"""
from apps.moderation.services import ModerationService
return ModerationService.create_edit_submission_with_queue(
content_object=park,
changes=changes,
submitter=submitter,
submission_type="EDIT",
reason=reason,
source=source,
)
@staticmethod
def create_or_update_location(
*,
park: Park,
latitude: Optional[float],
longitude: Optional[float],
street_address: str = "",
city: str = "",
state: str = "",
country: str = "USA",
postal_code: str = "",
) -> Optional[ParkLocation]:
"""
Create or update a park's location.
Args:
park: Park instance
latitude: Latitude coordinate
longitude: Longitude coordinate
street_address: Street address
city: City name
state: State/region
country: Country (default: USA)
postal_code: Postal/ZIP code
Returns:
ParkLocation instance or None if no coordinates provided
"""
if not latitude or not longitude:
return None
try:
park_location = park.location
# Update existing location
park_location.street_address = street_address
park_location.city = city
park_location.state = state
park_location.country = country or "USA"
park_location.postal_code = postal_code
park_location.set_coordinates(float(latitude), float(longitude))
park_location.save()
return park_location
except ParkLocation.DoesNotExist:
# Create new location
park_location = ParkLocation.objects.create(
park=park,
street_address=street_address,
city=city,
state=state,
country=country or "USA",
postal_code=postal_code,
)
park_location.set_coordinates(float(latitude), float(longitude))
park_location.save()
return park_location
@staticmethod
def upload_photos(
*,
park: Park,
photos: List[UploadedFile],
uploaded_by: "AbstractUser",
) -> Dict[str, Any]:
"""
Upload multiple photos for a park.
Args:
park: Park instance
photos: List of uploaded photo files
uploaded_by: User uploading the photos
Returns:
Dictionary with uploaded_count and errors list
"""
from django.contrib.contenttypes.models import ContentType
uploaded_count = 0
errors: List[str] = []
for photo_file in photos:
try:
ParkPhoto.objects.create(
image=photo_file,
uploaded_by=uploaded_by,
park=park,
)
uploaded_count += 1
except Exception as e:
error_msg = f"Error uploading photo {photo_file.name}: {str(e)}"
errors.append(error_msg)
logger.warning(error_msg)
return {
"uploaded_count": uploaded_count,
"errors": errors,
}
@staticmethod
def handle_park_creation_result(
*,
result: Dict[str, Any],
form_data: Dict[str, Any],
photos: List[UploadedFile],
user: "AbstractUser",
) -> Dict[str, Any]:
"""
Handle the result of park creation through moderation.
Args:
result: Result from create_park_with_moderation
form_data: Cleaned form data containing location info
photos: List of uploaded photo files
user: User who submitted
Returns:
Dictionary with status, park (if created), uploaded_count, and errors
"""
response: Dict[str, Any] = {
"status": result["status"],
"park": None,
"uploaded_count": 0,
"errors": [],
}
if result["status"] == "auto_approved":
park = result["created_object"]
response["park"] = park
# Create location
ParkService.create_or_update_location(
park=park,
latitude=form_data.get("latitude"),
longitude=form_data.get("longitude"),
street_address=form_data.get("street_address", ""),
city=form_data.get("city", ""),
state=form_data.get("state", ""),
country=form_data.get("country", "USA"),
postal_code=form_data.get("postal_code", ""),
)
# Upload photos
if photos:
photo_result = ParkService.upload_photos(
park=park,
photos=photos,
uploaded_by=user,
)
response["uploaded_count"] = photo_result["uploaded_count"]
response["errors"] = photo_result["errors"]
elif result["status"] == "failed":
response["message"] = result.get("message", "Creation failed")
return response
@staticmethod
def handle_park_update_result(
*,
result: Dict[str, Any],
park: Park,
form_data: Dict[str, Any],
photos: List[UploadedFile],
user: "AbstractUser",
) -> Dict[str, Any]:
"""
Handle the result of park update through moderation.
Args:
result: Result from update_park_with_moderation
park: Original park instance (for queued submissions)
form_data: Cleaned form data containing location info
photos: List of uploaded photo files
user: User who submitted
Returns:
Dictionary with status, park, uploaded_count, and errors
"""
response: Dict[str, Any] = {
"status": result["status"],
"park": park,
"uploaded_count": 0,
"errors": [],
}
if result["status"] == "auto_approved":
updated_park = result["created_object"]
response["park"] = updated_park
# Update location
ParkService.create_or_update_location(
park=updated_park,
latitude=form_data.get("latitude"),
longitude=form_data.get("longitude"),
street_address=form_data.get("street_address", ""),
city=form_data.get("city", ""),
state=form_data.get("state", ""),
country=form_data.get("country", ""),
postal_code=form_data.get("postal_code", ""),
)
# Upload photos
if photos:
photo_result = ParkService.upload_photos(
park=updated_park,
photos=photos,
uploaded_by=user,
)
response["uploaded_count"] = photo_result["uploaded_count"]
response["errors"] = photo_result["errors"]
elif result["status"] == "failed":
response["message"] = result.get("message", "Update failed")
return response

View File

@@ -11,6 +11,28 @@ from .models import Park
logger = logging.getLogger(__name__)
# =============================================================================
# Computed Field Maintenance Signals
# =============================================================================
def update_park_search_text(park):
"""
Update park's search_text computed field.
This is called when related objects (location, operator, property_owner)
change and might affect the park's search text.
"""
if park is None:
return
try:
park._populate_computed_fields()
park.save(update_fields=['search_text'])
logger.debug(f"Updated search_text for park {park.pk}")
except Exception as e:
logger.exception(f"Failed to update search_text for park {park.pk}: {e}")
# Status values that count as "active" rides for counting purposes
ACTIVE_STATUSES = {'OPERATING', 'SEASONAL', 'UNDER_CONSTRUCTION'}
@@ -46,8 +68,8 @@ def update_park_ride_counts(park, old_status=None, new_status=None):
# Count total operating rides
ride_count = park.rides.filter(operating_rides).count()
# Count total operating roller coasters
coaster_count = park.rides.filter(operating_rides, category="RC").count()
# Count total operating roller coasters (including water coasters)
coaster_count = park.rides.filter(operating_rides, category__in=["RC", "WC"]).count()
# Update park counts
Park.objects.filter(id=park_id).update(
@@ -148,3 +170,44 @@ def handle_ride_status_transition(instance, source, target, user, **kwargs):
f"by {user if user else 'system'}"
)
update_park_ride_counts(instance.park, source, target)
# =============================================================================
# Computed Field Maintenance Signal Handlers
# =============================================================================
@receiver(post_save, sender='parks.ParkLocation')
def update_park_search_text_on_location_change(sender, instance, **kwargs):
"""
Update park search_text when location changes.
When a park's location is updated (city, state, country changes),
the park's search_text needs to be regenerated to include the new
location information.
"""
try:
if hasattr(instance, 'park') and instance.park:
update_park_search_text(instance.park)
except Exception as e:
logger.exception(f"Failed to update park search_text on location change: {e}")
@receiver(post_save, sender='parks.Company')
def update_park_search_text_on_company_change(sender, instance, **kwargs):
"""
Update park search_text when operator/owner name changes.
When a company's name changes, all parks operated or owned by that
company need their search_text regenerated.
"""
try:
# Update all parks operated by this company
for park in instance.operated_parks.all():
update_park_search_text(park)
# Update all parks owned by this company
for park in instance.owned_parks.all():
update_park_search_text(park)
except Exception as e:
logger.exception(f"Failed to update park search_text on company change: {e}")

View File

@@ -0,0 +1,243 @@
"""
Tests for query optimization patterns in the parks app.
These tests verify that:
1. Manager methods use proper select_related/prefetch_related
2. Views don't trigger N+1 queries
3. Computed fields are updated correctly
"""
from django.test import TestCase
from django.db import connection
from django.test.utils import CaptureQueriesContext
from apps.parks.models import Park, ParkLocation, Company
class ParkQueryOptimizationTests(TestCase):
"""Tests for Park query optimization."""
@classmethod
def setUpTestData(cls):
"""Set up test data for all tests."""
# Create a test operator company
cls.operator = Company.objects.create(
name="Test Operator",
slug="test-operator",
roles=["OPERATOR"],
)
# Create test parks
cls.parks = []
for i in range(5):
park = Park.objects.create(
name=f"Test Park {i}",
slug=f"test-park-{i}",
operator=cls.operator,
timezone="UTC",
)
# Create location for each park
ParkLocation.objects.create(
park=park,
city=f"City {i}",
state="CA",
country="USA",
)
cls.parks.append(park)
def test_optimized_for_list_query_count(self):
"""Verify optimized_for_list uses expected number of queries."""
with CaptureQueriesContext(connection) as context:
parks = Park.objects.optimized_for_list()
# Force evaluation
list(parks)
# Should be a small number of queries (main query + prefetch)
# The exact count depends on prefetch_related configuration
self.assertLessEqual(
len(context.captured_queries),
5,
f"Expected <= 5 queries, got {len(context.captured_queries)}"
)
def test_optimized_for_detail_query_count(self):
"""Verify optimized_for_detail uses expected number of queries."""
with CaptureQueriesContext(connection) as context:
parks = Park.objects.optimized_for_detail()
park = parks.first()
if park:
# Access related objects that should be prefetched
_ = park.operator
_ = list(park.areas.all())
# Should be a reasonable number of queries
self.assertLessEqual(
len(context.captured_queries),
10,
f"Expected <= 10 queries, got {len(context.captured_queries)}"
)
def test_with_location_includes_location(self):
"""Verify with_location prefetches location data."""
with CaptureQueriesContext(connection) as context:
parks = Park.objects.with_location()
for park in parks:
# Accessing location should not cause additional queries
_ = park.location
# Should be minimal queries
self.assertLessEqual(len(context.captured_queries), 3)
def test_for_map_display_returns_minimal_fields(self):
"""Verify for_map_display returns only necessary fields."""
result = Park.objects.for_map_display()
if result.exists():
first = result.first()
# Should include these fields
self.assertIn('id', first)
self.assertIn('name', first)
self.assertIn('slug', first)
self.assertIn('status', first)
def test_search_autocomplete_limits_results(self):
"""Verify search_autocomplete respects limit parameter."""
result = Park.objects.search_autocomplete(query="Test", limit=3)
self.assertLessEqual(len(result), 3)
class CompanyQueryOptimizationTests(TestCase):
"""Tests for Company query optimization."""
@classmethod
def setUpTestData(cls):
"""Set up test data for all tests."""
# Create test companies with different roles
cls.manufacturers = []
for i in range(5):
company = Company.objects.create(
name=f"Manufacturer {i}",
slug=f"manufacturer-{i}",
roles=["MANUFACTURER"],
)
cls.manufacturers.append(company)
cls.operators = []
for i in range(3):
company = Company.objects.create(
name=f"Operator {i}",
slug=f"operator-{i}",
roles=["OPERATOR"],
)
cls.operators.append(company)
def test_manufacturers_query_only_returns_manufacturers(self):
"""Verify manufacturers() only returns companies with MANUFACTURER role."""
result = Company.objects.manufacturers()
for company in result:
self.assertIn("MANUFACTURER", company.roles)
def test_operators_query_only_returns_operators(self):
"""Verify operators() only returns companies with OPERATOR role."""
result = Company.objects.operators()
for company in result:
self.assertIn("OPERATOR", company.roles)
def test_manufacturers_with_ride_count_includes_annotation(self):
"""Verify manufacturers_with_ride_count adds ride_count annotation."""
result = Company.objects.manufacturers_with_ride_count()
if result.exists():
first = result.first()
# Should have ride_count attribute
self.assertTrue(hasattr(first, 'ride_count'))
def test_operators_with_park_count_includes_annotation(self):
"""Verify operators_with_park_count adds park count annotations."""
result = Company.objects.operators_with_park_count()
if result.exists():
first = result.first()
# Should have operated_parks_count attribute
self.assertTrue(hasattr(first, 'operated_parks_count'))
class ComputedFieldMaintenanceTests(TestCase):
"""Tests for computed field maintenance via signals."""
@classmethod
def setUpTestData(cls):
"""Set up test data for all tests."""
cls.operator = Company.objects.create(
name="Test Operator",
slug="test-operator",
roles=["OPERATOR"],
)
def test_park_search_text_includes_name(self):
"""Verify park search_text includes park name."""
park = Park.objects.create(
name="Magic Kingdom",
slug="magic-kingdom",
operator=self.operator,
timezone="UTC",
)
self.assertIn("magic kingdom", park.search_text.lower())
def test_park_search_text_includes_description(self):
"""Verify park search_text includes description."""
park = Park.objects.create(
name="Test Park",
slug="test-park",
description="A magical theme park experience",
operator=self.operator,
timezone="UTC",
)
self.assertIn("magical", park.search_text.lower())
def test_park_search_text_includes_operator(self):
"""Verify park search_text includes operator name."""
park = Park.objects.create(
name="Test Park",
slug="test-park-2",
operator=self.operator,
timezone="UTC",
)
self.assertIn("test operator", park.search_text.lower())
def test_park_opening_year_computed_from_date(self):
"""Verify opening_year is computed from opening_date."""
from datetime import date
park = Park.objects.create(
name="Test Park",
slug="test-park-3",
operator=self.operator,
timezone="UTC",
opening_date=date(1971, 10, 1),
)
self.assertEqual(park.opening_year, 1971)
def test_park_search_text_updated_on_location_change(self):
"""Verify park search_text updates when location changes."""
park = Park.objects.create(
name="Test Park",
slug="test-park-4",
operator=self.operator,
timezone="UTC",
)
# Initially no location in search_text
original_search_text = park.search_text
# Add location
location = ParkLocation.objects.create(
park=park,
city="Orlando",
state="Florida",
country="USA",
)
# Refresh park from database
park.refresh_from_db()
# Search text should now include location
# Note: This depends on signal handlers being properly registered
# The actual behavior may vary based on signal configuration

View File

@@ -2,7 +2,6 @@ from .querysets import get_base_park_queryset
from apps.core.mixins import HTMXFilterableMixin
from .models.location import ParkLocation
from .models.media import ParkPhoto
from apps.moderation.services import ModerationService
from apps.moderation.mixins import (
EditSubmissionMixin,
PhotoSubmissionMixin,
@@ -12,7 +11,7 @@ from apps.core.views.views import SlugRedirectMixin
from .filters import ParkFilter
from .forms import ParkForm
from .models import Park, ParkArea, ParkReview as Review
from .services import ParkFilterService
from .services import ParkFilterService, ParkService
from django.http import (
HttpResponseRedirect,
HttpResponse,
@@ -21,7 +20,6 @@ from django.http import (
)
from django.core.exceptions import ObjectDoesNotExist
from django.contrib import messages
from django.contrib.contenttypes.models import ContentType
from django.contrib.auth.mixins import LoginRequiredMixin
from django.db.models import QuerySet
from django.urls import reverse
@@ -33,7 +31,7 @@ from decimal import Decimal, ROUND_DOWN
from typing import Any, Optional, cast, Literal, Dict
from django.views.decorators.http import require_POST
from django.template.loader import render_to_string
import json
# Constants
@@ -137,22 +135,24 @@ def park_status_actions(request: HttpRequest, slug: str) -> HttpResponse:
park = get_object_or_404(Park, slug=slug)
# Only show to moderators
if not request.user.has_perm('parks.change_park'):
if not request.user.has_perm("parks.change_park"):
return HttpResponse("")
return render(request, "parks/partials/park_status_actions.html", {
"park": park,
"user": request.user
})
return render(
request,
"parks/partials/park_status_actions.html",
{"park": park, "user": request.user},
)
def park_header_badge(request: HttpRequest, slug: str) -> HttpResponse:
"""Return the header status badge partial for a park"""
park = get_object_or_404(Park, slug=slug)
return render(request, "parks/partials/park_header_badge.html", {
"park": park,
"user": request.user
})
return render(
request,
"parks/partials/park_header_badge.html",
{"park": park, "user": request.user},
)
def get_park_areas(request: HttpRequest) -> HttpResponse:
@@ -502,6 +502,7 @@ def htmx_saved_trips(request: HttpRequest) -> HttpResponse:
if request.user.is_authenticated:
try:
from .models import Trip # type: ignore
qs = Trip.objects.filter(owner=request.user).order_by("-created_at")
trips = list(qs[:10])
except Exception:
@@ -648,7 +649,10 @@ def htmx_optimize_route(request: HttpRequest) -> HttpResponse:
rlat1, rlon1, rlat2, rlon2 = map(math.radians, [lat1, lon1, lat2, lon2])
dlat = rlat2 - rlat1
dlon = rlon2 - rlon1
a = math.sin(dlat / 2) ** 2 + math.cos(rlat1) * math.cos(rlat2) * math.sin(dlon / 2) ** 2
a = (
math.sin(dlat / 2) ** 2
+ math.cos(rlat1) * math.cos(rlat2) * math.sin(dlon / 2) ** 2
)
c = 2 * math.asin(min(1, math.sqrt(a)))
miles = 3958.8 * c
return miles
@@ -660,14 +664,18 @@ def htmx_optimize_route(request: HttpRequest) -> HttpResponse:
lat = getattr(loc, "latitude", None) if loc else None
lon = getattr(loc, "longitude", None) if loc else None
if lat is not None and lon is not None:
waypoints.append({"id": p.id, "name": p.name, "latitude": lat, "longitude": lon})
waypoints.append(
{"id": p.id, "name": p.name, "latitude": lat, "longitude": lon}
)
# sum straight-line distances between consecutive waypoints
for i in range(len(waypoints) - 1):
a = waypoints[i]
b = waypoints[i + 1]
try:
total_miles += haversine_miles(a["latitude"], a["longitude"], b["latitude"], b["longitude"])
total_miles += haversine_miles(
a["latitude"], a["longitude"], b["latitude"], b["longitude"]
)
except Exception:
continue
@@ -681,7 +689,9 @@ def htmx_optimize_route(request: HttpRequest) -> HttpResponse:
"total_rides": sum(getattr(p, "ride_count", 0) or 0 for p in parks),
}
html = render_to_string(TRIP_SUMMARY_TEMPLATE, {"summary": summary}, request=request)
html = render_to_string(
TRIP_SUMMARY_TEMPLATE, {"summary": summary}, request=request
)
resp = HttpResponse(html)
# Include waypoints payload in HX-Trigger so client can render route on the map
resp["HX-Trigger"] = json.dumps({"tripOptimized": {"parks": waypoints}})
@@ -710,13 +720,16 @@ def htmx_save_trip(request: HttpRequest) -> HttpResponse:
if request.user.is_authenticated:
try:
from .models import Trip # type: ignore
trip = Trip.objects.create(owner=request.user, name=name)
# attempt to associate parks if the Trip model supports it
try:
trip.parks.set([p.id for p in parks])
except Exception:
pass
trips = list(Trip.objects.filter(owner=request.user).order_by("-created_at")[:10])
trips = list(
Trip.objects.filter(owner=request.user).order_by("-created_at")[:10]
)
except Exception:
trips = []
@@ -735,6 +748,7 @@ def htmx_clear_trip(request: HttpRequest) -> HttpResponse:
resp["HX-Trigger"] = json.dumps({"tripCleared": True})
return resp
class ParkCreateView(LoginRequiredMixin, CreateView):
model = Park
form_class = ParkForm
@@ -775,76 +789,49 @@ class ParkCreateView(LoginRequiredMixin, CreateView):
self.normalize_coordinates(form)
changes = self.prepare_changes_data(form.cleaned_data)
# Use the new queue routing service
result = ModerationService.create_edit_submission_with_queue(
content_object=None, # None for CREATE
# Submit through moderation service
result = ParkService.create_park_with_moderation(
changes=changes,
submitter=self.request.user,
submission_type="CREATE",
reason=self.request.POST.get("reason", ""),
source=self.request.POST.get("source", ""),
)
if result['status'] == 'auto_approved':
# Moderator submission was auto-approved
self.object = result['created_object']
# Handle the result using the service
photos = self.request.FILES.getlist("photos")
service_result = ParkService.handle_park_creation_result(
result=result,
form_data=form.cleaned_data,
photos=photos,
user=self.request.user,
)
if form.cleaned_data.get("latitude") and form.cleaned_data.get("longitude"):
# Create or update ParkLocation
park_location, _ = ParkLocation.objects.get_or_create(
park=self.object,
defaults={
"street_address": form.cleaned_data.get("street_address", ""),
"city": form.cleaned_data.get("city", ""),
"state": form.cleaned_data.get("state", ""),
"country": form.cleaned_data.get("country", "USA"),
"postal_code": form.cleaned_data.get("postal_code", ""),
},
)
park_location.set_coordinates(
form.cleaned_data["latitude"],
form.cleaned_data["longitude"],
)
park_location.save()
photos = self.request.FILES.getlist("photos")
uploaded_count = 0
for photo_file in photos:
try:
ParkPhoto.objects.create(
image=photo_file,
uploaded_by=self.request.user,
park=self.object,
)
uploaded_count += 1
except Exception as e:
messages.error(
self.request,
f"Error uploading photo {photo_file.name}: {str(e)}",
)
# Report any photo upload errors
for error in service_result.get("errors", []):
messages.error(self.request, error)
if service_result["status"] == "auto_approved":
self.object = service_result["park"]
messages.success(
self.request,
f"Successfully created {self.object.name}. "
f"Added {uploaded_count} photo(s).",
f"Added {service_result['uploaded_count']} photo(s).",
)
return HttpResponseRedirect(self.get_success_url())
elif result['status'] == 'queued':
# Regular user submission was queued
elif service_result["status"] == "queued":
messages.success(
self.request,
"Your park submission has been sent for review. "
"You will be notified when it is approved.",
)
# Redirect to parks list since we don't have an object yet
return HttpResponseRedirect(reverse("parks:park_list"))
elif result['status'] == 'failed':
# Auto-approval failed
elif service_result["status"] == "failed":
messages.error(
self.request,
f"Error creating park: {result['message']}. Please check your input and try again.",
f"Error creating park: {service_result.get('message', 'Unknown error')}. "
"Please check your input and try again.",
)
return self.form_invalid(form)
@@ -900,110 +887,43 @@ class ParkUpdateView(LoginRequiredMixin, UpdateView):
Decimal("0.000001"), rounding=ROUND_DOWN
)
def form_valid(self, form: ParkForm) -> HttpResponse: # noqa: C901
def form_valid(self, form: ParkForm) -> HttpResponse:
self.normalize_coordinates(form)
changes = self.prepare_changes_data(form.cleaned_data)
# Use the new queue routing service
result = ModerationService.create_edit_submission_with_queue(
content_object=self.object,
# Submit through moderation service
result = ParkService.update_park_with_moderation(
park=self.object,
changes=changes,
submitter=self.request.user,
submission_type="EDIT",
reason=self.request.POST.get("reason", ""),
source=self.request.POST.get("source", ""),
)
if result['status'] == 'auto_approved':
# Moderator submission was auto-approved
# The object was already updated by the service
self.object = result['created_object']
# Handle the result using the service
photos = self.request.FILES.getlist("photos")
service_result = ParkService.handle_park_update_result(
result=result,
park=self.object,
form_data=form.cleaned_data,
photos=photos,
user=self.request.user,
)
location_data = {
"name": self.object.name,
"location_type": "park",
"latitude": form.cleaned_data.get("latitude"),
"longitude": form.cleaned_data.get("longitude"),
"street_address": form.cleaned_data.get("street_address", ""),
"city": form.cleaned_data.get("city", ""),
"state": form.cleaned_data.get("state", ""),
"country": form.cleaned_data.get("country", ""),
"postal_code": form.cleaned_data.get("postal_code", ""),
}
# Create or update ParkLocation
try:
park_location = self.object.location
# Update existing location
for key, value in location_data.items():
if key in ["latitude", "longitude"] and value:
continue # Handle coordinates separately
if hasattr(park_location, key):
setattr(park_location, key, value)
# Handle coordinates if provided
if "latitude" in location_data and "longitude" in location_data:
if location_data["latitude"] and location_data["longitude"]:
park_location.set_coordinates(
float(location_data["latitude"]),
float(location_data["longitude"]),
)
park_location.save()
except ParkLocation.DoesNotExist:
# Create new ParkLocation
coordinates_data = {}
if "latitude" in location_data and "longitude" in location_data:
if location_data["latitude"] and location_data["longitude"]:
coordinates_data = {
"latitude": float(location_data["latitude"]),
"longitude": float(location_data["longitude"]),
}
# Remove coordinate fields from location_data for creation
creation_data = {
k: v
for k, v in location_data.items()
if k not in ["latitude", "longitude"]
}
creation_data.setdefault("country", "USA")
park_location = ParkLocation.objects.create(
park=self.object, **creation_data
)
if coordinates_data:
park_location.set_coordinates(
coordinates_data["latitude"],
coordinates_data["longitude"],
)
park_location.save()
photos = self.request.FILES.getlist("photos")
uploaded_count = 0
for photo_file in photos:
try:
ParkPhoto.objects.create(
image=photo_file,
uploaded_by=self.request.user,
content_type=ContentType.objects.get_for_model(Park),
object_id=self.object.id,
)
uploaded_count += 1
except Exception as e:
messages.error(
self.request,
f"Error uploading photo {photo_file.name}: {str(e)}",
)
# Report any photo upload errors
for error in service_result.get("errors", []):
messages.error(self.request, error)
if service_result["status"] == "auto_approved":
self.object = service_result["park"]
messages.success(
self.request,
f"Successfully updated {self.object.name}. "
f"Added {uploaded_count} new photo(s).",
f"Added {service_result['uploaded_count']} new photo(s).",
)
return HttpResponseRedirect(self.get_success_url())
elif result['status'] == 'queued':
# Regular user submission was queued
elif service_result["status"] == "queued":
messages.success(
self.request,
f"Your changes to {self.object.name} have been sent for review. "
@@ -1013,11 +933,11 @@ class ParkUpdateView(LoginRequiredMixin, UpdateView):
reverse(PARK_DETAIL_URL, kwargs={"slug": self.object.slug})
)
elif result['status'] == 'failed':
# Auto-approval failed
elif service_result["status"] == "failed":
messages.error(
self.request,
f"Error updating park: {result['message']}. Please check your input and try again.",
f"Error updating park: {service_result.get('message', 'Unknown error')}. "
"Please check your input and try again.",
)
return self.form_invalid(form)
@@ -1133,13 +1053,14 @@ class OperatorListView(ListView):
paginate_by = 24
def get_queryset(self):
"""Get companies that are operators"""
"""Get companies that are operators with optimized query"""
from .models.companies import Company
from django.db.models import Count
return (
Company.objects.filter(roles__contains=["OPERATOR"])
.annotate(park_count=Count("operated_parks"))
.only("id", "name", "slug", "roles", "description")
.order_by("name")
)

View File

@@ -75,13 +75,14 @@ class RideQuerySet(StatusQuerySet, ReviewableQuerySet):
return self.select_related(
"park",
"park__location",
"park_area",
"manufacturer",
"designer",
"ride_model",
"ride_model__manufacturer",
).prefetch_related(
"location",
"rollercoaster_stats",
"coaster_stats",
Prefetch(
"reviews",
queryset=RideReview.objects.select_related("user")
@@ -91,6 +92,12 @@ class RideQuerySet(StatusQuerySet, ReviewableQuerySet):
"photos",
)
def with_coaster_stats(self):
"""Always prefetch coaster_stats for roller coaster queries."""
return self.select_related(
"park", "manufacturer", "ride_model"
).prefetch_related("coaster_stats")
def for_map_display(self):
"""Optimize for map display."""
return (
@@ -176,6 +183,10 @@ class RideManager(StatusManager, ReviewableManager):
def optimized_for_detail(self):
return self.get_queryset().optimized_for_detail()
def with_coaster_stats(self):
"""Always prefetch coaster_stats for roller coaster queries."""
return self.get_queryset().with_coaster_stats()
class RideModelQuerySet(BaseQuerySet):
"""QuerySet for RideModel model."""

View File

@@ -0,0 +1,40 @@
"""
Convert RideModel unique_together to UniqueConstraint.
This migration converts the legacy unique_together constraints to the modern
UniqueConstraint syntax which provides better error messages and more flexibility.
"""
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('rides', '0025_convert_ride_status_to_fsm'),
]
operations = [
# Remove the old unique_together constraint
migrations.AlterUniqueTogether(
name='ridemodel',
unique_together=set(),
),
# Add new UniqueConstraints with better error messages
migrations.AddConstraint(
model_name='ridemodel',
constraint=models.UniqueConstraint(
fields=['manufacturer', 'name'],
name='ridemodel_manufacturer_name_unique',
violation_error_message='A ride model with this name already exists for this manufacturer'
),
),
migrations.AddConstraint(
model_name='ridemodel',
constraint=models.UniqueConstraint(
fields=['manufacturer', 'slug'],
name='ridemodel_manufacturer_slug_unique',
violation_error_message='A ride model with this slug already exists for this manufacturer'
),
),
]

View File

@@ -0,0 +1,49 @@
"""
Mixins for ride views.
This module contains mixins that provide reusable functionality
for ride-related views, reducing code duplication.
"""
from typing import Any, Dict
from django.contrib import messages
from apps.rides.services import RideService
class RideFormMixin:
"""
Mixin for handling ride form submissions with entity suggestions.
Provides common functionality for RideCreateView and RideUpdateView
to handle new manufacturer, designer, and ride model suggestions.
"""
def handle_entity_suggestions(self, form) -> Dict[str, Any]:
"""
Process new entity suggestions from form.
Creates moderation submissions for any new manufacturers,
designers, or ride models that were suggested but don't
exist in the system.
Args:
form: Validated form instance with cleaned_data
Returns:
Dictionary with submission results from RideService
"""
result = RideService.handle_new_entity_suggestions(
form_data=form.cleaned_data,
submitter=self.request.user
)
if result['total_submissions'] > 0:
messages.info(
self.request,
f"Created {result['total_submissions']} moderation submission(s) "
"for new entities"
)
return result

View File

@@ -1,5 +1,6 @@
from django.db import models
from django.utils.text import slugify
from django.core.exceptions import ValidationError
from config.django import base as settings
from apps.core.models import TrackedModel
from apps.core.choices import RichChoiceField
@@ -165,8 +166,18 @@ class RideModel(TrackedModel):
class Meta(TrackedModel.Meta):
ordering = ["manufacturer__name", "name"]
unique_together = [["manufacturer", "name"], ["manufacturer", "slug"]]
constraints = [
# Unique constraints (replacing unique_together for better error messages)
models.UniqueConstraint(
fields=['manufacturer', 'name'],
name='ridemodel_manufacturer_name_unique',
violation_error_message='A ride model with this name already exists for this manufacturer'
),
models.UniqueConstraint(
fields=['manufacturer', 'slug'],
name='ridemodel_manufacturer_slug_unique',
violation_error_message='A ride model with this slug already exists for this manufacturer'
),
# Height range validation
models.CheckConstraint(
name="ride_model_height_range_logical",
@@ -222,6 +233,14 @@ class RideModel(TrackedModel):
else f"{self.manufacturer.name} {self.name}"
)
def clean(self) -> None:
"""Validate RideModel business rules."""
super().clean()
if self.is_discontinued and not self.last_installation_year:
raise ValidationError({
'last_installation_year': 'Discontinued models must have a last installation year'
})
def save(self, *args, **kwargs) -> None:
if not self.slug:
from django.utils.text import slugify

View File

@@ -307,3 +307,85 @@ class RideService:
ride = Ride.objects.select_for_update().get(id=ride_id)
ride.open(user=user)
return ride
@staticmethod
def handle_new_entity_suggestions(
*,
form_data: Dict[str, Any],
submitter: UserType,
) -> Dict[str, Any]:
"""
Handle suggestions for new manufacturers, designers, and ride models.
Creates moderation submissions for entities that don't exist in the system.
This extracts the business logic from RideCreateView and RideUpdateView.
Args:
form_data: Cleaned form data containing search fields and selections
submitter: User making the suggestions
Returns:
Dictionary with lists of created submission IDs by type:
{
'manufacturers': [...],
'designers': [...],
'ride_models': [...],
'total_submissions': int
}
"""
from apps.moderation.services import ModerationService
result = {
'manufacturers': [],
'designers': [],
'ride_models': [],
'total_submissions': 0
}
# Check for new manufacturer
manufacturer_name = form_data.get("manufacturer_search")
if manufacturer_name and not form_data.get("manufacturer"):
submission = ModerationService.create_edit_submission_with_queue(
content_object=None,
changes={"name": manufacturer_name, "roles": ["MANUFACTURER"]},
submitter=submitter,
submission_type="CREATE",
reason=f"New manufacturer suggested: {manufacturer_name}",
)
if submission:
result['manufacturers'].append(submission.id)
result['total_submissions'] += 1
# Check for new designer
designer_name = form_data.get("designer_search")
if designer_name and not form_data.get("designer"):
submission = ModerationService.create_edit_submission_with_queue(
content_object=None,
changes={"name": designer_name, "roles": ["DESIGNER"]},
submitter=submitter,
submission_type="CREATE",
reason=f"New designer suggested: {designer_name}",
)
if submission:
result['designers'].append(submission.id)
result['total_submissions'] += 1
# Check for new ride model
ride_model_name = form_data.get("ride_model_search")
manufacturer = form_data.get("manufacturer")
if ride_model_name and not form_data.get("ride_model") and manufacturer:
submission = ModerationService.create_edit_submission_with_queue(
content_object=None,
changes={
"name": ride_model_name,
"manufacturer": manufacturer.id,
},
submitter=submitter,
submission_type="CREATE",
reason=f"New ride model suggested: {ride_model_name}",
)
if submission:
result['ride_models'].append(submission.id)
result['total_submissions'] += 1
return result

View File

@@ -1,6 +1,6 @@
import logging
from django.db.models.signals import pre_save
from django.db.models.signals import pre_save, post_save
from django.dispatch import receiver
from django.utils import timezone
@@ -10,6 +10,28 @@ from .models import Ride
logger = logging.getLogger(__name__)
# =============================================================================
# Computed Field Maintenance
# =============================================================================
def update_ride_search_text(ride):
"""
Update ride's search_text computed field.
This is called when related objects (park, manufacturer, ride_model)
change and might affect the ride's search text.
"""
if ride is None:
return
try:
ride._populate_computed_fields()
ride.save(update_fields=['search_text'])
logger.debug(f"Updated search_text for ride {ride.pk}")
except Exception as e:
logger.exception(f"Failed to update search_text for ride {ride.pk}: {e}")
@receiver(pre_save, sender=Ride)
def handle_ride_status(sender, instance, **kwargs):
"""
@@ -186,3 +208,58 @@ def apply_post_closing_status(instance, user=None):
f"Applied post_closing_status {target_status} to ride {instance.pk} (direct)"
)
return True
# =============================================================================
# Computed Field Maintenance Signal Handlers
# =============================================================================
@receiver(post_save, sender='parks.Park')
def update_ride_search_text_on_park_change(sender, instance, **kwargs):
"""
Update ride search_text when park name or location changes.
When a park's name changes, all rides at that park need their
search_text regenerated.
"""
try:
for ride in instance.rides.all():
update_ride_search_text(ride)
except Exception as e:
logger.exception(f"Failed to update ride search_text on park change: {e}")
@receiver(post_save, sender='parks.Company')
def update_ride_search_text_on_company_change(sender, instance, **kwargs):
"""
Update ride search_text when manufacturer/designer name changes.
When a company's name changes, all rides manufactured or designed
by that company need their search_text regenerated.
"""
try:
# Update all rides manufactured by this company
for ride in instance.manufactured_rides.all():
update_ride_search_text(ride)
# Update all rides designed by this company
for ride in instance.designed_rides.all():
update_ride_search_text(ride)
except Exception as e:
logger.exception(f"Failed to update ride search_text on company change: {e}")
@receiver(post_save, sender='rides.RideModel')
def update_ride_search_text_on_ride_model_change(sender, instance, **kwargs):
"""
Update ride search_text when ride model name changes.
When a ride model's name changes, all rides using that model need
their search_text regenerated.
"""
try:
for ride in instance.rides.all():
update_ride_search_text(ride)
except Exception as e:
logger.exception(f"Failed to update ride search_text on ride model change: {e}")

View File

@@ -1,80 +1,188 @@
from django.views.generic import DetailView, ListView, CreateView, UpdateView
"""
Views for ride functionality.
This module contains views for managing ride resources including
CRUD operations, search, filtering, and HTMX partial rendering.
View Types:
CBVs:
- RideDetailView: Display ride details
- RideCreateView: Create new ride
- RideUpdateView: Update existing ride
- RideListView: List rides with filtering
- RideSearchView: Search rides with HTMX support
- RideRankingsView: Display ride rankings
- RideRankingDetailView: Display ranking details
- ManufacturerListView: List manufacturers
- DesignerListView: List designers
- SingleCategoryListView: List rides by category
FBVs (HTMX Partials):
- show_coaster_fields: Toggle coaster-specific fields
- ride_status_actions: FSM status actions for moderators
- ride_header_badge: Status badge partial
- search_companies: Company search autocomplete
- search_ride_models: Ride model search autocomplete
- get_search_suggestions: Smart search suggestions
- ranking_history_chart: Ranking history chart data
- ranking_comparisons: Head-to-head comparisons
Dependencies:
- Services: apps.rides.services
- Models: apps.rides.models
- Forms: apps.rides.forms
Code Quality:
- PEP 8 compliant (verified with black, flake8, ruff)
- Maximum line length: 88 characters
- Maximum complexity: 10 (McCabe)
"""
from django.contrib.auth.mixins import LoginRequiredMixin
from django.db.models import Count, Q
from django.http import Http404, HttpRequest, HttpResponse
from django.shortcuts import get_object_or_404, render
from django.urls import reverse
from django.db.models import Q
from django.contrib.auth.mixins import LoginRequiredMixin
from django.http import HttpRequest, HttpResponse, Http404
from django.db.models import Count
from .models.rides import Ride, RideModel
from .choices import RIDE_CATEGORIES
from .models.company import Company
from django.views.generic import CreateView, DetailView, ListView, UpdateView
from apps.moderation.mixins import EditSubmissionMixin, HistoryMixin
from apps.parks.models import Park
from .forms import RideForm, RideSearchForm
from .forms.search import MasterFilterForm
from .services.search import RideSearchService
from apps.parks.models import Park
from apps.moderation.mixins import EditSubmissionMixin, HistoryMixin
from apps.moderation.services import ModerationService
from .models.rankings import RideRanking, RankingSnapshot
from .mixins import RideFormMixin
from .models.company import Company
from .models.rankings import RankingSnapshot, RideRanking
from .models.rides import Ride, RideModel
from .services.ranking_service import RideRankingService
class ParkContextRequired:
"""Mixin to require park context for views"""
"""
Mixin to require park context for views.
Ensures that the view has access to a park_slug URL parameter.
Raises Http404 if park context is not available.
"""
def dispatch(self, request, *args, **kwargs):
"""Check for park context before dispatching to handler."""
if "park_slug" not in self.kwargs:
raise Http404("Park context is required")
return super().dispatch(request, *args, **kwargs)
def show_coaster_fields(request: HttpRequest) -> HttpResponse:
"""Show roller coaster specific fields based on category selection"""
"""
Show roller coaster specific fields based on category selection.
View Type: FBV (HTMX Partial)
URL Pattern: /rides/coaster-fields/
Returns: HTML partial with coaster-specific form fields
Args:
request: HTTP request with 'category' query parameter
Returns:
Empty response for non-coaster categories,
or coaster fields partial for roller coasters
"""
category = request.GET.get("category")
if category != "RC": # Only show for roller coasters
return HttpResponse("")
return render(request, "rides/partials/coaster_fields.html")
def ride_status_actions(request: HttpRequest, park_slug: str, ride_slug: str) -> HttpResponse:
"""Return FSM status actions for ride moderators"""
def ride_status_actions(
request: HttpRequest, park_slug: str, ride_slug: str
) -> HttpResponse:
"""
Return FSM status actions for ride moderators.
View Type: FBV (HTMX Partial)
URL Pattern: /parks/<park_slug>/rides/<ride_slug>/status-actions/
Returns: HTML partial with available status transition actions
Permissions: rides.change_ride
Args:
request: HTTP request
park_slug: Slug of the park
ride_slug: Slug of the ride
Returns:
Empty response for non-moderators, or status actions partial
"""
park = get_object_or_404(Park, slug=park_slug)
ride = get_object_or_404(Ride, park=park, slug=ride_slug)
# Only show to moderators
if not request.user.has_perm('rides.change_ride'):
if not request.user.has_perm("rides.change_ride"):
return HttpResponse("")
return render(request, "rides/partials/ride_status_actions.html", {
"ride": ride,
"park": park,
"user": request.user
})
return render(
request,
"rides/partials/ride_status_actions.html",
{"ride": ride, "park": park, "user": request.user},
)
def ride_header_badge(request: HttpRequest, park_slug: str, ride_slug: str) -> HttpResponse:
"""Return the header status badge partial for a ride"""
def ride_header_badge(
request: HttpRequest, park_slug: str, ride_slug: str
) -> HttpResponse:
"""
Return the header status badge partial for a ride.
View Type: FBV (HTMX Partial)
URL Pattern: /parks/<park_slug>/rides/<ride_slug>/header-badge/
Returns: HTML partial with ride status badge
Args:
request: HTTP request
park_slug: Slug of the park
ride_slug: Slug of the ride
Returns:
Rendered status badge partial
"""
park = get_object_or_404(Park, slug=park_slug)
ride = get_object_or_404(Ride, park=park, slug=ride_slug)
return render(request, "rides/partials/ride_header_badge.html", {
"ride": ride,
"park": park,
"user": request.user
})
return render(
request,
"rides/partials/ride_header_badge.html",
{"ride": ride, "park": park, "user": request.user},
)
class RideDetailView(HistoryMixin, DetailView):
"""View for displaying ride details"""
"""
Display ride details with related data.
View Type: CBV (DetailView)
URL Pattern: /parks/<park_slug>/rides/<ride_slug>/
Template: rides/ride_detail.html
Permissions: Public
Includes history tracking via HistoryMixin for audit trail display.
"""
model = Ride
template_name = "rides/ride_detail.html"
slug_url_kwarg = "ride_slug"
def get_queryset(self):
"""Get ride for the specific park if park_slug is provided"""
"""Get ride for the specific park with optimized query."""
queryset = (
Ride.objects.all()
.select_related("park", "ride_model", "ride_model__manufacturer")
.prefetch_related("photos")
.select_related(
"park",
"park__location",
"park_area",
"manufacturer",
"designer",
"ride_model",
"ride_model__manufacturer",
)
.prefetch_related("photos", "coaster_stats")
)
if "park_slug" in self.kwargs:
@@ -92,15 +200,24 @@ class RideDetailView(HistoryMixin, DetailView):
return context
class RideCreateView(LoginRequiredMixin, ParkContextRequired, CreateView):
"""View for creating a new ride"""
class RideCreateView(
LoginRequiredMixin, ParkContextRequired, RideFormMixin, CreateView
):
"""
View for creating a new ride.
View Type: CBV (CreateView)
URL Pattern: /parks/<park_slug>/rides/add/
Template: rides/ride_form.html
Permissions: LoginRequired
"""
model = Ride
form_class = RideForm
template_name = "rides/ride_form.html"
def get_success_url(self):
"""Get URL to redirect to after successful creation"""
"""Get URL to redirect to after successful creation."""
return reverse(
"parks:rides:ride_detail",
kwargs={
@@ -110,14 +227,14 @@ class RideCreateView(LoginRequiredMixin, ParkContextRequired, CreateView):
)
def get_form_kwargs(self):
"""Pass park to the form"""
"""Pass park to the form."""
kwargs = super().get_form_kwargs()
self.park = get_object_or_404(Park, slug=self.kwargs["park_slug"])
kwargs["park"] = self.park
return kwargs
def get_context_data(self, **kwargs):
"""Add park and park_slug to context"""
"""Add park and park_slug to context."""
context = super().get_context_data(**kwargs)
context["park"] = self.park
context["park_slug"] = self.park.slug
@@ -125,51 +242,26 @@ class RideCreateView(LoginRequiredMixin, ParkContextRequired, CreateView):
return context
def form_valid(self, form):
"""Handle form submission including new items"""
# Check for new manufacturer
manufacturer_name = form.cleaned_data.get("manufacturer_search")
if manufacturer_name and not form.cleaned_data.get("manufacturer"):
ModerationService.create_edit_submission_with_queue(
content_object=None,
changes={"name": manufacturer_name, "roles": ["MANUFACTURER"]},
submitter=self.request.user,
submission_type="CREATE",
reason=f"New manufacturer suggested during ride creation: {manufacturer_name}",
)
# Check for new designer
designer_name = form.cleaned_data.get("designer_search")
if designer_name and not form.cleaned_data.get("designer"):
ModerationService.create_edit_submission_with_queue(
content_object=None,
changes={"name": designer_name, "roles": ["DESIGNER"]},
submitter=self.request.user,
submission_type="CREATE",
reason=f"New designer suggested during ride creation: {designer_name}",
)
# Check for new ride model
ride_model_name = form.cleaned_data.get("ride_model_search")
manufacturer = form.cleaned_data.get("manufacturer")
if ride_model_name and not form.cleaned_data.get("ride_model") and manufacturer:
ModerationService.create_edit_submission_with_queue(
content_object=None,
changes={
"name": ride_model_name,
"manufacturer": manufacturer.id,
},
submitter=self.request.user,
submission_type="CREATE",
reason=f"New ride model suggested during ride creation: {ride_model_name}",
)
"""Handle form submission using RideFormMixin for entity suggestions."""
self.handle_entity_suggestions(form)
return super().form_valid(form)
class RideUpdateView(
LoginRequiredMixin, ParkContextRequired, EditSubmissionMixin, UpdateView
LoginRequiredMixin,
ParkContextRequired,
RideFormMixin,
EditSubmissionMixin,
UpdateView,
):
"""View for updating an existing ride"""
"""
View for updating an existing ride.
View Type: CBV (UpdateView)
URL Pattern: /parks/<park_slug>/rides/<ride_slug>/edit/
Template: rides/ride_form.html
Permissions: LoginRequired
"""
model = Ride
form_class = RideForm
@@ -177,7 +269,7 @@ class RideUpdateView(
slug_url_kwarg = "ride_slug"
def get_success_url(self):
"""Get URL to redirect to after successful update"""
"""Get URL to redirect to after successful update."""
return reverse(
"parks:rides:ride_detail",
kwargs={
@@ -187,18 +279,18 @@ class RideUpdateView(
)
def get_queryset(self):
"""Get ride for the specific park"""
"""Get ride for the specific park."""
return Ride.objects.filter(park__slug=self.kwargs["park_slug"])
def get_form_kwargs(self):
"""Pass park to the form"""
"""Pass park to the form."""
kwargs = super().get_form_kwargs()
self.park = get_object_or_404(Park, slug=self.kwargs["park_slug"])
kwargs["park"] = self.park
return kwargs
def get_context_data(self, **kwargs):
"""Add park and park_slug to context"""
"""Add park and park_slug to context."""
context = super().get_context_data(**kwargs)
context["park"] = self.park
context["park_slug"] = self.park.slug
@@ -206,44 +298,8 @@ class RideUpdateView(
return context
def form_valid(self, form):
"""Handle form submission including new items"""
# Check for new manufacturer
manufacturer_name = form.cleaned_data.get("manufacturer_search")
if manufacturer_name and not form.cleaned_data.get("manufacturer"):
ModerationService.create_edit_submission_with_queue(
content_object=None,
changes={"name": manufacturer_name, "roles": ["MANUFACTURER"]},
submitter=self.request.user,
submission_type="CREATE",
reason=f"New manufacturer suggested during ride update: {manufacturer_name}",
)
# Check for new designer
designer_name = form.cleaned_data.get("designer_search")
if designer_name and not form.cleaned_data.get("designer"):
ModerationService.create_edit_submission_with_queue(
content_object=None,
changes={"name": designer_name, "roles": ["DESIGNER"]},
submitter=self.request.user,
submission_type="CREATE",
reason=f"New designer suggested during ride update: {designer_name}",
)
# Check for new ride model
ride_model_name = form.cleaned_data.get("ride_model_search")
manufacturer = form.cleaned_data.get("manufacturer")
if ride_model_name and not form.cleaned_data.get("ride_model") and manufacturer:
ModerationService.create_edit_submission_with_queue(
content_object=None,
changes={
"name": ride_model_name,
"manufacturer": manufacturer.id,
},
submitter=self.request.user,
submission_type="CREATE",
reason=f"New ride model suggested during ride update: {ride_model_name}",
)
"""Handle form submission using RideFormMixin for entity suggestions."""
self.handle_entity_suggestions(form)
return super().form_valid(form)
@@ -256,13 +312,7 @@ class RideListView(ListView):
paginate_by = 24
def get_queryset(self):
"""Get filtered rides using the advanced search service"""
# Initialize search service
search_service = RideSearchService()
# Parse filters from request
filter_form = MasterFilterForm(self.request.GET)
"""Get filtered rides using the advanced search service."""
# Apply park context if available
park = None
if "park_slug" in self.kwargs:
@@ -275,12 +325,12 @@ class RideListView(ListView):
.select_related("park", "ride_model", "ride_model__manufacturer")
.prefetch_related("photos")
)
if park:
queryset = queryset.filter(park=park)
# Apply basic search if provided
search_query = self.request.GET.get('search', '').strip()
search_query = self.request.GET.get("search", "").strip()
if search_query:
queryset = queryset.filter(name__icontains=search_query)
@@ -306,8 +356,11 @@ class RideListView(ListView):
context["filter_form"] = filter_form
# Use Rich Choice registry directly
from apps.core.choices.registry import get_choices
choices = get_choices("categories", "rides")
context["category_choices"] = [(choice.value, choice.label) for choice in choices]
context["category_choices"] = [
(choice.value, choice.label) for choice in choices
]
# Add filter summary for display
if filter_form.is_valid():
@@ -357,8 +410,12 @@ class SingleCategoryListView(ListView):
context["park_slug"] = self.kwargs["park_slug"]
# Find the category choice by value using Rich Choice registry
from apps.core.choices.registry import get_choices
choices = get_choices("categories", "rides")
category_choice = next((choice for choice in choices if choice.value == self.kwargs["category"]), None)
category_choice = next(
(choice for choice in choices if choice.value == self.kwargs["category"]),
None,
)
context["category"] = category_choice.label if category_choice else "Unknown"
return context
@@ -372,7 +429,7 @@ def search_companies(request: HttpRequest) -> HttpResponse:
query = request.GET.get("q", "").strip()
role = request.GET.get("role", "").upper()
companies = Company.objects.all().order_by("name")
companies = Company.objects.only("id", "name", "slug", "roles").order_by("name")
if role:
companies = companies.filter(roles__contains=[role])
if query:
@@ -439,8 +496,8 @@ def get_search_suggestions(request: HttpRequest) -> HttpResponse:
}
)
# Get matching parks
matching_parks = Park.objects.filter(
# Get matching parks with optimized query
matching_parks = Park.objects.select_related("location").filter(
Q(name__icontains=query) | Q(location__city__icontains=query)
)[:3]
@@ -455,6 +512,7 @@ def get_search_suggestions(request: HttpRequest) -> HttpResponse:
# Add category matches
from apps.core.choices.registry import get_choices
choices = get_choices("categories", "rides")
for choice in choices:
if query in choice.label.lower():
@@ -556,8 +614,11 @@ class RideRankingsView(ListView):
context = super().get_context_data(**kwargs)
# Use Rich Choice registry directly
from apps.core.choices.registry import get_choices
choices = get_choices("categories", "rides")
context["category_choices"] = [(choice.value, choice.label) for choice in choices]
context["category_choices"] = [
(choice.value, choice.label) for choice in choices
]
context["selected_category"] = self.request.GET.get("category", "all")
context["min_riders"] = self.request.GET.get("min_riders", "")
@@ -690,10 +751,11 @@ class ManufacturerListView(ListView):
paginate_by = 24
def get_queryset(self):
"""Get companies that are manufacturers"""
"""Get companies that are manufacturers with optimized query"""
return (
Company.objects.filter(roles__contains=["MANUFACTURER"])
.annotate(ride_count=Count("manufactured_rides"))
.only("id", "name", "slug", "roles", "description")
.order_by("name")
)
@@ -713,10 +775,11 @@ class DesignerListView(ListView):
paginate_by = 24
def get_queryset(self):
"""Get companies that are designers"""
"""Get companies that are designers with optimized query"""
return (
Company.objects.filter(roles__contains=["DESIGNER"])
.annotate(ride_count=Count("designed_rides"))
.only("id", "name", "slug", "roles", "description")
.order_by("name")
)