mirror of
https://github.com/pacnpal/thrillwiki_django_no_react.git
synced 2025-12-30 01:27:00 -05:00
feat: Implement MFA authentication, add ride statistics model, and update various services, APIs, and tests across the application.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from rest_framework import serializers
|
||||
from drf_spectacular.utils import extend_schema_field
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.accounts.models import UserProfile
|
||||
from apps.accounts.serializers import UserSerializer # existing shared user serializer
|
||||
|
||||
@@ -24,7 +25,7 @@ class UserProfileUpdateInputSerializer(serializers.ModelSerializer):
|
||||
from django_cloudflareimages_toolkit.models import CloudflareImage
|
||||
image, _ = CloudflareImage.objects.get_or_create(cloudflare_id=cloudflare_id)
|
||||
instance.avatar = image
|
||||
|
||||
|
||||
return super().update(instance, validated_data)
|
||||
|
||||
|
||||
|
||||
@@ -2,8 +2,14 @@
|
||||
URL configuration for user account management API endpoints.
|
||||
"""
|
||||
|
||||
from django.urls import path
|
||||
from . import views
|
||||
from django.urls import include, path
|
||||
from rest_framework.routers import DefaultRouter
|
||||
|
||||
from . import views, views_credits, views_magic_link
|
||||
|
||||
# Register ViewSets
|
||||
router = DefaultRouter()
|
||||
router.register(r"credits", views_credits.RideCreditViewSet, basename="ride-credit")
|
||||
|
||||
urlpatterns = [
|
||||
# Admin endpoints for user management
|
||||
@@ -108,19 +114,18 @@ urlpatterns = [
|
||||
path("profile/avatar/upload/", views.upload_avatar, name="upload_avatar"),
|
||||
path("profile/avatar/save/", views.save_avatar_image, name="save_avatar_image"),
|
||||
path("profile/avatar/delete/", views.delete_avatar, name="delete_avatar"),
|
||||
|
||||
|
||||
# Login history endpoint
|
||||
path("login-history/", views.get_login_history, name="get_login_history"),
|
||||
|
||||
# Magic Link (Login by Code) endpoints
|
||||
path("magic-link/request/", views_magic_link.request_magic_link, name="request_magic_link"),
|
||||
path("magic-link/verify/", views_magic_link.verify_magic_link, name="verify_magic_link"),
|
||||
|
||||
# Public Profile
|
||||
path("profiles/<str:username>/", views.get_public_user_profile, name="get_public_user_profile"),
|
||||
]
|
||||
|
||||
# Register ViewSets
|
||||
from rest_framework.routers import DefaultRouter
|
||||
from . import views_credits
|
||||
from django.urls import include
|
||||
|
||||
router = DefaultRouter()
|
||||
router.register(r"credits", views_credits.RideCreditViewSet, basename="ride-credit")
|
||||
|
||||
urlpatterns += [
|
||||
# ViewSet routes
|
||||
path("", include(router.urls)),
|
||||
]
|
||||
|
||||
|
||||
@@ -6,43 +6,44 @@ user deletion while preserving submissions, profile management, settings,
|
||||
preferences, privacy, notifications, and security.
|
||||
"""
|
||||
|
||||
from apps.api.v1.serializers.accounts import (
|
||||
CompleteUserSerializer,
|
||||
PublicUserSerializer,
|
||||
UserPreferencesSerializer,
|
||||
NotificationSettingsSerializer,
|
||||
PrivacySettingsSerializer,
|
||||
SecuritySettingsSerializer,
|
||||
UserStatisticsSerializer,
|
||||
UserListSerializer,
|
||||
AccountUpdateSerializer,
|
||||
ProfileUpdateSerializer,
|
||||
ThemePreferenceSerializer,
|
||||
UserNotificationSerializer,
|
||||
NotificationPreferenceSerializer,
|
||||
MarkNotificationsReadSerializer,
|
||||
AvatarUploadSerializer,
|
||||
)
|
||||
from apps.accounts.services import UserDeletionService
|
||||
from apps.accounts.export_service import UserExportService
|
||||
from apps.accounts.models import (
|
||||
User,
|
||||
UserProfile,
|
||||
UserNotification,
|
||||
NotificationPreference,
|
||||
)
|
||||
from apps.lists.models import UserList
|
||||
import logging
|
||||
from rest_framework import status
|
||||
from rest_framework.decorators import api_view, permission_classes
|
||||
from rest_framework.permissions import IsAuthenticated, IsAdminUser
|
||||
from rest_framework.response import Response
|
||||
from drf_spectacular.utils import extend_schema, OpenApiParameter
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
|
||||
from django.shortcuts import get_object_or_404
|
||||
from rest_framework.permissions import AllowAny
|
||||
from django.utils import timezone
|
||||
from django_cloudflareimages_toolkit.models import CloudflareImage
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, extend_schema
|
||||
from rest_framework import status
|
||||
from rest_framework.decorators import api_view, permission_classes
|
||||
from rest_framework.permissions import AllowAny, IsAdminUser, IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
|
||||
from apps.accounts.export_service import UserExportService
|
||||
from apps.accounts.models import (
|
||||
NotificationPreference,
|
||||
User,
|
||||
UserNotification,
|
||||
UserProfile,
|
||||
)
|
||||
from apps.accounts.services import UserDeletionService
|
||||
from apps.api.v1.serializers.accounts import (
|
||||
AccountUpdateSerializer,
|
||||
AvatarUploadSerializer,
|
||||
CompleteUserSerializer,
|
||||
MarkNotificationsReadSerializer,
|
||||
NotificationPreferenceSerializer,
|
||||
NotificationSettingsSerializer,
|
||||
PrivacySettingsSerializer,
|
||||
ProfileUpdateSerializer,
|
||||
PublicUserSerializer,
|
||||
SecuritySettingsSerializer,
|
||||
ThemePreferenceSerializer,
|
||||
UserListSerializer,
|
||||
UserNotificationSerializer,
|
||||
UserPreferencesSerializer,
|
||||
UserStatisticsSerializer,
|
||||
)
|
||||
from apps.lists.models import UserList
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -307,7 +308,7 @@ def save_avatar_image(request):
|
||||
try:
|
||||
cloudflare_image = CloudflareImage.objects.get(
|
||||
cloudflare_id=cloudflare_image_id)
|
||||
|
||||
|
||||
# Update existing record with latest data from Cloudflare
|
||||
cloudflare_image.status = 'uploaded'
|
||||
cloudflare_image.uploaded_at = timezone.now()
|
||||
@@ -319,7 +320,7 @@ def save_avatar_image(request):
|
||||
cloudflare_image.height = image_data.get('height')
|
||||
cloudflare_image.format = image_data.get('format', '')
|
||||
cloudflare_image.save()
|
||||
|
||||
|
||||
except CloudflareImage.DoesNotExist:
|
||||
# Create new CloudflareImage record from API response
|
||||
cloudflare_image = CloudflareImage.objects.create(
|
||||
@@ -367,7 +368,7 @@ def save_avatar_image(request):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete old avatar from Cloudflare: {str(e)}")
|
||||
# Continue with database deletion even if Cloudflare deletion fails
|
||||
|
||||
|
||||
old_avatar.delete()
|
||||
|
||||
# Debug logging to see what's happening with the CloudflareImage
|
||||
@@ -442,7 +443,7 @@ def delete_avatar(request):
|
||||
avatar_to_delete = profile.avatar
|
||||
profile.avatar = None
|
||||
profile.save()
|
||||
|
||||
|
||||
# Delete from Cloudflare first, then from database
|
||||
try:
|
||||
from django_cloudflareimages_toolkit.services import CloudflareImagesService
|
||||
@@ -452,7 +453,7 @@ def delete_avatar(request):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete avatar from Cloudflare: {str(e)}")
|
||||
# Continue with database deletion even if Cloudflare deletion fails
|
||||
|
||||
|
||||
avatar_to_delete.delete()
|
||||
|
||||
# Get the default avatar URL
|
||||
@@ -1273,10 +1274,10 @@ def update_security_settings(request):
|
||||
|
||||
# Handle security settings updates
|
||||
if "two_factor_enabled" in request.data:
|
||||
setattr(user, "two_factor_enabled", request.data["two_factor_enabled"])
|
||||
user.two_factor_enabled = request.data["two_factor_enabled"]
|
||||
|
||||
if "login_notifications" in request.data:
|
||||
setattr(user, "login_notifications", request.data["login_notifications"])
|
||||
user.login_notifications = request.data["login_notifications"]
|
||||
|
||||
user.save()
|
||||
|
||||
@@ -1612,7 +1613,7 @@ def export_user_data(request):
|
||||
except Exception as e:
|
||||
logger.error(f"Error exporting data for user {request.user.id}: {e}", exc_info=True)
|
||||
return Response(
|
||||
{"error": "Failed to generate data export"},
|
||||
{"error": "Failed to generate data export"},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
@@ -1636,54 +1637,73 @@ def get_public_user_profile(request, username):
|
||||
return Response(serializer.data, status=status.HTTP_200_OK)
|
||||
|
||||
|
||||
# === MISSING FUNCTION IMPLEMENTATIONS ===
|
||||
|
||||
|
||||
@extend_schema(
|
||||
operation_id="request_account_deletion",
|
||||
summary="Request account deletion",
|
||||
description="Request deletion of the authenticated user's account.",
|
||||
operation_id="get_login_history",
|
||||
summary="Get user login history",
|
||||
description=(
|
||||
"Returns the authenticated user's recent login history including "
|
||||
"IP addresses, devices, and timestamps for security auditing."
|
||||
),
|
||||
parameters=[
|
||||
OpenApiParameter(
|
||||
name="limit",
|
||||
type=OpenApiTypes.INT,
|
||||
location=OpenApiParameter.QUERY,
|
||||
description="Maximum number of entries to return (default: 20, max: 100)",
|
||||
),
|
||||
],
|
||||
responses={
|
||||
200: {"description": "Deletion request created"},
|
||||
400: {"description": "Cannot delete account"},
|
||||
},
|
||||
tags=["Self-Service Account Management"],
|
||||
)
|
||||
@api_view(["POST"])
|
||||
@permission_classes([IsAuthenticated])
|
||||
def request_account_deletion(request):
|
||||
"""Request account deletion."""
|
||||
try:
|
||||
user = request.user
|
||||
|
||||
# Check if user can be deleted
|
||||
can_delete, reason = UserDeletionService.can_delete_user(user)
|
||||
if not can_delete:
|
||||
return Response(
|
||||
{"success": False, "error": reason},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Create deletion request
|
||||
deletion_request = UserDeletionService.create_deletion_request(user)
|
||||
|
||||
return Response(
|
||||
{
|
||||
"success": True,
|
||||
"message": "Verification code sent to your email",
|
||||
"expires_at": deletion_request.expires_at,
|
||||
"email": user.email,
|
||||
200: {
|
||||
"description": "Login history entries",
|
||||
"example": {
|
||||
"results": [
|
||||
{
|
||||
"id": 1,
|
||||
"ip_address": "192.168.1.1",
|
||||
"user_agent": "Mozilla/5.0...",
|
||||
"login_method": "PASSWORD",
|
||||
"login_method_display": "Password",
|
||||
"login_timestamp": "2024-12-27T10:30:00Z",
|
||||
"country": "United States",
|
||||
"city": "New York",
|
||||
}
|
||||
],
|
||||
"count": 1,
|
||||
},
|
||||
status=status.HTTP_200_OK,
|
||||
)
|
||||
},
|
||||
401: {"description": "Authentication required"},
|
||||
},
|
||||
tags=["User Security"],
|
||||
)
|
||||
@api_view(["GET"])
|
||||
@permission_classes([IsAuthenticated])
|
||||
def get_login_history(request):
|
||||
"""Get user login history for security auditing."""
|
||||
from apps.accounts.login_history import LoginHistory
|
||||
|
||||
user = request.user
|
||||
limit = min(int(request.query_params.get("limit", 20)), 100)
|
||||
|
||||
# Get login history for user
|
||||
entries = LoginHistory.objects.filter(user=user).order_by("-login_timestamp")[:limit]
|
||||
|
||||
# Serialize
|
||||
results = []
|
||||
for entry in entries:
|
||||
results.append({
|
||||
"id": entry.id,
|
||||
"ip_address": entry.ip_address,
|
||||
"user_agent": entry.user_agent[:100] if entry.user_agent else None, # Truncate long user agents
|
||||
"login_method": entry.login_method,
|
||||
"login_method_display": dict(LoginHistory._meta.get_field('login_method').choices).get(entry.login_method, entry.login_method),
|
||||
"login_timestamp": entry.login_timestamp.isoformat(),
|
||||
"country": entry.country,
|
||||
"city": entry.city,
|
||||
"success": entry.success,
|
||||
})
|
||||
|
||||
return Response({
|
||||
"results": results,
|
||||
"count": len(results),
|
||||
})
|
||||
|
||||
except ValueError as e:
|
||||
return Response(
|
||||
{"success": False, "error": str(e)},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
except Exception as e:
|
||||
return Response(
|
||||
{"success": False, "error": f"Error creating deletion request: {str(e)}"},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
from rest_framework import viewsets, permissions, filters
|
||||
from django.db import transaction
|
||||
from django_filters.rest_framework import DjangoFilterBackend
|
||||
from apps.rides.models.credits import RideCredit
|
||||
from apps.api.v1.serializers.ride_credits import RideCreditSerializer
|
||||
from drf_spectacular.utils import extend_schema, OpenApiParameter
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, extend_schema
|
||||
from rest_framework import filters, permissions, status, viewsets
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.response import Response
|
||||
|
||||
from apps.api.v1.serializers.ride_credits import RideCreditSerializer
|
||||
from apps.rides.models.credits import RideCredit
|
||||
|
||||
|
||||
class RideCreditViewSet(viewsets.ModelViewSet):
|
||||
"""
|
||||
@@ -14,8 +19,8 @@ class RideCreditViewSet(viewsets.ModelViewSet):
|
||||
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
|
||||
filter_backends = [DjangoFilterBackend, filters.OrderingFilter]
|
||||
filterset_fields = ['user__username', 'ride__park__slug', 'ride__manufacturer__slug']
|
||||
ordering_fields = ['first_ridden_at', 'last_ridden_at', 'created_at', 'count', 'rating']
|
||||
ordering = ['-last_ridden_at']
|
||||
ordering_fields = ['first_ridden_at', 'last_ridden_at', 'created_at', 'count', 'rating', 'display_order']
|
||||
ordering = ['display_order', '-last_ridden_at']
|
||||
|
||||
def get_queryset(self):
|
||||
"""
|
||||
@@ -23,18 +28,77 @@ class RideCreditViewSet(viewsets.ModelViewSet):
|
||||
Optionally filter by user via query param ?user=username
|
||||
"""
|
||||
queryset = RideCredit.objects.all().select_related('ride', 'ride__park', 'user')
|
||||
|
||||
|
||||
# Filter by user if provided
|
||||
username = self.request.query_params.get('user')
|
||||
if username:
|
||||
queryset = queryset.filter(user__username=username)
|
||||
|
||||
|
||||
return queryset
|
||||
|
||||
def perform_create(self, serializer):
|
||||
"""Associate the current user with the ride credit."""
|
||||
serializer.save(user=self.request.user)
|
||||
|
||||
@action(detail=False, methods=['post'], permission_classes=[permissions.IsAuthenticated])
|
||||
@extend_schema(
|
||||
summary="Reorder ride credits",
|
||||
description="Bulk update the display order of ride credits. Send a list of {id, order} objects.",
|
||||
request={
|
||||
'application/json': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'order': {
|
||||
'type': 'array',
|
||||
'items': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'id': {'type': 'integer'},
|
||||
'order': {'type': 'integer'}
|
||||
},
|
||||
'required': ['id', 'order']
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
def reorder(self, request):
|
||||
"""
|
||||
Bulk update display_order for multiple credits.
|
||||
Expects: {"order": [{"id": 1, "order": 0}, {"id": 2, "order": 1}, ...]}
|
||||
"""
|
||||
order_data = request.data.get('order', [])
|
||||
|
||||
if not order_data:
|
||||
return Response(
|
||||
{'error': 'No order data provided'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# Validate that all credits belong to the current user
|
||||
credit_ids = [item['id'] for item in order_data]
|
||||
user_credits = RideCredit.objects.filter(
|
||||
id__in=credit_ids,
|
||||
user=request.user
|
||||
).values_list('id', flat=True)
|
||||
|
||||
if set(credit_ids) != set(user_credits):
|
||||
return Response(
|
||||
{'error': 'You can only reorder your own credits'},
|
||||
status=status.HTTP_403_FORBIDDEN
|
||||
)
|
||||
|
||||
# Bulk update in a transaction
|
||||
with transaction.atomic():
|
||||
for item in order_data:
|
||||
RideCredit.objects.filter(
|
||||
id=item['id'],
|
||||
user=request.user
|
||||
).update(display_order=item['order'])
|
||||
|
||||
return Response({'status': 'reordered', 'count': len(order_data)})
|
||||
|
||||
@extend_schema(
|
||||
summary="List ride credits",
|
||||
description="List ride credits. filter by user username.",
|
||||
@@ -49,3 +113,4 @@ class RideCreditViewSet(viewsets.ModelViewSet):
|
||||
)
|
||||
def list(self, request, *args, **kwargs):
|
||||
return super().list(request, *args, **kwargs)
|
||||
|
||||
|
||||
180
backend/apps/api/v1/accounts/views_magic_link.py
Normal file
180
backend/apps/api/v1/accounts/views_magic_link.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
Magic Link (Login by Code) API views.
|
||||
|
||||
Provides API endpoints for passwordless login via email code.
|
||||
Uses django-allauth's built-in login-by-code functionality.
|
||||
"""
|
||||
from django.conf import settings
|
||||
from drf_spectacular.utils import OpenApiExample, extend_schema
|
||||
from rest_framework import status
|
||||
from rest_framework.decorators import api_view, permission_classes
|
||||
from rest_framework.permissions import AllowAny
|
||||
from rest_framework.response import Response
|
||||
|
||||
try:
|
||||
from allauth.account.internal.flows.login_by_code import perform_login_by_code, request_login_code
|
||||
from allauth.account.models import EmailAddress
|
||||
from allauth.account.utils import user_email # noqa: F401 - imported to verify availability
|
||||
HAS_LOGIN_BY_CODE = True
|
||||
except ImportError:
|
||||
HAS_LOGIN_BY_CODE = False
|
||||
|
||||
|
||||
@extend_schema(
|
||||
summary="Request magic link login code",
|
||||
description="Send a one-time login code to the user's email address.",
|
||||
request={
|
||||
'application/json': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'email': {'type': 'string', 'format': 'email'}
|
||||
},
|
||||
'required': ['email']
|
||||
}
|
||||
},
|
||||
responses={
|
||||
200: {'description': 'Login code sent successfully'},
|
||||
400: {'description': 'Invalid email or feature disabled'},
|
||||
},
|
||||
examples=[
|
||||
OpenApiExample(
|
||||
'Request login code',
|
||||
value={'email': 'user@example.com'},
|
||||
request_only=True
|
||||
)
|
||||
]
|
||||
)
|
||||
@api_view(['POST'])
|
||||
@permission_classes([AllowAny])
|
||||
def request_magic_link(request):
|
||||
"""
|
||||
Request a login code to be sent to the user's email.
|
||||
|
||||
This is the first step of the magic link flow:
|
||||
1. User enters their email
|
||||
2. If the email exists, a code is sent
|
||||
3. User enters the code to complete login
|
||||
"""
|
||||
if not getattr(settings, 'ACCOUNT_LOGIN_BY_CODE_ENABLED', False):
|
||||
return Response(
|
||||
{'error': 'Magic link login is not enabled'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
if not HAS_LOGIN_BY_CODE:
|
||||
return Response(
|
||||
{'error': 'Login by code is not available in this version of allauth'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
email = request.data.get('email', '').lower().strip()
|
||||
|
||||
if not email:
|
||||
return Response(
|
||||
{'error': 'Email is required'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
# Check if email exists (don't reveal if it doesn't for security)
|
||||
try:
|
||||
email_address = EmailAddress.objects.get(email__iexact=email, verified=True)
|
||||
user = email_address.user
|
||||
|
||||
# Request the login code
|
||||
request_login_code(request._request, user)
|
||||
|
||||
return Response({
|
||||
'success': True,
|
||||
'message': 'If an account exists with this email, a login code has been sent.',
|
||||
'timeout': getattr(settings, 'ACCOUNT_LOGIN_BY_CODE_TIMEOUT', 300)
|
||||
})
|
||||
|
||||
except EmailAddress.DoesNotExist:
|
||||
# Don't reveal that the email doesn't exist
|
||||
return Response({
|
||||
'success': True,
|
||||
'message': 'If an account exists with this email, a login code has been sent.',
|
||||
'timeout': getattr(settings, 'ACCOUNT_LOGIN_BY_CODE_TIMEOUT', 300)
|
||||
})
|
||||
|
||||
|
||||
@extend_schema(
|
||||
summary="Verify magic link code",
|
||||
description="Verify the login code and complete the login process.",
|
||||
request={
|
||||
'application/json': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'email': {'type': 'string', 'format': 'email'},
|
||||
'code': {'type': 'string'}
|
||||
},
|
||||
'required': ['email', 'code']
|
||||
}
|
||||
},
|
||||
responses={
|
||||
200: {'description': 'Login successful'},
|
||||
400: {'description': 'Invalid or expired code'},
|
||||
}
|
||||
)
|
||||
@api_view(['POST'])
|
||||
@permission_classes([AllowAny])
|
||||
def verify_magic_link(request):
|
||||
"""
|
||||
Verify the login code and complete the login.
|
||||
|
||||
This is the second step of the magic link flow.
|
||||
"""
|
||||
if not getattr(settings, 'ACCOUNT_LOGIN_BY_CODE_ENABLED', False):
|
||||
return Response(
|
||||
{'error': 'Magic link login is not enabled'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
if not HAS_LOGIN_BY_CODE:
|
||||
return Response(
|
||||
{'error': 'Login by code is not available'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
email = request.data.get('email', '').lower().strip()
|
||||
code = request.data.get('code', '').strip()
|
||||
|
||||
if not email or not code:
|
||||
return Response(
|
||||
{'error': 'Email and code are required'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
try:
|
||||
email_address = EmailAddress.objects.get(email__iexact=email, verified=True)
|
||||
user = email_address.user
|
||||
|
||||
# Attempt to verify the code and log in
|
||||
success = perform_login_by_code(request._request, user, code)
|
||||
|
||||
if success:
|
||||
return Response({
|
||||
'success': True,
|
||||
'message': 'Login successful',
|
||||
'user': {
|
||||
'id': user.id,
|
||||
'username': user.username,
|
||||
'email': user.email
|
||||
}
|
||||
})
|
||||
else:
|
||||
return Response(
|
||||
{'error': 'Invalid or expired code. Please request a new one.'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
except EmailAddress.DoesNotExist:
|
||||
return Response(
|
||||
{'error': 'Invalid email or code'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
except Exception:
|
||||
return Response(
|
||||
{'error': 'Invalid or expired code. Please request a new one.'},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
385
backend/apps/api/v1/auth/mfa.py
Normal file
385
backend/apps/api/v1/auth/mfa.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""
|
||||
MFA (Multi-Factor Authentication) API Views
|
||||
|
||||
Provides REST API endpoints for MFA operations using django-allauth's mfa module.
|
||||
Supports TOTP (Time-based One-Time Password) authentication.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from io import BytesIO
|
||||
|
||||
from django.conf import settings
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework import status
|
||||
from rest_framework.decorators import api_view, permission_classes
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
|
||||
try:
|
||||
import qrcode
|
||||
HAS_QRCODE = True
|
||||
except ImportError:
|
||||
HAS_QRCODE = False
|
||||
|
||||
|
||||
@extend_schema(
|
||||
operation_id="get_mfa_status",
|
||||
summary="Get MFA status for current user",
|
||||
description="Returns whether MFA is enabled and what methods are configured.",
|
||||
responses={
|
||||
200: {
|
||||
"description": "MFA status",
|
||||
"example": {
|
||||
"mfa_enabled": True,
|
||||
"totp_enabled": True,
|
||||
"recovery_codes_count": 10,
|
||||
},
|
||||
},
|
||||
},
|
||||
tags=["MFA"],
|
||||
)
|
||||
@api_view(["GET"])
|
||||
@permission_classes([IsAuthenticated])
|
||||
def get_mfa_status(request):
|
||||
"""Get MFA status for current user."""
|
||||
from allauth.mfa.models import Authenticator
|
||||
|
||||
user = request.user
|
||||
authenticators = Authenticator.objects.filter(user=user)
|
||||
|
||||
totp_enabled = authenticators.filter(type=Authenticator.Type.TOTP).exists()
|
||||
recovery_enabled = authenticators.filter(type=Authenticator.Type.RECOVERY_CODES).exists()
|
||||
|
||||
# Count recovery codes if any
|
||||
recovery_count = 0
|
||||
if recovery_enabled:
|
||||
try:
|
||||
recovery_auth = authenticators.get(type=Authenticator.Type.RECOVERY_CODES)
|
||||
recovery_count = len(recovery_auth.data.get("codes", []))
|
||||
except Authenticator.DoesNotExist:
|
||||
pass
|
||||
|
||||
return Response({
|
||||
"mfa_enabled": totp_enabled,
|
||||
"totp_enabled": totp_enabled,
|
||||
"recovery_codes_enabled": recovery_enabled,
|
||||
"recovery_codes_count": recovery_count,
|
||||
})
|
||||
|
||||
|
||||
@extend_schema(
|
||||
operation_id="setup_totp",
|
||||
summary="Initialize TOTP setup",
|
||||
description="Generates a new TOTP secret and returns the QR code for scanning.",
|
||||
responses={
|
||||
200: {
|
||||
"description": "TOTP setup data",
|
||||
"example": {
|
||||
"secret": "ABCDEFGHIJKLMNOP",
|
||||
"provisioning_uri": "otpauth://totp/ThrillWiki:user@example.com?secret=...",
|
||||
"qr_code_base64": "data:image/png;base64,...",
|
||||
},
|
||||
},
|
||||
},
|
||||
tags=["MFA"],
|
||||
)
|
||||
@api_view(["POST"])
|
||||
@permission_classes([IsAuthenticated])
|
||||
def setup_totp(request):
|
||||
"""Generate TOTP secret and QR code for setup."""
|
||||
from allauth.mfa.totp.internal import auth as totp_auth
|
||||
|
||||
user = request.user
|
||||
|
||||
# Generate TOTP secret
|
||||
secret = totp_auth.get_totp_secret(None) # Generate new secret
|
||||
|
||||
# Build provisioning URI
|
||||
issuer = getattr(settings, "MFA_TOTP_ISSUER", "ThrillWiki")
|
||||
account_name = user.email or user.username
|
||||
uri = f"otpauth://totp/{issuer}:{account_name}?secret={secret}&issuer={issuer}"
|
||||
|
||||
# Generate QR code if qrcode library is available
|
||||
qr_code_base64 = None
|
||||
if HAS_QRCODE:
|
||||
qr = qrcode.make(uri)
|
||||
buffer = BytesIO()
|
||||
qr.save(buffer, format="PNG")
|
||||
qr_code_base64 = f"data:image/png;base64,{base64.b64encode(buffer.getvalue()).decode()}"
|
||||
|
||||
# Store secret in session for later verification
|
||||
request.session["pending_totp_secret"] = secret
|
||||
|
||||
return Response({
|
||||
"secret": secret,
|
||||
"provisioning_uri": uri,
|
||||
"qr_code_base64": qr_code_base64,
|
||||
})
|
||||
|
||||
|
||||
@extend_schema(
|
||||
operation_id="activate_totp",
|
||||
summary="Activate TOTP with verification code",
|
||||
description="Verifies the TOTP code and activates 2FA for the user.",
|
||||
request={
|
||||
"application/json": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "6-digit TOTP code from authenticator app",
|
||||
"example": "123456",
|
||||
}
|
||||
},
|
||||
"required": ["code"],
|
||||
}
|
||||
},
|
||||
responses={
|
||||
200: {
|
||||
"description": "TOTP activated successfully",
|
||||
"example": {
|
||||
"success": True,
|
||||
"message": "Two-factor authentication enabled",
|
||||
"recovery_codes": ["ABCD1234", "EFGH5678"],
|
||||
},
|
||||
},
|
||||
400: {"description": "Invalid code or missing setup data"},
|
||||
},
|
||||
tags=["MFA"],
|
||||
)
|
||||
@api_view(["POST"])
|
||||
@permission_classes([IsAuthenticated])
|
||||
def activate_totp(request):
|
||||
"""Verify TOTP code and activate MFA."""
|
||||
from allauth.mfa.models import Authenticator
|
||||
from allauth.mfa.recovery_codes.internal import auth as recovery_auth
|
||||
from allauth.mfa.totp.internal import auth as totp_auth
|
||||
|
||||
user = request.user
|
||||
code = request.data.get("code", "").strip()
|
||||
|
||||
if not code:
|
||||
return Response(
|
||||
{"success": False, "error": "Verification code is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Get pending secret from session
|
||||
secret = request.session.get("pending_totp_secret")
|
||||
if not secret:
|
||||
return Response(
|
||||
{"success": False, "error": "No pending TOTP setup. Please start setup again."},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Verify the code
|
||||
if not totp_auth.validate_totp_code(secret, code):
|
||||
return Response(
|
||||
{"success": False, "error": "Invalid verification code"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Check if already has TOTP
|
||||
if Authenticator.objects.filter(user=user, type=Authenticator.Type.TOTP).exists():
|
||||
return Response(
|
||||
{"success": False, "error": "TOTP is already enabled"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Create TOTP authenticator
|
||||
Authenticator.objects.create(
|
||||
user=user,
|
||||
type=Authenticator.Type.TOTP,
|
||||
data={"secret": secret},
|
||||
)
|
||||
|
||||
# Generate recovery codes
|
||||
codes = recovery_auth.generate_recovery_codes()
|
||||
Authenticator.objects.create(
|
||||
user=user,
|
||||
type=Authenticator.Type.RECOVERY_CODES,
|
||||
data={"codes": codes},
|
||||
)
|
||||
|
||||
# Clear session
|
||||
del request.session["pending_totp_secret"]
|
||||
|
||||
return Response({
|
||||
"success": True,
|
||||
"message": "Two-factor authentication enabled",
|
||||
"recovery_codes": codes,
|
||||
})
|
||||
|
||||
|
||||
@extend_schema(
|
||||
operation_id="deactivate_totp",
|
||||
summary="Disable TOTP authentication",
|
||||
description="Removes TOTP from the user's account after password verification.",
|
||||
request={
|
||||
"application/json": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"password": {
|
||||
"type": "string",
|
||||
"description": "Current password for confirmation",
|
||||
}
|
||||
},
|
||||
"required": ["password"],
|
||||
}
|
||||
},
|
||||
responses={
|
||||
200: {
|
||||
"description": "TOTP disabled",
|
||||
"example": {"success": True, "message": "Two-factor authentication disabled"},
|
||||
},
|
||||
400: {"description": "Invalid password or MFA not enabled"},
|
||||
},
|
||||
tags=["MFA"],
|
||||
)
|
||||
@api_view(["POST"])
|
||||
@permission_classes([IsAuthenticated])
|
||||
def deactivate_totp(request):
|
||||
"""Disable TOTP authentication."""
|
||||
from allauth.mfa.models import Authenticator
|
||||
|
||||
user = request.user
|
||||
password = request.data.get("password", "")
|
||||
|
||||
# Verify password
|
||||
if not user.check_password(password):
|
||||
return Response(
|
||||
{"success": False, "error": "Invalid password"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Remove TOTP and recovery codes
|
||||
deleted_count, _ = Authenticator.objects.filter(
|
||||
user=user,
|
||||
type__in=[Authenticator.Type.TOTP, Authenticator.Type.RECOVERY_CODES]
|
||||
).delete()
|
||||
|
||||
if deleted_count == 0:
|
||||
return Response(
|
||||
{"success": False, "error": "Two-factor authentication is not enabled"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
return Response({
|
||||
"success": True,
|
||||
"message": "Two-factor authentication disabled",
|
||||
})
|
||||
|
||||
|
||||
@extend_schema(
|
||||
operation_id="verify_totp",
|
||||
summary="Verify TOTP code during login",
|
||||
description="Verifies the TOTP code as part of the login process.",
|
||||
request={
|
||||
"application/json": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {"type": "string", "description": "6-digit TOTP code"}
|
||||
},
|
||||
"required": ["code"],
|
||||
}
|
||||
},
|
||||
responses={
|
||||
200: {"description": "Code verified", "example": {"success": True}},
|
||||
400: {"description": "Invalid code"},
|
||||
},
|
||||
tags=["MFA"],
|
||||
)
|
||||
@api_view(["POST"])
|
||||
@permission_classes([IsAuthenticated])
|
||||
def verify_totp(request):
|
||||
"""Verify TOTP code."""
|
||||
from allauth.mfa.models import Authenticator
|
||||
from allauth.mfa.totp.internal import auth as totp_auth
|
||||
|
||||
user = request.user
|
||||
code = request.data.get("code", "").strip()
|
||||
|
||||
if not code:
|
||||
return Response(
|
||||
{"success": False, "error": "Verification code is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
try:
|
||||
authenticator = Authenticator.objects.get(user=user, type=Authenticator.Type.TOTP)
|
||||
secret = authenticator.data.get("secret")
|
||||
|
||||
if totp_auth.validate_totp_code(secret, code):
|
||||
return Response({"success": True})
|
||||
else:
|
||||
return Response(
|
||||
{"success": False, "error": "Invalid verification code"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
except Authenticator.DoesNotExist:
|
||||
return Response(
|
||||
{"success": False, "error": "TOTP is not enabled"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
@extend_schema(
|
||||
operation_id="regenerate_recovery_codes",
|
||||
summary="Regenerate recovery codes",
|
||||
description="Generates new recovery codes (invalidates old ones).",
|
||||
request={
|
||||
"application/json": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"password": {"type": "string", "description": "Current password"}
|
||||
},
|
||||
"required": ["password"],
|
||||
}
|
||||
},
|
||||
responses={
|
||||
200: {
|
||||
"description": "New recovery codes",
|
||||
"example": {"success": True, "recovery_codes": ["ABCD1234", "EFGH5678"]},
|
||||
},
|
||||
400: {"description": "Invalid password or MFA not enabled"},
|
||||
},
|
||||
tags=["MFA"],
|
||||
)
|
||||
@api_view(["POST"])
|
||||
@permission_classes([IsAuthenticated])
|
||||
def regenerate_recovery_codes(request):
|
||||
"""Regenerate recovery codes."""
|
||||
from allauth.mfa.models import Authenticator
|
||||
from allauth.mfa.recovery_codes.internal import auth as recovery_auth
|
||||
|
||||
user = request.user
|
||||
password = request.data.get("password", "")
|
||||
|
||||
# Verify password
|
||||
if not user.check_password(password):
|
||||
return Response(
|
||||
{"success": False, "error": "Invalid password"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Check if TOTP is enabled
|
||||
if not Authenticator.objects.filter(user=user, type=Authenticator.Type.TOTP).exists():
|
||||
return Response(
|
||||
{"success": False, "error": "Two-factor authentication is not enabled"},
|
||||
status=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
# Generate new codes
|
||||
codes = recovery_auth.generate_recovery_codes()
|
||||
|
||||
# Update or create recovery codes authenticator
|
||||
authenticator, created = Authenticator.objects.update_or_create(
|
||||
user=user,
|
||||
type=Authenticator.Type.RECOVERY_CODES,
|
||||
defaults={"data": {"codes": codes}},
|
||||
)
|
||||
|
||||
return Response({
|
||||
"success": True,
|
||||
"recovery_codes": codes,
|
||||
})
|
||||
@@ -5,21 +5,21 @@ This module contains all serializers related to authentication, user accounts,
|
||||
profiles, top lists, and user statistics.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from rest_framework import serializers
|
||||
from drf_spectacular.utils import (
|
||||
extend_schema_serializer,
|
||||
extend_schema_field,
|
||||
OpenApiExample,
|
||||
)
|
||||
from django.contrib.auth.password_validation import validate_password
|
||||
from django.utils.crypto import get_random_string
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.utils import timezone
|
||||
from datetime import timedelta
|
||||
from apps.accounts.models import PasswordReset
|
||||
from typing import Any
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.contrib.auth.password_validation import validate_password
|
||||
from django.utils import timezone
|
||||
from django.utils.crypto import get_random_string
|
||||
from drf_spectacular.utils import (
|
||||
OpenApiExample,
|
||||
extend_schema_field,
|
||||
extend_schema_serializer,
|
||||
)
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.accounts.models import PasswordReset
|
||||
|
||||
UserModel = get_user_model()
|
||||
|
||||
@@ -192,11 +192,13 @@ class SignupInputSerializer(serializers.ModelSerializer):
|
||||
|
||||
def _send_verification_email(self, user):
|
||||
"""Send email verification to the user."""
|
||||
from apps.accounts.models import EmailVerification
|
||||
import logging
|
||||
|
||||
from django.contrib.sites.shortcuts import get_current_site
|
||||
from django.utils.crypto import get_random_string
|
||||
from django_forwardemail.services import EmailService
|
||||
from django.contrib.sites.shortcuts import get_current_site
|
||||
import logging
|
||||
|
||||
from apps.accounts.models import EmailVerification
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -436,7 +438,7 @@ class UserProfileOutputSerializer(serializers.Serializer):
|
||||
return obj.get_avatar_url()
|
||||
|
||||
@extend_schema_field(serializers.DictField())
|
||||
def get_user(self, obj) -> Dict[str, Any]:
|
||||
def get_user(self, obj) -> dict[str, Any]:
|
||||
return {
|
||||
"username": obj.user.username,
|
||||
"date_joined": obj.user.date_joined,
|
||||
|
||||
@@ -6,15 +6,15 @@ Main authentication serializers are imported directly from the parent serializer
|
||||
"""
|
||||
|
||||
from .social import (
|
||||
ConnectedProviderSerializer,
|
||||
AvailableProviderSerializer,
|
||||
SocialAuthStatusSerializer,
|
||||
ConnectedProviderSerializer,
|
||||
ConnectedProvidersListOutputSerializer,
|
||||
ConnectProviderInputSerializer,
|
||||
ConnectProviderOutputSerializer,
|
||||
DisconnectProviderOutputSerializer,
|
||||
SocialProviderListOutputSerializer,
|
||||
ConnectedProvidersListOutputSerializer,
|
||||
SocialAuthStatusSerializer,
|
||||
SocialProviderErrorSerializer,
|
||||
SocialProviderListOutputSerializer,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -5,8 +5,8 @@ Serializers for handling social provider connection/disconnection requests
|
||||
and responses in the ThrillWiki API.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
from django.contrib.auth import get_user_model
|
||||
from rest_framework import serializers
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
@@ -5,29 +5,30 @@ This module contains URL patterns for core authentication functionality only.
|
||||
User profiles and top lists are handled by the dedicated accounts app.
|
||||
"""
|
||||
|
||||
from django.urls import path, include
|
||||
from django.urls import include, path
|
||||
from rest_framework_simplejwt.views import TokenRefreshView
|
||||
|
||||
from . import mfa as mfa_views
|
||||
from .views import (
|
||||
# Main auth views
|
||||
LoginAPIView,
|
||||
SignupAPIView,
|
||||
LogoutAPIView,
|
||||
CurrentUserAPIView,
|
||||
PasswordResetAPIView,
|
||||
PasswordChangeAPIView,
|
||||
SocialProvidersAPIView,
|
||||
AuthStatusAPIView,
|
||||
# Email verification views
|
||||
EmailVerificationAPIView,
|
||||
ResendVerificationAPIView,
|
||||
# Social provider management views
|
||||
AvailableProvidersAPIView,
|
||||
ConnectedProvidersAPIView,
|
||||
ConnectProviderAPIView,
|
||||
CurrentUserAPIView,
|
||||
DisconnectProviderAPIView,
|
||||
# Email verification views
|
||||
EmailVerificationAPIView,
|
||||
# Main auth views
|
||||
LoginAPIView,
|
||||
LogoutAPIView,
|
||||
PasswordChangeAPIView,
|
||||
PasswordResetAPIView,
|
||||
ResendVerificationAPIView,
|
||||
SignupAPIView,
|
||||
SocialAuthStatusAPIView,
|
||||
SocialProvidersAPIView,
|
||||
)
|
||||
from rest_framework_simplejwt.views import TokenRefreshView
|
||||
|
||||
|
||||
urlpatterns = [
|
||||
# Core authentication endpoints
|
||||
@@ -98,6 +99,14 @@ urlpatterns = [
|
||||
ResendVerificationAPIView.as_view(),
|
||||
name="auth-resend-verification",
|
||||
),
|
||||
|
||||
# MFA (Multi-Factor Authentication) endpoints
|
||||
path("mfa/status/", mfa_views.get_mfa_status, name="auth-mfa-status"),
|
||||
path("mfa/totp/setup/", mfa_views.setup_totp, name="auth-mfa-totp-setup"),
|
||||
path("mfa/totp/activate/", mfa_views.activate_totp, name="auth-mfa-totp-activate"),
|
||||
path("mfa/totp/deactivate/", mfa_views.deactivate_totp, name="auth-mfa-totp-deactivate"),
|
||||
path("mfa/totp/verify/", mfa_views.verify_totp, name="auth-mfa-totp-verify"),
|
||||
path("mfa/recovery-codes/regenerate/", mfa_views.regenerate_recovery_codes, name="auth-mfa-recovery-regenerate"),
|
||||
]
|
||||
|
||||
# Note: User profiles and top lists functionality is now handled by the accounts app
|
||||
|
||||
@@ -6,44 +6,46 @@ login, signup, logout, password management, social authentication,
|
||||
user profiles, and top lists.
|
||||
"""
|
||||
|
||||
from .serializers_package.social import (
|
||||
ConnectedProviderSerializer,
|
||||
AvailableProviderSerializer,
|
||||
SocialAuthStatusSerializer,
|
||||
ConnectProviderInputSerializer,
|
||||
ConnectProviderOutputSerializer,
|
||||
DisconnectProviderOutputSerializer,
|
||||
SocialProviderErrorSerializer,
|
||||
)
|
||||
from apps.accounts.services.social_provider_service import SocialProviderService
|
||||
from django.contrib.auth import authenticate, login, logout, get_user_model
|
||||
from typing import cast # added 'cast'
|
||||
|
||||
from django.contrib.auth import authenticate, get_user_model, login, logout
|
||||
from django.contrib.sites.shortcuts import get_current_site
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.db.models import Q
|
||||
from typing import Optional, cast # added 'cast'
|
||||
from django.http import HttpRequest # new import
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_view
|
||||
from rest_framework import status
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_view
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from apps.accounts.services.social_provider_service import SocialProviderService
|
||||
|
||||
# Import directly from the auth serializers.py file (not the serializers package)
|
||||
from .serializers import (
|
||||
AuthStatusOutputSerializer,
|
||||
# Authentication serializers
|
||||
LoginInputSerializer,
|
||||
LoginOutputSerializer,
|
||||
SignupInputSerializer,
|
||||
SignupOutputSerializer,
|
||||
LogoutOutputSerializer,
|
||||
UserOutputSerializer,
|
||||
PasswordResetInputSerializer,
|
||||
PasswordResetOutputSerializer,
|
||||
PasswordChangeInputSerializer,
|
||||
PasswordChangeOutputSerializer,
|
||||
PasswordResetInputSerializer,
|
||||
PasswordResetOutputSerializer,
|
||||
SignupInputSerializer,
|
||||
SignupOutputSerializer,
|
||||
SocialProviderOutputSerializer,
|
||||
AuthStatusOutputSerializer,
|
||||
UserOutputSerializer,
|
||||
)
|
||||
from .serializers_package.social import (
|
||||
AvailableProviderSerializer,
|
||||
ConnectedProviderSerializer,
|
||||
ConnectProviderInputSerializer,
|
||||
ConnectProviderOutputSerializer,
|
||||
DisconnectProviderOutputSerializer,
|
||||
SocialAuthStatusSerializer,
|
||||
SocialProviderErrorSerializer,
|
||||
)
|
||||
|
||||
# Handle optional dependencies with fallback classes
|
||||
@@ -62,10 +64,7 @@ try:
|
||||
|
||||
# Ensure the imported object is a class/type that can be used as a base class.
|
||||
# If it's not a type for any reason, fall back to the safe mixin.
|
||||
if isinstance(_ImportedTurnstileMixin, type):
|
||||
TurnstileMixin = _ImportedTurnstileMixin
|
||||
else:
|
||||
TurnstileMixin = FallbackTurnstileMixin
|
||||
TurnstileMixin = _ImportedTurnstileMixin if isinstance(_ImportedTurnstileMixin, type) else FallbackTurnstileMixin
|
||||
except Exception:
|
||||
# Catch any import errors or unexpected exceptions and use the fallback mixin.
|
||||
TurnstileMixin = FallbackTurnstileMixin
|
||||
@@ -88,7 +87,7 @@ def _get_underlying_request(request: Request) -> HttpRequest:
|
||||
# Helper: encapsulate user lookup + authenticate to reduce complexity in view
|
||||
def _authenticate_user_by_lookup(
|
||||
email_or_username: str, password: str, request: Request
|
||||
) -> Optional[UserModel]:
|
||||
) -> UserModel | None:
|
||||
"""
|
||||
Try a single optimized query to find a user by email OR username then authenticate.
|
||||
Returns authenticated user or None.
|
||||
@@ -199,7 +198,7 @@ class LoginAPIView(APIView):
|
||||
else:
|
||||
return Response(
|
||||
{
|
||||
"error": "Email verification required",
|
||||
"error": "Email verification required",
|
||||
"message": "Please verify your email address before logging in. Check your email for a verification link.",
|
||||
"email_verification_required": True
|
||||
},
|
||||
@@ -246,7 +245,7 @@ class SignupAPIView(APIView):
|
||||
serializer = SignupInputSerializer(data=request.data, context={"request": request})
|
||||
if serializer.is_valid():
|
||||
user = serializer.save()
|
||||
|
||||
|
||||
# Don't log in the user immediately - they need to verify their email first
|
||||
response_serializer = SignupOutputSerializer(
|
||||
{
|
||||
@@ -754,23 +753,23 @@ class EmailVerificationAPIView(APIView):
|
||||
|
||||
def get(self, request: Request, token: str) -> Response:
|
||||
from apps.accounts.models import EmailVerification
|
||||
|
||||
|
||||
try:
|
||||
verification = EmailVerification.objects.select_related('user').get(token=token)
|
||||
user = verification.user
|
||||
|
||||
|
||||
# Activate the user
|
||||
user.is_active = True
|
||||
user.save()
|
||||
|
||||
|
||||
# Delete the verification record
|
||||
verification.delete()
|
||||
|
||||
|
||||
return Response({
|
||||
"message": "Email verified successfully. You can now log in.",
|
||||
"success": True
|
||||
})
|
||||
|
||||
|
||||
except EmailVerification.DoesNotExist:
|
||||
return Response(
|
||||
{"error": "Invalid or expired verification token"},
|
||||
@@ -798,45 +797,46 @@ class ResendVerificationAPIView(APIView):
|
||||
authentication_classes = []
|
||||
|
||||
def post(self, request: Request) -> Response:
|
||||
from apps.accounts.models import EmailVerification
|
||||
from django.contrib.sites.shortcuts import get_current_site
|
||||
from django.utils.crypto import get_random_string
|
||||
from django_forwardemail.services import EmailService
|
||||
from django.contrib.sites.shortcuts import get_current_site
|
||||
|
||||
|
||||
from apps.accounts.models import EmailVerification
|
||||
|
||||
email = request.data.get('email')
|
||||
if not email:
|
||||
return Response(
|
||||
{"error": "Email address is required"},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
user = UserModel.objects.get(email__iexact=email.strip().lower())
|
||||
|
||||
|
||||
# Don't resend if user is already active
|
||||
if user.is_active:
|
||||
return Response(
|
||||
{"error": "Email is already verified"},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
# Create or update verification record
|
||||
verification, created = EmailVerification.objects.get_or_create(
|
||||
user=user,
|
||||
defaults={'token': get_random_string(64)}
|
||||
)
|
||||
|
||||
|
||||
if not created:
|
||||
# Update existing token and timestamp
|
||||
verification.token = get_random_string(64)
|
||||
verification.save()
|
||||
|
||||
|
||||
# Send verification email
|
||||
site = get_current_site(_get_underlying_request(request))
|
||||
verification_url = request.build_absolute_uri(
|
||||
f"/api/v1/auth/verify-email/{verification.token}/"
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
EmailService.send_email(
|
||||
to=user.email,
|
||||
@@ -854,22 +854,22 @@ The ThrillWiki Team
|
||||
""".strip(),
|
||||
site=site,
|
||||
)
|
||||
|
||||
|
||||
return Response({
|
||||
"message": "Verification email sent successfully",
|
||||
"success": True
|
||||
})
|
||||
|
||||
|
||||
except Exception as e:
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.error(f"Failed to send verification email to {user.email}: {e}")
|
||||
|
||||
|
||||
return Response(
|
||||
{"error": "Failed to send verification email"},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
except UserModel.DoesNotExist:
|
||||
# Don't reveal whether email exists
|
||||
return Response({
|
||||
|
||||
@@ -4,6 +4,7 @@ Centralized from apps.core.urls
|
||||
"""
|
||||
|
||||
from django.urls import path
|
||||
|
||||
from . import views
|
||||
|
||||
# Entity search endpoints - migrated from apps.core.urls
|
||||
|
||||
@@ -8,18 +8,20 @@ Caching Strategy:
|
||||
- EntityNotFoundView: No caching - POST requests with context-specific data
|
||||
"""
|
||||
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.response import Response
|
||||
|
||||
import contextlib
|
||||
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework import status
|
||||
from rest_framework.permissions import AllowAny
|
||||
from typing import Optional, List
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from apps.core.services.entity_fuzzy_matching import (
|
||||
entity_fuzzy_matcher,
|
||||
EntityType,
|
||||
)
|
||||
from apps.core.decorators.cache_decorators import cache_api_response
|
||||
from apps.core.services.entity_fuzzy_matching import (
|
||||
EntityType,
|
||||
entity_fuzzy_matcher,
|
||||
)
|
||||
|
||||
|
||||
class EntityFuzzySearchView(APIView):
|
||||
@@ -199,10 +201,8 @@ class EntityNotFoundView(APIView):
|
||||
# Determine entity types to search based on context
|
||||
entity_types = []
|
||||
if entity_type_hint:
|
||||
try:
|
||||
with contextlib.suppress(ValueError):
|
||||
entity_types = [EntityType(entity_type_hint)]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# If we have park context, prioritize ride searches
|
||||
if context.get("park_slug") and not entity_types:
|
||||
@@ -344,7 +344,7 @@ class QuickEntitySuggestionView(APIView):
|
||||
|
||||
# Utility function for other views to use
|
||||
def get_entity_suggestions(
|
||||
query: str, entity_types: Optional[List[str]] = None, user=None
|
||||
query: str, entity_types: list[str] | None = None, user=None
|
||||
):
|
||||
"""
|
||||
Utility function for other Django views to get entity suggestions.
|
||||
|
||||
@@ -4,6 +4,7 @@ Centralized from apps.email_service.urls
|
||||
"""
|
||||
|
||||
from django.urls import path
|
||||
|
||||
from . import views
|
||||
|
||||
urlpatterns = [
|
||||
|
||||
@@ -3,13 +3,13 @@ Centralized email service API views.
|
||||
Migrated from apps.email_service.views
|
||||
"""
|
||||
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.response import Response
|
||||
from django.contrib.sites.shortcuts import get_current_site
|
||||
from django_forwardemail.services import EmailService
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework import status
|
||||
from rest_framework.permissions import AllowAny
|
||||
from django.contrib.sites.shortcuts import get_current_site
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from django_forwardemail.services import EmailService
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
@extend_schema(
|
||||
|
||||
@@ -4,7 +4,7 @@ History API URLs
|
||||
URL patterns for history-related API endpoints.
|
||||
"""
|
||||
|
||||
from django.urls import path, include
|
||||
from django.urls import include, path
|
||||
from rest_framework.routers import DefaultRouter
|
||||
|
||||
from .views import (
|
||||
|
||||
@@ -5,18 +5,21 @@ This module provides ViewSets for accessing historical data and change tracking
|
||||
across all models in the ThrillWiki system using django-pghistory.
|
||||
"""
|
||||
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_view, OpenApiParameter
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
import pghistory.models
|
||||
from django.db.models import Count, QuerySet
|
||||
from django.shortcuts import get_object_or_404
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view
|
||||
from rest_framework import serializers as drf_serializers
|
||||
from rest_framework.filters import OrderingFilter
|
||||
from rest_framework.permissions import AllowAny
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ReadOnlyModelViewSet
|
||||
from rest_framework.request import Request
|
||||
from typing import Optional, cast, Sequence
|
||||
from django.shortcuts import get_object_or_404
|
||||
from django.db.models import Count, QuerySet
|
||||
import pghistory.models
|
||||
from datetime import datetime
|
||||
|
||||
# Import models
|
||||
from apps.parks.models import Park
|
||||
@@ -24,7 +27,6 @@ from apps.rides.models import Ride
|
||||
|
||||
# Import serializers
|
||||
from .. import serializers as history_serializers
|
||||
from rest_framework import serializers as drf_serializers
|
||||
|
||||
# Minimal fallback serializer used when a specific serializer symbol is missing.
|
||||
|
||||
@@ -79,7 +81,7 @@ ALL_TRACKED_MODELS: Sequence[str] = [
|
||||
# --- Helper utilities to reduce duplicated logic / cognitive complexity ---
|
||||
|
||||
|
||||
def _parse_date(date_str: Optional[str]) -> Optional[datetime]:
|
||||
def _parse_date(date_str: str | None) -> datetime | None:
|
||||
if not date_str:
|
||||
return None
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from django.urls import path
|
||||
|
||||
from .views import GenerateUploadURLView
|
||||
|
||||
urlpatterns = [
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework import status
|
||||
from apps.core.utils.cloudflare import get_direct_upload_url
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
import requests
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from rest_framework import status
|
||||
from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from apps.core.utils.cloudflare import get_direct_upload_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GenerateUploadURLView(APIView):
|
||||
@@ -29,7 +31,7 @@ class GenerateUploadURLView(APIView):
|
||||
{"detail": "Failed to generate upload URL."},
|
||||
status=status.HTTP_502_BAD_GATEWAY
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Unexpected error generating upload URL")
|
||||
return Response(
|
||||
{"detail": "An unexpected error occurred."},
|
||||
|
||||
@@ -4,6 +4,7 @@ Migrated from apps.core.urls.map_urls to centralized API structure.
|
||||
"""
|
||||
|
||||
from django.urls import path
|
||||
|
||||
from . import views
|
||||
|
||||
# Map API endpoints - migrated from apps.core.urls.map_urls
|
||||
|
||||
@@ -12,30 +12,31 @@ Caching Strategy:
|
||||
|
||||
import logging
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.http import HttpRequest
|
||||
from django.db.models import Q
|
||||
from django.contrib.gis.geos import Polygon
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.response import Response
|
||||
from rest_framework import status
|
||||
from rest_framework.permissions import AllowAny, IsAdminUser
|
||||
from django.core.cache import cache
|
||||
from django.db.models import Q
|
||||
from django.http import HttpRequest
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import (
|
||||
OpenApiExample,
|
||||
OpenApiParameter,
|
||||
extend_schema,
|
||||
extend_schema_view,
|
||||
OpenApiParameter,
|
||||
OpenApiExample,
|
||||
)
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from rest_framework import status
|
||||
from rest_framework.permissions import AllowAny, IsAdminUser
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from apps.core.decorators.cache_decorators import cache_api_response
|
||||
from apps.core.services.enhanced_cache_service import EnhancedCacheService
|
||||
from apps.parks.models import Park
|
||||
from apps.rides.models import Ride
|
||||
from apps.core.services.enhanced_cache_service import EnhancedCacheService
|
||||
from apps.core.decorators.cache_decorators import cache_api_response
|
||||
|
||||
from ..serializers.maps import (
|
||||
MapLocationDetailSerializer,
|
||||
MapLocationsResponseSerializer,
|
||||
MapSearchResponseSerializer,
|
||||
MapLocationDetailSerializer,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -7,7 +7,8 @@ TypeScript interfaces, providing immediate feedback during development.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
from django.http import JsonResponse
|
||||
from django.utils.deprecation import MiddlewareMixin
|
||||
@@ -19,52 +20,49 @@ logger = logging.getLogger(__name__)
|
||||
class ContractValidationMiddleware(MiddlewareMixin):
|
||||
"""
|
||||
Development-only middleware that validates API responses against expected contracts.
|
||||
|
||||
|
||||
This middleware:
|
||||
1. Checks all API responses for contract compliance
|
||||
2. Logs warnings when responses don't match expected TypeScript interfaces
|
||||
3. Specifically validates filter metadata structure
|
||||
4. Alerts when categorical filters are strings instead of objects
|
||||
|
||||
|
||||
Only active when DEBUG=True to avoid performance impact in production.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, get_response):
|
||||
super().__init__(get_response)
|
||||
self.get_response = get_response
|
||||
self.enabled = getattr(settings, 'DEBUG', False)
|
||||
|
||||
|
||||
if self.enabled:
|
||||
logger.info("Contract validation middleware enabled (DEBUG mode)")
|
||||
|
||||
|
||||
def process_response(self, request, response):
|
||||
"""Process API responses to check for contract violations."""
|
||||
|
||||
|
||||
if not self.enabled:
|
||||
return response
|
||||
|
||||
|
||||
# Only validate API endpoints
|
||||
if not request.path.startswith('/api/'):
|
||||
return response
|
||||
|
||||
|
||||
# Only validate JSON responses
|
||||
if not isinstance(response, (JsonResponse, Response)):
|
||||
return response
|
||||
|
||||
|
||||
# Only validate successful responses (2xx status codes)
|
||||
if not (200 <= response.status_code < 300):
|
||||
return response
|
||||
|
||||
|
||||
try:
|
||||
# Get response data
|
||||
if isinstance(response, Response):
|
||||
data = response.data
|
||||
else:
|
||||
data = json.loads(response.content.decode('utf-8'))
|
||||
|
||||
data = response.data if isinstance(response, Response) else json.loads(response.content.decode('utf-8'))
|
||||
|
||||
# Validate the response
|
||||
self._validate_response_contract(request.path, data)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Log validation errors but don't break the response
|
||||
logger.warning(
|
||||
@@ -76,55 +74,55 @@ class ContractValidationMiddleware(MiddlewareMixin):
|
||||
'validation_error': str(e)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def _validate_response_contract(self, path: str, data: Any) -> None:
|
||||
"""Validate response data against expected contracts."""
|
||||
|
||||
|
||||
# Check for filter metadata endpoints
|
||||
if 'filter-options' in path or 'filter_options' in path:
|
||||
self._validate_filter_metadata(path, data)
|
||||
|
||||
|
||||
# Check for hybrid filtering endpoints
|
||||
if 'hybrid' in path:
|
||||
self._validate_hybrid_response(path, data)
|
||||
|
||||
|
||||
# Check for pagination responses
|
||||
if isinstance(data, dict) and 'results' in data:
|
||||
self._validate_pagination_response(path, data)
|
||||
|
||||
|
||||
# Check for common contract violations
|
||||
self._validate_common_patterns(path, data)
|
||||
|
||||
|
||||
def _validate_filter_metadata(self, path: str, data: Any) -> None:
|
||||
"""Validate filter metadata structure."""
|
||||
|
||||
|
||||
if not isinstance(data, dict):
|
||||
self._log_contract_violation(
|
||||
path,
|
||||
path,
|
||||
"FILTER_METADATA_NOT_DICT",
|
||||
f"Filter metadata should be a dictionary, got {type(data).__name__}"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
# Check for categorical filters
|
||||
if 'categorical' in data:
|
||||
categorical = data['categorical']
|
||||
if isinstance(categorical, dict):
|
||||
for filter_name, filter_options in categorical.items():
|
||||
self._validate_categorical_filter(path, filter_name, filter_options)
|
||||
|
||||
|
||||
# Check for ranges
|
||||
if 'ranges' in data:
|
||||
ranges = data['ranges']
|
||||
if isinstance(ranges, dict):
|
||||
for range_name, range_data in ranges.items():
|
||||
self._validate_range_filter(path, range_name, range_data)
|
||||
|
||||
|
||||
def _validate_categorical_filter(self, path: str, filter_name: str, filter_options: Any) -> None:
|
||||
"""Validate categorical filter options format."""
|
||||
|
||||
|
||||
if not isinstance(filter_options, list):
|
||||
self._log_contract_violation(
|
||||
path,
|
||||
@@ -132,7 +130,7 @@ class ContractValidationMiddleware(MiddlewareMixin):
|
||||
f"Categorical filter '{filter_name}' should be an array, got {type(filter_options).__name__}"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
for i, option in enumerate(filter_options):
|
||||
if isinstance(option, str):
|
||||
# CRITICAL: This is the main contract violation we're trying to catch
|
||||
@@ -163,10 +161,10 @@ class ContractValidationMiddleware(MiddlewareMixin):
|
||||
"INVALID_COUNT_TYPE",
|
||||
f"Categorical filter '{filter_name}' option {i} 'count' should be a number, got {type(option['count']).__name__}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_range_filter(self, path: str, range_name: str, range_data: Any) -> None:
|
||||
"""Validate range filter format."""
|
||||
|
||||
|
||||
if not isinstance(range_data, dict):
|
||||
self._log_contract_violation(
|
||||
path,
|
||||
@@ -174,7 +172,7 @@ class ContractValidationMiddleware(MiddlewareMixin):
|
||||
f"Range filter '{range_name}' should be an object, got {type(range_data).__name__}"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
# Check required properties
|
||||
required_props = ['min', 'max']
|
||||
for prop in required_props:
|
||||
@@ -184,7 +182,7 @@ class ContractValidationMiddleware(MiddlewareMixin):
|
||||
"MISSING_RANGE_PROPERTY",
|
||||
f"Range filter '{range_name}' missing required property '{prop}'"
|
||||
)
|
||||
|
||||
|
||||
# Check step property
|
||||
if 'step' in range_data and not isinstance(range_data['step'], (int, float)):
|
||||
self._log_contract_violation(
|
||||
@@ -192,13 +190,13 @@ class ContractValidationMiddleware(MiddlewareMixin):
|
||||
"INVALID_STEP_TYPE",
|
||||
f"Range filter '{range_name}' 'step' should be a number, got {type(range_data['step']).__name__}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_hybrid_response(self, path: str, data: Any) -> None:
|
||||
"""Validate hybrid filtering response structure."""
|
||||
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return
|
||||
|
||||
|
||||
# Check for strategy field
|
||||
if 'strategy' in data:
|
||||
strategy = data['strategy']
|
||||
@@ -208,14 +206,14 @@ class ContractValidationMiddleware(MiddlewareMixin):
|
||||
"INVALID_STRATEGY_VALUE",
|
||||
f"Hybrid response strategy should be 'client_side' or 'server_side', got '{strategy}'"
|
||||
)
|
||||
|
||||
|
||||
# Check filter_metadata structure
|
||||
if 'filter_metadata' in data:
|
||||
self._validate_filter_metadata(path, data['filter_metadata'])
|
||||
|
||||
def _validate_pagination_response(self, path: str, data: Dict[str, Any]) -> None:
|
||||
|
||||
def _validate_pagination_response(self, path: str, data: dict[str, Any]) -> None:
|
||||
"""Validate pagination response structure."""
|
||||
|
||||
|
||||
# Check for required pagination fields
|
||||
required_fields = ['count', 'results']
|
||||
for field in required_fields:
|
||||
@@ -225,7 +223,7 @@ class ContractValidationMiddleware(MiddlewareMixin):
|
||||
"MISSING_PAGINATION_FIELD",
|
||||
f"Pagination response missing required field '{field}'"
|
||||
)
|
||||
|
||||
|
||||
# Check results is array
|
||||
if 'results' in data and not isinstance(data['results'], list):
|
||||
self._log_contract_violation(
|
||||
@@ -233,17 +231,17 @@ class ContractValidationMiddleware(MiddlewareMixin):
|
||||
"RESULTS_NOT_ARRAY",
|
||||
f"Pagination 'results' should be an array, got {type(data['results']).__name__}"
|
||||
)
|
||||
|
||||
|
||||
def _validate_common_patterns(self, path: str, data: Any) -> None:
|
||||
"""Validate common API response patterns."""
|
||||
|
||||
|
||||
if isinstance(data, dict):
|
||||
# Check for null vs undefined issues
|
||||
for key, value in data.items():
|
||||
if value is None and key.endswith('_id'):
|
||||
# ID fields should probably be null, not undefined
|
||||
continue
|
||||
|
||||
|
||||
# Check for numeric fields that might be strings
|
||||
if key.endswith('_count') and isinstance(value, str):
|
||||
try:
|
||||
@@ -255,16 +253,16 @@ class ContractValidationMiddleware(MiddlewareMixin):
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def _log_contract_violation(
|
||||
self,
|
||||
path: str,
|
||||
violation_type: str,
|
||||
message: str,
|
||||
self,
|
||||
path: str,
|
||||
violation_type: str,
|
||||
message: str,
|
||||
severity: str = "WARNING"
|
||||
) -> None:
|
||||
"""Log a contract violation with structured data."""
|
||||
|
||||
|
||||
log_data = {
|
||||
'contract_violation': True,
|
||||
'violation_type': violation_type,
|
||||
@@ -273,15 +271,15 @@ class ContractValidationMiddleware(MiddlewareMixin):
|
||||
'message': message,
|
||||
'suggestion': self._get_violation_suggestion(violation_type)
|
||||
}
|
||||
|
||||
|
||||
if severity == "ERROR":
|
||||
logger.error(f"CONTRACT VIOLATION [{violation_type}]: {message}", extra=log_data)
|
||||
else:
|
||||
logger.warning(f"CONTRACT VIOLATION [{violation_type}]: {message}", extra=log_data)
|
||||
|
||||
|
||||
def _get_violation_suggestion(self, violation_type: str) -> str:
|
||||
"""Get suggestion for fixing a contract violation."""
|
||||
|
||||
|
||||
suggestions = {
|
||||
"CATEGORICAL_OPTION_IS_STRING": (
|
||||
"Convert string arrays to object arrays with {value, label, count} structure. "
|
||||
@@ -308,31 +306,31 @@ class ContractValidationMiddleware(MiddlewareMixin):
|
||||
"Check serializer implementation."
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
return suggestions.get(violation_type, "Check the API response format against frontend TypeScript interfaces.")
|
||||
|
||||
|
||||
class ContractValidationSettings:
|
||||
"""Settings for contract validation middleware."""
|
||||
|
||||
|
||||
# Enable/disable specific validation checks
|
||||
VALIDATE_FILTER_METADATA = True
|
||||
VALIDATE_PAGINATION = True
|
||||
VALIDATE_HYBRID_RESPONSES = True
|
||||
VALIDATE_COMMON_PATTERNS = True
|
||||
|
||||
|
||||
# Severity levels for different violations
|
||||
CATEGORICAL_STRING_SEVERITY = "ERROR" # This is the critical issue
|
||||
MISSING_PROPERTY_SEVERITY = "WARNING"
|
||||
TYPE_MISMATCH_SEVERITY = "WARNING"
|
||||
|
||||
|
||||
# Paths to exclude from validation
|
||||
EXCLUDED_PATHS = [
|
||||
'/api/docs/',
|
||||
'/api/schema/',
|
||||
'/api/v1/auth/', # Auth endpoints might have different structures
|
||||
]
|
||||
|
||||
|
||||
@classmethod
|
||||
def should_validate_path(cls, path: str) -> bool:
|
||||
"""Check if a path should be validated."""
|
||||
|
||||
@@ -2,14 +2,16 @@
|
||||
Park history API views.
|
||||
"""
|
||||
|
||||
from rest_framework import viewsets, mixins
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.permissions import AllowAny
|
||||
from django.shortcuts import get_object_or_404
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from rest_framework import viewsets
|
||||
from rest_framework.permissions import AllowAny
|
||||
from rest_framework.response import Response
|
||||
|
||||
from apps.api.v1.serializers.history import ParkHistoryOutputSerializer, RideHistoryOutputSerializer
|
||||
from apps.parks.models import Park
|
||||
from apps.rides.models import Ride
|
||||
from apps.api.v1.serializers.history import ParkHistoryOutputSerializer, RideHistoryOutputSerializer
|
||||
|
||||
|
||||
class ParkHistoryViewSet(viewsets.GenericViewSet):
|
||||
"""
|
||||
@@ -18,7 +20,7 @@ class ParkHistoryViewSet(viewsets.GenericViewSet):
|
||||
permission_classes = [AllowAny]
|
||||
lookup_field = "slug"
|
||||
lookup_url_kwarg = "park_slug"
|
||||
|
||||
|
||||
@extend_schema(
|
||||
summary="Get park history",
|
||||
description="Retrieve history events for a park.",
|
||||
@@ -27,24 +29,24 @@ class ParkHistoryViewSet(viewsets.GenericViewSet):
|
||||
)
|
||||
def list(self, request, park_slug=None):
|
||||
park = get_object_or_404(Park, slug=park_slug)
|
||||
|
||||
|
||||
events = []
|
||||
if hasattr(park, "events"):
|
||||
events = park.events.all().order_by("-pgh_created_at")
|
||||
|
||||
|
||||
summary = {
|
||||
"total_events": len(events),
|
||||
"first_recorded": events.last().pgh_created_at if len(events) else None,
|
||||
"last_modified": events.first().pgh_created_at if len(events) else None,
|
||||
}
|
||||
|
||||
|
||||
data = {
|
||||
"park": park,
|
||||
"current_state": park,
|
||||
"summary": summary,
|
||||
"events": events
|
||||
}
|
||||
|
||||
|
||||
serializer = ParkHistoryOutputSerializer(data)
|
||||
return Response(serializer.data)
|
||||
|
||||
|
||||
@@ -6,27 +6,26 @@ Provides CRUD operations for park reviews nested under parks/{slug}/reviews/
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from django.db.models import Avg
|
||||
from django.utils import timezone
|
||||
from drf_spectacular.utils import extend_schema_view, extend_schema
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_view
|
||||
from rest_framework import status
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ValidationError, NotFound
|
||||
from rest_framework.permissions import IsAuthenticated, AllowAny
|
||||
from rest_framework.exceptions import NotFound, ValidationError
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from apps.parks.models import Park, ParkReview
|
||||
from apps.api.v1.serializers.park_reviews import (
|
||||
ParkReviewOutputSerializer,
|
||||
ParkReviewCreateInputSerializer,
|
||||
ParkReviewUpdateInputSerializer,
|
||||
ParkReviewListOutputSerializer,
|
||||
ParkReviewOutputSerializer,
|
||||
ParkReviewStatsOutputSerializer,
|
||||
ParkReviewModerationInputSerializer,
|
||||
ParkReviewUpdateInputSerializer,
|
||||
)
|
||||
from apps.parks.models import Park, ParkReview
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -66,10 +65,7 @@ class ParkReviewViewSet(ModelViewSet):
|
||||
|
||||
def get_permissions(self):
|
||||
"""Set permissions based on action."""
|
||||
if self.action in ['list', 'retrieve', 'stats']:
|
||||
permission_classes = [AllowAny]
|
||||
else:
|
||||
permission_classes = [IsAuthenticated]
|
||||
permission_classes = [AllowAny] if self.action in ['list', 'retrieve', 'stats'] else [IsAuthenticated]
|
||||
return [permission() for permission in permission_classes]
|
||||
|
||||
def get_queryset(self):
|
||||
@@ -143,7 +139,7 @@ class ParkReviewViewSet(ModelViewSet):
|
||||
reviews = ParkReview.objects.filter(park=park, is_published=True)
|
||||
total_reviews = reviews.count()
|
||||
avg_rating = reviews.aggregate(avg=Avg('rating'))['avg']
|
||||
|
||||
|
||||
rating_distribution = {}
|
||||
for i in range(1, 11):
|
||||
rating_distribution[str(i)] = reviews.filter(rating=i).count()
|
||||
|
||||
@@ -6,19 +6,16 @@ This module implements endpoints for accessing rides within specific parks:
|
||||
- GET /parks/{park_slug}/rides/{ride_slug}/ - Get specific ride details within park context
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from django.db import models
|
||||
from django.db.models import Q, Count, Avg
|
||||
from django.db.models import Q
|
||||
from django.db.models.query import QuerySet
|
||||
|
||||
from rest_framework import status, permissions
|
||||
from rest_framework.views import APIView
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, extend_schema
|
||||
from rest_framework import permissions, status
|
||||
from rest_framework.exceptions import NotFound
|
||||
from rest_framework.pagination import PageNumberPagination
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.pagination import PageNumberPagination
|
||||
from rest_framework.exceptions import NotFound
|
||||
from drf_spectacular.utils import extend_schema, OpenApiParameter
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from rest_framework.views import APIView
|
||||
|
||||
# Import models
|
||||
try:
|
||||
@@ -32,8 +29,8 @@ except Exception:
|
||||
|
||||
# Import serializers
|
||||
try:
|
||||
from apps.api.v1.serializers.rides import RideListOutputSerializer, RideDetailOutputSerializer
|
||||
from apps.api.v1.serializers.parks import ParkDetailOutputSerializer
|
||||
from apps.api.v1.serializers.rides import RideDetailOutputSerializer, RideListOutputSerializer
|
||||
SERIALIZERS_AVAILABLE = True
|
||||
except Exception:
|
||||
SERIALIZERS_AVAILABLE = False
|
||||
@@ -47,7 +44,7 @@ class StandardResultsSetPagination(PageNumberPagination):
|
||||
|
||||
class ParkRidesListAPIView(APIView):
|
||||
"""List rides at a specific park with pagination and filtering."""
|
||||
|
||||
|
||||
permission_classes = [permissions.AllowAny]
|
||||
|
||||
@extend_schema(
|
||||
@@ -59,7 +56,7 @@ class ParkRidesListAPIView(APIView):
|
||||
type=OpenApiTypes.INT, description="Page number"),
|
||||
OpenApiParameter(name="page_size", location=OpenApiParameter.QUERY,
|
||||
type=OpenApiTypes.INT, description="Number of results per page (max 100)"),
|
||||
|
||||
|
||||
# Filtering
|
||||
OpenApiParameter(name="category", location=OpenApiParameter.QUERY,
|
||||
type=OpenApiTypes.STR, description="Filter by ride category"),
|
||||
@@ -67,7 +64,7 @@ class ParkRidesListAPIView(APIView):
|
||||
type=OpenApiTypes.STR, description="Filter by operational status"),
|
||||
OpenApiParameter(name="search", location=OpenApiParameter.QUERY,
|
||||
type=OpenApiTypes.STR, description="Search rides by name"),
|
||||
|
||||
|
||||
# Ordering
|
||||
OpenApiParameter(name="ordering", location=OpenApiParameter.QUERY,
|
||||
type=OpenApiTypes.STR, description="Order results by field"),
|
||||
@@ -158,7 +155,7 @@ class ParkRidesListAPIView(APIView):
|
||||
|
||||
class ParkRideDetailAPIView(APIView):
|
||||
"""Get specific ride details within park context."""
|
||||
|
||||
|
||||
permission_classes = [permissions.AllowAny]
|
||||
|
||||
@extend_schema(
|
||||
@@ -222,7 +219,7 @@ class ParkRideDetailAPIView(APIView):
|
||||
|
||||
class ParkComprehensiveDetailAPIView(APIView):
|
||||
"""Get comprehensive park details including summary of rides."""
|
||||
|
||||
|
||||
permission_classes = [permissions.AllowAny]
|
||||
|
||||
@extend_schema(
|
||||
@@ -271,7 +268,7 @@ class ParkComprehensiveDetailAPIView(APIView):
|
||||
rides_serializer = RideListOutputSerializer(
|
||||
rides_sample, many=True, context={"request": request, "park": park}
|
||||
)
|
||||
|
||||
|
||||
# Enhance response with rides data
|
||||
park_data["rides_summary"] = {
|
||||
"total_count": park.ride_count or 0,
|
||||
|
||||
@@ -11,23 +11,24 @@ This module implements comprehensive park endpoints with full filtering support:
|
||||
Supports all 24 filtering parameters from frontend API documentation.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
from typing import Any
|
||||
from django.db import models
|
||||
from django.db.models import Q, Count, Avg
|
||||
from django.db.models.query import QuerySet
|
||||
|
||||
from rest_framework import status, permissions
|
||||
from rest_framework.views import APIView
|
||||
from django.db import models
|
||||
from django.db.models import Avg, Count, Q
|
||||
from django.db.models.query import QuerySet
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, extend_schema
|
||||
from rest_framework import permissions, status
|
||||
from rest_framework.exceptions import NotFound
|
||||
from rest_framework.pagination import PageNumberPagination
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.pagination import PageNumberPagination
|
||||
from rest_framework.exceptions import NotFound
|
||||
from drf_spectacular.utils import extend_schema, OpenApiParameter
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from rest_framework.views import APIView
|
||||
|
||||
# Import models
|
||||
try:
|
||||
from apps.parks.models import Park, Company
|
||||
from apps.parks.models import Company, Park
|
||||
MODELS_AVAILABLE = True
|
||||
except Exception:
|
||||
Park = None # type: ignore
|
||||
@@ -45,11 +46,11 @@ except Exception:
|
||||
# Import serializers
|
||||
try:
|
||||
from apps.api.v1.serializers.parks import (
|
||||
ParkListOutputSerializer,
|
||||
ParkDetailOutputSerializer,
|
||||
ParkCreateInputSerializer,
|
||||
ParkUpdateInputSerializer,
|
||||
ParkDetailOutputSerializer,
|
||||
ParkImageSettingsInputSerializer,
|
||||
ParkListOutputSerializer,
|
||||
ParkUpdateInputSerializer,
|
||||
)
|
||||
SERIALIZERS_AVAILABLE = True
|
||||
except Exception:
|
||||
@@ -247,12 +248,12 @@ class ParkListCreateAPIView(APIView):
|
||||
'city': 'location__city__iexact',
|
||||
'continent': 'location__continent__iexact'
|
||||
}
|
||||
|
||||
|
||||
for param_name, filter_field in location_filters.items():
|
||||
value = params.get(param_name)
|
||||
if value:
|
||||
qs = qs.filter(**{filter_field: value})
|
||||
|
||||
|
||||
return qs
|
||||
|
||||
def _apply_park_attribute_filters(self, qs: QuerySet, params: dict) -> QuerySet:
|
||||
@@ -264,7 +265,7 @@ class ParkListCreateAPIView(APIView):
|
||||
status_filter = params.get("status")
|
||||
if status_filter:
|
||||
qs = qs.filter(status=status_filter)
|
||||
|
||||
|
||||
return qs
|
||||
|
||||
def _apply_company_filters(self, qs: QuerySet, params: dict) -> QuerySet:
|
||||
@@ -275,73 +276,59 @@ class ParkListCreateAPIView(APIView):
|
||||
'property_owner_id': 'property_owner_id',
|
||||
'property_owner_slug': 'property_owner__slug'
|
||||
}
|
||||
|
||||
|
||||
for param_name, filter_field in company_filters.items():
|
||||
value = params.get(param_name)
|
||||
if value:
|
||||
qs = qs.filter(**{filter_field: value})
|
||||
|
||||
|
||||
return qs
|
||||
|
||||
def _apply_rating_filters(self, qs: QuerySet, params: dict) -> QuerySet:
|
||||
"""Apply rating-based filtering to the queryset."""
|
||||
min_rating = params.get("min_rating")
|
||||
if min_rating:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(average_rating__gte=float(min_rating))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
max_rating = params.get("max_rating")
|
||||
if max_rating:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(average_rating__lte=float(max_rating))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
|
||||
return qs
|
||||
|
||||
def _apply_ride_count_filters(self, qs: QuerySet, params: dict) -> QuerySet:
|
||||
"""Apply ride count filtering to the queryset."""
|
||||
min_ride_count = params.get("min_ride_count")
|
||||
if min_ride_count:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(ride_count__gte=int(min_ride_count))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
max_ride_count = params.get("max_ride_count")
|
||||
if max_ride_count:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(ride_count__lte=int(max_ride_count))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
|
||||
return qs
|
||||
|
||||
def _apply_opening_year_filters(self, qs: QuerySet, params: dict) -> QuerySet:
|
||||
"""Apply opening year filtering to the queryset."""
|
||||
opening_year = params.get("opening_year")
|
||||
if opening_year:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(opening_date__year=int(opening_year))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
min_opening_year = params.get("min_opening_year")
|
||||
if min_opening_year:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(opening_date__year__gte=int(min_opening_year))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
max_opening_year = params.get("max_opening_year")
|
||||
if max_opening_year:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(opening_date__year__lte=int(max_opening_year))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
|
||||
return qs
|
||||
|
||||
def _apply_roller_coaster_filters(self, qs: QuerySet, params: dict) -> QuerySet:
|
||||
@@ -355,18 +342,14 @@ class ParkListCreateAPIView(APIView):
|
||||
|
||||
min_roller_coaster_count = params.get("min_roller_coaster_count")
|
||||
if min_roller_coaster_count:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(coaster_count__gte=int(min_roller_coaster_count))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
max_roller_coaster_count = params.get("max_roller_coaster_count")
|
||||
if max_roller_coaster_count:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(coaster_count__lte=int(max_roller_coaster_count))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
|
||||
return qs
|
||||
|
||||
@extend_schema(
|
||||
@@ -440,13 +423,13 @@ class ParkDetailAPIView(APIView):
|
||||
def _get_park_or_404(self, identifier: str) -> Any:
|
||||
if not MODELS_AVAILABLE:
|
||||
raise NotFound(
|
||||
(
|
||||
|
||||
"Park detail is not available because domain models "
|
||||
"are not imported. Implement apps.parks.models.Park "
|
||||
"to enable detail endpoints."
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
|
||||
# Try to parse as integer ID first
|
||||
try:
|
||||
pk = int(identifier)
|
||||
@@ -475,36 +458,36 @@ class ParkDetailAPIView(APIView):
|
||||
summary="Get park full details",
|
||||
description="""
|
||||
Retrieve comprehensive park details including:
|
||||
|
||||
|
||||
**Core Information:**
|
||||
- Basic park details (name, slug, description, status)
|
||||
- Opening/closing dates and operating season
|
||||
- Size in acres and website URL
|
||||
- Statistics (average rating, ride count, coaster count)
|
||||
|
||||
|
||||
**Location Data:**
|
||||
- Full address with coordinates
|
||||
- City, state, country information
|
||||
- Formatted address string
|
||||
|
||||
|
||||
**Company Information:**
|
||||
- Operating company details
|
||||
- Property owner information (if different)
|
||||
|
||||
|
||||
**Media:**
|
||||
- All approved photos with Cloudflare variants
|
||||
- Primary photo designation
|
||||
- Banner and card image settings
|
||||
|
||||
|
||||
**Related Content:**
|
||||
- Park areas/themed sections
|
||||
- Associated rides (summary)
|
||||
|
||||
|
||||
**Lookup Methods:**
|
||||
- By ID: `/api/v1/parks/123/`
|
||||
- By current slug: `/api/v1/parks/cedar-point/`
|
||||
- By historical slug: `/api/v1/parks/old-cedar-point-name/`
|
||||
|
||||
|
||||
**No Query Parameters Required** - This endpoint returns full details by default.
|
||||
""",
|
||||
responses={
|
||||
@@ -598,11 +581,11 @@ class FilterOptionsAPIView(APIView):
|
||||
"""Return comprehensive filter options with Rich Choice Objects metadata."""
|
||||
# Import Rich Choice registry
|
||||
from apps.core.choices.registry import get_choices
|
||||
|
||||
|
||||
# Always get static choice definitions from Rich Choice Objects (primary source)
|
||||
park_types = get_choices('types', 'parks')
|
||||
statuses = get_choices('statuses', 'parks')
|
||||
|
||||
|
||||
# Convert Rich Choice Objects to frontend format with metadata
|
||||
park_types_data = [
|
||||
{
|
||||
@@ -616,7 +599,7 @@ class FilterOptionsAPIView(APIView):
|
||||
}
|
||||
for choice in park_types
|
||||
]
|
||||
|
||||
|
||||
statuses_data = [
|
||||
{
|
||||
"value": choice.value,
|
||||
@@ -629,12 +612,12 @@ class FilterOptionsAPIView(APIView):
|
||||
}
|
||||
for choice in statuses
|
||||
]
|
||||
|
||||
|
||||
# Get dynamic data from database if models are available
|
||||
if MODELS_AVAILABLE:
|
||||
# Add any dynamic data queries here
|
||||
pass
|
||||
|
||||
|
||||
return Response({
|
||||
"park_types": park_types_data,
|
||||
"statuses": statuses_data,
|
||||
@@ -707,7 +690,7 @@ class FilterOptionsAPIView(APIView):
|
||||
# Get rich choice objects from registry
|
||||
park_types = get_choices('types', 'parks')
|
||||
statuses = get_choices('statuses', 'parks')
|
||||
|
||||
|
||||
# Convert Rich Choice Objects to frontend format with metadata
|
||||
park_types_data = [
|
||||
{
|
||||
@@ -721,7 +704,7 @@ class FilterOptionsAPIView(APIView):
|
||||
}
|
||||
for choice in park_types
|
||||
]
|
||||
|
||||
|
||||
statuses_data = [
|
||||
{
|
||||
"value": choice.value,
|
||||
@@ -1118,7 +1101,7 @@ class OperatorListAPIView(APIView):
|
||||
}
|
||||
for op in operators
|
||||
]
|
||||
|
||||
|
||||
return Response({
|
||||
"results": data,
|
||||
"count": len(data)
|
||||
|
||||
@@ -13,27 +13,27 @@ if TYPE_CHECKING:
|
||||
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from django.utils import timezone
|
||||
from drf_spectacular.utils import extend_schema_view, extend_schema
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_view
|
||||
from rest_framework import status
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ValidationError, NotFound
|
||||
from rest_framework.permissions import IsAuthenticated, AllowAny
|
||||
from rest_framework.exceptions import NotFound, ValidationError
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from apps.rides.models.media import RidePhoto
|
||||
from apps.rides.models import Ride
|
||||
from apps.parks.models import Park
|
||||
from apps.rides.services.media_service import RideMediaService
|
||||
from apps.api.v1.rides.serializers import (
|
||||
RidePhotoOutputSerializer,
|
||||
RidePhotoCreateInputSerializer,
|
||||
RidePhotoUpdateInputSerializer,
|
||||
RidePhotoListOutputSerializer,
|
||||
RidePhotoApprovalInputSerializer,
|
||||
RidePhotoCreateInputSerializer,
|
||||
RidePhotoListOutputSerializer,
|
||||
RidePhotoOutputSerializer,
|
||||
RidePhotoStatsOutputSerializer,
|
||||
RidePhotoUpdateInputSerializer,
|
||||
)
|
||||
from apps.parks.models import Park
|
||||
from apps.rides.models import Ride
|
||||
from apps.rides.models.media import RidePhoto
|
||||
from apps.rides.services.media_service import RideMediaService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -116,10 +116,7 @@ class RidePhotoViewSet(ModelViewSet):
|
||||
|
||||
def get_permissions(self):
|
||||
"""Set permissions based on action."""
|
||||
if self.action in ['list', 'retrieve', 'stats']:
|
||||
permission_classes = [AllowAny]
|
||||
else:
|
||||
permission_classes = [IsAuthenticated]
|
||||
permission_classes = [AllowAny] if self.action in ['list', 'retrieve', 'stats'] else [IsAuthenticated]
|
||||
return [permission() for permission in permission_classes]
|
||||
|
||||
def get_queryset(self):
|
||||
@@ -131,7 +128,7 @@ class RidePhotoViewSet(ModelViewSet):
|
||||
# Filter by park and ride from URL kwargs
|
||||
park_slug = self.kwargs.get("park_slug")
|
||||
ride_slug = self.kwargs.get("ride_slug")
|
||||
|
||||
|
||||
if park_slug and ride_slug:
|
||||
try:
|
||||
park, _ = Park.get_by_slug(park_slug)
|
||||
@@ -158,7 +155,7 @@ class RidePhotoViewSet(ModelViewSet):
|
||||
"""Create a new ride photo using RideMediaService."""
|
||||
park_slug = self.kwargs.get("park_slug")
|
||||
ride_slug = self.kwargs.get("ride_slug")
|
||||
|
||||
|
||||
if not park_slug or not ride_slug:
|
||||
raise ValidationError("Park and ride slugs are required")
|
||||
|
||||
@@ -185,7 +182,7 @@ class RidePhotoViewSet(ModelViewSet):
|
||||
|
||||
# Set the instance for the serializer response
|
||||
serializer.instance = photo
|
||||
|
||||
|
||||
logger.info(f"Created ride photo {photo.id} for ride {ride.name} by user {self.request.user.username}")
|
||||
|
||||
except Exception as e:
|
||||
@@ -249,7 +246,7 @@ class RidePhotoViewSet(ModelViewSet):
|
||||
RideMediaService.delete_photo(
|
||||
instance, deleted_by=self.request.user
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Deleted ride photo {instance.id} by user {self.request.user.username}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting ride photo: {e}")
|
||||
@@ -331,7 +328,7 @@ class RidePhotoViewSet(ModelViewSet):
|
||||
validated_data = getattr(serializer, "validated_data", {})
|
||||
photo_ids = validated_data.get("photo_ids")
|
||||
approve = validated_data.get("approve")
|
||||
|
||||
|
||||
park_slug = self.kwargs.get("park_slug")
|
||||
ride_slug = self.kwargs.get("ride_slug")
|
||||
|
||||
@@ -381,7 +378,7 @@ class RidePhotoViewSet(ModelViewSet):
|
||||
"""Get photo statistics for the ride."""
|
||||
park_slug = self.kwargs.get("park_slug")
|
||||
ride_slug = self.kwargs.get("ride_slug")
|
||||
|
||||
|
||||
if not park_slug or not ride_slug:
|
||||
return Response(
|
||||
{"error": "Park and ride slugs are required"},
|
||||
@@ -431,7 +428,7 @@ class RidePhotoViewSet(ModelViewSet):
|
||||
"""Save a Cloudflare image as a ride photo after direct upload to Cloudflare."""
|
||||
park_slug = self.kwargs.get("park_slug")
|
||||
ride_slug = self.kwargs.get("ride_slug")
|
||||
|
||||
|
||||
if not park_slug or not ride_slug:
|
||||
return Response(
|
||||
{"error": "Park and ride slugs are required"},
|
||||
|
||||
@@ -12,28 +12,28 @@ if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from django.db.models import Avg, Count, Q
|
||||
from django.db.models import Avg
|
||||
from django.utils import timezone
|
||||
from drf_spectacular.utils import extend_schema_view, extend_schema
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_view
|
||||
from rest_framework import status
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ValidationError, NotFound
|
||||
from rest_framework.permissions import IsAuthenticated, AllowAny
|
||||
from rest_framework.exceptions import NotFound, ValidationError
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from apps.rides.models.reviews import RideReview
|
||||
from apps.rides.models import Ride
|
||||
from apps.parks.models import Park
|
||||
from apps.api.v1.serializers.ride_reviews import (
|
||||
RideReviewOutputSerializer,
|
||||
RideReviewCreateInputSerializer,
|
||||
RideReviewUpdateInputSerializer,
|
||||
RideReviewListOutputSerializer,
|
||||
RideReviewStatsOutputSerializer,
|
||||
RideReviewModerationInputSerializer,
|
||||
RideReviewOutputSerializer,
|
||||
RideReviewStatsOutputSerializer,
|
||||
RideReviewUpdateInputSerializer,
|
||||
)
|
||||
from apps.parks.models import Park
|
||||
from apps.rides.models import Ride
|
||||
from apps.rides.models.reviews import RideReview
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -115,10 +115,7 @@ class RideReviewViewSet(ModelViewSet):
|
||||
|
||||
def get_permissions(self):
|
||||
"""Set permissions based on action."""
|
||||
if self.action in ['list', 'retrieve', 'stats']:
|
||||
permission_classes = [AllowAny]
|
||||
else:
|
||||
permission_classes = [IsAuthenticated]
|
||||
permission_classes = [AllowAny] if self.action in ['list', 'retrieve', 'stats'] else [IsAuthenticated]
|
||||
return [permission() for permission in permission_classes]
|
||||
|
||||
def get_queryset(self):
|
||||
@@ -130,7 +127,7 @@ class RideReviewViewSet(ModelViewSet):
|
||||
# Filter by park and ride from URL kwargs
|
||||
park_slug = self.kwargs.get("park_slug")
|
||||
ride_slug = self.kwargs.get("ride_slug")
|
||||
|
||||
|
||||
if park_slug and ride_slug:
|
||||
try:
|
||||
park, _ = Park.get_by_slug(park_slug)
|
||||
@@ -141,7 +138,7 @@ class RideReviewViewSet(ModelViewSet):
|
||||
return queryset.none()
|
||||
|
||||
# Filter published reviews for non-staff users
|
||||
if not (hasattr(self.request, 'user') and
|
||||
if not (hasattr(self.request, 'user') and
|
||||
getattr(self.request.user, 'is_staff', False)):
|
||||
queryset = queryset.filter(is_published=True)
|
||||
|
||||
@@ -162,7 +159,7 @@ class RideReviewViewSet(ModelViewSet):
|
||||
"""Create a new ride review."""
|
||||
park_slug = self.kwargs.get("park_slug")
|
||||
ride_slug = self.kwargs.get("ride_slug")
|
||||
|
||||
|
||||
if not park_slug or not ride_slug:
|
||||
raise ValidationError("Park and ride slugs are required")
|
||||
|
||||
@@ -185,7 +182,7 @@ class RideReviewViewSet(ModelViewSet):
|
||||
user=self.request.user,
|
||||
is_published=True # Auto-publish for now, can add moderation later
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Created ride review {review.id} for ride {ride.name} by user {self.request.user.username}")
|
||||
|
||||
except Exception as e:
|
||||
@@ -241,7 +238,7 @@ class RideReviewViewSet(ModelViewSet):
|
||||
"""Get review statistics for the ride."""
|
||||
park_slug = self.kwargs.get("park_slug")
|
||||
ride_slug = self.kwargs.get("ride_slug")
|
||||
|
||||
|
||||
if not park_slug or not ride_slug:
|
||||
return Response(
|
||||
{"error": "Park and ride slugs are required"},
|
||||
@@ -265,19 +262,19 @@ class RideReviewViewSet(ModelViewSet):
|
||||
try:
|
||||
# Get review statistics
|
||||
reviews = RideReview.objects.filter(ride=ride, is_published=True)
|
||||
|
||||
|
||||
total_reviews = reviews.count()
|
||||
published_reviews = total_reviews # Since we're filtering published
|
||||
pending_reviews = RideReview.objects.filter(ride=ride, is_published=False).count()
|
||||
|
||||
|
||||
# Calculate average rating
|
||||
avg_rating = reviews.aggregate(avg_rating=Avg('rating'))['avg_rating']
|
||||
|
||||
|
||||
# Get rating distribution
|
||||
rating_distribution = {}
|
||||
for i in range(1, 11):
|
||||
rating_distribution[str(i)] = reviews.filter(rating=i).count()
|
||||
|
||||
|
||||
# Get recent reviews count (last 30 days)
|
||||
from datetime import timedelta
|
||||
thirty_days_ago = timezone.now() - timedelta(days=30)
|
||||
|
||||
@@ -5,12 +5,13 @@ This module contains serializers for park-specific media functionality.
|
||||
Enhanced from rogue implementation to maintain full feature parity.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
from drf_spectacular.utils import (
|
||||
OpenApiExample,
|
||||
extend_schema_field,
|
||||
extend_schema_serializer,
|
||||
OpenApiExample,
|
||||
)
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.parks.models import Park, ParkPhoto
|
||||
|
||||
|
||||
@@ -235,7 +236,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
|
||||
Enhanced serializer for hybrid filtering strategy.
|
||||
Includes all filterable fields for client-side filtering.
|
||||
"""
|
||||
|
||||
|
||||
# Location fields from related ParkLocation
|
||||
city = serializers.SerializerMethodField()
|
||||
state = serializers.SerializerMethodField()
|
||||
@@ -243,19 +244,19 @@ class HybridParkSerializer(serializers.ModelSerializer):
|
||||
continent = serializers.SerializerMethodField()
|
||||
latitude = serializers.SerializerMethodField()
|
||||
longitude = serializers.SerializerMethodField()
|
||||
|
||||
|
||||
# Company fields
|
||||
operator_name = serializers.CharField(source="operator.name", read_only=True)
|
||||
property_owner_name = serializers.CharField(source="property_owner.name", read_only=True, allow_null=True)
|
||||
|
||||
|
||||
# Image URLs for display
|
||||
banner_image_url = serializers.SerializerMethodField()
|
||||
card_image_url = serializers.SerializerMethodField()
|
||||
|
||||
|
||||
# Computed fields for filtering
|
||||
opening_year = serializers.IntegerField(read_only=True)
|
||||
search_text = serializers.CharField(read_only=True)
|
||||
|
||||
|
||||
@extend_schema_field(serializers.CharField(allow_null=True))
|
||||
def get_city(self, obj):
|
||||
"""Get city from related location."""
|
||||
@@ -263,7 +264,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
|
||||
return obj.location.city if hasattr(obj, 'location') and obj.location else None
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.CharField(allow_null=True))
|
||||
def get_state(self, obj):
|
||||
"""Get state from related location."""
|
||||
@@ -271,7 +272,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
|
||||
return obj.location.state if hasattr(obj, 'location') and obj.location else None
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.CharField(allow_null=True))
|
||||
def get_country(self, obj):
|
||||
"""Get country from related location."""
|
||||
@@ -279,7 +280,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
|
||||
return obj.location.country if hasattr(obj, 'location') and obj.location else None
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.CharField(allow_null=True))
|
||||
def get_continent(self, obj):
|
||||
"""Get continent from related location."""
|
||||
@@ -287,7 +288,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
|
||||
return obj.location.continent if hasattr(obj, 'location') and obj.location else None
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.FloatField(allow_null=True))
|
||||
def get_latitude(self, obj):
|
||||
"""Get latitude from related location."""
|
||||
@@ -297,7 +298,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
|
||||
return None
|
||||
except (AttributeError, IndexError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.FloatField(allow_null=True))
|
||||
def get_longitude(self, obj):
|
||||
"""Get longitude from related location."""
|
||||
@@ -307,14 +308,14 @@ class HybridParkSerializer(serializers.ModelSerializer):
|
||||
return None
|
||||
except (AttributeError, IndexError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.URLField(allow_null=True))
|
||||
def get_banner_image_url(self, obj):
|
||||
"""Get banner image URL."""
|
||||
if obj.banner_image and obj.banner_image.image:
|
||||
return obj.banner_image.image.url
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.URLField(allow_null=True))
|
||||
def get_card_image_url(self, obj):
|
||||
"""Get card image URL."""
|
||||
@@ -332,42 +333,42 @@ class HybridParkSerializer(serializers.ModelSerializer):
|
||||
"description",
|
||||
"status",
|
||||
"park_type",
|
||||
|
||||
|
||||
# Dates and computed fields
|
||||
"opening_date",
|
||||
"closing_date",
|
||||
"opening_year",
|
||||
"operating_season",
|
||||
|
||||
|
||||
# Location fields
|
||||
"city",
|
||||
"state",
|
||||
"state",
|
||||
"country",
|
||||
"continent",
|
||||
"latitude",
|
||||
"longitude",
|
||||
|
||||
|
||||
# Company relationships
|
||||
"operator_name",
|
||||
"property_owner_name",
|
||||
|
||||
|
||||
# Statistics
|
||||
"size_acres",
|
||||
"average_rating",
|
||||
"ride_count",
|
||||
"coaster_count",
|
||||
|
||||
|
||||
# Images
|
||||
"banner_image_url",
|
||||
"card_image_url",
|
||||
|
||||
|
||||
# URLs
|
||||
"website",
|
||||
"url",
|
||||
|
||||
|
||||
# Computed fields for filtering
|
||||
"search_text",
|
||||
|
||||
|
||||
# Metadata
|
||||
"created_at",
|
||||
"updated_at",
|
||||
|
||||
@@ -6,28 +6,10 @@ intentionally expansive to match the rides API functionality and provide
|
||||
complete feature parity for parks management.
|
||||
"""
|
||||
|
||||
from django.urls import path, include
|
||||
from django.urls import include, path
|
||||
from rest_framework.routers import DefaultRouter
|
||||
|
||||
from .park_views import (
|
||||
ParkListCreateAPIView,
|
||||
ParkDetailAPIView,
|
||||
FilterOptionsAPIView,
|
||||
CompanySearchAPIView,
|
||||
ParkSearchSuggestionsAPIView,
|
||||
ParkImageSettingsAPIView,
|
||||
OperatorListAPIView,
|
||||
)
|
||||
from .park_rides_views import (
|
||||
ParkRidesListAPIView,
|
||||
ParkRideDetailAPIView,
|
||||
ParkComprehensiveDetailAPIView,
|
||||
)
|
||||
from apps.parks.views import location_search, reverse_geocode
|
||||
from .views import ParkPhotoViewSet, HybridParkAPIView, ParkFilterMetadataAPIView
|
||||
from .ride_photos_views import RidePhotoViewSet
|
||||
from .ride_photos_views import RidePhotoViewSet
|
||||
from .ride_reviews_views import RideReviewViewSet
|
||||
from apps.parks.views_roadtrip import (
|
||||
CreateTripView,
|
||||
FindParksAlongRouteView,
|
||||
@@ -35,6 +17,24 @@ from apps.parks.views_roadtrip import (
|
||||
ParkDistanceCalculatorView,
|
||||
)
|
||||
|
||||
from .park_rides_views import (
|
||||
ParkComprehensiveDetailAPIView,
|
||||
ParkRideDetailAPIView,
|
||||
ParkRidesListAPIView,
|
||||
)
|
||||
from .park_views import (
|
||||
CompanySearchAPIView,
|
||||
FilterOptionsAPIView,
|
||||
OperatorListAPIView,
|
||||
ParkDetailAPIView,
|
||||
ParkImageSettingsAPIView,
|
||||
ParkListCreateAPIView,
|
||||
ParkSearchSuggestionsAPIView,
|
||||
)
|
||||
from .ride_photos_views import RidePhotoViewSet
|
||||
from .ride_reviews_views import RideReviewViewSet
|
||||
from .views import HybridParkAPIView, ParkFilterMetadataAPIView, ParkPhotoViewSet
|
||||
|
||||
# Create router for nested photo endpoints
|
||||
router = DefaultRouter()
|
||||
router.register(r"", ParkPhotoViewSet, basename="park-photo")
|
||||
@@ -42,13 +42,12 @@ router.register(r"", ParkPhotoViewSet, basename="park-photo")
|
||||
# Create routers for nested ride endpoints
|
||||
ride_photos_router = DefaultRouter()
|
||||
ride_photos_router.register(r"", RidePhotoViewSet, basename="ride-photo")
|
||||
from .ride_reviews_views import RideReviewViewSet
|
||||
|
||||
ride_reviews_router = DefaultRouter()
|
||||
ride_reviews_router.register(r"", RideReviewViewSet, basename="ride-review")
|
||||
|
||||
from .park_reviews_views import ParkReviewViewSet
|
||||
from .history_views import ParkHistoryViewSet, RideHistoryViewSet
|
||||
from .park_reviews_views import ParkReviewViewSet
|
||||
|
||||
# Create routers for nested park endpoints
|
||||
reviews_router = DefaultRouter()
|
||||
@@ -60,11 +59,11 @@ app_name = "api_v1_parks"
|
||||
urlpatterns = [
|
||||
# Core list/create endpoints
|
||||
path("", ParkListCreateAPIView.as_view(), name="park-list-create"),
|
||||
|
||||
|
||||
# Hybrid filtering endpoints
|
||||
path("hybrid/", HybridParkAPIView.as_view(), name="park-hybrid-list"),
|
||||
path("hybrid/filter-metadata/", ParkFilterMetadataAPIView.as_view(), name="park-hybrid-filter-metadata"),
|
||||
|
||||
|
||||
# Filter options
|
||||
path("filter-options/", FilterOptionsAPIView.as_view(), name="park-filter-options"),
|
||||
# Autocomplete / suggestion endpoints
|
||||
@@ -80,14 +79,14 @@ urlpatterns = [
|
||||
),
|
||||
# Detail and action endpoints - supports both ID and slug
|
||||
path("<str:pk>/", ParkDetailAPIView.as_view(), name="park-detail"),
|
||||
|
||||
|
||||
# Park rides endpoints
|
||||
path("<str:park_slug>/rides/", ParkRidesListAPIView.as_view(), name="park-rides-list"),
|
||||
path("<str:park_slug>/rides/<str:ride_slug>/", ParkRideDetailAPIView.as_view(), name="park-ride-detail"),
|
||||
|
||||
|
||||
# Comprehensive park detail endpoint with rides summary
|
||||
path("<str:park_slug>/detail/", ParkComprehensiveDetailAPIView.as_view(), name="park-comprehensive-detail"),
|
||||
|
||||
|
||||
# Park image settings endpoint
|
||||
path(
|
||||
"<int:pk>/image-settings/",
|
||||
@@ -96,21 +95,21 @@ urlpatterns = [
|
||||
),
|
||||
# Park photo endpoints - domain-specific photo management
|
||||
path("<str:park_pk>/photos/", include(router.urls)),
|
||||
|
||||
|
||||
# Nested ride photo endpoints - photos for specific rides within parks
|
||||
path("<str:park_slug>/rides/<str:ride_slug>/photos/", include(ride_photos_router.urls)),
|
||||
|
||||
|
||||
# Nested ride review endpoints - reviews for specific rides within parks
|
||||
path("<str:park_slug>/rides/<str:ride_slug>/reviews/", include(ride_reviews_router.urls)),
|
||||
# Nested ride review endpoints - reviews for specific rides within parks
|
||||
path("<str:park_slug>/rides/<str:ride_slug>/reviews/", include(ride_reviews_router.urls)),
|
||||
|
||||
|
||||
# Ride History
|
||||
path("<str:park_slug>/rides/<str:ride_slug>/history/", RideHistoryViewSet.as_view({'get': 'list'}), name="ride-history"),
|
||||
|
||||
# Park Reviews
|
||||
path("<str:park_slug>/reviews/", include(reviews_router.urls)),
|
||||
|
||||
|
||||
# Park History
|
||||
path("<str:park_slug>/history/", ParkHistoryViewSet.as_view({'get': 'list'}), name="park-history"),
|
||||
|
||||
|
||||
@@ -26,14 +26,13 @@ from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from apps.core.decorators.cache_decorators import cache_api_response
|
||||
from apps.core.exceptions import (
|
||||
NotFoundError,
|
||||
PermissionDeniedError,
|
||||
ServiceError,
|
||||
ValidationException,
|
||||
)
|
||||
from apps.core.utils.error_handling import ErrorHandler
|
||||
from apps.core.decorators.cache_decorators import cache_api_response
|
||||
from apps.parks.models import Park, ParkPhoto
|
||||
from apps.parks.services import ParkMediaService
|
||||
from apps.parks.services.hybrid_loader import smart_park_loader
|
||||
@@ -130,10 +129,7 @@ class ParkPhotoViewSet(ModelViewSet):
|
||||
|
||||
def get_permissions(self):
|
||||
"""Set permissions based on action."""
|
||||
if self.action in ["list", "retrieve", "stats"]:
|
||||
permission_classes = [AllowAny]
|
||||
else:
|
||||
permission_classes = [IsAuthenticated]
|
||||
permission_classes = [AllowAny] if self.action in ["list", "retrieve", "stats"] else [IsAuthenticated]
|
||||
return [permission() for permission in permission_classes]
|
||||
|
||||
def get_queryset(self): # type: ignore[override]
|
||||
@@ -171,11 +167,8 @@ class ParkPhotoViewSet(ModelViewSet):
|
||||
raise ValidationError("Park ID/Slug is required")
|
||||
|
||||
try:
|
||||
if str(park_id).isdigit():
|
||||
park = Park.objects.get(pk=park_id)
|
||||
else:
|
||||
park = Park.objects.get(slug=park_id)
|
||||
|
||||
park = Park.objects.get(pk=park_id) if str(park_id).isdigit() else Park.objects.get(slug=park_id)
|
||||
|
||||
# Use real park ID
|
||||
park_id = park.id
|
||||
except Park.DoesNotExist:
|
||||
@@ -398,10 +391,7 @@ class ParkPhotoViewSet(ModelViewSet):
|
||||
park = None
|
||||
if park_pk:
|
||||
try:
|
||||
if str(park_pk).isdigit():
|
||||
park = Park.objects.get(pk=park_pk)
|
||||
else:
|
||||
park = Park.objects.get(slug=park_pk)
|
||||
park = Park.objects.get(pk=park_pk) if str(park_pk).isdigit() else Park.objects.get(slug=park_pk)
|
||||
except Park.DoesNotExist:
|
||||
return ErrorHandler.handle_api_error(
|
||||
NotFoundError(f"Park with id/slug {park_pk} not found"),
|
||||
@@ -490,10 +480,7 @@ class ParkPhotoViewSet(ModelViewSet):
|
||||
)
|
||||
|
||||
try:
|
||||
if str(park_pk).isdigit():
|
||||
park = Park.objects.get(pk=park_pk)
|
||||
else:
|
||||
park = Park.objects.get(slug=park_pk)
|
||||
park = Park.objects.get(pk=park_pk) if str(park_pk).isdigit() else Park.objects.get(slug=park_pk)
|
||||
except Park.DoesNotExist:
|
||||
return Response(
|
||||
{"error": "Park not found"},
|
||||
@@ -509,9 +496,9 @@ class ParkPhotoViewSet(ModelViewSet):
|
||||
|
||||
try:
|
||||
# Import CloudflareImage model and service
|
||||
from django.utils import timezone
|
||||
from django_cloudflareimages_toolkit.models import CloudflareImage
|
||||
from django_cloudflareimages_toolkit.services import CloudflareImagesService
|
||||
from django.utils import timezone
|
||||
|
||||
# Always fetch the latest image data from Cloudflare API
|
||||
# Get image details from Cloudflare API
|
||||
|
||||
12
backend/apps/api/v1/rides/company_urls.py
Normal file
12
backend/apps/api/v1/rides/company_urls.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""URL routes for Company CRUD API."""
|
||||
|
||||
from django.urls import path
|
||||
|
||||
from .company_views import CompanyDetailAPIView, CompanyListCreateAPIView
|
||||
|
||||
app_name = "api_v1_companies"
|
||||
|
||||
urlpatterns = [
|
||||
path("", CompanyListCreateAPIView.as_view(), name="company-list-create"),
|
||||
path("<int:pk>/", CompanyDetailAPIView.as_view(), name="company-detail"),
|
||||
]
|
||||
167
backend/apps/api/v1/rides/company_views.py
Normal file
167
backend/apps/api/v1/rides/company_views.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Company API views for ThrillWiki API v1.
|
||||
|
||||
This module implements CRUD endpoints for company management:
|
||||
- List / Create: GET /companies/ POST /companies/
|
||||
- Retrieve / Update / Delete: GET /companies/{id}/ PATCH/PUT/DELETE
|
||||
"""
|
||||
|
||||
from django.db.models import Q
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, extend_schema
|
||||
from rest_framework import permissions, status
|
||||
from rest_framework.exceptions import NotFound
|
||||
from rest_framework.pagination import PageNumberPagination
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from apps.api.v1.serializers.companies import (
|
||||
CompanyCreateInputSerializer,
|
||||
CompanyDetailOutputSerializer,
|
||||
CompanyUpdateInputSerializer,
|
||||
)
|
||||
|
||||
try:
|
||||
from apps.rides.models.company import Company
|
||||
MODELS_AVAILABLE = True
|
||||
except ImportError:
|
||||
Company = None
|
||||
MODELS_AVAILABLE = False
|
||||
|
||||
|
||||
class StandardResultsSetPagination(PageNumberPagination):
|
||||
page_size = 20
|
||||
page_size_query_param = "page_size"
|
||||
max_page_size = 100
|
||||
|
||||
|
||||
class CompanyListCreateAPIView(APIView):
|
||||
"""List and create companies."""
|
||||
|
||||
permission_classes = [permissions.AllowAny]
|
||||
|
||||
@extend_schema(
|
||||
summary="List all companies",
|
||||
description="List companies with optional search and role filtering.",
|
||||
parameters=[
|
||||
OpenApiParameter(name="search", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR),
|
||||
OpenApiParameter(name="role", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR),
|
||||
OpenApiParameter(name="page", location=OpenApiParameter.QUERY, type=OpenApiTypes.INT),
|
||||
OpenApiParameter(name="page_size", location=OpenApiParameter.QUERY, type=OpenApiTypes.INT),
|
||||
],
|
||||
responses={200: CompanyDetailOutputSerializer(many=True)},
|
||||
tags=["Companies"],
|
||||
)
|
||||
def get(self, request: Request) -> Response:
|
||||
if not MODELS_AVAILABLE:
|
||||
return Response(
|
||||
{"detail": "Company models not available"},
|
||||
status=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
)
|
||||
|
||||
qs = Company.objects.all().order_by("name")
|
||||
|
||||
# Search filter
|
||||
search = request.query_params.get("search", "")
|
||||
if search:
|
||||
qs = qs.filter(
|
||||
Q(name__icontains=search) | Q(description__icontains=search)
|
||||
)
|
||||
|
||||
# Role filter
|
||||
role = request.query_params.get("role", "")
|
||||
if role:
|
||||
qs = qs.filter(roles__contains=[role])
|
||||
|
||||
paginator = StandardResultsSetPagination()
|
||||
page = paginator.paginate_queryset(qs, request)
|
||||
serializer = CompanyDetailOutputSerializer(page, many=True)
|
||||
return paginator.get_paginated_response(serializer.data)
|
||||
|
||||
@extend_schema(
|
||||
summary="Create a new company",
|
||||
description="Create a new company with the given details.",
|
||||
request=CompanyCreateInputSerializer,
|
||||
responses={201: CompanyDetailOutputSerializer()},
|
||||
tags=["Companies"],
|
||||
)
|
||||
def post(self, request: Request) -> Response:
|
||||
if not MODELS_AVAILABLE:
|
||||
return Response(
|
||||
{"detail": "Company models not available"},
|
||||
status=status.HTTP_501_NOT_IMPLEMENTED,
|
||||
)
|
||||
|
||||
serializer_in = CompanyCreateInputSerializer(data=request.data)
|
||||
serializer_in.is_valid(raise_exception=True)
|
||||
validated = serializer_in.validated_data
|
||||
|
||||
company = Company.objects.create(
|
||||
name=validated["name"],
|
||||
roles=validated["roles"],
|
||||
description=validated.get("description", ""),
|
||||
website=validated.get("website", ""),
|
||||
founded_date=validated.get("founded_date"),
|
||||
)
|
||||
|
||||
serializer = CompanyDetailOutputSerializer(company)
|
||||
return Response(serializer.data, status=status.HTTP_201_CREATED)
|
||||
|
||||
|
||||
class CompanyDetailAPIView(APIView):
|
||||
"""Retrieve, update, and delete a company."""
|
||||
|
||||
permission_classes = [permissions.AllowAny]
|
||||
|
||||
def _get_company_or_404(self, pk: int) -> "Company":
|
||||
if not MODELS_AVAILABLE:
|
||||
raise NotFound("Company models not available")
|
||||
try:
|
||||
return Company.objects.get(pk=pk)
|
||||
except Company.DoesNotExist:
|
||||
raise NotFound("Company not found")
|
||||
|
||||
@extend_schema(
|
||||
summary="Retrieve a company",
|
||||
description="Get detailed information about a specific company.",
|
||||
responses={200: CompanyDetailOutputSerializer()},
|
||||
tags=["Companies"],
|
||||
)
|
||||
def get(self, request: Request, pk: int) -> Response:
|
||||
company = self._get_company_or_404(pk)
|
||||
serializer = CompanyDetailOutputSerializer(company)
|
||||
return Response(serializer.data)
|
||||
|
||||
@extend_schema(
|
||||
summary="Update a company",
|
||||
description="Update a company (partial update supported).",
|
||||
request=CompanyUpdateInputSerializer,
|
||||
responses={200: CompanyDetailOutputSerializer()},
|
||||
tags=["Companies"],
|
||||
)
|
||||
def patch(self, request: Request, pk: int) -> Response:
|
||||
company = self._get_company_or_404(pk)
|
||||
serializer_in = CompanyUpdateInputSerializer(data=request.data, partial=True)
|
||||
serializer_in.is_valid(raise_exception=True)
|
||||
|
||||
for field, value in serializer_in.validated_data.items():
|
||||
setattr(company, field, value)
|
||||
company.save()
|
||||
|
||||
serializer = CompanyDetailOutputSerializer(company)
|
||||
return Response(serializer.data)
|
||||
|
||||
def put(self, request: Request, pk: int) -> Response:
|
||||
return self.patch(request, pk)
|
||||
|
||||
@extend_schema(
|
||||
summary="Delete a company",
|
||||
description="Delete a company.",
|
||||
responses={204: None},
|
||||
tags=["Companies"],
|
||||
)
|
||||
def delete(self, request: Request, pk: int) -> Response:
|
||||
company = self._get_company_or_404(pk)
|
||||
company.delete()
|
||||
return Response(status=status.HTTP_204_NO_CONTENT)
|
||||
@@ -11,17 +11,17 @@ This file exposes comprehensive endpoints for ride model management:
|
||||
from django.urls import path
|
||||
|
||||
from .views import (
|
||||
RideModelListCreateAPIView,
|
||||
RideModelDetailAPIView,
|
||||
RideModelSearchAPIView,
|
||||
RideModelFilterOptionsAPIView,
|
||||
RideModelStatsAPIView,
|
||||
RideModelVariantListCreateAPIView,
|
||||
RideModelVariantDetailAPIView,
|
||||
RideModelTechnicalSpecListCreateAPIView,
|
||||
RideModelTechnicalSpecDetailAPIView,
|
||||
RideModelPhotoListCreateAPIView,
|
||||
RideModelListCreateAPIView,
|
||||
RideModelPhotoDetailAPIView,
|
||||
RideModelPhotoListCreateAPIView,
|
||||
RideModelSearchAPIView,
|
||||
RideModelStatsAPIView,
|
||||
RideModelTechnicalSpecDetailAPIView,
|
||||
RideModelTechnicalSpecListCreateAPIView,
|
||||
RideModelVariantDetailAPIView,
|
||||
RideModelVariantListCreateAPIView,
|
||||
)
|
||||
|
||||
app_name = "api_v1_ride_models"
|
||||
|
||||
@@ -12,40 +12,40 @@ This module implements comprehensive endpoints for ride model management:
|
||||
- Photos: CRUD operations for ride model photos
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from rest_framework import status, permissions
|
||||
from rest_framework.views import APIView
|
||||
from django.db.models import Count, Q
|
||||
from django.utils import timezone
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, extend_schema
|
||||
from rest_framework import permissions, status
|
||||
from rest_framework.exceptions import NotFound, ValidationError
|
||||
from rest_framework.pagination import PageNumberPagination
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.pagination import PageNumberPagination
|
||||
from rest_framework.exceptions import NotFound, ValidationError
|
||||
from drf_spectacular.utils import extend_schema, OpenApiParameter
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from django.db.models import Q, Count
|
||||
from django.utils import timezone
|
||||
from rest_framework.views import APIView
|
||||
|
||||
# Import serializers
|
||||
from apps.api.v1.serializers.ride_models import (
|
||||
RideModelListOutputSerializer,
|
||||
RideModelDetailOutputSerializer,
|
||||
RideModelCreateInputSerializer,
|
||||
RideModelUpdateInputSerializer,
|
||||
RideModelDetailOutputSerializer,
|
||||
RideModelFilterInputSerializer,
|
||||
RideModelVariantOutputSerializer,
|
||||
RideModelVariantCreateInputSerializer,
|
||||
RideModelVariantUpdateInputSerializer,
|
||||
RideModelListOutputSerializer,
|
||||
RideModelStatsOutputSerializer,
|
||||
RideModelUpdateInputSerializer,
|
||||
RideModelVariantCreateInputSerializer,
|
||||
RideModelVariantOutputSerializer,
|
||||
RideModelVariantUpdateInputSerializer,
|
||||
)
|
||||
|
||||
# Attempt to import models; fall back gracefully if not present
|
||||
try:
|
||||
from apps.rides.models import (
|
||||
RideModel,
|
||||
RideModelVariant,
|
||||
RideModelPhoto,
|
||||
RideModelTechnicalSpec,
|
||||
RideModelVariant,
|
||||
)
|
||||
from apps.rides.models.company import Company
|
||||
|
||||
@@ -54,12 +54,12 @@ except ImportError:
|
||||
try:
|
||||
# Try alternative import path
|
||||
from apps.rides.models.rides import (
|
||||
Company,
|
||||
RideModel,
|
||||
RideModelVariant,
|
||||
RideModelPhoto,
|
||||
RideModelTechnicalSpec,
|
||||
RideModelVariant,
|
||||
)
|
||||
from apps.rides.models.rides import Company
|
||||
|
||||
MODELS_AVAILABLE = True
|
||||
except ImportError:
|
||||
@@ -486,14 +486,14 @@ class RideModelFilterOptionsAPIView(APIView):
|
||||
"""Return filter options for ride models with Rich Choice Objects metadata."""
|
||||
# Import Rich Choice registry
|
||||
from apps.core.choices.registry import get_choices
|
||||
|
||||
|
||||
if not MODELS_AVAILABLE:
|
||||
# Use Rich Choice Objects for fallback options
|
||||
try:
|
||||
# Get rich choice objects from registry
|
||||
categories = get_choices('categories', 'rides')
|
||||
target_markets = get_choices('target_markets', 'rides')
|
||||
|
||||
|
||||
# Convert Rich Choice Objects to frontend format with metadata
|
||||
categories_data = [
|
||||
{
|
||||
@@ -507,7 +507,7 @@ class RideModelFilterOptionsAPIView(APIView):
|
||||
}
|
||||
for choice in categories
|
||||
]
|
||||
|
||||
|
||||
target_markets_data = [
|
||||
{
|
||||
"value": choice.value,
|
||||
@@ -520,7 +520,7 @@ class RideModelFilterOptionsAPIView(APIView):
|
||||
}
|
||||
for choice in target_markets
|
||||
]
|
||||
|
||||
|
||||
except Exception:
|
||||
# Ultimate fallback with basic structure
|
||||
categories_data = [
|
||||
@@ -538,7 +538,7 @@ class RideModelFilterOptionsAPIView(APIView):
|
||||
{"value": "KIDDIE", "label": "Kiddie", "description": "Designed for young children", "color": "pink", "icon": "kiddie", "css_class": "bg-pink-100 text-pink-800", "sort_order": 4},
|
||||
{"value": "ALL_AGES", "label": "All Ages", "description": "Enjoyable for all age groups", "color": "blue", "icon": "all-ages", "css_class": "bg-blue-100 text-blue-800", "sort_order": 5},
|
||||
]
|
||||
|
||||
|
||||
return Response({
|
||||
"categories": categories_data,
|
||||
"target_markets": target_markets_data,
|
||||
@@ -557,11 +557,11 @@ class RideModelFilterOptionsAPIView(APIView):
|
||||
|
||||
# Get static choice definitions from Rich Choice Objects (primary source)
|
||||
# Get dynamic data from database queries
|
||||
|
||||
|
||||
# Get rich choice objects from registry
|
||||
categories = get_choices('categories', 'rides')
|
||||
target_markets = get_choices('target_markets', 'rides')
|
||||
|
||||
|
||||
# Convert Rich Choice Objects to frontend format with metadata
|
||||
categories_data = [
|
||||
{
|
||||
@@ -575,7 +575,7 @@ class RideModelFilterOptionsAPIView(APIView):
|
||||
}
|
||||
for choice in categories
|
||||
]
|
||||
|
||||
|
||||
target_markets_data = [
|
||||
{
|
||||
"value": choice.value,
|
||||
|
||||
@@ -5,23 +5,25 @@ This module contains ride photo ViewSet following the parks pattern for domain c
|
||||
Enhanced from centralized media API to provide domain-specific ride photo management.
|
||||
"""
|
||||
|
||||
from .serializers import (
|
||||
RidePhotoOutputSerializer,
|
||||
RidePhotoCreateInputSerializer,
|
||||
RidePhotoUpdateInputSerializer,
|
||||
RidePhotoListOutputSerializer,
|
||||
RidePhotoApprovalInputSerializer,
|
||||
RidePhotoStatsOutputSerializer,
|
||||
)
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .serializers import (
|
||||
RidePhotoApprovalInputSerializer,
|
||||
RidePhotoCreateInputSerializer,
|
||||
RidePhotoListOutputSerializer,
|
||||
RidePhotoOutputSerializer,
|
||||
RidePhotoStatsOutputSerializer,
|
||||
RidePhotoUpdateInputSerializer,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
import logging
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from django.core.exceptions import PermissionDenied
|
||||
from drf_spectacular.utils import extend_schema_view, extend_schema
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_view
|
||||
from rest_framework import status
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.exceptions import ValidationError
|
||||
@@ -29,9 +31,8 @@ from rest_framework.permissions import IsAuthenticated
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from apps.rides.models import RidePhoto, Ride
|
||||
from apps.rides.models import Ride, RidePhoto
|
||||
from apps.rides.services.media_service import RideMediaService
|
||||
from django.contrib.auth import get_user_model
|
||||
|
||||
UserModel = get_user_model()
|
||||
|
||||
@@ -460,9 +461,9 @@ class RidePhotoViewSet(ModelViewSet):
|
||||
|
||||
try:
|
||||
# Import CloudflareImage model and service
|
||||
from django.utils import timezone
|
||||
from django_cloudflareimages_toolkit.models import CloudflareImage
|
||||
from django_cloudflareimages_toolkit.services import CloudflareImagesService
|
||||
from django.utils import timezone
|
||||
|
||||
# Always fetch the latest image data from Cloudflare API
|
||||
try:
|
||||
|
||||
@@ -4,12 +4,13 @@ Ride media serializers for ThrillWiki API v1.
|
||||
This module contains serializers for ride-specific media functionality.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
from drf_spectacular.utils import (
|
||||
OpenApiExample,
|
||||
extend_schema_field,
|
||||
extend_schema_serializer,
|
||||
OpenApiExample,
|
||||
)
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.rides.models import Ride, RidePhoto
|
||||
|
||||
|
||||
@@ -267,33 +268,33 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
Enhanced serializer for hybrid filtering strategy.
|
||||
Includes all filterable fields for client-side filtering.
|
||||
"""
|
||||
|
||||
|
||||
# Park fields
|
||||
park_name = serializers.CharField(source="park.name", read_only=True)
|
||||
park_slug = serializers.CharField(source="park.slug", read_only=True)
|
||||
|
||||
|
||||
# Park location fields
|
||||
park_city = serializers.SerializerMethodField()
|
||||
park_state = serializers.SerializerMethodField()
|
||||
park_country = serializers.SerializerMethodField()
|
||||
|
||||
|
||||
# Park area fields
|
||||
park_area_name = serializers.CharField(source="park_area.name", read_only=True, allow_null=True)
|
||||
park_area_slug = serializers.CharField(source="park_area.slug", read_only=True, allow_null=True)
|
||||
|
||||
|
||||
# Company fields
|
||||
manufacturer_name = serializers.CharField(source="manufacturer.name", read_only=True, allow_null=True)
|
||||
manufacturer_slug = serializers.CharField(source="manufacturer.slug", read_only=True, allow_null=True)
|
||||
designer_name = serializers.CharField(source="designer.name", read_only=True, allow_null=True)
|
||||
designer_slug = serializers.CharField(source="designer.slug", read_only=True, allow_null=True)
|
||||
|
||||
|
||||
# Ride model fields
|
||||
ride_model_name = serializers.CharField(source="ride_model.name", read_only=True, allow_null=True)
|
||||
ride_model_slug = serializers.CharField(source="ride_model.slug", read_only=True, allow_null=True)
|
||||
ride_model_category = serializers.CharField(source="ride_model.category", read_only=True, allow_null=True)
|
||||
ride_model_manufacturer_name = serializers.CharField(source="ride_model.manufacturer.name", read_only=True, allow_null=True)
|
||||
ride_model_manufacturer_slug = serializers.CharField(source="ride_model.manufacturer.slug", read_only=True, allow_null=True)
|
||||
|
||||
|
||||
# Roller coaster stats fields
|
||||
coaster_height_ft = serializers.SerializerMethodField()
|
||||
coaster_length_ft = serializers.SerializerMethodField()
|
||||
@@ -309,15 +310,15 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
coaster_trains_count = serializers.SerializerMethodField()
|
||||
coaster_cars_per_train = serializers.SerializerMethodField()
|
||||
coaster_seats_per_car = serializers.SerializerMethodField()
|
||||
|
||||
|
||||
# Image URLs for display
|
||||
banner_image_url = serializers.SerializerMethodField()
|
||||
card_image_url = serializers.SerializerMethodField()
|
||||
|
||||
|
||||
# Computed fields for filtering
|
||||
opening_year = serializers.IntegerField(read_only=True)
|
||||
search_text = serializers.CharField(read_only=True)
|
||||
|
||||
|
||||
@extend_schema_field(serializers.CharField(allow_null=True))
|
||||
def get_park_city(self, obj):
|
||||
"""Get city from park location."""
|
||||
@@ -327,7 +328,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
return None
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.CharField(allow_null=True))
|
||||
def get_park_state(self, obj):
|
||||
"""Get state from park location."""
|
||||
@@ -337,7 +338,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
return None
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.CharField(allow_null=True))
|
||||
def get_park_country(self, obj):
|
||||
"""Get country from park location."""
|
||||
@@ -347,7 +348,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
return None
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.FloatField(allow_null=True))
|
||||
def get_coaster_height_ft(self, obj):
|
||||
"""Get roller coaster height."""
|
||||
@@ -357,7 +358,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
return None
|
||||
except (AttributeError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.FloatField(allow_null=True))
|
||||
def get_coaster_length_ft(self, obj):
|
||||
"""Get roller coaster length."""
|
||||
@@ -367,7 +368,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
return None
|
||||
except (AttributeError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.FloatField(allow_null=True))
|
||||
def get_coaster_speed_mph(self, obj):
|
||||
"""Get roller coaster speed."""
|
||||
@@ -377,7 +378,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
return None
|
||||
except (AttributeError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.IntegerField(allow_null=True))
|
||||
def get_coaster_inversions(self, obj):
|
||||
"""Get roller coaster inversions."""
|
||||
@@ -387,7 +388,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
return None
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.IntegerField(allow_null=True))
|
||||
def get_coaster_ride_time_seconds(self, obj):
|
||||
"""Get roller coaster ride time."""
|
||||
@@ -397,7 +398,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
return None
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.CharField(allow_null=True))
|
||||
def get_coaster_track_type(self, obj):
|
||||
"""Get roller coaster track type."""
|
||||
@@ -407,7 +408,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
return None
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.CharField(allow_null=True))
|
||||
def get_coaster_track_material(self, obj):
|
||||
"""Get roller coaster track material."""
|
||||
@@ -417,7 +418,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
return None
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.CharField(allow_null=True))
|
||||
def get_coaster_roller_coaster_type(self, obj):
|
||||
"""Get roller coaster type."""
|
||||
@@ -427,7 +428,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
return None
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.FloatField(allow_null=True))
|
||||
def get_coaster_max_drop_height_ft(self, obj):
|
||||
"""Get roller coaster max drop height."""
|
||||
@@ -437,7 +438,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
return None
|
||||
except (AttributeError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.CharField(allow_null=True))
|
||||
def get_coaster_propulsion_system(self, obj):
|
||||
"""Get roller coaster propulsion system."""
|
||||
@@ -447,7 +448,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
return None
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.CharField(allow_null=True))
|
||||
def get_coaster_train_style(self, obj):
|
||||
"""Get roller coaster train style."""
|
||||
@@ -457,7 +458,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
return None
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.IntegerField(allow_null=True))
|
||||
def get_coaster_trains_count(self, obj):
|
||||
"""Get roller coaster trains count."""
|
||||
@@ -467,7 +468,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
return None
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.IntegerField(allow_null=True))
|
||||
def get_coaster_cars_per_train(self, obj):
|
||||
"""Get roller coaster cars per train."""
|
||||
@@ -477,7 +478,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
return None
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.IntegerField(allow_null=True))
|
||||
def get_coaster_seats_per_car(self, obj):
|
||||
"""Get roller coaster seats per car."""
|
||||
@@ -487,14 +488,14 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
return None
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.URLField(allow_null=True))
|
||||
def get_banner_image_url(self, obj):
|
||||
"""Get banner image URL."""
|
||||
if obj.banner_image and obj.banner_image.image:
|
||||
return obj.banner_image.image.url
|
||||
return None
|
||||
|
||||
|
||||
@extend_schema_field(serializers.URLField(allow_null=True))
|
||||
def get_card_image_url(self, obj):
|
||||
"""Get card image URL."""
|
||||
@@ -513,44 +514,44 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
"category",
|
||||
"status",
|
||||
"post_closing_status",
|
||||
|
||||
|
||||
# Dates and computed fields
|
||||
"opening_date",
|
||||
"closing_date",
|
||||
"status_since",
|
||||
"opening_year",
|
||||
|
||||
|
||||
# Park fields
|
||||
"park_name",
|
||||
"park_slug",
|
||||
"park_city",
|
||||
"park_state",
|
||||
"park_country",
|
||||
|
||||
|
||||
# Park area fields
|
||||
"park_area_name",
|
||||
"park_area_slug",
|
||||
|
||||
|
||||
# Company fields
|
||||
"manufacturer_name",
|
||||
"manufacturer_slug",
|
||||
"designer_name",
|
||||
"designer_slug",
|
||||
|
||||
|
||||
# Ride model fields
|
||||
"ride_model_name",
|
||||
"ride_model_slug",
|
||||
"ride_model_category",
|
||||
"ride_model_manufacturer_name",
|
||||
"ride_model_manufacturer_slug",
|
||||
|
||||
|
||||
# Ride specifications
|
||||
"min_height_in",
|
||||
"max_height_in",
|
||||
"capacity_per_hour",
|
||||
"ride_duration_seconds",
|
||||
"average_rating",
|
||||
|
||||
|
||||
# Roller coaster stats
|
||||
"coaster_height_ft",
|
||||
"coaster_length_ft",
|
||||
@@ -566,18 +567,18 @@ class HybridRideSerializer(serializers.ModelSerializer):
|
||||
"coaster_trains_count",
|
||||
"coaster_cars_per_train",
|
||||
"coaster_seats_per_car",
|
||||
|
||||
|
||||
# Images
|
||||
"banner_image_url",
|
||||
"card_image_url",
|
||||
|
||||
|
||||
# URLs
|
||||
"url",
|
||||
"park_url",
|
||||
|
||||
|
||||
# Computed fields for filtering
|
||||
"search_text",
|
||||
|
||||
|
||||
# Metadata
|
||||
"created_at",
|
||||
"updated_at",
|
||||
|
||||
@@ -8,23 +8,23 @@ actions (bulk, publish, export, import, recommendations) should be added
|
||||
to the views module when business logic is available.
|
||||
"""
|
||||
|
||||
from django.urls import path, include
|
||||
from django.urls import include, path
|
||||
from rest_framework.routers import DefaultRouter
|
||||
|
||||
from .photo_views import RidePhotoViewSet
|
||||
from .views import (
|
||||
RideListCreateAPIView,
|
||||
RideDetailAPIView,
|
||||
FilterOptionsAPIView,
|
||||
CompanySearchAPIView,
|
||||
DesignerListAPIView,
|
||||
FilterOptionsAPIView,
|
||||
HybridRideAPIView,
|
||||
ManufacturerListAPIView,
|
||||
RideDetailAPIView,
|
||||
RideFilterMetadataAPIView,
|
||||
RideImageSettingsAPIView,
|
||||
RideListCreateAPIView,
|
||||
RideModelSearchAPIView,
|
||||
RideSearchSuggestionsAPIView,
|
||||
RideImageSettingsAPIView,
|
||||
HybridRideAPIView,
|
||||
RideFilterMetadataAPIView,
|
||||
ManufacturerListAPIView,
|
||||
DesignerListAPIView,
|
||||
)
|
||||
from .photo_views import RidePhotoViewSet
|
||||
|
||||
# Create router for nested photo endpoints
|
||||
router = DefaultRouter()
|
||||
@@ -35,11 +35,11 @@ app_name = "api_v1_rides"
|
||||
urlpatterns = [
|
||||
# Core list/create endpoints
|
||||
path("", RideListCreateAPIView.as_view(), name="ride-list-create"),
|
||||
|
||||
|
||||
# Hybrid filtering endpoints
|
||||
path("hybrid/", HybridRideAPIView.as_view(), name="ride-hybrid-filtering"),
|
||||
path("hybrid/filter-metadata/", RideFilterMetadataAPIView.as_view(), name="ride-hybrid-filter-metadata"),
|
||||
|
||||
|
||||
# Filter options
|
||||
path("filter-options/", FilterOptionsAPIView.as_view(), name="ride-filter-options"),
|
||||
# Autocomplete / suggestion endpoints
|
||||
|
||||
@@ -23,12 +23,13 @@ Caching Strategy:
|
||||
- RideSearchSuggestionsAPIView.get: 5 minutes (300s) - suggestions should be fresh
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from django.db import models
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_view, OpenApiParameter
|
||||
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view
|
||||
from rest_framework import permissions, status
|
||||
from rest_framework.exceptions import NotFound
|
||||
from rest_framework.pagination import PageNumberPagination
|
||||
@@ -53,9 +54,9 @@ smart_ride_loader = SmartRideLoader()
|
||||
|
||||
# Attempt to import model-level helpers; fall back gracefully if not present.
|
||||
try:
|
||||
from apps.parks.models import Company, Park
|
||||
from apps.rides.models import Ride, RideModel
|
||||
from apps.rides.models.rides import RollerCoasterStats
|
||||
from apps.parks.models import Park, Company
|
||||
|
||||
MODELS_AVAILABLE = True
|
||||
except Exception:
|
||||
@@ -370,10 +371,8 @@ class RideListCreateAPIView(APIView):
|
||||
|
||||
park_id = params.get("park_id")
|
||||
if park_id:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(park_id=int(park_id))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return qs
|
||||
|
||||
@@ -393,10 +392,8 @@ class RideListCreateAPIView(APIView):
|
||||
"""Apply manufacturer and designer filtering."""
|
||||
manufacturer_id = params.get("manufacturer_id")
|
||||
if manufacturer_id:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(manufacturer_id=int(manufacturer_id))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
manufacturer_slug = params.get("manufacturer_slug")
|
||||
if manufacturer_slug:
|
||||
@@ -404,10 +401,8 @@ class RideListCreateAPIView(APIView):
|
||||
|
||||
designer_id = params.get("designer_id")
|
||||
if designer_id:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(designer_id=int(designer_id))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
designer_slug = params.get("designer_slug")
|
||||
if designer_slug:
|
||||
@@ -419,10 +414,8 @@ class RideListCreateAPIView(APIView):
|
||||
"""Apply ride model filtering."""
|
||||
ride_model_id = params.get("ride_model_id")
|
||||
if ride_model_id:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(ride_model_id=int(ride_model_id))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
ride_model_slug = params.get("ride_model_slug")
|
||||
manufacturer_slug_for_model = params.get("manufacturer_slug")
|
||||
@@ -438,17 +431,13 @@ class RideListCreateAPIView(APIView):
|
||||
"""Apply rating-based filtering."""
|
||||
min_rating = params.get("min_rating")
|
||||
if min_rating:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(average_rating__gte=float(min_rating))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
max_rating = params.get("max_rating")
|
||||
if max_rating:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(average_rating__lte=float(max_rating))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return qs
|
||||
|
||||
@@ -456,17 +445,13 @@ class RideListCreateAPIView(APIView):
|
||||
"""Apply height requirement filtering."""
|
||||
min_height_req = params.get("min_height_requirement")
|
||||
if min_height_req:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(min_height_in__gte=int(min_height_req))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
max_height_req = params.get("max_height_requirement")
|
||||
if max_height_req:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(max_height_in__lte=int(max_height_req))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return qs
|
||||
|
||||
@@ -474,17 +459,13 @@ class RideListCreateAPIView(APIView):
|
||||
"""Apply capacity filtering."""
|
||||
min_capacity = params.get("min_capacity")
|
||||
if min_capacity:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(capacity_per_hour__gte=int(min_capacity))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
max_capacity = params.get("max_capacity")
|
||||
if max_capacity:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(capacity_per_hour__lte=int(max_capacity))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return qs
|
||||
|
||||
@@ -492,24 +473,18 @@ class RideListCreateAPIView(APIView):
|
||||
"""Apply opening year filtering."""
|
||||
opening_year = params.get("opening_year")
|
||||
if opening_year:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(opening_date__year=int(opening_year))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
min_opening_year = params.get("min_opening_year")
|
||||
if min_opening_year:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(opening_date__year__gte=int(min_opening_year))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
max_opening_year = params.get("max_opening_year")
|
||||
if max_opening_year:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(opening_date__year__lte=int(max_opening_year))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
return qs
|
||||
|
||||
@@ -530,47 +505,35 @@ class RideListCreateAPIView(APIView):
|
||||
# Height filters
|
||||
min_height_ft = params.get("min_height_ft")
|
||||
if min_height_ft:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(coaster_stats__height_ft__gte=float(min_height_ft))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
max_height_ft = params.get("max_height_ft")
|
||||
if max_height_ft:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(coaster_stats__height_ft__lte=float(max_height_ft))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Speed filters
|
||||
min_speed_mph = params.get("min_speed_mph")
|
||||
if min_speed_mph:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(coaster_stats__speed_mph__gte=float(min_speed_mph))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
max_speed_mph = params.get("max_speed_mph")
|
||||
if max_speed_mph:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(coaster_stats__speed_mph__lte=float(max_speed_mph))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Inversion filters
|
||||
min_inversions = params.get("min_inversions")
|
||||
if min_inversions:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(coaster_stats__inversions__gte=int(min_inversions))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
max_inversions = params.get("max_inversions")
|
||||
if max_inversions:
|
||||
try:
|
||||
with contextlib.suppress(ValueError, TypeError):
|
||||
qs = qs.filter(coaster_stats__inversions__lte=int(max_inversions))
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
has_inversions = params.get("has_inversions")
|
||||
if has_inversions is not None:
|
||||
@@ -2176,10 +2139,8 @@ class HybridRideAPIView(APIView):
|
||||
value = query_params.get(param)
|
||||
if value:
|
||||
if param == "park_id":
|
||||
try:
|
||||
with contextlib.suppress(ValueError):
|
||||
filters[param] = int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
filters[param] = value
|
||||
|
||||
@@ -2461,14 +2422,14 @@ class RideFilterMetadataAPIView(APIView):
|
||||
class BaseCompanyListAPIView(APIView):
|
||||
permission_classes = [permissions.AllowAny]
|
||||
role = None
|
||||
|
||||
|
||||
def get(self, request: Request) -> Response:
|
||||
if not MODELS_AVAILABLE:
|
||||
return Response(
|
||||
{"detail": "Models not available"},
|
||||
status=status.HTTP_501_NOT_IMPLEMENTED
|
||||
)
|
||||
|
||||
|
||||
companies = (
|
||||
Company.objects.filter(roles__contains=[self.role])
|
||||
.annotate(ride_count=Count("manufactured_rides" if self.role == "MANUFACTURER" else "designed_rides"))
|
||||
@@ -2486,7 +2447,7 @@ class BaseCompanyListAPIView(APIView):
|
||||
}
|
||||
for c in companies
|
||||
]
|
||||
|
||||
|
||||
return Response({
|
||||
"results": data,
|
||||
"count": len(data)
|
||||
|
||||
@@ -5,88 +5,88 @@ This module provides a unified interface to all serializers across different dom
|
||||
while maintaining the modular structure for better organization and maintainability.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
# --- Companies and ride models domain ---
|
||||
from .companies import (
|
||||
CompanyCreateInputSerializer,
|
||||
CompanyDetailOutputSerializer,
|
||||
CompanyUpdateInputSerializer,
|
||||
RideModelCreateInputSerializer,
|
||||
RideModelDetailOutputSerializer,
|
||||
RideModelUpdateInputSerializer,
|
||||
) # noqa: F401
|
||||
|
||||
# --- Parks domain ---
|
||||
from .parks import (
|
||||
ParkAreaCreateInputSerializer,
|
||||
ParkAreaDetailOutputSerializer,
|
||||
ParkAreaUpdateInputSerializer,
|
||||
ParkCreateInputSerializer,
|
||||
ParkDetailOutputSerializer,
|
||||
ParkFilterInputSerializer,
|
||||
ParkListOutputSerializer,
|
||||
ParkLocationCreateInputSerializer,
|
||||
ParkLocationOutputSerializer,
|
||||
ParkLocationUpdateInputSerializer,
|
||||
ParkSuggestionOutputSerializer,
|
||||
ParkSuggestionSerializer,
|
||||
ParkUpdateInputSerializer,
|
||||
) # noqa: F401
|
||||
|
||||
# --- Rides domain ---
|
||||
from .rides import (
|
||||
RideCreateInputSerializer,
|
||||
RideDetailOutputSerializer,
|
||||
RideFilterInputSerializer,
|
||||
RideListOutputSerializer,
|
||||
RideLocationCreateInputSerializer,
|
||||
RideLocationOutputSerializer,
|
||||
RideLocationUpdateInputSerializer,
|
||||
RideModelOutputSerializer,
|
||||
RideParkOutputSerializer,
|
||||
RideReviewCreateInputSerializer,
|
||||
RideReviewOutputSerializer,
|
||||
RideReviewUpdateInputSerializer,
|
||||
RideUpdateInputSerializer,
|
||||
RollerCoasterStatsCreateInputSerializer,
|
||||
RollerCoasterStatsOutputSerializer,
|
||||
RollerCoasterStatsUpdateInputSerializer,
|
||||
) # noqa: F401
|
||||
from .services import (
|
||||
HealthCheckOutputSerializer,
|
||||
PerformanceMetricsOutputSerializer,
|
||||
SimpleHealthOutputSerializer,
|
||||
EmailSendInputSerializer,
|
||||
EmailTemplateOutputSerializer,
|
||||
MapDataOutputSerializer,
|
||||
CoordinateInputSerializer,
|
||||
HistoryEventSerializer,
|
||||
HistoryEntryOutputSerializer,
|
||||
HistoryCreateInputSerializer,
|
||||
ModerationSubmissionSerializer,
|
||||
ModerationSubmissionOutputSerializer,
|
||||
RoadtripParkSerializer,
|
||||
RoadtripCreateInputSerializer,
|
||||
RoadtripOutputSerializer,
|
||||
GeocodeInputSerializer,
|
||||
GeocodeOutputSerializer,
|
||||
DistanceCalculationInputSerializer,
|
||||
DistanceCalculationOutputSerializer,
|
||||
EmailSendInputSerializer,
|
||||
EmailTemplateOutputSerializer,
|
||||
GeocodeInputSerializer,
|
||||
GeocodeOutputSerializer,
|
||||
HealthCheckOutputSerializer,
|
||||
HistoryCreateInputSerializer,
|
||||
HistoryEntryOutputSerializer,
|
||||
HistoryEventSerializer,
|
||||
MapDataOutputSerializer,
|
||||
ModerationSubmissionOutputSerializer,
|
||||
ModerationSubmissionSerializer,
|
||||
PerformanceMetricsOutputSerializer,
|
||||
RoadtripCreateInputSerializer,
|
||||
RoadtripOutputSerializer,
|
||||
RoadtripParkSerializer,
|
||||
SimpleHealthOutputSerializer,
|
||||
) # noqa: F401
|
||||
from typing import Any, Dict, List
|
||||
import importlib
|
||||
|
||||
# --- Shared utilities and base classes ---
|
||||
from .shared import (
|
||||
FilterOptionSerializer,
|
||||
FilterRangeSerializer,
|
||||
StandardizedFilterMetadataSerializer,
|
||||
validate_filter_metadata_contract,
|
||||
ensure_filter_option_format,
|
||||
) # noqa: F401
|
||||
|
||||
# --- Parks domain ---
|
||||
from .parks import (
|
||||
ParkListOutputSerializer,
|
||||
ParkDetailOutputSerializer,
|
||||
ParkCreateInputSerializer,
|
||||
ParkUpdateInputSerializer,
|
||||
ParkFilterInputSerializer,
|
||||
ParkAreaDetailOutputSerializer,
|
||||
ParkAreaCreateInputSerializer,
|
||||
ParkAreaUpdateInputSerializer,
|
||||
ParkLocationOutputSerializer,
|
||||
ParkLocationCreateInputSerializer,
|
||||
ParkLocationUpdateInputSerializer,
|
||||
ParkSuggestionSerializer,
|
||||
ParkSuggestionOutputSerializer,
|
||||
) # noqa: F401
|
||||
|
||||
# --- Companies and ride models domain ---
|
||||
from .companies import (
|
||||
CompanyDetailOutputSerializer,
|
||||
CompanyCreateInputSerializer,
|
||||
CompanyUpdateInputSerializer,
|
||||
RideModelDetailOutputSerializer,
|
||||
RideModelCreateInputSerializer,
|
||||
RideModelUpdateInputSerializer,
|
||||
) # noqa: F401
|
||||
|
||||
# --- Rides domain ---
|
||||
from .rides import (
|
||||
RideParkOutputSerializer,
|
||||
RideModelOutputSerializer,
|
||||
RideListOutputSerializer,
|
||||
RideDetailOutputSerializer,
|
||||
RideCreateInputSerializer,
|
||||
RideUpdateInputSerializer,
|
||||
RideFilterInputSerializer,
|
||||
RollerCoasterStatsOutputSerializer,
|
||||
RollerCoasterStatsCreateInputSerializer,
|
||||
RollerCoasterStatsUpdateInputSerializer,
|
||||
RideLocationOutputSerializer,
|
||||
RideLocationCreateInputSerializer,
|
||||
RideLocationUpdateInputSerializer,
|
||||
RideReviewOutputSerializer,
|
||||
RideReviewCreateInputSerializer,
|
||||
RideReviewUpdateInputSerializer,
|
||||
validate_filter_metadata_contract,
|
||||
) # noqa: F401
|
||||
|
||||
# --- Accounts domain: try multiple likely locations, fall back to placeholders ---
|
||||
_ACCOUNTS_SYMBOLS: List[str] = [
|
||||
_ACCOUNTS_SYMBOLS: list[str] = [
|
||||
"UserProfileOutputSerializer",
|
||||
"UserProfileCreateInputSerializer",
|
||||
"UserProfileUpdateInputSerializer",
|
||||
@@ -106,7 +106,7 @@ _ACCOUNTS_SYMBOLS: List[str] = [
|
||||
]
|
||||
|
||||
|
||||
def _import_accounts_symbols() -> Dict[str, Any]:
|
||||
def _import_accounts_symbols() -> dict[str, Any]:
|
||||
"""
|
||||
Try a list of candidate module paths and return a dict mapping expected symbol
|
||||
names to the objects found. If no candidate provides a symbol, the symbol maps to None.
|
||||
@@ -119,7 +119,7 @@ def _import_accounts_symbols() -> Dict[str, Any]:
|
||||
]
|
||||
|
||||
# Prepare default placeholders
|
||||
result: Dict[str, Any] = {name: None for name in _ACCOUNTS_SYMBOLS}
|
||||
result: dict[str, Any] = dict.fromkeys(_ACCOUNTS_SYMBOLS)
|
||||
|
||||
for modname in candidates:
|
||||
try:
|
||||
|
||||
@@ -5,21 +5,22 @@ This module contains all serializers related to user account management,
|
||||
profile settings, preferences, privacy, notifications, and security.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
from django.contrib.auth import get_user_model
|
||||
from drf_spectacular.utils import (
|
||||
extend_schema_serializer,
|
||||
OpenApiExample,
|
||||
extend_schema_serializer,
|
||||
)
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.accounts.models import (
|
||||
User,
|
||||
UserProfile,
|
||||
UserNotification,
|
||||
NotificationPreference,
|
||||
User,
|
||||
UserNotification,
|
||||
UserProfile,
|
||||
)
|
||||
from apps.core.choices.serializers import RichChoiceFieldSerializer
|
||||
from apps.lists.models import UserList
|
||||
from apps.rides.models.credits import RideCredit
|
||||
from apps.core.choices.serializers import RichChoiceFieldSerializer
|
||||
|
||||
UserModel = get_user_model()
|
||||
|
||||
@@ -187,7 +188,7 @@ class PublicUserSerializer(serializers.ModelSerializer):
|
||||
Only exposes public information.
|
||||
"""
|
||||
profile = UserProfileSerializer(read_only=True)
|
||||
|
||||
|
||||
class Meta:
|
||||
model = User
|
||||
fields = [
|
||||
@@ -906,9 +907,10 @@ class AvatarUploadSerializer(serializers.Serializer):
|
||||
|
||||
# Try to validate with PIL
|
||||
try:
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from PIL import Image
|
||||
|
||||
value.seek(0)
|
||||
image_data = value.read()
|
||||
value.seek(0) # Reset for later use
|
||||
|
||||
@@ -5,14 +5,14 @@ This module contains all serializers related to user authentication,
|
||||
registration, password management, and social authentication.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
from django.contrib.auth import get_user_model, authenticate
|
||||
from django.contrib.auth import authenticate, get_user_model
|
||||
from django.contrib.auth.password_validation import validate_password
|
||||
from django.core.exceptions import ValidationError as DjangoValidationError
|
||||
from drf_spectacular.utils import (
|
||||
extend_schema_serializer,
|
||||
OpenApiExample,
|
||||
extend_schema_serializer,
|
||||
)
|
||||
from rest_framework import serializers
|
||||
|
||||
UserModel = get_user_model()
|
||||
|
||||
|
||||
@@ -5,16 +5,16 @@ This module contains all serializers related to companies that operate parks
|
||||
or manufacture rides, as well as ride model serializers.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
from drf_spectacular.utils import (
|
||||
extend_schema_serializer,
|
||||
extend_schema_field,
|
||||
OpenApiExample,
|
||||
extend_schema_field,
|
||||
extend_schema_serializer,
|
||||
)
|
||||
from rest_framework import serializers
|
||||
|
||||
from .shared import ModelChoices
|
||||
from apps.core.choices.serializers import RichChoiceFieldSerializer
|
||||
|
||||
from .shared import ModelChoices
|
||||
|
||||
# === COMPANY SERIALIZERS ===
|
||||
|
||||
|
||||
@@ -5,8 +5,8 @@ This module contains serializers for history tracking and timeline functionality
|
||||
using django-pghistory.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
from drf_spectacular.utils import extend_schema_field
|
||||
from rest_framework import serializers
|
||||
|
||||
|
||||
class ParkHistoryEventSerializer(serializers.Serializer):
|
||||
|
||||
@@ -5,13 +5,12 @@ This module contains all serializers related to map functionality,
|
||||
including location data, search results, and clustering.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
from drf_spectacular.utils import (
|
||||
extend_schema_serializer,
|
||||
extend_schema_field,
|
||||
OpenApiExample,
|
||||
extend_schema_field,
|
||||
extend_schema_serializer,
|
||||
)
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
# === MAP LOCATION SERIALIZERS ===
|
||||
|
||||
|
||||
@@ -5,13 +5,12 @@ This module contains serializers for photo uploads, media management,
|
||||
and related media functionality.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
from drf_spectacular.utils import (
|
||||
extend_schema_serializer,
|
||||
extend_schema_field,
|
||||
OpenApiExample,
|
||||
extend_schema_field,
|
||||
extend_schema_serializer,
|
||||
)
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
# === MEDIA SERIALIZERS ===
|
||||
|
||||
|
||||
@@ -5,13 +5,12 @@ This module contains serializers for statistics, health checks, and other
|
||||
miscellaneous functionality.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
from drf_spectacular.utils import (
|
||||
extend_schema_field,
|
||||
)
|
||||
from .shared import ModelChoices
|
||||
from apps.core.choices.serializers import RichChoiceFieldSerializer
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.core.choices.serializers import RichChoiceFieldSerializer
|
||||
|
||||
# === STATISTICS SERIALIZERS ===
|
||||
|
||||
|
||||
@@ -4,10 +4,12 @@ Serializers for park review API endpoints.
|
||||
This module contains serializers for park review CRUD operations.
|
||||
"""
|
||||
|
||||
from drf_spectacular.utils import OpenApiExample, extend_schema_field, extend_schema_serializer
|
||||
from rest_framework import serializers
|
||||
from drf_spectacular.utils import extend_schema_field, extend_schema_serializer, OpenApiExample
|
||||
from apps.parks.models.reviews import ParkReview
|
||||
|
||||
from apps.api.v1.serializers.reviews import ReviewUserSerializer
|
||||
from apps.parks.models.reviews import ParkReview
|
||||
|
||||
|
||||
@extend_schema_serializer(
|
||||
examples=[
|
||||
|
||||
@@ -5,18 +5,18 @@ This module contains all serializers related to parks, park areas, park location
|
||||
and park search functionality.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
from drf_spectacular.utils import (
|
||||
extend_schema_serializer,
|
||||
extend_schema_field,
|
||||
OpenApiExample,
|
||||
extend_schema_field,
|
||||
extend_schema_serializer,
|
||||
)
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.core.choices.serializers import RichChoiceFieldSerializer
|
||||
from apps.core.services.media_url_service import MediaURLService
|
||||
from config.django import base as settings
|
||||
|
||||
from .shared import LocationOutputSerializer, CompanyOutputSerializer, ModelChoices
|
||||
from apps.core.services.media_url_service import MediaURLService
|
||||
from apps.core.choices.serializers import RichChoiceFieldSerializer
|
||||
|
||||
from .shared import CompanyOutputSerializer, LocationOutputSerializer, ModelChoices
|
||||
|
||||
# === PARK SERIALIZERS ===
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ This module contains serializers for park-specific media functionality.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.parks.models import ParkPhoto
|
||||
|
||||
|
||||
|
||||
@@ -3,9 +3,10 @@ Serializers for review-related API endpoints.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.accounts.models import User
|
||||
from apps.parks.models.reviews import ParkReview
|
||||
from apps.rides.models.reviews import RideReview
|
||||
from apps.accounts.models import User
|
||||
|
||||
|
||||
class ReviewUserSerializer(serializers.ModelSerializer):
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
from rest_framework import serializers
|
||||
from drf_spectacular.utils import extend_schema_field
|
||||
from apps.rides.models.credits import RideCredit
|
||||
from apps.rides.models import Ride
|
||||
|
||||
from apps.api.v1.serializers.rides import RideListOutputSerializer
|
||||
from apps.rides.models import Ride
|
||||
from apps.rides.models.credits import RideCredit
|
||||
|
||||
|
||||
class RideCreditSerializer(serializers.ModelSerializer):
|
||||
"""Serializer for user ride credits."""
|
||||
|
||||
|
||||
ride_id = serializers.PrimaryKeyRelatedField(
|
||||
queryset=Ride.objects.all(), source='ride', write_only=True
|
||||
)
|
||||
ride = RideListOutputSerializer(read_only=True)
|
||||
|
||||
|
||||
class Meta:
|
||||
model = RideCredit
|
||||
fields = [
|
||||
@@ -23,6 +24,7 @@ class RideCreditSerializer(serializers.ModelSerializer):
|
||||
'first_ridden_at',
|
||||
'last_ridden_at',
|
||||
'notes',
|
||||
'display_order',
|
||||
'created_at',
|
||||
'updated_at',
|
||||
]
|
||||
@@ -37,7 +39,7 @@ class RideCreditSerializer(serializers.ModelSerializer):
|
||||
last = attrs.get('last_ridden_at')
|
||||
if first and last and last < first:
|
||||
raise serializers.ValidationError("Last ridden date cannot be before first ridden date.")
|
||||
|
||||
|
||||
return attrs
|
||||
|
||||
def create(self, validated_data):
|
||||
|
||||
@@ -5,16 +5,17 @@ This module contains all serializers related to ride models, variants,
|
||||
technical specifications, and related functionality.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
from drf_spectacular.utils import (
|
||||
extend_schema_serializer,
|
||||
extend_schema_field,
|
||||
OpenApiExample,
|
||||
extend_schema_field,
|
||||
extend_schema_serializer,
|
||||
)
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.core.choices.serializers import RichChoiceFieldSerializer
|
||||
from config.django import base as settings
|
||||
|
||||
from .shared import ModelChoices
|
||||
from apps.core.choices.serializers import RichChoiceFieldSerializer
|
||||
|
||||
# Use dynamic imports to avoid circular import issues
|
||||
|
||||
@@ -23,9 +24,9 @@ def get_ride_model_classes():
|
||||
"""Get ride model classes dynamically to avoid import issues."""
|
||||
from apps.rides.models import (
|
||||
RideModel,
|
||||
RideModelVariant,
|
||||
RideModelPhoto,
|
||||
RideModelTechnicalSpec,
|
||||
RideModelVariant,
|
||||
)
|
||||
|
||||
return RideModel, RideModelVariant, RideModelPhoto, RideModelTechnicalSpec
|
||||
|
||||
@@ -4,11 +4,11 @@ Serializers for ride review API endpoints.
|
||||
This module contains serializers for ride review CRUD operations with Rich Choice Objects compliance.
|
||||
"""
|
||||
|
||||
from drf_spectacular.utils import OpenApiExample, extend_schema_field, extend_schema_serializer
|
||||
from rest_framework import serializers
|
||||
from drf_spectacular.utils import extend_schema_field, extend_schema_serializer, OpenApiExample
|
||||
from apps.rides.models.reviews import RideReview
|
||||
|
||||
from apps.accounts.models import User
|
||||
from apps.core.choices.serializers import RichChoiceSerializer
|
||||
from apps.rides.models.reviews import RideReview
|
||||
|
||||
|
||||
class ReviewUserSerializer(serializers.ModelSerializer):
|
||||
@@ -74,7 +74,7 @@ class RideReviewOutputSerializer(serializers.ModelSerializer):
|
||||
"""Output serializer for ride reviews."""
|
||||
|
||||
user = ReviewUserSerializer(read_only=True)
|
||||
|
||||
|
||||
# Ride information
|
||||
ride = serializers.SerializerMethodField()
|
||||
park = serializers.SerializerMethodField()
|
||||
|
||||
@@ -5,16 +5,17 @@ This module contains all serializers related to rides, roller coaster statistics
|
||||
ride locations, and ride reviews.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
from drf_spectacular.utils import (
|
||||
extend_schema_serializer,
|
||||
extend_schema_field,
|
||||
OpenApiExample,
|
||||
extend_schema_field,
|
||||
extend_schema_serializer,
|
||||
)
|
||||
from config.django import base as settings
|
||||
from .shared import ModelChoices
|
||||
from apps.core.choices.serializers import RichChoiceFieldSerializer
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.core.choices.serializers import RichChoiceFieldSerializer
|
||||
from config.django import base as settings
|
||||
|
||||
from .shared import ModelChoices
|
||||
|
||||
# === RIDE SERIALIZERS ===
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ This module contains serializers for ride-specific media functionality.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.rides.models import RidePhoto
|
||||
|
||||
|
||||
|
||||
@@ -6,9 +6,10 @@ and other search functionality.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
from ..shared import ModelChoices
|
||||
|
||||
from apps.core.choices.serializers import RichChoiceFieldSerializer
|
||||
|
||||
from ..shared import ModelChoices
|
||||
|
||||
# === CORE ENTITY SEARCH SERIALIZERS ===
|
||||
|
||||
|
||||
@@ -5,11 +5,10 @@ This module contains serializers for various services like email, maps,
|
||||
history tracking, moderation, and roadtrip planning.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
from drf_spectacular.utils import (
|
||||
extend_schema_field,
|
||||
)
|
||||
|
||||
from rest_framework import serializers
|
||||
|
||||
# === HEALTH CHECK SERIALIZERS ===
|
||||
|
||||
|
||||
@@ -8,14 +8,15 @@ These serializers prevent contract violations by providing a single source of tr
|
||||
for common data structures used throughout the API.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from rest_framework import serializers
|
||||
from typing import Dict, Any, List
|
||||
|
||||
|
||||
class FilterOptionSerializer(serializers.Serializer):
|
||||
"""
|
||||
Standard filter option format - matches frontend TypeScript exactly.
|
||||
|
||||
|
||||
Frontend TypeScript interface:
|
||||
interface FilterOption {
|
||||
value: string;
|
||||
@@ -31,7 +32,7 @@ class FilterOptionSerializer(serializers.Serializer):
|
||||
help_text="Human-readable display label"
|
||||
)
|
||||
count = serializers.IntegerField(
|
||||
required=False,
|
||||
required=False,
|
||||
allow_null=True,
|
||||
help_text="Number of items matching this filter option"
|
||||
)
|
||||
@@ -44,7 +45,7 @@ class FilterOptionSerializer(serializers.Serializer):
|
||||
class FilterRangeSerializer(serializers.Serializer):
|
||||
"""
|
||||
Standard range filter format.
|
||||
|
||||
|
||||
Frontend TypeScript interface:
|
||||
interface FilterRange {
|
||||
min: number;
|
||||
@@ -66,7 +67,7 @@ class FilterRangeSerializer(serializers.Serializer):
|
||||
help_text="Step size for range inputs"
|
||||
)
|
||||
unit = serializers.CharField(
|
||||
required=False,
|
||||
required=False,
|
||||
allow_null=True,
|
||||
help_text="Unit of measurement (e.g., 'feet', 'mph', 'stars')"
|
||||
)
|
||||
@@ -75,7 +76,7 @@ class FilterRangeSerializer(serializers.Serializer):
|
||||
class BooleanFilterSerializer(serializers.Serializer):
|
||||
"""
|
||||
Standard boolean filter format.
|
||||
|
||||
|
||||
Frontend TypeScript interface:
|
||||
interface BooleanFilter {
|
||||
key: string;
|
||||
@@ -97,7 +98,7 @@ class BooleanFilterSerializer(serializers.Serializer):
|
||||
class OrderingOptionSerializer(serializers.Serializer):
|
||||
"""
|
||||
Standard ordering option format.
|
||||
|
||||
|
||||
Frontend TypeScript interface:
|
||||
interface OrderingOption {
|
||||
value: string;
|
||||
@@ -115,7 +116,7 @@ class OrderingOptionSerializer(serializers.Serializer):
|
||||
class StandardizedFilterMetadataSerializer(serializers.Serializer):
|
||||
"""
|
||||
Matches frontend TypeScript interface exactly.
|
||||
|
||||
|
||||
This serializer ensures all filter metadata responses follow the same structure
|
||||
that the frontend expects, preventing runtime type errors.
|
||||
"""
|
||||
@@ -131,7 +132,7 @@ class StandardizedFilterMetadataSerializer(serializers.Serializer):
|
||||
help_text="Total number of items in the filtered dataset"
|
||||
)
|
||||
ordering_options = FilterOptionSerializer(
|
||||
many=True,
|
||||
many=True,
|
||||
required=False,
|
||||
help_text="Available ordering options"
|
||||
)
|
||||
@@ -145,7 +146,7 @@ class StandardizedFilterMetadataSerializer(serializers.Serializer):
|
||||
class PaginationMetadataSerializer(serializers.Serializer):
|
||||
"""
|
||||
Standard pagination metadata format.
|
||||
|
||||
|
||||
Frontend TypeScript interface:
|
||||
interface PaginationMetadata {
|
||||
count: number;
|
||||
@@ -183,7 +184,7 @@ class PaginationMetadataSerializer(serializers.Serializer):
|
||||
class ApiResponseSerializer(serializers.Serializer):
|
||||
"""
|
||||
Standard API response wrapper.
|
||||
|
||||
|
||||
Frontend TypeScript interface:
|
||||
interface ApiResponse<T> {
|
||||
success: boolean;
|
||||
@@ -214,7 +215,7 @@ class ApiResponseSerializer(serializers.Serializer):
|
||||
class ErrorResponseSerializer(serializers.Serializer):
|
||||
"""
|
||||
Standard error response format.
|
||||
|
||||
|
||||
Frontend TypeScript interface:
|
||||
interface ApiError {
|
||||
status: "error";
|
||||
@@ -245,7 +246,7 @@ class ErrorResponseSerializer(serializers.Serializer):
|
||||
class LocationSerializer(serializers.Serializer):
|
||||
"""
|
||||
Standard location format.
|
||||
|
||||
|
||||
Frontend TypeScript interface:
|
||||
interface Location {
|
||||
city: string;
|
||||
@@ -291,7 +292,7 @@ LocationOutputSerializer = LocationSerializer
|
||||
class CompanyOutputSerializer(serializers.Serializer):
|
||||
"""
|
||||
Standard company output format.
|
||||
|
||||
|
||||
Frontend TypeScript interface:
|
||||
interface Company {
|
||||
id: number;
|
||||
@@ -322,24 +323,24 @@ class ModelChoices:
|
||||
"""
|
||||
Utility class to provide model choices for serializers using Rich Choice Objects.
|
||||
This prevents circular imports while providing access to model choices from the registry.
|
||||
|
||||
|
||||
NO FALLBACKS - All choices must be properly defined in Rich Choice Objects.
|
||||
"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_park_status_choices():
|
||||
"""Get park status choices from Rich Choice registry."""
|
||||
from apps.core.choices.registry import get_choices
|
||||
choices = get_choices("statuses", "parks")
|
||||
return [(choice.value, choice.label) for choice in choices]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_ride_status_choices():
|
||||
"""Get ride status choices from Rich Choice registry."""
|
||||
from apps.core.choices.registry import get_choices
|
||||
choices = get_choices("statuses", "rides")
|
||||
return [(choice.value, choice.label) for choice in choices]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_company_role_choices():
|
||||
"""Get company role choices from Rich Choice registry."""
|
||||
@@ -350,91 +351,91 @@ class ModelChoices:
|
||||
parks_choices = get_choices("company_roles", "parks")
|
||||
all_choices = list(rides_choices) + list(parks_choices)
|
||||
return [(choice.value, choice.label) for choice in all_choices]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_ride_category_choices():
|
||||
"""Get ride category choices from Rich Choice registry."""
|
||||
from apps.core.choices.registry import get_choices
|
||||
choices = get_choices("categories", "rides")
|
||||
return [(choice.value, choice.label) for choice in choices]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_ride_post_closing_choices():
|
||||
"""Get ride post-closing status choices from Rich Choice registry."""
|
||||
from apps.core.choices.registry import get_choices
|
||||
choices = get_choices("post_closing_statuses", "rides")
|
||||
return [(choice.value, choice.label) for choice in choices]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_coaster_track_choices():
|
||||
"""Get coaster track material choices from Rich Choice registry."""
|
||||
from apps.core.choices.registry import get_choices
|
||||
choices = get_choices("track_materials", "rides")
|
||||
return [(choice.value, choice.label) for choice in choices]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_coaster_type_choices():
|
||||
"""Get coaster type choices from Rich Choice registry."""
|
||||
from apps.core.choices.registry import get_choices
|
||||
choices = get_choices("coaster_types", "rides")
|
||||
return [(choice.value, choice.label) for choice in choices]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_launch_choices():
|
||||
"""Get launch system choices from Rich Choice registry (legacy method)."""
|
||||
from apps.core.choices.registry import get_choices
|
||||
choices = get_choices("propulsion_systems", "rides")
|
||||
return [(choice.value, choice.label) for choice in choices]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_propulsion_system_choices():
|
||||
"""Get propulsion system choices from Rich Choice registry."""
|
||||
from apps.core.choices.registry import get_choices
|
||||
choices = get_choices("propulsion_systems", "rides")
|
||||
return [(choice.value, choice.label) for choice in choices]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_photo_type_choices():
|
||||
"""Get photo type choices from Rich Choice registry."""
|
||||
from apps.core.choices.registry import get_choices
|
||||
choices = get_choices("photo_types", "rides")
|
||||
return [(choice.value, choice.label) for choice in choices]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_spec_category_choices():
|
||||
"""Get technical specification category choices from Rich Choice registry."""
|
||||
from apps.core.choices.registry import get_choices
|
||||
choices = get_choices("spec_categories", "rides")
|
||||
return [(choice.value, choice.label) for choice in choices]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_technical_spec_category_choices():
|
||||
"""Get technical specification category choices from Rich Choice registry."""
|
||||
from apps.core.choices.registry import get_choices
|
||||
choices = get_choices("spec_categories", "rides")
|
||||
return [(choice.value, choice.label) for choice in choices]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_target_market_choices():
|
||||
"""Get target market choices from Rich Choice registry."""
|
||||
from apps.core.choices.registry import get_choices
|
||||
choices = get_choices("target_markets", "rides")
|
||||
return [(choice.value, choice.label) for choice in choices]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_entity_type_choices():
|
||||
"""Get entity type choices for search functionality."""
|
||||
from apps.core.choices.registry import get_choices
|
||||
choices = get_choices("entity_types", "core")
|
||||
return [(choice.value, choice.label) for choice in choices]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_health_status_choices():
|
||||
"""Get health check status choices from Rich Choice registry."""
|
||||
from apps.core.choices.registry import get_choices
|
||||
choices = get_choices("health_statuses", "core")
|
||||
return [(choice.value, choice.label) for choice in choices]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_simple_health_status_choices():
|
||||
"""Get simple health check status choices from Rich Choice registry."""
|
||||
@@ -446,7 +447,7 @@ class ModelChoices:
|
||||
class EntityReferenceSerializer(serializers.Serializer):
|
||||
"""
|
||||
Standard entity reference format.
|
||||
|
||||
|
||||
Frontend TypeScript interface:
|
||||
interface Entity {
|
||||
id: number;
|
||||
@@ -468,7 +469,7 @@ class EntityReferenceSerializer(serializers.Serializer):
|
||||
class ImageVariantsSerializer(serializers.Serializer):
|
||||
"""
|
||||
Standard image variants format.
|
||||
|
||||
|
||||
Frontend TypeScript interface:
|
||||
interface ImageVariants {
|
||||
thumbnail: string;
|
||||
@@ -495,7 +496,7 @@ class ImageVariantsSerializer(serializers.Serializer):
|
||||
class PhotoSerializer(serializers.Serializer):
|
||||
"""
|
||||
Standard photo format.
|
||||
|
||||
|
||||
Frontend TypeScript interface:
|
||||
interface Photo {
|
||||
id: number;
|
||||
@@ -546,7 +547,7 @@ class PhotoSerializer(serializers.Serializer):
|
||||
class UserInfoSerializer(serializers.Serializer):
|
||||
"""
|
||||
Standard user info format.
|
||||
|
||||
|
||||
Frontend TypeScript interface:
|
||||
interface UserInfo {
|
||||
id: number;
|
||||
@@ -571,19 +572,19 @@ class UserInfoSerializer(serializers.Serializer):
|
||||
)
|
||||
|
||||
|
||||
def validate_filter_metadata_contract(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def validate_filter_metadata_contract(data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Validate that filter metadata follows the expected contract.
|
||||
|
||||
|
||||
This function can be used in views to ensure filter metadata
|
||||
matches the frontend TypeScript interface before returning it.
|
||||
|
||||
|
||||
Args:
|
||||
data: Filter metadata dictionary
|
||||
|
||||
|
||||
Returns:
|
||||
Validated and potentially transformed data
|
||||
|
||||
|
||||
Raises:
|
||||
serializers.ValidationError: If data doesn't match contract
|
||||
"""
|
||||
@@ -593,21 +594,21 @@ def validate_filter_metadata_contract(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return serializer.validated_data
|
||||
|
||||
|
||||
def ensure_filter_option_format(options: List[Any]) -> List[Dict[str, Any]]:
|
||||
def ensure_filter_option_format(options: list[Any]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Ensure a list of filter options follows the expected format.
|
||||
|
||||
|
||||
This utility function converts various input formats to the standard
|
||||
FilterOption format expected by the frontend.
|
||||
|
||||
|
||||
Args:
|
||||
options: List of options in various formats
|
||||
|
||||
|
||||
Returns:
|
||||
List of options in standard format
|
||||
"""
|
||||
standardized = []
|
||||
|
||||
|
||||
for option in options:
|
||||
if isinstance(option, dict):
|
||||
# Already in correct format or close to it
|
||||
@@ -633,19 +634,19 @@ def ensure_filter_option_format(options: List[Any]) -> List[Dict[str, Any]]:
|
||||
'count': None,
|
||||
'selected': False
|
||||
}
|
||||
|
||||
|
||||
standardized.append(standardized_option)
|
||||
|
||||
|
||||
return standardized
|
||||
|
||||
|
||||
def ensure_range_format(range_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def ensure_range_format(range_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Ensure range data follows the expected format.
|
||||
|
||||
|
||||
Args:
|
||||
range_data: Range data dictionary
|
||||
|
||||
|
||||
Returns:
|
||||
Range data in standard format
|
||||
"""
|
||||
|
||||
@@ -2,13 +2,14 @@
|
||||
API serializers for the ride ranking system.
|
||||
"""
|
||||
|
||||
from rest_framework import serializers
|
||||
from drf_spectacular.utils import (
|
||||
extend_schema_serializer,
|
||||
extend_schema_field,
|
||||
OpenApiExample,
|
||||
extend_schema_field,
|
||||
extend_schema_serializer,
|
||||
)
|
||||
from apps.rides.models import RideRanking, RankingSnapshot
|
||||
from rest_framework import serializers
|
||||
|
||||
from apps.rides.models import RankingSnapshot, RideRanking
|
||||
|
||||
|
||||
@extend_schema_serializer(
|
||||
@@ -179,6 +180,7 @@ class RideRankingDetailSerializer(serializers.ModelSerializer):
|
||||
def get_head_to_head_comparisons(self, obj):
|
||||
"""Get top head-to-head comparisons."""
|
||||
from django.db.models import Q
|
||||
|
||||
from apps.rides.models import RidePairComparison
|
||||
|
||||
comparisons = (
|
||||
|
||||
@@ -5,17 +5,20 @@ This module contains signal handlers that invalidate the stats cache
|
||||
whenever relevant entities are created, updated, or deleted.
|
||||
"""
|
||||
|
||||
from django.db.models.signals import post_save, post_delete
|
||||
from django.dispatch import receiver
|
||||
from django.core.cache import cache
|
||||
from django.db.models.signals import post_delete, post_save
|
||||
from django.dispatch import receiver
|
||||
|
||||
from apps.parks.models import Park, ParkReview, ParkPhoto, Company as ParkCompany
|
||||
from apps.parks.models import Company as ParkCompany
|
||||
from apps.parks.models import Park, ParkPhoto, ParkReview
|
||||
from apps.rides.models import (
|
||||
Company as RideCompany,
|
||||
)
|
||||
from apps.rides.models import (
|
||||
Ride,
|
||||
RollerCoasterStats,
|
||||
RideReview,
|
||||
RidePhoto,
|
||||
Company as RideCompany,
|
||||
RideReview,
|
||||
RollerCoasterStats,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,120 +5,120 @@ These tests verify that API responses match frontend TypeScript interfaces exact
|
||||
preventing runtime errors and ensuring type safety.
|
||||
"""
|
||||
|
||||
from django.test import TestCase, Client
|
||||
from django.test import Client, TestCase
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from apps.api.v1.serializers.shared import (
|
||||
ensure_filter_option_format,
|
||||
ensure_range_format,
|
||||
validate_filter_metadata_contract,
|
||||
)
|
||||
from apps.parks.services.hybrid_loader import smart_park_loader
|
||||
from apps.rides.services.hybrid_loader import SmartRideLoader
|
||||
from apps.api.v1.serializers.shared import (
|
||||
validate_filter_metadata_contract,
|
||||
ensure_filter_option_format,
|
||||
ensure_range_format
|
||||
)
|
||||
|
||||
|
||||
class FilterMetadataContractTests(TestCase):
|
||||
"""Test that filter metadata follows the expected contract."""
|
||||
|
||||
|
||||
def setUp(self):
|
||||
self.client = Client()
|
||||
|
||||
|
||||
def test_parks_filter_metadata_structure(self):
|
||||
"""Test that parks filter metadata has correct structure."""
|
||||
# Get filter metadata from the service
|
||||
metadata = smart_park_loader.get_filter_metadata()
|
||||
|
||||
|
||||
# Should have required top-level keys
|
||||
self.assertIn('categorical', metadata)
|
||||
self.assertIn('ranges', metadata)
|
||||
self.assertIn('total_count', metadata)
|
||||
|
||||
|
||||
# Categorical filters should be objects with value/label/count
|
||||
categorical = metadata['categorical']
|
||||
self.assertIsInstance(categorical, dict)
|
||||
|
||||
|
||||
for filter_name, filter_options in categorical.items():
|
||||
with self.subTest(filter_name=filter_name):
|
||||
self.assertIsInstance(filter_options, list,
|
||||
self.assertIsInstance(filter_options, list,
|
||||
f"Filter '{filter_name}' should be a list")
|
||||
|
||||
|
||||
for i, option in enumerate(filter_options):
|
||||
with self.subTest(filter_name=filter_name, option_index=i):
|
||||
self.assertIsInstance(option, dict,
|
||||
f"Filter '{filter_name}' option {i} should be an object, not {type(option).__name__}")
|
||||
|
||||
|
||||
# Check required properties
|
||||
self.assertIn('value', option,
|
||||
f"Filter '{filter_name}' option {i} missing 'value' property")
|
||||
self.assertIn('label', option,
|
||||
f"Filter '{filter_name}' option {i} missing 'label' property")
|
||||
|
||||
|
||||
# Check types
|
||||
self.assertIsInstance(option['value'], str,
|
||||
f"Filter '{filter_name}' option {i} 'value' should be string")
|
||||
self.assertIsInstance(option['label'], str,
|
||||
f"Filter '{filter_name}' option {i} 'label' should be string")
|
||||
|
||||
|
||||
# Count is optional but should be int if present
|
||||
if 'count' in option and option['count'] is not None:
|
||||
self.assertIsInstance(option['count'], int,
|
||||
f"Filter '{filter_name}' option {i} 'count' should be int")
|
||||
|
||||
|
||||
def test_rides_filter_metadata_structure(self):
|
||||
"""Test that rides filter metadata has correct structure."""
|
||||
loader = SmartRideLoader()
|
||||
metadata = loader.get_filter_metadata()
|
||||
|
||||
|
||||
# Should have required top-level keys
|
||||
self.assertIn('categorical', metadata)
|
||||
self.assertIn('ranges', metadata)
|
||||
self.assertIn('total_count', metadata)
|
||||
|
||||
|
||||
# Categorical filters should be objects with value/label/count
|
||||
categorical = metadata['categorical']
|
||||
self.assertIsInstance(categorical, dict)
|
||||
|
||||
|
||||
# Test specific categorical filters that were problematic
|
||||
critical_filters = ['categories', 'statuses', 'roller_coaster_types', 'track_materials']
|
||||
|
||||
|
||||
for filter_name in critical_filters:
|
||||
if filter_name in categorical:
|
||||
with self.subTest(filter_name=filter_name):
|
||||
filter_options = categorical[filter_name]
|
||||
self.assertIsInstance(filter_options, list)
|
||||
|
||||
|
||||
for i, option in enumerate(filter_options):
|
||||
with self.subTest(filter_name=filter_name, option_index=i):
|
||||
self.assertIsInstance(option, dict,
|
||||
f"CRITICAL: Filter '{filter_name}' option {i} is {type(option).__name__} but should be dict")
|
||||
|
||||
|
||||
self.assertIn('value', option)
|
||||
self.assertIn('label', option)
|
||||
self.assertIn('count', option)
|
||||
|
||||
|
||||
def test_range_metadata_structure(self):
|
||||
"""Test that range metadata has correct structure."""
|
||||
# Test parks ranges
|
||||
parks_metadata = smart_park_loader.get_filter_metadata()
|
||||
ranges = parks_metadata['ranges']
|
||||
|
||||
|
||||
for range_name, range_data in ranges.items():
|
||||
with self.subTest(range_name=range_name):
|
||||
self.assertIsInstance(range_data, dict,
|
||||
f"Range '{range_name}' should be an object")
|
||||
|
||||
|
||||
# Check required properties
|
||||
self.assertIn('min', range_data)
|
||||
self.assertIn('max', range_data)
|
||||
self.assertIn('step', range_data)
|
||||
self.assertIn('unit', range_data)
|
||||
|
||||
|
||||
# Check types (min/max can be None)
|
||||
if range_data['min'] is not None:
|
||||
self.assertIsInstance(range_data['min'], (int, float))
|
||||
if range_data['max'] is not None:
|
||||
self.assertIsInstance(range_data['max'], (int, float))
|
||||
|
||||
|
||||
self.assertIsInstance(range_data['step'], (int, float))
|
||||
# Unit can be None or string
|
||||
if range_data['unit'] is not None:
|
||||
@@ -127,7 +127,7 @@ class FilterMetadataContractTests(TestCase):
|
||||
|
||||
class ContractValidationUtilityTests(TestCase):
|
||||
"""Test contract validation utility functions."""
|
||||
|
||||
|
||||
def test_validate_filter_metadata_contract_valid(self):
|
||||
"""Test validation passes for valid filter metadata."""
|
||||
valid_metadata = {
|
||||
@@ -147,16 +147,16 @@ class ContractValidationUtilityTests(TestCase):
|
||||
},
|
||||
'total_count': 100
|
||||
}
|
||||
|
||||
|
||||
# Should not raise an exception
|
||||
validated = validate_filter_metadata_contract(valid_metadata)
|
||||
self.assertIsInstance(validated, dict)
|
||||
self.assertEqual(validated['total_count'], 100)
|
||||
|
||||
|
||||
def test_validate_filter_metadata_contract_invalid(self):
|
||||
"""Test validation fails for invalid filter metadata."""
|
||||
from rest_framework import serializers
|
||||
|
||||
|
||||
invalid_metadata = {
|
||||
'categorical': {
|
||||
'statuses': ['OPERATING', 'CLOSED_TEMP'] # Should be objects, not strings
|
||||
@@ -164,17 +164,17 @@ class ContractValidationUtilityTests(TestCase):
|
||||
'ranges': {},
|
||||
'total_count': 100
|
||||
}
|
||||
|
||||
|
||||
# Should raise ValidationError
|
||||
with self.assertRaises(serializers.ValidationError):
|
||||
validate_filter_metadata_contract(invalid_metadata)
|
||||
|
||||
|
||||
def test_ensure_filter_option_format_strings(self):
|
||||
"""Test converting string arrays to proper format."""
|
||||
string_options = ['OPERATING', 'CLOSED_TEMP', 'UNDER_CONSTRUCTION']
|
||||
|
||||
|
||||
formatted = ensure_filter_option_format(string_options)
|
||||
|
||||
|
||||
self.assertEqual(len(formatted), 3)
|
||||
for i, option in enumerate(formatted):
|
||||
self.assertIsInstance(option, dict)
|
||||
@@ -182,44 +182,44 @@ class ContractValidationUtilityTests(TestCase):
|
||||
self.assertIn('label', option)
|
||||
self.assertIn('count', option)
|
||||
self.assertIn('selected', option)
|
||||
|
||||
|
||||
self.assertEqual(option['value'], string_options[i])
|
||||
self.assertEqual(option['label'], string_options[i])
|
||||
self.assertIsNone(option['count'])
|
||||
self.assertFalse(option['selected'])
|
||||
|
||||
|
||||
def test_ensure_filter_option_format_tuples(self):
|
||||
"""Test converting tuple arrays to proper format."""
|
||||
tuple_options = [
|
||||
('OPERATING', 'Operating', 5),
|
||||
('CLOSED_TEMP', 'Temporarily Closed', 2)
|
||||
]
|
||||
|
||||
|
||||
formatted = ensure_filter_option_format(tuple_options)
|
||||
|
||||
|
||||
self.assertEqual(len(formatted), 2)
|
||||
self.assertEqual(formatted[0]['value'], 'OPERATING')
|
||||
self.assertEqual(formatted[0]['label'], 'Operating')
|
||||
self.assertEqual(formatted[0]['count'], 5)
|
||||
|
||||
|
||||
self.assertEqual(formatted[1]['value'], 'CLOSED_TEMP')
|
||||
self.assertEqual(formatted[1]['label'], 'Temporarily Closed')
|
||||
self.assertEqual(formatted[1]['count'], 2)
|
||||
|
||||
|
||||
def test_ensure_filter_option_format_dicts(self):
|
||||
"""Test that properly formatted dicts pass through correctly."""
|
||||
dict_options = [
|
||||
{'value': 'OPERATING', 'label': 'Operating', 'count': 5},
|
||||
{'value': 'CLOSED_TEMP', 'label': 'Temporarily Closed', 'count': 2}
|
||||
]
|
||||
|
||||
|
||||
formatted = ensure_filter_option_format(dict_options)
|
||||
|
||||
|
||||
self.assertEqual(len(formatted), 2)
|
||||
self.assertEqual(formatted[0]['value'], 'OPERATING')
|
||||
self.assertEqual(formatted[0]['label'], 'Operating')
|
||||
self.assertEqual(formatted[0]['count'], 5)
|
||||
|
||||
|
||||
def test_ensure_range_format(self):
|
||||
"""Test range format utility."""
|
||||
range_data = {
|
||||
@@ -228,36 +228,36 @@ class ContractValidationUtilityTests(TestCase):
|
||||
'step': 0.5,
|
||||
'unit': 'stars'
|
||||
}
|
||||
|
||||
|
||||
formatted = ensure_range_format(range_data)
|
||||
|
||||
|
||||
self.assertEqual(formatted['min'], 1.0)
|
||||
self.assertEqual(formatted['max'], 10.0)
|
||||
self.assertEqual(formatted['step'], 0.5)
|
||||
self.assertEqual(formatted['unit'], 'stars')
|
||||
|
||||
|
||||
def test_ensure_range_format_missing_step(self):
|
||||
"""Test range format with missing step defaults to 1.0."""
|
||||
range_data = {
|
||||
'min': 1,
|
||||
'max': 10
|
||||
}
|
||||
|
||||
|
||||
formatted = ensure_range_format(range_data)
|
||||
|
||||
|
||||
self.assertEqual(formatted['step'], 1.0)
|
||||
self.assertIsNone(formatted['unit'])
|
||||
|
||||
|
||||
class APIEndpointContractTests(APITestCase):
|
||||
"""Test actual API endpoints for contract compliance."""
|
||||
|
||||
|
||||
def test_parks_hybrid_endpoint_contract(self):
|
||||
"""Test parks hybrid endpoint returns proper contract."""
|
||||
# This would require actual data in the database
|
||||
# For now, we'll test the structure
|
||||
pass
|
||||
|
||||
|
||||
def test_rides_hybrid_endpoint_contract(self):
|
||||
"""Test rides hybrid endpoint returns proper contract."""
|
||||
# This would require actual data in the database
|
||||
@@ -267,7 +267,7 @@ class APIEndpointContractTests(APITestCase):
|
||||
|
||||
class TypeScriptInterfaceComplianceTests(TestCase):
|
||||
"""Test that responses match TypeScript interfaces exactly."""
|
||||
|
||||
|
||||
def test_filter_option_interface_compliance(self):
|
||||
"""Test FilterOption interface compliance."""
|
||||
# TypeScript interface:
|
||||
@@ -277,28 +277,28 @@ class TypeScriptInterfaceComplianceTests(TestCase):
|
||||
# count?: number;
|
||||
# selected?: boolean;
|
||||
# }
|
||||
|
||||
|
||||
option = {
|
||||
'value': 'OPERATING',
|
||||
'label': 'Operating',
|
||||
'count': 5,
|
||||
'selected': False
|
||||
}
|
||||
|
||||
|
||||
# All required fields present
|
||||
self.assertIn('value', option)
|
||||
self.assertIn('label', option)
|
||||
|
||||
|
||||
# Correct types
|
||||
self.assertIsInstance(option['value'], str)
|
||||
self.assertIsInstance(option['label'], str)
|
||||
|
||||
|
||||
# Optional fields have correct types if present
|
||||
if 'count' in option and option['count'] is not None:
|
||||
self.assertIsInstance(option['count'], int)
|
||||
if 'selected' in option:
|
||||
self.assertIsInstance(option['selected'], bool)
|
||||
|
||||
|
||||
def test_filter_range_interface_compliance(self):
|
||||
"""Test FilterRange interface compliance."""
|
||||
# TypeScript interface:
|
||||
@@ -308,27 +308,27 @@ class TypeScriptInterfaceComplianceTests(TestCase):
|
||||
# step: number;
|
||||
# unit?: string;
|
||||
# }
|
||||
|
||||
|
||||
range_data = {
|
||||
'min': 1.0,
|
||||
'max': 10.0,
|
||||
'step': 0.1,
|
||||
'unit': 'stars'
|
||||
}
|
||||
|
||||
|
||||
# All required fields present
|
||||
self.assertIn('min', range_data)
|
||||
self.assertIn('max', range_data)
|
||||
self.assertIn('step', range_data)
|
||||
|
||||
|
||||
# Correct types (min/max can be null)
|
||||
if range_data['min'] is not None:
|
||||
self.assertIsInstance(range_data['min'], (int, float))
|
||||
if range_data['max'] is not None:
|
||||
self.assertIsInstance(range_data['max'], (int, float))
|
||||
|
||||
|
||||
self.assertIsInstance(range_data['step'], (int, float))
|
||||
|
||||
|
||||
# Optional unit field
|
||||
if 'unit' in range_data and range_data['unit'] is not None:
|
||||
self.assertIsInstance(range_data['unit'], str)
|
||||
@@ -336,72 +336,72 @@ class TypeScriptInterfaceComplianceTests(TestCase):
|
||||
|
||||
class RegressionTests(TestCase):
|
||||
"""Regression tests for specific contract violations that were fixed."""
|
||||
|
||||
|
||||
def test_categorical_filters_not_strings(self):
|
||||
"""Regression test: Ensure categorical filters are never returned as strings."""
|
||||
# This was the main issue - categorical filters were returned as:
|
||||
# ['OPERATING', 'CLOSED_TEMP'] instead of
|
||||
# ['OPERATING', 'CLOSED_TEMP'] instead of
|
||||
# [{'value': 'OPERATING', 'label': 'Operating', 'count': 5}, ...]
|
||||
|
||||
|
||||
# Test parks
|
||||
parks_metadata = smart_park_loader.get_filter_metadata()
|
||||
categorical = parks_metadata.get('categorical', {})
|
||||
|
||||
|
||||
for filter_name, filter_options in categorical.items():
|
||||
with self.subTest(filter_name=filter_name):
|
||||
self.assertIsInstance(filter_options, list)
|
||||
|
||||
|
||||
for i, option in enumerate(filter_options):
|
||||
with self.subTest(filter_name=filter_name, option_index=i):
|
||||
self.assertIsInstance(option, dict,
|
||||
f"REGRESSION: Filter '{filter_name}' option {i} is a {type(option).__name__} "
|
||||
f"but should be a dict. This causes frontend crashes!")
|
||||
|
||||
|
||||
# Must not be a string
|
||||
self.assertNotIsInstance(option, str,
|
||||
f"CRITICAL REGRESSION: Filter '{filter_name}' option {i} is a string '{option}' "
|
||||
f"but frontend expects object with value/label/count properties!")
|
||||
|
||||
|
||||
# Test rides
|
||||
rides_loader = SmartRideLoader()
|
||||
rides_metadata = rides_loader.get_filter_metadata()
|
||||
categorical = rides_metadata.get('categorical', {})
|
||||
|
||||
|
||||
for filter_name, filter_options in categorical.items():
|
||||
with self.subTest(filter_name=f"rides_{filter_name}"):
|
||||
self.assertIsInstance(filter_options, list)
|
||||
|
||||
|
||||
for i, option in enumerate(filter_options):
|
||||
with self.subTest(filter_name=f"rides_{filter_name}", option_index=i):
|
||||
self.assertIsInstance(option, dict,
|
||||
f"REGRESSION: Rides filter '{filter_name}' option {i} is a {type(option).__name__} "
|
||||
f"but should be a dict. This causes frontend crashes!")
|
||||
|
||||
|
||||
def test_ranges_have_step_and_unit(self):
|
||||
"""Regression test: Ensure ranges have step and unit properties."""
|
||||
# Frontend expects: { min: number, max: number, step: number, unit?: string }
|
||||
# Backend was sometimes missing step and unit
|
||||
|
||||
|
||||
parks_metadata = smart_park_loader.get_filter_metadata()
|
||||
ranges = parks_metadata.get('ranges', {})
|
||||
|
||||
|
||||
for range_name, range_data in ranges.items():
|
||||
with self.subTest(range_name=range_name):
|
||||
self.assertIn('step', range_data,
|
||||
f"Range '{range_name}' missing 'step' property required by frontend")
|
||||
self.assertIn('unit', range_data,
|
||||
f"Range '{range_name}' missing 'unit' property required by frontend")
|
||||
|
||||
|
||||
# Step should be a number
|
||||
self.assertIsInstance(range_data['step'], (int, float),
|
||||
f"Range '{range_name}' step should be a number")
|
||||
|
||||
|
||||
def test_no_undefined_values(self):
|
||||
"""Regression test: Ensure no undefined values (should be null)."""
|
||||
# JavaScript undefined !== null, and TypeScript interfaces expect null
|
||||
|
||||
|
||||
parks_metadata = smart_park_loader.get_filter_metadata()
|
||||
|
||||
|
||||
def check_no_undefined(obj, path=""):
|
||||
if isinstance(obj, dict):
|
||||
for key, value in obj.items():
|
||||
@@ -413,6 +413,6 @@ class RegressionTests(TestCase):
|
||||
for i, item in enumerate(obj):
|
||||
current_path = f"{path}[{i}]"
|
||||
check_no_undefined(item, current_path)
|
||||
|
||||
|
||||
# This will recursively check the entire metadata structure
|
||||
check_no_undefined(parks_metadata)
|
||||
|
||||
@@ -5,23 +5,24 @@ This module provides unified API routing following RESTful conventions
|
||||
and DRF Router patterns for automatic URL generation.
|
||||
"""
|
||||
|
||||
from .viewsets_rankings import RideRankingViewSet, TriggerRankingCalculationView
|
||||
from django.urls import include, path
|
||||
from rest_framework.routers import DefaultRouter
|
||||
|
||||
# Import other views from the views directory
|
||||
from .views import (
|
||||
HealthCheckAPIView,
|
||||
NewContentAPIView,
|
||||
PerformanceMetricsAPIView,
|
||||
SimpleHealthAPIView,
|
||||
# Trending system views
|
||||
TrendingAPIView,
|
||||
NewContentAPIView,
|
||||
TriggerTrendingCalculationAPIView,
|
||||
)
|
||||
from .views.discovery import DiscoveryAPIView
|
||||
from .views.stats import StatsAPIView, StatsRecalculateAPIView
|
||||
from .views.reviews import LatestReviewsAPIView
|
||||
from .views.leaderboard import leaderboard
|
||||
from django.urls import path, include
|
||||
from rest_framework.routers import DefaultRouter
|
||||
from .views.reviews import LatestReviewsAPIView
|
||||
from .views.stats import StatsAPIView, StatsRecalculateAPIView
|
||||
from .viewsets_rankings import RideRankingViewSet, TriggerRankingCalculationView
|
||||
|
||||
# Create the main API router
|
||||
router = DefaultRouter()
|
||||
@@ -79,6 +80,7 @@ urlpatterns = [
|
||||
path("core/", include("apps.api.v1.core.urls")),
|
||||
path("maps/", include("apps.api.v1.maps.urls")),
|
||||
path("lists/", include("apps.lists.urls")),
|
||||
path("companies/", include("apps.api.v1.rides.company_urls")),
|
||||
path("moderation/", include("apps.moderation.urls")),
|
||||
path("reviews/", include("apps.reviews.urls")),
|
||||
path("media/", include("apps.media.urls")),
|
||||
|
||||
@@ -9,25 +9,23 @@ This package contains all API view classes organized by functionality:
|
||||
|
||||
# Import all view classes for easy access
|
||||
from .auth import (
|
||||
LoginAPIView,
|
||||
SignupAPIView,
|
||||
LogoutAPIView,
|
||||
CurrentUserAPIView,
|
||||
PasswordResetAPIView,
|
||||
PasswordChangeAPIView,
|
||||
SocialProvidersAPIView,
|
||||
AuthStatusAPIView,
|
||||
CurrentUserAPIView,
|
||||
LoginAPIView,
|
||||
LogoutAPIView,
|
||||
PasswordChangeAPIView,
|
||||
PasswordResetAPIView,
|
||||
SignupAPIView,
|
||||
SocialProvidersAPIView,
|
||||
)
|
||||
|
||||
from .health import (
|
||||
HealthCheckAPIView,
|
||||
PerformanceMetricsAPIView,
|
||||
SimpleHealthAPIView,
|
||||
)
|
||||
|
||||
from .trending import (
|
||||
TrendingAPIView,
|
||||
NewContentAPIView,
|
||||
TrendingAPIView,
|
||||
TriggerTrendingCalculationAPIView,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,34 +7,34 @@ login, signup, logout, password management, and social authentication.
|
||||
|
||||
# type: ignore[misc,attr-defined,arg-type,call-arg,index,assignment]
|
||||
|
||||
from typing import TYPE_CHECKING, Type, Any
|
||||
from django.contrib.auth import login, logout, get_user_model
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from django.contrib.auth import get_user_model, login, logout
|
||||
from django.contrib.sites.shortcuts import get_current_site
|
||||
from django.core.exceptions import ValidationError
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_view
|
||||
from rest_framework import status
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticated
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_view
|
||||
from rest_framework.views import APIView
|
||||
|
||||
# Import serializers from the auth serializers module
|
||||
from ..serializers.auth import (
|
||||
AuthStatusOutputSerializer,
|
||||
LoginInputSerializer,
|
||||
LoginOutputSerializer,
|
||||
SignupInputSerializer,
|
||||
SignupOutputSerializer,
|
||||
LogoutOutputSerializer,
|
||||
UserOutputSerializer,
|
||||
PasswordResetInputSerializer,
|
||||
PasswordResetOutputSerializer,
|
||||
PasswordChangeInputSerializer,
|
||||
PasswordChangeOutputSerializer,
|
||||
PasswordResetInputSerializer,
|
||||
PasswordResetOutputSerializer,
|
||||
SignupInputSerializer,
|
||||
SignupOutputSerializer,
|
||||
SocialProviderOutputSerializer,
|
||||
AuthStatusOutputSerializer,
|
||||
UserOutputSerializer,
|
||||
)
|
||||
|
||||
|
||||
# Handle optional dependencies with fallback classes
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ except ImportError:
|
||||
if TYPE_CHECKING:
|
||||
from typing import Union
|
||||
|
||||
TurnstileMixinType = Union[Type[FallbackTurnstileMixin], Any]
|
||||
TurnstileMixinType = Union[type[FallbackTurnstileMixin], Any]
|
||||
else:
|
||||
TurnstileMixinType = TurnstileMixin
|
||||
|
||||
|
||||
@@ -6,16 +6,15 @@ consistent formats that match frontend TypeScript interfaces exactly.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, Type
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.response import Response
|
||||
from rest_framework import status
|
||||
from rest_framework.serializers import Serializer
|
||||
from django.conf import settings
|
||||
from typing import Any
|
||||
|
||||
from apps.api.v1.serializers.shared import (
|
||||
validate_filter_metadata_contract
|
||||
)
|
||||
from django.conf import settings
|
||||
from rest_framework import status
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import Serializer
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from apps.api.v1.serializers.shared import validate_filter_metadata_contract
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -23,28 +22,28 @@ logger = logging.getLogger(__name__)
|
||||
class ContractCompliantAPIView(APIView):
|
||||
"""
|
||||
Base API view that ensures all responses are contract-compliant.
|
||||
|
||||
|
||||
This view provides:
|
||||
- Standardized success response format
|
||||
- Consistent error response format
|
||||
- Automatic contract validation in DEBUG mode
|
||||
- Proper error logging with context
|
||||
"""
|
||||
|
||||
|
||||
# Override in subclasses to specify response serializer
|
||||
response_serializer_class: Optional[Type[Serializer]] = None
|
||||
|
||||
response_serializer_class: type[Serializer] | None = None
|
||||
|
||||
def dispatch(self, request, *args, **kwargs):
|
||||
"""Override dispatch to add contract validation."""
|
||||
try:
|
||||
response = super().dispatch(request, *args, **kwargs)
|
||||
|
||||
|
||||
# Validate contract in DEBUG mode
|
||||
if settings.DEBUG and hasattr(response, 'data'):
|
||||
self._validate_response_contract(response.data)
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# Log the error with context
|
||||
logger.error(
|
||||
@@ -58,66 +57,66 @@ class ContractCompliantAPIView(APIView):
|
||||
},
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
|
||||
# Return standardized error response
|
||||
return self.error_response(
|
||||
message="An internal error occurred",
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
def success_response(
|
||||
self,
|
||||
data: Any = None,
|
||||
message: str = None,
|
||||
self,
|
||||
data: Any = None,
|
||||
message: str = None,
|
||||
status_code: int = status.HTTP_200_OK,
|
||||
headers: Dict[str, str] = None
|
||||
headers: dict[str, str] = None
|
||||
) -> Response:
|
||||
"""
|
||||
Create a standardized success response.
|
||||
|
||||
|
||||
Args:
|
||||
data: Response data
|
||||
message: Optional success message
|
||||
status_code: HTTP status code
|
||||
headers: Optional response headers
|
||||
|
||||
|
||||
Returns:
|
||||
Response with standardized format
|
||||
"""
|
||||
response_data = {
|
||||
'success': True
|
||||
}
|
||||
|
||||
|
||||
if data is not None:
|
||||
response_data['data'] = data
|
||||
|
||||
|
||||
if message:
|
||||
response_data['message'] = message
|
||||
|
||||
|
||||
return Response(
|
||||
response_data,
|
||||
response_data,
|
||||
status=status_code,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
|
||||
def error_response(
|
||||
self,
|
||||
message: str,
|
||||
status_code: int = status.HTTP_400_BAD_REQUEST,
|
||||
error_code: str = None,
|
||||
details: Any = None,
|
||||
headers: Dict[str, str] = None
|
||||
headers: dict[str, str] = None
|
||||
) -> Response:
|
||||
"""
|
||||
Create a standardized error response.
|
||||
|
||||
|
||||
Args:
|
||||
message: Error message
|
||||
status_code: HTTP status code
|
||||
error_code: Optional error code
|
||||
details: Optional error details
|
||||
headers: Optional response headers
|
||||
|
||||
|
||||
Returns:
|
||||
Response with standardized error format
|
||||
"""
|
||||
@@ -125,40 +124,40 @@ class ContractCompliantAPIView(APIView):
|
||||
'code': error_code or 'API_ERROR',
|
||||
'message': message
|
||||
}
|
||||
|
||||
|
||||
if details:
|
||||
error_data['details'] = details
|
||||
|
||||
|
||||
# Add user context if available
|
||||
if hasattr(self, 'request') and hasattr(self.request, 'user'):
|
||||
user = self.request.user
|
||||
if user and user.is_authenticated:
|
||||
error_data['request_user'] = user.username
|
||||
|
||||
|
||||
response_data = {
|
||||
'status': 'error',
|
||||
'error': error_data,
|
||||
'data': None
|
||||
}
|
||||
|
||||
|
||||
return Response(
|
||||
response_data,
|
||||
status=status_code,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
|
||||
def validation_error_response(
|
||||
self,
|
||||
errors: Dict[str, Any],
|
||||
errors: dict[str, Any],
|
||||
message: str = "Validation failed"
|
||||
) -> Response:
|
||||
"""
|
||||
Create a standardized validation error response.
|
||||
|
||||
|
||||
Args:
|
||||
errors: Validation errors dictionary
|
||||
message: Error message
|
||||
|
||||
|
||||
Returns:
|
||||
Response with validation errors
|
||||
"""
|
||||
@@ -170,11 +169,11 @@ class ContractCompliantAPIView(APIView):
|
||||
},
|
||||
status=status.HTTP_400_BAD_REQUEST
|
||||
)
|
||||
|
||||
|
||||
def _validate_response_contract(self, data: Any) -> None:
|
||||
"""
|
||||
Validate response data against expected contracts.
|
||||
|
||||
|
||||
This method is called automatically in DEBUG mode to catch
|
||||
contract violations during development.
|
||||
"""
|
||||
@@ -182,9 +181,9 @@ class ContractCompliantAPIView(APIView):
|
||||
# Check if this looks like filter metadata
|
||||
if isinstance(data, dict) and 'categorical' in data and 'ranges' in data:
|
||||
validate_filter_metadata_contract(data)
|
||||
|
||||
|
||||
# Add more contract validations as needed
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Contract validation failed in {self.__class__.__name__}: {str(e)}",
|
||||
@@ -199,30 +198,30 @@ class ContractCompliantAPIView(APIView):
|
||||
class FilterMetadataAPIView(ContractCompliantAPIView):
|
||||
"""
|
||||
Base view for filter metadata endpoints.
|
||||
|
||||
|
||||
This view ensures filter metadata responses always follow the correct
|
||||
contract that matches frontend TypeScript interfaces.
|
||||
"""
|
||||
|
||||
def get_filter_metadata(self) -> Dict[str, Any]:
|
||||
|
||||
def get_filter_metadata(self) -> dict[str, Any]:
|
||||
"""
|
||||
Override this method in subclasses to provide filter metadata.
|
||||
|
||||
|
||||
Returns:
|
||||
Filter metadata dictionary
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement get_filter_metadata()")
|
||||
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
"""Handle GET requests for filter metadata."""
|
||||
try:
|
||||
metadata = self.get_filter_metadata()
|
||||
|
||||
|
||||
# Validate the metadata contract
|
||||
validated_metadata = validate_filter_metadata_contract(metadata)
|
||||
|
||||
|
||||
return self.success_response(validated_metadata)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting filter metadata in {self.__class__.__name__}: {str(e)}",
|
||||
@@ -232,7 +231,7 @@ class FilterMetadataAPIView(ContractCompliantAPIView):
|
||||
},
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
|
||||
return self.error_response(
|
||||
message="Failed to retrieve filter metadata",
|
||||
error_code="FILTER_METADATA_ERROR"
|
||||
@@ -242,37 +241,37 @@ class FilterMetadataAPIView(ContractCompliantAPIView):
|
||||
class HybridFilteringAPIView(ContractCompliantAPIView):
|
||||
"""
|
||||
Base view for hybrid filtering endpoints.
|
||||
|
||||
|
||||
This view provides common functionality for hybrid filtering responses
|
||||
and ensures they follow the correct contract.
|
||||
"""
|
||||
|
||||
def get_hybrid_data(self, filters: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
|
||||
def get_hybrid_data(self, filters: dict[str, Any] = None) -> dict[str, Any]:
|
||||
"""
|
||||
Override this method in subclasses to provide hybrid data.
|
||||
|
||||
|
||||
Args:
|
||||
filters: Filter parameters
|
||||
|
||||
|
||||
Returns:
|
||||
Hybrid response dictionary
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement get_hybrid_data()")
|
||||
|
||||
|
||||
def get(self, request, *args, **kwargs):
|
||||
"""Handle GET requests for hybrid filtering."""
|
||||
try:
|
||||
# Extract filters from request parameters
|
||||
filters = self.extract_filters(request)
|
||||
|
||||
|
||||
# Get hybrid data
|
||||
hybrid_data = self.get_hybrid_data(filters)
|
||||
|
||||
|
||||
# Validate hybrid response structure
|
||||
self._validate_hybrid_response(hybrid_data)
|
||||
|
||||
|
||||
return self.success_response(hybrid_data)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in hybrid filtering for {self.__class__.__name__}: {str(e)}",
|
||||
@@ -283,21 +282,21 @@ class HybridFilteringAPIView(ContractCompliantAPIView):
|
||||
},
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
|
||||
return self.error_response(
|
||||
message="Failed to retrieve filtered data",
|
||||
error_code="HYBRID_FILTERING_ERROR"
|
||||
)
|
||||
|
||||
def extract_filters(self, request) -> Dict[str, Any]:
|
||||
|
||||
def extract_filters(self, request) -> dict[str, Any]:
|
||||
"""
|
||||
Extract filter parameters from request.
|
||||
|
||||
|
||||
Override this method in subclasses to customize filter extraction.
|
||||
|
||||
|
||||
Args:
|
||||
request: HTTP request object
|
||||
|
||||
|
||||
Returns:
|
||||
Dictionary of filter parameters
|
||||
"""
|
||||
@@ -306,24 +305,24 @@ class HybridFilteringAPIView(ContractCompliantAPIView):
|
||||
for key, value in request.query_params.items():
|
||||
if value: # Only include non-empty values
|
||||
filters[key] = value
|
||||
|
||||
|
||||
# Store for error logging
|
||||
self._extracted_filters = filters
|
||||
|
||||
|
||||
return filters
|
||||
|
||||
def _validate_hybrid_response(self, data: Dict[str, Any]) -> None:
|
||||
|
||||
def _validate_hybrid_response(self, data: dict[str, Any]) -> None:
|
||||
"""Validate hybrid response structure."""
|
||||
required_fields = ['strategy', 'total_count']
|
||||
|
||||
|
||||
for field in required_fields:
|
||||
if field not in data:
|
||||
raise ValueError(f"Hybrid response missing required field: {field}")
|
||||
|
||||
|
||||
# Validate strategy value
|
||||
if data['strategy'] not in ['client_side', 'server_side']:
|
||||
raise ValueError(f"Invalid strategy value: {data['strategy']}")
|
||||
|
||||
|
||||
# Validate filter metadata if present
|
||||
if 'filter_metadata' in data:
|
||||
validate_filter_metadata_contract(data['filter_metadata'])
|
||||
@@ -332,77 +331,77 @@ class HybridFilteringAPIView(ContractCompliantAPIView):
|
||||
class PaginatedAPIView(ContractCompliantAPIView):
|
||||
"""
|
||||
Base view for paginated responses.
|
||||
|
||||
|
||||
This view ensures paginated responses follow the correct contract
|
||||
with consistent pagination metadata.
|
||||
"""
|
||||
|
||||
|
||||
default_page_size = 20
|
||||
max_page_size = 100
|
||||
|
||||
|
||||
def get_paginated_response(
|
||||
self,
|
||||
queryset,
|
||||
serializer_class: Type[Serializer],
|
||||
serializer_class: type[Serializer],
|
||||
request,
|
||||
page_size: int = None
|
||||
) -> Response:
|
||||
"""
|
||||
Create a paginated response.
|
||||
|
||||
|
||||
Args:
|
||||
queryset: Django queryset to paginate
|
||||
serializer_class: Serializer class for items
|
||||
request: HTTP request object
|
||||
page_size: Optional page size override
|
||||
|
||||
|
||||
Returns:
|
||||
Paginated response
|
||||
"""
|
||||
from django.core.paginator import Paginator, EmptyPage, PageNotAnInteger
|
||||
|
||||
from django.core.paginator import EmptyPage, PageNotAnInteger, Paginator
|
||||
|
||||
# Determine page size
|
||||
if page_size is None:
|
||||
page_size = min(
|
||||
int(request.query_params.get('page_size', self.default_page_size)),
|
||||
self.max_page_size
|
||||
)
|
||||
|
||||
|
||||
# Get page number
|
||||
page_number = request.query_params.get('page', 1)
|
||||
|
||||
|
||||
try:
|
||||
page_number = int(page_number)
|
||||
except (ValueError, TypeError):
|
||||
page_number = 1
|
||||
|
||||
|
||||
# Create paginator
|
||||
paginator = Paginator(queryset, page_size)
|
||||
|
||||
|
||||
try:
|
||||
page = paginator.page(page_number)
|
||||
except PageNotAnInteger:
|
||||
page = paginator.page(1)
|
||||
except EmptyPage:
|
||||
page = paginator.page(paginator.num_pages)
|
||||
|
||||
|
||||
# Serialize data
|
||||
serializer = serializer_class(page.object_list, many=True)
|
||||
|
||||
|
||||
# Build pagination URLs
|
||||
request_url = request.build_absolute_uri().split('?')[0]
|
||||
query_params = request.query_params.copy()
|
||||
|
||||
|
||||
next_url = None
|
||||
if page.has_next():
|
||||
query_params['page'] = page.next_page_number()
|
||||
next_url = f"{request_url}?{query_params.urlencode()}"
|
||||
|
||||
|
||||
previous_url = None
|
||||
if page.has_previous():
|
||||
query_params['page'] = page.previous_page_number()
|
||||
previous_url = f"{request_url}?{query_params.urlencode()}"
|
||||
|
||||
|
||||
# Create response data
|
||||
response_data = {
|
||||
'count': paginator.count,
|
||||
@@ -413,36 +412,36 @@ class PaginatedAPIView(ContractCompliantAPIView):
|
||||
'current_page': page.number,
|
||||
'total_pages': paginator.num_pages
|
||||
}
|
||||
|
||||
|
||||
return self.success_response(response_data)
|
||||
|
||||
|
||||
def contract_compliant_view(view_class):
|
||||
"""
|
||||
Decorator to make any view contract-compliant.
|
||||
|
||||
|
||||
This decorator can be applied to existing views to add contract
|
||||
validation without changing the base class.
|
||||
"""
|
||||
original_dispatch = view_class.dispatch
|
||||
|
||||
|
||||
def new_dispatch(self, request, *args, **kwargs):
|
||||
try:
|
||||
response = original_dispatch(self, request, *args, **kwargs)
|
||||
|
||||
|
||||
# Add contract validation in DEBUG mode
|
||||
if settings.DEBUG and hasattr(response, 'data'):
|
||||
# Basic validation - can be extended
|
||||
pass
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in decorated view {view_class.__name__}: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
|
||||
# Return basic error response
|
||||
return Response(
|
||||
{
|
||||
@@ -455,6 +454,6 @@ def contract_compliant_view(view_class):
|
||||
},
|
||||
status=status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
)
|
||||
|
||||
|
||||
view_class.dispatch = new_dispatch
|
||||
return view_class
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.response import Response
|
||||
from rest_framework import status
|
||||
from rest_framework.permissions import AllowAny
|
||||
from django.db.models import F
|
||||
|
||||
from django.utils import timezone
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from datetime import timedelta
|
||||
from rest_framework.permissions import AllowAny
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from apps.parks.models import Park
|
||||
from apps.rides.models import Ride
|
||||
|
||||
|
||||
class DiscoveryAPIView(APIView):
|
||||
"""
|
||||
API endpoint for discovery content (Top Lists, Opening/Closing Soon).
|
||||
@@ -28,7 +28,7 @@ class DiscoveryAPIView(APIView):
|
||||
# --- TOP LISTS ---
|
||||
# Top Parks by average rating
|
||||
top_parks = Park.objects.filter(average_rating__isnull=False).order_by("-average_rating")[:limit]
|
||||
|
||||
|
||||
# Top Rides by average rating (fallback to RideRanking in future)
|
||||
top_rides = Ride.objects.filter(average_rating__isnull=False).order_by("-average_rating")[:limit]
|
||||
|
||||
@@ -70,7 +70,7 @@ class DiscoveryAPIView(APIView):
|
||||
"rides": self._serialize(recently_closed_rides, "ride"),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
return Response(data)
|
||||
|
||||
def _serialize(self, queryset, type_):
|
||||
|
||||
@@ -6,14 +6,15 @@ performance metrics, and database analysis.
|
||||
"""
|
||||
|
||||
import time
|
||||
from django.utils import timezone
|
||||
|
||||
from django.conf import settings
|
||||
from rest_framework.views import APIView
|
||||
from django.utils import timezone
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_view
|
||||
from health_check.views import MainView
|
||||
from rest_framework.permissions import AllowAny
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.permissions import AllowAny
|
||||
from health_check.views import MainView
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_view
|
||||
from rest_framework.views import APIView
|
||||
|
||||
# Import serializers
|
||||
from ..serializers import (
|
||||
@@ -150,9 +151,10 @@ class HealthCheckAPIView(APIView):
|
||||
def _get_database_metrics(self) -> dict:
|
||||
"""Get database performance metrics."""
|
||||
try:
|
||||
from django.db import connection
|
||||
from typing import Any
|
||||
|
||||
from django.db import connection
|
||||
|
||||
# Get basic connection info
|
||||
metrics: dict[str, Any] = {
|
||||
"vendor": connection.vendor,
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
"""
|
||||
Leaderboard views for user rankings
|
||||
"""
|
||||
from rest_framework.decorators import api_view, permission_classes
|
||||
from rest_framework.permissions import AllowAny
|
||||
from rest_framework.response import Response
|
||||
from datetime import timedelta
|
||||
|
||||
from django.db.models import Count, Sum
|
||||
from django.db.models.functions import Coalesce
|
||||
from django.utils import timezone
|
||||
from datetime import timedelta
|
||||
from rest_framework.decorators import api_view, permission_classes
|
||||
from rest_framework.permissions import AllowAny
|
||||
from rest_framework.response import Response
|
||||
|
||||
from apps.accounts.models import User
|
||||
from apps.rides.models import RideCredit
|
||||
from apps.reviews.models import Review
|
||||
from apps.moderation.models import EditSubmission
|
||||
from apps.reviews.models import Review
|
||||
from apps.rides.models import RideCredit
|
||||
|
||||
|
||||
@api_view(['GET'])
|
||||
@@ -20,7 +20,7 @@ from apps.moderation.models import EditSubmission
|
||||
def leaderboard(request):
|
||||
"""
|
||||
Get user leaderboard data.
|
||||
|
||||
|
||||
Query params:
|
||||
- category: 'credits' | 'reviews' | 'contributions' (default: credits)
|
||||
- period: 'all' | 'monthly' | 'weekly' (default: all)
|
||||
@@ -29,14 +29,14 @@ def leaderboard(request):
|
||||
category = request.query_params.get('category', 'credits')
|
||||
period = request.query_params.get('period', 'all')
|
||||
limit = min(int(request.query_params.get('limit', 25)), 100)
|
||||
|
||||
|
||||
# Calculate date filter based on period
|
||||
date_filter = None
|
||||
if period == 'weekly':
|
||||
date_filter = timezone.now() - timedelta(days=7)
|
||||
elif period == 'monthly':
|
||||
date_filter = timezone.now() - timedelta(days=30)
|
||||
|
||||
|
||||
if category == 'credits':
|
||||
return _get_credits_leaderboard(date_filter, limit)
|
||||
elif category == 'reviews':
|
||||
@@ -50,16 +50,16 @@ def leaderboard(request):
|
||||
def _get_credits_leaderboard(date_filter, limit):
|
||||
"""Top users by total ride credits."""
|
||||
queryset = RideCredit.objects.all()
|
||||
|
||||
|
||||
if date_filter:
|
||||
queryset = queryset.filter(created_at__gte=date_filter)
|
||||
|
||||
|
||||
# Aggregate credits per user
|
||||
users_data = queryset.values('user_id', 'user__username', 'user__display_name').annotate(
|
||||
total_credits=Coalesce(Sum('count'), 0),
|
||||
unique_rides=Count('ride', distinct=True),
|
||||
).order_by('-total_credits')[:limit]
|
||||
|
||||
|
||||
results = []
|
||||
for rank, entry in enumerate(users_data, 1):
|
||||
results.append({
|
||||
@@ -70,7 +70,7 @@ def _get_credits_leaderboard(date_filter, limit):
|
||||
'total_credits': entry['total_credits'],
|
||||
'unique_rides': entry['unique_rides'],
|
||||
})
|
||||
|
||||
|
||||
return Response({
|
||||
'category': 'credits',
|
||||
'results': results,
|
||||
@@ -80,15 +80,15 @@ def _get_credits_leaderboard(date_filter, limit):
|
||||
def _get_reviews_leaderboard(date_filter, limit):
|
||||
"""Top users by review count."""
|
||||
queryset = Review.objects.all()
|
||||
|
||||
|
||||
if date_filter:
|
||||
queryset = queryset.filter(created_at__gte=date_filter)
|
||||
|
||||
|
||||
# Count reviews per user
|
||||
users_data = queryset.values('user_id', 'user__username', 'user__display_name').annotate(
|
||||
review_count=Count('id'),
|
||||
).order_by('-review_count')[:limit]
|
||||
|
||||
|
||||
results = []
|
||||
for rank, entry in enumerate(users_data, 1):
|
||||
results.append({
|
||||
@@ -98,7 +98,7 @@ def _get_reviews_leaderboard(date_filter, limit):
|
||||
'display_name': entry['user__display_name'] or entry['user__username'],
|
||||
'review_count': entry['review_count'],
|
||||
})
|
||||
|
||||
|
||||
return Response({
|
||||
'category': 'reviews',
|
||||
'results': results,
|
||||
@@ -108,15 +108,15 @@ def _get_reviews_leaderboard(date_filter, limit):
|
||||
def _get_contributions_leaderboard(date_filter, limit):
|
||||
"""Top users by approved contributions."""
|
||||
queryset = EditSubmission.objects.filter(status='approved')
|
||||
|
||||
|
||||
if date_filter:
|
||||
queryset = queryset.filter(created_at__gte=date_filter)
|
||||
|
||||
|
||||
# Count contributions per user
|
||||
users_data = queryset.values('submitted_by_id', 'submitted_by__username', 'submitted_by__display_name').annotate(
|
||||
contribution_count=Count('id'),
|
||||
).order_by('-contribution_count')[:limit]
|
||||
|
||||
|
||||
results = []
|
||||
for rank, entry in enumerate(users_data, 1):
|
||||
results.append({
|
||||
@@ -126,7 +126,7 @@ def _get_contributions_leaderboard(date_filter, limit):
|
||||
'display_name': entry['submitted_by__display_name'] or entry['submitted_by__username'],
|
||||
'contribution_count': entry['contribution_count'],
|
||||
})
|
||||
|
||||
|
||||
return Response({
|
||||
'category': 'contributions',
|
||||
'results': results,
|
||||
|
||||
@@ -2,17 +2,19 @@
|
||||
Views for review-related API endpoints.
|
||||
"""
|
||||
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.permissions import AllowAny
|
||||
from rest_framework import status
|
||||
from drf_spectacular.utils import extend_schema, OpenApiParameter
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, extend_schema
|
||||
from rest_framework import status
|
||||
from rest_framework.permissions import AllowAny
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from apps.parks.models.reviews import ParkReview
|
||||
from apps.rides.models.reviews import RideReview
|
||||
|
||||
from ..serializers.reviews import LatestReviewSerializer
|
||||
|
||||
|
||||
|
||||
@@ -5,24 +5,29 @@ Provides aggregate statistics about the platform's content including
|
||||
counts of parks, rides, manufacturers, and other entities.
|
||||
"""
|
||||
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.response import Response
|
||||
from rest_framework import status
|
||||
from rest_framework.permissions import AllowAny, IsAdminUser
|
||||
from django.db.models import Count
|
||||
from django.core.cache import cache
|
||||
from django.utils import timezone
|
||||
from drf_spectacular.utils import extend_schema, OpenApiExample
|
||||
from datetime import datetime
|
||||
|
||||
from apps.parks.models import Park, ParkReview, ParkPhoto, Company as ParkCompany
|
||||
from django.core.cache import cache
|
||||
from django.db.models import Count
|
||||
from django.utils import timezone
|
||||
from drf_spectacular.utils import OpenApiExample, extend_schema
|
||||
from rest_framework import status
|
||||
from rest_framework.permissions import AllowAny, IsAdminUser
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
|
||||
from apps.parks.models import Company as ParkCompany
|
||||
from apps.parks.models import Park, ParkPhoto, ParkReview
|
||||
from apps.rides.models import (
|
||||
Ride,
|
||||
RollerCoasterStats,
|
||||
RideReview,
|
||||
RidePhoto,
|
||||
Company as RideCompany,
|
||||
)
|
||||
from apps.rides.models import (
|
||||
Ride,
|
||||
RidePhoto,
|
||||
RideReview,
|
||||
RollerCoasterStats,
|
||||
)
|
||||
|
||||
from ..serializers.stats import StatsSerializer
|
||||
|
||||
|
||||
@@ -103,17 +108,17 @@ class StatsAPIView(APIView):
|
||||
summary="Get platform statistics",
|
||||
description="""
|
||||
Returns comprehensive aggregate statistics about the ThrillWiki platform.
|
||||
|
||||
|
||||
This endpoint provides detailed counts and breakdowns of all major entities including:
|
||||
- Parks, rides, and roller coasters
|
||||
- Companies (manufacturers, operators, designers, property owners)
|
||||
- Photos and reviews
|
||||
- Ride categories (roller coasters, dark rides, flat rides, etc.)
|
||||
- Status breakdowns (operating, closed, under construction, etc.)
|
||||
|
||||
Results are cached for 5 minutes for optimal performance and automatically
|
||||
|
||||
Results are cached for 5 minutes for optimal performance and automatically
|
||||
invalidated when relevant data changes.
|
||||
|
||||
|
||||
**No authentication required** - this is a public endpoint.
|
||||
""".strip(),
|
||||
responses={
|
||||
|
||||
@@ -5,14 +5,15 @@ This module contains endpoints for trending and new content discovery
|
||||
including trending parks, rides, and recently added content.
|
||||
"""
|
||||
|
||||
from datetime import datetime, date
|
||||
from rest_framework.views import APIView
|
||||
from datetime import date, datetime
|
||||
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view
|
||||
from rest_framework import status
|
||||
from rest_framework.permissions import AllowAny, IsAdminUser
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.permissions import AllowAny, IsAdminUser
|
||||
from rest_framework import status
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_view, OpenApiParameter
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from rest_framework.views import APIView
|
||||
|
||||
|
||||
@extend_schema_view(
|
||||
@@ -111,9 +112,10 @@ class TriggerTrendingCalculationAPIView(APIView):
|
||||
def post(self, request: Request) -> Response:
|
||||
"""Trigger trending content calculation using management commands."""
|
||||
try:
|
||||
from django.core.management import call_command
|
||||
import io
|
||||
from contextlib import redirect_stdout, redirect_stderr
|
||||
from contextlib import redirect_stderr, redirect_stdout
|
||||
|
||||
from django.core.management import call_command
|
||||
|
||||
# Capture command output
|
||||
trending_output = io.StringIO()
|
||||
@@ -227,10 +229,7 @@ class NewContentAPIView(APIView):
|
||||
if date_added:
|
||||
try:
|
||||
# Parse the date string
|
||||
if isinstance(date_added, str):
|
||||
item_date = datetime.fromisoformat(date_added).date()
|
||||
else:
|
||||
item_date = date_added
|
||||
item_date = datetime.fromisoformat(date_added).date() if isinstance(date_added, str) else date_added
|
||||
|
||||
# Calculate days difference
|
||||
days_diff = (today - item_date).days
|
||||
|
||||
@@ -2,32 +2,34 @@
|
||||
API viewsets for the ride ranking system.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Type, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from django.db.models import Q, QuerySet
|
||||
from django.utils import timezone
|
||||
from django_filters.rest_framework import DjangoFilterBackend
|
||||
from drf_spectacular.utils import extend_schema, extend_schema_view, OpenApiParameter
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view
|
||||
from rest_framework import status
|
||||
from rest_framework.decorators import action
|
||||
from rest_framework.filters import OrderingFilter
|
||||
from rest_framework.permissions import IsAuthenticatedOrReadOnly, AllowAny
|
||||
from rest_framework.permissions import AllowAny, IsAuthenticatedOrReadOnly
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import BaseSerializer
|
||||
from rest_framework.viewsets import ReadOnlyModelViewSet
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.viewsets import ReadOnlyModelViewSet
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
# Import models inside methods to avoid Django initialization issues
|
||||
import contextlib
|
||||
|
||||
from .serializers_rankings import (
|
||||
RideRankingSerializer,
|
||||
RideRankingDetailSerializer,
|
||||
RankingSnapshotSerializer,
|
||||
RankingStatsSerializer,
|
||||
RideRankingDetailSerializer,
|
||||
RideRankingSerializer,
|
||||
)
|
||||
|
||||
|
||||
@@ -127,10 +129,8 @@ class RideRankingViewSet(ReadOnlyModelViewSet):
|
||||
# Filter by minimum mutual riders
|
||||
min_riders = request.query_params.get("min_riders")
|
||||
if min_riders:
|
||||
try:
|
||||
with contextlib.suppress(ValueError):
|
||||
queryset = queryset.filter(mutual_riders_count__gte=int(min_riders))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Filter by park
|
||||
park_slug = request.query_params.get("park")
|
||||
@@ -142,12 +142,12 @@ class RideRankingViewSet(ReadOnlyModelViewSet):
|
||||
def get_serializer_class(self) -> Any: # type: ignore[override]
|
||||
"""Use different serializers for list vs detail."""
|
||||
if self.action == "retrieve":
|
||||
return cast(Type[BaseSerializer], RideRankingDetailSerializer)
|
||||
return cast(type[BaseSerializer], RideRankingDetailSerializer)
|
||||
elif self.action == "history":
|
||||
return cast(Type[BaseSerializer], RankingSnapshotSerializer)
|
||||
return cast(type[BaseSerializer], RankingSnapshotSerializer)
|
||||
elif self.action == "statistics":
|
||||
return cast(Type[BaseSerializer], RankingStatsSerializer)
|
||||
return cast(Type[BaseSerializer], RideRankingSerializer)
|
||||
return cast(type[BaseSerializer], RankingStatsSerializer)
|
||||
return cast(type[BaseSerializer], RideRankingSerializer)
|
||||
|
||||
@action(detail=True, methods=["get"])
|
||||
def history(self, request, ride_slug=None):
|
||||
@@ -167,7 +167,7 @@ class RideRankingViewSet(ReadOnlyModelViewSet):
|
||||
@action(detail=False, methods=["get"])
|
||||
def statistics(self, request):
|
||||
"""Get overall ranking system statistics."""
|
||||
from apps.rides.models import RideRanking, RidePairComparison, RankingSnapshot
|
||||
from apps.rides.models import RankingSnapshot, RidePairComparison, RideRanking
|
||||
|
||||
total_rankings = RideRanking.objects.count()
|
||||
total_comparisons = RidePairComparison.objects.count()
|
||||
|
||||
Reference in New Issue
Block a user