feat: Implement MFA authentication, add ride statistics model, and update various services, APIs, and tests across the application.

This commit is contained in:
pacnpal
2025-12-28 17:32:53 -05:00
parent aa56c46c27
commit c95f99ca10
452 changed files with 7948 additions and 6073 deletions

View File

@@ -9,25 +9,23 @@ This package contains all API view classes organized by functionality:
# Import all view classes for easy access
from .auth import (
LoginAPIView,
SignupAPIView,
LogoutAPIView,
CurrentUserAPIView,
PasswordResetAPIView,
PasswordChangeAPIView,
SocialProvidersAPIView,
AuthStatusAPIView,
CurrentUserAPIView,
LoginAPIView,
LogoutAPIView,
PasswordChangeAPIView,
PasswordResetAPIView,
SignupAPIView,
SocialProvidersAPIView,
)
from .health import (
HealthCheckAPIView,
PerformanceMetricsAPIView,
SimpleHealthAPIView,
)
from .trending import (
TrendingAPIView,
NewContentAPIView,
TrendingAPIView,
TriggerTrendingCalculationAPIView,
)

View File

@@ -7,34 +7,34 @@ login, signup, logout, password management, and social authentication.
# type: ignore[misc,attr-defined,arg-type,call-arg,index,assignment]
from typing import TYPE_CHECKING, Type, Any
from django.contrib.auth import login, logout, get_user_model
from typing import TYPE_CHECKING, Any
from django.contrib.auth import get_user_model, login, logout
from django.contrib.sites.shortcuts import get_current_site
from django.core.exceptions import ValidationError
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import status
from rest_framework.views import APIView
from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.permissions import AllowAny, IsAuthenticated
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework.views import APIView
# Import serializers from the auth serializers module
from ..serializers.auth import (
AuthStatusOutputSerializer,
LoginInputSerializer,
LoginOutputSerializer,
SignupInputSerializer,
SignupOutputSerializer,
LogoutOutputSerializer,
UserOutputSerializer,
PasswordResetInputSerializer,
PasswordResetOutputSerializer,
PasswordChangeInputSerializer,
PasswordChangeOutputSerializer,
PasswordResetInputSerializer,
PasswordResetOutputSerializer,
SignupInputSerializer,
SignupOutputSerializer,
SocialProviderOutputSerializer,
AuthStatusOutputSerializer,
UserOutputSerializer,
)
# Handle optional dependencies with fallback classes
@@ -56,7 +56,7 @@ except ImportError:
if TYPE_CHECKING:
from typing import Union
TurnstileMixinType = Union[Type[FallbackTurnstileMixin], Any]
TurnstileMixinType = Union[type[FallbackTurnstileMixin], Any]
else:
TurnstileMixinType = TurnstileMixin

View File

@@ -6,16 +6,15 @@ consistent formats that match frontend TypeScript interfaces exactly.
"""
import logging
from typing import Dict, Any, Optional, Type
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
from rest_framework.serializers import Serializer
from django.conf import settings
from typing import Any
from apps.api.v1.serializers.shared import (
validate_filter_metadata_contract
)
from django.conf import settings
from rest_framework import status
from rest_framework.response import Response
from rest_framework.serializers import Serializer
from rest_framework.views import APIView
from apps.api.v1.serializers.shared import validate_filter_metadata_contract
logger = logging.getLogger(__name__)
@@ -23,28 +22,28 @@ logger = logging.getLogger(__name__)
class ContractCompliantAPIView(APIView):
"""
Base API view that ensures all responses are contract-compliant.
This view provides:
- Standardized success response format
- Consistent error response format
- Automatic contract validation in DEBUG mode
- Proper error logging with context
"""
# Override in subclasses to specify response serializer
response_serializer_class: Optional[Type[Serializer]] = None
response_serializer_class: type[Serializer] | None = None
def dispatch(self, request, *args, **kwargs):
"""Override dispatch to add contract validation."""
try:
response = super().dispatch(request, *args, **kwargs)
# Validate contract in DEBUG mode
if settings.DEBUG and hasattr(response, 'data'):
self._validate_response_contract(response.data)
return response
except Exception as e:
# Log the error with context
logger.error(
@@ -58,66 +57,66 @@ class ContractCompliantAPIView(APIView):
},
exc_info=True
)
# Return standardized error response
return self.error_response(
message="An internal error occurred",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
def success_response(
self,
data: Any = None,
message: str = None,
self,
data: Any = None,
message: str = None,
status_code: int = status.HTTP_200_OK,
headers: Dict[str, str] = None
headers: dict[str, str] = None
) -> Response:
"""
Create a standardized success response.
Args:
data: Response data
message: Optional success message
status_code: HTTP status code
headers: Optional response headers
Returns:
Response with standardized format
"""
response_data = {
'success': True
}
if data is not None:
response_data['data'] = data
if message:
response_data['message'] = message
return Response(
response_data,
response_data,
status=status_code,
headers=headers
)
def error_response(
self,
message: str,
status_code: int = status.HTTP_400_BAD_REQUEST,
error_code: str = None,
details: Any = None,
headers: Dict[str, str] = None
headers: dict[str, str] = None
) -> Response:
"""
Create a standardized error response.
Args:
message: Error message
status_code: HTTP status code
error_code: Optional error code
details: Optional error details
headers: Optional response headers
Returns:
Response with standardized error format
"""
@@ -125,40 +124,40 @@ class ContractCompliantAPIView(APIView):
'code': error_code or 'API_ERROR',
'message': message
}
if details:
error_data['details'] = details
# Add user context if available
if hasattr(self, 'request') and hasattr(self.request, 'user'):
user = self.request.user
if user and user.is_authenticated:
error_data['request_user'] = user.username
response_data = {
'status': 'error',
'error': error_data,
'data': None
}
return Response(
response_data,
status=status_code,
headers=headers
)
def validation_error_response(
self,
errors: Dict[str, Any],
errors: dict[str, Any],
message: str = "Validation failed"
) -> Response:
"""
Create a standardized validation error response.
Args:
errors: Validation errors dictionary
message: Error message
Returns:
Response with validation errors
"""
@@ -170,11 +169,11 @@ class ContractCompliantAPIView(APIView):
},
status=status.HTTP_400_BAD_REQUEST
)
def _validate_response_contract(self, data: Any) -> None:
"""
Validate response data against expected contracts.
This method is called automatically in DEBUG mode to catch
contract violations during development.
"""
@@ -182,9 +181,9 @@ class ContractCompliantAPIView(APIView):
# Check if this looks like filter metadata
if isinstance(data, dict) and 'categorical' in data and 'ranges' in data:
validate_filter_metadata_contract(data)
# Add more contract validations as needed
except Exception as e:
logger.warning(
f"Contract validation failed in {self.__class__.__name__}: {str(e)}",
@@ -199,30 +198,30 @@ class ContractCompliantAPIView(APIView):
class FilterMetadataAPIView(ContractCompliantAPIView):
"""
Base view for filter metadata endpoints.
This view ensures filter metadata responses always follow the correct
contract that matches frontend TypeScript interfaces.
"""
def get_filter_metadata(self) -> Dict[str, Any]:
def get_filter_metadata(self) -> dict[str, Any]:
"""
Override this method in subclasses to provide filter metadata.
Returns:
Filter metadata dictionary
"""
raise NotImplementedError("Subclasses must implement get_filter_metadata()")
def get(self, request, *args, **kwargs):
"""Handle GET requests for filter metadata."""
try:
metadata = self.get_filter_metadata()
# Validate the metadata contract
validated_metadata = validate_filter_metadata_contract(metadata)
return self.success_response(validated_metadata)
except Exception as e:
logger.error(
f"Error getting filter metadata in {self.__class__.__name__}: {str(e)}",
@@ -232,7 +231,7 @@ class FilterMetadataAPIView(ContractCompliantAPIView):
},
exc_info=True
)
return self.error_response(
message="Failed to retrieve filter metadata",
error_code="FILTER_METADATA_ERROR"
@@ -242,37 +241,37 @@ class FilterMetadataAPIView(ContractCompliantAPIView):
class HybridFilteringAPIView(ContractCompliantAPIView):
"""
Base view for hybrid filtering endpoints.
This view provides common functionality for hybrid filtering responses
and ensures they follow the correct contract.
"""
def get_hybrid_data(self, filters: Dict[str, Any] = None) -> Dict[str, Any]:
def get_hybrid_data(self, filters: dict[str, Any] = None) -> dict[str, Any]:
"""
Override this method in subclasses to provide hybrid data.
Args:
filters: Filter parameters
Returns:
Hybrid response dictionary
"""
raise NotImplementedError("Subclasses must implement get_hybrid_data()")
def get(self, request, *args, **kwargs):
"""Handle GET requests for hybrid filtering."""
try:
# Extract filters from request parameters
filters = self.extract_filters(request)
# Get hybrid data
hybrid_data = self.get_hybrid_data(filters)
# Validate hybrid response structure
self._validate_hybrid_response(hybrid_data)
return self.success_response(hybrid_data)
except Exception as e:
logger.error(
f"Error in hybrid filtering for {self.__class__.__name__}: {str(e)}",
@@ -283,21 +282,21 @@ class HybridFilteringAPIView(ContractCompliantAPIView):
},
exc_info=True
)
return self.error_response(
message="Failed to retrieve filtered data",
error_code="HYBRID_FILTERING_ERROR"
)
def extract_filters(self, request) -> Dict[str, Any]:
def extract_filters(self, request) -> dict[str, Any]:
"""
Extract filter parameters from request.
Override this method in subclasses to customize filter extraction.
Args:
request: HTTP request object
Returns:
Dictionary of filter parameters
"""
@@ -306,24 +305,24 @@ class HybridFilteringAPIView(ContractCompliantAPIView):
for key, value in request.query_params.items():
if value: # Only include non-empty values
filters[key] = value
# Store for error logging
self._extracted_filters = filters
return filters
def _validate_hybrid_response(self, data: Dict[str, Any]) -> None:
def _validate_hybrid_response(self, data: dict[str, Any]) -> None:
"""Validate hybrid response structure."""
required_fields = ['strategy', 'total_count']
for field in required_fields:
if field not in data:
raise ValueError(f"Hybrid response missing required field: {field}")
# Validate strategy value
if data['strategy'] not in ['client_side', 'server_side']:
raise ValueError(f"Invalid strategy value: {data['strategy']}")
# Validate filter metadata if present
if 'filter_metadata' in data:
validate_filter_metadata_contract(data['filter_metadata'])
@@ -332,77 +331,77 @@ class HybridFilteringAPIView(ContractCompliantAPIView):
class PaginatedAPIView(ContractCompliantAPIView):
"""
Base view for paginated responses.
This view ensures paginated responses follow the correct contract
with consistent pagination metadata.
"""
default_page_size = 20
max_page_size = 100
def get_paginated_response(
self,
queryset,
serializer_class: Type[Serializer],
serializer_class: type[Serializer],
request,
page_size: int = None
) -> Response:
"""
Create a paginated response.
Args:
queryset: Django queryset to paginate
serializer_class: Serializer class for items
request: HTTP request object
page_size: Optional page size override
Returns:
Paginated response
"""
from django.core.paginator import Paginator, EmptyPage, PageNotAnInteger
from django.core.paginator import EmptyPage, PageNotAnInteger, Paginator
# Determine page size
if page_size is None:
page_size = min(
int(request.query_params.get('page_size', self.default_page_size)),
self.max_page_size
)
# Get page number
page_number = request.query_params.get('page', 1)
try:
page_number = int(page_number)
except (ValueError, TypeError):
page_number = 1
# Create paginator
paginator = Paginator(queryset, page_size)
try:
page = paginator.page(page_number)
except PageNotAnInteger:
page = paginator.page(1)
except EmptyPage:
page = paginator.page(paginator.num_pages)
# Serialize data
serializer = serializer_class(page.object_list, many=True)
# Build pagination URLs
request_url = request.build_absolute_uri().split('?')[0]
query_params = request.query_params.copy()
next_url = None
if page.has_next():
query_params['page'] = page.next_page_number()
next_url = f"{request_url}?{query_params.urlencode()}"
previous_url = None
if page.has_previous():
query_params['page'] = page.previous_page_number()
previous_url = f"{request_url}?{query_params.urlencode()}"
# Create response data
response_data = {
'count': paginator.count,
@@ -413,36 +412,36 @@ class PaginatedAPIView(ContractCompliantAPIView):
'current_page': page.number,
'total_pages': paginator.num_pages
}
return self.success_response(response_data)
def contract_compliant_view(view_class):
"""
Decorator to make any view contract-compliant.
This decorator can be applied to existing views to add contract
validation without changing the base class.
"""
original_dispatch = view_class.dispatch
def new_dispatch(self, request, *args, **kwargs):
try:
response = original_dispatch(self, request, *args, **kwargs)
# Add contract validation in DEBUG mode
if settings.DEBUG and hasattr(response, 'data'):
# Basic validation - can be extended
pass
return response
except Exception as e:
logger.error(
f"Error in decorated view {view_class.__name__}: {str(e)}",
exc_info=True
)
# Return basic error response
return Response(
{
@@ -455,6 +454,6 @@ def contract_compliant_view(view_class):
},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
view_class.dispatch = new_dispatch
return view_class

View File

@@ -1,14 +1,14 @@
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
from rest_framework.permissions import AllowAny
from django.db.models import F
from django.utils import timezone
from drf_spectacular.utils import extend_schema
from datetime import timedelta
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from rest_framework.views import APIView
from apps.parks.models import Park
from apps.rides.models import Ride
class DiscoveryAPIView(APIView):
"""
API endpoint for discovery content (Top Lists, Opening/Closing Soon).
@@ -28,7 +28,7 @@ class DiscoveryAPIView(APIView):
# --- TOP LISTS ---
# Top Parks by average rating
top_parks = Park.objects.filter(average_rating__isnull=False).order_by("-average_rating")[:limit]
# Top Rides by average rating (fallback to RideRanking in future)
top_rides = Ride.objects.filter(average_rating__isnull=False).order_by("-average_rating")[:limit]
@@ -70,7 +70,7 @@ class DiscoveryAPIView(APIView):
"rides": self._serialize(recently_closed_rides, "ride"),
}
}
return Response(data)
def _serialize(self, queryset, type_):

View File

@@ -6,14 +6,15 @@ performance metrics, and database analysis.
"""
import time
from django.utils import timezone
from django.conf import settings
from rest_framework.views import APIView
from django.utils import timezone
from drf_spectacular.utils import extend_schema, extend_schema_view
from health_check.views import MainView
from rest_framework.permissions import AllowAny
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.permissions import AllowAny
from health_check.views import MainView
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework.views import APIView
# Import serializers
from ..serializers import (
@@ -150,9 +151,10 @@ class HealthCheckAPIView(APIView):
def _get_database_metrics(self) -> dict:
"""Get database performance metrics."""
try:
from django.db import connection
from typing import Any
from django.db import connection
# Get basic connection info
metrics: dict[str, Any] = {
"vendor": connection.vendor,

View File

@@ -1,18 +1,18 @@
"""
Leaderboard views for user rankings
"""
from rest_framework.decorators import api_view, permission_classes
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from datetime import timedelta
from django.db.models import Count, Sum
from django.db.models.functions import Coalesce
from django.utils import timezone
from datetime import timedelta
from rest_framework.decorators import api_view, permission_classes
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from apps.accounts.models import User
from apps.rides.models import RideCredit
from apps.reviews.models import Review
from apps.moderation.models import EditSubmission
from apps.reviews.models import Review
from apps.rides.models import RideCredit
@api_view(['GET'])
@@ -20,7 +20,7 @@ from apps.moderation.models import EditSubmission
def leaderboard(request):
"""
Get user leaderboard data.
Query params:
- category: 'credits' | 'reviews' | 'contributions' (default: credits)
- period: 'all' | 'monthly' | 'weekly' (default: all)
@@ -29,14 +29,14 @@ def leaderboard(request):
category = request.query_params.get('category', 'credits')
period = request.query_params.get('period', 'all')
limit = min(int(request.query_params.get('limit', 25)), 100)
# Calculate date filter based on period
date_filter = None
if period == 'weekly':
date_filter = timezone.now() - timedelta(days=7)
elif period == 'monthly':
date_filter = timezone.now() - timedelta(days=30)
if category == 'credits':
return _get_credits_leaderboard(date_filter, limit)
elif category == 'reviews':
@@ -50,16 +50,16 @@ def leaderboard(request):
def _get_credits_leaderboard(date_filter, limit):
"""Top users by total ride credits."""
queryset = RideCredit.objects.all()
if date_filter:
queryset = queryset.filter(created_at__gte=date_filter)
# Aggregate credits per user
users_data = queryset.values('user_id', 'user__username', 'user__display_name').annotate(
total_credits=Coalesce(Sum('count'), 0),
unique_rides=Count('ride', distinct=True),
).order_by('-total_credits')[:limit]
results = []
for rank, entry in enumerate(users_data, 1):
results.append({
@@ -70,7 +70,7 @@ def _get_credits_leaderboard(date_filter, limit):
'total_credits': entry['total_credits'],
'unique_rides': entry['unique_rides'],
})
return Response({
'category': 'credits',
'results': results,
@@ -80,15 +80,15 @@ def _get_credits_leaderboard(date_filter, limit):
def _get_reviews_leaderboard(date_filter, limit):
"""Top users by review count."""
queryset = Review.objects.all()
if date_filter:
queryset = queryset.filter(created_at__gte=date_filter)
# Count reviews per user
users_data = queryset.values('user_id', 'user__username', 'user__display_name').annotate(
review_count=Count('id'),
).order_by('-review_count')[:limit]
results = []
for rank, entry in enumerate(users_data, 1):
results.append({
@@ -98,7 +98,7 @@ def _get_reviews_leaderboard(date_filter, limit):
'display_name': entry['user__display_name'] or entry['user__username'],
'review_count': entry['review_count'],
})
return Response({
'category': 'reviews',
'results': results,
@@ -108,15 +108,15 @@ def _get_reviews_leaderboard(date_filter, limit):
def _get_contributions_leaderboard(date_filter, limit):
"""Top users by approved contributions."""
queryset = EditSubmission.objects.filter(status='approved')
if date_filter:
queryset = queryset.filter(created_at__gte=date_filter)
# Count contributions per user
users_data = queryset.values('submitted_by_id', 'submitted_by__username', 'submitted_by__display_name').annotate(
contribution_count=Count('id'),
).order_by('-contribution_count')[:limit]
results = []
for rank, entry in enumerate(users_data, 1):
results.append({
@@ -126,7 +126,7 @@ def _get_contributions_leaderboard(date_filter, limit):
'display_name': entry['submitted_by__display_name'] or entry['submitted_by__username'],
'contribution_count': entry['contribution_count'],
})
return Response({
'category': 'contributions',
'results': results,

View File

@@ -2,17 +2,19 @@
Views for review-related API endpoints.
"""
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework.permissions import AllowAny
from rest_framework import status
from drf_spectacular.utils import extend_schema, OpenApiParameter
from drf_spectacular.types import OpenApiTypes
from itertools import chain
from operator import attrgetter
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema
from rest_framework import status
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from rest_framework.views import APIView
from apps.parks.models.reviews import ParkReview
from apps.rides.models.reviews import RideReview
from ..serializers.reviews import LatestReviewSerializer

View File

@@ -5,24 +5,29 @@ Provides aggregate statistics about the platform's content including
counts of parks, rides, manufacturers, and other entities.
"""
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
from rest_framework.permissions import AllowAny, IsAdminUser
from django.db.models import Count
from django.core.cache import cache
from django.utils import timezone
from drf_spectacular.utils import extend_schema, OpenApiExample
from datetime import datetime
from apps.parks.models import Park, ParkReview, ParkPhoto, Company as ParkCompany
from django.core.cache import cache
from django.db.models import Count
from django.utils import timezone
from drf_spectacular.utils import OpenApiExample, extend_schema
from rest_framework import status
from rest_framework.permissions import AllowAny, IsAdminUser
from rest_framework.response import Response
from rest_framework.views import APIView
from apps.parks.models import Company as ParkCompany
from apps.parks.models import Park, ParkPhoto, ParkReview
from apps.rides.models import (
Ride,
RollerCoasterStats,
RideReview,
RidePhoto,
Company as RideCompany,
)
from apps.rides.models import (
Ride,
RidePhoto,
RideReview,
RollerCoasterStats,
)
from ..serializers.stats import StatsSerializer
@@ -103,17 +108,17 @@ class StatsAPIView(APIView):
summary="Get platform statistics",
description="""
Returns comprehensive aggregate statistics about the ThrillWiki platform.
This endpoint provides detailed counts and breakdowns of all major entities including:
- Parks, rides, and roller coasters
- Companies (manufacturers, operators, designers, property owners)
- Photos and reviews
- Ride categories (roller coasters, dark rides, flat rides, etc.)
- Status breakdowns (operating, closed, under construction, etc.)
Results are cached for 5 minutes for optimal performance and automatically
Results are cached for 5 minutes for optimal performance and automatically
invalidated when relevant data changes.
**No authentication required** - this is a public endpoint.
""".strip(),
responses={

View File

@@ -5,14 +5,15 @@ This module contains endpoints for trending and new content discovery
including trending parks, rides, and recently added content.
"""
from datetime import datetime, date
from rest_framework.views import APIView
from datetime import date, datetime
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view
from rest_framework import status
from rest_framework.permissions import AllowAny, IsAdminUser
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.permissions import AllowAny, IsAdminUser
from rest_framework import status
from drf_spectacular.utils import extend_schema, extend_schema_view, OpenApiParameter
from drf_spectacular.types import OpenApiTypes
from rest_framework.views import APIView
@extend_schema_view(
@@ -111,9 +112,10 @@ class TriggerTrendingCalculationAPIView(APIView):
def post(self, request: Request) -> Response:
"""Trigger trending content calculation using management commands."""
try:
from django.core.management import call_command
import io
from contextlib import redirect_stdout, redirect_stderr
from contextlib import redirect_stderr, redirect_stdout
from django.core.management import call_command
# Capture command output
trending_output = io.StringIO()
@@ -227,10 +229,7 @@ class NewContentAPIView(APIView):
if date_added:
try:
# Parse the date string
if isinstance(date_added, str):
item_date = datetime.fromisoformat(date_added).date()
else:
item_date = date_added
item_date = datetime.fromisoformat(date_added).date() if isinstance(date_added, str) else date_added
# Calculate days difference
days_diff = (today - item_date).days