feat: Refactor rides app with unique constraints, mixins, and enhanced documentation

- Added migration to convert unique_together constraints to UniqueConstraint for RideModel.
- Introduced RideFormMixin for handling entity suggestions in ride forms.
- Created comprehensive code standards documentation outlining formatting, docstring requirements, complexity guidelines, and testing requirements.
- Established error handling guidelines with a structured exception hierarchy and best practices for API and view error handling.
- Documented view pattern guidelines, emphasizing the use of CBVs, FBVs, and ViewSets with examples.
- Implemented a benchmarking script for query performance analysis and optimization.
- Developed security documentation detailing measures, configurations, and a security checklist.
- Compiled a database optimization guide covering indexing strategies, query optimization patterns, and computed fields.
This commit is contained in:
pacnpal
2025-12-22 11:17:31 -05:00
parent 45d97b6e68
commit 2e35f8c5d9
71 changed files with 8036 additions and 1462 deletions

83
.github/SECURITY.md vendored Normal file
View File

@@ -0,0 +1,83 @@
# Security Policy
## Supported Versions
| Version | Supported |
| ------- | ------------------ |
| latest | :white_check_mark: |
| < latest | :x: |
Only the latest version of ThrillWiki receives security updates.
## Reporting a Vulnerability
We take security vulnerabilities seriously. If you discover a security issue, please report it responsibly.
### How to Report
1. **Do not** create a public GitHub issue for security vulnerabilities
2. Email your report to the project maintainers
3. Include as much detail as possible:
- Description of the vulnerability
- Steps to reproduce
- Potential impact
- Affected versions
- Any proof of concept (if available)
### What to Expect
- **Acknowledgment**: We will acknowledge receipt within 48 hours
- **Assessment**: We will assess the vulnerability and its impact
- **Updates**: We will keep you informed of our progress
- **Resolution**: We aim to resolve critical vulnerabilities within 7 days
- **Credit**: With your permission, we will credit you in our security advisories
### Scope
The following are in scope for security reports:
- ThrillWiki web application vulnerabilities
- Authentication and authorization issues
- Data exposure vulnerabilities
- Injection vulnerabilities (SQL, XSS, etc.)
- CSRF vulnerabilities
- Server-side request forgery (SSRF)
- Insecure direct object references
### Out of Scope
The following are out of scope:
- Denial of service attacks
- Social engineering attacks
- Physical security issues
- Issues in third-party applications or services
- Issues requiring physical access to a user's device
- Vulnerabilities in outdated versions
## Security Measures
ThrillWiki implements the following security measures:
- HTTPS enforcement with HSTS
- Content Security Policy
- XSS protection with input sanitization
- CSRF protection
- SQL injection prevention via ORM
- Rate limiting on authentication endpoints
- Secure session management
- JWT token rotation and blacklisting
For more details, see [docs/SECURITY.md](../docs/SECURITY.md).
## Security Updates
Security updates are released as soon as possible after a vulnerability is confirmed. We recommend:
1. Keep your installation up to date
2. Subscribe to release notifications
3. Review security advisories
## Contact
For security-related inquiries, please contact the project maintainers.

View File

@@ -1,7 +1,15 @@
"""
Management command to reset the database and create an admin user.
Security Note: This command uses a mix of raw SQL (for PostgreSQL-specific operations
like dropping all tables) and Django ORM (for creating users). The raw SQL operations
use quote_ident() for table/sequence names which is safe from SQL injection.
WARNING: This command is destructive and should only be used in development.
"""
from django.core.management.base import BaseCommand
from django.db import connection
from django.contrib.auth.hashers import make_password
import uuid
class Command(BaseCommand):
@@ -10,7 +18,8 @@ class Command(BaseCommand):
def handle(self, *args, **options):
self.stdout.write("Resetting database...")
# Drop all tables
# Drop all tables using PostgreSQL-specific operations
# Security: Using quote_ident() to safely quote table/sequence names
with connection.cursor() as cursor:
cursor.execute(
"""
@@ -21,7 +30,7 @@ class Command(BaseCommand):
SELECT tablename FROM pg_tables
WHERE schemaname = current_schema()
) LOOP
EXECUTE 'DROP TABLE IF EXISTS ' || \
EXECUTE 'DROP TABLE IF EXISTS ' ||
quote_ident(r.tablename) || ' CASCADE';
END LOOP;
END $$;
@@ -38,7 +47,7 @@ class Command(BaseCommand):
SELECT sequencename FROM pg_sequences
WHERE schemaname = current_schema()
) LOOP
EXECUTE 'ALTER SEQUENCE ' || \
EXECUTE 'ALTER SEQUENCE ' ||
quote_ident(r.sequencename) || ' RESTART WITH 1';
END LOOP;
END $$;
@@ -54,51 +63,25 @@ class Command(BaseCommand):
self.stdout.write("Migrations applied.")
# Create superuser using raw SQL
# Create superuser using Django ORM (safer than raw SQL)
try:
with connection.cursor() as cursor:
# Create user
user_id = str(uuid.uuid4())[:10]
cursor.execute(
"""
INSERT INTO accounts_user (
username, password, email, is_superuser, is_staff,
is_active, date_joined, user_id, first_name,
last_name, role, is_banned, ban_reason,
theme_preference
) VALUES (
'admin', %s, 'admin@thrillwiki.com', true, true,
true, NOW(), %s, '', '', 'SUPERUSER', false, '',
'light'
) RETURNING id;
""",
[make_password("admin"), user_id],
)
from apps.accounts.models import User, UserProfile
result = cursor.fetchone()
if result is None:
raise Exception("Failed to create user - no ID returned")
user_db_id = result[0]
# Security: Using Django ORM instead of raw SQL for user creation
user = User.objects.create_superuser(
username='admin',
email='admin@thrillwiki.com',
password='admin',
role='SUPERUSER',
)
# Create profile
profile_id = str(uuid.uuid4())[:10]
cursor.execute(
"""
INSERT INTO accounts_userprofile (
profile_id, display_name, pronouns, bio,
twitter, instagram, youtube, discord,
coaster_credits, dark_ride_credits,
flat_ride_credits, water_ride_credits,
user_id, avatar
) VALUES (
%s, 'Admin', 'they/them', 'ThrillWiki Administrator',
'', '', '', '',
0, 0, 0, 0,
%s, ''
);
""",
[profile_id, user_db_id],
)
# Create profile using ORM
UserProfile.objects.create(
user=user,
display_name='Admin',
pronouns='they/them',
bio='ThrillWiki Administrator',
)
self.stdout.write("Superuser created.")
except Exception as e:

View File

@@ -0,0 +1,40 @@
"""
Add performance indexes and constraints to User model.
This migration adds:
1. db_index=True to is_banned and role fields for faster filtering
2. Composite index on (is_banned, role) for common query patterns
3. CheckConstraint to ensure banned users have a ban_date set
"""
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('accounts', '0012_alter_toplist_category_and_more'),
]
operations = [
# Add db_index to is_banned field
migrations.AlterField(
model_name='user',
name='is_banned',
field=models.BooleanField(default=False, db_index=True),
),
# Add composite index for common query patterns
migrations.AddIndex(
model_name='user',
index=models.Index(fields=['is_banned', 'role'], name='accounts_user_banned_role_idx'),
),
# Add CheckConstraint for ban consistency
migrations.AddConstraint(
model_name='user',
constraint=models.CheckConstraint(
name='user_ban_consistency',
check=models.Q(is_banned=False) | models.Q(ban_date__isnull=False),
violation_error_message='Banned users must have a ban_date set'
),
),
]

View File

@@ -49,8 +49,9 @@ class User(AbstractUser):
domain="accounts",
max_length=10,
default="USER",
db_index=True,
)
is_banned = models.BooleanField(default=False)
is_banned = models.BooleanField(default=False, db_index=True)
ban_reason = models.TextField(blank=True)
ban_date = models.DateTimeField(null=True, blank=True)
pending_email = models.EmailField(blank=True, null=True)
@@ -127,6 +128,18 @@ class User(AbstractUser):
return profile.display_name
return self.username
class Meta:
indexes = [
models.Index(fields=['is_banned', 'role'], name='accounts_user_banned_role_idx'),
]
constraints = [
models.CheckConstraint(
name='user_ban_consistency',
check=models.Q(is_banned=False) | models.Q(ban_date__isnull=False),
violation_error_message='Banned users must have a ban_date set'
),
]
def save(self, *args, **kwargs):
if not self.user_id:
self.user_id = generate_random_id(User, "user_id")

View File

@@ -2,16 +2,281 @@
User management services for ThrillWiki.
This module contains services for user account management including
user deletion while preserving submissions.
user deletion while preserving submissions, password management,
and email change functionality.
Recent additions:
- AccountService: Handles password and email change operations
- UserDeletionService: Manages user deletion while preserving content
"""
from typing import Optional
from django.db import transaction
from django.utils import timezone
import logging
import re
from typing import Any, Dict, Optional
from django.conf import settings
from django.contrib.auth import update_session_auth_hash
from django.contrib.sites.models import Site
from django.contrib.sites.shortcuts import get_current_site
from django.db import transaction
from django.http import HttpRequest
from django.template.loader import render_to_string
from django.utils import timezone
from django.utils.crypto import get_random_string
from django_forwardemail.services import EmailService
from .models import User, UserProfile, UserDeletionRequest
from .models import EmailVerification, User, UserDeletionRequest, UserProfile
logger = logging.getLogger(__name__)
class AccountService:
"""Service for account management operations including password and email changes."""
@staticmethod
def validate_password(password: str) -> bool:
"""
Validate password meets requirements.
Args:
password: The password to validate
Returns:
True if password meets requirements, False otherwise
"""
return (
len(password) >= 8
and bool(re.search(r"[A-Z]", password))
and bool(re.search(r"[a-z]", password))
and bool(re.search(r"[0-9]", password))
)
@staticmethod
def change_password(
*,
user: User,
old_password: str,
new_password: str,
request: HttpRequest,
) -> Dict[str, Any]:
"""
Change user password with validation and notification.
Validates the old password, checks new password requirements,
updates the password, and sends a confirmation email.
Args:
user: The user whose password is being changed
old_password: Current password for verification
new_password: New password to set
request: HTTP request for session handling
Returns:
Dictionary with success status, message, and optional redirect URL:
{
'success': bool,
'message': str,
'redirect_url': Optional[str]
}
"""
# Verify old password
if not user.check_password(old_password):
logger.warning(
f"Password change failed: incorrect current password for user {user.id}"
)
return {
'success': False,
'message': "Current password is incorrect",
'redirect_url': None
}
# Validate new password
if not AccountService.validate_password(new_password):
return {
'success': False,
'message': "Password must be at least 8 characters and contain uppercase, lowercase, and numbers",
'redirect_url': None
}
# Update password
user.set_password(new_password)
user.save()
# Keep user logged in after password change
update_session_auth_hash(request, user)
# Send confirmation email
AccountService._send_password_change_confirmation(request, user)
logger.info(f"Password changed successfully for user {user.id}")
return {
'success': True,
'message': "Password changed successfully. Please check your email for confirmation.",
'redirect_url': None
}
@staticmethod
def _send_password_change_confirmation(request: HttpRequest, user: User) -> None:
"""Send password change confirmation email."""
site = get_current_site(request)
context = {
"user": user,
"site_name": site.name,
}
email_html = render_to_string(
"accounts/email/password_change_confirmation.html", context
)
try:
EmailService.send_email(
to=user.email,
subject="Password Changed Successfully",
text="Your password has been changed successfully.",
site=site,
html=email_html,
)
except Exception as e:
logger.error(f"Failed to send password change confirmation email: {e}")
@staticmethod
def initiate_email_change(
*,
user: User,
new_email: str,
request: HttpRequest,
) -> Dict[str, Any]:
"""
Initiate email change with verification.
Creates a verification token and sends a verification email
to the new email address.
Args:
user: The user changing their email
new_email: The new email address
request: HTTP request for site context
Returns:
Dictionary with success status and message:
{
'success': bool,
'message': str
}
"""
if not new_email:
return {
'success': False,
'message': "New email is required"
}
# Check if email is already in use
if User.objects.filter(email=new_email).exclude(id=user.id).exists():
return {
'success': False,
'message': "This email address is already in use"
}
# Generate verification token
token = get_random_string(64)
# Create or update email verification record
EmailVerification.objects.update_or_create(
user=user,
defaults={"token": token}
)
# Store pending email
user.pending_email = new_email
user.save()
# Send verification email
AccountService._send_email_verification(request, user, new_email, token)
logger.info(f"Email change initiated for user {user.id} to {new_email}")
return {
'success': True,
'message': "Verification email sent to your new email address"
}
@staticmethod
def _send_email_verification(
request: HttpRequest,
user: User,
new_email: str,
token: str
) -> None:
"""Send email verification for email change."""
from django.urls import reverse
site = get_current_site(request)
verification_url = reverse("verify_email", kwargs={"token": token})
context = {
"user": user,
"verification_url": verification_url,
"site_name": site.name,
}
email_html = render_to_string("accounts/email/verify_email.html", context)
try:
EmailService.send_email(
to=new_email,
subject="Verify your new email address",
text="Click the link to verify your new email address",
site=site,
html=email_html,
)
except Exception as e:
logger.error(f"Failed to send email verification: {e}")
@staticmethod
def verify_email_change(*, token: str) -> Dict[str, Any]:
"""
Verify email change token and update user email.
Args:
token: The verification token
Returns:
Dictionary with success status and message
"""
try:
verification = EmailVerification.objects.select_related("user").get(
token=token
)
except EmailVerification.DoesNotExist:
return {
'success': False,
'message': "Invalid or expired verification token"
}
user = verification.user
if not user.pending_email:
return {
'success': False,
'message': "No pending email change found"
}
# Update email
old_email = user.email
user.email = user.pending_email
user.pending_email = None
user.save()
# Delete verification record
verification.delete()
logger.info(f"Email changed for user {user.id} from {old_email} to {user.email}")
return {
'success': True,
'message': "Email address updated successfully"
}
class UserDeletionService:

View File

@@ -0,0 +1,101 @@
"""
Tests for model constraints and validators in the accounts app.
These tests verify that:
1. CheckConstraints raise appropriate errors
2. Validators work correctly
3. Business rules are enforced at the model level
"""
from django.test import TestCase
from django.db import IntegrityError
from django.core.exceptions import ValidationError
from django.utils import timezone
from apps.accounts.models import User
class UserConstraintTests(TestCase):
"""Tests for User model constraints."""
def test_banned_user_without_ban_date_raises_error(self):
"""Verify banned users must have a ban_date set."""
user = User(
username="testuser",
email="test@example.com",
is_banned=True,
ban_date=None, # This should violate the constraint
)
# The constraint should be enforced at database level
with self.assertRaises(IntegrityError):
user.save()
def test_banned_user_with_ban_date_saves_successfully(self):
"""Verify banned users with ban_date save successfully."""
user = User.objects.create_user(
username="testuser2",
email="test2@example.com",
password="testpass123",
is_banned=True,
ban_date=timezone.now(),
)
self.assertIsNotNone(user.pk)
self.assertTrue(user.is_banned)
self.assertIsNotNone(user.ban_date)
def test_non_banned_user_without_ban_date_saves_successfully(self):
"""Verify non-banned users can be saved without ban_date."""
user = User.objects.create_user(
username="testuser3",
email="test3@example.com",
password="testpass123",
is_banned=False,
ban_date=None,
)
self.assertIsNotNone(user.pk)
self.assertFalse(user.is_banned)
def test_user_id_is_auto_generated(self):
"""Verify user_id is automatically generated on save."""
user = User.objects.create_user(
username="testuser4",
email="test4@example.com",
password="testpass123",
)
self.assertIsNotNone(user.user_id)
self.assertTrue(len(user.user_id) >= 4)
def test_user_id_is_unique(self):
"""Verify user_id is unique across users."""
user1 = User.objects.create_user(
username="testuser5",
email="test5@example.com",
password="testpass123",
)
user2 = User.objects.create_user(
username="testuser6",
email="test6@example.com",
password="testpass123",
)
self.assertNotEqual(user1.user_id, user2.user_id)
class UserIndexTests(TestCase):
"""Tests for User model indexes."""
def test_is_banned_field_is_indexed(self):
"""Verify is_banned field has db_index=True."""
field = User._meta.get_field('is_banned')
self.assertTrue(field.db_index)
def test_role_field_is_indexed(self):
"""Verify role field has db_index=True."""
field = User._meta.get_field('role')
self.assertTrue(field.db_index)
def test_composite_index_exists(self):
"""Verify composite index on (is_banned, role) exists."""
indexes = User._meta.indexes
index_names = [idx.name for idx in indexes]
self.assertIn('accounts_user_banned_role_idx', index_names)

View File

@@ -1302,11 +1302,15 @@ def get_user_statistics(request):
user = request.user
# Calculate user statistics
# TODO(THRILLWIKI-104): Implement full user statistics tracking
from apps.parks.models import ParkReview
from apps.rides.models import RideReview
data = {
"parks_visited": 0, # TODO: Implement based on reviews/check-ins
"rides_ridden": 0, # TODO: Implement based on reviews/check-ins
"reviews_written": 0, # TODO: Count user's reviews
"photos_uploaded": 0, # TODO: Count user's photos
"parks_visited": ParkReview.objects.filter(user=user).values("park").distinct().count(),
"rides_ridden": RideReview.objects.filter(user=user).values("ride").distinct().count(),
"reviews_written": ParkReview.objects.filter(user=user).count() + RideReview.objects.filter(user=user).count(),
"photos_uploaded": 0, # TODO(THRILLWIKI-105): Implement photo counting
"top_lists_created": TopList.objects.filter(user=user).count(),
"member_since": user.date_joined,
"last_activity": user.last_login,

View File

@@ -7,8 +7,6 @@ from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
from rest_framework.permissions import AllowAny
from django.views.decorators.csrf import csrf_exempt
from django.utils.decorators import method_decorator
from typing import Optional, List
from drf_spectacular.utils import extend_schema
@@ -260,12 +258,14 @@ class EntityNotFoundView(APIView):
)
@method_decorator(csrf_exempt, name="dispatch")
class QuickEntitySuggestionView(APIView):
"""
Lightweight endpoint for quick entity suggestions (e.g., autocomplete).
Migrated from apps.core.views.entity_search.QuickEntitySuggestionView
Security Note: This endpoint only accepts GET requests, which are inherently
safe from CSRF attacks. No CSRF exemption is needed.
"""
permission_classes = [AllowAny]

View File

@@ -12,7 +12,7 @@ from django.contrib.gis.geos import Polygon
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
from rest_framework.permissions import AllowAny
from rest_framework.permissions import AllowAny, IsAdminUser
from drf_spectacular.utils import (
extend_schema,
extend_schema_view,
@@ -306,7 +306,7 @@ class MapLocationsAPIView(APIView):
return {
"status": "success",
"locations": locations,
"clusters": [], # TODO: Implement clustering
"clusters": [], # TODO(THRILLWIKI-106): Implement map clustering algorithm
"bounds": self._calculate_bounds(locations),
"total_count": len(locations),
"clustered": params["cluster"],
@@ -471,7 +471,7 @@ class MapLocationDetailAPIView(APIView):
obj.opening_date.isoformat() if obj.opening_date else None
),
},
"nearby_locations": [], # TODO: Implement nearby locations
"nearby_locations": [], # TODO(THRILLWIKI-107): Implement nearby locations for parks
}
else: # ride
data = {
@@ -538,7 +538,7 @@ class MapLocationDetailAPIView(APIView):
obj.manufacturer.name if obj.manufacturer else None
),
},
"nearby_locations": [], # TODO: Implement nearby locations
"nearby_locations": [], # TODO(THRILLWIKI-107): Implement nearby locations for rides
}
return Response(
@@ -669,7 +669,7 @@ class MapSearchAPIView(APIView):
else ""
),
},
"relevance_score": 1.0, # TODO: Implement relevance scoring
"relevance_score": 1.0, # TODO(THRILLWIKI-108): Implement relevance scoring for search
}
)
@@ -722,7 +722,7 @@ class MapSearchAPIView(APIView):
else ""
),
},
"relevance_score": 1.0, # TODO: Implement relevance scoring
"relevance_score": 1.0, # TODO(THRILLWIKI-108): Implement relevance scoring for search
}
)
@@ -965,8 +965,8 @@ class MapStatsAPIView(APIView):
"total_locations": total_locations,
"parks_with_location": parks_with_location,
"rides_with_location": rides_with_location,
"cache_hits": 0, # TODO: Implement cache statistics
"cache_misses": 0, # TODO: Implement cache statistics
"cache_hits": 0, # TODO(THRILLWIKI-109): Implement cache statistics tracking
"cache_misses": 0, # TODO(THRILLWIKI-109): Implement cache statistics tracking
},
}
)
@@ -996,7 +996,7 @@ class MapStatsAPIView(APIView):
class MapCacheAPIView(APIView):
"""API endpoint for cache management (admin only)."""
permission_classes = [AllowAny] # TODO: Add admin permission check
permission_classes = [IsAdminUser] # Admin only
def delete(self, request: HttpRequest) -> Response:
"""Clear all map cache (admin only)."""

View File

@@ -5,34 +5,44 @@ This module contains consolidated park photo viewset for the centralized API str
Enhanced from rogue implementation to maintain full feature parity.
"""
from .serializers import (
ParkPhotoOutputSerializer,
ParkPhotoCreateInputSerializer,
ParkPhotoUpdateInputSerializer,
ParkPhotoListOutputSerializer,
ParkPhotoApprovalInputSerializer,
ParkPhotoStatsOutputSerializer,
)
from typing import Any, cast
import logging
from typing import Any, cast
from django.contrib.auth import get_user_model
from django.core.exceptions import PermissionDenied
from drf_spectacular.utils import extend_schema_view, extend_schema, OpenApiParameter
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view
from rest_framework import status
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
from rest_framework.permissions import IsAuthenticated, AllowAny
from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.viewsets import ModelViewSet
from apps.parks.models import ParkPhoto, Park
from apps.core.exceptions import (
NotFoundError,
PermissionDeniedError,
ServiceError,
ValidationException,
)
from apps.core.utils.error_handling import ErrorHandler
from apps.parks.models import Park, ParkPhoto
from apps.parks.services import ParkMediaService
from django.contrib.auth import get_user_model
from apps.parks.services.hybrid_loader import smart_park_loader
UserModel = get_user_model()
from .serializers import (
HybridParkSerializer,
ParkPhotoApprovalInputSerializer,
ParkPhotoCreateInputSerializer,
ParkPhotoListOutputSerializer,
ParkPhotoOutputSerializer,
ParkPhotoStatsOutputSerializer,
ParkPhotoUpdateInputSerializer,
)
logger = logging.getLogger(__name__)
UserModel = get_user_model()
@extend_schema_view(
@@ -113,7 +123,7 @@ class ParkPhotoViewSet(ModelViewSet):
def get_permissions(self):
"""Set permissions based on action."""
if self.action in ['list', 'retrieve', 'stats']:
if self.action in ["list", "retrieve", "stats"]:
permission_classes = [AllowAny]
else:
permission_classes = [IsAuthenticated]
@@ -166,8 +176,11 @@ class ParkPhotoViewSet(ModelViewSet):
# Set the instance for the serializer response
serializer.instance = photo
except Exception as e:
logger.error(f"Error creating park photo: {e}")
except (ValidationException, ValidationError) as e:
logger.warning(f"Validation error creating park photo: {e}")
raise ValidationError(str(e))
except ServiceError as e:
logger.error(f"Service error creating park photo: {e}")
raise ValidationError(f"Failed to create photo: {str(e)}")
def perform_update(self, serializer):
@@ -190,8 +203,11 @@ class ParkPhotoViewSet(ModelViewSet):
# Remove is_primary from validated_data since service handles it
if "is_primary" in serializer.validated_data:
del serializer.validated_data["is_primary"]
except Exception as e:
logger.error(f"Error setting primary photo: {e}")
except (ValidationException, ValidationError) as e:
logger.warning(f"Validation error setting primary photo: {e}")
raise ValidationError(str(e))
except ServiceError as e:
logger.error(f"Service error setting primary photo: {e}")
raise ValidationError(f"Failed to set primary photo: {str(e)}")
def perform_destroy(self, instance):
@@ -205,25 +221,30 @@ class ParkPhotoViewSet(ModelViewSet):
"You can only delete your own photos or be an admin."
)
try:
# Delete from Cloudflare first if image exists
if instance.image:
try:
from django_cloudflareimages_toolkit.services import CloudflareImagesService
service = CloudflareImagesService()
service.delete_image(instance.image)
logger.info(
f"Successfully deleted park photo from Cloudflare: {instance.image.cloudflare_id}")
except Exception as e:
logger.error(
f"Failed to delete park photo from Cloudflare: {str(e)}")
# Continue with database deletion even if Cloudflare deletion fails
# Delete from Cloudflare first if image exists
if instance.image:
try:
from django_cloudflareimages_toolkit.services import (
CloudflareImagesService,
)
service = CloudflareImagesService()
service.delete_image(instance.image)
logger.info(
f"Successfully deleted park photo from Cloudflare: {instance.image.cloudflare_id}"
)
except ImportError:
logger.warning("CloudflareImagesService not available")
except ServiceError as e:
logger.error(f"Service error deleting from Cloudflare: {str(e)}")
# Continue with database deletion even if Cloudflare deletion fails
try:
ParkMediaService().delete_photo(
instance.id, deleted_by=cast(UserModel, self.request.user)
)
except Exception as e:
logger.error(f"Error deleting park photo: {e}")
except ServiceError as e:
logger.error(f"Service error deleting park photo: {e}")
raise ValidationError(f"Failed to delete photo: {str(e)}")
@extend_schema(
@@ -265,11 +286,18 @@ class ParkPhotoViewSet(ModelViewSet):
status=status.HTTP_200_OK,
)
except Exception as e:
logger.error(f"Error setting primary photo: {e}")
return Response(
{"error": f"Failed to set primary photo: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST,
except (ValidationException, ValidationError) as e:
logger.warning(f"Validation error setting primary photo: {e}")
return ErrorHandler.handle_api_error(
e,
user_message="Failed to set primary photo",
status_code=status.HTTP_400_BAD_REQUEST,
)
except ServiceError as e:
return ErrorHandler.handle_api_error(
e,
user_message="Failed to set primary photo",
status_code=status.HTTP_400_BAD_REQUEST,
)
@extend_schema(
@@ -319,11 +347,18 @@ class ParkPhotoViewSet(ModelViewSet):
status=status.HTTP_200_OK,
)
except Exception as e:
logger.error(f"Error in bulk photo approval: {e}")
return Response(
{"error": f"Failed to update photos: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST,
except (ValidationException, ValidationError) as e:
logger.warning(f"Validation error in bulk photo approval: {e}")
return ErrorHandler.handle_api_error(
e,
user_message="Failed to update photos",
status_code=status.HTTP_400_BAD_REQUEST,
)
except ServiceError as e:
return ErrorHandler.handle_api_error(
e,
user_message="Failed to update photos",
status_code=status.HTTP_400_BAD_REQUEST,
)
@extend_schema(
@@ -345,9 +380,10 @@ class ParkPhotoViewSet(ModelViewSet):
try:
park = Park.objects.get(pk=park_pk)
except Park.DoesNotExist:
return Response(
{"error": "Park not found."},
status=status.HTTP_404_NOT_FOUND,
return ErrorHandler.handle_api_error(
NotFoundError(f"Park with id {park_pk} not found"),
user_message="Park not found",
status_code=status.HTTP_404_NOT_FOUND,
)
try:
@@ -359,11 +395,11 @@ class ParkPhotoViewSet(ModelViewSet):
return Response(serializer.data, status=status.HTTP_200_OK)
except Exception as e:
logger.error(f"Error getting park photo stats: {e}")
return Response(
{"error": f"Failed to get photo statistics: {str(e)}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
except ServiceError as e:
return ErrorHandler.handle_api_error(
e,
user_message="Failed to get photo statistics",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
# Legacy compatibility action using the legacy set_primary logic
@@ -394,9 +430,19 @@ class ParkPhotoViewSet(ModelViewSet):
park_id=photo.park_id, photo_id=photo.id
)
return Response({"message": "Photo set as primary successfully."})
except Exception as e:
logger.error(f"Error in set_primary_photo: {str(e)}", exc_info=True)
return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST)
except (ValidationException, ValidationError) as e:
logger.warning(f"Validation error in set_primary_photo: {str(e)}")
return ErrorHandler.handle_api_error(
e,
user_message="Failed to set primary photo",
status_code=status.HTTP_400_BAD_REQUEST,
)
except ServiceError as e:
return ErrorHandler.handle_api_error(
e,
user_message="Failed to set primary photo",
status_code=status.HTTP_400_BAD_REQUEST,
)
@extend_schema(
summary="Save Cloudflare image as park photo",
@@ -442,60 +488,55 @@ class ParkPhotoViewSet(ModelViewSet):
from django.utils import timezone
# Always fetch the latest image data from Cloudflare API
# Get image details from Cloudflare API
service = CloudflareImagesService()
image_data = service.get_image(cloudflare_image_id)
if not image_data:
return ErrorHandler.handle_api_error(
NotFoundError("Image not found in Cloudflare"),
user_message="Image not found in Cloudflare",
status_code=status.HTTP_400_BAD_REQUEST,
)
# Try to find existing CloudflareImage record by cloudflare_id
cloudflare_image = None
try:
# Get image details from Cloudflare API
service = CloudflareImagesService()
image_data = service.get_image(cloudflare_image_id)
cloudflare_image = CloudflareImage.objects.get(
cloudflare_id=cloudflare_image_id
)
if not image_data:
return Response(
{"error": "Image not found in Cloudflare"},
status=status.HTTP_400_BAD_REQUEST,
)
# Update existing record with latest data from Cloudflare
cloudflare_image.status = "uploaded"
cloudflare_image.uploaded_at = timezone.now()
cloudflare_image.metadata = image_data.get("meta", {})
# Extract variants from nested result structure
cloudflare_image.variants = image_data.get("result", {}).get(
"variants", []
)
cloudflare_image.cloudflare_metadata = image_data
cloudflare_image.width = image_data.get("width")
cloudflare_image.height = image_data.get("height")
cloudflare_image.format = image_data.get("format", "")
cloudflare_image.save()
# Try to find existing CloudflareImage record by cloudflare_id
cloudflare_image = None
try:
cloudflare_image = CloudflareImage.objects.get(
cloudflare_id=cloudflare_image_id)
# Update existing record with latest data from Cloudflare
cloudflare_image.status = 'uploaded'
cloudflare_image.uploaded_at = timezone.now()
cloudflare_image.metadata = image_data.get('meta', {})
except CloudflareImage.DoesNotExist:
# Create new CloudflareImage record from API response
cloudflare_image = CloudflareImage.objects.create(
cloudflare_id=cloudflare_image_id,
user=request.user,
status="uploaded",
upload_url="", # Not needed for uploaded images
expires_at=timezone.now()
+ timezone.timedelta(days=365), # Set far future expiry
uploaded_at=timezone.now(),
metadata=image_data.get("meta", {}),
# Extract variants from nested result structure
cloudflare_image.variants = image_data.get(
'result', {}).get('variants', [])
cloudflare_image.cloudflare_metadata = image_data
cloudflare_image.width = image_data.get('width')
cloudflare_image.height = image_data.get('height')
cloudflare_image.format = image_data.get('format', '')
cloudflare_image.save()
except CloudflareImage.DoesNotExist:
# Create new CloudflareImage record from API response
cloudflare_image = CloudflareImage.objects.create(
cloudflare_id=cloudflare_image_id,
user=request.user,
status='uploaded',
upload_url='', # Not needed for uploaded images
expires_at=timezone.now() + timezone.timedelta(days=365), # Set far future expiry
uploaded_at=timezone.now(),
metadata=image_data.get('meta', {}),
# Extract variants from nested result structure
variants=image_data.get('result', {}).get('variants', []),
cloudflare_metadata=image_data,
width=image_data.get('width'),
height=image_data.get('height'),
format=image_data.get('format', ''),
)
except Exception as api_error:
logger.error(
f"Error fetching image from Cloudflare API: {str(api_error)}", exc_info=True)
return Response(
{"error": f"Failed to fetch image from Cloudflare: {str(api_error)}"},
status=status.HTTP_400_BAD_REQUEST,
variants=image_data.get("result", {}).get("variants", []),
cloudflare_metadata=image_data,
width=image_data.get("width"),
height=image_data.get("height"),
format=image_data.get("format", ""),
)
# Create the park photo with the CloudflareImage reference
@@ -516,25 +557,33 @@ class ParkPhotoViewSet(ModelViewSet):
ParkMediaService().set_primary_photo(
park_id=park.id, photo_id=photo.id
)
except Exception as e:
except ServiceError as e:
logger.error(f"Error setting primary photo: {e}")
# Don't fail the entire operation, just log the error
serializer = ParkPhotoOutputSerializer(photo, context={"request": request})
return Response(serializer.data, status=status.HTTP_201_CREATED)
except Exception as e:
logger.error(f"Error saving park photo: {e}")
return Response(
{"error": f"Failed to save photo: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST,
except ImportError:
logger.error("CloudflareImagesService not available")
return ErrorHandler.handle_api_error(
ServiceError("Cloudflare Images service not available"),
user_message="Image upload service not available",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
except (ValidationException, ValidationError) as e:
logger.warning(f"Validation error saving park photo: {e}")
return ErrorHandler.handle_api_error(
e,
user_message="Failed to save photo",
status_code=status.HTTP_400_BAD_REQUEST,
)
except ServiceError as e:
return ErrorHandler.handle_api_error(
e,
user_message="Failed to save photo",
status_code=status.HTTP_400_BAD_REQUEST,
)
from rest_framework.views import APIView
from rest_framework.permissions import AllowAny
from .serializers import HybridParkSerializer
from apps.parks.services.hybrid_loader import smart_park_loader
@extend_schema_view(
@@ -542,23 +591,79 @@ from apps.parks.services.hybrid_loader import smart_park_loader
summary="Get parks with hybrid filtering",
description="Retrieve parks with intelligent hybrid filtering strategy. Automatically chooses between client-side and server-side filtering based on data size.",
parameters=[
OpenApiParameter("status", OpenApiTypes.STR, description="Filter by park status (comma-separated for multiple)"),
OpenApiParameter("park_type", OpenApiTypes.STR, description="Filter by park type (comma-separated for multiple)"),
OpenApiParameter("country", OpenApiTypes.STR, description="Filter by country (comma-separated for multiple)"),
OpenApiParameter("state", OpenApiTypes.STR, description="Filter by state (comma-separated for multiple)"),
OpenApiParameter("opening_year_min", OpenApiTypes.INT, description="Minimum opening year"),
OpenApiParameter("opening_year_max", OpenApiTypes.INT, description="Maximum opening year"),
OpenApiParameter("size_min", OpenApiTypes.NUMBER, description="Minimum park size in acres"),
OpenApiParameter("size_max", OpenApiTypes.NUMBER, description="Maximum park size in acres"),
OpenApiParameter("rating_min", OpenApiTypes.NUMBER, description="Minimum average rating"),
OpenApiParameter("rating_max", OpenApiTypes.NUMBER, description="Maximum average rating"),
OpenApiParameter("ride_count_min", OpenApiTypes.INT, description="Minimum ride count"),
OpenApiParameter("ride_count_max", OpenApiTypes.INT, description="Maximum ride count"),
OpenApiParameter("coaster_count_min", OpenApiTypes.INT, description="Minimum coaster count"),
OpenApiParameter("coaster_count_max", OpenApiTypes.INT, description="Maximum coaster count"),
OpenApiParameter("operator", OpenApiTypes.STR, description="Filter by operator slug (comma-separated for multiple)"),
OpenApiParameter("search", OpenApiTypes.STR, description="Search query for park names, descriptions, locations, and operators"),
OpenApiParameter("offset", OpenApiTypes.INT, description="Offset for progressive loading (server-side pagination)"),
OpenApiParameter(
"status",
OpenApiTypes.STR,
description="Filter by park status (comma-separated for multiple)",
),
OpenApiParameter(
"park_type",
OpenApiTypes.STR,
description="Filter by park type (comma-separated for multiple)",
),
OpenApiParameter(
"country",
OpenApiTypes.STR,
description="Filter by country (comma-separated for multiple)",
),
OpenApiParameter(
"state",
OpenApiTypes.STR,
description="Filter by state (comma-separated for multiple)",
),
OpenApiParameter(
"opening_year_min", OpenApiTypes.INT, description="Minimum opening year"
),
OpenApiParameter(
"opening_year_max", OpenApiTypes.INT, description="Maximum opening year"
),
OpenApiParameter(
"size_min",
OpenApiTypes.NUMBER,
description="Minimum park size in acres",
),
OpenApiParameter(
"size_max",
OpenApiTypes.NUMBER,
description="Maximum park size in acres",
),
OpenApiParameter(
"rating_min", OpenApiTypes.NUMBER, description="Minimum average rating"
),
OpenApiParameter(
"rating_max", OpenApiTypes.NUMBER, description="Maximum average rating"
),
OpenApiParameter(
"ride_count_min", OpenApiTypes.INT, description="Minimum ride count"
),
OpenApiParameter(
"ride_count_max", OpenApiTypes.INT, description="Maximum ride count"
),
OpenApiParameter(
"coaster_count_min",
OpenApiTypes.INT,
description="Minimum coaster count",
),
OpenApiParameter(
"coaster_count_max",
OpenApiTypes.INT,
description="Maximum coaster count",
),
OpenApiParameter(
"operator",
OpenApiTypes.STR,
description="Filter by operator slug (comma-separated for multiple)",
),
OpenApiParameter(
"search",
OpenApiTypes.STR,
description="Search query for park names, descriptions, locations, and operators",
),
OpenApiParameter(
"offset",
OpenApiTypes.INT,
description="Offset for progressive loading (server-side pagination)",
),
],
responses={
200: {
@@ -570,31 +675,33 @@ from apps.parks.services.hybrid_loader import smart_park_loader
"properties": {
"parks": {
"type": "array",
"items": {"$ref": "#/components/schemas/HybridParkSerializer"}
"items": {
"$ref": "#/components/schemas/HybridParkSerializer"
},
},
"total_count": {"type": "integer"},
"strategy": {
"type": "string",
"enum": ["client_side", "server_side"],
"description": "Filtering strategy used"
"description": "Filtering strategy used",
},
"has_more": {
"type": "boolean",
"description": "Whether more data is available for progressive loading"
"description": "Whether more data is available for progressive loading",
},
"next_offset": {
"type": "integer",
"nullable": True,
"description": "Next offset for progressive loading"
"description": "Next offset for progressive loading",
},
"filter_metadata": {
"type": "object",
"description": "Available filter options and ranges"
}
}
"description": "Available filter options and ranges",
},
},
}
}
}
},
}
},
tags=["Parks"],
@@ -603,77 +710,83 @@ from apps.parks.services.hybrid_loader import smart_park_loader
class HybridParkAPIView(APIView):
"""
Hybrid Park API View with intelligent filtering strategy.
Automatically chooses between client-side and server-side filtering
based on data size and complexity. Provides progressive loading
for large datasets and complete data for smaller sets.
"""
permission_classes = [AllowAny]
def get(self, request):
"""Get parks with hybrid filtering strategy."""
# Extract filters from query parameters
filters = self._extract_filters(request.query_params)
# Check if this is a progressive load request
offset = request.query_params.get("offset")
if offset is not None:
try:
offset = int(offset)
except ValueError:
return ErrorHandler.handle_api_error(
ValidationException("Invalid offset parameter"),
user_message="Invalid offset parameter",
status_code=status.HTTP_400_BAD_REQUEST,
)
try:
# Extract filters from query parameters
filters = self._extract_filters(request.query_params)
# Check if this is a progressive load request
offset = request.query_params.get('offset')
if offset is not None:
try:
offset = int(offset)
# Get progressive load data
data = smart_park_loader.get_progressive_load(offset, filters)
except ValueError:
return Response(
{"error": "Invalid offset parameter"},
status=status.HTTP_400_BAD_REQUEST
)
# Get progressive load data
data = smart_park_loader.get_progressive_load(offset, filters)
else:
# Get initial load data
data = smart_park_loader.get_initial_load(filters)
# Serialize the parks data
serializer = HybridParkSerializer(data['parks'], many=True)
serializer = HybridParkSerializer(data["parks"], many=True)
# Prepare response
response_data = {
'parks': serializer.data,
'total_count': data['total_count'],
'strategy': data.get('strategy', 'server_side'),
'has_more': data.get('has_more', False),
'next_offset': data.get('next_offset'),
"parks": serializer.data,
"total_count": data["total_count"],
"strategy": data.get("strategy", "server_side"),
"has_more": data.get("has_more", False),
"next_offset": data.get("next_offset"),
}
# Include filter metadata for initial loads
if 'filter_metadata' in data:
response_data['filter_metadata'] = data['filter_metadata']
if "filter_metadata" in data:
response_data["filter_metadata"] = data["filter_metadata"]
return Response(response_data, status=status.HTTP_200_OK)
except Exception as e:
logger.error(f"Error in HybridParkAPIView: {e}")
return Response(
{"error": "Internal server error"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
except ServiceError as e:
return ErrorHandler.handle_api_error(
e,
user_message="Failed to load parks",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
def _extract_filters(self, query_params):
"""Extract and parse filters from query parameters."""
filters = {}
# Handle comma-separated list parameters
list_params = ['status', 'park_type', 'country', 'state', 'operator']
list_params = ["status", "park_type", "country", "state", "operator"]
for param in list_params:
value = query_params.get(param)
if value:
filters[param] = [v.strip() for v in value.split(',') if v.strip()]
filters[param] = [v.strip() for v in value.split(",") if v.strip()]
# Handle integer parameters
int_params = [
'opening_year_min', 'opening_year_max',
'ride_count_min', 'ride_count_max',
'coaster_count_min', 'coaster_count_max'
"opening_year_min",
"opening_year_max",
"ride_count_min",
"ride_count_max",
"coaster_count_min",
"coaster_count_max",
]
for param in int_params:
value = query_params.get(param)
@@ -682,9 +795,9 @@ class HybridParkAPIView(APIView):
filters[param] = int(value)
except ValueError:
pass # Skip invalid integer values
# Handle float parameters
float_params = ['size_min', 'size_max', 'rating_min', 'rating_max']
float_params = ["size_min", "size_max", "rating_min", "rating_max"]
for param in float_params:
value = query_params.get(param)
if value:
@@ -692,12 +805,12 @@ class HybridParkAPIView(APIView):
filters[param] = float(value)
except ValueError:
pass # Skip invalid float values
# Handle search parameter
search = query_params.get('search')
search = query_params.get("search")
if search:
filters['search'] = search.strip()
filters["search"] = search.strip()
return filters
@@ -706,7 +819,11 @@ class HybridParkAPIView(APIView):
summary="Get park filter metadata",
description="Get available filter options and ranges for parks filtering.",
parameters=[
OpenApiParameter("scoped", OpenApiTypes.BOOL, description="Whether to scope metadata to current filters"),
OpenApiParameter(
"scoped",
OpenApiTypes.BOOL,
description="Whether to scope metadata to current filters",
),
],
responses={
200: {
@@ -719,21 +836,33 @@ class HybridParkAPIView(APIView):
"categorical": {
"type": "object",
"properties": {
"countries": {"type": "array", "items": {"type": "string"}},
"states": {"type": "array", "items": {"type": "string"}},
"park_types": {"type": "array", "items": {"type": "string"}},
"statuses": {"type": "array", "items": {"type": "string"}},
"countries": {
"type": "array",
"items": {"type": "string"},
},
"states": {
"type": "array",
"items": {"type": "string"},
},
"park_types": {
"type": "array",
"items": {"type": "string"},
},
"statuses": {
"type": "array",
"items": {"type": "string"},
},
"operators": {
"type": "array",
"items": {
"type": "object",
"properties": {
"name": {"type": "string"},
"slug": {"type": "string"}
}
}
}
}
"slug": {"type": "string"},
},
},
},
},
},
"ranges": {
"type": "object",
@@ -741,45 +870,75 @@ class HybridParkAPIView(APIView):
"opening_year": {
"type": "object",
"properties": {
"min": {"type": "integer", "nullable": True},
"max": {"type": "integer", "nullable": True}
}
"min": {
"type": "integer",
"nullable": True,
},
"max": {
"type": "integer",
"nullable": True,
},
},
},
"size_acres": {
"type": "object",
"properties": {
"min": {"type": "number", "nullable": True},
"max": {"type": "number", "nullable": True}
}
"min": {
"type": "number",
"nullable": True,
},
"max": {
"type": "number",
"nullable": True,
},
},
},
"average_rating": {
"type": "object",
"properties": {
"min": {"type": "number", "nullable": True},
"max": {"type": "number", "nullable": True}
}
"min": {
"type": "number",
"nullable": True,
},
"max": {
"type": "number",
"nullable": True,
},
},
},
"ride_count": {
"type": "object",
"properties": {
"min": {"type": "integer", "nullable": True},
"max": {"type": "integer", "nullable": True}
}
"min": {
"type": "integer",
"nullable": True,
},
"max": {
"type": "integer",
"nullable": True,
},
},
},
"coaster_count": {
"type": "object",
"properties": {
"min": {"type": "integer", "nullable": True},
"max": {"type": "integer", "nullable": True}
}
}
}
"min": {
"type": "integer",
"nullable": True,
},
"max": {
"type": "integer",
"nullable": True,
},
},
},
},
},
"total_count": {"type": "integer"}
}
"total_count": {"type": "integer"},
},
}
}
}
},
}
},
tags=["Parks"],
@@ -788,35 +947,34 @@ class HybridParkAPIView(APIView):
class ParkFilterMetadataAPIView(APIView):
"""
API view for getting park filter metadata.
Provides information about available filter options and ranges
to help build dynamic filter interfaces.
"""
permission_classes = [AllowAny]
def get(self, request):
"""Get park filter metadata."""
# Check if metadata should be scoped to current filters
scoped = request.query_params.get("scoped", "").lower() == "true"
filters = None
if scoped:
filters = self._extract_filters(request.query_params)
try:
# Check if metadata should be scoped to current filters
scoped = request.query_params.get('scoped', '').lower() == 'true'
filters = None
if scoped:
filters = self._extract_filters(request.query_params)
# Get filter metadata
metadata = smart_park_loader.get_filter_metadata(filters)
return Response(metadata, status=status.HTTP_200_OK)
except Exception as e:
logger.error(f"Error in ParkFilterMetadataAPIView: {e}")
return Response(
{"error": "Internal server error"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
except ServiceError as e:
return ErrorHandler.handle_api_error(
e,
user_message="Failed to get filter metadata",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
def _extract_filters(self, query_params):
"""Extract and parse filters from query parameters."""
# Reuse the same filter extraction logic

File diff suppressed because it is too large Load Diff

View File

@@ -365,7 +365,7 @@ class MapLocationDetailSerializer(serializers.Serializer):
@extend_schema_field(serializers.ListField(child=serializers.DictField()))
def get_nearby_locations(self, obj) -> list:
"""Get nearby locations (placeholder for now)."""
# TODO: Implement nearby location logic
# TODO(THRILLWIKI-107): Implement nearby location logic using spatial queries
return []

View File

@@ -4,3 +4,12 @@ from django.apps import AppConfig
class CoreConfig(AppConfig):
default_auto_field = "django.db.models.BigAutoField"
name = "apps.core"
def ready(self):
"""
Application initialization.
Imports security checks to register them with Django's check framework.
"""
# Import security checks to register them
from . import checks # noqa: F401

372
backend/apps/core/checks.py Normal file
View File

@@ -0,0 +1,372 @@
"""
Django System Checks for Security Configuration.
This module implements Django system checks that validate security settings
at startup. These checks help catch security misconfigurations before
deployment.
Usage:
These checks run automatically when Django starts. They can also be run
manually with: python manage.py check --tag=security
Security checks included:
- SECRET_KEY validation (not default, sufficient entropy)
- DEBUG mode check (should be False in production)
- ALLOWED_HOSTS check (should be configured in production)
- Security headers validation
- HTTPS settings validation
- Cookie security settings
"""
import os
import re
from django.conf import settings
from django.core.checks import Error, Warning, register, Tags
# =============================================================================
# Secret Key Validation
# =============================================================================
@register(Tags.security)
def check_secret_key(app_configs, **kwargs):
"""
Check that SECRET_KEY is properly configured.
Validates:
- Key is not a known default/placeholder value
- Key has sufficient entropy (length and character variety)
"""
errors = []
secret_key = getattr(settings, 'SECRET_KEY', '')
# Check for empty or missing key
if not secret_key:
errors.append(
Error(
'SECRET_KEY is not set.',
hint='Set a strong, random SECRET_KEY in your environment.',
id='security.E001',
)
)
return errors
# Check for known insecure default values
insecure_defaults = [
'django-insecure',
'your-secret-key',
'change-me',
'changeme',
'secret',
'xxx',
'test',
'development',
'dev-key',
]
key_lower = secret_key.lower()
for default in insecure_defaults:
if default in key_lower:
errors.append(
Error(
f'SECRET_KEY appears to contain an insecure default value: "{default}"',
hint='Generate a new secret key using: python -c "from django.core.management.utils import get_random_secret_key; print(get_random_secret_key())"',
id='security.E002',
)
)
break
# Check minimum length (Django recommends at least 50 characters)
if len(secret_key) < 50:
errors.append(
Warning(
f'SECRET_KEY is only {len(secret_key)} characters long.',
hint='A secret key should be at least 50 characters for proper security.',
id='security.W001',
)
)
# Check for sufficient character variety
has_upper = bool(re.search(r'[A-Z]', secret_key))
has_lower = bool(re.search(r'[a-z]', secret_key))
has_digit = bool(re.search(r'[0-9]', secret_key))
has_special = bool(re.search(r'[!@#$%^&*()_+\-=\[\]{};\':"\\|,.<>\/?]', secret_key))
char_types = sum([has_upper, has_lower, has_digit, has_special])
if char_types < 3:
errors.append(
Warning(
'SECRET_KEY lacks character variety.',
hint='A good secret key should contain uppercase, lowercase, digits, and special characters.',
id='security.W002',
)
)
return errors
# =============================================================================
# Debug Mode Check
# =============================================================================
@register(Tags.security)
def check_debug_mode(app_configs, **kwargs):
"""
Check that DEBUG is False in production-like environments.
"""
errors = []
# Check if we're in a production-like environment
env = os.environ.get('DJANGO_SETTINGS_MODULE', '')
is_production = 'production' in env.lower() or 'prod' in env.lower()
if is_production and settings.DEBUG:
errors.append(
Error(
'DEBUG is True in what appears to be a production environment.',
hint='Set DEBUG=False in production settings.',
id='security.E003',
)
)
# Also check if DEBUG is True with ALLOWED_HOSTS configured
# (indicates possible production deployment with debug on)
if settings.DEBUG and settings.ALLOWED_HOSTS and '*' not in settings.ALLOWED_HOSTS:
if len(settings.ALLOWED_HOSTS) > 0 and 'localhost' not in settings.ALLOWED_HOSTS[0]:
errors.append(
Warning(
'DEBUG is True but ALLOWED_HOSTS contains non-localhost values.',
hint='This may indicate DEBUG is accidentally enabled in a deployed environment.',
id='security.W003',
)
)
return errors
# =============================================================================
# ALLOWED_HOSTS Check
# =============================================================================
@register(Tags.security)
def check_allowed_hosts(app_configs, **kwargs):
"""
Check ALLOWED_HOSTS configuration.
"""
errors = []
allowed_hosts = getattr(settings, 'ALLOWED_HOSTS', [])
if not settings.DEBUG:
# In non-debug mode, ALLOWED_HOSTS must be set
if not allowed_hosts:
errors.append(
Error(
'ALLOWED_HOSTS is empty but DEBUG is False.',
hint='Set ALLOWED_HOSTS to a list of allowed hostnames.',
id='security.E004',
)
)
elif '*' in allowed_hosts:
errors.append(
Error(
'ALLOWED_HOSTS contains "*" which allows all hosts.',
hint='Specify explicit hostnames instead of wildcards.',
id='security.E005',
)
)
return errors
# =============================================================================
# Security Headers Check
# =============================================================================
@register(Tags.security)
def check_security_headers(app_configs, **kwargs):
"""
Check that security headers are properly configured.
"""
errors = []
# Check X-Frame-Options
x_frame_options = getattr(settings, 'X_FRAME_OPTIONS', None)
if x_frame_options not in ('DENY', 'SAMEORIGIN'):
errors.append(
Warning(
f'X_FRAME_OPTIONS is set to "{x_frame_options}" or not set.',
hint='Set X_FRAME_OPTIONS to "DENY" or "SAMEORIGIN" to prevent clickjacking.',
id='security.W004',
)
)
# Check content type sniffing protection
if not getattr(settings, 'SECURE_CONTENT_TYPE_NOSNIFF', False):
errors.append(
Warning(
'SECURE_CONTENT_TYPE_NOSNIFF is not enabled.',
hint='Set SECURE_CONTENT_TYPE_NOSNIFF = True to prevent MIME type sniffing.',
id='security.W005',
)
)
# Check referrer policy
referrer_policy = getattr(settings, 'SECURE_REFERRER_POLICY', None)
if not referrer_policy:
errors.append(
Warning(
'SECURE_REFERRER_POLICY is not set.',
hint='Set SECURE_REFERRER_POLICY to control referrer header behavior.',
id='security.W006',
)
)
return errors
# =============================================================================
# HTTPS Settings Check
# =============================================================================
@register(Tags.security)
def check_https_settings(app_configs, **kwargs):
"""
Check HTTPS-related security settings for production.
"""
errors = []
# Skip these checks in debug mode
if settings.DEBUG:
return errors
# Check SSL redirect
if not getattr(settings, 'SECURE_SSL_REDIRECT', False):
errors.append(
Warning(
'SECURE_SSL_REDIRECT is not enabled.',
hint='Set SECURE_SSL_REDIRECT = True to redirect HTTP to HTTPS.',
id='security.W007',
)
)
# Check HSTS settings
hsts_seconds = getattr(settings, 'SECURE_HSTS_SECONDS', 0)
if hsts_seconds < 31536000: # Less than 1 year
errors.append(
Warning(
f'SECURE_HSTS_SECONDS is {hsts_seconds} (less than 1 year).',
hint='Set SECURE_HSTS_SECONDS to at least 31536000 (1 year) for HSTS preload eligibility.',
id='security.W008',
)
)
if not getattr(settings, 'SECURE_HSTS_INCLUDE_SUBDOMAINS', False):
errors.append(
Warning(
'SECURE_HSTS_INCLUDE_SUBDOMAINS is not enabled.',
hint='Set SECURE_HSTS_INCLUDE_SUBDOMAINS = True to include all subdomains in HSTS.',
id='security.W009',
)
)
return errors
# =============================================================================
# Cookie Security Check
# =============================================================================
@register(Tags.security)
def check_cookie_security(app_configs, **kwargs):
"""
Check cookie security settings for production.
"""
errors = []
# Skip in debug mode
if settings.DEBUG:
return errors
# Check session cookie security
if not getattr(settings, 'SESSION_COOKIE_SECURE', False):
errors.append(
Warning(
'SESSION_COOKIE_SECURE is not enabled.',
hint='Set SESSION_COOKIE_SECURE = True to only send session cookies over HTTPS.',
id='security.W010',
)
)
if not getattr(settings, 'SESSION_COOKIE_HTTPONLY', True):
errors.append(
Warning(
'SESSION_COOKIE_HTTPONLY is disabled.',
hint='Set SESSION_COOKIE_HTTPONLY = True to prevent JavaScript access to session cookies.',
id='security.W011',
)
)
# Check CSRF cookie security
if not getattr(settings, 'CSRF_COOKIE_SECURE', False):
errors.append(
Warning(
'CSRF_COOKIE_SECURE is not enabled.',
hint='Set CSRF_COOKIE_SECURE = True to only send CSRF cookies over HTTPS.',
id='security.W012',
)
)
# Check SameSite attributes
session_samesite = getattr(settings, 'SESSION_COOKIE_SAMESITE', 'Lax')
if session_samesite not in ('Strict', 'Lax'):
errors.append(
Warning(
f'SESSION_COOKIE_SAMESITE is set to "{session_samesite}".',
hint='Set SESSION_COOKIE_SAMESITE to "Strict" or "Lax" for CSRF protection.',
id='security.W013',
)
)
return errors
# =============================================================================
# Database Security Check
# =============================================================================
@register(Tags.security)
def check_database_security(app_configs, **kwargs):
"""
Check database connection security settings.
"""
errors = []
# Skip in debug mode
if settings.DEBUG:
return errors
databases = getattr(settings, 'DATABASES', {})
default_db = databases.get('default', {})
# Check for empty password
if not default_db.get('PASSWORD') and default_db.get('ENGINE', '').endswith('postgresql'):
errors.append(
Warning(
'Database password is empty.',
hint='Set a strong password for database authentication.',
id='security.W014',
)
)
# Check for SSL mode in PostgreSQL
options = default_db.get('OPTIONS', {})
if 'sslmode' not in str(options) and default_db.get('ENGINE', '').endswith('postgresql'):
errors.append(
Warning(
'Database SSL mode is not explicitly configured.',
hint='Consider setting sslmode in database OPTIONS for encrypted connections.',
id='security.W015',
)
)
return errors

View File

@@ -65,6 +65,14 @@ class BusinessLogicError(ThrillWikiException):
status_code = 400
class ServiceError(ThrillWikiException):
"""Raised when a service operation fails."""
default_message = "Service operation failed"
error_code = "SERVICE_ERROR"
status_code = 500
class ExternalServiceError(ThrillWikiException):
"""Raised when external service calls fail."""

View File

@@ -0,0 +1,240 @@
"""
Security Audit Management Command.
Runs comprehensive security checks on the Django application and generates
a security audit report.
Usage:
python manage.py security_audit
python manage.py security_audit --output report.txt
python manage.py security_audit --verbose
"""
from django.core.management.base import BaseCommand
from django.core.checks import registry, Tags
from django.conf import settings
class Command(BaseCommand):
help = 'Run security audit and generate a report'
def add_arguments(self, parser):
parser.add_argument(
'--output',
type=str,
help='Output file for the security report',
)
parser.add_argument(
'--verbose',
action='store_true',
help='Show detailed information for each check',
)
def handle(self, *args, **options):
self.verbose = options.get('verbose', False)
output_file = options.get('output')
report_lines = []
self.log("=" * 60, report_lines)
self.log("ThrillWiki Security Audit Report", report_lines)
self.log("=" * 60, report_lines)
self.log("", report_lines)
# Run Django's built-in security checks
self.log("Running Django Security Checks...", report_lines)
self.log("-" * 40, report_lines)
self.run_django_checks(report_lines)
# Run custom configuration checks
self.log("", report_lines)
self.log("Configuration Analysis...", report_lines)
self.log("-" * 40, report_lines)
self.check_configuration(report_lines)
# Run middleware checks
self.log("", report_lines)
self.log("Middleware Analysis...", report_lines)
self.log("-" * 40, report_lines)
self.check_middleware(report_lines)
# Summary
self.log("", report_lines)
self.log("=" * 60, report_lines)
self.log("Audit Complete", report_lines)
self.log("=" * 60, report_lines)
# Write to file if specified
if output_file:
with open(output_file, 'w') as f:
f.write('\n'.join(report_lines))
self.stdout.write(
self.style.SUCCESS(f'\nReport saved to: {output_file}')
)
def log(self, message, report_lines):
"""Log message to both stdout and report."""
self.stdout.write(message)
report_lines.append(message)
def run_django_checks(self, report_lines):
"""Run Django's security checks."""
errors = registry.run_checks(tags=[Tags.security])
if not errors:
self.log(
self.style.SUCCESS(" ✓ All Django security checks passed"),
report_lines
)
else:
for error in errors:
if error.is_serious():
prefix = self.style.ERROR(" ✗ ERROR")
else:
prefix = self.style.WARNING(" ! WARNING")
self.log(f"{prefix}: {error.msg}", report_lines)
if error.hint and self.verbose:
self.log(f" Hint: {error.hint}", report_lines)
def check_configuration(self, report_lines):
"""Check various configuration settings."""
checks = [
('DEBUG mode', not settings.DEBUG, 'DEBUG should be False'),
(
'SECRET_KEY length',
len(settings.SECRET_KEY) >= 50,
f'Length: {len(settings.SECRET_KEY)}'
),
(
'ALLOWED_HOSTS',
bool(settings.ALLOWED_HOSTS) and '*' not in settings.ALLOWED_HOSTS,
str(settings.ALLOWED_HOSTS)
),
(
'CSRF_TRUSTED_ORIGINS',
bool(getattr(settings, 'CSRF_TRUSTED_ORIGINS', [])),
str(getattr(settings, 'CSRF_TRUSTED_ORIGINS', []))
),
(
'X_FRAME_OPTIONS',
getattr(settings, 'X_FRAME_OPTIONS', '') in ('DENY', 'SAMEORIGIN'),
str(getattr(settings, 'X_FRAME_OPTIONS', 'Not set'))
),
(
'SECURE_CONTENT_TYPE_NOSNIFF',
getattr(settings, 'SECURE_CONTENT_TYPE_NOSNIFF', False),
str(getattr(settings, 'SECURE_CONTENT_TYPE_NOSNIFF', False))
),
(
'SECURE_BROWSER_XSS_FILTER',
getattr(settings, 'SECURE_BROWSER_XSS_FILTER', False),
str(getattr(settings, 'SECURE_BROWSER_XSS_FILTER', False))
),
(
'SESSION_COOKIE_HTTPONLY',
getattr(settings, 'SESSION_COOKIE_HTTPONLY', True),
str(getattr(settings, 'SESSION_COOKIE_HTTPONLY', 'Not set'))
),
(
'CSRF_COOKIE_HTTPONLY',
getattr(settings, 'CSRF_COOKIE_HTTPONLY', True),
str(getattr(settings, 'CSRF_COOKIE_HTTPONLY', 'Not set'))
),
]
# Production-only checks
if not settings.DEBUG:
checks.extend([
(
'SECURE_SSL_REDIRECT',
getattr(settings, 'SECURE_SSL_REDIRECT', False),
str(getattr(settings, 'SECURE_SSL_REDIRECT', False))
),
(
'SESSION_COOKIE_SECURE',
getattr(settings, 'SESSION_COOKIE_SECURE', False),
str(getattr(settings, 'SESSION_COOKIE_SECURE', False))
),
(
'CSRF_COOKIE_SECURE',
getattr(settings, 'CSRF_COOKIE_SECURE', False),
str(getattr(settings, 'CSRF_COOKIE_SECURE', False))
),
(
'SECURE_HSTS_SECONDS',
getattr(settings, 'SECURE_HSTS_SECONDS', 0) >= 31536000,
str(getattr(settings, 'SECURE_HSTS_SECONDS', 0))
),
])
for name, is_secure, value in checks:
if is_secure:
status = self.style.SUCCESS("")
else:
status = self.style.WARNING("!")
msg = f" {status} {name}"
if self.verbose:
msg += f" ({value})"
self.log(msg, report_lines)
def check_middleware(self, report_lines):
"""Check security-related middleware is properly configured."""
middleware = getattr(settings, 'MIDDLEWARE', [])
required_middleware = [
('django.middleware.security.SecurityMiddleware', 'SecurityMiddleware'),
('django.middleware.csrf.CsrfViewMiddleware', 'CSRF Middleware'),
('django.middleware.clickjacking.XFrameOptionsMiddleware', 'X-Frame-Options'),
]
custom_security_middleware = [
('apps.core.middleware.security_headers.SecurityHeadersMiddleware', 'Security Headers'),
('apps.core.middleware.rate_limiting.AuthRateLimitMiddleware', 'Rate Limiting'),
]
# Check required middleware
for mw_path, mw_name in required_middleware:
if mw_path in middleware:
self.log(
f" {self.style.SUCCESS('')} {mw_name} is enabled",
report_lines
)
else:
self.log(
f" {self.style.ERROR('')} {mw_name} is NOT enabled",
report_lines
)
# Check custom security middleware
for mw_path, mw_name in custom_security_middleware:
if mw_path in middleware:
self.log(
f" {self.style.SUCCESS('')} {mw_name} is enabled",
report_lines
)
else:
self.log(
f" {self.style.WARNING('!')} {mw_name} is not enabled (optional)",
report_lines
)
# Check middleware order
try:
security_idx = middleware.index('django.middleware.security.SecurityMiddleware')
session_idx = middleware.index('django.contrib.sessions.middleware.SessionMiddleware')
if security_idx < session_idx:
self.log(
f" {self.style.SUCCESS('')} Middleware ordering is correct",
report_lines
)
else:
self.log(
f" {self.style.WARNING('!')} SecurityMiddleware should come before SessionMiddleware",
report_lines
)
except ValueError:
pass # Middleware not found, already reported above

View File

@@ -0,0 +1,253 @@
"""
Rate Limiting Middleware for ThrillWiki.
This middleware provides rate limiting for authentication endpoints to prevent
brute force attacks, credential stuffing, and account enumeration.
Security Note:
Rate limiting is applied at the IP level and user level (if authenticated).
Limits are configurable and should be adjusted based on actual usage patterns.
Usage:
Add 'apps.core.middleware.rate_limiting.AuthRateLimitMiddleware'
to MIDDLEWARE in settings.py.
"""
import logging
from typing import Callable, Optional, Tuple
from django.core.cache import cache
from django.http import HttpRequest, HttpResponse, JsonResponse
from django.conf import settings
logger = logging.getLogger(__name__)
class AuthRateLimitMiddleware:
"""
Middleware that rate limits authentication-related endpoints.
Protects against:
- Brute force login attacks
- Password reset abuse
- Account enumeration through timing attacks
"""
# Endpoints to rate limit
RATE_LIMITED_PATHS = {
# Login endpoints
'/api/v1/auth/login/': {'per_minute': 5, 'per_hour': 30, 'per_day': 100},
'/accounts/login/': {'per_minute': 5, 'per_hour': 30, 'per_day': 100},
# Signup endpoints
'/api/v1/auth/signup/': {'per_minute': 3, 'per_hour': 10, 'per_day': 20},
'/accounts/signup/': {'per_minute': 3, 'per_hour': 10, 'per_day': 20},
# Password reset endpoints
'/api/v1/auth/password-reset/': {'per_minute': 2, 'per_hour': 5, 'per_day': 10},
'/accounts/password/reset/': {'per_minute': 2, 'per_hour': 5, 'per_day': 10},
# Token endpoints
'/api/v1/auth/token/': {'per_minute': 10, 'per_hour': 60, 'per_day': 200},
'/api/v1/auth/token/refresh/': {'per_minute': 20, 'per_hour': 120, 'per_day': 500},
}
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
self.get_response = get_response
def __call__(self, request: HttpRequest) -> HttpResponse:
# Only rate limit POST requests to auth endpoints
if request.method != 'POST':
return self.get_response(request)
# Check if this path should be rate limited
limits = self._get_rate_limits(request.path)
if not limits:
return self.get_response(request)
# Get client identifier (IP address)
client_ip = self._get_client_ip(request)
# Check rate limits
is_allowed, message = self._check_rate_limits(
client_ip, request.path, limits
)
if not is_allowed:
logger.warning(
f"Rate limit exceeded for {client_ip} on {request.path}"
)
return self._rate_limit_response(message)
# Process request
response = self.get_response(request)
# Only increment counter for failed auth attempts (non-2xx responses)
if response.status_code >= 400:
self._increment_counters(client_ip, request.path)
return response
def _get_rate_limits(self, path: str) -> Optional[dict]:
"""Get rate limits for a path, if any."""
# Exact match
if path in self.RATE_LIMITED_PATHS:
return self.RATE_LIMITED_PATHS[path]
# Prefix match (for paths with trailing slashes)
path_without_slash = path.rstrip('/')
for limited_path, limits in self.RATE_LIMITED_PATHS.items():
if path_without_slash == limited_path.rstrip('/'):
return limits
return None
def _get_client_ip(self, request: HttpRequest) -> str:
"""
Get the client's IP address from the request.
Handles common proxy headers (X-Forwarded-For, X-Real-IP).
"""
# Check for forwarded headers (set by reverse proxies)
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
if x_forwarded_for:
# Take the first IP in the chain (client IP)
return x_forwarded_for.split(',')[0].strip()
x_real_ip = request.META.get('HTTP_X_REAL_IP')
if x_real_ip:
return x_real_ip
return request.META.get('REMOTE_ADDR', 'unknown')
def _check_rate_limits(
self,
client_ip: str,
path: str,
limits: dict
) -> Tuple[bool, str]:
"""
Check if the client has exceeded rate limits.
Returns:
Tuple of (is_allowed, reason_if_blocked)
"""
# Create a safe cache key from path
path_key = path.replace('/', '_').strip('_')
# Check per-minute limit
minute_key = f"auth_rate:{client_ip}:{path_key}:minute"
minute_count = cache.get(minute_key, 0)
if minute_count >= limits.get('per_minute', 10):
return False, "Too many requests. Please wait a minute before trying again."
# Check per-hour limit
hour_key = f"auth_rate:{client_ip}:{path_key}:hour"
hour_count = cache.get(hour_key, 0)
if hour_count >= limits.get('per_hour', 60):
return False, "Too many requests. Please try again later."
# Check per-day limit
day_key = f"auth_rate:{client_ip}:{path_key}:day"
day_count = cache.get(day_key, 0)
if day_count >= limits.get('per_day', 200):
return False, "Daily limit exceeded. Please try again tomorrow."
return True, ""
def _increment_counters(self, client_ip: str, path: str) -> None:
"""Increment rate limit counters."""
path_key = path.replace('/', '_').strip('_')
# Increment per-minute counter
minute_key = f"auth_rate:{client_ip}:{path_key}:minute"
try:
cache.incr(minute_key)
except ValueError:
cache.set(minute_key, 1, 60)
# Increment per-hour counter
hour_key = f"auth_rate:{client_ip}:{path_key}:hour"
try:
cache.incr(hour_key)
except ValueError:
cache.set(hour_key, 1, 3600)
# Increment per-day counter
day_key = f"auth_rate:{client_ip}:{path_key}:day"
try:
cache.incr(day_key)
except ValueError:
cache.set(day_key, 1, 86400)
def _rate_limit_response(self, message: str) -> JsonResponse:
"""Generate a rate limit exceeded response."""
return JsonResponse(
{
'error': message,
'code': 'RATE_LIMIT_EXCEEDED',
},
status=429, # Too Many Requests
)
class SecurityEventLogger:
"""
Utility class for logging security-relevant events.
Use this to log:
- Failed authentication attempts
- Permission denied events
- Suspicious activity
"""
@staticmethod
def log_failed_login(
request: HttpRequest,
username: str,
reason: str = "Invalid credentials"
) -> None:
"""Log a failed login attempt."""
client_ip = AuthRateLimitMiddleware._get_client_ip(
AuthRateLimitMiddleware, request
)
logger.warning(
f"Failed login attempt - IP: {client_ip}, Username: {username}, "
f"Reason: {reason}, User-Agent: {request.META.get('HTTP_USER_AGENT', 'unknown')}"
)
@staticmethod
def log_permission_denied(
request: HttpRequest,
resource: str,
action: str = "access"
) -> None:
"""Log a permission denied event."""
client_ip = AuthRateLimitMiddleware._get_client_ip(
AuthRateLimitMiddleware, request
)
user = getattr(request, 'user', None)
username = user.username if user and user.is_authenticated else 'anonymous'
logger.warning(
f"Permission denied - IP: {client_ip}, User: {username}, "
f"Resource: {resource}, Action: {action}"
)
@staticmethod
def log_suspicious_activity(
request: HttpRequest,
activity_type: str,
details: str = ""
) -> None:
"""Log suspicious activity."""
client_ip = AuthRateLimitMiddleware._get_client_ip(
AuthRateLimitMiddleware, request
)
user = getattr(request, 'user', None)
username = user.username if user and user.is_authenticated else 'anonymous'
logger.error(
f"Suspicious activity detected - Type: {activity_type}, "
f"IP: {client_ip}, User: {username}, Details: {details}"
)

View File

@@ -114,18 +114,52 @@ class RequestLoggingMiddleware(MiddlewareMixin):
return response
# Sensitive field patterns that should be masked in logs
# Security: Comprehensive list of sensitive data patterns
SENSITIVE_PATTERNS = [
'password',
'passwd',
'pwd',
'token',
'secret',
'key',
'api_key',
'apikey',
'auth',
'authorization',
'credential',
'ssn',
'social_security',
'credit_card',
'creditcard',
'card_number',
'cvv',
'cvc',
'pin',
'access_token',
'refresh_token',
'jwt',
'session',
'cookie',
'private',
]
def _safe_log_data(self, data):
"""Safely log data, truncating if too large and masking sensitive fields."""
"""
Safely log data, truncating if too large and masking sensitive fields.
Security measures:
- Masks all sensitive field names
- Masks email addresses (shows only domain)
- Truncates long values to prevent log flooding
- Recursively processes nested dictionaries and lists
"""
try:
# Convert to string representation
if isinstance(data, dict):
# Mask sensitive fields
safe_data = {}
for key, value in data.items():
if any(sensitive in key.lower() for sensitive in ['password', 'token', 'secret', 'key']):
safe_data[key] = '***MASKED***'
else:
safe_data[key] = value
safe_data = self._mask_sensitive_dict(data)
data_str = json.dumps(safe_data, indent=2, default=str)
elif isinstance(data, list):
safe_data = [self._mask_sensitive_value(item) for item in data]
data_str = json.dumps(safe_data, indent=2, default=str)
else:
data_str = json.dumps(data, indent=2, default=str)
@@ -136,3 +170,37 @@ class RequestLoggingMiddleware(MiddlewareMixin):
return data_str
except Exception:
return str(data)[:500] + '...[ERROR_LOGGING]'
def _mask_sensitive_dict(self, data, depth=0):
"""Recursively mask sensitive fields in a dictionary."""
if depth > 5: # Prevent infinite recursion
return '***DEPTH_LIMIT***'
safe_data = {}
for key, value in data.items():
key_lower = str(key).lower()
# Check if key contains any sensitive pattern
if any(pattern in key_lower for pattern in self.SENSITIVE_PATTERNS):
safe_data[key] = '***MASKED***'
else:
safe_data[key] = self._mask_sensitive_value(value, depth)
return safe_data
def _mask_sensitive_value(self, value, depth=0):
"""Mask a single value, handling different types."""
if isinstance(value, dict):
return self._mask_sensitive_dict(value, depth + 1)
elif isinstance(value, list):
return [self._mask_sensitive_value(item, depth + 1) for item in value[:10]] # Limit list items
elif isinstance(value, str):
# Mask email addresses (show only domain)
if '@' in value and '.' in value.split('@')[-1]:
parts = value.split('@')
if len(parts) == 2:
return f"***@{parts[1]}"
# Truncate long strings
if len(value) > 200:
return value[:200] + '...[TRUNCATED]'
return value

View File

@@ -0,0 +1,196 @@
"""
Security Headers Middleware for ThrillWiki.
This middleware adds additional security headers to all HTTP responses,
providing defense-in-depth against common web vulnerabilities.
Headers added:
- Content-Security-Policy: Controls resource loading to prevent XSS
- Permissions-Policy: Restricts browser feature access
- Cross-Origin-Embedder-Policy: Prevents cross-origin embedding
- Cross-Origin-Resource-Policy: Restricts cross-origin resource access
Usage:
Add 'apps.core.middleware.security_headers.SecurityHeadersMiddleware'
to MIDDLEWARE in settings.py (after SecurityMiddleware).
"""
from django.conf import settings
class SecurityHeadersMiddleware:
"""
Middleware that adds security headers to HTTP responses.
This provides defense-in-depth by adding headers that Django's
SecurityMiddleware doesn't handle.
"""
def __init__(self, get_response):
self.get_response = get_response
# Build CSP header at startup for performance
self._csp_header = self._build_csp_header()
self._permissions_policy_header = self._build_permissions_policy_header()
def __call__(self, request):
response = self.get_response(request)
return self._add_security_headers(response, request)
def _add_security_headers(self, response, request):
"""Add security headers to the response."""
# Content-Security-Policy
# Only add CSP for HTML responses to avoid breaking API/JSON responses
content_type = response.get("Content-Type", "")
if "text/html" in content_type:
if not response.get("Content-Security-Policy"):
response["Content-Security-Policy"] = self._csp_header
# Permissions-Policy (successor to Feature-Policy)
if not response.get("Permissions-Policy"):
response["Permissions-Policy"] = self._permissions_policy_header
# Cross-Origin-Embedder-Policy
# Requires resources to be CORS-enabled or same-origin
# Using 'unsafe-none' for now as 'require-corp' can break third-party resources
if not response.get("Cross-Origin-Embedder-Policy"):
response["Cross-Origin-Embedder-Policy"] = "unsafe-none"
# Cross-Origin-Resource-Policy
# Controls how resources can be shared with other origins
if not response.get("Cross-Origin-Resource-Policy"):
response["Cross-Origin-Resource-Policy"] = "same-origin"
return response
def _build_csp_header(self):
"""
Build the Content-Security-Policy header value.
CSP directives explained:
- default-src: Fallback for other fetch directives
- script-src: Sources for JavaScript
- style-src: Sources for CSS
- img-src: Sources for images
- font-src: Sources for fonts
- connect-src: Sources for fetch, XHR, WebSocket
- frame-ancestors: Controls framing (replaces X-Frame-Options)
- form-action: Valid targets for form submissions
- base-uri: Restricts base element URLs
- object-src: Sources for plugins (Flash, etc.)
"""
# Check if we're in debug mode
debug = getattr(settings, "DEBUG", False)
# Base directives (production-focused)
directives = {
"default-src": ["'self'"],
"script-src": [
"'self'",
# Allow HTMX inline scripts with nonce (would need nonce middleware)
# For now, using 'unsafe-inline' for HTMX compatibility
"'unsafe-inline'" if debug else "'self'",
# CDNs for external scripts
"https://cdn.jsdelivr.net",
"https://unpkg.com",
"https://challenges.cloudflare.com", # Turnstile
],
"style-src": [
"'self'",
"'unsafe-inline'", # Required for Tailwind and inline styles
"https://cdn.jsdelivr.net",
"https://fonts.googleapis.com",
],
"img-src": [
"'self'",
"data:",
"blob:",
"https:", # Allow HTTPS images (needed for user uploads, maps, etc.)
],
"font-src": [
"'self'",
"https://fonts.gstatic.com",
"https://cdn.jsdelivr.net",
],
"connect-src": [
"'self'",
"https://api.forwardemail.net",
"https://challenges.cloudflare.com",
"https://*.cloudflare.com",
# Map tile servers
"https://*.openstreetmap.org",
"https://*.tile.openstreetmap.org",
],
"frame-src": [
"'self'",
"https://challenges.cloudflare.com", # Turnstile widget
],
"frame-ancestors": ["'self'"],
"form-action": ["'self'"],
"base-uri": ["'self'"],
"object-src": ["'none'"],
"upgrade-insecure-requests": [], # Upgrade HTTP to HTTPS
}
# Add debug-specific relaxations
if debug:
# Allow webpack dev server connections in development
directives["connect-src"].extend([
"ws://localhost:*",
"http://localhost:*",
"http://127.0.0.1:*",
])
# Build header string
parts = []
for directive, sources in directives.items():
if sources:
parts.append(f"{directive} {' '.join(sources)}")
else:
# Directives like upgrade-insecure-requests don't need values
parts.append(directive)
return "; ".join(parts)
def _build_permissions_policy_header(self):
"""
Build the Permissions-Policy header value.
This header controls which browser features the page can use.
"""
# Get permissions policy from settings or use defaults
policy = getattr(settings, "PERMISSIONS_POLICY", {
"accelerometer": [],
"ambient-light-sensor": [],
"autoplay": [],
"camera": [],
"display-capture": [],
"document-domain": [],
"encrypted-media": [],
"fullscreen": ["self"],
"geolocation": ["self"],
"gyroscope": [],
"interest-cohort": [],
"magnetometer": [],
"microphone": [],
"midi": [],
"payment": [],
"picture-in-picture": [],
"publickey-credentials-get": [],
"screen-wake-lock": [],
"sync-xhr": [],
"usb": [],
"web-share": ["self"],
"xr-spatial-tracking": [],
})
parts = []
for feature, allowlist in policy.items():
if not allowlist:
# Empty list means disallow completely
parts.append(f"{feature}=()")
else:
# Convert allowlist to proper format
formatted = " ".join(allowlist)
parts.append(f"{feature}=({formatted})")
return ", ".join(parts)

View File

@@ -223,7 +223,7 @@ class MapResponse:
"query_time_ms": self.query_time_ms,
"filters_applied": self.filters_applied,
"pagination": {
"has_more": False, # TODO: Implement pagination
"has_more": False, # TODO(THRILLWIKI-102): Implement pagination for map data
"total_pages": 1,
},
},

View File

@@ -297,7 +297,7 @@ class CompanyLocationAdapter(BaseLocationAdapter):
"""Convert CompanyHeadquarters to UnifiedLocation."""
# Note: CompanyHeadquarters doesn't have coordinates, so we need to geocode
# For now, we'll skip companies without coordinates
# TODO: Implement geocoding service integration
# TODO(THRILLWIKI-101): Implement geocoding service integration for company HQs
return None
def get_queryset(

View File

@@ -0,0 +1,275 @@
"""
Safe HTML Template Tags and Filters for ThrillWiki.
This module provides template tags and filters for safely rendering
HTML content without XSS vulnerabilities.
Security Note:
Always use these filters instead of |safe for user-generated content.
The |safe filter should only be used for content that has been
pre-sanitized in the view layer.
Usage:
{% load safe_html %}
{# Sanitize user content #}
{{ user_description|sanitize }}
{# Minimal sanitization for comments #}
{{ comment_text|sanitize_minimal }}
{# Strip all HTML #}
{{ raw_text|strip_html }}
{# Safe JSON for JavaScript #}
{{ data|json_safe }}
{# Render trusted icon SVG #}
{% icon "check" class="w-4 h-4" %}
"""
import json
from django import template
from django.utils.safestring import mark_safe
from apps.core.utils.html_sanitizer import (
sanitize_html,
sanitize_minimal as _sanitize_minimal,
sanitize_svg,
strip_html as _strip_html,
sanitize_for_json,
escape_js_string as _escape_js_string,
sanitize_url as _sanitize_url,
sanitize_attribute_value,
)
register = template.Library()
# =============================================================================
# HTML Sanitization Filters
# =============================================================================
@register.filter(name='sanitize', is_safe=True)
def sanitize_filter(value):
"""
Sanitize HTML content to prevent XSS attacks.
Allows common formatting tags while stripping dangerous content.
Usage:
{{ user_content|sanitize }}
"""
if not value:
return ''
return mark_safe(sanitize_html(str(value)))
@register.filter(name='sanitize_minimal', is_safe=True)
def sanitize_minimal_filter(value):
"""
Sanitize HTML with minimal allowed tags.
Only allows basic text formatting: p, br, strong, em, i, b, a
Usage:
{{ comment|sanitize_minimal }}
"""
if not value:
return ''
return mark_safe(_sanitize_minimal(str(value)))
@register.filter(name='sanitize_svg', is_safe=True)
def sanitize_svg_filter(value):
"""
Sanitize SVG content for safe inline rendering.
Usage:
{{ icon_svg|sanitize_svg }}
"""
if not value:
return ''
return mark_safe(sanitize_svg(str(value)))
@register.filter(name='strip_html')
def strip_html_filter(value):
"""
Remove all HTML tags from content.
Usage:
{{ html_content|strip_html }}
"""
if not value:
return ''
return _strip_html(str(value))
# =============================================================================
# JavaScript/JSON Context Filters
# =============================================================================
@register.filter(name='json_safe', is_safe=True)
def json_safe_filter(value):
"""
Safely serialize data for embedding in JavaScript.
This is safer than using |safe for JSON data as it properly
escapes </script> and other dangerous sequences.
Usage:
<script>
const data = {{ python_dict|json_safe }};
</script>
"""
if value is None:
return 'null'
return mark_safe(sanitize_for_json(value))
@register.filter(name='escapejs_safe')
def escapejs_safe_filter(value):
"""
Escape a string for safe use in JavaScript string literals.
Usage:
<script>
const message = '{{ user_input|escapejs_safe }}';
</script>
"""
if not value:
return ''
return _escape_js_string(str(value))
# =============================================================================
# URL and Attribute Filters
# =============================================================================
@register.filter(name='sanitize_url')
def sanitize_url_filter(value):
"""
Sanitize a URL to prevent javascript: and other dangerous protocols.
Usage:
<a href="{{ user_url|sanitize_url }}">Link</a>
"""
if not value:
return ''
return _sanitize_url(str(value))
@register.filter(name='attr_safe')
def attr_safe_filter(value):
"""
Escape a value for safe use in HTML attributes.
Usage:
<div data-value="{{ user_value|attr_safe }}">
"""
if not value:
return ''
return sanitize_attribute_value(str(value))
# =============================================================================
# Icon Template Tags
# =============================================================================
# Predefined safe SVG icons
# These are trusted and can be rendered without sanitization
BUILTIN_ICONS = {
'check': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M4.5 12.75l6 6 9-13.5" /></svg>''',
'x': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M6 18L18 6M6 6l12 12" /></svg>''',
'plus': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M12 4.5v15m7.5-7.5h-15" /></svg>''',
'minus': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M19.5 12h-15" /></svg>''',
'chevron-down': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M19.5 8.25l-7.5 7.5-7.5-7.5" /></svg>''',
'chevron-up': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M4.5 15.75l7.5-7.5 7.5 7.5" /></svg>''',
'chevron-left': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M15.75 19.5L8.25 12l7.5-7.5" /></svg>''',
'chevron-right': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M8.25 4.5l7.5 7.5-7.5 7.5" /></svg>''',
'search': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M21 21l-5.197-5.197m0 0A7.5 7.5 0 105.196 5.196a7.5 7.5 0 0010.607 10.607z" /></svg>''',
'menu': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3.75 6.75h16.5M3.75 12h16.5m-16.5 5.25h16.5" /></svg>''',
'user': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M15.75 6a3.75 3.75 0 11-7.5 0 3.75 3.75 0 017.5 0zM4.501 20.118a7.5 7.5 0 0114.998 0A17.933 17.933 0 0112 21.75c-2.676 0-5.216-.584-7.499-1.632z" /></svg>''',
'cog': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M9.594 3.94c.09-.542.56-.94 1.11-.94h2.593c.55 0 1.02.398 1.11.94l.213 1.281c.063.374.313.686.645.87.074.04.147.083.22.127.324.196.72.257 1.075.124l1.217-.456a1.125 1.125 0 011.37.49l1.296 2.247a1.125 1.125 0 01-.26 1.431l-1.003.827c-.293.24-.438.613-.431.992a6.759 6.759 0 010 .255c-.007.378.138.75.43.99l1.005.828c.424.35.534.954.26 1.43l-1.298 2.247a1.125 1.125 0 01-1.369.491l-1.217-.456c-.355-.133-.75-.072-1.076.124a6.57 6.57 0 01-.22.128c-.331.183-.581.495-.644.869l-.213 1.28c-.09.543-.56.941-1.11.941h-2.594c-.55 0-1.02-.398-1.11-.94l-.213-1.281c-.062-.374-.312-.686-.644-.87a6.52 6.52 0 01-.22-.127c-.325-.196-.72-.257-1.076-.124l-1.217.456a1.125 1.125 0 01-1.369-.49l-1.297-2.247a1.125 1.125 0 01.26-1.431l1.004-.827c.292-.24.437-.613.43-.992a6.932 6.932 0 010-.255c.007-.378-.138-.75-.43-.99l-1.004-.828a1.125 1.125 0 01-.26-1.43l1.297-2.247a1.125 1.125 0 011.37-.491l1.216.456c.356.133.751.072 1.076-.124.072-.044.146-.087.22-.128.332-.183.582-.495.644-.869l.214-1.281z" /><path stroke-linecap="round" stroke-linejoin="round" d="M15 12a3 3 0 11-6 0 3 3 0 016 0z" /></svg>''',
'trash': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M14.74 9l-.346 9m-4.788 0L9.26 9m9.968-3.21c.342.052.682.107 1.022.166m-1.022-.165L18.16 19.673a2.25 2.25 0 01-2.244 2.077H8.084a2.25 2.25 0 01-2.244-2.077L4.772 5.79m14.456 0a48.108 48.108 0 00-3.478-.397m-12 .562c.34-.059.68-.114 1.022-.165m0 0a48.11 48.11 0 013.478-.397m7.5 0v-.916c0-1.18-.91-2.164-2.09-2.201a51.964 51.964 0 00-3.32 0c-1.18.037-2.09 1.022-2.09 2.201v.916m7.5 0a48.667 48.667 0 00-7.5 0" /></svg>''',
'pencil': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M16.862 4.487l1.687-1.688a1.875 1.875 0 112.652 2.652L10.582 16.07a4.5 4.5 0 01-1.897 1.13L6 18l.8-2.685a4.5 4.5 0 011.13-1.897l8.932-8.931zm0 0L19.5 7.125M18 14v4.75A2.25 2.25 0 0115.75 21H5.25A2.25 2.25 0 013 18.75V8.25A2.25 2.25 0 015.25 6H10" /></svg>''',
'eye': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M2.036 12.322a1.012 1.012 0 010-.639C3.423 7.51 7.36 4.5 12 4.5c4.638 0 8.573 3.007 9.963 7.178.07.207.07.431 0 .639C20.577 16.49 16.64 19.5 12 19.5c-4.638 0-8.573-3.007-9.963-7.178z" /><path stroke-linecap="round" stroke-linejoin="round" d="M15 12a3 3 0 11-6 0 3 3 0 016 0z" /></svg>''',
'eye-slash': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3.98 8.223A10.477 10.477 0 001.934 12C3.226 16.338 7.244 19.5 12 19.5c.993 0 1.953-.138 2.863-.395M6.228 6.228A10.45 10.45 0 0112 4.5c4.756 0 8.773 3.162 10.065 7.498a10.523 10.523 0 01-4.293 5.774M6.228 6.228L3 3m3.228 3.228l3.65 3.65m7.894 7.894L21 21m-3.228-3.228l-3.65-3.65m0 0a3 3 0 10-4.243-4.243m4.242 4.242L9.88 9.88" /></svg>''',
'arrow-left': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M10.5 19.5L3 12m0 0l7.5-7.5M3 12h18" /></svg>''',
'arrow-right': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M13.5 4.5L21 12m0 0l-7.5 7.5M21 12H3" /></svg>''',
'info': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M11.25 11.25l.041-.02a.75.75 0 011.063.852l-.708 2.836a.75.75 0 001.063.853l.041-.021M21 12a9 9 0 11-18 0 9 9 0 0118 0zm-9-3.75h.008v.008H12V8.25z" /></svg>''',
'warning': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M12 9v3.75m-9.303 3.376c-.866 1.5.217 3.374 1.948 3.374h14.71c1.73 0 2.813-1.874 1.948-3.374L13.949 3.378c-.866-1.5-3.032-1.5-3.898 0L2.697 16.126zM12 15.75h.007v.008H12v-.008z" /></svg>''',
'error': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M12 9v3.75m9-.75a9 9 0 11-18 0 9 9 0 0118 0zm-9 3.75h.008v.008H12v-.008z" /></svg>''',
'success': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M9 12.75L11.25 15 15 9.75M21 12a9 9 0 11-18 0 9 9 0 0118 0z" /></svg>''',
'loading': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" {attrs}><circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle><path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>''',
'external-link': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M13.5 6H5.25A2.25 2.25 0 003 8.25v10.5A2.25 2.25 0 005.25 21h10.5A2.25 2.25 0 0018 18.75V10.5m-10.5 6L21 3m0 0h-5.25M21 3v5.25" /></svg>''',
'download': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3 16.5v2.25A2.25 2.25 0 005.25 21h13.5A2.25 2.25 0 0021 18.75V16.5M16.5 12L12 16.5m0 0L7.5 12m4.5 4.5V3" /></svg>''',
'upload': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M3 16.5v2.25A2.25 2.25 0 005.25 21h13.5A2.25 2.25 0 0021 18.75V16.5m-13.5-9L12 3m0 0l4.5 4.5M12 3v13.5" /></svg>''',
'star': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M11.48 3.499a.562.562 0 011.04 0l2.125 5.111a.563.563 0 00.475.345l5.518.442c.499.04.701.663.321.988l-4.204 3.602a.563.563 0 00-.182.557l1.285 5.385a.562.562 0 01-.84.61l-4.725-2.885a.563.563 0 00-.586 0L6.982 20.54a.562.562 0 01-.84-.61l1.285-5.386a.562.562 0 00-.182-.557l-4.204-3.602a.563.563 0 01.321-.988l5.518-.442a.563.563 0 00.475-.345L11.48 3.5z" /></svg>''',
'star-filled': '''<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" {attrs}><path fill-rule="evenodd" d="M10.788 3.21c.448-1.077 1.976-1.077 2.424 0l2.082 5.007 5.404.433c1.164.093 1.636 1.545.749 2.305l-4.117 3.527 1.257 5.273c.271 1.136-.964 2.033-1.96 1.425L12 18.354 7.373 21.18c-.996.608-2.231-.29-1.96-1.425l1.257-5.273-4.117-3.527c-.887-.76-.415-2.212.749-2.305l5.404-.433 2.082-5.006z" clip-rule="evenodd" /></svg>''',
'heart': '''<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor" {attrs}><path stroke-linecap="round" stroke-linejoin="round" d="M21 8.25c0-2.485-2.099-4.5-4.688-4.5-1.935 0-3.597 1.126-4.312 2.733-.715-1.607-2.377-2.733-4.313-2.733C5.1 3.75 3 5.765 3 8.25c0 7.22 9 12 9 12s9-4.78 9-12z" /></svg>''',
'heart-filled': '''<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" {attrs}><path d="M11.645 20.91l-.007-.003-.022-.012a15.247 15.247 0 01-.383-.218 25.18 25.18 0 01-4.244-3.17C4.688 15.36 2.25 12.174 2.25 8.25 2.25 5.322 4.714 3 7.688 3A5.5 5.5 0 0112 5.052 5.5 5.5 0 0116.313 3c2.973 0 5.437 2.322 5.437 5.25 0 3.925-2.438 7.111-4.739 9.256a25.175 25.175 0 01-4.244 3.17 15.247 15.247 0 01-.383.219l-.022.012-.007.004-.003.001a.752.752 0 01-.704 0l-.003-.001z" /></svg>''',
}
@register.simple_tag
def icon(name, **kwargs):
"""
Render a trusted SVG icon.
This tag renders predefined SVG icons that are trusted and safe.
Custom attributes can be passed to customize the icon.
Usage:
{% icon "check" class="w-4 h-4 text-green-500" %}
{% icon "x" class="w-6 h-6" aria_hidden="true" %}
Args:
name: The icon name (from BUILTIN_ICONS)
**kwargs: Additional HTML attributes for the SVG
Returns:
Safe HTML for the icon SVG
"""
svg_template = BUILTIN_ICONS.get(name)
if not svg_template:
# Return empty string for unknown icons (fail silently)
return ''
# Build attributes string
attrs_list = []
for key, value in kwargs.items():
# Convert underscore to hyphen for HTML attributes (e.g., aria_hidden -> aria-hidden)
attr_name = key.replace('_', '-')
# Escape attribute values to prevent XSS
safe_value = sanitize_attribute_value(str(value))
attrs_list.append(f'{attr_name}="{safe_value}"')
attrs_str = ' '.join(attrs_list)
# Substitute attributes into template
svg = svg_template.format(attrs=attrs_str)
return mark_safe(svg)
@register.simple_tag
def icon_class(name, size='w-5 h-5', extra_class=''):
"""
Render a trusted SVG icon with common class presets.
Usage:
{% icon_class "check" size="w-4 h-4" extra_class="text-green-500" %}
Args:
name: The icon name
size: Size classes (default: "w-5 h-5")
extra_class: Additional CSS classes
Returns:
Safe HTML for the icon SVG
"""
classes = f'{size} {extra_class}'.strip()
return icon(name, **{'class': classes})

View File

@@ -0,0 +1,161 @@
"""
Centralized error handling utilities.
This module provides standardized error handling for views across the application,
ensuring consistent logging, user messages, and API responses.
"""
import logging
from typing import Any, Dict, Optional
from django.contrib import messages
from django.http import HttpRequest
from rest_framework import status
from rest_framework.response import Response
from apps.core.exceptions import ThrillWikiException
logger = logging.getLogger(__name__)
class ErrorHandler:
"""Centralized error handling for views."""
@staticmethod
def handle_view_error(
request: HttpRequest,
error: Exception,
user_message: str = "An error occurred",
log_message: Optional[str] = None,
level: str = "error",
) -> None:
"""
Handle errors in template views.
Logs the error with appropriate context and displays a user-friendly
message using Django's messages framework.
Args:
request: HTTP request object
error: Exception that occurred
user_message: Message to show to the user (should be user-friendly)
log_message: Message to log (defaults to str(error) with user_message prefix)
level: Log level - one of "error", "warning", "info"
Example:
try:
ParkService.create_park(...)
except ServiceError as e:
ErrorHandler.handle_view_error(
request,
e,
user_message="Failed to create park",
log_message=f"Park creation failed for user {request.user.id}"
)
"""
log_msg = log_message or f"{user_message}: {str(error)}"
if level == "error":
logger.error(log_msg, exc_info=True)
elif level == "warning":
logger.warning(log_msg)
else:
logger.info(log_msg)
messages.error(request, user_message)
@staticmethod
def handle_api_error(
error: Exception,
user_message: str = "An error occurred",
log_message: Optional[str] = None,
status_code: int = status.HTTP_400_BAD_REQUEST,
) -> Response:
"""
Handle errors in API views.
Logs the error and returns a standardized DRF Response with error details.
Args:
error: Exception that occurred
user_message: Message to return to the client (should be user-friendly)
log_message: Message to log (defaults to str(error) with user_message prefix)
status_code: HTTP status code to return
Returns:
DRF Response with error details in standard format
Example:
try:
result = ParkService.create_park(...)
except ServiceError as e:
return ErrorHandler.handle_api_error(
e,
user_message="Failed to create park photo",
status_code=status.HTTP_400_BAD_REQUEST
)
"""
log_msg = log_message or f"{user_message}: {str(error)}"
logger.error(log_msg, exc_info=True)
# Build error response
error_data: Dict[str, Any] = {
"error": user_message,
"detail": str(error),
}
# Include additional details for ThrillWikiException subclasses
if isinstance(error, ThrillWikiException):
error_data["error_code"] = error.error_code
if error.details:
error_data["details"] = error.details
return Response(error_data, status=status_code)
@staticmethod
def handle_success(
request: HttpRequest,
message: str,
level: str = "success",
) -> None:
"""
Handle success messages in template views.
Args:
request: HTTP request object
message: Success message to display
level: Message level - one of "success", "info", "warning"
"""
if level == "success":
messages.success(request, message)
elif level == "info":
messages.info(request, message)
elif level == "warning":
messages.warning(request, message)
@staticmethod
def api_success_response(
data: Any = None,
message: str = "Success",
status_code: int = status.HTTP_200_OK,
) -> Response:
"""
Create a standardized success response for API views.
Args:
data: Response data (optional)
message: Success message
status_code: HTTP status code
Returns:
DRF Response with success data in standard format
"""
response_data: Dict[str, Any] = {
"status": "success",
"message": message,
}
if data is not None:
response_data["data"] = data
return Response(response_data, status=status_code)

View File

@@ -0,0 +1,432 @@
"""
File Upload Security Utilities for ThrillWiki.
This module provides comprehensive file validation and security checks for
file uploads, including:
- File type validation (MIME type and magic number verification)
- File size validation
- Filename sanitization
- Image-specific validation
Security Note:
Always validate uploaded files before saving them. Never trust user-provided
file extensions or Content-Type headers alone.
Usage:
from apps.core.utils.file_scanner import validate_image_upload, sanitize_filename
# In view
try:
validate_image_upload(uploaded_file)
# Safe to save file
except FileValidationError as e:
return JsonResponse({'error': str(e)}, status=400)
# Sanitize filename
safe_name = sanitize_filename(uploaded_file.name)
"""
import os
import re
import uuid
from io import BytesIO
from typing import Optional, Set, Tuple
from django.core.exceptions import ValidationError
from django.core.files.uploadedfile import UploadedFile
class FileValidationError(ValidationError):
"""Custom exception for file validation errors."""
pass
# =============================================================================
# Image Magic Number Signatures
# =============================================================================
# Magic number signatures for common image formats
# Format: (magic_bytes, offset, description)
IMAGE_SIGNATURES = {
'jpeg': [
(b'\xFF\xD8\xFF\xE0', 0, 'JPEG (JFIF)'),
(b'\xFF\xD8\xFF\xE1', 0, 'JPEG (EXIF)'),
(b'\xFF\xD8\xFF\xE2', 0, 'JPEG (ICC)'),
(b'\xFF\xD8\xFF\xE3', 0, 'JPEG (Samsung)'),
(b'\xFF\xD8\xFF\xE8', 0, 'JPEG (SPIFF)'),
(b'\xFF\xD8\xFF\xDB', 0, 'JPEG (Raw)'),
],
'png': [
(b'\x89PNG\r\n\x1a\n', 0, 'PNG'),
],
'gif': [
(b'GIF87a', 0, 'GIF87a'),
(b'GIF89a', 0, 'GIF89a'),
],
'webp': [
(b'RIFF', 0, 'RIFF'), # WebP starts with RIFF header
],
'bmp': [
(b'BM', 0, 'BMP'),
],
}
# All allowed MIME types
ALLOWED_IMAGE_MIME_TYPES: Set[str] = frozenset({
'image/jpeg',
'image/png',
'image/gif',
'image/webp',
})
# Allowed file extensions
ALLOWED_IMAGE_EXTENSIONS: Set[str] = frozenset({
'.jpg', '.jpeg', '.png', '.gif', '.webp',
})
# Maximum file size (10MB)
MAX_FILE_SIZE = 10 * 1024 * 1024
# Minimum file size (prevent empty files)
MIN_FILE_SIZE = 100 # 100 bytes
# =============================================================================
# File Validation Functions
# =============================================================================
def validate_image_upload(
file: UploadedFile,
max_size: int = MAX_FILE_SIZE,
allowed_types: Optional[Set[str]] = None,
allowed_extensions: Optional[Set[str]] = None,
) -> bool:
"""
Validate an uploaded image file for security.
Performs multiple validation checks:
1. File size validation
2. File extension validation
3. MIME type validation (from Content-Type header)
4. Magic number validation (actual file content check)
5. Image integrity validation (using PIL)
Args:
file: The uploaded file object
max_size: Maximum allowed file size in bytes
allowed_types: Set of allowed MIME types
allowed_extensions: Set of allowed file extensions
Returns:
True if all validations pass
Raises:
FileValidationError: If any validation fails
"""
if allowed_types is None:
allowed_types = ALLOWED_IMAGE_MIME_TYPES
if allowed_extensions is None:
allowed_extensions = ALLOWED_IMAGE_EXTENSIONS
# 1. Check if file exists
if not file:
raise FileValidationError("No file provided")
# 2. Check file size
if file.size > max_size:
raise FileValidationError(
f"File too large. Maximum size is {max_size // (1024 * 1024)}MB"
)
if file.size < MIN_FILE_SIZE:
raise FileValidationError("File too small or empty")
# 3. Check file extension
filename = file.name or ''
ext = os.path.splitext(filename)[1].lower()
if ext not in allowed_extensions:
raise FileValidationError(
f"Invalid file extension '{ext}'. Allowed: {', '.join(allowed_extensions)}"
)
# 4. Check Content-Type header
content_type = getattr(file, 'content_type', '')
if content_type and content_type not in allowed_types:
raise FileValidationError(
f"Invalid file type '{content_type}'. Allowed: {', '.join(allowed_types)}"
)
# 5. Validate magic numbers (actual file content)
if not _validate_magic_number(file):
raise FileValidationError(
"File content doesn't match file extension. File may be corrupted or malicious."
)
# 6. Validate image integrity using PIL
if not _validate_image_integrity(file):
raise FileValidationError(
"Invalid or corrupted image file"
)
return True
def _validate_magic_number(file: UploadedFile) -> bool:
"""
Validate file content against known magic number signatures.
This is more reliable than checking the file extension or Content-Type
header, which can be easily spoofed.
Args:
file: The uploaded file object
Returns:
True if magic number matches an allowed image type
"""
# Read the file header
file.seek(0)
header = file.read(16)
file.seek(0)
# Check against known signatures
for format_name, signatures in IMAGE_SIGNATURES.items():
for magic, offset, description in signatures:
if len(header) >= offset + len(magic):
if header[offset:offset + len(magic)] == magic:
# Special handling for WebP (must also have WEBP marker)
if format_name == 'webp':
if len(header) >= 12 and header[8:12] == b'WEBP':
return True
else:
return True
return False
def _validate_image_integrity(file: UploadedFile) -> bool:
"""
Validate image integrity using PIL.
This catches corrupted images and various image-related attacks.
Args:
file: The uploaded file object
Returns:
True if image can be opened and verified
"""
try:
from PIL import Image
file.seek(0)
# Read into BytesIO to avoid issues with file-like objects
img_data = BytesIO(file.read())
file.seek(0)
with Image.open(img_data) as img:
# Verify the image is not truncated or corrupted
img.verify()
# Re-open for size check (verify() can only be called once)
img_data.seek(0)
with Image.open(img_data) as img2:
# Check for reasonable image dimensions
# Prevent decompression bombs
max_dimension = 10000
if img2.width > max_dimension or img2.height > max_dimension:
raise FileValidationError(
f"Image dimensions too large. Maximum is {max_dimension}x{max_dimension}"
)
# Check for very small dimensions (might be suspicious)
if img2.width < 1 or img2.height < 1:
raise FileValidationError("Invalid image dimensions")
return True
except FileValidationError:
raise
except Exception:
return False
# =============================================================================
# Filename Sanitization
# =============================================================================
def sanitize_filename(filename: str, max_length: int = 100) -> str:
"""
Sanitize a filename to prevent directory traversal and other attacks.
This function:
- Removes path separators and directory traversal attempts
- Removes special characters
- Truncates to maximum length
- Ensures the filename is not empty
Args:
filename: The original filename
max_length: Maximum length for the filename
Returns:
Sanitized filename
"""
if not filename:
return f"file_{uuid.uuid4().hex[:8]}"
# Get just the filename (remove any path components)
filename = os.path.basename(filename)
# Split into name and extension
name, ext = os.path.splitext(filename)
# Remove or replace dangerous characters from name
# Allow alphanumeric, hyphens, underscores, dots
name = re.sub(r'[^\w\-.]', '_', name)
# Remove leading dots and underscores (hidden file prevention)
name = name.lstrip('._')
# Collapse multiple underscores
name = re.sub(r'_+', '_', name)
# Ensure name is not empty
if not name:
name = f"file_{uuid.uuid4().hex[:8]}"
# Sanitize extension
ext = ext.lower()
ext = re.sub(r'[^\w.]', '', ext)
# Combine and truncate
result = f"{name[:max_length - len(ext)]}{ext}"
return result
def generate_unique_filename(original_filename: str, prefix: str = '') -> str:
"""
Generate a unique filename using UUID while preserving extension.
Args:
original_filename: The original filename
prefix: Optional prefix for the filename
Returns:
Unique filename with UUID
"""
ext = os.path.splitext(original_filename)[1].lower()
# Sanitize extension
ext = re.sub(r'[^\w.]', '', ext)
# Generate unique filename
unique_id = uuid.uuid4().hex[:12]
if prefix:
return f"{sanitize_filename(prefix)}_{unique_id}{ext}"
return f"{unique_id}{ext}"
# =============================================================================
# Rate Limiting for Uploads
# =============================================================================
# Rate limiting configuration
UPLOAD_RATE_LIMITS = {
'per_minute': 10,
'per_hour': 100,
'per_day': 500,
}
def check_upload_rate_limit(user_id: int, cache_backend=None) -> Tuple[bool, str]:
"""
Check if user has exceeded upload rate limits.
Args:
user_id: The user's ID
cache_backend: Optional Django cache backend (uses default if not provided)
Returns:
Tuple of (is_allowed, reason_if_blocked)
"""
if cache_backend is None:
from django.core.cache import cache
cache_backend = cache
# Check per-minute limit
minute_key = f"upload_rate:{user_id}:minute"
minute_count = cache_backend.get(minute_key, 0)
if minute_count >= UPLOAD_RATE_LIMITS['per_minute']:
return False, "Upload rate limit exceeded. Please wait a minute."
# Check per-hour limit
hour_key = f"upload_rate:{user_id}:hour"
hour_count = cache_backend.get(hour_key, 0)
if hour_count >= UPLOAD_RATE_LIMITS['per_hour']:
return False, "Hourly upload limit exceeded. Please try again later."
# Check per-day limit
day_key = f"upload_rate:{user_id}:day"
day_count = cache_backend.get(day_key, 0)
if day_count >= UPLOAD_RATE_LIMITS['per_day']:
return False, "Daily upload limit exceeded. Please try again tomorrow."
return True, ""
def increment_upload_count(user_id: int, cache_backend=None) -> None:
"""
Increment upload count for rate limiting.
Args:
user_id: The user's ID
cache_backend: Optional Django cache backend
"""
if cache_backend is None:
from django.core.cache import cache
cache_backend = cache
# Increment per-minute counter (expires in 60 seconds)
minute_key = f"upload_rate:{user_id}:minute"
try:
cache_backend.incr(minute_key)
except ValueError:
cache_backend.set(minute_key, 1, 60)
# Increment per-hour counter (expires in 3600 seconds)
hour_key = f"upload_rate:{user_id}:hour"
try:
cache_backend.incr(hour_key)
except ValueError:
cache_backend.set(hour_key, 1, 3600)
# Increment per-day counter (expires in 86400 seconds)
day_key = f"upload_rate:{user_id}:day"
try:
cache_backend.incr(day_key)
except ValueError:
cache_backend.set(day_key, 1, 86400)
# =============================================================================
# Antivirus Integration Point
# =============================================================================
def scan_file_for_malware(file: UploadedFile) -> Tuple[bool, str]:
"""
Placeholder for antivirus/malware scanning integration.
This function should be implemented to integrate with a virus scanner
like ClamAV. Currently it returns True (safe) for all files.
Args:
file: The uploaded file object
Returns:
Tuple of (is_safe, reason_if_unsafe)
"""
# TODO(THRILLWIKI-110): Implement ClamAV integration for malware scanning
# This requires ClamAV daemon to be running and python-clamav to be installed
return True, ""

View File

@@ -0,0 +1,382 @@
"""
HTML Sanitization Utilities for ThrillWiki.
This module provides functions for sanitizing user-generated HTML content
to prevent XSS (Cross-Site Scripting) attacks while allowing safe HTML
formatting.
Security Note:
Always sanitize user-generated content before rendering with |safe filter
or mark_safe(). Never trust user input.
Usage:
from apps.core.utils.html_sanitizer import sanitize_html, sanitize_for_json
# In views
context['description'] = sanitize_html(user_input)
# For JSON/JavaScript contexts
json_safe = sanitize_for_json(data)
"""
import json
import re
from html import escape as html_escape
from typing import Any
try:
import bleach
BLEACH_AVAILABLE = True
except ImportError:
BLEACH_AVAILABLE = False
# =============================================================================
# Allowed HTML Configuration
# =============================================================================
# Default allowed HTML tags for user-generated content
ALLOWED_TAGS = frozenset([
# Text formatting
'p', 'br', 'hr',
'strong', 'b', 'em', 'i', 'u', 's', 'strike',
'sub', 'sup', 'small', 'mark',
# Headers
'h1', 'h2', 'h3', 'h4', 'h5', 'h6',
# Lists
'ul', 'ol', 'li',
# Links (with restrictions on attributes)
'a',
# Block elements
'blockquote', 'pre', 'code',
'div', 'span',
# Tables
'table', 'thead', 'tbody', 'tfoot', 'tr', 'th', 'td',
])
# Allowed attributes for each tag
ALLOWED_ATTRIBUTES = {
'a': ['href', 'title', 'rel', 'target'],
'img': ['src', 'alt', 'title', 'width', 'height'],
'div': ['class'],
'span': ['class'],
'p': ['class'],
'table': ['class'],
'th': ['class', 'colspan', 'rowspan'],
'td': ['class', 'colspan', 'rowspan'],
'*': ['class'], # Allow class on all elements
}
# Allowed URL protocols
ALLOWED_PROTOCOLS = frozenset([
'http', 'https', 'mailto', 'tel',
])
# Minimal tags for comments and short text
MINIMAL_TAGS = frozenset([
'p', 'br', 'strong', 'b', 'em', 'i', 'a',
])
# Tags allowed in icon SVGs (for icon template rendering)
SVG_TAGS = frozenset([
'svg', 'path', 'g', 'circle', 'rect', 'line', 'polyline', 'polygon',
'ellipse', 'text', 'tspan', 'defs', 'use', 'symbol', 'clipPath',
'mask', 'linearGradient', 'radialGradient', 'stop', 'title',
])
SVG_ATTRIBUTES = {
'svg': ['viewBox', 'width', 'height', 'fill', 'stroke', 'class',
'xmlns', 'aria-hidden', 'role'],
'path': ['d', 'fill', 'stroke', 'stroke-width', 'stroke-linecap',
'stroke-linejoin', 'class', 'fill-rule', 'clip-rule'],
'g': ['fill', 'stroke', 'transform', 'class'],
'circle': ['cx', 'cy', 'r', 'fill', 'stroke', 'class'],
'rect': ['x', 'y', 'width', 'height', 'rx', 'ry', 'fill', 'stroke', 'class'],
'line': ['x1', 'y1', 'x2', 'y2', 'stroke', 'stroke-width', 'class'],
'polyline': ['points', 'fill', 'stroke', 'class'],
'polygon': ['points', 'fill', 'stroke', 'class'],
'*': ['class', 'fill', 'stroke'],
}
# =============================================================================
# Sanitization Functions
# =============================================================================
def sanitize_html(
html: str | None,
allowed_tags: frozenset | None = None,
allowed_attributes: dict | None = None,
allowed_protocols: frozenset | None = None,
strip: bool = True,
) -> str:
"""
Sanitize HTML content to prevent XSS attacks.
Args:
html: The HTML string to sanitize
allowed_tags: Set of allowed HTML tag names
allowed_attributes: Dict mapping tag names to allowed attributes
allowed_protocols: Set of allowed URL protocols
strip: If True, remove disallowed tags; if False, escape them
Returns:
Sanitized HTML string safe for rendering
Example:
>>> sanitize_html('<script>alert("xss")</script><p>Hello</p>')
'<p>Hello</p>'
"""
if not html:
return ''
if not isinstance(html, str):
html = str(html)
if not BLEACH_AVAILABLE:
# Fallback: escape all HTML if bleach is not available
return html_escape(html)
tags = allowed_tags if allowed_tags is not None else ALLOWED_TAGS
attrs = allowed_attributes if allowed_attributes is not None else ALLOWED_ATTRIBUTES
protocols = allowed_protocols if allowed_protocols is not None else ALLOWED_PROTOCOLS
return bleach.clean(
html,
tags=tags,
attributes=attrs,
protocols=protocols,
strip=strip,
)
def sanitize_minimal(html: str | None) -> str:
"""
Sanitize HTML with minimal allowed tags.
Use this for user comments, short descriptions, etc.
Args:
html: The HTML string to sanitize
Returns:
Sanitized HTML with only basic formatting tags allowed
"""
return sanitize_html(
html,
allowed_tags=MINIMAL_TAGS,
allowed_attributes={'a': ['href', 'title']},
)
def sanitize_svg(svg: str | None) -> str:
"""
Sanitize SVG content for safe inline rendering.
This is specifically for icon SVGs that need to be rendered inline.
Removes potentially dangerous elements while preserving SVG structure.
Args:
svg: The SVG string to sanitize
Returns:
Sanitized SVG string safe for inline rendering
"""
if not svg:
return ''
if not isinstance(svg, str):
svg = str(svg)
if not BLEACH_AVAILABLE:
# Fallback: escape all if bleach is not available
return html_escape(svg)
return bleach.clean(
svg,
tags=SVG_TAGS,
attributes=SVG_ATTRIBUTES,
strip=True,
)
def strip_html(html: str | None) -> str:
"""
Remove all HTML tags from a string.
Use this for contexts where no HTML is allowed at all.
Args:
html: The HTML string to strip
Returns:
Plain text with all HTML tags removed
"""
if not html:
return ''
if not isinstance(html, str):
html = str(html)
if BLEACH_AVAILABLE:
return bleach.clean(html, tags=[], strip=True)
else:
# Fallback: use regex to strip tags
return re.sub(r'<[^>]+>', '', html)
# =============================================================================
# JSON/JavaScript Context Sanitization
# =============================================================================
def sanitize_for_json(data: Any) -> str:
"""
Safely serialize data for embedding in JavaScript/JSON contexts.
This prevents XSS when embedding data in <script> tags or JavaScript.
Args:
data: The data to serialize (dict, list, or primitive)
Returns:
JSON string safe for embedding in JavaScript
Example:
>>> sanitize_for_json({'name': '</script><script>alert("xss")'})
'{"name": "\\u003c/script\\u003e\\u003cscript\\u003ealert(\\"xss\\")"}'
"""
# JSON encode with safe characters escaped
return json.dumps(data, ensure_ascii=False).replace(
'<', '\\u003c'
).replace(
'>', '\\u003e'
).replace(
'&', '\\u0026'
).replace(
"'", '\\u0027'
)
def escape_js_string(s: str | None) -> str:
"""
Escape a string for safe use in JavaScript string literals.
Args:
s: The string to escape
Returns:
Escaped string safe for JavaScript contexts
"""
if not s:
return ''
if not isinstance(s, str):
s = str(s)
# Escape backslashes first, then other special characters
return s.replace('\\', '\\\\').replace(
"'", "\\'"
).replace(
'"', '\\"'
).replace(
'\n', '\\n'
).replace(
'\r', '\\r'
).replace(
'<', '\\u003c'
).replace(
'>', '\\u003e'
).replace(
'&', '\\u0026'
)
# =============================================================================
# URL Sanitization
# =============================================================================
def sanitize_url(url: str | None, allowed_protocols: frozenset | None = None) -> str:
"""
Sanitize a URL to prevent javascript: and other dangerous protocols.
Args:
url: The URL to sanitize
allowed_protocols: Set of allowed URL protocols
Returns:
Sanitized URL or empty string if unsafe
"""
if not url:
return ''
if not isinstance(url, str):
url = str(url)
url = url.strip()
if not url:
return ''
protocols = allowed_protocols if allowed_protocols is not None else ALLOWED_PROTOCOLS
# Check for allowed protocols
url_lower = url.lower()
# Check for javascript:, data:, vbscript:, etc.
if ':' in url_lower:
protocol = url_lower.split(':')[0]
if protocol not in protocols:
# Allow relative URLs and anchor links
if not (url.startswith('/') or url.startswith('#') or url.startswith('?')):
return ''
return url
# =============================================================================
# Attribute Sanitization
# =============================================================================
def sanitize_attribute_value(value: str | None) -> str:
"""
Sanitize a value for use in HTML attributes.
Args:
value: The attribute value to sanitize
Returns:
Sanitized value safe for HTML attribute contexts
"""
if not value:
return ''
if not isinstance(value, str):
value = str(value)
# HTML escape for attribute context
return html_escape(value, quote=True)
def sanitize_class_name(name: str | None) -> str:
"""
Sanitize a CSS class name.
Args:
name: The class name to sanitize
Returns:
Sanitized class name containing only safe characters
"""
if not name:
return ''
if not isinstance(name, str):
name = str(name)
# Only allow alphanumeric, hyphens, and underscores
return re.sub(r'[^a-zA-Z0-9_-]', '', name)

View File

@@ -0,0 +1,79 @@
"""
Base view classes with common patterns.
This module provides base view classes that implement common patterns
such as automatic query optimization with select_related and prefetch_related.
"""
from typing import List
from django.db.models import QuerySet
from django.views.generic import DetailView, ListView
class OptimizedListView(ListView):
"""
ListView with automatic query optimization.
Automatically applies select_related and prefetch_related based on
class attributes, reducing the need for boilerplate code in get_queryset.
Attributes:
select_related_fields: List of fields to pass to select_related()
prefetch_related_fields: List of fields to pass to prefetch_related()
Example:
class RideListView(OptimizedListView):
model = Ride
select_related_fields = ['park', 'manufacturer']
prefetch_related_fields = ['photos']
"""
select_related_fields: List[str] = []
prefetch_related_fields: List[str] = []
def get_queryset(self) -> QuerySet:
"""Get queryset with optimizations applied."""
queryset = super().get_queryset()
if self.select_related_fields:
queryset = queryset.select_related(*self.select_related_fields)
if self.prefetch_related_fields:
queryset = queryset.prefetch_related(*self.prefetch_related_fields)
return queryset
class OptimizedDetailView(DetailView):
"""
DetailView with automatic query optimization.
Automatically applies select_related and prefetch_related based on
class attributes, reducing the need for boilerplate code in get_queryset.
Attributes:
select_related_fields: List of fields to pass to select_related()
prefetch_related_fields: List of fields to pass to prefetch_related()
Example:
class RideDetailView(OptimizedDetailView):
model = Ride
select_related_fields = ['park', 'park__location', 'manufacturer']
prefetch_related_fields = ['photos', 'coaster_stats']
"""
select_related_fields: List[str] = []
prefetch_related_fields: List[str] = []
def get_queryset(self) -> QuerySet:
"""Get queryset with optimizations applied."""
queryset = super().get_queryset()
if self.select_related_fields:
queryset = queryset.select_related(*self.select_related_fields)
if self.prefetch_related_fields:
queryset = queryset.prefetch_related(*self.prefetch_related_fields)
return queryset

View File

@@ -6,8 +6,6 @@ from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
from rest_framework.permissions import AllowAny
from django.views.decorators.csrf import csrf_exempt
from django.utils.decorators import method_decorator
from typing import Optional, List
from ..services.entity_fuzzy_matching import (
@@ -244,10 +242,12 @@ class EntityNotFoundView(APIView):
)
@method_decorator(csrf_exempt, name="dispatch")
class QuickEntitySuggestionView(APIView):
"""
Lightweight endpoint for quick entity suggestions (e.g., autocomplete).
Security Note: This endpoint only accepts GET requests, which are inherently
safe from CSRF attacks. No CSRF exemption is needed.
"""
permission_classes = [AllowAny]

View File

@@ -1,5 +1,4 @@
from django.views.generic.edit import FormView
from django.shortcuts import get_object_or_404
class InlineEditView(FormView):

View File

@@ -636,7 +636,9 @@ class MapCacheView(MapAPIView):
def delete(self, request: HttpRequest) -> JsonResponse:
"""Clear all map cache (admin only)."""
# TODO: Add admin permission check
# TODO(THRILLWIKI-103): Add admin permission check for cache clear
if not (request.user.is_authenticated and request.user.is_staff):
return self._error_response("Admin access required", 403)
try:
unified_map_service.invalidate_cache()
@@ -655,7 +657,9 @@ class MapCacheView(MapAPIView):
def post(self, request: HttpRequest) -> JsonResponse:
"""Invalidate specific cache entries."""
# TODO: Add admin permission check
# TODO(THRILLWIKI-103): Add admin permission check for cache invalidation
if not (request.user.is_authenticated and request.user.is_staff):
return self._error_response("Admin access required", 403)
try:
data = json.loads(request.body)

View File

@@ -1,5 +1,4 @@
from django.views.generic.edit import FormView
from django.http import HttpResponse
class HTMXModalFormView(FormView):

View File

@@ -1,6 +1,7 @@
"""
Core views for the application.
"""
import json
import logging
from typing import Any, Dict, Optional, Type
@@ -93,18 +94,14 @@ class GlobalSearchView(TemplateView):
# Real implementation should query multiple models.
if q:
# Return a small payload of mocked results to keep this scaffold safe
results = [
{"title": f"Result for {q}", "url": "#", "subtitle": "Park"}
]
results = [{"title": f"Result for {q}", "url": "#", "subtitle": "Park"}]
suggestions = [{"text": q, "url": "#"}]
context = {"results": results, "suggestions": suggestions}
# If HTMX request, render dropdown partial
if request.headers.get("HX-Request") == "true":
return render(
request, "core/search/partials/search_dropdown.html", context
)
return render(request, "core/search/partials/search_dropdown.html", context)
return render(request, self.template_name, context)
@@ -117,25 +114,75 @@ class GlobalSearchView(TemplateView):
# Default transition metadata for styling
TRANSITION_METADATA = {
# Approval transitions
"approve": {"style": "green", "icon": "check", "requires_confirm": True, "confirm_message": "Are you sure you want to approve this?"},
"transition_to_approved": {"style": "green", "icon": "check", "requires_confirm": True, "confirm_message": "Are you sure you want to approve this?"},
"approve": {
"style": "green",
"icon": "check",
"requires_confirm": True,
"confirm_message": "Are you sure you want to approve this?",
},
"transition_to_approved": {
"style": "green",
"icon": "check",
"requires_confirm": True,
"confirm_message": "Are you sure you want to approve this?",
},
# Rejection transitions
"reject": {"style": "red", "icon": "times", "requires_confirm": True, "confirm_message": "Are you sure you want to reject this?"},
"transition_to_rejected": {"style": "red", "icon": "times", "requires_confirm": True, "confirm_message": "Are you sure you want to reject this?"},
"reject": {
"style": "red",
"icon": "times",
"requires_confirm": True,
"confirm_message": "Are you sure you want to reject this?",
},
"transition_to_rejected": {
"style": "red",
"icon": "times",
"requires_confirm": True,
"confirm_message": "Are you sure you want to reject this?",
},
# Escalation transitions
"escalate": {"style": "yellow", "icon": "arrow-up", "requires_confirm": True, "confirm_message": "Are you sure you want to escalate this?"},
"transition_to_escalated": {"style": "yellow", "icon": "arrow-up", "requires_confirm": True, "confirm_message": "Are you sure you want to escalate this?"},
"escalate": {
"style": "yellow",
"icon": "arrow-up",
"requires_confirm": True,
"confirm_message": "Are you sure you want to escalate this?",
},
"transition_to_escalated": {
"style": "yellow",
"icon": "arrow-up",
"requires_confirm": True,
"confirm_message": "Are you sure you want to escalate this?",
},
# Assignment transitions
"assign": {"style": "blue", "icon": "user-plus", "requires_confirm": False},
"unassign": {"style": "gray", "icon": "user-minus", "requires_confirm": False},
# Status transitions
"start": {"style": "blue", "icon": "play", "requires_confirm": False},
"complete": {"style": "green", "icon": "check-circle", "requires_confirm": True, "confirm_message": "Are you sure you want to complete this?"},
"cancel": {"style": "red", "icon": "ban", "requires_confirm": True, "confirm_message": "Are you sure you want to cancel this?"},
"complete": {
"style": "green",
"icon": "check-circle",
"requires_confirm": True,
"confirm_message": "Are you sure you want to complete this?",
},
"cancel": {
"style": "red",
"icon": "ban",
"requires_confirm": True,
"confirm_message": "Are you sure you want to cancel this?",
},
"reopen": {"style": "blue", "icon": "redo", "requires_confirm": False},
# Resolution transitions
"resolve": {"style": "green", "icon": "check-double", "requires_confirm": True, "confirm_message": "Are you sure you want to resolve this?"},
"dismiss": {"style": "gray", "icon": "times-circle", "requires_confirm": True, "confirm_message": "Are you sure you want to dismiss this?"},
"resolve": {
"style": "green",
"icon": "check-double",
"requires_confirm": True,
"confirm_message": "Are you sure you want to resolve this?",
},
"dismiss": {
"style": "gray",
"icon": "times-circle",
"requires_confirm": True,
"confirm_message": "Are you sure you want to dismiss this?",
},
# Default
"default": {"style": "gray", "icon": "arrow-right", "requires_confirm": False},
}
@@ -155,7 +202,9 @@ def get_transition_metadata(transition_name: str) -> Dict[str, Any]:
return TRANSITION_METADATA["default"].copy()
def add_toast_trigger(response: HttpResponse, message: str, toast_type: str = "success") -> HttpResponse:
def add_toast_trigger(
response: HttpResponse, message: str, toast_type: str = "success"
) -> HttpResponse:
"""
Add HX-Trigger header to trigger Alpine.js toast.
@@ -167,17 +216,12 @@ def add_toast_trigger(response: HttpResponse, message: str, toast_type: str = "s
Returns:
Modified response with HX-Trigger header
"""
trigger_data = {
"showToast": {
"message": message,
"type": toast_type
}
}
trigger_data = {"showToast": {"message": message, "type": toast_type}}
response["HX-Trigger"] = json.dumps(trigger_data)
return response
@method_decorator(csrf_protect, name='dispatch')
@method_decorator(csrf_protect, name="dispatch")
class FSMTransitionView(View):
"""
Generic view for handling FSM state transitions via HTMX.
@@ -212,12 +256,16 @@ class FSMTransitionView(View):
The model class or None if not found
"""
try:
content_type = ContentType.objects.get(app_label=app_label, model=model_name)
content_type = ContentType.objects.get(
app_label=app_label, model=model_name
)
return content_type.model_class()
except ContentType.DoesNotExist:
return None
def get_object(self, model_class: Type[Model], pk: Any, slug: Optional[str] = None) -> Model:
def get_object(
self, model_class: Type[Model], pk: Any, slug: Optional[str] = None
) -> Model:
"""
Get the model instance.
@@ -249,7 +297,9 @@ class FSMTransitionView(View):
"""
return getattr(obj, transition_name, None)
def validate_transition(self, obj: Model, transition_name: str, user) -> tuple[bool, Optional[str]]:
def validate_transition(
self, obj: Model, transition_name: str, user
) -> tuple[bool, Optional[str]]:
"""
Validate that the transition can proceed.
@@ -264,18 +314,26 @@ class FSMTransitionView(View):
method = self.get_transition_method(obj, transition_name)
if method is None:
return False, f"Transition '{transition_name}' not found on {obj.__class__.__name__}"
return (
False,
f"Transition '{transition_name}' not found on {obj.__class__.__name__}",
)
if not callable(method):
return False, f"'{transition_name}' is not a callable method"
# Check if the transition can proceed
if not can_proceed(method, user):
return False, f"Transition '{transition_name}' is not allowed from current state"
return (
False,
f"Transition '{transition_name}' is not allowed from current state",
)
return True, None
def execute_transition(self, obj: Model, transition_name: str, user, **kwargs) -> None:
def execute_transition(
self, obj: Model, transition_name: str, user, **kwargs
) -> None:
"""
Execute the transition on the object.
@@ -297,7 +355,9 @@ class FSMTransitionView(View):
def get_success_message(self, obj: Model, transition_name: str) -> str:
"""Generate a success message for the transition."""
# Clean up transition name for display
display_name = transition_name.replace("transition_to_", "").replace("_", " ").title()
display_name = (
transition_name.replace("transition_to_", "").replace("_", " ").title()
)
model_name = obj._meta.verbose_name.title()
return f"{model_name} has been {display_name.lower()}d successfully."
@@ -321,9 +381,9 @@ class FSMTransitionView(View):
model_name = obj._meta.model_name
# Special handling for parks and rides - return status section
if app_label == 'parks' and model_name == 'park':
if app_label == "parks" and model_name == "park":
return "parks/partials/park_status_actions.html"
elif app_label == 'rides' and model_name == 'ride':
elif app_label == "rides" and model_name == "ride":
return "rides/partials/ride_status_actions.html"
# Check for model-specific templates in order of preference
@@ -337,6 +397,7 @@ class FSMTransitionView(View):
# Use template loader to check if template exists
from django.template.loader import select_template
from django.template import TemplateDoesNotExist
try:
template = select_template(possible_templates)
return template.template.name
@@ -344,10 +405,7 @@ class FSMTransitionView(View):
return "htmx/updated_row.html"
def format_success_response(
self,
request: HttpRequest,
obj: Model,
transition_name: str
self, request: HttpRequest, obj: Model, transition_name: str
) -> HttpResponse:
"""
Format a successful transition response.
@@ -381,17 +439,20 @@ class FSMTransitionView(View):
return add_toast_trigger(response, message, "success")
# Regular request - return JSON
return JsonResponse({
"success": True,
"message": message,
"new_state": getattr(obj, obj.state_field_name, None) if hasattr(obj, "state_field_name") else None,
})
return JsonResponse(
{
"success": True,
"message": message,
"new_state": (
getattr(obj, obj.state_field_name, None)
if hasattr(obj, "state_field_name")
else None
),
}
)
def format_error_response(
self,
request: HttpRequest,
error: Exception,
status_code: int = 400
self, request: HttpRequest, error: Exception, status_code: int = 400
) -> HttpResponse:
"""
Format an error response.
@@ -408,10 +469,13 @@ class FSMTransitionView(View):
return add_toast_trigger(response, message, "error")
# Regular request - return JSON
return JsonResponse({
"success": False,
"error": error_data,
}, status=status_code)
return JsonResponse(
{
"success": False,
"error": error_data,
},
status=status_code,
)
def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Handle POST request to execute a transition."""
@@ -425,24 +489,22 @@ class FSMTransitionView(View):
if not all([app_label, model_name, transition_name]):
return self.format_error_response(
request,
ValueError("Missing required parameters: app_label, model_name, and transition_name"),
400
ValueError(
"Missing required parameters: app_label, model_name, and transition_name"
),
400,
)
if not pk and not slug:
return self.format_error_response(
request,
ValueError("Missing required parameter: pk or slug"),
400
request, ValueError("Missing required parameter: pk or slug"), 400
)
# Get the model class
model_class = self.get_model_class(app_label, model_name)
if model_class is None:
return self.format_error_response(
request,
ValueError(f"Model '{app_label}.{model_name}' not found"),
404
request, ValueError(f"Model '{app_label}.{model_name}' not found"), 404
)
# Get the object
@@ -450,13 +512,13 @@ class FSMTransitionView(View):
obj = self.get_object(model_class, pk, slug)
except ObjectDoesNotExist:
return self.format_error_response(
request,
ValueError(f"Object not found: {model_name} with pk={pk}"),
404
request, ValueError(f"Object not found: {model_name} with pk={pk}"), 404
)
# Validate the transition
can_execute, error_msg = self.validate_transition(obj, transition_name, request.user)
can_execute, error_msg = self.validate_transition(
obj, transition_name, request.user
)
if not can_execute:
return self.format_error_response(
request,
@@ -466,7 +528,7 @@ class FSMTransitionView(View):
current_state=getattr(obj, "status", None),
requested_transition=transition_name,
),
400
400,
)
# Execute the transition
@@ -509,7 +571,5 @@ class FSMTransitionView(View):
f"Unexpected error during transition '{transition_name}' on {model_class.__name__}(pk={obj.pk})"
)
return self.format_error_response(
request,
ValueError(f"An unexpected error occurred: {str(e)}"),
500
request, ValueError(f"An unexpected error occurred: {str(e)}"), 500
)

View File

@@ -7,6 +7,7 @@ across all models using django-fsm-log.
from django.core.management.base import BaseCommand
from django.db.models import Count, Avg, F
from django.db.models.functions import TruncDate, ExtractHour
from django.utils import timezone
from datetime import timedelta
from django_fsm_log.models import StateLog
@@ -148,9 +149,10 @@ class Command(BaseCommand):
self.stdout.write(f" System-initiated: {system_count} ({(system_count/total_transitions)*100:.1f}%)")
# Daily transition volume
# Security: Using Django ORM functions instead of raw SQL .extra() to prevent SQL injection
self.stdout.write(self.style.SUCCESS('\n--- Daily Transition Volume ---'))
daily_stats = (
queryset.extra(select={'day': 'date(timestamp)'})
queryset.annotate(day=TruncDate('timestamp'))
.values('day')
.annotate(count=Count('id'))
.order_by('-day')[:7]
@@ -162,9 +164,10 @@ class Command(BaseCommand):
self.stdout.write(f" {date}: {count} transitions")
# Busiest hours
# Security: Using Django ORM functions instead of raw SQL .extra() to prevent SQL injection
self.stdout.write(self.style.SUCCESS('\n--- Busiest Hours (UTC) ---'))
hourly_stats = (
queryset.extra(select={'hour': 'extract(hour from timestamp)'})
queryset.annotate(hour=ExtractHour('timestamp'))
.values('hour')
.annotate(count=Count('id'))
.order_by('-count')[:5]

View File

@@ -4,7 +4,8 @@ Following Django styleguide pattern for separating data access from business log
"""
from typing import Optional, Dict, Any
from django.db.models import QuerySet, Count
from django.db.models import QuerySet, Count, F, ExpressionWrapper, FloatField
from django.db.models.functions import Extract
from django.utils import timezone
from datetime import timedelta
from django.contrib.auth.models import User
@@ -185,12 +186,14 @@ def moderation_statistics_summary(
rejected_submissions = handled_queryset.filter(status="REJECTED").count()
# Response time analysis (only for handled submissions)
# Security: Using Django ORM instead of raw SQL .extra() to prevent SQL injection
handled_with_times = (
handled_queryset.exclude(handled_at__isnull=True)
.extra(
select={
"response_hours": "EXTRACT(EPOCH FROM (handled_at - created_at)) / 3600"
}
.annotate(
response_hours=ExpressionWrapper(
Extract(F('handled_at') - F('created_at'), 'epoch') / 3600.0,
output_field=FloatField()
)
)
.values_list("response_hours", flat=True)
)

View File

@@ -306,35 +306,39 @@ class ModerationReportViewSet(viewsets.ModelViewSet):
return Response(stats_data)
@action(detail=True, methods=['get'], permission_classes=[CanViewModerationData])
@action(detail=True, methods=["get"], permission_classes=[CanViewModerationData])
def history(self, request, pk=None):
"""Get transition history for this report."""
from django_fsm_log.models import StateLog
from django.contrib.contenttypes.models import ContentType
report = self.get_object()
content_type = ContentType.objects.get_for_model(report)
logs = StateLog.objects.filter(
content_type=content_type,
object_id=report.id
).select_related('by').order_by('-timestamp')
history_data = [{
'id': log.id,
'timestamp': log.timestamp,
'state': log.state,
'from_state': log.source_state,
'to_state': log.state,
'transition': log.transition,
'user': log.by.username if log.by else None,
'description': log.description,
'reason': log.description,
} for log in logs]
logs = (
StateLog.objects.filter(content_type=content_type, object_id=report.id)
.select_related("by")
.order_by("-timestamp")
)
history_data = [
{
"id": log.id,
"timestamp": log.timestamp,
"state": log.state,
"from_state": log.source_state,
"to_state": log.state,
"transition": log.transition,
"user": log.by.username if log.by else None,
"description": log.description,
"reason": log.description,
}
for log in logs
]
return Response(history_data)
@action(detail=False, methods=['get'], permission_classes=[CanViewModerationData])
@action(detail=False, methods=["get"], permission_classes=[CanViewModerationData])
def all_history(self, request):
"""Get all transition history with filtering.
@@ -343,61 +347,77 @@ class ModerationReportViewSet(viewsets.ModelViewSet):
from django_fsm_log.models import StateLog
from django.contrib.contenttypes.models import ContentType
queryset = StateLog.objects.select_related('by', 'content_type').all()
queryset = StateLog.objects.select_related("by", "content_type").all()
# Filter by id (for detail view)
log_id = request.query_params.get('id')
log_id = request.query_params.get("id")
if log_id:
try:
log = queryset.get(id=log_id)
# Check if HTMX request for detail view
if request.headers.get('HX-Request'):
return render(request, 'moderation/partials/history_detail_content.html', {
'log': log,
})
if request.headers.get("HX-Request"):
return render(
request,
"moderation/partials/history_detail_content.html",
{
"log": log,
},
)
# Return JSON for API request
return Response({
'id': log.id,
'timestamp': log.timestamp,
'model': log.content_type.model,
'object_id': log.object_id,
'state': log.state,
'from_state': log.source_state,
'to_state': log.state,
'transition': log.transition,
'user': log.by.username if log.by else None,
'description': log.description,
'reason': log.description,
})
return Response(
{
"id": log.id,
"timestamp": log.timestamp,
"model": log.content_type.model,
"object_id": log.object_id,
"state": log.state,
"from_state": log.source_state,
"to_state": log.state,
"transition": log.transition,
"user": log.by.username if log.by else None,
"description": log.description,
"reason": log.description,
}
)
except StateLog.DoesNotExist:
if request.headers.get('HX-Request'):
return render(request, 'moderation/partials/history_detail_content.html', {
'log': None,
})
return Response({'error': 'Log not found'}, status=status.HTTP_404_NOT_FOUND)
if request.headers.get("HX-Request"):
return render(
request,
"moderation/partials/history_detail_content.html",
{
"log": None,
},
)
return Response(
{"error": "Log not found"}, status=status.HTTP_404_NOT_FOUND
)
# Filter by model type with app_label support for correct ContentType resolution
model_type = request.query_params.get('model_type')
app_label = request.query_params.get('app_label')
model_type = request.query_params.get("model_type")
app_label = request.query_params.get("app_label")
if model_type:
try:
if app_label:
# Use both app_label and model for precise matching
content_type = ContentType.objects.get_by_natural_key(app_label, model_type)
content_type = ContentType.objects.get_by_natural_key(
app_label, model_type
)
else:
# Map common model names to their app_labels for correct resolution
model_app_mapping = {
'park': 'parks',
'ride': 'rides',
'editsubmission': 'submissions',
'photosubmission': 'submissions',
'moderationreport': 'moderation',
'moderationqueue': 'moderation',
'bulkoperation': 'moderation',
"park": "parks",
"ride": "rides",
"editsubmission": "submissions",
"photosubmission": "submissions",
"moderationreport": "moderation",
"moderationqueue": "moderation",
"bulkoperation": "moderation",
}
mapped_app_label = model_app_mapping.get(model_type.lower())
if mapped_app_label:
content_type = ContentType.objects.get_by_natural_key(mapped_app_label, model_type.lower())
content_type = ContentType.objects.get_by_natural_key(
mapped_app_label, model_type.lower()
)
else:
# Fallback to model-only lookup
content_type = ContentType.objects.get(model=model_type)
@@ -406,88 +426,98 @@ class ModerationReportViewSet(viewsets.ModelViewSet):
pass
# Filter by object_id (for object-level history)
object_id = request.query_params.get('object_id')
object_id = request.query_params.get("object_id")
if object_id:
queryset = queryset.filter(object_id=object_id)
# Filter by user
user_id = request.query_params.get('user_id')
user_id = request.query_params.get("user_id")
if user_id:
queryset = queryset.filter(by_id=user_id)
# Filter by date range
start_date = request.query_params.get('start_date')
end_date = request.query_params.get('end_date')
start_date = request.query_params.get("start_date")
end_date = request.query_params.get("end_date")
if start_date:
queryset = queryset.filter(timestamp__gte=start_date)
if end_date:
queryset = queryset.filter(timestamp__lte=end_date)
# Filter by state
state = request.query_params.get('state')
state = request.query_params.get("state")
if state:
queryset = queryset.filter(state=state)
# Search filter (case-insensitive across relevant fields)
search_query = request.query_params.get('q')
search_query = request.query_params.get("q")
if search_query:
queryset = queryset.filter(
Q(transition__icontains=search_query) |
Q(description__icontains=search_query) |
Q(state__icontains=search_query) |
Q(source_state__icontains=search_query) |
Q(object_id__icontains=search_query) |
Q(by__username__icontains=search_query)
Q(transition__icontains=search_query)
| Q(description__icontains=search_query)
| Q(state__icontains=search_query)
| Q(source_state__icontains=search_query)
| Q(object_id__icontains=search_query)
| Q(by__username__icontains=search_query)
)
# Order queryset
queryset = queryset.order_by('-timestamp')
queryset = queryset.order_by("-timestamp")
# Check if HTMX request
if request.headers.get('HX-Request'):
if request.headers.get("HX-Request"):
# Use Django's Paginator for HTMX responses
paginator = Paginator(queryset, 20)
page_number = request.query_params.get('page', 1)
page_number = request.query_params.get("page", 1)
page_obj = paginator.get_page(page_number)
return render(request, 'moderation/partials/history_table.html', {
'history_logs': page_obj,
'page_obj': page_obj,
'request': request,
})
return render(
request,
"moderation/partials/history_table.html",
{
"history_logs": page_obj,
"page_obj": page_obj,
"request": request,
},
)
# Paginate for API response
page = self.paginate_queryset(queryset)
if page is not None:
history_data = [{
'id': log.id,
'timestamp': log.timestamp,
'model': log.content_type.model,
'object_id': log.object_id,
'state': log.state,
'from_state': log.source_state,
'to_state': log.state,
'transition': log.transition,
'user': log.by.username if log.by else None,
'description': log.description,
'reason': log.description,
} for log in page]
history_data = [
{
"id": log.id,
"timestamp": log.timestamp,
"model": log.content_type.model,
"object_id": log.object_id,
"state": log.state,
"from_state": log.source_state,
"to_state": log.state,
"transition": log.transition,
"user": log.by.username if log.by else None,
"description": log.description,
"reason": log.description,
}
for log in page
]
return self.get_paginated_response(history_data)
# Return all history data when pagination is not triggered
history_data = [{
'id': log.id,
'timestamp': log.timestamp,
'model': log.content_type.model,
'object_id': log.object_id,
'state': log.state,
'from_state': log.source_state,
'to_state': log.state,
'transition': log.transition,
'user': log.by.username if log.by else None,
'description': log.description,
'reason': log.description,
} for log in queryset]
history_data = [
{
"id": log.id,
"timestamp": log.timestamp,
"model": log.content_type.model,
"object_id": log.object_id,
"state": log.state,
"from_state": log.source_state,
"to_state": log.state,
"transition": log.transition,
"user": log.by.username if log.by else None,
"description": log.description,
"reason": log.description,
}
for log in queryset
]
return Response(history_data)
@@ -704,7 +734,7 @@ class ModerationQueueViewSet(viewsets.ModelViewSet):
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)
@action(detail=True, methods=['get'], permission_classes=[CanViewModerationData])
@action(detail=True, methods=["get"], permission_classes=[CanViewModerationData])
def history(self, request, pk=None):
"""Get transition history for this queue item."""
from django_fsm_log.models import StateLog
@@ -713,22 +743,26 @@ class ModerationQueueViewSet(viewsets.ModelViewSet):
queue_item = self.get_object()
content_type = ContentType.objects.get_for_model(queue_item)
logs = StateLog.objects.filter(
content_type=content_type,
object_id=queue_item.id
).select_related('by').order_by('-timestamp')
logs = (
StateLog.objects.filter(content_type=content_type, object_id=queue_item.id)
.select_related("by")
.order_by("-timestamp")
)
history_data = [{
'id': log.id,
'timestamp': log.timestamp,
'state': log.state,
'from_state': log.source_state,
'to_state': log.state,
'transition': log.transition,
'user': log.by.username if log.by else None,
'description': log.description,
'reason': log.description,
} for log in logs]
history_data = [
{
"id": log.id,
"timestamp": log.timestamp,
"state": log.state,
"from_state": log.source_state,
"to_state": log.state,
"transition": log.transition,
"user": log.by.username if log.by else None,
"description": log.description,
"reason": log.description,
}
for log in logs
]
return Response(history_data)
@@ -996,7 +1030,7 @@ class BulkOperationViewSet(viewsets.ModelViewSet):
serializer = self.get_serializer(queryset, many=True)
return Response(serializer.data)
@action(detail=True, methods=['get'])
@action(detail=True, methods=["get"])
def history(self, request, pk=None):
"""Get transition history for this bulk operation."""
from django_fsm_log.models import StateLog
@@ -1005,22 +1039,26 @@ class BulkOperationViewSet(viewsets.ModelViewSet):
operation = self.get_object()
content_type = ContentType.objects.get_for_model(operation)
logs = StateLog.objects.filter(
content_type=content_type,
object_id=operation.id
).select_related('by').order_by('-timestamp')
logs = (
StateLog.objects.filter(content_type=content_type, object_id=operation.id)
.select_related("by")
.order_by("-timestamp")
)
history_data = [{
'id': log.id,
'timestamp': log.timestamp,
'state': log.state,
'from_state': log.source_state,
'to_state': log.state,
'transition': log.transition,
'user': log.by.username if log.by else None,
'description': log.description,
'reason': log.description,
} for log in logs]
history_data = [
{
"id": log.id,
"timestamp": log.timestamp,
"state": log.state,
"from_state": log.source_state,
"to_state": log.state,
"transition": log.transition,
"user": log.by.username if log.by else None,
"description": log.description,
"reason": log.description,
}
for log in logs
]
return Response(history_data)

View File

@@ -203,11 +203,12 @@ class Command(BaseCommand):
)
def create_parks(self):
"""Create parks with proper operator relationships"""
"""Create parks with proper operator relationships."""
self.stdout.write("Creating parks...")
# TODO: Implement park creation - parks_data defined but not used yet
parks_data = [ # noqa: F841
# Park creation data - will be used to create parks in the database
# TODO(THRILLWIKI-111): Complete park creation implementation
parks_data = [
{
"name": "Magic Kingdom",
"slug": "magic-kingdom",

View File

@@ -194,10 +194,9 @@ class Command(BaseCommand):
missing_tables = []
for model in required_models:
try:
# Check if the table exists by trying to get the table name
table_name = model._meta.db_table
with connection.cursor() as cursor:
cursor.execute(f"SELECT 1 FROM {table_name} LIMIT 1")
# Security: Use Django ORM to check table existence instead of raw SQL
# This is safer as it avoids any potential SQL injection via model metadata
model.objects.exists()
except Exception:
missing_tables.append(model._meta.label)

View File

@@ -127,16 +127,17 @@ class ParkQuerySet(StatusQuerySet, ReviewableQuerySet, LocationQuerySet):
| Q(location__state__icontains=query)
)
.select_related("operator", "location")
.values(
"id",
"name",
"slug",
"location__city",
"location__state",
"operator__name",
.only(
"id", "name", "slug",
"location__city", "location__state",
"operator__name"
)[:limit]
)
def with_location(self):
"""Always prefetch location for park queries."""
return self.select_related("operator").prefetch_related("location")
class ParkManager(StatusManager, ReviewableManager, LocationManager):
"""Custom manager for Park model."""
@@ -162,6 +163,10 @@ class ParkManager(StatusManager, ReviewableManager, LocationManager):
def for_map_display(self, *, bounds=None):
return self.get_queryset().for_map_display(bounds=bounds)
def with_location(self):
"""Always prefetch location for park queries."""
return self.get_queryset().with_location()
class ParkAreaQuerySet(BaseQuerySet):
"""QuerySet for ParkArea model."""
@@ -300,3 +305,33 @@ class CompanyManager(BaseManager):
def major_operators(self, *, min_parks: int = 5):
return self.get_queryset().major_operators(min_parks=min_parks)
def manufacturers_with_ride_count(self):
"""Get manufacturers with ride count annotation for list views."""
return (
self.get_queryset()
.manufacturers()
.annotate(ride_count=Count("manufactured_rides", distinct=True))
.only('id', 'name', 'slug', 'roles', 'description')
.order_by("name")
)
def designers_with_ride_count(self):
"""Get designers with ride count annotation for list views."""
return (
self.get_queryset()
.filter(roles__contains=["DESIGNER"])
.annotate(ride_count=Count("designed_rides", distinct=True))
.only('id', 'name', 'slug', 'roles', 'description')
.order_by("name")
)
def operators_with_park_count(self):
"""Get operators with park count annotation for list views."""
return (
self.get_queryset()
.operators()
.with_park_counts()
.only('id', 'name', 'slug', 'roles', 'description')
.order_by("name")
)

View File

@@ -0,0 +1,25 @@
"""
Add GIN index for Company.roles ArrayField.
This improves query performance for queries like:
Company.objects.filter(roles__contains=["MANUFACTURER"])
Company.objects.filter(roles__contains=["OPERATOR"])
GIN indexes are specifically designed for array containment queries in PostgreSQL.
"""
from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('parks', '0022_alter_company_roles_alter_companyevent_roles'),
]
operations = [
migrations.RunSQL(
sql="CREATE INDEX IF NOT EXISTS parks_company_roles_gin_idx ON parks_company USING gin(roles);",
reverse_sql="DROP INDEX IF EXISTS parks_company_roles_gin_idx;",
),
]

View File

@@ -0,0 +1,28 @@
"""
Add default value 'UTC' to Park.timezone field.
This ensures all new parks have a valid timezone and existing parks
without a timezone get a sensible default.
"""
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('parks', '0023_add_company_roles_gin_index'),
]
operations = [
migrations.AlterField(
model_name='park',
name='timezone',
field=models.CharField(
blank=True,
default='UTC',
help_text="Timezone identifier for park operations (e.g., 'America/New_York')",
max_length=50,
),
),
]

View File

@@ -127,6 +127,8 @@ class Park(StateMachineMixin, TrackedModel):
# Timezone for park operations
timezone = models.CharField(
max_length=50,
default='UTC',
blank=True,
help_text="Timezone identifier for park operations (e.g., 'America/New_York')"
)

View File

@@ -3,16 +3,21 @@ Services for park-related business logic.
Following Django styleguide pattern for business logic encapsulation.
"""
from typing import Optional, Dict, Any, TYPE_CHECKING
import logging
from typing import Optional, Dict, Any, List, TYPE_CHECKING
from django.db import transaction
from django.db.models import Q
from django.core.files.uploadedfile import UploadedFile
if TYPE_CHECKING:
from django.contrib.auth.models import AbstractUser
from ..models import Park, ParkArea
from ..models import Park, ParkArea, ParkPhoto
from ..models.location import ParkLocation
from .location_service import ParkLocationService
logger = logging.getLogger(__name__)
class ParkService:
"""Service for managing park operations."""
@@ -226,3 +231,282 @@ class ParkService:
park.save()
return park
@staticmethod
def create_park_with_moderation(
*,
changes: Dict[str, Any],
submitter: "AbstractUser",
reason: str = "",
source: str = "",
) -> Dict[str, Any]:
"""
Create a park through the moderation system.
Args:
changes: Dictionary of park data
submitter: User submitting the park
reason: Reason for submission
source: Source of information
Returns:
Dictionary with status and created object (if auto-approved)
"""
from apps.moderation.services import ModerationService
return ModerationService.create_edit_submission_with_queue(
content_object=None,
changes=changes,
submitter=submitter,
submission_type="CREATE",
reason=reason,
source=source,
)
@staticmethod
def update_park_with_moderation(
*,
park: Park,
changes: Dict[str, Any],
submitter: "AbstractUser",
reason: str = "",
source: str = "",
) -> Dict[str, Any]:
"""
Update a park through the moderation system.
Args:
park: Park instance to update
changes: Dictionary of changes
submitter: User submitting the update
reason: Reason for submission
source: Source of information
Returns:
Dictionary with status and updated object (if auto-approved)
"""
from apps.moderation.services import ModerationService
return ModerationService.create_edit_submission_with_queue(
content_object=park,
changes=changes,
submitter=submitter,
submission_type="EDIT",
reason=reason,
source=source,
)
@staticmethod
def create_or_update_location(
*,
park: Park,
latitude: Optional[float],
longitude: Optional[float],
street_address: str = "",
city: str = "",
state: str = "",
country: str = "USA",
postal_code: str = "",
) -> Optional[ParkLocation]:
"""
Create or update a park's location.
Args:
park: Park instance
latitude: Latitude coordinate
longitude: Longitude coordinate
street_address: Street address
city: City name
state: State/region
country: Country (default: USA)
postal_code: Postal/ZIP code
Returns:
ParkLocation instance or None if no coordinates provided
"""
if not latitude or not longitude:
return None
try:
park_location = park.location
# Update existing location
park_location.street_address = street_address
park_location.city = city
park_location.state = state
park_location.country = country or "USA"
park_location.postal_code = postal_code
park_location.set_coordinates(float(latitude), float(longitude))
park_location.save()
return park_location
except ParkLocation.DoesNotExist:
# Create new location
park_location = ParkLocation.objects.create(
park=park,
street_address=street_address,
city=city,
state=state,
country=country or "USA",
postal_code=postal_code,
)
park_location.set_coordinates(float(latitude), float(longitude))
park_location.save()
return park_location
@staticmethod
def upload_photos(
*,
park: Park,
photos: List[UploadedFile],
uploaded_by: "AbstractUser",
) -> Dict[str, Any]:
"""
Upload multiple photos for a park.
Args:
park: Park instance
photos: List of uploaded photo files
uploaded_by: User uploading the photos
Returns:
Dictionary with uploaded_count and errors list
"""
from django.contrib.contenttypes.models import ContentType
uploaded_count = 0
errors: List[str] = []
for photo_file in photos:
try:
ParkPhoto.objects.create(
image=photo_file,
uploaded_by=uploaded_by,
park=park,
)
uploaded_count += 1
except Exception as e:
error_msg = f"Error uploading photo {photo_file.name}: {str(e)}"
errors.append(error_msg)
logger.warning(error_msg)
return {
"uploaded_count": uploaded_count,
"errors": errors,
}
@staticmethod
def handle_park_creation_result(
*,
result: Dict[str, Any],
form_data: Dict[str, Any],
photos: List[UploadedFile],
user: "AbstractUser",
) -> Dict[str, Any]:
"""
Handle the result of park creation through moderation.
Args:
result: Result from create_park_with_moderation
form_data: Cleaned form data containing location info
photos: List of uploaded photo files
user: User who submitted
Returns:
Dictionary with status, park (if created), uploaded_count, and errors
"""
response: Dict[str, Any] = {
"status": result["status"],
"park": None,
"uploaded_count": 0,
"errors": [],
}
if result["status"] == "auto_approved":
park = result["created_object"]
response["park"] = park
# Create location
ParkService.create_or_update_location(
park=park,
latitude=form_data.get("latitude"),
longitude=form_data.get("longitude"),
street_address=form_data.get("street_address", ""),
city=form_data.get("city", ""),
state=form_data.get("state", ""),
country=form_data.get("country", "USA"),
postal_code=form_data.get("postal_code", ""),
)
# Upload photos
if photos:
photo_result = ParkService.upload_photos(
park=park,
photos=photos,
uploaded_by=user,
)
response["uploaded_count"] = photo_result["uploaded_count"]
response["errors"] = photo_result["errors"]
elif result["status"] == "failed":
response["message"] = result.get("message", "Creation failed")
return response
@staticmethod
def handle_park_update_result(
*,
result: Dict[str, Any],
park: Park,
form_data: Dict[str, Any],
photos: List[UploadedFile],
user: "AbstractUser",
) -> Dict[str, Any]:
"""
Handle the result of park update through moderation.
Args:
result: Result from update_park_with_moderation
park: Original park instance (for queued submissions)
form_data: Cleaned form data containing location info
photos: List of uploaded photo files
user: User who submitted
Returns:
Dictionary with status, park, uploaded_count, and errors
"""
response: Dict[str, Any] = {
"status": result["status"],
"park": park,
"uploaded_count": 0,
"errors": [],
}
if result["status"] == "auto_approved":
updated_park = result["created_object"]
response["park"] = updated_park
# Update location
ParkService.create_or_update_location(
park=updated_park,
latitude=form_data.get("latitude"),
longitude=form_data.get("longitude"),
street_address=form_data.get("street_address", ""),
city=form_data.get("city", ""),
state=form_data.get("state", ""),
country=form_data.get("country", ""),
postal_code=form_data.get("postal_code", ""),
)
# Upload photos
if photos:
photo_result = ParkService.upload_photos(
park=updated_park,
photos=photos,
uploaded_by=user,
)
response["uploaded_count"] = photo_result["uploaded_count"]
response["errors"] = photo_result["errors"]
elif result["status"] == "failed":
response["message"] = result.get("message", "Update failed")
return response

View File

@@ -11,6 +11,28 @@ from .models import Park
logger = logging.getLogger(__name__)
# =============================================================================
# Computed Field Maintenance Signals
# =============================================================================
def update_park_search_text(park):
"""
Update park's search_text computed field.
This is called when related objects (location, operator, property_owner)
change and might affect the park's search text.
"""
if park is None:
return
try:
park._populate_computed_fields()
park.save(update_fields=['search_text'])
logger.debug(f"Updated search_text for park {park.pk}")
except Exception as e:
logger.exception(f"Failed to update search_text for park {park.pk}: {e}")
# Status values that count as "active" rides for counting purposes
ACTIVE_STATUSES = {'OPERATING', 'SEASONAL', 'UNDER_CONSTRUCTION'}
@@ -46,8 +68,8 @@ def update_park_ride_counts(park, old_status=None, new_status=None):
# Count total operating rides
ride_count = park.rides.filter(operating_rides).count()
# Count total operating roller coasters
coaster_count = park.rides.filter(operating_rides, category="RC").count()
# Count total operating roller coasters (including water coasters)
coaster_count = park.rides.filter(operating_rides, category__in=["RC", "WC"]).count()
# Update park counts
Park.objects.filter(id=park_id).update(
@@ -148,3 +170,44 @@ def handle_ride_status_transition(instance, source, target, user, **kwargs):
f"by {user if user else 'system'}"
)
update_park_ride_counts(instance.park, source, target)
# =============================================================================
# Computed Field Maintenance Signal Handlers
# =============================================================================
@receiver(post_save, sender='parks.ParkLocation')
def update_park_search_text_on_location_change(sender, instance, **kwargs):
"""
Update park search_text when location changes.
When a park's location is updated (city, state, country changes),
the park's search_text needs to be regenerated to include the new
location information.
"""
try:
if hasattr(instance, 'park') and instance.park:
update_park_search_text(instance.park)
except Exception as e:
logger.exception(f"Failed to update park search_text on location change: {e}")
@receiver(post_save, sender='parks.Company')
def update_park_search_text_on_company_change(sender, instance, **kwargs):
"""
Update park search_text when operator/owner name changes.
When a company's name changes, all parks operated or owned by that
company need their search_text regenerated.
"""
try:
# Update all parks operated by this company
for park in instance.operated_parks.all():
update_park_search_text(park)
# Update all parks owned by this company
for park in instance.owned_parks.all():
update_park_search_text(park)
except Exception as e:
logger.exception(f"Failed to update park search_text on company change: {e}")

View File

@@ -0,0 +1,243 @@
"""
Tests for query optimization patterns in the parks app.
These tests verify that:
1. Manager methods use proper select_related/prefetch_related
2. Views don't trigger N+1 queries
3. Computed fields are updated correctly
"""
from django.test import TestCase
from django.db import connection
from django.test.utils import CaptureQueriesContext
from apps.parks.models import Park, ParkLocation, Company
class ParkQueryOptimizationTests(TestCase):
"""Tests for Park query optimization."""
@classmethod
def setUpTestData(cls):
"""Set up test data for all tests."""
# Create a test operator company
cls.operator = Company.objects.create(
name="Test Operator",
slug="test-operator",
roles=["OPERATOR"],
)
# Create test parks
cls.parks = []
for i in range(5):
park = Park.objects.create(
name=f"Test Park {i}",
slug=f"test-park-{i}",
operator=cls.operator,
timezone="UTC",
)
# Create location for each park
ParkLocation.objects.create(
park=park,
city=f"City {i}",
state="CA",
country="USA",
)
cls.parks.append(park)
def test_optimized_for_list_query_count(self):
"""Verify optimized_for_list uses expected number of queries."""
with CaptureQueriesContext(connection) as context:
parks = Park.objects.optimized_for_list()
# Force evaluation
list(parks)
# Should be a small number of queries (main query + prefetch)
# The exact count depends on prefetch_related configuration
self.assertLessEqual(
len(context.captured_queries),
5,
f"Expected <= 5 queries, got {len(context.captured_queries)}"
)
def test_optimized_for_detail_query_count(self):
"""Verify optimized_for_detail uses expected number of queries."""
with CaptureQueriesContext(connection) as context:
parks = Park.objects.optimized_for_detail()
park = parks.first()
if park:
# Access related objects that should be prefetched
_ = park.operator
_ = list(park.areas.all())
# Should be a reasonable number of queries
self.assertLessEqual(
len(context.captured_queries),
10,
f"Expected <= 10 queries, got {len(context.captured_queries)}"
)
def test_with_location_includes_location(self):
"""Verify with_location prefetches location data."""
with CaptureQueriesContext(connection) as context:
parks = Park.objects.with_location()
for park in parks:
# Accessing location should not cause additional queries
_ = park.location
# Should be minimal queries
self.assertLessEqual(len(context.captured_queries), 3)
def test_for_map_display_returns_minimal_fields(self):
"""Verify for_map_display returns only necessary fields."""
result = Park.objects.for_map_display()
if result.exists():
first = result.first()
# Should include these fields
self.assertIn('id', first)
self.assertIn('name', first)
self.assertIn('slug', first)
self.assertIn('status', first)
def test_search_autocomplete_limits_results(self):
"""Verify search_autocomplete respects limit parameter."""
result = Park.objects.search_autocomplete(query="Test", limit=3)
self.assertLessEqual(len(result), 3)
class CompanyQueryOptimizationTests(TestCase):
"""Tests for Company query optimization."""
@classmethod
def setUpTestData(cls):
"""Set up test data for all tests."""
# Create test companies with different roles
cls.manufacturers = []
for i in range(5):
company = Company.objects.create(
name=f"Manufacturer {i}",
slug=f"manufacturer-{i}",
roles=["MANUFACTURER"],
)
cls.manufacturers.append(company)
cls.operators = []
for i in range(3):
company = Company.objects.create(
name=f"Operator {i}",
slug=f"operator-{i}",
roles=["OPERATOR"],
)
cls.operators.append(company)
def test_manufacturers_query_only_returns_manufacturers(self):
"""Verify manufacturers() only returns companies with MANUFACTURER role."""
result = Company.objects.manufacturers()
for company in result:
self.assertIn("MANUFACTURER", company.roles)
def test_operators_query_only_returns_operators(self):
"""Verify operators() only returns companies with OPERATOR role."""
result = Company.objects.operators()
for company in result:
self.assertIn("OPERATOR", company.roles)
def test_manufacturers_with_ride_count_includes_annotation(self):
"""Verify manufacturers_with_ride_count adds ride_count annotation."""
result = Company.objects.manufacturers_with_ride_count()
if result.exists():
first = result.first()
# Should have ride_count attribute
self.assertTrue(hasattr(first, 'ride_count'))
def test_operators_with_park_count_includes_annotation(self):
"""Verify operators_with_park_count adds park count annotations."""
result = Company.objects.operators_with_park_count()
if result.exists():
first = result.first()
# Should have operated_parks_count attribute
self.assertTrue(hasattr(first, 'operated_parks_count'))
class ComputedFieldMaintenanceTests(TestCase):
"""Tests for computed field maintenance via signals."""
@classmethod
def setUpTestData(cls):
"""Set up test data for all tests."""
cls.operator = Company.objects.create(
name="Test Operator",
slug="test-operator",
roles=["OPERATOR"],
)
def test_park_search_text_includes_name(self):
"""Verify park search_text includes park name."""
park = Park.objects.create(
name="Magic Kingdom",
slug="magic-kingdom",
operator=self.operator,
timezone="UTC",
)
self.assertIn("magic kingdom", park.search_text.lower())
def test_park_search_text_includes_description(self):
"""Verify park search_text includes description."""
park = Park.objects.create(
name="Test Park",
slug="test-park",
description="A magical theme park experience",
operator=self.operator,
timezone="UTC",
)
self.assertIn("magical", park.search_text.lower())
def test_park_search_text_includes_operator(self):
"""Verify park search_text includes operator name."""
park = Park.objects.create(
name="Test Park",
slug="test-park-2",
operator=self.operator,
timezone="UTC",
)
self.assertIn("test operator", park.search_text.lower())
def test_park_opening_year_computed_from_date(self):
"""Verify opening_year is computed from opening_date."""
from datetime import date
park = Park.objects.create(
name="Test Park",
slug="test-park-3",
operator=self.operator,
timezone="UTC",
opening_date=date(1971, 10, 1),
)
self.assertEqual(park.opening_year, 1971)
def test_park_search_text_updated_on_location_change(self):
"""Verify park search_text updates when location changes."""
park = Park.objects.create(
name="Test Park",
slug="test-park-4",
operator=self.operator,
timezone="UTC",
)
# Initially no location in search_text
original_search_text = park.search_text
# Add location
location = ParkLocation.objects.create(
park=park,
city="Orlando",
state="Florida",
country="USA",
)
# Refresh park from database
park.refresh_from_db()
# Search text should now include location
# Note: This depends on signal handlers being properly registered
# The actual behavior may vary based on signal configuration

View File

@@ -2,7 +2,6 @@ from .querysets import get_base_park_queryset
from apps.core.mixins import HTMXFilterableMixin
from .models.location import ParkLocation
from .models.media import ParkPhoto
from apps.moderation.services import ModerationService
from apps.moderation.mixins import (
EditSubmissionMixin,
PhotoSubmissionMixin,
@@ -12,7 +11,7 @@ from apps.core.views.views import SlugRedirectMixin
from .filters import ParkFilter
from .forms import ParkForm
from .models import Park, ParkArea, ParkReview as Review
from .services import ParkFilterService
from .services import ParkFilterService, ParkService
from django.http import (
HttpResponseRedirect,
HttpResponse,
@@ -21,7 +20,6 @@ from django.http import (
)
from django.core.exceptions import ObjectDoesNotExist
from django.contrib import messages
from django.contrib.contenttypes.models import ContentType
from django.contrib.auth.mixins import LoginRequiredMixin
from django.db.models import QuerySet
from django.urls import reverse
@@ -33,7 +31,7 @@ from decimal import Decimal, ROUND_DOWN
from typing import Any, Optional, cast, Literal, Dict
from django.views.decorators.http import require_POST
from django.template.loader import render_to_string
import json
# Constants
@@ -137,22 +135,24 @@ def park_status_actions(request: HttpRequest, slug: str) -> HttpResponse:
park = get_object_or_404(Park, slug=slug)
# Only show to moderators
if not request.user.has_perm('parks.change_park'):
if not request.user.has_perm("parks.change_park"):
return HttpResponse("")
return render(request, "parks/partials/park_status_actions.html", {
"park": park,
"user": request.user
})
return render(
request,
"parks/partials/park_status_actions.html",
{"park": park, "user": request.user},
)
def park_header_badge(request: HttpRequest, slug: str) -> HttpResponse:
"""Return the header status badge partial for a park"""
park = get_object_or_404(Park, slug=slug)
return render(request, "parks/partials/park_header_badge.html", {
"park": park,
"user": request.user
})
return render(
request,
"parks/partials/park_header_badge.html",
{"park": park, "user": request.user},
)
def get_park_areas(request: HttpRequest) -> HttpResponse:
@@ -502,6 +502,7 @@ def htmx_saved_trips(request: HttpRequest) -> HttpResponse:
if request.user.is_authenticated:
try:
from .models import Trip # type: ignore
qs = Trip.objects.filter(owner=request.user).order_by("-created_at")
trips = list(qs[:10])
except Exception:
@@ -648,7 +649,10 @@ def htmx_optimize_route(request: HttpRequest) -> HttpResponse:
rlat1, rlon1, rlat2, rlon2 = map(math.radians, [lat1, lon1, lat2, lon2])
dlat = rlat2 - rlat1
dlon = rlon2 - rlon1
a = math.sin(dlat / 2) ** 2 + math.cos(rlat1) * math.cos(rlat2) * math.sin(dlon / 2) ** 2
a = (
math.sin(dlat / 2) ** 2
+ math.cos(rlat1) * math.cos(rlat2) * math.sin(dlon / 2) ** 2
)
c = 2 * math.asin(min(1, math.sqrt(a)))
miles = 3958.8 * c
return miles
@@ -660,14 +664,18 @@ def htmx_optimize_route(request: HttpRequest) -> HttpResponse:
lat = getattr(loc, "latitude", None) if loc else None
lon = getattr(loc, "longitude", None) if loc else None
if lat is not None and lon is not None:
waypoints.append({"id": p.id, "name": p.name, "latitude": lat, "longitude": lon})
waypoints.append(
{"id": p.id, "name": p.name, "latitude": lat, "longitude": lon}
)
# sum straight-line distances between consecutive waypoints
for i in range(len(waypoints) - 1):
a = waypoints[i]
b = waypoints[i + 1]
try:
total_miles += haversine_miles(a["latitude"], a["longitude"], b["latitude"], b["longitude"])
total_miles += haversine_miles(
a["latitude"], a["longitude"], b["latitude"], b["longitude"]
)
except Exception:
continue
@@ -681,7 +689,9 @@ def htmx_optimize_route(request: HttpRequest) -> HttpResponse:
"total_rides": sum(getattr(p, "ride_count", 0) or 0 for p in parks),
}
html = render_to_string(TRIP_SUMMARY_TEMPLATE, {"summary": summary}, request=request)
html = render_to_string(
TRIP_SUMMARY_TEMPLATE, {"summary": summary}, request=request
)
resp = HttpResponse(html)
# Include waypoints payload in HX-Trigger so client can render route on the map
resp["HX-Trigger"] = json.dumps({"tripOptimized": {"parks": waypoints}})
@@ -710,13 +720,16 @@ def htmx_save_trip(request: HttpRequest) -> HttpResponse:
if request.user.is_authenticated:
try:
from .models import Trip # type: ignore
trip = Trip.objects.create(owner=request.user, name=name)
# attempt to associate parks if the Trip model supports it
try:
trip.parks.set([p.id for p in parks])
except Exception:
pass
trips = list(Trip.objects.filter(owner=request.user).order_by("-created_at")[:10])
trips = list(
Trip.objects.filter(owner=request.user).order_by("-created_at")[:10]
)
except Exception:
trips = []
@@ -735,6 +748,7 @@ def htmx_clear_trip(request: HttpRequest) -> HttpResponse:
resp["HX-Trigger"] = json.dumps({"tripCleared": True})
return resp
class ParkCreateView(LoginRequiredMixin, CreateView):
model = Park
form_class = ParkForm
@@ -775,76 +789,49 @@ class ParkCreateView(LoginRequiredMixin, CreateView):
self.normalize_coordinates(form)
changes = self.prepare_changes_data(form.cleaned_data)
# Use the new queue routing service
result = ModerationService.create_edit_submission_with_queue(
content_object=None, # None for CREATE
# Submit through moderation service
result = ParkService.create_park_with_moderation(
changes=changes,
submitter=self.request.user,
submission_type="CREATE",
reason=self.request.POST.get("reason", ""),
source=self.request.POST.get("source", ""),
)
if result['status'] == 'auto_approved':
# Moderator submission was auto-approved
self.object = result['created_object']
# Handle the result using the service
photos = self.request.FILES.getlist("photos")
service_result = ParkService.handle_park_creation_result(
result=result,
form_data=form.cleaned_data,
photos=photos,
user=self.request.user,
)
if form.cleaned_data.get("latitude") and form.cleaned_data.get("longitude"):
# Create or update ParkLocation
park_location, _ = ParkLocation.objects.get_or_create(
park=self.object,
defaults={
"street_address": form.cleaned_data.get("street_address", ""),
"city": form.cleaned_data.get("city", ""),
"state": form.cleaned_data.get("state", ""),
"country": form.cleaned_data.get("country", "USA"),
"postal_code": form.cleaned_data.get("postal_code", ""),
},
)
park_location.set_coordinates(
form.cleaned_data["latitude"],
form.cleaned_data["longitude"],
)
park_location.save()
photos = self.request.FILES.getlist("photos")
uploaded_count = 0
for photo_file in photos:
try:
ParkPhoto.objects.create(
image=photo_file,
uploaded_by=self.request.user,
park=self.object,
)
uploaded_count += 1
except Exception as e:
messages.error(
self.request,
f"Error uploading photo {photo_file.name}: {str(e)}",
)
# Report any photo upload errors
for error in service_result.get("errors", []):
messages.error(self.request, error)
if service_result["status"] == "auto_approved":
self.object = service_result["park"]
messages.success(
self.request,
f"Successfully created {self.object.name}. "
f"Added {uploaded_count} photo(s).",
f"Added {service_result['uploaded_count']} photo(s).",
)
return HttpResponseRedirect(self.get_success_url())
elif result['status'] == 'queued':
# Regular user submission was queued
elif service_result["status"] == "queued":
messages.success(
self.request,
"Your park submission has been sent for review. "
"You will be notified when it is approved.",
)
# Redirect to parks list since we don't have an object yet
return HttpResponseRedirect(reverse("parks:park_list"))
elif result['status'] == 'failed':
# Auto-approval failed
elif service_result["status"] == "failed":
messages.error(
self.request,
f"Error creating park: {result['message']}. Please check your input and try again.",
f"Error creating park: {service_result.get('message', 'Unknown error')}. "
"Please check your input and try again.",
)
return self.form_invalid(form)
@@ -900,110 +887,43 @@ class ParkUpdateView(LoginRequiredMixin, UpdateView):
Decimal("0.000001"), rounding=ROUND_DOWN
)
def form_valid(self, form: ParkForm) -> HttpResponse: # noqa: C901
def form_valid(self, form: ParkForm) -> HttpResponse:
self.normalize_coordinates(form)
changes = self.prepare_changes_data(form.cleaned_data)
# Use the new queue routing service
result = ModerationService.create_edit_submission_with_queue(
content_object=self.object,
# Submit through moderation service
result = ParkService.update_park_with_moderation(
park=self.object,
changes=changes,
submitter=self.request.user,
submission_type="EDIT",
reason=self.request.POST.get("reason", ""),
source=self.request.POST.get("source", ""),
)
if result['status'] == 'auto_approved':
# Moderator submission was auto-approved
# The object was already updated by the service
self.object = result['created_object']
# Handle the result using the service
photos = self.request.FILES.getlist("photos")
service_result = ParkService.handle_park_update_result(
result=result,
park=self.object,
form_data=form.cleaned_data,
photos=photos,
user=self.request.user,
)
location_data = {
"name": self.object.name,
"location_type": "park",
"latitude": form.cleaned_data.get("latitude"),
"longitude": form.cleaned_data.get("longitude"),
"street_address": form.cleaned_data.get("street_address", ""),
"city": form.cleaned_data.get("city", ""),
"state": form.cleaned_data.get("state", ""),
"country": form.cleaned_data.get("country", ""),
"postal_code": form.cleaned_data.get("postal_code", ""),
}
# Create or update ParkLocation
try:
park_location = self.object.location
# Update existing location
for key, value in location_data.items():
if key in ["latitude", "longitude"] and value:
continue # Handle coordinates separately
if hasattr(park_location, key):
setattr(park_location, key, value)
# Handle coordinates if provided
if "latitude" in location_data and "longitude" in location_data:
if location_data["latitude"] and location_data["longitude"]:
park_location.set_coordinates(
float(location_data["latitude"]),
float(location_data["longitude"]),
)
park_location.save()
except ParkLocation.DoesNotExist:
# Create new ParkLocation
coordinates_data = {}
if "latitude" in location_data and "longitude" in location_data:
if location_data["latitude"] and location_data["longitude"]:
coordinates_data = {
"latitude": float(location_data["latitude"]),
"longitude": float(location_data["longitude"]),
}
# Remove coordinate fields from location_data for creation
creation_data = {
k: v
for k, v in location_data.items()
if k not in ["latitude", "longitude"]
}
creation_data.setdefault("country", "USA")
park_location = ParkLocation.objects.create(
park=self.object, **creation_data
)
if coordinates_data:
park_location.set_coordinates(
coordinates_data["latitude"],
coordinates_data["longitude"],
)
park_location.save()
photos = self.request.FILES.getlist("photos")
uploaded_count = 0
for photo_file in photos:
try:
ParkPhoto.objects.create(
image=photo_file,
uploaded_by=self.request.user,
content_type=ContentType.objects.get_for_model(Park),
object_id=self.object.id,
)
uploaded_count += 1
except Exception as e:
messages.error(
self.request,
f"Error uploading photo {photo_file.name}: {str(e)}",
)
# Report any photo upload errors
for error in service_result.get("errors", []):
messages.error(self.request, error)
if service_result["status"] == "auto_approved":
self.object = service_result["park"]
messages.success(
self.request,
f"Successfully updated {self.object.name}. "
f"Added {uploaded_count} new photo(s).",
f"Added {service_result['uploaded_count']} new photo(s).",
)
return HttpResponseRedirect(self.get_success_url())
elif result['status'] == 'queued':
# Regular user submission was queued
elif service_result["status"] == "queued":
messages.success(
self.request,
f"Your changes to {self.object.name} have been sent for review. "
@@ -1013,11 +933,11 @@ class ParkUpdateView(LoginRequiredMixin, UpdateView):
reverse(PARK_DETAIL_URL, kwargs={"slug": self.object.slug})
)
elif result['status'] == 'failed':
# Auto-approval failed
elif service_result["status"] == "failed":
messages.error(
self.request,
f"Error updating park: {result['message']}. Please check your input and try again.",
f"Error updating park: {service_result.get('message', 'Unknown error')}. "
"Please check your input and try again.",
)
return self.form_invalid(form)
@@ -1133,13 +1053,14 @@ class OperatorListView(ListView):
paginate_by = 24
def get_queryset(self):
"""Get companies that are operators"""
"""Get companies that are operators with optimized query"""
from .models.companies import Company
from django.db.models import Count
return (
Company.objects.filter(roles__contains=["OPERATOR"])
.annotate(park_count=Count("operated_parks"))
.only("id", "name", "slug", "roles", "description")
.order_by("name")
)

View File

@@ -75,13 +75,14 @@ class RideQuerySet(StatusQuerySet, ReviewableQuerySet):
return self.select_related(
"park",
"park__location",
"park_area",
"manufacturer",
"designer",
"ride_model",
"ride_model__manufacturer",
).prefetch_related(
"location",
"rollercoaster_stats",
"coaster_stats",
Prefetch(
"reviews",
queryset=RideReview.objects.select_related("user")
@@ -91,6 +92,12 @@ class RideQuerySet(StatusQuerySet, ReviewableQuerySet):
"photos",
)
def with_coaster_stats(self):
"""Always prefetch coaster_stats for roller coaster queries."""
return self.select_related(
"park", "manufacturer", "ride_model"
).prefetch_related("coaster_stats")
def for_map_display(self):
"""Optimize for map display."""
return (
@@ -176,6 +183,10 @@ class RideManager(StatusManager, ReviewableManager):
def optimized_for_detail(self):
return self.get_queryset().optimized_for_detail()
def with_coaster_stats(self):
"""Always prefetch coaster_stats for roller coaster queries."""
return self.get_queryset().with_coaster_stats()
class RideModelQuerySet(BaseQuerySet):
"""QuerySet for RideModel model."""

View File

@@ -0,0 +1,40 @@
"""
Convert RideModel unique_together to UniqueConstraint.
This migration converts the legacy unique_together constraints to the modern
UniqueConstraint syntax which provides better error messages and more flexibility.
"""
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('rides', '0025_convert_ride_status_to_fsm'),
]
operations = [
# Remove the old unique_together constraint
migrations.AlterUniqueTogether(
name='ridemodel',
unique_together=set(),
),
# Add new UniqueConstraints with better error messages
migrations.AddConstraint(
model_name='ridemodel',
constraint=models.UniqueConstraint(
fields=['manufacturer', 'name'],
name='ridemodel_manufacturer_name_unique',
violation_error_message='A ride model with this name already exists for this manufacturer'
),
),
migrations.AddConstraint(
model_name='ridemodel',
constraint=models.UniqueConstraint(
fields=['manufacturer', 'slug'],
name='ridemodel_manufacturer_slug_unique',
violation_error_message='A ride model with this slug already exists for this manufacturer'
),
),
]

View File

@@ -0,0 +1,49 @@
"""
Mixins for ride views.
This module contains mixins that provide reusable functionality
for ride-related views, reducing code duplication.
"""
from typing import Any, Dict
from django.contrib import messages
from apps.rides.services import RideService
class RideFormMixin:
"""
Mixin for handling ride form submissions with entity suggestions.
Provides common functionality for RideCreateView and RideUpdateView
to handle new manufacturer, designer, and ride model suggestions.
"""
def handle_entity_suggestions(self, form) -> Dict[str, Any]:
"""
Process new entity suggestions from form.
Creates moderation submissions for any new manufacturers,
designers, or ride models that were suggested but don't
exist in the system.
Args:
form: Validated form instance with cleaned_data
Returns:
Dictionary with submission results from RideService
"""
result = RideService.handle_new_entity_suggestions(
form_data=form.cleaned_data,
submitter=self.request.user
)
if result['total_submissions'] > 0:
messages.info(
self.request,
f"Created {result['total_submissions']} moderation submission(s) "
"for new entities"
)
return result

View File

@@ -1,5 +1,6 @@
from django.db import models
from django.utils.text import slugify
from django.core.exceptions import ValidationError
from config.django import base as settings
from apps.core.models import TrackedModel
from apps.core.choices import RichChoiceField
@@ -165,8 +166,18 @@ class RideModel(TrackedModel):
class Meta(TrackedModel.Meta):
ordering = ["manufacturer__name", "name"]
unique_together = [["manufacturer", "name"], ["manufacturer", "slug"]]
constraints = [
# Unique constraints (replacing unique_together for better error messages)
models.UniqueConstraint(
fields=['manufacturer', 'name'],
name='ridemodel_manufacturer_name_unique',
violation_error_message='A ride model with this name already exists for this manufacturer'
),
models.UniqueConstraint(
fields=['manufacturer', 'slug'],
name='ridemodel_manufacturer_slug_unique',
violation_error_message='A ride model with this slug already exists for this manufacturer'
),
# Height range validation
models.CheckConstraint(
name="ride_model_height_range_logical",
@@ -222,6 +233,14 @@ class RideModel(TrackedModel):
else f"{self.manufacturer.name} {self.name}"
)
def clean(self) -> None:
"""Validate RideModel business rules."""
super().clean()
if self.is_discontinued and not self.last_installation_year:
raise ValidationError({
'last_installation_year': 'Discontinued models must have a last installation year'
})
def save(self, *args, **kwargs) -> None:
if not self.slug:
from django.utils.text import slugify

View File

@@ -307,3 +307,85 @@ class RideService:
ride = Ride.objects.select_for_update().get(id=ride_id)
ride.open(user=user)
return ride
@staticmethod
def handle_new_entity_suggestions(
*,
form_data: Dict[str, Any],
submitter: UserType,
) -> Dict[str, Any]:
"""
Handle suggestions for new manufacturers, designers, and ride models.
Creates moderation submissions for entities that don't exist in the system.
This extracts the business logic from RideCreateView and RideUpdateView.
Args:
form_data: Cleaned form data containing search fields and selections
submitter: User making the suggestions
Returns:
Dictionary with lists of created submission IDs by type:
{
'manufacturers': [...],
'designers': [...],
'ride_models': [...],
'total_submissions': int
}
"""
from apps.moderation.services import ModerationService
result = {
'manufacturers': [],
'designers': [],
'ride_models': [],
'total_submissions': 0
}
# Check for new manufacturer
manufacturer_name = form_data.get("manufacturer_search")
if manufacturer_name and not form_data.get("manufacturer"):
submission = ModerationService.create_edit_submission_with_queue(
content_object=None,
changes={"name": manufacturer_name, "roles": ["MANUFACTURER"]},
submitter=submitter,
submission_type="CREATE",
reason=f"New manufacturer suggested: {manufacturer_name}",
)
if submission:
result['manufacturers'].append(submission.id)
result['total_submissions'] += 1
# Check for new designer
designer_name = form_data.get("designer_search")
if designer_name and not form_data.get("designer"):
submission = ModerationService.create_edit_submission_with_queue(
content_object=None,
changes={"name": designer_name, "roles": ["DESIGNER"]},
submitter=submitter,
submission_type="CREATE",
reason=f"New designer suggested: {designer_name}",
)
if submission:
result['designers'].append(submission.id)
result['total_submissions'] += 1
# Check for new ride model
ride_model_name = form_data.get("ride_model_search")
manufacturer = form_data.get("manufacturer")
if ride_model_name and not form_data.get("ride_model") and manufacturer:
submission = ModerationService.create_edit_submission_with_queue(
content_object=None,
changes={
"name": ride_model_name,
"manufacturer": manufacturer.id,
},
submitter=submitter,
submission_type="CREATE",
reason=f"New ride model suggested: {ride_model_name}",
)
if submission:
result['ride_models'].append(submission.id)
result['total_submissions'] += 1
return result

View File

@@ -1,6 +1,6 @@
import logging
from django.db.models.signals import pre_save
from django.db.models.signals import pre_save, post_save
from django.dispatch import receiver
from django.utils import timezone
@@ -10,6 +10,28 @@ from .models import Ride
logger = logging.getLogger(__name__)
# =============================================================================
# Computed Field Maintenance
# =============================================================================
def update_ride_search_text(ride):
"""
Update ride's search_text computed field.
This is called when related objects (park, manufacturer, ride_model)
change and might affect the ride's search text.
"""
if ride is None:
return
try:
ride._populate_computed_fields()
ride.save(update_fields=['search_text'])
logger.debug(f"Updated search_text for ride {ride.pk}")
except Exception as e:
logger.exception(f"Failed to update search_text for ride {ride.pk}: {e}")
@receiver(pre_save, sender=Ride)
def handle_ride_status(sender, instance, **kwargs):
"""
@@ -186,3 +208,58 @@ def apply_post_closing_status(instance, user=None):
f"Applied post_closing_status {target_status} to ride {instance.pk} (direct)"
)
return True
# =============================================================================
# Computed Field Maintenance Signal Handlers
# =============================================================================
@receiver(post_save, sender='parks.Park')
def update_ride_search_text_on_park_change(sender, instance, **kwargs):
"""
Update ride search_text when park name or location changes.
When a park's name changes, all rides at that park need their
search_text regenerated.
"""
try:
for ride in instance.rides.all():
update_ride_search_text(ride)
except Exception as e:
logger.exception(f"Failed to update ride search_text on park change: {e}")
@receiver(post_save, sender='parks.Company')
def update_ride_search_text_on_company_change(sender, instance, **kwargs):
"""
Update ride search_text when manufacturer/designer name changes.
When a company's name changes, all rides manufactured or designed
by that company need their search_text regenerated.
"""
try:
# Update all rides manufactured by this company
for ride in instance.manufactured_rides.all():
update_ride_search_text(ride)
# Update all rides designed by this company
for ride in instance.designed_rides.all():
update_ride_search_text(ride)
except Exception as e:
logger.exception(f"Failed to update ride search_text on company change: {e}")
@receiver(post_save, sender='rides.RideModel')
def update_ride_search_text_on_ride_model_change(sender, instance, **kwargs):
"""
Update ride search_text when ride model name changes.
When a ride model's name changes, all rides using that model need
their search_text regenerated.
"""
try:
for ride in instance.rides.all():
update_ride_search_text(ride)
except Exception as e:
logger.exception(f"Failed to update ride search_text on ride model change: {e}")

View File

@@ -1,80 +1,188 @@
from django.views.generic import DetailView, ListView, CreateView, UpdateView
"""
Views for ride functionality.
This module contains views for managing ride resources including
CRUD operations, search, filtering, and HTMX partial rendering.
View Types:
CBVs:
- RideDetailView: Display ride details
- RideCreateView: Create new ride
- RideUpdateView: Update existing ride
- RideListView: List rides with filtering
- RideSearchView: Search rides with HTMX support
- RideRankingsView: Display ride rankings
- RideRankingDetailView: Display ranking details
- ManufacturerListView: List manufacturers
- DesignerListView: List designers
- SingleCategoryListView: List rides by category
FBVs (HTMX Partials):
- show_coaster_fields: Toggle coaster-specific fields
- ride_status_actions: FSM status actions for moderators
- ride_header_badge: Status badge partial
- search_companies: Company search autocomplete
- search_ride_models: Ride model search autocomplete
- get_search_suggestions: Smart search suggestions
- ranking_history_chart: Ranking history chart data
- ranking_comparisons: Head-to-head comparisons
Dependencies:
- Services: apps.rides.services
- Models: apps.rides.models
- Forms: apps.rides.forms
Code Quality:
- PEP 8 compliant (verified with black, flake8, ruff)
- Maximum line length: 88 characters
- Maximum complexity: 10 (McCabe)
"""
from django.contrib.auth.mixins import LoginRequiredMixin
from django.db.models import Count, Q
from django.http import Http404, HttpRequest, HttpResponse
from django.shortcuts import get_object_or_404, render
from django.urls import reverse
from django.db.models import Q
from django.contrib.auth.mixins import LoginRequiredMixin
from django.http import HttpRequest, HttpResponse, Http404
from django.db.models import Count
from .models.rides import Ride, RideModel
from .choices import RIDE_CATEGORIES
from .models.company import Company
from django.views.generic import CreateView, DetailView, ListView, UpdateView
from apps.moderation.mixins import EditSubmissionMixin, HistoryMixin
from apps.parks.models import Park
from .forms import RideForm, RideSearchForm
from .forms.search import MasterFilterForm
from .services.search import RideSearchService
from apps.parks.models import Park
from apps.moderation.mixins import EditSubmissionMixin, HistoryMixin
from apps.moderation.services import ModerationService
from .models.rankings import RideRanking, RankingSnapshot
from .mixins import RideFormMixin
from .models.company import Company
from .models.rankings import RankingSnapshot, RideRanking
from .models.rides import Ride, RideModel
from .services.ranking_service import RideRankingService
class ParkContextRequired:
"""Mixin to require park context for views"""
"""
Mixin to require park context for views.
Ensures that the view has access to a park_slug URL parameter.
Raises Http404 if park context is not available.
"""
def dispatch(self, request, *args, **kwargs):
"""Check for park context before dispatching to handler."""
if "park_slug" not in self.kwargs:
raise Http404("Park context is required")
return super().dispatch(request, *args, **kwargs)
def show_coaster_fields(request: HttpRequest) -> HttpResponse:
"""Show roller coaster specific fields based on category selection"""
"""
Show roller coaster specific fields based on category selection.
View Type: FBV (HTMX Partial)
URL Pattern: /rides/coaster-fields/
Returns: HTML partial with coaster-specific form fields
Args:
request: HTTP request with 'category' query parameter
Returns:
Empty response for non-coaster categories,
or coaster fields partial for roller coasters
"""
category = request.GET.get("category")
if category != "RC": # Only show for roller coasters
return HttpResponse("")
return render(request, "rides/partials/coaster_fields.html")
def ride_status_actions(request: HttpRequest, park_slug: str, ride_slug: str) -> HttpResponse:
"""Return FSM status actions for ride moderators"""
def ride_status_actions(
request: HttpRequest, park_slug: str, ride_slug: str
) -> HttpResponse:
"""
Return FSM status actions for ride moderators.
View Type: FBV (HTMX Partial)
URL Pattern: /parks/<park_slug>/rides/<ride_slug>/status-actions/
Returns: HTML partial with available status transition actions
Permissions: rides.change_ride
Args:
request: HTTP request
park_slug: Slug of the park
ride_slug: Slug of the ride
Returns:
Empty response for non-moderators, or status actions partial
"""
park = get_object_or_404(Park, slug=park_slug)
ride = get_object_or_404(Ride, park=park, slug=ride_slug)
# Only show to moderators
if not request.user.has_perm('rides.change_ride'):
if not request.user.has_perm("rides.change_ride"):
return HttpResponse("")
return render(request, "rides/partials/ride_status_actions.html", {
"ride": ride,
"park": park,
"user": request.user
})
return render(
request,
"rides/partials/ride_status_actions.html",
{"ride": ride, "park": park, "user": request.user},
)
def ride_header_badge(request: HttpRequest, park_slug: str, ride_slug: str) -> HttpResponse:
"""Return the header status badge partial for a ride"""
def ride_header_badge(
request: HttpRequest, park_slug: str, ride_slug: str
) -> HttpResponse:
"""
Return the header status badge partial for a ride.
View Type: FBV (HTMX Partial)
URL Pattern: /parks/<park_slug>/rides/<ride_slug>/header-badge/
Returns: HTML partial with ride status badge
Args:
request: HTTP request
park_slug: Slug of the park
ride_slug: Slug of the ride
Returns:
Rendered status badge partial
"""
park = get_object_or_404(Park, slug=park_slug)
ride = get_object_or_404(Ride, park=park, slug=ride_slug)
return render(request, "rides/partials/ride_header_badge.html", {
"ride": ride,
"park": park,
"user": request.user
})
return render(
request,
"rides/partials/ride_header_badge.html",
{"ride": ride, "park": park, "user": request.user},
)
class RideDetailView(HistoryMixin, DetailView):
"""View for displaying ride details"""
"""
Display ride details with related data.
View Type: CBV (DetailView)
URL Pattern: /parks/<park_slug>/rides/<ride_slug>/
Template: rides/ride_detail.html
Permissions: Public
Includes history tracking via HistoryMixin for audit trail display.
"""
model = Ride
template_name = "rides/ride_detail.html"
slug_url_kwarg = "ride_slug"
def get_queryset(self):
"""Get ride for the specific park if park_slug is provided"""
"""Get ride for the specific park with optimized query."""
queryset = (
Ride.objects.all()
.select_related("park", "ride_model", "ride_model__manufacturer")
.prefetch_related("photos")
.select_related(
"park",
"park__location",
"park_area",
"manufacturer",
"designer",
"ride_model",
"ride_model__manufacturer",
)
.prefetch_related("photos", "coaster_stats")
)
if "park_slug" in self.kwargs:
@@ -92,15 +200,24 @@ class RideDetailView(HistoryMixin, DetailView):
return context
class RideCreateView(LoginRequiredMixin, ParkContextRequired, CreateView):
"""View for creating a new ride"""
class RideCreateView(
LoginRequiredMixin, ParkContextRequired, RideFormMixin, CreateView
):
"""
View for creating a new ride.
View Type: CBV (CreateView)
URL Pattern: /parks/<park_slug>/rides/add/
Template: rides/ride_form.html
Permissions: LoginRequired
"""
model = Ride
form_class = RideForm
template_name = "rides/ride_form.html"
def get_success_url(self):
"""Get URL to redirect to after successful creation"""
"""Get URL to redirect to after successful creation."""
return reverse(
"parks:rides:ride_detail",
kwargs={
@@ -110,14 +227,14 @@ class RideCreateView(LoginRequiredMixin, ParkContextRequired, CreateView):
)
def get_form_kwargs(self):
"""Pass park to the form"""
"""Pass park to the form."""
kwargs = super().get_form_kwargs()
self.park = get_object_or_404(Park, slug=self.kwargs["park_slug"])
kwargs["park"] = self.park
return kwargs
def get_context_data(self, **kwargs):
"""Add park and park_slug to context"""
"""Add park and park_slug to context."""
context = super().get_context_data(**kwargs)
context["park"] = self.park
context["park_slug"] = self.park.slug
@@ -125,51 +242,26 @@ class RideCreateView(LoginRequiredMixin, ParkContextRequired, CreateView):
return context
def form_valid(self, form):
"""Handle form submission including new items"""
# Check for new manufacturer
manufacturer_name = form.cleaned_data.get("manufacturer_search")
if manufacturer_name and not form.cleaned_data.get("manufacturer"):
ModerationService.create_edit_submission_with_queue(
content_object=None,
changes={"name": manufacturer_name, "roles": ["MANUFACTURER"]},
submitter=self.request.user,
submission_type="CREATE",
reason=f"New manufacturer suggested during ride creation: {manufacturer_name}",
)
# Check for new designer
designer_name = form.cleaned_data.get("designer_search")
if designer_name and not form.cleaned_data.get("designer"):
ModerationService.create_edit_submission_with_queue(
content_object=None,
changes={"name": designer_name, "roles": ["DESIGNER"]},
submitter=self.request.user,
submission_type="CREATE",
reason=f"New designer suggested during ride creation: {designer_name}",
)
# Check for new ride model
ride_model_name = form.cleaned_data.get("ride_model_search")
manufacturer = form.cleaned_data.get("manufacturer")
if ride_model_name and not form.cleaned_data.get("ride_model") and manufacturer:
ModerationService.create_edit_submission_with_queue(
content_object=None,
changes={
"name": ride_model_name,
"manufacturer": manufacturer.id,
},
submitter=self.request.user,
submission_type="CREATE",
reason=f"New ride model suggested during ride creation: {ride_model_name}",
)
"""Handle form submission using RideFormMixin for entity suggestions."""
self.handle_entity_suggestions(form)
return super().form_valid(form)
class RideUpdateView(
LoginRequiredMixin, ParkContextRequired, EditSubmissionMixin, UpdateView
LoginRequiredMixin,
ParkContextRequired,
RideFormMixin,
EditSubmissionMixin,
UpdateView,
):
"""View for updating an existing ride"""
"""
View for updating an existing ride.
View Type: CBV (UpdateView)
URL Pattern: /parks/<park_slug>/rides/<ride_slug>/edit/
Template: rides/ride_form.html
Permissions: LoginRequired
"""
model = Ride
form_class = RideForm
@@ -177,7 +269,7 @@ class RideUpdateView(
slug_url_kwarg = "ride_slug"
def get_success_url(self):
"""Get URL to redirect to after successful update"""
"""Get URL to redirect to after successful update."""
return reverse(
"parks:rides:ride_detail",
kwargs={
@@ -187,18 +279,18 @@ class RideUpdateView(
)
def get_queryset(self):
"""Get ride for the specific park"""
"""Get ride for the specific park."""
return Ride.objects.filter(park__slug=self.kwargs["park_slug"])
def get_form_kwargs(self):
"""Pass park to the form"""
"""Pass park to the form."""
kwargs = super().get_form_kwargs()
self.park = get_object_or_404(Park, slug=self.kwargs["park_slug"])
kwargs["park"] = self.park
return kwargs
def get_context_data(self, **kwargs):
"""Add park and park_slug to context"""
"""Add park and park_slug to context."""
context = super().get_context_data(**kwargs)
context["park"] = self.park
context["park_slug"] = self.park.slug
@@ -206,44 +298,8 @@ class RideUpdateView(
return context
def form_valid(self, form):
"""Handle form submission including new items"""
# Check for new manufacturer
manufacturer_name = form.cleaned_data.get("manufacturer_search")
if manufacturer_name and not form.cleaned_data.get("manufacturer"):
ModerationService.create_edit_submission_with_queue(
content_object=None,
changes={"name": manufacturer_name, "roles": ["MANUFACTURER"]},
submitter=self.request.user,
submission_type="CREATE",
reason=f"New manufacturer suggested during ride update: {manufacturer_name}",
)
# Check for new designer
designer_name = form.cleaned_data.get("designer_search")
if designer_name and not form.cleaned_data.get("designer"):
ModerationService.create_edit_submission_with_queue(
content_object=None,
changes={"name": designer_name, "roles": ["DESIGNER"]},
submitter=self.request.user,
submission_type="CREATE",
reason=f"New designer suggested during ride update: {designer_name}",
)
# Check for new ride model
ride_model_name = form.cleaned_data.get("ride_model_search")
manufacturer = form.cleaned_data.get("manufacturer")
if ride_model_name and not form.cleaned_data.get("ride_model") and manufacturer:
ModerationService.create_edit_submission_with_queue(
content_object=None,
changes={
"name": ride_model_name,
"manufacturer": manufacturer.id,
},
submitter=self.request.user,
submission_type="CREATE",
reason=f"New ride model suggested during ride update: {ride_model_name}",
)
"""Handle form submission using RideFormMixin for entity suggestions."""
self.handle_entity_suggestions(form)
return super().form_valid(form)
@@ -256,13 +312,7 @@ class RideListView(ListView):
paginate_by = 24
def get_queryset(self):
"""Get filtered rides using the advanced search service"""
# Initialize search service
search_service = RideSearchService()
# Parse filters from request
filter_form = MasterFilterForm(self.request.GET)
"""Get filtered rides using the advanced search service."""
# Apply park context if available
park = None
if "park_slug" in self.kwargs:
@@ -275,12 +325,12 @@ class RideListView(ListView):
.select_related("park", "ride_model", "ride_model__manufacturer")
.prefetch_related("photos")
)
if park:
queryset = queryset.filter(park=park)
# Apply basic search if provided
search_query = self.request.GET.get('search', '').strip()
search_query = self.request.GET.get("search", "").strip()
if search_query:
queryset = queryset.filter(name__icontains=search_query)
@@ -306,8 +356,11 @@ class RideListView(ListView):
context["filter_form"] = filter_form
# Use Rich Choice registry directly
from apps.core.choices.registry import get_choices
choices = get_choices("categories", "rides")
context["category_choices"] = [(choice.value, choice.label) for choice in choices]
context["category_choices"] = [
(choice.value, choice.label) for choice in choices
]
# Add filter summary for display
if filter_form.is_valid():
@@ -357,8 +410,12 @@ class SingleCategoryListView(ListView):
context["park_slug"] = self.kwargs["park_slug"]
# Find the category choice by value using Rich Choice registry
from apps.core.choices.registry import get_choices
choices = get_choices("categories", "rides")
category_choice = next((choice for choice in choices if choice.value == self.kwargs["category"]), None)
category_choice = next(
(choice for choice in choices if choice.value == self.kwargs["category"]),
None,
)
context["category"] = category_choice.label if category_choice else "Unknown"
return context
@@ -372,7 +429,7 @@ def search_companies(request: HttpRequest) -> HttpResponse:
query = request.GET.get("q", "").strip()
role = request.GET.get("role", "").upper()
companies = Company.objects.all().order_by("name")
companies = Company.objects.only("id", "name", "slug", "roles").order_by("name")
if role:
companies = companies.filter(roles__contains=[role])
if query:
@@ -439,8 +496,8 @@ def get_search_suggestions(request: HttpRequest) -> HttpResponse:
}
)
# Get matching parks
matching_parks = Park.objects.filter(
# Get matching parks with optimized query
matching_parks = Park.objects.select_related("location").filter(
Q(name__icontains=query) | Q(location__city__icontains=query)
)[:3]
@@ -455,6 +512,7 @@ def get_search_suggestions(request: HttpRequest) -> HttpResponse:
# Add category matches
from apps.core.choices.registry import get_choices
choices = get_choices("categories", "rides")
for choice in choices:
if query in choice.label.lower():
@@ -556,8 +614,11 @@ class RideRankingsView(ListView):
context = super().get_context_data(**kwargs)
# Use Rich Choice registry directly
from apps.core.choices.registry import get_choices
choices = get_choices("categories", "rides")
context["category_choices"] = [(choice.value, choice.label) for choice in choices]
context["category_choices"] = [
(choice.value, choice.label) for choice in choices
]
context["selected_category"] = self.request.GET.get("category", "all")
context["min_riders"] = self.request.GET.get("min_riders", "")
@@ -690,10 +751,11 @@ class ManufacturerListView(ListView):
paginate_by = 24
def get_queryset(self):
"""Get companies that are manufacturers"""
"""Get companies that are manufacturers with optimized query"""
return (
Company.objects.filter(roles__contains=["MANUFACTURER"])
.annotate(ride_count=Count("manufactured_rides"))
.only("id", "name", "slug", "roles", "description")
.order_by("name")
)
@@ -713,10 +775,11 @@ class DesignerListView(ListView):
paginate_by = 24
def get_queryset(self):
"""Get companies that are designers"""
"""Get companies that are designers with optimized query"""
return (
Company.objects.filter(roles__contains=["DESIGNER"])
.annotate(ride_count=Count("designed_rides"))
.only("id", "name", "slug", "roles", "description")
.order_by("name")
)

View File

@@ -115,6 +115,8 @@ MIDDLEWARE = [
"django.middleware.cache.UpdateCacheMiddleware",
"corsheaders.middleware.CorsMiddleware", # CORS middleware for API
"django.middleware.security.SecurityMiddleware",
"apps.core.middleware.security_headers.SecurityHeadersMiddleware", # Custom security headers (CSP, Permissions-Policy)
"apps.core.middleware.rate_limiting.AuthRateLimitMiddleware", # Rate limiting for auth endpoints
"whitenoise.middleware.WhiteNoiseMiddleware",
"django.contrib.sessions.middleware.SessionMiddleware",
"django.middleware.common.CommonMiddleware",
@@ -471,50 +473,91 @@ CACHES = {
}
# Use Redis for sessions
# =============================================================================
# Session Security Settings
# =============================================================================
SESSION_ENGINE = "django.contrib.sessions.backends.cache"
SESSION_CACHE_ALIAS = "sessions"
SESSION_COOKIE_AGE = 86400 # 24 hours
SESSION_COOKIE_AGE = 3600 # 1 hour (reduced from 24 hours for security)
SESSION_SAVE_EVERY_REQUEST = True # Update session on each request (sliding expiry)
SESSION_COOKIE_HTTPONLY = True # Prevent JavaScript access to session cookie
SESSION_EXPIRE_AT_BROWSER_CLOSE = False # Session persists until cookie expires
# Cache middleware settings
CACHE_MIDDLEWARE_SECONDS = 300 # 5 minutes
CACHE_MIDDLEWARE_KEY_PREFIX = "thrillwiki"
# =============================================================================
# JWT Settings
# =============================================================================
# Security considerations:
# - Short access token lifetime reduces window of vulnerability
# - Refresh token rotation prevents token reuse after refresh
# - Token blacklisting allows revocation of compromised tokens
# - JTI claim enables unique token identification for logging
SIMPLE_JWT = {
"ACCESS_TOKEN_LIFETIME": timedelta(minutes=60), # 1 hour
# Token lifetimes
# Security: Shorter access tokens (15 min) provide better security
# but may require more frequent refreshes
"ACCESS_TOKEN_LIFETIME": timedelta(minutes=15), # 15 minutes (reduced from 60)
"REFRESH_TOKEN_LIFETIME": timedelta(days=7), # 7 days
# Token rotation and blacklisting
# Security: Rotate refresh tokens on each use and blacklist old ones
"ROTATE_REFRESH_TOKENS": True,
"BLACKLIST_AFTER_ROTATION": True,
# Update last login on token refresh
"UPDATE_LAST_LOGIN": True,
# Cryptographic settings
"ALGORITHM": "HS256",
"SIGNING_KEY": SECRET_KEY,
"VERIFYING_KEY": None,
# Token validation
"AUDIENCE": None,
"ISSUER": None,
"ISSUER": "thrillwiki", # Added issuer for token validation
"JWK_URL": None,
"LEEWAY": 0,
"LEEWAY": 0, # No leeway for token expiration
# Authentication header
"AUTH_HEADER_TYPES": ("Bearer",),
"AUTH_HEADER_NAME": "HTTP_AUTHORIZATION",
# User identification
"USER_ID_FIELD": "id",
"USER_ID_CLAIM": "user_id",
"USER_AUTHENTICATION_RULE": "rest_framework_simplejwt.authentication.default_user_authentication_rule",
# Token classes
"AUTH_TOKEN_CLASSES": ("rest_framework_simplejwt.tokens.AccessToken",),
"TOKEN_TYPE_CLAIM": "token_type",
"TOKEN_USER_CLASS": "rest_framework_simplejwt.models.TokenUser",
# JTI claim for unique token identification
# Security: Enables token tracking and revocation
"JTI_CLAIM": "jti",
# Sliding token settings (if using sliding tokens)
"SLIDING_TOKEN_REFRESH_EXP_CLAIM": "refresh_exp",
"SLIDING_TOKEN_LIFETIME": timedelta(minutes=60),
"SLIDING_TOKEN_LIFETIME": timedelta(minutes=15),
"SLIDING_TOKEN_REFRESH_LIFETIME": timedelta(days=1),
}
# dj-rest-auth settings
# =============================================================================
# dj-rest-auth Settings
# =============================================================================
REST_AUTH = {
"USE_JWT": True,
"JWT_AUTH_COOKIE": "thrillwiki-auth",
"JWT_AUTH_REFRESH_COOKIE": "thrillwiki-refresh",
"JWT_AUTH_SECURE": not DEBUG, # Use secure cookies in production
# Security: Only send cookies over HTTPS in production
"JWT_AUTH_SECURE": not DEBUG,
# Security: Prevent JavaScript access to cookies
"JWT_AUTH_HTTPONLY": True,
# Security: SameSite cookie attribute (Lax is compatible with OAuth flows)
"JWT_AUTH_SAMESITE": "Lax",
"JWT_AUTH_RETURN_EXPIRATION": True,
"JWT_TOKEN_CLAIMS_SERIALIZER": "rest_framework_simplejwt.serializers.TokenObtainPairSerializer",

View File

@@ -24,14 +24,32 @@ ALLOWED_HOSTS = base.config("ALLOWED_HOSTS")
# CSRF trusted origins for production
CSRF_TRUSTED_ORIGINS = base.config("CSRF_TRUSTED_ORIGINS")
# Security settings for production
# =============================================================================
# Security Settings for Production
# =============================================================================
# SSL/HTTPS enforcement
SECURE_SSL_REDIRECT = True
SESSION_COOKIE_SECURE = True
CSRF_COOKIE_SECURE = True
# HSTS (HTTP Strict Transport Security)
SECURE_HSTS_SECONDS = 31536000 # 1 year
SECURE_HSTS_INCLUDE_SUBDOMAINS = True
SECURE_HSTS_PRELOAD = True
# Session cookie security (stricter than development)
SESSION_COOKIE_SECURE = True # Only send over HTTPS
SESSION_COOKIE_SAMESITE = "Strict" # Stricter than Lax for production
# CSRF cookie security (stricter than development)
CSRF_COOKIE_SECURE = True # Only send over HTTPS
CSRF_COOKIE_SAMESITE = "Strict" # Stricter than Lax for production
# Additional security headers
X_FRAME_OPTIONS = "DENY" # Never allow framing
SECURE_CONTENT_TYPE_NOSNIFF = True
SECURE_REFERRER_POLICY = "strict-origin-when-cross-origin"
SECURE_CROSS_ORIGIN_OPENER_POLICY = "same-origin"
# Production logging
LOGGING = {
"version": 1,

View File

@@ -1,5 +1,8 @@
"""
Security configuration for thrillwiki project.
This module configures security headers and settings to protect against common
web vulnerabilities including XSS, clickjacking, MIME sniffing, and more.
"""
import environ
@@ -14,23 +17,136 @@ TURNSTILE_VERIFY_URL = env(
default="https://challenges.cloudflare.com/turnstile/v0/siteverify",
)
# Security headers and settings (for production)
# =============================================================================
# Security Headers Configuration
# =============================================================================
# X-XSS-Protection: Enables browser's built-in XSS filter
# Note: Modern browsers are deprecating this in favor of CSP, but it's still
# useful for older browsers
SECURE_BROWSER_XSS_FILTER = env.bool("SECURE_BROWSER_XSS_FILTER", default=True)
# X-Content-Type-Options: Prevents MIME type sniffing attacks
# When True, adds "X-Content-Type-Options: nosniff" header
SECURE_CONTENT_TYPE_NOSNIFF = env.bool("SECURE_CONTENT_TYPE_NOSNIFF", default=True)
# X-Frame-Options: Protects against clickjacking attacks
# DENY = Never allow framing (most secure)
# SAMEORIGIN = Only allow framing from same origin
X_FRAME_OPTIONS = env("X_FRAME_OPTIONS", default="DENY")
# Referrer-Policy: Controls how much referrer information is sent
# strict-origin-when-cross-origin = Send full URL for same-origin,
# only origin for cross-origin, nothing for downgrade
SECURE_REFERRER_POLICY = env(
"SECURE_REFERRER_POLICY", default="strict-origin-when-cross-origin"
)
# Cross-Origin-Opener-Policy: Prevents cross-origin attacks via window references
# same-origin = Document can only be accessed by windows from same origin
SECURE_CROSS_ORIGIN_OPENER_POLICY = env(
"SECURE_CROSS_ORIGIN_OPENER_POLICY", default="same-origin"
)
# =============================================================================
# HSTS (HTTP Strict Transport Security) Configuration
# =============================================================================
# Include subdomains in HSTS policy
SECURE_HSTS_INCLUDE_SUBDOMAINS = env.bool(
"SECURE_HSTS_INCLUDE_SUBDOMAINS", default=True
)
SECURE_HSTS_SECONDS = env.int("SECURE_HSTS_SECONDS", default=31536000) # 1 year
# HSTS max-age in seconds (31536000 = 1 year, recommended minimum)
SECURE_HSTS_SECONDS = env.int("SECURE_HSTS_SECONDS", default=31536000)
# HSTS preload: Allow inclusion in browser preload lists
# Only enable after confirming HTTPS works properly for all subdomains
SECURE_HSTS_PRELOAD = env.bool("SECURE_HSTS_PRELOAD", default=False)
# URLs exempt from SSL redirect (e.g., health checks)
SECURE_REDIRECT_EXEMPT = env.list("SECURE_REDIRECT_EXEMPT", default=[])
# Redirect all HTTP requests to HTTPS
SECURE_SSL_REDIRECT = env.bool("SECURE_SSL_REDIRECT", default=False)
# Header used by proxy to indicate HTTPS (e.g., ('HTTP_X_FORWARDED_PROTO', 'https'))
SECURE_PROXY_SSL_HEADER = env.tuple("SECURE_PROXY_SSL_HEADER", default=None)
# Session security
# =============================================================================
# Session Cookie Security
# =============================================================================
# Only send session cookie over HTTPS
SESSION_COOKIE_SECURE = env.bool("SESSION_COOKIE_SECURE", default=False)
# Prevent JavaScript access to session cookie (mitigates XSS)
SESSION_COOKIE_HTTPONLY = env.bool("SESSION_COOKIE_HTTPONLY", default=True)
# SameSite attribute: Protects against CSRF attacks
# Strict = Cookie only sent for same-site requests (most secure)
# Lax = Cookie sent for same-site and top-level navigations (default)
SESSION_COOKIE_SAMESITE = env("SESSION_COOKIE_SAMESITE", default="Lax")
# CSRF security
# =============================================================================
# CSRF Cookie Security
# =============================================================================
# Only send CSRF cookie over HTTPS
CSRF_COOKIE_SECURE = env.bool("CSRF_COOKIE_SECURE", default=False)
# Prevent JavaScript access to CSRF cookie
# Note: Set to False if you need to read the token via JavaScript for AJAX
CSRF_COOKIE_HTTPONLY = env.bool("CSRF_COOKIE_HTTPONLY", default=True)
# SameSite attribute for CSRF cookie
CSRF_COOKIE_SAMESITE = env("CSRF_COOKIE_SAMESITE", default="Lax")
# =============================================================================
# File Upload Security
# =============================================================================
# Maximum size (in bytes) of file to upload into memory (2.5MB)
FILE_UPLOAD_MAX_MEMORY_SIZE = env.int(
"FILE_UPLOAD_MAX_MEMORY_SIZE", default=2621440
)
# Maximum size (in bytes) of request data (10MB)
DATA_UPLOAD_MAX_MEMORY_SIZE = env.int(
"DATA_UPLOAD_MAX_MEMORY_SIZE", default=10485760
)
# File upload permissions (0o644 = rw-r--r--)
FILE_UPLOAD_PERMISSIONS = 0o644
# Directory permissions for uploaded files (0o755 = rwxr-xr-x)
FILE_UPLOAD_DIRECTORY_PERMISSIONS = 0o755
# =============================================================================
# Permissions Policy (Feature Policy successor)
# Controls which browser features can be used
# =============================================================================
PERMISSIONS_POLICY = {
"accelerometer": [],
"ambient-light-sensor": [],
"autoplay": [],
"camera": [],
"display-capture": [],
"document-domain": [],
"encrypted-media": [],
"fullscreen": ["self"],
"geolocation": ["self"], # Required for map features
"gyroscope": [],
"interest-cohort": [], # Block FLoC
"magnetometer": [],
"microphone": [],
"midi": [],
"payment": [],
"picture-in-picture": [],
"publickey-credentials-get": [],
"screen-wake-lock": [],
"sync-xhr": [],
"usb": [],
"web-share": ["self"],
"xr-spatial-tracking": [],
}

View File

@@ -0,0 +1,261 @@
# Code Standards
This document defines the code quality standards for the ThrillWiki backend.
## Formatting & Style
### PEP 8 Compliance
All Python code must comply with PEP 8, verified using:
- **black**: Code formatting (line length: 88)
- **flake8**: Style checking (max-line-length: 88, max-complexity: 10)
- **ruff**: Fast linting and import sorting
### Running Formatters
```bash
# Format code
uv run black backend/
# Check style
uv run flake8 backend/ --max-line-length=88 --max-complexity=10
# Lint and fix
uv run ruff check backend/ --fix
```
## Docstring Requirements
### Coverage
- 100% coverage for public classes and methods
- 100% coverage for all functions
- Optional for private methods (but encouraged)
### Style
Follow Google-style docstrings:
```python
def function_name(arg1: Type1, arg2: Type2) -> ReturnType:
"""
Brief description of what this function does.
Longer description if needed, explaining the purpose,
behavior, and any important details.
Args:
arg1: Description of arg1
arg2: Description of arg2
Returns:
Description of return value
Raises:
ExceptionType: When this exception is raised
Example:
>>> function_name("value1", "value2")
"result"
"""
```
### Class Docstrings
```python
class ClassName:
"""
Brief description of what this class does.
Longer description if needed.
Attributes:
attr1: Description of attr1
attr2: Description of attr2
Example:
instance = ClassName()
instance.method()
"""
```
### View Docstrings
Views should include URL patterns and permissions:
```python
class MyView(DetailView):
"""
Brief description of what this view does.
View Type: CBV (DetailView)
URL Pattern: /resource/<slug>/
Template: app/resource_detail.html
Permissions: LoginRequired
"""
```
## Complexity Guidelines
### Limits
- **Maximum McCabe complexity**: 10
- **Maximum method length**: 50 lines
- **Maximum nesting depth**: 3 levels
### Checking Complexity
```bash
# Check McCabe complexity
uv run flake8 backend/ --max-complexity=10 --select=C901
# Get complexity metrics
uv run radon cc backend/apps/ -a
```
### Refactoring Strategies
1. **Extract helper methods** for distinct responsibilities:
```python
# Before
def process_data(self, data):
# Validate data (10 lines)
# Transform data (10 lines)
# Save data (10 lines)
# Send notifications (10 lines)
pass
# After
def process_data(self, data):
self._validate_data(data)
transformed = self._transform_data(data)
result = self._save_data(transformed)
self._send_notifications(result)
return result
```
2. **Use early returns** to reduce nesting:
```python
# Before
def process(self, data):
if data:
if data.get('field1'):
if data.get('field2'):
return result
return None
# After
def process(self, data):
if not data:
return None
if not data.get('field1'):
return None
if not data.get('field2'):
return None
return result
```
3. **Move complex logic to service layer**
## Service Layer Patterns
### Service Method Signature
Always use keyword-only arguments for service methods:
```python
class MyService:
@staticmethod
def create_entity(
*, # Force keyword-only arguments
name: str,
description: str = "",
created_by: Optional[User] = None,
) -> Entity:
"""Create a new entity."""
pass
```
### Validation Pattern
Always call `full_clean()` before save:
```python
@staticmethod
def create_park(*, name: str, ...) -> Park:
with transaction.atomic():
park = Park(name=name, ...)
park.full_clean() # Validate before save
park.save()
return park
```
## Import Organization
Imports should be organized in this order:
1. Standard library
2. Third-party packages
3. Django imports
4. Local app imports
```python
import logging
from typing import Any, Dict, Optional
from django.contrib.auth import get_user_model
from django.db import transaction
from rest_framework import status
from apps.core.exceptions import ServiceError
from .models import MyModel
```
## Type Hints
Use type hints for all function signatures:
```python
def process_data(
data: Dict[str, Any],
user: Optional[User] = None,
) -> ProcessResult:
"""Process data and return result."""
pass
```
## Testing Requirements
- Maintain or improve test coverage with changes
- Add tests for new service methods
- Add tests for new mixins and base classes
- Run tests before committing:
```bash
pytest backend/tests/ --cov=backend/apps --cov-report=html
```
## Pre-commit Configuration
The following pre-commit hooks are configured:
```yaml
repos:
- repo: https://github.com/psf/black
rev: 24.1.0
hooks:
- id: black
- repo: https://github.com/pycqa/flake8
rev: 7.1.1
hooks:
- id: flake8
args: [--max-line-length=88, --max-complexity=10]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.10
hooks:
- id: ruff
args: [--fix]
```

View File

@@ -0,0 +1,219 @@
# Error Handling Guidelines
This document provides guidelines for handling errors consistently across the ThrillWiki backend.
## Exception Hierarchy
ThrillWiki uses a structured exception hierarchy defined in `apps/core/exceptions.py`:
```
ThrillWikiException (base)
├── ValidationException (400)
├── NotFoundError (404)
├── PermissionDeniedError (403)
├── BusinessLogicError (400)
├── ServiceError (500)
├── ExternalServiceError (502)
├── CacheError (500)
├── Domain-specific exceptions:
│ ├── ParkError
│ │ ├── ParkNotFoundError
│ │ └── ParkOperationError
│ ├── RideError
│ │ ├── RideNotFoundError
│ │ └── RideOperationError
│ ├── LocationError
│ │ ├── InvalidCoordinatesError
│ │ └── GeolocationError
│ ├── ReviewError
│ │ ├── ReviewModerationError
│ │ └── DuplicateReviewError
│ └── AccountError
│ ├── InsufficientPermissionsError
│ └── EmailError
```
## Using ErrorHandler
The `ErrorHandler` class in `apps/core/utils/error_handling.py` provides standardized error handling.
### Template Views
```python
from apps.core.utils.error_handling import ErrorHandler
from apps.core.exceptions import ServiceError
def my_view(request):
try:
result = SomeService.do_operation(...)
except ServiceError as e:
ErrorHandler.handle_view_error(
request,
e,
user_message="The operation failed. Please try again.",
log_message=f"Service operation failed for user {request.user.id}"
)
return redirect("some-fallback")
except ValidationError as e:
ErrorHandler.handle_view_error(
request,
e,
user_message="Invalid data provided",
level="warning"
)
return redirect("form-view")
```
### API Views
```python
from apps.core.utils.error_handling import ErrorHandler
from apps.core.exceptions import ServiceError
from rest_framework import status
class MyAPIView(APIView):
def post(self, request):
try:
result = SomeService.do_operation(...)
return ErrorHandler.api_success_response(
data=result,
message="Operation completed successfully"
)
except ServiceError as e:
return ErrorHandler.handle_api_error(
e,
user_message="Failed to complete operation",
status_code=status.HTTP_400_BAD_REQUEST
)
```
## Best Practices
### 1. Always Catch Specific Exceptions
```python
# Good
try:
park = ParkService.create_park(...)
except ParkOperationError as e:
# Handle park-specific error
pass
except ValidationException as e:
# Handle validation error
pass
# Bad
try:
park = ParkService.create_park(...)
except Exception as e:
# Too broad - loses error context
pass
```
### 2. Log with Appropriate Context
```python
# Good
logger.error(
f"Park creation failed for user {user.id}: {error}",
exc_info=True,
extra={"user_id": user.id, "park_name": name}
)
# Bad
logger.error(f"Error: {error}")
```
### 3. Provide Clear User Messages
```python
# Good - User-friendly and actionable
ErrorHandler.handle_view_error(
request,
error,
user_message="Unable to save your changes. Please check your input and try again."
)
# Bad - Technical details exposed to user
ErrorHandler.handle_view_error(
request,
error,
user_message=f"IntegrityError: UNIQUE constraint failed: parks_park.slug"
)
```
### 4. Use Appropriate HTTP Status Codes
| Error Type | Status Code | When to Use |
|------------|-------------|-------------|
| ValidationException | 400 | Invalid user input |
| NotFoundError | 404 | Resource doesn't exist |
| PermissionDeniedError | 403 | User lacks permission |
| BusinessLogicError | 400 | Business rule violation |
| ServiceError | 500 | Internal service failure |
| ExternalServiceError | 502 | Third-party service failure |
### 5. Never Use Bare `except:` Clauses
```python
# Never do this
try:
something()
except:
pass
# Always specify exception type
try:
something()
except SpecificException:
handle_error()
```
## Error Response Format
### API Error Response
```json
{
"error": "User-friendly error message",
"detail": "Technical error details",
"error_code": "SPECIFIC_ERROR_CODE",
"details": {
"field": "Additional context"
}
}
```
### API Success Response
```json
{
"status": "success",
"message": "Operation completed successfully",
"data": {
// Response data
}
}
```
## Creating Custom Exceptions
When creating domain-specific exceptions:
```python
from apps.core.exceptions import BusinessLogicError
class MyDomainError(BusinessLogicError):
"""Raised when my domain operation fails."""
default_message = "My domain operation failed"
error_code = "MY_DOMAIN_ERROR"
status_code = 400
def __init__(self, context_value: str = None, **kwargs):
if context_value:
kwargs["details"] = {"context": context_value}
kwargs["message"] = f"Operation failed for: {context_value}"
super().__init__(**kwargs)
```

View File

@@ -0,0 +1,158 @@
# View Pattern Guidelines
This document provides guidelines for implementing views in the ThrillWiki backend.
## When to Use CBVs
Use Class-Based Views for:
- CRUD operations (CreateView, UpdateView, DetailView, ListView, DeleteView)
- Complex views with multiple methods
- Views that benefit from inheritance and mixins
- Views with shared context or queryset logic
### CBV Examples
```python
class ParkDetailView(OptimizedDetailView):
"""Display park details with related data."""
model = Park
template_name = "parks/park_detail.html"
select_related_fields = ["location", "operator"]
prefetch_related_fields = ["photos", "rides"]
```
## When to Use FBVs
Use Function-Based Views for:
- Simple HTMX partial renders
- Single-purpose utility endpoints
- Search/autocomplete endpoints
- Status badge/action endpoints
### FBV Examples
```python
def search_parks(request: HttpRequest) -> HttpResponse:
"""
HTMX endpoint for park search autocomplete.
View Type: FBV (HTMX Partial)
URL Pattern: /parks/search/
Returns: HTML partial
"""
query = request.GET.get("q", "").strip()
parks = Park.objects.filter(name__icontains=query)[:10]
return render(request, "parks/partials/search_results.html", {"parks": parks})
```
## When to Use ViewSets
Use DRF ViewSets for:
- REST API endpoints
- Resources with standard CRUD operations
- Resources requiring nested routing
### ViewSet Examples
```python
class ParkPhotoViewSet(ModelViewSet):
"""ViewSet for managing park photos via API."""
queryset = ParkPhoto.objects.all()
serializer_class = ParkPhotoSerializer
permission_classes = [IsAuthenticated]
```
## View Inventory
### Parks App
- **CBVs**: ParkListView, ParkCreateView, ParkUpdateView, ParkDetailView, ParkAreaDetailView, OperatorListView
- **FBVs**: geocode_location, reverse_geocode, search_parks, roadtrip_* (HTMX)
### Rides App
- **CBVs**: RideDetailView, RideCreateView, RideUpdateView, RideListView, SingleCategoryListView, RideRankingsView, RideRankingDetailView, ManufacturerListView, DesignerListView
- **FBVs**: show_coaster_fields, ride_status_actions, ride_header_badge, search_companies, search_ride_models, get_search_suggestions, ranking_history_chart, ranking_comparisons
### Accounts App
- **CBVs**: ProfileView, SettingsView, CustomLoginView, CustomSignupView
- **FBVs**: user_redirect_view, email_required, request_password_reset, reset_password
### Moderation App
- **ViewSets**: ModerationReportViewSet, ModerationQueueViewSet, ModerationActionViewSet, BulkOperationViewSet, UserModerationViewSet
## View Type Indicators
Always include view type information in docstrings:
```python
class ParkDetailView(DetailView):
"""
Display park details with related data.
View Type: CBV (DetailView)
URL Pattern: /parks/<slug>/
Template: parks/park_detail.html
Permissions: Public
"""
def search_parks(request: HttpRequest) -> HttpResponse:
"""
HTMX endpoint for park search autocomplete.
View Type: FBV (HTMX Partial)
URL Pattern: /parks/search/
Returns: HTML partial
"""
```
## Base Classes
### OptimizedListView
Automatically applies select_related and prefetch_related based on class attributes.
```python
from apps.core.views.base import OptimizedListView
class RideListView(OptimizedListView):
model = Ride
select_related_fields = ['park', 'manufacturer']
prefetch_related_fields = ['photos']
```
### OptimizedDetailView
Automatically applies select_related and prefetch_related based on class attributes.
```python
from apps.core.views.base import OptimizedDetailView
class RideDetailView(OptimizedDetailView):
model = Ride
select_related_fields = ['park', 'park__location', 'manufacturer']
prefetch_related_fields = ['photos', 'coaster_stats']
```
## Mixins
### RideFormMixin
Handles entity suggestions in ride forms (manufacturers, designers, models).
```python
from apps.rides.mixins import RideFormMixin
class RideCreateView(RideFormMixin, CreateView):
def form_valid(self, form):
self.handle_entity_suggestions(form)
return super().form_valid(form)
```
## Best Practices
1. **Prefer CBVs for CRUD** - Use Django's built-in generic views
2. **Prefer FBVs for HTMX** - Simple endpoints are easier to understand as functions
3. **Document view types** - Include View Type in all docstrings
4. **Use mixins for shared logic** - Avoid code duplication
5. **Use base classes for query optimization** - OptimizedListView, OptimizedDetailView

View File

@@ -0,0 +1,204 @@
#!/usr/bin/env python
"""
Query Performance Benchmarking Script
This script measures query performance for key operations in the ThrillWiki
application to help identify optimization opportunities and verify improvements.
Usage:
python manage.py shell < scripts/benchmark_queries.py
Or in Django shell:
exec(open('scripts/benchmark_queries.py').read())
"""
import time
import statistics
from functools import wraps
from typing import Callable, Any, List, Dict
from django.db import connection, reset_queries
from django.test.utils import CaptureQueriesContext
from django.conf import settings
# Ensure debug mode for query logging
if not settings.DEBUG:
print("Warning: DEBUG mode is not enabled. Query counts may not be accurate.")
def benchmark(name: str, iterations: int = 5):
"""Decorator to benchmark a function."""
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> Dict[str, Any]:
times = []
query_counts = []
for _ in range(iterations):
reset_queries()
with CaptureQueriesContext(connection) as context:
start = time.perf_counter()
result = func(*args, **kwargs)
end = time.perf_counter()
times.append((end - start) * 1000) # Convert to ms
query_counts.append(len(context.captured_queries))
return {
'name': name,
'avg_time_ms': statistics.mean(times),
'min_time_ms': min(times),
'max_time_ms': max(times),
'std_dev_ms': statistics.stdev(times) if len(times) > 1 else 0,
'avg_queries': statistics.mean(query_counts),
'min_queries': min(query_counts),
'max_queries': max(query_counts),
'iterations': iterations,
}
return wrapper
return decorator
def print_benchmark_result(result: Dict[str, Any]) -> None:
"""Pretty print benchmark results."""
print(f"\n{'='*60}")
print(f"Benchmark: {result['name']}")
print(f"{'='*60}")
print(f" Time (ms): avg={result['avg_time_ms']:.2f}, min={result['min_time_ms']:.2f}, max={result['max_time_ms']:.2f}")
print(f" Std Dev (ms): {result['std_dev_ms']:.2f}")
print(f" Queries: avg={result['avg_queries']:.1f}, min={result['min_queries']}, max={result['max_queries']}")
print(f" Iterations: {result['iterations']}")
def run_benchmarks() -> List[Dict[str, Any]]:
"""Run all benchmarks and return results."""
from apps.parks.models import Park, Company
from apps.rides.models import Ride
results = []
# Park List Optimized
@benchmark("Park.objects.optimized_for_list()")
def bench_park_list_optimized():
parks = Park.objects.optimized_for_list()[:50]
for park in parks:
_ = park.operator
_ = park.coaster_count_calculated if hasattr(park, 'coaster_count_calculated') else None
return list(parks)
results.append(bench_park_list_optimized())
# Park List Non-Optimized (for comparison)
@benchmark("Park.objects.all() (non-optimized)")
def bench_park_list_non_optimized():
parks = Park.objects.all()[:50]
for park in parks:
_ = park.operator # This will cause N+1 queries
return list(parks)
results.append(bench_park_list_non_optimized())
# Park Detail Optimized
@benchmark("Park.objects.optimized_for_detail()")
def bench_park_detail_optimized():
park = Park.objects.optimized_for_detail().first()
if park:
_ = park.operator
_ = list(park.areas.all())
_ = list(park.rides.all())
return park
results.append(bench_park_detail_optimized())
# Ride List Optimized
@benchmark("Ride.objects.optimized_for_list()")
def bench_ride_list_optimized():
rides = Ride.objects.optimized_for_list()[:50]
for ride in rides:
_ = ride.park
_ = ride.manufacturer
return list(rides)
results.append(bench_ride_list_optimized())
# Ride Detail with Coaster Stats
@benchmark("Ride.objects.with_coaster_stats()")
def bench_ride_with_coaster_stats():
rides = Ride.objects.with_coaster_stats()[:20]
for ride in rides:
_ = ride.park
try:
_ = ride.coaster_stats
except Exception:
pass
return list(rides)
results.append(bench_ride_with_coaster_stats())
# Company Manufacturers with Ride Count
@benchmark("Company.objects.manufacturers_with_ride_count()")
def bench_manufacturers_with_count():
companies = Company.objects.manufacturers_with_ride_count()[:20]
for company in companies:
_ = company.ride_count
return list(companies)
results.append(bench_manufacturers_with_count())
# Park Search Autocomplete
@benchmark("Park search_autocomplete()")
def bench_park_autocomplete():
result = Park.objects.get_queryset().search_autocomplete(query="park", limit=10)
return list(result)
results.append(bench_park_autocomplete())
# Park Map Display
@benchmark("Park.objects.for_map_display()")
def bench_park_map_display():
result = Park.objects.for_map_display()
return list(result)
results.append(bench_park_map_display())
return results
def print_summary(results: List[Dict[str, Any]]) -> None:
"""Print a summary table of all benchmarks."""
print("\n" + "="*80)
print("BENCHMARK SUMMARY")
print("="*80)
print(f"{'Benchmark':<45} {'Avg Time (ms)':<15} {'Avg Queries':<15}")
print("-"*80)
for result in results:
print(f"{result['name']:<45} {result['avg_time_ms']:<15.2f} {result['avg_queries']:<15.1f}")
print("="*80)
if __name__ == "__main__" or True: # Always run when executed
print("\n" + "="*80)
print("THRILLWIKI QUERY PERFORMANCE BENCHMARKS")
print("="*80)
print("\nRunning benchmarks...")
try:
results = run_benchmarks()
# Print individual results
for result in results:
print_benchmark_result(result)
# Print summary
print_summary(results)
print("\nBenchmarks complete!")
except Exception as e:
print(f"\nError running benchmarks: {e}")
import traceback
traceback.print_exc()

View File

@@ -1,3 +1,69 @@
/**
* ThrillWiki Main JavaScript
*
* This file contains core functionality including:
* - CSRF token handling for HTMX and AJAX requests
* - Theme management
* - Mobile menu functionality
* - Flash message handling
* - Tooltip initialization
*/
// =============================================================================
// CSRF Token Handling
// =============================================================================
/**
* Get CSRF token from cookies.
* Django sets the CSRF token in a cookie named 'csrftoken'.
*
* @returns {string|null} The CSRF token or null if not found
*/
function getCookie(name) {
let cookieValue = null;
if (document.cookie && document.cookie !== '') {
const cookies = document.cookie.split(';');
for (let i = 0; i < cookies.length; i++) {
const cookie = cookies[i].trim();
if (cookie.substring(0, name.length + 1) === (name + '=')) {
cookieValue = decodeURIComponent(cookie.substring(name.length + 1));
break;
}
}
}
return cookieValue;
}
/**
* Configure HTMX to include CSRF token in all requests.
* This handler runs before every HTMX request and adds the X-CSRFToken header.
*/
document.body.addEventListener('htmx:configRequest', (event) => {
// Only add CSRF token for state-changing methods
const method = event.detail.verb.toUpperCase();
if (['POST', 'PUT', 'PATCH', 'DELETE'].includes(method)) {
// Try to get token from cookie first
const csrfToken = getCookie('csrftoken');
if (csrfToken) {
event.detail.headers['X-CSRFToken'] = csrfToken;
} else {
// Fallback: try to get from meta tag or hidden input
const metaToken = document.querySelector('meta[name="csrf-token"]');
const inputToken = document.querySelector('input[name="csrfmiddlewaretoken"]');
if (metaToken) {
event.detail.headers['X-CSRFToken'] = metaToken.getAttribute('content');
} else if (inputToken) {
event.detail.headers['X-CSRFToken'] = inputToken.value;
}
}
}
});
// =============================================================================
// Theme Handling
// =============================================================================
// Theme handling
document.addEventListener('DOMContentLoaded', () => {
const themeToggle = document.getElementById('theme-toggle');

View File

@@ -1,7 +1,10 @@
{% comment %}
Card Component - Django Template Version of shadcn/ui Card
Usage: {% include 'components/ui/card.html' with title='Card Title' content='Card content' %}
Security: All content variables are sanitized to prevent XSS attacks.
{% endcomment %}
{% load safe_html %}
<div class="rounded-lg border bg-card text-card-foreground shadow-sm {{ class|default:'' }}">
{% if title or header_content %}
@@ -13,25 +16,25 @@ Usage: {% include 'components/ui/card.html' with title='Card Title' content='Car
<p class="text-sm text-muted-foreground">{{ description }}</p>
{% endif %}
{% if header_content %}
{{ header_content|safe }}
{{ header_content|sanitize }}
{% endif %}
</div>
{% endif %}
{% if content or body_content %}
<div class="p-6 pt-0">
{% if content %}
{{ content|safe }}
{{ content|sanitize }}
{% endif %}
{% if body_content %}
{{ body_content|safe }}
{{ body_content|sanitize }}
{% endif %}
</div>
{% endif %}
{% if footer_content %}
<div class="flex items-center p-6 pt-0">
{{ footer_content|safe }}
{{ footer_content|sanitize }}
</div>
{% endif %}
</div>

View File

@@ -316,6 +316,11 @@
<!-- Leaflet JS -->
<script src="https://unpkg.com/leaflet@1.9.4/dist/leaflet.js"></script>
{# Security: Safely embed JSON data using Django's json_script tag to prevent XSS #}
{% if nearby_locations %}
{{ nearby_locations|json_script:"nearby-locations-data" }}
{% endif %}
<script>
// Nearby locations map class
class NearbyMap {
@@ -436,11 +441,14 @@ class NearbyMap {
addLocationMarkers() {
{% if nearby_locations %}
const locations = {{ nearby_locations|safe }};
locations.forEach(location => {
this.addLocationMarker(location);
});
{# Security: Use json_script to safely embed JSON data without XSS risk #}
const locationsData = document.getElementById('nearby-locations-data');
if (locationsData) {
const locations = JSON.parse(locationsData.textContent);
locations.forEach(location => {
this.addLocationMarker(location);
});
}
{% endif %}
}

View File

@@ -77,7 +77,13 @@
{% endif %}
{% if show_trip_action %}
<button onclick="addToTrip({{ location|safe }})"
{# Security: Use data attributes instead of inline JS with |safe #}
<button onclick="addToTripFromElement(this)"
data-location-id="{{ location.id }}"
data-location-type="{{ location.type }}"
data-location-name="{{ location.name }}"
data-location-lat="{{ location.latitude }}"
data-location-lng="{{ location.longitude }}"
class="px-3 py-2 text-sm text-purple-600 border border-purple-600 rounded-lg hover:bg-purple-50 dark:hover:bg-purple-900 transition-colors"
title="Add to trip">
<i class="fas fa-plus"></i>
@@ -316,6 +322,19 @@ window.addToTrip = function(locationData) {
document.dispatchEvent(event);
};
// Security: Helper function to extract location data from element attributes
// instead of using inline JavaScript with unsanitized data
window.addToTripFromElement = function(element) {
const locationData = {
id: element.dataset.locationId,
type: element.dataset.locationType,
name: element.dataset.locationName,
latitude: parseFloat(element.dataset.locationLat),
longitude: parseFloat(element.dataset.locationLng)
};
addToTrip(locationData);
};
// Handle location card selection
document.addEventListener('DOMContentLoaded', function() {
document.addEventListener('click', function(e) {

View File

@@ -150,9 +150,11 @@
{% endif %}
<!-- Custom Content -->
{% load safe_html %}
{% if custom_content %}
<div class="popup-custom">
{{ custom_content|safe }}
{# Security: Sanitize custom content to prevent XSS #}
{{ custom_content|sanitize }}
</div>
{% endif %}
@@ -180,7 +182,13 @@
{% endif %}
{% if show_trip_button %}
<button onclick="addLocationToTrip({{ location|safe }})"
{# Security: Use data attribute for location data instead of inline JS with |safe #}
<button onclick="addLocationToTripFromElement(this)"
data-location-id="{{ location.id }}"
data-location-type="{{ location.type }}"
data-location-name="{{ location.name }}"
data-location-lat="{{ location.latitude }}"
data-location-lng="{{ location.longitude }}"
class="popup-btn popup-btn-accent">
<i class="mr-1 fas fa-plus"></i>{{ trip_button_text|default:"Add to Trip" }}
</button>
@@ -455,11 +463,24 @@ window.addLocationToTrip = function(locationData) {
detail: locationData
});
document.dispatchEvent(event);
// Show feedback
showPopupFeedback('Added to trip!', 'success');
};
// Security: Helper function to extract location data from element attributes
// instead of using inline JavaScript with unsanitized data
window.addLocationToTripFromElement = function(element) {
const locationData = {
id: element.dataset.locationId,
type: element.dataset.locationType,
name: element.dataset.locationName,
latitude: parseFloat(element.dataset.locationLat),
longitude: parseFloat(element.dataset.locationLng)
};
addLocationToTrip(locationData);
};
window.shareLocation = function(type, id) {
// Share location URL
const url = window.location.origin + `/{{ type }}/${id}/`;

193
docs/SECURITY.md Normal file
View File

@@ -0,0 +1,193 @@
# ThrillWiki Security Documentation
This document describes the security measures implemented in ThrillWiki and provides guidelines for maintaining security.
## Security Architecture Overview
ThrillWiki implements defense-in-depth security with multiple layers of protection:
1. **Network Layer**: HTTPS, HSTS, security headers
2. **Application Layer**: Input validation, output encoding, CSRF protection
3. **Authentication Layer**: JWT tokens, rate limiting, session management
4. **Data Layer**: SQL injection prevention, data sanitization
## Security Features
### HTTP Security Headers
The following security headers are configured:
| Header | Value | Purpose |
|--------|-------|---------|
| X-Frame-Options | DENY | Prevents clickjacking attacks |
| X-Content-Type-Options | nosniff | Prevents MIME sniffing |
| Referrer-Policy | strict-origin-when-cross-origin | Controls referrer information |
| Content-Security-Policy | Configured | Controls resource loading |
| Permissions-Policy | Configured | Restricts browser features |
### XSS Prevention
- All user input is escaped by Django's template engine by default
- Custom `|sanitize` filter for user-generated HTML content
- SVG sanitization for icon rendering
- JavaScript data is serialized using Django's `json_script` tag
### CSRF Protection
- Django's CSRF middleware is enabled
- CSRF tokens required for all state-changing requests
- HTMX requests automatically include CSRF tokens via `htmx:configRequest` event
- SameSite cookie attribute set to prevent CSRF attacks
### SQL Injection Prevention
- All database queries use Django ORM
- No raw SQL with user input
- `.extra()` calls replaced with Django ORM functions
- Management commands use parameterized queries
### File Upload Security
- File type validation (extension, MIME type, magic number)
- File size limits (10MB max)
- Image integrity validation using PIL
- Rate limiting (10 uploads per minute)
- Secure filename generation
### Authentication & Authorization
- JWT-based authentication with short-lived access tokens (15 minutes)
- Refresh token rotation with blacklisting
- Rate limiting on authentication endpoints:
- Login: 5 per minute, 30 per hour
- Signup: 3 per minute, 10 per hour
- Password reset: 2 per minute, 5 per hour
- Permission-based access control
### Session Security
- Redis-backed session storage
- 1-hour session timeout with sliding expiry
- HTTPOnly cookies prevent JavaScript access
- Secure cookies in production (HTTPS only)
- SameSite attribute set to prevent CSRF
### Sensitive Data Protection
- Passwords hashed with Django's PBKDF2
- Sensitive fields masked in logs
- Email addresses partially masked in logs
- Error messages don't expose internal details in production
- DEBUG mode disabled in production
## Security Configuration
### Environment Variables
The following environment variables should be set for production:
```bash
DEBUG=False
SECRET_KEY=<strong-random-key>
ALLOWED_HOSTS=yourdomain.com,www.yourdomain.com
CSRF_TRUSTED_ORIGINS=https://yourdomain.com
DATABASE_URL=<secure-database-url>
```
### Production Checklist
Before deploying to production:
- [ ] DEBUG is False
- [ ] SECRET_KEY is a strong, random value
- [ ] ALLOWED_HOSTS is configured
- [ ] HTTPS is enabled (SECURE_SSL_REDIRECT=True)
- [ ] HSTS is enabled (SECURE_HSTS_SECONDS >= 31536000)
- [ ] Secure cookies enabled (SESSION_COOKIE_SECURE=True)
- [ ] Database uses SSL connection
- [ ] Error emails configured (ADMINS setting)
## Security Audit
Run the security audit command:
```bash
python manage.py security_audit --verbose
```
This checks:
- Django security settings
- Configuration analysis
- Middleware configuration
## Vulnerability Reporting
To report a security vulnerability:
1. **Do not** open a public issue
2. Email security concerns to: [security contact]
3. Include:
- Description of the vulnerability
- Steps to reproduce
- Potential impact
- Any suggested fixes
## Security Updates
- Keep Django and dependencies updated
- Monitor security advisories
- Review OWASP Top 10 periodically
- Run security scans (OWASP ZAP, etc.)
## Code Security Guidelines
### Input Validation
```python
# Always validate user input
from apps.core.utils.html_sanitizer import sanitize_html
user_input = request.data.get('content')
safe_content = sanitize_html(user_input)
```
### Template Security
```html
<!-- Use sanitize filter for user content -->
{% load safe_html %}
{{ user_content|sanitize }}
<!-- Use json_script for JavaScript data -->
{{ data|json_script:"data-id" }}
```
### File Uploads
```python
from apps.core.utils.file_scanner import validate_image_upload
try:
validate_image_upload(uploaded_file)
except FileValidationError as e:
return error_response(str(e))
```
### Logging
```python
# Don't log sensitive data
logger.info(f"User {user.id} logged in") # OK
logger.info(f"Password: {password}") # BAD
```
## Dependencies
Security-relevant dependencies:
- `bleach`: HTML sanitization
- `Pillow`: Image validation
- `djangorestframework-simplejwt`: JWT authentication
- `django-cors-headers`: CORS configuration
Keep these updated to patch security vulnerabilities.

155
docs/SECURITY_CHECKLIST.md Normal file
View File

@@ -0,0 +1,155 @@
# ThrillWiki Security Checklist
Use this checklist for code reviews and pre-deployment verification.
## Pre-Deployment Checklist
### Django Settings
- [ ] `DEBUG = False`
- [ ] `SECRET_KEY` is unique and strong (50+ characters)
- [ ] `ALLOWED_HOSTS` is configured (no wildcards)
- [ ] `CSRF_TRUSTED_ORIGINS` is configured
- [ ] `SECURE_SSL_REDIRECT = True`
- [ ] `SECURE_HSTS_SECONDS >= 31536000` (1 year)
- [ ] `SECURE_HSTS_INCLUDE_SUBDOMAINS = True`
- [ ] `SECURE_HSTS_PRELOAD = True`
### Cookie Security
- [ ] `SESSION_COOKIE_SECURE = True`
- [ ] `SESSION_COOKIE_HTTPONLY = True`
- [ ] `SESSION_COOKIE_SAMESITE = 'Strict'`
- [ ] `CSRF_COOKIE_SECURE = True`
- [ ] `CSRF_COOKIE_SAMESITE = 'Strict'`
### Database
- [ ] Database password is strong
- [ ] Database connection uses SSL
- [ ] Database user has minimal required permissions
- [ ] No raw SQL with user input
### Environment
- [ ] Environment variables are used for secrets
- [ ] No secrets in version control
- [ ] `.env` file is in `.gitignore`
- [ ] Production logs don't contain sensitive data
## Code Review Checklist
### Input Validation
- [ ] All user input is validated
- [ ] File uploads use `validate_image_upload()`
- [ ] User-generated HTML uses `|sanitize` filter
- [ ] URLs are validated with `sanitize_url()`
- [ ] Form data uses Django forms/serializers
### Output Encoding
- [ ] No `|safe` filter on user-controlled content
- [ ] JSON data uses `json_script` tag
- [ ] JavaScript strings use `escapejs` filter
- [ ] SVG icons use `|sanitize_svg` filter
### Authentication
- [ ] Sensitive views require `@login_required`
- [ ] API views have appropriate `permission_classes`
- [ ] Password changes invalidate sessions
- [ ] Rate limiting on auth endpoints
### Authorization
- [ ] Object-level permissions checked
- [ ] Users can only access their own data
- [ ] Admin actions require proper permissions
- [ ] No privilege escalation paths
### Data Protection
- [ ] Sensitive data not logged
- [ ] PII masked in logs
- [ ] Error messages don't expose internals
- [ ] Secure deletion of sensitive data
### CSRF
- [ ] All forms include `{% csrf_token %}`
- [ ] AJAX requests include CSRF header
- [ ] CSRF exemptions are documented and justified
### SQL Injection
- [ ] No raw SQL with user input
- [ ] No `.extra()` with user input
- [ ] Parameterized queries for raw SQL
- [ ] Django ORM used for queries
## Incident Response
### If a Vulnerability is Found
1. [ ] Document the vulnerability
2. [ ] Assess impact and affected users
3. [ ] Develop and test a fix
4. [ ] Deploy fix to production
5. [ ] Notify affected users if needed
6. [ ] Post-mortem analysis
### If a Breach is Suspected
1. [ ] Isolate affected systems
2. [ ] Preserve logs and evidence
3. [ ] Notify relevant parties
4. [ ] Investigate scope
5. [ ] Remediate and restore
6. [ ] Document lessons learned
## Regular Security Tasks
### Weekly
- [ ] Review error logs for anomalies
- [ ] Check rate limiting effectiveness
- [ ] Monitor failed login attempts
### Monthly
- [ ] Run `python manage.py security_audit`
- [ ] Review and update dependencies
- [ ] Check for security advisories
### Quarterly
- [ ] Full security review
- [ ] Penetration testing
- [ ] Update security documentation
- [ ] Review and rotate secrets
## Security Tools
### Recommended Tools
- **OWASP ZAP**: Web application scanner
- **bandit**: Python security linter
- **safety**: Python dependency checker
- **pip-audit**: Vulnerability scanner for Python packages
### Running Security Scans
```bash
# Run Django security check
python manage.py check --tag=security
# Run security audit
python manage.py security_audit --verbose
# Check for vulnerable dependencies
pip-audit
# Run Python security linter
bandit -r backend/
```

View File

@@ -0,0 +1,290 @@
# Database Optimization Guide
This document describes the database optimization strategies implemented in ThrillWiki.
## Overview
The application uses several optimization techniques to ensure fast query performance:
1. **Indexing Strategy** - Strategic use of B-tree, GIN, and composite indexes
2. **Query Optimization** - Proper use of `select_related` and `prefetch_related`
3. **Computed Fields** - Pre-computed values for common aggregations
4. **Manager Methods** - Optimized query patterns encapsulated in managers
## Indexing Strategy
### B-tree Indexes (Standard)
Standard B-tree indexes are used for fields that are frequently filtered or sorted:
| Model | Field | Index Type | Purpose |
|-------|-------|------------|---------|
| User | `is_banned` | B-tree | Fast filtering of banned users |
| User | `role` | B-tree | Fast filtering by user role |
| User | `(is_banned, role)` | Composite | Common query pattern |
| Park | `status` | B-tree | Filter by park status |
| Park | `search_text` | GIN trigram | Full-text search |
| Ride | `status` | B-tree | Filter by ride status |
| Ride | `search_text` | GIN trigram | Full-text search |
### GIN Indexes
GIN (Generalized Inverted Index) indexes are used for array fields and full-text search:
| Model | Field | Purpose |
|-------|-------|---------|
| Company | `roles` | Fast array containment queries (`roles__contains=["MANUFACTURER"]`) |
| Park | `search_text` | Full-text search with trigram similarity |
| Ride | `search_text` | Full-text search with trigram similarity |
### Creating GIN Indexes
```sql
-- Array containment index
CREATE INDEX IF NOT EXISTS parks_company_roles_gin_idx
ON parks_company USING gin(roles);
-- Full-text search index (if using tsvector)
CREATE INDEX IF NOT EXISTS parks_park_search_idx
ON parks_park USING gin(search_text gin_trgm_ops);
```
## Query Optimization Patterns
### Manager Methods
The application uses custom managers with optimized query methods:
#### Park Queries
```python
# List view - includes prefetched relations and stats
parks = Park.objects.optimized_for_list()
# Detail view - deep prefetching for all related data
park = Park.objects.optimized_for_detail().get(slug='magic-kingdom')
# Map display - minimal fields for markers
parks = Park.objects.for_map_display()
# Autocomplete - limited fields, fast lookup
results = Park.objects.get_queryset().search_autocomplete(query='disney', limit=10)
```
#### Ride Queries
```python
# List view with related objects
rides = Ride.objects.optimized_for_list()
# Detail view with stats
ride = Ride.objects.optimized_for_detail().get(slug='space-mountain')
# With coaster statistics
rides = Ride.objects.with_coaster_stats().filter(category='RC')
```
#### Company Queries
```python
# Manufacturers with ride counts
manufacturers = Company.objects.manufacturers_with_ride_count()
# Designers with ride counts
designers = Company.objects.designers_with_ride_count()
# Operators with park counts
operators = Company.objects.operators_with_park_count()
```
### Avoiding N+1 Queries
Always use the optimized manager methods instead of raw queries:
```python
# BAD - causes N+1 queries
for park in Park.objects.all():
print(park.operator.name) # Each access hits DB
# GOOD - single query with prefetch
for park in Park.objects.optimized_for_list():
print(park.operator.name) # Already loaded
```
### Using only() for Minimal Data
When you only need specific fields, use `only()`:
```python
# Only fetch necessary fields
companies = Company.objects.filter(roles__contains=["MANUFACTURER"]).only(
'id', 'name', 'slug', 'roles'
)
```
## Computed Fields
### Park Computed Fields
| Field | Description | Updated When |
|-------|-------------|--------------|
| `ride_count` | Number of operating rides | Ride created/deleted/status changed |
| `coaster_count` | Number of operating coasters | Ride created/deleted/status changed |
| `opening_year` | Year extracted from opening_date | Park saved with opening_date |
| `search_text` | Combined searchable text | Park/Location/Company name changes |
### Ride Computed Fields
| Field | Description | Updated When |
|-------|-------------|--------------|
| `opening_year` | Year extracted from opening_date | Ride saved with opening_date |
| `search_text` | Combined searchable text | Ride/Park/Company/RideModel changes |
### Signal Handlers
Signals automatically update computed fields:
```python
# When a park location changes, update search_text
@receiver(post_save, sender='parks.ParkLocation')
def update_park_search_text_on_location_change(sender, instance, **kwargs):
if hasattr(instance, 'park') and instance.park:
instance.park._populate_computed_fields()
instance.park.save(update_fields=['search_text'])
```
## CheckConstraints
Database-level constraints ensure data integrity:
### User Constraints
```python
# Banned users must have a ban_date
models.CheckConstraint(
name='user_ban_consistency',
check=models.Q(is_banned=False) | models.Q(ban_date__isnull=False),
)
```
### RideModel Constraints
```python
# Unique name per manufacturer
models.UniqueConstraint(
fields=['manufacturer', 'name'],
name='ridemodel_manufacturer_name_unique',
)
# Installation year range must be valid
models.CheckConstraint(
name="ride_model_installation_years_logical",
condition=models.Q(first_installation_year__isnull=True) |
models.Q(last_installation_year__isnull=True) |
models.Q(first_installation_year__lte=models.F("last_installation_year")),
)
```
## Performance Benchmarking
Use the benchmark script to measure query performance:
```bash
# Run benchmarks
python manage.py shell < scripts/benchmark_queries.py
```
Key metrics to monitor:
- Average query time (< 100ms for list views, < 50ms for detail views)
- Number of queries per operation (avoid N+1 patterns)
- Index usage (check query plans)
## Migration Best Practices
### Adding Indexes
```python
# Use RunSQL for GIN indexes (not natively supported by Django)
migrations.RunSQL(
sql="CREATE INDEX IF NOT EXISTS ... USING gin(...)",
reverse_sql="DROP INDEX IF EXISTS ..."
)
```
### Adding Constraints
```python
# Use AddConstraint for proper dependency handling
migrations.AddConstraint(
model_name='user',
constraint=models.CheckConstraint(...)
)
```
### Rollback Procedures
Each migration should be reversible:
```bash
# Rollback specific migration
python manage.py migrate accounts 0012
# Check migration plan before applying
python manage.py migrate --plan
```
## Monitoring
### Query Analysis
Enable query logging in development:
```python
LOGGING = {
'handlers': {
'console': {'class': 'logging.StreamHandler'},
},
'loggers': {
'django.db.backends': {
'level': 'DEBUG',
'handlers': ['console'],
}
}
}
```
### Index Usage
Check if indexes are being used:
```sql
EXPLAIN ANALYZE SELECT * FROM parks_park WHERE status = 'OPERATING';
```
## Quick Reference
### Common Query Patterns
| Operation | Method |
|-----------|--------|
| Park list page | `Park.objects.optimized_for_list()` |
| Park detail page | `Park.objects.optimized_for_detail()` |
| Map markers | `Park.objects.for_map_display()` |
| Search autocomplete | `Park.objects.get_queryset().search_autocomplete()` |
| Ride list page | `Ride.objects.optimized_for_list()` |
| Ride detail page | `Ride.objects.optimized_for_detail()` |
| Manufacturer list | `Company.objects.manufacturers_with_ride_count()` |
| Operator list | `Company.objects.operators_with_park_count()` |
### Index Commands
```sql
-- List all indexes for a table
\di+ parks_park*
-- Check index usage statistics
SELECT * FROM pg_stat_user_indexes WHERE relname = 'parks_park';
-- Rebuild an index
REINDEX INDEX parks_company_roles_gin_idx;
```

View File

@@ -1,3 +1,11 @@
"""
Photo upload and management views for ThrillWiki.
Security Note:
All uploads are validated for file type, size, and content before being saved.
Rate limiting is enforced to prevent abuse.
"""
from django.http import JsonResponse
from django.views.decorators.http import require_http_methods
from django.contrib.auth.decorators import login_required
@@ -7,6 +15,12 @@ import json
import logging
from .models import Photo
from apps.core.utils.file_scanner import (
validate_image_upload,
FileValidationError,
check_upload_rate_limit,
increment_upload_count,
)
logger = logging.getLogger(__name__)
@@ -14,18 +28,37 @@ logger = logging.getLogger(__name__)
@login_required
@require_http_methods(["POST"])
def upload_photo(request):
"""Handle photo upload for any model"""
"""
Handle photo upload for any model.
Security measures:
- Rate limiting to prevent abuse
- File type validation (extension, MIME type, magic number)
- File size validation
- Image integrity validation
- Permission checks
"""
try:
# Security: Check rate limiting before processing upload
is_allowed, rate_limit_message = check_upload_rate_limit(request.user.id)
if not is_allowed:
logger.warning(
f"User {request.user} exceeded upload rate limit"
)
return JsonResponse(
{"error": rate_limit_message},
status=429, # Too Many Requests
)
# Get app label, model, and object ID
app_label = request.POST.get("app_label")
model = request.POST.get("model")
object_id = request.POST.get("object_id")
# Log received data
# Log received data (don't log file contents for security)
logger.debug(
f"Received upload request - app_label: {app_label}, model: {model}, object_id: {object_id}"
)
logger.debug(f"Files in request: {request.FILES}")
# Validate required fields
missing_fields = []
@@ -44,6 +77,19 @@ def upload_photo(request):
status=400,
)
# Security: Validate uploaded file before processing
uploaded_file = request.FILES["image"]
try:
validate_image_upload(uploaded_file)
except FileValidationError as e:
logger.warning(
f"User {request.user} attempted to upload invalid file: {str(e)}"
)
return JsonResponse(
{"error": str(e)},
status=400,
)
# Get content type
try:
content_type = ContentType.objects.get(
@@ -87,17 +133,19 @@ def upload_photo(request):
# Create the photo
photo = Photo.objects.create(
image=request.FILES["image"],
image=uploaded_file,
content_type=content_type,
object_id=obj.pk,
uploaded_by=request.user, # Add the user who uploaded the photo
uploaded_by=request.user,
is_primary=not Photo.objects.filter(
content_type=content_type, object_id=obj.pk
).exists(),
is_approved=is_approved,
# Auto-approve if the user is a moderator, admin, or superuser
)
# Security: Increment upload count for rate limiting
increment_upload_count(request.user.id)
return JsonResponse(
{
"id": photo.pk,
@@ -111,7 +159,7 @@ def upload_photo(request):
except Exception as e:
logger.error(f"Error in upload_photo: {str(e)}", exc_info=True)
return JsonResponse(
{"error": f"An error occurred while uploading the photo: {str(e)}"},
{"error": "An error occurred while uploading the photo"},
status=400,
)

View File

@@ -1,10 +1,11 @@
{% load widget_tweaks %}
{% load widget_tweaks safe_html %}
{# Security: Icon SVGs are sanitized to prevent XSS. Only trusted SVG elements are allowed. #}
{% if type == 'button' or type == 'submit' or type == 'reset' %}
<button
<button
type="{{ type|default:'button' }}"
{% if id %}id="{{ id }}"{% endif %}
class="inline-flex items-center justify-center rounded-md text-sm font-medium ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50
class="inline-flex items-center justify-center rounded-md text-sm font-medium ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50
{% if variant == 'destructive' %}bg-destructive text-destructive-foreground hover:bg-destructive/90
{% elif variant == 'outline' %}border border-input bg-background hover:bg-accent hover:text-accent-foreground
{% elif variant == 'secondary' %}bg-secondary text-secondary-foreground hover:bg-secondary/80
@@ -26,7 +27,7 @@
{% if x_data %}x-data="{{ x_data }}"{% endif %}
{% if x_on_click %}@click="{{ x_on_click }}"{% endif %}>
{% if icon %}
{{ icon|safe }}
{{ icon|sanitize_svg }}
{% if label %}<span class="ml-2">{{ label }}</span>{% endif %}
{% else %}
{{ label|default:content }}
@@ -34,9 +35,9 @@
{% block button_content %}{% endblock %}
</button>
{% else %}
<a
<a
href="{{ href }}"
class="inline-flex items-center justify-center rounded-md text-sm font-medium ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50
class="inline-flex items-center justify-center rounded-md text-sm font-medium ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50
{% if variant == 'destructive' %}bg-destructive text-destructive-foreground hover:bg-destructive/90
{% elif variant == 'outline' %}border border-input bg-background hover:bg-accent hover:text-accent-foreground
{% elif variant == 'secondary' %}bg-secondary text-secondary-foreground hover:bg-secondary/80
@@ -51,7 +52,7 @@
{% if x_data %}x-data="{{ x_data }}"{% endif %}
{% if x_on_click %}@click="{{ x_on_click }}"{% endif %}>
{% if icon %}
{{ icon|safe }}
{{ icon|sanitize_svg }}
{% if label %}<span class="ml-2">{{ label }}</span>{% endif %}
{% else %}
{{ label|default:content }}

View File

@@ -1,3 +1,6 @@
{% load safe_html %}
{# Security: Header, content, and footer are sanitized to prevent XSS. #}
<div class="border rounded-lg shadow-sm bg-card text-card-foreground">
{% if header or title %}
<div class="flex flex-col space-y-1.5 p-6">
@@ -8,19 +11,19 @@
<p class="text-sm text-muted-foreground">{{ description }}</p>
{% endif %}
{% if header %}
{{ header|safe }}
{{ header|sanitize }}
{% endif %}
</div>
{% endif %}
<div class="p-6 pt-0">
{{ content|safe }}
{{ content|sanitize }}
{% block card_content %}{% endblock %}
</div>
{% if footer %}
<div class="flex items-center p-6 pt-0">
{{ footer|safe }}
{{ footer|sanitize }}
</div>
{% endif %}
</div>

View File

@@ -1,3 +1,6 @@
{% load safe_html %}
{# Security: Dialog content and footer are sanitized to prevent XSS. #}
<div class="fixed inset-0 z-50 flex items-start justify-center sm:items-center"
role="dialog"
aria-modal="true"
@@ -9,42 +12,42 @@
x-transition:leave="transition ease-in duration-100"
x-transition:leave-start="opacity-100"
x-transition:leave-end="opacity-0">
<!-- Backdrop -->
<div class="fixed inset-0 transition-all bg-background/80 backdrop-blur-sm"
<div class="fixed inset-0 transition-all bg-background/80 backdrop-blur-sm"
@click="open = false; setTimeout(() => $el.closest('[hx-history-elt]').innerHTML = '', 200)"></div>
<!-- Dialog Panel -->
<div class="fixed z-50 grid w-full gap-4 p-6 duration-200 border shadow-lg bg-background sm:max-w-lg sm:rounded-lg"
x-show="open"
x-transition:enter="animate-in fade-in-0 zoom-in-95 slide-in-from-left-1/2 slide-in-from-top-[48%]"
x-transition:leave="animate-out fade-out-0 zoom-out-95 slide-out-to-left-1/2 slide-out-to-top-[48%]">
<div class="flex flex-col space-y-1.5 text-center sm:text-left">
{% if title %}
<h2 class="text-lg font-semibold leading-none tracking-tight">
{{ title }}
</h2>
{% endif %}
{% if description %}
<p class="text-sm text-muted-foreground">
{{ description }}
</p>
{% endif %}
</div>
<div class="py-4">
{{ content|safe }}
{{ content|sanitize }}
{% block dialog_content %}{% endblock %}
</div>
{% if footer %}
<div class="flex flex-col-reverse sm:flex-row sm:justify-end sm:space-x-2">
{{ footer|safe }}
{{ footer|sanitize }}
</div>
{% endif %}
<!-- Close Button -->
<button class="absolute right-4 top-4 rounded-sm opacity-70 ring-offset-background transition-opacity hover:opacity-100 focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2 disabled:pointer-events-none data-[state=open]:bg-accent data-[state=open]:text-muted-foreground"
@click="open = false">

View File

@@ -1,12 +1,15 @@
{% load safe_html %}
{# Security: Trigger and content are sanitized to allow only safe HTML elements. #}
<div class="relative inline-block text-left" x-data="{ open: false }" @click.away="open = false">
<div @click="open = !open">
{% if trigger %}
{{ trigger|safe }}
{{ trigger|sanitize }}
{% else %}
<button type="button"
class="inline-flex w-full justify-center gap-x-1.5 rounded-md bg-white px-3 py-2 text-sm font-semibold text-gray-900 shadow-sm ring-1 ring-inset ring-gray-300 hover:bg-gray-50"
id="menu-button"
aria-expanded="true"
<button type="button"
class="inline-flex w-full justify-center gap-x-1.5 rounded-md bg-white px-3 py-2 text-sm font-semibold text-gray-900 shadow-sm ring-1 ring-inset ring-gray-300 hover:bg-gray-50"
id="menu-button"
aria-expanded="true"
aria-haspopup="true">
{{ label|default:"Options" }}
<svg class="w-5 h-5 -mr-1 text-gray-400" viewBox="0 0 20 20" fill="currentColor" aria-hidden="true">
@@ -29,7 +32,7 @@
aria-labelledby="menu-button"
tabindex="-1">
<div class="py-1" role="none">
{{ content|safe }}
{{ content|sanitize }}
{% block dropdown_content %}{% endblock %}
</div>
</div>