feat: Implement MFA authentication, add ride statistics model, and update various services, APIs, and tests across the application.

This commit is contained in:
pacnpal
2025-12-28 17:32:53 -05:00
parent aa56c46c27
commit c95f99ca10
452 changed files with 7948 additions and 6073 deletions

View File

@@ -14,32 +14,25 @@ Usage:
import random
from datetime import date
from decimal import Decimal
from typing import List
from django.core.management.base import BaseCommand, CommandError
from django.contrib.auth import get_user_model
from django.contrib.gis.geos import Point
from django.core.management.base import BaseCommand, CommandError
from django.db import transaction
from django.utils.text import slugify
# Import all models
from apps.accounts.models import (
User, UserProfile, UserNotification,
NotificationPreference, UserDeletionRequest
)
from apps.parks.models import (
Park, ParkLocation, ParkArea, ParkPhoto, ParkReview
)
from apps.parks.models.companies import Company as ParkCompany, CompanyHeadquarters
from apps.rides.models import (
Ride, RideModel, RollerCoasterStats, RidePhoto, RideReview, RideLocation
)
from apps.rides.models.company import Company as RideCompany
from apps.accounts.models import NotificationPreference, UserDeletionRequest, UserNotification, UserProfile
from apps.core.history import HistoricalSlug
from apps.parks.models import Park, ParkArea, ParkLocation, ParkPhoto, ParkReview
from apps.parks.models.companies import Company as ParkCompany
from apps.parks.models.companies import CompanyHeadquarters
from apps.rides.models import Ride, RideLocation, RideModel, RidePhoto, RideReview, RollerCoasterStats
from apps.rides.models.company import Company as RideCompany
# Try to import optional models that may not exist
try:
from apps.rides.models import RideModelVariant, RideModelPhoto, RideModelTechnicalSpec
from apps.rides.models import RideModelPhoto, RideModelTechnicalSpec, RideModelVariant
except ImportError:
RideModelVariant = None
RideModelPhoto = None
@@ -51,7 +44,7 @@ except ImportError:
RideRanking = None
try:
from apps.moderation.models import ModerationQueue, ModerationAction
from apps.moderation.models import ModerationAction, ModerationQueue
except ImportError:
ModerationQueue = None
ModerationAction = None
@@ -125,16 +118,16 @@ class Command(BaseCommand):
ride_models = self.create_ride_models(options['ride_models'], companies)
parks = self.create_parks(options['parks'], companies)
rides = self.create_rides(options['rides'], parks, companies, ride_models)
# Create content and interactions
self.create_reviews(options['reviews'], users, parks, rides)
self.create_notifications(users)
self.create_moderation_data(users, parks, rides)
# Create media and photos
self.create_photos(parks, rides, ride_models)
# Create rankings and statistics
self.create_rankings(rides)
@@ -146,26 +139,26 @@ class Command(BaseCommand):
def clear_data(self):
"""Clear existing data in reverse dependency order"""
self.stdout.write('🗑️ Clearing existing data...')
models_to_clear = [
# Content and interactions (clear first)
UserNotification, NotificationPreference,
ParkReview, RideReview, ModerationAction, ModerationQueue,
# Media
ParkPhoto, RidePhoto, CloudflareImage,
# Core entities
RollerCoasterStats, Ride, ParkArea, Park, ParkLocation,
RideModel, CompanyHeadquarters, ParkCompany, RideCompany,
# Users (clear last due to foreign key dependencies)
UserDeletionRequest, UserProfile, User,
# History
HistoricalSlug,
]
# Add optional models if they exist
if RideRanking:
models_to_clear.insert(4, RideRanking)
@@ -179,7 +172,7 @@ class Command(BaseCommand):
models_to_clear.insert(-6, RideModelVariant)
if ModerationQueue:
models_to_clear.insert(4, ModerationQueue)
for model in models_to_clear:
try:
count = model.objects.count()
@@ -193,12 +186,12 @@ class Command(BaseCommand):
# Continue with other models
continue
def create_users(self, count: int) -> List[User]:
def create_users(self, count: int) -> list[User]:
"""Create diverse users with comprehensive profiles"""
self.stdout.write(f'👥 Creating {count} users...')
users = []
# Create admin user if it doesn't exist
admin, created = User.objects.get_or_create(
username='admin',
@@ -216,7 +209,7 @@ class Command(BaseCommand):
admin.set_password('admin123')
admin.save()
users.append(admin)
# Create moderator if it doesn't exist
moderator, created = User.objects.get_or_create(
username='moderator',
@@ -233,7 +226,7 @@ class Command(BaseCommand):
moderator.set_password('mod123')
moderator.save()
users.append(moderator)
# Sample user data
first_names = [
'Alex', 'Jordan', 'Taylor', 'Casey', 'Morgan', 'Riley', 'Avery', 'Quinn',
@@ -241,23 +234,23 @@ class Command(BaseCommand):
'Jamie', 'Kendall', 'Logan', 'Parker', 'Peyton', 'Reese', 'Sage',
'Skyler', 'Sydney', 'Tanner'
]
last_names = [
'Smith', 'Johnson', 'Williams', 'Brown', 'Jones', 'Garcia', 'Miller',
'Davis', 'Rodriguez', 'Martinez', 'Hernandez', 'Lopez', 'Gonzalez',
'Wilson', 'Anderson', 'Thomas', 'Taylor', 'Moore', 'Jackson', 'Martin',
'Lee', 'Perez', 'Thompson', 'White', 'Harris'
]
domains = ['gmail.com', 'yahoo.com', 'hotmail.com', 'outlook.com', 'icloud.com']
# Create regular users
for i in range(count - 2): # -2 for admin and moderator
for _i in range(count - 2): # -2 for admin and moderator
first_name = random.choice(first_names)
last_name = random.choice(last_names)
username = f"{first_name.lower()}{last_name.lower()}{random.randint(1, 999)}"
email = f"{username}@{random.choice(domains)}"
user = User.objects.create_user(
username=username,
email=email,
@@ -275,7 +268,7 @@ class Command(BaseCommand):
two_factor_enabled=random.choice([True, False]),
login_notifications=random.choice([True, False]),
)
# Create detailed notification preferences
user.notification_preferences = {
'email': {
@@ -295,7 +288,7 @@ class Command(BaseCommand):
}
}
user.save()
# Create user profile with ride credits
profile = UserProfile.objects.get(user=user)
profile.bio = f"Thrill seeker from {random.choice(['California', 'Florida', 'Ohio', 'Pennsylvania', 'Texas'])}. Love roller coasters!"
@@ -304,7 +297,7 @@ class Command(BaseCommand):
profile.dark_ride_credits = random.randint(0, 100)
profile.flat_ride_credits = random.randint(0, 200)
profile.water_ride_credits = random.randint(0, 50)
# Add social media links for some users
if random.random() < 0.3:
profile.twitter = f"https://twitter.com/{username}"
@@ -312,19 +305,19 @@ class Command(BaseCommand):
profile.instagram = f"https://instagram.com/{username}"
if random.random() < 0.1:
profile.youtube = f"https://youtube.com/@{username}"
profile.save()
users.append(user)
self.stdout.write(f' ✅ Created {len(users)} users')
return users
def create_companies(self, count: int) -> List:
def create_companies(self, count: int) -> list:
"""Create companies with different roles"""
self.stdout.write(f'🏢 Creating {count} companies...')
companies = []
# Major theme park operators
operators_data = [
('Walt Disney Company', ['OPERATOR', 'PROPERTY_OWNER'], 1923, 'Burbank, CA, USA'),
@@ -335,7 +328,7 @@ class Command(BaseCommand):
('Busch Gardens', ['OPERATOR'], 1959, 'Tampa, FL, USA'),
('Knott\'s Berry Farm', ['OPERATOR'], 1920, 'Buena Park, CA, USA'),
]
# Major ride manufacturers
manufacturers_data = [
('Bolliger & Mabillard', ['MANUFACTURER'], 1988, 'Monthey, Switzerland'),
@@ -347,16 +340,16 @@ class Command(BaseCommand):
('Premier Rides', ['MANUFACTURER'], 1994, 'Baltimore, MD, USA'),
('S&S Worldwide', ['MANUFACTURER'], 1994, 'Logan, UT, USA'),
]
# Ride designers
designers_data = [
('Werner Stengel', ['DESIGNER'], 1965, 'Munich, Germany'),
('Alan Schilke', ['DESIGNER'], 1990, 'Hayden, ID, USA'),
('John Wardley', ['DESIGNER'], 1970, 'London, UK'),
]
all_company_data = operators_data + manufacturers_data + designers_data
for name, roles, founded_year, location in all_company_data:
# Determine which Company model to use based on roles
if 'OPERATOR' in roles or 'PROPERTY_OWNER' in roles:
@@ -387,7 +380,7 @@ class Command(BaseCommand):
'coasters_count': random.randint(5, 100) if 'MANUFACTURER' in roles else 0,
}
)
# Create headquarters if company was created and is a ParkCompany
if created and isinstance(company, ParkCompany):
city, state_country = location.rsplit(', ', 1)
@@ -397,7 +390,7 @@ class Command(BaseCommand):
else:
state = ''
country = state_country
CompanyHeadquarters.objects.get_or_create(
company=company,
defaults={
@@ -408,16 +401,16 @@ class Command(BaseCommand):
'postal_code': f"{random.randint(10000, 99999)}" if country == 'USA' else '',
}
)
companies.append(company)
# Create additional random companies to reach the target count
company_types = ['Theme Parks', 'Amusements', 'Entertainment', 'Rides', 'Design', 'Engineering']
for i in range(len(all_company_data), count):
for _i in range(len(all_company_data), count):
company_type = random.choice(company_types)
name = f"{random.choice(['Global', 'International', 'Premier', 'Elite', 'Advanced', 'Creative'])} {company_type} {'Group' if random.random() < 0.5 else 'Corporation'}"
roles = []
if 'Theme Parks' in name or 'Amusements' in name:
roles = ['OPERATOR']
@@ -429,7 +422,7 @@ class Command(BaseCommand):
roles = ['DESIGNER']
else:
roles = [random.choice(['OPERATOR', 'MANUFACTURER', 'DESIGNER'])]
# Use appropriate company model based on roles
if 'OPERATOR' in roles or 'PROPERTY_OWNER' in roles:
company = ParkCompany.objects.create(
@@ -453,12 +446,12 @@ class Command(BaseCommand):
rides_count=random.randint(5, 100) if 'MANUFACTURER' in roles else 0,
coasters_count=random.randint(2, 50) if 'MANUFACTURER' in roles else 0,
)
# Create headquarters
cities = ['Los Angeles', 'New York', 'Chicago', 'Houston', 'Phoenix', 'Philadelphia', 'San Antonio', 'San Diego', 'Dallas', 'San Jose']
states = ['CA', 'NY', 'IL', 'TX', 'AZ', 'PA', 'TX', 'CA', 'TX', 'CA']
city_state = random.choice(list(zip(cities, states)))
city_state = random.choice(list(zip(cities, states, strict=False)))
CompanyHeadquarters.objects.create(
company=company,
city=city_state[0],
@@ -467,23 +460,23 @@ class Command(BaseCommand):
street_address=f"{random.randint(100, 9999)} {random.choice(['Business', 'Corporate', 'Industry', 'Commerce'])} {random.choice(['Pkwy', 'Blvd', 'Dr', 'Way'])}",
postal_code=f"{random.randint(10000, 99999)}",
)
companies.append(company)
self.stdout.write(f' ✅ Created {len(companies)} companies')
return companies
def create_ride_models(self, count: int, companies: List) -> List[RideModel]:
def create_ride_models(self, count: int, companies: list) -> list[RideModel]:
"""Create ride models from manufacturers"""
self.stdout.write(f'🎢 Creating {count} ride models...')
manufacturers = [c for c in companies if 'MANUFACTURER' in c.roles]
if not manufacturers:
self.stdout.write(' ⚠️ No manufacturers found, skipping ride models')
return []
ride_models = []
# Famous ride models
famous_models = [
('Dive Coaster', 'RC', 'Bolliger & Mabillard', 'Vertical drop roller coaster with holding brake'),
@@ -507,12 +500,12 @@ class Command(BaseCommand):
('Drop Tower', 'FR', 'Intamin', 'Vertical drop ride'),
('Gyro Drop', 'FR', 'Intamin', 'Tilting drop tower'),
]
for model_name, category, manufacturer_name, description in famous_models:
manufacturer = next((c for c in manufacturers if manufacturer_name in c.name), None)
if not manufacturer:
manufacturer = random.choice(manufacturers)
ride_model, created = RideModel.objects.get_or_create(
name=model_name,
manufacturer=manufacturer,
@@ -536,7 +529,7 @@ class Command(BaseCommand):
'total_installations': random.randint(1, 50),
}
)
# Create technical specs if model exists
if category == 'RC' and RideModelTechnicalSpec:
specs = [
@@ -545,7 +538,7 @@ class Command(BaseCommand):
('CAPACITY', 'Riders per Train', f"{random.randint(20, 32)}", 'people'),
('SAFETY', 'Block Zones', f"{random.randint(4, 8)}", 'zones'),
]
for spec_category, spec_name, spec_value, spec_unit in specs:
RideModelTechnicalSpec.objects.create(
ride_model=ride_model,
@@ -554,31 +547,31 @@ class Command(BaseCommand):
spec_value=spec_value,
spec_unit=spec_unit,
)
# Create variants for some models if model exists
if random.random() < 0.3 and RideModelVariant:
variant_names = ['Compact', 'Extended', 'Family', 'Extreme', 'Custom']
variant_name = random.choice(variant_names)
RideModelVariant.objects.create(
ride_model=ride_model,
name=f"{variant_name} Version",
description=f"Modified version of {model_name} for {variant_name.lower()} installations",
distinguishing_features=f"Optimized for {variant_name.lower()} market segment",
)
ride_models.append(ride_model)
# Create additional random models
model_types = ['Coaster', 'Ride', 'System', 'Experience', 'Adventure']
prefixes = ['Mega', 'Super', 'Ultra', 'Hyper', 'Giga', 'Extreme', 'Family', 'Junior']
for i in range(len(famous_models), count):
for _i in range(len(famous_models), count):
manufacturer = random.choice(manufacturers)
category = random.choice(['RC', 'DR', 'FR', 'WR', 'TR'])
model_name = f"{random.choice(prefixes)} {random.choice(model_types)}"
ride_model = RideModel.objects.create(
name=model_name,
manufacturer=manufacturer,
@@ -606,31 +599,31 @@ class Command(BaseCommand):
]),
total_installations=random.randint(0, 25),
)
ride_models.append(ride_model)
self.stdout.write(f' ✅ Created {len(ride_models)} ride models')
return ride_models
def create_parks(self, count: int, companies: List) -> List[Park]:
def create_parks(self, count: int, companies: list) -> list[Park]:
"""Create parks with locations and areas"""
self.stdout.write(f'🏰 Creating {count} parks...')
if count == 0:
self.stdout.write(' Skipping park creation (count = 0)')
return []
operators = [c for c in companies if 'OPERATOR' in c.roles]
property_owners = [c for c in companies if 'PROPERTY_OWNER' in c.roles]
if not operators:
raise CommandError('No operators found. Create companies first.')
parks = []
# Famous theme parks with timezone information
famous_parks = [
('Magic Kingdom', 'Walt Disney World\'s flagship theme park', 'THEME_PARK', 'OPERATING',
('Magic Kingdom', 'Walt Disney World\'s flagship theme park', 'THEME_PARK', 'OPERATING',
date(1971, 10, 1), 107, 'Orlando', 'FL', 'USA', 28.4177, -81.5812, 'America/New_York'),
('Disneyland', 'The original Disney theme park', 'THEME_PARK', 'OPERATING',
date(1955, 7, 17), 85, 'Anaheim', 'CA', 'USA', 33.8121, -117.9190, 'America/Los_Angeles'),
@@ -647,7 +640,7 @@ class Command(BaseCommand):
('SeaWorld Orlando', 'Marine life theme park', 'THEME_PARK', 'OPERATING',
date(1973, 12, 15), 200, 'Orlando', 'FL', 'USA', 28.4110, -81.4610, 'America/New_York'),
]
for park_name, description, park_type, status, opening_date, size_acres, city, state, country, lat, lng, timezone_str in famous_parks:
# Find appropriate operator
operator = None
@@ -665,15 +658,15 @@ class Command(BaseCommand):
operator = next((c for c in operators if 'Busch' in c.name), None)
elif 'SeaWorld' in park_name:
operator = next((c for c in operators if 'SeaWorld' in c.name), None)
if not operator:
operator = random.choice(operators)
# Find property owner (could be same as operator)
property_owner = None
if property_owners and random.random() < 0.7:
property_owner = random.choice(property_owners)
# Use get_or_create to avoid duplicates
park, created = Park.objects.get_or_create(
name=park_name,
@@ -693,14 +686,14 @@ class Command(BaseCommand):
)
if not created:
self.stdout.write(f' Using existing park: {park_name}')
# Create park location only if it doesn't exist
location_exists = False
try:
location_exists = hasattr(park, 'location') and park.location is not None
except Exception:
location_exists = False
if created or not location_exists:
ParkLocation.objects.get_or_create(
park=park,
@@ -713,7 +706,7 @@ class Command(BaseCommand):
'postal_code': f"{random.randint(10000, 99999)}" if country == 'USA' else '',
}
)
# Create park areas only if park was created
if created:
area_names = ['Main Street', 'Fantasyland', 'Tomorrowland', 'Adventureland', 'Frontierland']
@@ -725,9 +718,9 @@ class Command(BaseCommand):
'description': f"Themed area within {park_name}",
}
)
parks.append(park)
# Create additional random parks
park_types = ['THEME_PARK', 'AMUSEMENT_PARK', 'WATER_PARK', 'FAMILY_ENTERTAINMENT_CENTER']
cities_data = [
@@ -740,28 +733,28 @@ class Command(BaseCommand):
('San Antonio', 'TX', 'USA', 29.4241, -98.4936),
('San Diego', 'CA', 'USA', 32.7157, -117.1611),
]
for i in range(len(famous_parks), count):
park_type = random.choice(park_types)
# Make park names more unique by adding a number
park_name = f"{random.choice(['Adventure', 'Magic', 'Wonder', 'Fantasy', 'Thrill', 'Family'])} {random.choice(['World', 'Land', 'Park', 'Kingdom', 'Gardens'])} {i + 1}"
operator = random.choice(operators)
property_owner = random.choice(property_owners) if property_owners and random.random() < 0.5 else None
city, state, country, lat, lng = random.choice(cities_data)
# Determine timezone based on state
timezone_map = {
'CA': 'America/Los_Angeles',
'NY': 'America/New_York',
'NY': 'America/New_York',
'IL': 'America/Chicago',
'TX': 'America/Chicago',
'AZ': 'America/Phoenix',
'PA': 'America/New_York',
}
park_timezone = timezone_map.get(state, 'America/New_York')
park = Park.objects.create(
name=park_name,
description=f"Exciting {park_type.lower().replace('_', ' ')} featuring thrilling rides and family entertainment",
@@ -776,11 +769,11 @@ class Command(BaseCommand):
coaster_count=random.randint(2, 15),
timezone=park_timezone,
)
# Create park location with slight coordinate variation
lat_offset = random.uniform(-0.1, 0.1)
lng_offset = random.uniform(-0.1, 0.1)
ParkLocation.objects.create(
park=park,
point=Point(lng + lng_offset, lat + lat_offset),
@@ -790,7 +783,7 @@ class Command(BaseCommand):
country=country,
postal_code=f"{random.randint(10000, 99999)}",
)
# Create park areas
area_names = ['Main Plaza', 'Adventure Zone', 'Family Area', 'Thrill Section', 'Water World', 'Kids Corner']
for area_name in random.sample(area_names, random.randint(2, 4)):
@@ -799,25 +792,25 @@ class Command(BaseCommand):
name=area_name,
description=f"Themed area within {park_name}",
)
parks.append(park)
self.stdout.write(f' ✅ Created {len(parks)} parks')
return parks
def create_rides(self, count: int, parks: List[Park], companies: List, ride_models: List[RideModel]) -> List[Ride]:
def create_rides(self, count: int, parks: list[Park], companies: list, ride_models: list[RideModel]) -> list[Ride]:
"""Create rides with comprehensive details"""
self.stdout.write(f'🎠 Creating {count} rides...')
if not parks:
self.stdout.write(' ⚠️ No parks found, skipping rides')
return []
manufacturers = [c for c in companies if 'MANUFACTURER' in c.roles]
designers = [c for c in companies if 'DESIGNER' in c.roles]
rides = []
# Famous roller coasters
famous_coasters = [
('Steel Vengeance', 'RC', 'Hybrid steel-wood roller coaster', 'Rocky Mountain Construction'),
@@ -831,7 +824,7 @@ class Command(BaseCommand):
('Twisted Timbers', 'RC', 'RMC conversion of wooden coaster', 'Rocky Mountain Construction'),
('Goliath', 'RC', 'Hyper coaster with massive drops', 'Bolliger & Mabillard'),
]
# Create famous coasters
for coaster_name, category, description, manufacturer_name in famous_coasters:
park = random.choice(parks)
@@ -840,14 +833,14 @@ class Command(BaseCommand):
manufacturer = next((c for c in manufacturers if manufacturer_name in c.name), None)
if not manufacturer and manufacturers:
manufacturer = random.choice(manufacturers)
designer = random.choice(designers) if designers and random.random() < 0.3 else None
ride_model = random.choice(ride_models) if ride_models and random.random() < 0.5 else None
# Get park areas for this park
park_areas = list(park.areas.all())
park_area = random.choice(park_areas) if park_areas else None
ride = Ride.objects.create(
name=coaster_name,
description=description,
@@ -864,7 +857,7 @@ class Command(BaseCommand):
ride_duration_seconds=random.randint(90, 240),
average_rating=Decimal(str(random.uniform(7.0, 9.5))),
)
# Create roller coaster stats
if category == 'RC':
RollerCoasterStats.objects.create(
@@ -884,9 +877,9 @@ class Command(BaseCommand):
cars_per_train=random.randint(6, 8),
seats_per_car=random.randint(2, 4),
)
rides.append(ride)
# Create additional random rides
ride_names = [
'Thunder Mountain', 'Space Coaster', 'Wild Eagle', 'Dragon Fire', 'Phoenix Rising',
@@ -894,21 +887,21 @@ class Command(BaseCommand):
'Viper', 'Cobra', 'Rattlesnake', 'Sidewinder', 'Diamondback', 'Copperhead',
'Banshee', 'Valkyrie', 'Griffon', 'Falcon', 'Eagle\'s Flight', 'Soaring Heights'
]
categories = ['RC', 'DR', 'FR', 'WR', 'TR', 'OT']
for i in range(len(famous_coasters), count):
for _i in range(len(famous_coasters), count):
park = random.choice(parks)
park_areas = list(park.areas.all())
park_area = random.choice(park_areas) if park_areas else None
ride_name = random.choice(ride_names)
category = random.choice(categories)
manufacturer = random.choice(manufacturers) if manufacturers and random.random() < 0.7 else None
designer = random.choice(designers) if designers and random.random() < 0.2 else None
ride_model = random.choice(ride_models) if ride_models and random.random() < 0.4 else None
ride = Ride.objects.create(
name=ride_name,
description=f"Exciting {category} ride with thrilling elements and smooth operation",
@@ -925,7 +918,7 @@ class Command(BaseCommand):
ride_duration_seconds=random.randint(60, 300),
average_rating=Decimal(str(random.uniform(6.0, 9.0))),
)
# Create roller coaster stats for RC category
if category == 'RC':
RollerCoasterStats.objects.create(
@@ -945,20 +938,20 @@ class Command(BaseCommand):
cars_per_train=random.randint(4, 8),
seats_per_car=random.randint(2, 4),
)
rides.append(ride)
self.stdout.write(f' ✅ Created {len(rides)} rides')
return rides
def create_reviews(self, count: int, users: List[User], parks: List[Park], rides: List[Ride]) -> None:
def create_reviews(self, count: int, users: list[User], parks: list[Park], rides: list[Ride]) -> None:
"""Create park and ride reviews"""
self.stdout.write(f'📝 Creating {count} reviews...')
if not users or (not parks and not rides):
self.stdout.write(' ⚠️ No users or content found, skipping reviews')
return
review_texts = [
"Amazing experience! The rides were thrilling and the staff was very friendly.",
"Great park with excellent theming. The roller coasters are world-class.",
@@ -971,21 +964,21 @@ class Command(BaseCommand):
"Family-friendly atmosphere with rides for all ages.",
"Outstanding park operations and friendly staff throughout.",
]
# Create park reviews
park_review_count = count // 2
created_park_reviews = 0
attempts = 0
max_attempts = park_review_count * 3 # Allow multiple attempts to avoid infinite loops
while created_park_reviews < park_review_count and attempts < max_attempts:
if not parks:
break
user = random.choice(users)
park = random.choice(parks)
attempts += 1
# Use get_or_create to avoid duplicates
review, created = ParkReview.objects.get_or_create(
user=user,
@@ -1002,24 +995,24 @@ class Command(BaseCommand):
),
}
)
if created:
created_park_reviews += 1
# Create ride reviews
ride_review_count = count - created_park_reviews
created_ride_reviews = 0
attempts = 0
max_attempts = ride_review_count * 3 # Allow multiple attempts to avoid infinite loops
while created_ride_reviews < ride_review_count and attempts < max_attempts:
if not rides:
break
user = random.choice(users)
ride = random.choice(rides)
attempts += 1
# Use get_or_create to avoid duplicates
review, created = RideReview.objects.get_or_create(
user=user,
@@ -1036,36 +1029,36 @@ class Command(BaseCommand):
),
}
)
if created:
created_ride_reviews += 1
self.stdout.write(f' ✅ Created {count} reviews')
def create_notifications(self, users: List[User]) -> None:
def create_notifications(self, users: list[User]) -> None:
"""Create sample notifications for users"""
self.stdout.write('🔔 Creating notifications...')
if not users:
self.stdout.write(' ⚠️ No users found, skipping notifications')
return
notification_count = 0
notification_types = [
("submission_approved", "Your park submission has been approved!", "Great news! Your submission for Adventure Park has been approved and is now live."),
("review_helpful", "Someone found your review helpful", "Your review of Steel Vengeance was marked as helpful by another user."),
("system_announcement", "New features available", "Check out our new ride comparison tool and enhanced search filters."),
("achievement_unlocked", "Achievement unlocked!", "Congratulations! You've unlocked the 'Coaster Enthusiast' achievement."),
]
# Create notifications for random users
for user in random.sample(users, min(len(users), 15)):
for _ in range(random.randint(1, 3)):
notification_type, title, message = random.choice(notification_types)
UserNotification.objects.create(
user=user,
notification_type=notification_type,
@@ -1077,50 +1070,50 @@ class Command(BaseCommand):
push_sent=random.choice([True, False]),
)
notification_count += 1
self.stdout.write(f' ✅ Created {notification_count} notifications')
def create_moderation_data(self, users: List[User], parks: List[Park], rides: List[Ride]) -> None:
def create_moderation_data(self, users: list[User], parks: list[Park], rides: list[Ride]) -> None:
"""Create moderation queue and actions"""
self.stdout.write('🛡️ Creating moderation data...')
if not ModerationQueue or not ModerationAction:
self.stdout.write(' ⚠️ Moderation models not available, skipping')
return
if not users or (not parks and not rides):
self.stdout.write(' ⚠️ No users or content found, skipping moderation data')
return
# This would create sample moderation queue items and actions
# Implementation depends on the actual moderation models structure
self.stdout.write(' ✅ Moderation data creation skipped (models not fully defined)')
def create_photos(self, parks: List[Park], rides: List[Ride], ride_models: List[RideModel]) -> None:
def create_photos(self, parks: list[Park], rides: list[Ride], ride_models: list[RideModel]) -> None:
"""Create sample photo records"""
self.stdout.write('📸 Creating photo records...')
if not CloudflareImage:
self.stdout.write(' ⚠️ CloudflareImage model not available, skipping photo creation')
return
# Since we don't have actual Cloudflare images, we'll skip photo creation
# In a real scenario, you would need actual CloudflareImage instances
self.stdout.write(' ⚠️ Photo creation skipped (requires actual CloudflareImage instances)')
self.stdout.write(' To create photos, you need to upload actual images to Cloudflare first')
def create_rankings(self, rides: List[Ride]) -> None:
def create_rankings(self, rides: list[Ride]) -> None:
"""Create ride rankings if model exists"""
self.stdout.write('🏆 Creating ride rankings...')
if not RideRanking:
self.stdout.write(' ⚠️ RideRanking model not available, skipping')
return
if not rides:
self.stdout.write(' ⚠️ No rides found, skipping rankings')
return
# This would create sample ride rankings
# Implementation depends on the actual RideRanking model structure
self.stdout.write(' ✅ Ride rankings creation skipped (model structure not fully defined)')
@@ -1129,7 +1122,7 @@ class Command(BaseCommand):
"""Print a summary of created data"""
self.stdout.write('\n📊 Data Seeding Summary:')
self.stdout.write('=' * 50)
# Count all created objects
counts = {
'Users': User.objects.count(),
@@ -1145,9 +1138,9 @@ class Command(BaseCommand):
'Park Photos': ParkPhoto.objects.count(),
'Ride Photos': RidePhoto.objects.count(),
}
for model_name, count in counts.items():
self.stdout.write(f' {model_name}: {count}')
self.stdout.write('=' * 50)
self.stdout.write('🎉 Seeding completed! Your ThrillWiki database is ready for testing.')

View File

@@ -1,4 +1,4 @@
from django.urls import path, include
from django.urls import include, path
urlpatterns = [
path("v1/", include("apps.api.v1.urls")),

View File

@@ -1,5 +1,6 @@
from rest_framework import serializers
from drf_spectacular.utils import extend_schema_field
from rest_framework import serializers
from apps.accounts.models import UserProfile
from apps.accounts.serializers import UserSerializer # existing shared user serializer
@@ -24,7 +25,7 @@ class UserProfileUpdateInputSerializer(serializers.ModelSerializer):
from django_cloudflareimages_toolkit.models import CloudflareImage
image, _ = CloudflareImage.objects.get_or_create(cloudflare_id=cloudflare_id)
instance.avatar = image
return super().update(instance, validated_data)

View File

@@ -2,8 +2,14 @@
URL configuration for user account management API endpoints.
"""
from django.urls import path
from . import views
from django.urls import include, path
from rest_framework.routers import DefaultRouter
from . import views, views_credits, views_magic_link
# Register ViewSets
router = DefaultRouter()
router.register(r"credits", views_credits.RideCreditViewSet, basename="ride-credit")
urlpatterns = [
# Admin endpoints for user management
@@ -108,19 +114,18 @@ urlpatterns = [
path("profile/avatar/upload/", views.upload_avatar, name="upload_avatar"),
path("profile/avatar/save/", views.save_avatar_image, name="save_avatar_image"),
path("profile/avatar/delete/", views.delete_avatar, name="delete_avatar"),
# Login history endpoint
path("login-history/", views.get_login_history, name="get_login_history"),
# Magic Link (Login by Code) endpoints
path("magic-link/request/", views_magic_link.request_magic_link, name="request_magic_link"),
path("magic-link/verify/", views_magic_link.verify_magic_link, name="verify_magic_link"),
# Public Profile
path("profiles/<str:username>/", views.get_public_user_profile, name="get_public_user_profile"),
]
# Register ViewSets
from rest_framework.routers import DefaultRouter
from . import views_credits
from django.urls import include
router = DefaultRouter()
router.register(r"credits", views_credits.RideCreditViewSet, basename="ride-credit")
urlpatterns += [
# ViewSet routes
path("", include(router.urls)),
]

View File

@@ -6,43 +6,44 @@ user deletion while preserving submissions, profile management, settings,
preferences, privacy, notifications, and security.
"""
from apps.api.v1.serializers.accounts import (
CompleteUserSerializer,
PublicUserSerializer,
UserPreferencesSerializer,
NotificationSettingsSerializer,
PrivacySettingsSerializer,
SecuritySettingsSerializer,
UserStatisticsSerializer,
UserListSerializer,
AccountUpdateSerializer,
ProfileUpdateSerializer,
ThemePreferenceSerializer,
UserNotificationSerializer,
NotificationPreferenceSerializer,
MarkNotificationsReadSerializer,
AvatarUploadSerializer,
)
from apps.accounts.services import UserDeletionService
from apps.accounts.export_service import UserExportService
from apps.accounts.models import (
User,
UserProfile,
UserNotification,
NotificationPreference,
)
from apps.lists.models import UserList
import logging
from rest_framework import status
from rest_framework.decorators import api_view, permission_classes
from rest_framework.permissions import IsAuthenticated, IsAdminUser
from rest_framework.response import Response
from drf_spectacular.utils import extend_schema, OpenApiParameter
from drf_spectacular.types import OpenApiTypes
from django.shortcuts import get_object_or_404
from rest_framework.permissions import AllowAny
from django.utils import timezone
from django_cloudflareimages_toolkit.models import CloudflareImage
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema
from rest_framework import status
from rest_framework.decorators import api_view, permission_classes
from rest_framework.permissions import AllowAny, IsAdminUser, IsAuthenticated
from rest_framework.response import Response
from apps.accounts.export_service import UserExportService
from apps.accounts.models import (
NotificationPreference,
User,
UserNotification,
UserProfile,
)
from apps.accounts.services import UserDeletionService
from apps.api.v1.serializers.accounts import (
AccountUpdateSerializer,
AvatarUploadSerializer,
CompleteUserSerializer,
MarkNotificationsReadSerializer,
NotificationPreferenceSerializer,
NotificationSettingsSerializer,
PrivacySettingsSerializer,
ProfileUpdateSerializer,
PublicUserSerializer,
SecuritySettingsSerializer,
ThemePreferenceSerializer,
UserListSerializer,
UserNotificationSerializer,
UserPreferencesSerializer,
UserStatisticsSerializer,
)
from apps.lists.models import UserList
# Set up logging
logger = logging.getLogger(__name__)
@@ -307,7 +308,7 @@ def save_avatar_image(request):
try:
cloudflare_image = CloudflareImage.objects.get(
cloudflare_id=cloudflare_image_id)
# Update existing record with latest data from Cloudflare
cloudflare_image.status = 'uploaded'
cloudflare_image.uploaded_at = timezone.now()
@@ -319,7 +320,7 @@ def save_avatar_image(request):
cloudflare_image.height = image_data.get('height')
cloudflare_image.format = image_data.get('format', '')
cloudflare_image.save()
except CloudflareImage.DoesNotExist:
# Create new CloudflareImage record from API response
cloudflare_image = CloudflareImage.objects.create(
@@ -367,7 +368,7 @@ def save_avatar_image(request):
except Exception as e:
logger.error(f"Failed to delete old avatar from Cloudflare: {str(e)}")
# Continue with database deletion even if Cloudflare deletion fails
old_avatar.delete()
# Debug logging to see what's happening with the CloudflareImage
@@ -442,7 +443,7 @@ def delete_avatar(request):
avatar_to_delete = profile.avatar
profile.avatar = None
profile.save()
# Delete from Cloudflare first, then from database
try:
from django_cloudflareimages_toolkit.services import CloudflareImagesService
@@ -452,7 +453,7 @@ def delete_avatar(request):
except Exception as e:
logger.error(f"Failed to delete avatar from Cloudflare: {str(e)}")
# Continue with database deletion even if Cloudflare deletion fails
avatar_to_delete.delete()
# Get the default avatar URL
@@ -1273,10 +1274,10 @@ def update_security_settings(request):
# Handle security settings updates
if "two_factor_enabled" in request.data:
setattr(user, "two_factor_enabled", request.data["two_factor_enabled"])
user.two_factor_enabled = request.data["two_factor_enabled"]
if "login_notifications" in request.data:
setattr(user, "login_notifications", request.data["login_notifications"])
user.login_notifications = request.data["login_notifications"]
user.save()
@@ -1612,7 +1613,7 @@ def export_user_data(request):
except Exception as e:
logger.error(f"Error exporting data for user {request.user.id}: {e}", exc_info=True)
return Response(
{"error": "Failed to generate data export"},
{"error": "Failed to generate data export"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
@@ -1636,54 +1637,73 @@ def get_public_user_profile(request, username):
return Response(serializer.data, status=status.HTTP_200_OK)
# === MISSING FUNCTION IMPLEMENTATIONS ===
@extend_schema(
operation_id="request_account_deletion",
summary="Request account deletion",
description="Request deletion of the authenticated user's account.",
operation_id="get_login_history",
summary="Get user login history",
description=(
"Returns the authenticated user's recent login history including "
"IP addresses, devices, and timestamps for security auditing."
),
parameters=[
OpenApiParameter(
name="limit",
type=OpenApiTypes.INT,
location=OpenApiParameter.QUERY,
description="Maximum number of entries to return (default: 20, max: 100)",
),
],
responses={
200: {"description": "Deletion request created"},
400: {"description": "Cannot delete account"},
},
tags=["Self-Service Account Management"],
)
@api_view(["POST"])
@permission_classes([IsAuthenticated])
def request_account_deletion(request):
"""Request account deletion."""
try:
user = request.user
# Check if user can be deleted
can_delete, reason = UserDeletionService.can_delete_user(user)
if not can_delete:
return Response(
{"success": False, "error": reason},
status=status.HTTP_400_BAD_REQUEST,
)
# Create deletion request
deletion_request = UserDeletionService.create_deletion_request(user)
return Response(
{
"success": True,
"message": "Verification code sent to your email",
"expires_at": deletion_request.expires_at,
"email": user.email,
200: {
"description": "Login history entries",
"example": {
"results": [
{
"id": 1,
"ip_address": "192.168.1.1",
"user_agent": "Mozilla/5.0...",
"login_method": "PASSWORD",
"login_method_display": "Password",
"login_timestamp": "2024-12-27T10:30:00Z",
"country": "United States",
"city": "New York",
}
],
"count": 1,
},
status=status.HTTP_200_OK,
)
},
401: {"description": "Authentication required"},
},
tags=["User Security"],
)
@api_view(["GET"])
@permission_classes([IsAuthenticated])
def get_login_history(request):
"""Get user login history for security auditing."""
from apps.accounts.login_history import LoginHistory
user = request.user
limit = min(int(request.query_params.get("limit", 20)), 100)
# Get login history for user
entries = LoginHistory.objects.filter(user=user).order_by("-login_timestamp")[:limit]
# Serialize
results = []
for entry in entries:
results.append({
"id": entry.id,
"ip_address": entry.ip_address,
"user_agent": entry.user_agent[:100] if entry.user_agent else None, # Truncate long user agents
"login_method": entry.login_method,
"login_method_display": dict(LoginHistory._meta.get_field('login_method').choices).get(entry.login_method, entry.login_method),
"login_timestamp": entry.login_timestamp.isoformat(),
"country": entry.country,
"city": entry.city,
"success": entry.success,
})
return Response({
"results": results,
"count": len(results),
})
except ValueError as e:
return Response(
{"success": False, "error": str(e)},
status=status.HTTP_400_BAD_REQUEST,
)
except Exception as e:
return Response(
{"success": False, "error": f"Error creating deletion request: {str(e)}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)

View File

@@ -1,9 +1,14 @@
from rest_framework import viewsets, permissions, filters
from django.db import transaction
from django_filters.rest_framework import DjangoFilterBackend
from apps.rides.models.credits import RideCredit
from apps.api.v1.serializers.ride_credits import RideCreditSerializer
from drf_spectacular.utils import extend_schema, OpenApiParameter
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema
from rest_framework import filters, permissions, status, viewsets
from rest_framework.decorators import action
from rest_framework.response import Response
from apps.api.v1.serializers.ride_credits import RideCreditSerializer
from apps.rides.models.credits import RideCredit
class RideCreditViewSet(viewsets.ModelViewSet):
"""
@@ -14,8 +19,8 @@ class RideCreditViewSet(viewsets.ModelViewSet):
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
filter_backends = [DjangoFilterBackend, filters.OrderingFilter]
filterset_fields = ['user__username', 'ride__park__slug', 'ride__manufacturer__slug']
ordering_fields = ['first_ridden_at', 'last_ridden_at', 'created_at', 'count', 'rating']
ordering = ['-last_ridden_at']
ordering_fields = ['first_ridden_at', 'last_ridden_at', 'created_at', 'count', 'rating', 'display_order']
ordering = ['display_order', '-last_ridden_at']
def get_queryset(self):
"""
@@ -23,18 +28,77 @@ class RideCreditViewSet(viewsets.ModelViewSet):
Optionally filter by user via query param ?user=username
"""
queryset = RideCredit.objects.all().select_related('ride', 'ride__park', 'user')
# Filter by user if provided
username = self.request.query_params.get('user')
if username:
queryset = queryset.filter(user__username=username)
return queryset
def perform_create(self, serializer):
"""Associate the current user with the ride credit."""
serializer.save(user=self.request.user)
@action(detail=False, methods=['post'], permission_classes=[permissions.IsAuthenticated])
@extend_schema(
summary="Reorder ride credits",
description="Bulk update the display order of ride credits. Send a list of {id, order} objects.",
request={
'application/json': {
'type': 'object',
'properties': {
'order': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'id': {'type': 'integer'},
'order': {'type': 'integer'}
},
'required': ['id', 'order']
}
}
}
}
}
)
def reorder(self, request):
"""
Bulk update display_order for multiple credits.
Expects: {"order": [{"id": 1, "order": 0}, {"id": 2, "order": 1}, ...]}
"""
order_data = request.data.get('order', [])
if not order_data:
return Response(
{'error': 'No order data provided'},
status=status.HTTP_400_BAD_REQUEST
)
# Validate that all credits belong to the current user
credit_ids = [item['id'] for item in order_data]
user_credits = RideCredit.objects.filter(
id__in=credit_ids,
user=request.user
).values_list('id', flat=True)
if set(credit_ids) != set(user_credits):
return Response(
{'error': 'You can only reorder your own credits'},
status=status.HTTP_403_FORBIDDEN
)
# Bulk update in a transaction
with transaction.atomic():
for item in order_data:
RideCredit.objects.filter(
id=item['id'],
user=request.user
).update(display_order=item['order'])
return Response({'status': 'reordered', 'count': len(order_data)})
@extend_schema(
summary="List ride credits",
description="List ride credits. filter by user username.",
@@ -49,3 +113,4 @@ class RideCreditViewSet(viewsets.ModelViewSet):
)
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)

View File

@@ -0,0 +1,180 @@
"""
Magic Link (Login by Code) API views.
Provides API endpoints for passwordless login via email code.
Uses django-allauth's built-in login-by-code functionality.
"""
from django.conf import settings
from drf_spectacular.utils import OpenApiExample, extend_schema
from rest_framework import status
from rest_framework.decorators import api_view, permission_classes
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
try:
from allauth.account.internal.flows.login_by_code import perform_login_by_code, request_login_code
from allauth.account.models import EmailAddress
from allauth.account.utils import user_email # noqa: F401 - imported to verify availability
HAS_LOGIN_BY_CODE = True
except ImportError:
HAS_LOGIN_BY_CODE = False
@extend_schema(
summary="Request magic link login code",
description="Send a one-time login code to the user's email address.",
request={
'application/json': {
'type': 'object',
'properties': {
'email': {'type': 'string', 'format': 'email'}
},
'required': ['email']
}
},
responses={
200: {'description': 'Login code sent successfully'},
400: {'description': 'Invalid email or feature disabled'},
},
examples=[
OpenApiExample(
'Request login code',
value={'email': 'user@example.com'},
request_only=True
)
]
)
@api_view(['POST'])
@permission_classes([AllowAny])
def request_magic_link(request):
"""
Request a login code to be sent to the user's email.
This is the first step of the magic link flow:
1. User enters their email
2. If the email exists, a code is sent
3. User enters the code to complete login
"""
if not getattr(settings, 'ACCOUNT_LOGIN_BY_CODE_ENABLED', False):
return Response(
{'error': 'Magic link login is not enabled'},
status=status.HTTP_400_BAD_REQUEST
)
if not HAS_LOGIN_BY_CODE:
return Response(
{'error': 'Login by code is not available in this version of allauth'},
status=status.HTTP_400_BAD_REQUEST
)
email = request.data.get('email', '').lower().strip()
if not email:
return Response(
{'error': 'Email is required'},
status=status.HTTP_400_BAD_REQUEST
)
# Check if email exists (don't reveal if it doesn't for security)
try:
email_address = EmailAddress.objects.get(email__iexact=email, verified=True)
user = email_address.user
# Request the login code
request_login_code(request._request, user)
return Response({
'success': True,
'message': 'If an account exists with this email, a login code has been sent.',
'timeout': getattr(settings, 'ACCOUNT_LOGIN_BY_CODE_TIMEOUT', 300)
})
except EmailAddress.DoesNotExist:
# Don't reveal that the email doesn't exist
return Response({
'success': True,
'message': 'If an account exists with this email, a login code has been sent.',
'timeout': getattr(settings, 'ACCOUNT_LOGIN_BY_CODE_TIMEOUT', 300)
})
@extend_schema(
summary="Verify magic link code",
description="Verify the login code and complete the login process.",
request={
'application/json': {
'type': 'object',
'properties': {
'email': {'type': 'string', 'format': 'email'},
'code': {'type': 'string'}
},
'required': ['email', 'code']
}
},
responses={
200: {'description': 'Login successful'},
400: {'description': 'Invalid or expired code'},
}
)
@api_view(['POST'])
@permission_classes([AllowAny])
def verify_magic_link(request):
"""
Verify the login code and complete the login.
This is the second step of the magic link flow.
"""
if not getattr(settings, 'ACCOUNT_LOGIN_BY_CODE_ENABLED', False):
return Response(
{'error': 'Magic link login is not enabled'},
status=status.HTTP_400_BAD_REQUEST
)
if not HAS_LOGIN_BY_CODE:
return Response(
{'error': 'Login by code is not available'},
status=status.HTTP_400_BAD_REQUEST
)
email = request.data.get('email', '').lower().strip()
code = request.data.get('code', '').strip()
if not email or not code:
return Response(
{'error': 'Email and code are required'},
status=status.HTTP_400_BAD_REQUEST
)
try:
email_address = EmailAddress.objects.get(email__iexact=email, verified=True)
user = email_address.user
# Attempt to verify the code and log in
success = perform_login_by_code(request._request, user, code)
if success:
return Response({
'success': True,
'message': 'Login successful',
'user': {
'id': user.id,
'username': user.username,
'email': user.email
}
})
else:
return Response(
{'error': 'Invalid or expired code. Please request a new one.'},
status=status.HTTP_400_BAD_REQUEST
)
except EmailAddress.DoesNotExist:
return Response(
{'error': 'Invalid email or code'},
status=status.HTTP_400_BAD_REQUEST
)
except Exception:
return Response(
{'error': 'Invalid or expired code. Please request a new one.'},
status=status.HTTP_400_BAD_REQUEST
)

View File

@@ -0,0 +1,385 @@
"""
MFA (Multi-Factor Authentication) API Views
Provides REST API endpoints for MFA operations using django-allauth's mfa module.
Supports TOTP (Time-based One-Time Password) authentication.
"""
import base64
from io import BytesIO
from django.conf import settings
from drf_spectacular.utils import extend_schema
from rest_framework import status
from rest_framework.decorators import api_view, permission_classes
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
try:
import qrcode
HAS_QRCODE = True
except ImportError:
HAS_QRCODE = False
@extend_schema(
operation_id="get_mfa_status",
summary="Get MFA status for current user",
description="Returns whether MFA is enabled and what methods are configured.",
responses={
200: {
"description": "MFA status",
"example": {
"mfa_enabled": True,
"totp_enabled": True,
"recovery_codes_count": 10,
},
},
},
tags=["MFA"],
)
@api_view(["GET"])
@permission_classes([IsAuthenticated])
def get_mfa_status(request):
"""Get MFA status for current user."""
from allauth.mfa.models import Authenticator
user = request.user
authenticators = Authenticator.objects.filter(user=user)
totp_enabled = authenticators.filter(type=Authenticator.Type.TOTP).exists()
recovery_enabled = authenticators.filter(type=Authenticator.Type.RECOVERY_CODES).exists()
# Count recovery codes if any
recovery_count = 0
if recovery_enabled:
try:
recovery_auth = authenticators.get(type=Authenticator.Type.RECOVERY_CODES)
recovery_count = len(recovery_auth.data.get("codes", []))
except Authenticator.DoesNotExist:
pass
return Response({
"mfa_enabled": totp_enabled,
"totp_enabled": totp_enabled,
"recovery_codes_enabled": recovery_enabled,
"recovery_codes_count": recovery_count,
})
@extend_schema(
operation_id="setup_totp",
summary="Initialize TOTP setup",
description="Generates a new TOTP secret and returns the QR code for scanning.",
responses={
200: {
"description": "TOTP setup data",
"example": {
"secret": "ABCDEFGHIJKLMNOP",
"provisioning_uri": "otpauth://totp/ThrillWiki:user@example.com?secret=...",
"qr_code_base64": "data:image/png;base64,...",
},
},
},
tags=["MFA"],
)
@api_view(["POST"])
@permission_classes([IsAuthenticated])
def setup_totp(request):
"""Generate TOTP secret and QR code for setup."""
from allauth.mfa.totp.internal import auth as totp_auth
user = request.user
# Generate TOTP secret
secret = totp_auth.get_totp_secret(None) # Generate new secret
# Build provisioning URI
issuer = getattr(settings, "MFA_TOTP_ISSUER", "ThrillWiki")
account_name = user.email or user.username
uri = f"otpauth://totp/{issuer}:{account_name}?secret={secret}&issuer={issuer}"
# Generate QR code if qrcode library is available
qr_code_base64 = None
if HAS_QRCODE:
qr = qrcode.make(uri)
buffer = BytesIO()
qr.save(buffer, format="PNG")
qr_code_base64 = f"data:image/png;base64,{base64.b64encode(buffer.getvalue()).decode()}"
# Store secret in session for later verification
request.session["pending_totp_secret"] = secret
return Response({
"secret": secret,
"provisioning_uri": uri,
"qr_code_base64": qr_code_base64,
})
@extend_schema(
operation_id="activate_totp",
summary="Activate TOTP with verification code",
description="Verifies the TOTP code and activates 2FA for the user.",
request={
"application/json": {
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "6-digit TOTP code from authenticator app",
"example": "123456",
}
},
"required": ["code"],
}
},
responses={
200: {
"description": "TOTP activated successfully",
"example": {
"success": True,
"message": "Two-factor authentication enabled",
"recovery_codes": ["ABCD1234", "EFGH5678"],
},
},
400: {"description": "Invalid code or missing setup data"},
},
tags=["MFA"],
)
@api_view(["POST"])
@permission_classes([IsAuthenticated])
def activate_totp(request):
"""Verify TOTP code and activate MFA."""
from allauth.mfa.models import Authenticator
from allauth.mfa.recovery_codes.internal import auth as recovery_auth
from allauth.mfa.totp.internal import auth as totp_auth
user = request.user
code = request.data.get("code", "").strip()
if not code:
return Response(
{"success": False, "error": "Verification code is required"},
status=status.HTTP_400_BAD_REQUEST,
)
# Get pending secret from session
secret = request.session.get("pending_totp_secret")
if not secret:
return Response(
{"success": False, "error": "No pending TOTP setup. Please start setup again."},
status=status.HTTP_400_BAD_REQUEST,
)
# Verify the code
if not totp_auth.validate_totp_code(secret, code):
return Response(
{"success": False, "error": "Invalid verification code"},
status=status.HTTP_400_BAD_REQUEST,
)
# Check if already has TOTP
if Authenticator.objects.filter(user=user, type=Authenticator.Type.TOTP).exists():
return Response(
{"success": False, "error": "TOTP is already enabled"},
status=status.HTTP_400_BAD_REQUEST,
)
# Create TOTP authenticator
Authenticator.objects.create(
user=user,
type=Authenticator.Type.TOTP,
data={"secret": secret},
)
# Generate recovery codes
codes = recovery_auth.generate_recovery_codes()
Authenticator.objects.create(
user=user,
type=Authenticator.Type.RECOVERY_CODES,
data={"codes": codes},
)
# Clear session
del request.session["pending_totp_secret"]
return Response({
"success": True,
"message": "Two-factor authentication enabled",
"recovery_codes": codes,
})
@extend_schema(
operation_id="deactivate_totp",
summary="Disable TOTP authentication",
description="Removes TOTP from the user's account after password verification.",
request={
"application/json": {
"type": "object",
"properties": {
"password": {
"type": "string",
"description": "Current password for confirmation",
}
},
"required": ["password"],
}
},
responses={
200: {
"description": "TOTP disabled",
"example": {"success": True, "message": "Two-factor authentication disabled"},
},
400: {"description": "Invalid password or MFA not enabled"},
},
tags=["MFA"],
)
@api_view(["POST"])
@permission_classes([IsAuthenticated])
def deactivate_totp(request):
"""Disable TOTP authentication."""
from allauth.mfa.models import Authenticator
user = request.user
password = request.data.get("password", "")
# Verify password
if not user.check_password(password):
return Response(
{"success": False, "error": "Invalid password"},
status=status.HTTP_400_BAD_REQUEST,
)
# Remove TOTP and recovery codes
deleted_count, _ = Authenticator.objects.filter(
user=user,
type__in=[Authenticator.Type.TOTP, Authenticator.Type.RECOVERY_CODES]
).delete()
if deleted_count == 0:
return Response(
{"success": False, "error": "Two-factor authentication is not enabled"},
status=status.HTTP_400_BAD_REQUEST,
)
return Response({
"success": True,
"message": "Two-factor authentication disabled",
})
@extend_schema(
operation_id="verify_totp",
summary="Verify TOTP code during login",
description="Verifies the TOTP code as part of the login process.",
request={
"application/json": {
"type": "object",
"properties": {
"code": {"type": "string", "description": "6-digit TOTP code"}
},
"required": ["code"],
}
},
responses={
200: {"description": "Code verified", "example": {"success": True}},
400: {"description": "Invalid code"},
},
tags=["MFA"],
)
@api_view(["POST"])
@permission_classes([IsAuthenticated])
def verify_totp(request):
"""Verify TOTP code."""
from allauth.mfa.models import Authenticator
from allauth.mfa.totp.internal import auth as totp_auth
user = request.user
code = request.data.get("code", "").strip()
if not code:
return Response(
{"success": False, "error": "Verification code is required"},
status=status.HTTP_400_BAD_REQUEST,
)
try:
authenticator = Authenticator.objects.get(user=user, type=Authenticator.Type.TOTP)
secret = authenticator.data.get("secret")
if totp_auth.validate_totp_code(secret, code):
return Response({"success": True})
else:
return Response(
{"success": False, "error": "Invalid verification code"},
status=status.HTTP_400_BAD_REQUEST,
)
except Authenticator.DoesNotExist:
return Response(
{"success": False, "error": "TOTP is not enabled"},
status=status.HTTP_400_BAD_REQUEST,
)
@extend_schema(
operation_id="regenerate_recovery_codes",
summary="Regenerate recovery codes",
description="Generates new recovery codes (invalidates old ones).",
request={
"application/json": {
"type": "object",
"properties": {
"password": {"type": "string", "description": "Current password"}
},
"required": ["password"],
}
},
responses={
200: {
"description": "New recovery codes",
"example": {"success": True, "recovery_codes": ["ABCD1234", "EFGH5678"]},
},
400: {"description": "Invalid password or MFA not enabled"},
},
tags=["MFA"],
)
@api_view(["POST"])
@permission_classes([IsAuthenticated])
def regenerate_recovery_codes(request):
"""Regenerate recovery codes."""
from allauth.mfa.models import Authenticator
from allauth.mfa.recovery_codes.internal import auth as recovery_auth
user = request.user
password = request.data.get("password", "")
# Verify password
if not user.check_password(password):
return Response(
{"success": False, "error": "Invalid password"},
status=status.HTTP_400_BAD_REQUEST,
)
# Check if TOTP is enabled
if not Authenticator.objects.filter(user=user, type=Authenticator.Type.TOTP).exists():
return Response(
{"success": False, "error": "Two-factor authentication is not enabled"},
status=status.HTTP_400_BAD_REQUEST,
)
# Generate new codes
codes = recovery_auth.generate_recovery_codes()
# Update or create recovery codes authenticator
authenticator, created = Authenticator.objects.update_or_create(
user=user,
type=Authenticator.Type.RECOVERY_CODES,
defaults={"data": {"codes": codes}},
)
return Response({
"success": True,
"recovery_codes": codes,
})

View File

@@ -5,21 +5,21 @@ This module contains all serializers related to authentication, user accounts,
profiles, top lists, and user statistics.
"""
from typing import Any, Dict
from rest_framework import serializers
from drf_spectacular.utils import (
extend_schema_serializer,
extend_schema_field,
OpenApiExample,
)
from django.contrib.auth.password_validation import validate_password
from django.utils.crypto import get_random_string
from django.contrib.auth import get_user_model
from django.utils import timezone
from datetime import timedelta
from apps.accounts.models import PasswordReset
from typing import Any
from django.contrib.auth import get_user_model
from django.contrib.auth.password_validation import validate_password
from django.utils import timezone
from django.utils.crypto import get_random_string
from drf_spectacular.utils import (
OpenApiExample,
extend_schema_field,
extend_schema_serializer,
)
from rest_framework import serializers
from apps.accounts.models import PasswordReset
UserModel = get_user_model()
@@ -192,11 +192,13 @@ class SignupInputSerializer(serializers.ModelSerializer):
def _send_verification_email(self, user):
"""Send email verification to the user."""
from apps.accounts.models import EmailVerification
import logging
from django.contrib.sites.shortcuts import get_current_site
from django.utils.crypto import get_random_string
from django_forwardemail.services import EmailService
from django.contrib.sites.shortcuts import get_current_site
import logging
from apps.accounts.models import EmailVerification
logger = logging.getLogger(__name__)
@@ -436,7 +438,7 @@ class UserProfileOutputSerializer(serializers.Serializer):
return obj.get_avatar_url()
@extend_schema_field(serializers.DictField())
def get_user(self, obj) -> Dict[str, Any]:
def get_user(self, obj) -> dict[str, Any]:
return {
"username": obj.user.username,
"date_joined": obj.user.date_joined,

View File

@@ -6,15 +6,15 @@ Main authentication serializers are imported directly from the parent serializer
"""
from .social import (
ConnectedProviderSerializer,
AvailableProviderSerializer,
SocialAuthStatusSerializer,
ConnectedProviderSerializer,
ConnectedProvidersListOutputSerializer,
ConnectProviderInputSerializer,
ConnectProviderOutputSerializer,
DisconnectProviderOutputSerializer,
SocialProviderListOutputSerializer,
ConnectedProvidersListOutputSerializer,
SocialAuthStatusSerializer,
SocialProviderErrorSerializer,
SocialProviderListOutputSerializer,
)
__all__ = [

View File

@@ -5,8 +5,8 @@ Serializers for handling social provider connection/disconnection requests
and responses in the ThrillWiki API.
"""
from rest_framework import serializers
from django.contrib.auth import get_user_model
from rest_framework import serializers
User = get_user_model()

View File

@@ -5,29 +5,30 @@ This module contains URL patterns for core authentication functionality only.
User profiles and top lists are handled by the dedicated accounts app.
"""
from django.urls import path, include
from django.urls import include, path
from rest_framework_simplejwt.views import TokenRefreshView
from . import mfa as mfa_views
from .views import (
# Main auth views
LoginAPIView,
SignupAPIView,
LogoutAPIView,
CurrentUserAPIView,
PasswordResetAPIView,
PasswordChangeAPIView,
SocialProvidersAPIView,
AuthStatusAPIView,
# Email verification views
EmailVerificationAPIView,
ResendVerificationAPIView,
# Social provider management views
AvailableProvidersAPIView,
ConnectedProvidersAPIView,
ConnectProviderAPIView,
CurrentUserAPIView,
DisconnectProviderAPIView,
# Email verification views
EmailVerificationAPIView,
# Main auth views
LoginAPIView,
LogoutAPIView,
PasswordChangeAPIView,
PasswordResetAPIView,
ResendVerificationAPIView,
SignupAPIView,
SocialAuthStatusAPIView,
SocialProvidersAPIView,
)
from rest_framework_simplejwt.views import TokenRefreshView
urlpatterns = [
# Core authentication endpoints
@@ -98,6 +99,14 @@ urlpatterns = [
ResendVerificationAPIView.as_view(),
name="auth-resend-verification",
),
# MFA (Multi-Factor Authentication) endpoints
path("mfa/status/", mfa_views.get_mfa_status, name="auth-mfa-status"),
path("mfa/totp/setup/", mfa_views.setup_totp, name="auth-mfa-totp-setup"),
path("mfa/totp/activate/", mfa_views.activate_totp, name="auth-mfa-totp-activate"),
path("mfa/totp/deactivate/", mfa_views.deactivate_totp, name="auth-mfa-totp-deactivate"),
path("mfa/totp/verify/", mfa_views.verify_totp, name="auth-mfa-totp-verify"),
path("mfa/recovery-codes/regenerate/", mfa_views.regenerate_recovery_codes, name="auth-mfa-recovery-regenerate"),
]
# Note: User profiles and top lists functionality is now handled by the accounts app

View File

@@ -6,44 +6,46 @@ login, signup, logout, password management, social authentication,
user profiles, and top lists.
"""
from .serializers_package.social import (
ConnectedProviderSerializer,
AvailableProviderSerializer,
SocialAuthStatusSerializer,
ConnectProviderInputSerializer,
ConnectProviderOutputSerializer,
DisconnectProviderOutputSerializer,
SocialProviderErrorSerializer,
)
from apps.accounts.services.social_provider_service import SocialProviderService
from django.contrib.auth import authenticate, login, logout, get_user_model
from typing import cast # added 'cast'
from django.contrib.auth import authenticate, get_user_model, login, logout
from django.contrib.sites.shortcuts import get_current_site
from django.core.exceptions import ValidationError
from django.db.models import Q
from typing import Optional, cast # added 'cast'
from django.http import HttpRequest # new import
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import status
from rest_framework.views import APIView
from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.permissions import AllowAny, IsAuthenticated
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework.views import APIView
from apps.accounts.services.social_provider_service import SocialProviderService
# Import directly from the auth serializers.py file (not the serializers package)
from .serializers import (
AuthStatusOutputSerializer,
# Authentication serializers
LoginInputSerializer,
LoginOutputSerializer,
SignupInputSerializer,
SignupOutputSerializer,
LogoutOutputSerializer,
UserOutputSerializer,
PasswordResetInputSerializer,
PasswordResetOutputSerializer,
PasswordChangeInputSerializer,
PasswordChangeOutputSerializer,
PasswordResetInputSerializer,
PasswordResetOutputSerializer,
SignupInputSerializer,
SignupOutputSerializer,
SocialProviderOutputSerializer,
AuthStatusOutputSerializer,
UserOutputSerializer,
)
from .serializers_package.social import (
AvailableProviderSerializer,
ConnectedProviderSerializer,
ConnectProviderInputSerializer,
ConnectProviderOutputSerializer,
DisconnectProviderOutputSerializer,
SocialAuthStatusSerializer,
SocialProviderErrorSerializer,
)
# Handle optional dependencies with fallback classes
@@ -62,10 +64,7 @@ try:
# Ensure the imported object is a class/type that can be used as a base class.
# If it's not a type for any reason, fall back to the safe mixin.
if isinstance(_ImportedTurnstileMixin, type):
TurnstileMixin = _ImportedTurnstileMixin
else:
TurnstileMixin = FallbackTurnstileMixin
TurnstileMixin = _ImportedTurnstileMixin if isinstance(_ImportedTurnstileMixin, type) else FallbackTurnstileMixin
except Exception:
# Catch any import errors or unexpected exceptions and use the fallback mixin.
TurnstileMixin = FallbackTurnstileMixin
@@ -88,7 +87,7 @@ def _get_underlying_request(request: Request) -> HttpRequest:
# Helper: encapsulate user lookup + authenticate to reduce complexity in view
def _authenticate_user_by_lookup(
email_or_username: str, password: str, request: Request
) -> Optional[UserModel]:
) -> UserModel | None:
"""
Try a single optimized query to find a user by email OR username then authenticate.
Returns authenticated user or None.
@@ -199,7 +198,7 @@ class LoginAPIView(APIView):
else:
return Response(
{
"error": "Email verification required",
"error": "Email verification required",
"message": "Please verify your email address before logging in. Check your email for a verification link.",
"email_verification_required": True
},
@@ -246,7 +245,7 @@ class SignupAPIView(APIView):
serializer = SignupInputSerializer(data=request.data, context={"request": request})
if serializer.is_valid():
user = serializer.save()
# Don't log in the user immediately - they need to verify their email first
response_serializer = SignupOutputSerializer(
{
@@ -754,23 +753,23 @@ class EmailVerificationAPIView(APIView):
def get(self, request: Request, token: str) -> Response:
from apps.accounts.models import EmailVerification
try:
verification = EmailVerification.objects.select_related('user').get(token=token)
user = verification.user
# Activate the user
user.is_active = True
user.save()
# Delete the verification record
verification.delete()
return Response({
"message": "Email verified successfully. You can now log in.",
"success": True
})
except EmailVerification.DoesNotExist:
return Response(
{"error": "Invalid or expired verification token"},
@@ -798,45 +797,46 @@ class ResendVerificationAPIView(APIView):
authentication_classes = []
def post(self, request: Request) -> Response:
from apps.accounts.models import EmailVerification
from django.contrib.sites.shortcuts import get_current_site
from django.utils.crypto import get_random_string
from django_forwardemail.services import EmailService
from django.contrib.sites.shortcuts import get_current_site
from apps.accounts.models import EmailVerification
email = request.data.get('email')
if not email:
return Response(
{"error": "Email address is required"},
status=status.HTTP_400_BAD_REQUEST
)
try:
user = UserModel.objects.get(email__iexact=email.strip().lower())
# Don't resend if user is already active
if user.is_active:
return Response(
{"error": "Email is already verified"},
status=status.HTTP_400_BAD_REQUEST
)
# Create or update verification record
verification, created = EmailVerification.objects.get_or_create(
user=user,
defaults={'token': get_random_string(64)}
)
if not created:
# Update existing token and timestamp
verification.token = get_random_string(64)
verification.save()
# Send verification email
site = get_current_site(_get_underlying_request(request))
verification_url = request.build_absolute_uri(
f"/api/v1/auth/verify-email/{verification.token}/"
)
try:
EmailService.send_email(
to=user.email,
@@ -854,22 +854,22 @@ The ThrillWiki Team
""".strip(),
site=site,
)
return Response({
"message": "Verification email sent successfully",
"success": True
})
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.error(f"Failed to send verification email to {user.email}: {e}")
return Response(
{"error": "Failed to send verification email"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
except UserModel.DoesNotExist:
# Don't reveal whether email exists
return Response({

View File

@@ -4,6 +4,7 @@ Centralized from apps.core.urls
"""
from django.urls import path
from . import views
# Entity search endpoints - migrated from apps.core.urls

View File

@@ -8,18 +8,20 @@ Caching Strategy:
- EntityNotFoundView: No caching - POST requests with context-specific data
"""
from rest_framework.views import APIView
from rest_framework.response import Response
import contextlib
from drf_spectacular.utils import extend_schema
from rest_framework import status
from rest_framework.permissions import AllowAny
from typing import Optional, List
from drf_spectacular.utils import extend_schema
from rest_framework.response import Response
from rest_framework.views import APIView
from apps.core.services.entity_fuzzy_matching import (
entity_fuzzy_matcher,
EntityType,
)
from apps.core.decorators.cache_decorators import cache_api_response
from apps.core.services.entity_fuzzy_matching import (
EntityType,
entity_fuzzy_matcher,
)
class EntityFuzzySearchView(APIView):
@@ -199,10 +201,8 @@ class EntityNotFoundView(APIView):
# Determine entity types to search based on context
entity_types = []
if entity_type_hint:
try:
with contextlib.suppress(ValueError):
entity_types = [EntityType(entity_type_hint)]
except ValueError:
pass
# If we have park context, prioritize ride searches
if context.get("park_slug") and not entity_types:
@@ -344,7 +344,7 @@ class QuickEntitySuggestionView(APIView):
# Utility function for other views to use
def get_entity_suggestions(
query: str, entity_types: Optional[List[str]] = None, user=None
query: str, entity_types: list[str] | None = None, user=None
):
"""
Utility function for other Django views to get entity suggestions.

View File

@@ -4,6 +4,7 @@ Centralized from apps.email_service.urls
"""
from django.urls import path
from . import views
urlpatterns = [

View File

@@ -3,13 +3,13 @@ Centralized email service API views.
Migrated from apps.email_service.views
"""
from rest_framework.views import APIView
from rest_framework.response import Response
from django.contrib.sites.shortcuts import get_current_site
from django_forwardemail.services import EmailService
from drf_spectacular.utils import extend_schema
from rest_framework import status
from rest_framework.permissions import AllowAny
from django.contrib.sites.shortcuts import get_current_site
from drf_spectacular.utils import extend_schema
from django_forwardemail.services import EmailService
from rest_framework.response import Response
from rest_framework.views import APIView
@extend_schema(

View File

@@ -4,7 +4,7 @@ History API URLs
URL patterns for history-related API endpoints.
"""
from django.urls import path, include
from django.urls import include, path
from rest_framework.routers import DefaultRouter
from .views import (

View File

@@ -5,18 +5,21 @@ This module provides ViewSets for accessing historical data and change tracking
across all models in the ThrillWiki system using django-pghistory.
"""
from drf_spectacular.utils import extend_schema, extend_schema_view, OpenApiParameter
from collections.abc import Sequence
from datetime import datetime
from typing import cast
import pghistory.models
from django.db.models import Count, QuerySet
from django.shortcuts import get_object_or_404
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view
from rest_framework import serializers as drf_serializers
from rest_framework.filters import OrderingFilter
from rest_framework.permissions import AllowAny
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import ReadOnlyModelViewSet
from rest_framework.request import Request
from typing import Optional, cast, Sequence
from django.shortcuts import get_object_or_404
from django.db.models import Count, QuerySet
import pghistory.models
from datetime import datetime
# Import models
from apps.parks.models import Park
@@ -24,7 +27,6 @@ from apps.rides.models import Ride
# Import serializers
from .. import serializers as history_serializers
from rest_framework import serializers as drf_serializers
# Minimal fallback serializer used when a specific serializer symbol is missing.
@@ -79,7 +81,7 @@ ALL_TRACKED_MODELS: Sequence[str] = [
# --- Helper utilities to reduce duplicated logic / cognitive complexity ---
def _parse_date(date_str: Optional[str]) -> Optional[datetime]:
def _parse_date(date_str: str | None) -> datetime | None:
if not date_str:
return None
try:

View File

@@ -1,4 +1,5 @@
from django.urls import path
from .views import GenerateUploadURLView
urlpatterns = [

View File

@@ -1,12 +1,14 @@
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework.permissions import IsAuthenticated
from rest_framework import status
from apps.core.utils.cloudflare import get_direct_upload_url
from django.core.exceptions import ImproperlyConfigured
import requests
import logging
import requests
from django.core.exceptions import ImproperlyConfigured
from rest_framework import status
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.views import APIView
from apps.core.utils.cloudflare import get_direct_upload_url
logger = logging.getLogger(__name__)
class GenerateUploadURLView(APIView):
@@ -29,7 +31,7 @@ class GenerateUploadURLView(APIView):
{"detail": "Failed to generate upload URL."},
status=status.HTTP_502_BAD_GATEWAY
)
except Exception as e:
except Exception:
logger.exception("Unexpected error generating upload URL")
return Response(
{"detail": "An unexpected error occurred."},

View File

@@ -4,6 +4,7 @@ Migrated from apps.core.urls.map_urls to centralized API structure.
"""
from django.urls import path
from . import views
# Map API endpoints - migrated from apps.core.urls.map_urls

View File

@@ -12,30 +12,31 @@ Caching Strategy:
import logging
from django.core.cache import cache
from django.http import HttpRequest
from django.db.models import Q
from django.contrib.gis.geos import Polygon
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
from rest_framework.permissions import AllowAny, IsAdminUser
from django.core.cache import cache
from django.db.models import Q
from django.http import HttpRequest
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import (
OpenApiExample,
OpenApiParameter,
extend_schema,
extend_schema_view,
OpenApiParameter,
OpenApiExample,
)
from drf_spectacular.types import OpenApiTypes
from rest_framework import status
from rest_framework.permissions import AllowAny, IsAdminUser
from rest_framework.response import Response
from rest_framework.views import APIView
from apps.core.decorators.cache_decorators import cache_api_response
from apps.core.services.enhanced_cache_service import EnhancedCacheService
from apps.parks.models import Park
from apps.rides.models import Ride
from apps.core.services.enhanced_cache_service import EnhancedCacheService
from apps.core.decorators.cache_decorators import cache_api_response
from ..serializers.maps import (
MapLocationDetailSerializer,
MapLocationsResponseSerializer,
MapSearchResponseSerializer,
MapLocationDetailSerializer,
)
logger = logging.getLogger(__name__)

View File

@@ -7,7 +7,8 @@ TypeScript interfaces, providing immediate feedback during development.
import json
import logging
from typing import Dict, Any
from typing import Any
from django.conf import settings
from django.http import JsonResponse
from django.utils.deprecation import MiddlewareMixin
@@ -19,52 +20,49 @@ logger = logging.getLogger(__name__)
class ContractValidationMiddleware(MiddlewareMixin):
"""
Development-only middleware that validates API responses against expected contracts.
This middleware:
1. Checks all API responses for contract compliance
2. Logs warnings when responses don't match expected TypeScript interfaces
3. Specifically validates filter metadata structure
4. Alerts when categorical filters are strings instead of objects
Only active when DEBUG=True to avoid performance impact in production.
"""
def __init__(self, get_response):
super().__init__(get_response)
self.get_response = get_response
self.enabled = getattr(settings, 'DEBUG', False)
if self.enabled:
logger.info("Contract validation middleware enabled (DEBUG mode)")
def process_response(self, request, response):
"""Process API responses to check for contract violations."""
if not self.enabled:
return response
# Only validate API endpoints
if not request.path.startswith('/api/'):
return response
# Only validate JSON responses
if not isinstance(response, (JsonResponse, Response)):
return response
# Only validate successful responses (2xx status codes)
if not (200 <= response.status_code < 300):
return response
try:
# Get response data
if isinstance(response, Response):
data = response.data
else:
data = json.loads(response.content.decode('utf-8'))
data = response.data if isinstance(response, Response) else json.loads(response.content.decode('utf-8'))
# Validate the response
self._validate_response_contract(request.path, data)
except Exception as e:
# Log validation errors but don't break the response
logger.warning(
@@ -76,55 +74,55 @@ class ContractValidationMiddleware(MiddlewareMixin):
'validation_error': str(e)
}
)
return response
def _validate_response_contract(self, path: str, data: Any) -> None:
"""Validate response data against expected contracts."""
# Check for filter metadata endpoints
if 'filter-options' in path or 'filter_options' in path:
self._validate_filter_metadata(path, data)
# Check for hybrid filtering endpoints
if 'hybrid' in path:
self._validate_hybrid_response(path, data)
# Check for pagination responses
if isinstance(data, dict) and 'results' in data:
self._validate_pagination_response(path, data)
# Check for common contract violations
self._validate_common_patterns(path, data)
def _validate_filter_metadata(self, path: str, data: Any) -> None:
"""Validate filter metadata structure."""
if not isinstance(data, dict):
self._log_contract_violation(
path,
path,
"FILTER_METADATA_NOT_DICT",
f"Filter metadata should be a dictionary, got {type(data).__name__}"
)
return
# Check for categorical filters
if 'categorical' in data:
categorical = data['categorical']
if isinstance(categorical, dict):
for filter_name, filter_options in categorical.items():
self._validate_categorical_filter(path, filter_name, filter_options)
# Check for ranges
if 'ranges' in data:
ranges = data['ranges']
if isinstance(ranges, dict):
for range_name, range_data in ranges.items():
self._validate_range_filter(path, range_name, range_data)
def _validate_categorical_filter(self, path: str, filter_name: str, filter_options: Any) -> None:
"""Validate categorical filter options format."""
if not isinstance(filter_options, list):
self._log_contract_violation(
path,
@@ -132,7 +130,7 @@ class ContractValidationMiddleware(MiddlewareMixin):
f"Categorical filter '{filter_name}' should be an array, got {type(filter_options).__name__}"
)
return
for i, option in enumerate(filter_options):
if isinstance(option, str):
# CRITICAL: This is the main contract violation we're trying to catch
@@ -163,10 +161,10 @@ class ContractValidationMiddleware(MiddlewareMixin):
"INVALID_COUNT_TYPE",
f"Categorical filter '{filter_name}' option {i} 'count' should be a number, got {type(option['count']).__name__}"
)
def _validate_range_filter(self, path: str, range_name: str, range_data: Any) -> None:
"""Validate range filter format."""
if not isinstance(range_data, dict):
self._log_contract_violation(
path,
@@ -174,7 +172,7 @@ class ContractValidationMiddleware(MiddlewareMixin):
f"Range filter '{range_name}' should be an object, got {type(range_data).__name__}"
)
return
# Check required properties
required_props = ['min', 'max']
for prop in required_props:
@@ -184,7 +182,7 @@ class ContractValidationMiddleware(MiddlewareMixin):
"MISSING_RANGE_PROPERTY",
f"Range filter '{range_name}' missing required property '{prop}'"
)
# Check step property
if 'step' in range_data and not isinstance(range_data['step'], (int, float)):
self._log_contract_violation(
@@ -192,13 +190,13 @@ class ContractValidationMiddleware(MiddlewareMixin):
"INVALID_STEP_TYPE",
f"Range filter '{range_name}' 'step' should be a number, got {type(range_data['step']).__name__}"
)
def _validate_hybrid_response(self, path: str, data: Any) -> None:
"""Validate hybrid filtering response structure."""
if not isinstance(data, dict):
return
# Check for strategy field
if 'strategy' in data:
strategy = data['strategy']
@@ -208,14 +206,14 @@ class ContractValidationMiddleware(MiddlewareMixin):
"INVALID_STRATEGY_VALUE",
f"Hybrid response strategy should be 'client_side' or 'server_side', got '{strategy}'"
)
# Check filter_metadata structure
if 'filter_metadata' in data:
self._validate_filter_metadata(path, data['filter_metadata'])
def _validate_pagination_response(self, path: str, data: Dict[str, Any]) -> None:
def _validate_pagination_response(self, path: str, data: dict[str, Any]) -> None:
"""Validate pagination response structure."""
# Check for required pagination fields
required_fields = ['count', 'results']
for field in required_fields:
@@ -225,7 +223,7 @@ class ContractValidationMiddleware(MiddlewareMixin):
"MISSING_PAGINATION_FIELD",
f"Pagination response missing required field '{field}'"
)
# Check results is array
if 'results' in data and not isinstance(data['results'], list):
self._log_contract_violation(
@@ -233,17 +231,17 @@ class ContractValidationMiddleware(MiddlewareMixin):
"RESULTS_NOT_ARRAY",
f"Pagination 'results' should be an array, got {type(data['results']).__name__}"
)
def _validate_common_patterns(self, path: str, data: Any) -> None:
"""Validate common API response patterns."""
if isinstance(data, dict):
# Check for null vs undefined issues
for key, value in data.items():
if value is None and key.endswith('_id'):
# ID fields should probably be null, not undefined
continue
# Check for numeric fields that might be strings
if key.endswith('_count') and isinstance(value, str):
try:
@@ -255,16 +253,16 @@ class ContractValidationMiddleware(MiddlewareMixin):
)
except ValueError:
pass
def _log_contract_violation(
self,
path: str,
violation_type: str,
message: str,
self,
path: str,
violation_type: str,
message: str,
severity: str = "WARNING"
) -> None:
"""Log a contract violation with structured data."""
log_data = {
'contract_violation': True,
'violation_type': violation_type,
@@ -273,15 +271,15 @@ class ContractValidationMiddleware(MiddlewareMixin):
'message': message,
'suggestion': self._get_violation_suggestion(violation_type)
}
if severity == "ERROR":
logger.error(f"CONTRACT VIOLATION [{violation_type}]: {message}", extra=log_data)
else:
logger.warning(f"CONTRACT VIOLATION [{violation_type}]: {message}", extra=log_data)
def _get_violation_suggestion(self, violation_type: str) -> str:
"""Get suggestion for fixing a contract violation."""
suggestions = {
"CATEGORICAL_OPTION_IS_STRING": (
"Convert string arrays to object arrays with {value, label, count} structure. "
@@ -308,31 +306,31 @@ class ContractValidationMiddleware(MiddlewareMixin):
"Check serializer implementation."
)
}
return suggestions.get(violation_type, "Check the API response format against frontend TypeScript interfaces.")
class ContractValidationSettings:
"""Settings for contract validation middleware."""
# Enable/disable specific validation checks
VALIDATE_FILTER_METADATA = True
VALIDATE_PAGINATION = True
VALIDATE_HYBRID_RESPONSES = True
VALIDATE_COMMON_PATTERNS = True
# Severity levels for different violations
CATEGORICAL_STRING_SEVERITY = "ERROR" # This is the critical issue
MISSING_PROPERTY_SEVERITY = "WARNING"
TYPE_MISMATCH_SEVERITY = "WARNING"
# Paths to exclude from validation
EXCLUDED_PATHS = [
'/api/docs/',
'/api/schema/',
'/api/v1/auth/', # Auth endpoints might have different structures
]
@classmethod
def should_validate_path(cls, path: str) -> bool:
"""Check if a path should be validated."""

View File

@@ -2,14 +2,16 @@
Park history API views.
"""
from rest_framework import viewsets, mixins
from rest_framework.response import Response
from rest_framework.permissions import AllowAny
from django.shortcuts import get_object_or_404
from drf_spectacular.utils import extend_schema
from rest_framework import viewsets
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from apps.api.v1.serializers.history import ParkHistoryOutputSerializer, RideHistoryOutputSerializer
from apps.parks.models import Park
from apps.rides.models import Ride
from apps.api.v1.serializers.history import ParkHistoryOutputSerializer, RideHistoryOutputSerializer
class ParkHistoryViewSet(viewsets.GenericViewSet):
"""
@@ -18,7 +20,7 @@ class ParkHistoryViewSet(viewsets.GenericViewSet):
permission_classes = [AllowAny]
lookup_field = "slug"
lookup_url_kwarg = "park_slug"
@extend_schema(
summary="Get park history",
description="Retrieve history events for a park.",
@@ -27,24 +29,24 @@ class ParkHistoryViewSet(viewsets.GenericViewSet):
)
def list(self, request, park_slug=None):
park = get_object_or_404(Park, slug=park_slug)
events = []
if hasattr(park, "events"):
events = park.events.all().order_by("-pgh_created_at")
summary = {
"total_events": len(events),
"first_recorded": events.last().pgh_created_at if len(events) else None,
"last_modified": events.first().pgh_created_at if len(events) else None,
}
data = {
"park": park,
"current_state": park,
"summary": summary,
"events": events
}
serializer = ParkHistoryOutputSerializer(data)
return Response(serializer.data)

View File

@@ -6,27 +6,26 @@ Provides CRUD operations for park reviews nested under parks/{slug}/reviews/
"""
import logging
from django.core.exceptions import PermissionDenied
from django.db.models import Avg
from django.utils import timezone
from drf_spectacular.utils import extend_schema_view, extend_schema
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import status
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError, NotFound
from rest_framework.permissions import IsAuthenticated, AllowAny
from rest_framework.exceptions import NotFound, ValidationError
from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from apps.parks.models import Park, ParkReview
from apps.api.v1.serializers.park_reviews import (
ParkReviewOutputSerializer,
ParkReviewCreateInputSerializer,
ParkReviewUpdateInputSerializer,
ParkReviewListOutputSerializer,
ParkReviewOutputSerializer,
ParkReviewStatsOutputSerializer,
ParkReviewModerationInputSerializer,
ParkReviewUpdateInputSerializer,
)
from apps.parks.models import Park, ParkReview
logger = logging.getLogger(__name__)
@@ -66,10 +65,7 @@ class ParkReviewViewSet(ModelViewSet):
def get_permissions(self):
"""Set permissions based on action."""
if self.action in ['list', 'retrieve', 'stats']:
permission_classes = [AllowAny]
else:
permission_classes = [IsAuthenticated]
permission_classes = [AllowAny] if self.action in ['list', 'retrieve', 'stats'] else [IsAuthenticated]
return [permission() for permission in permission_classes]
def get_queryset(self):
@@ -143,7 +139,7 @@ class ParkReviewViewSet(ModelViewSet):
reviews = ParkReview.objects.filter(park=park, is_published=True)
total_reviews = reviews.count()
avg_rating = reviews.aggregate(avg=Avg('rating'))['avg']
rating_distribution = {}
for i in range(1, 11):
rating_distribution[str(i)] = reviews.filter(rating=i).count()

View File

@@ -6,19 +6,16 @@ This module implements endpoints for accessing rides within specific parks:
- GET /parks/{park_slug}/rides/{ride_slug}/ - Get specific ride details within park context
"""
from typing import Any
from django.db import models
from django.db.models import Q, Count, Avg
from django.db.models import Q
from django.db.models.query import QuerySet
from rest_framework import status, permissions
from rest_framework.views import APIView
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema
from rest_framework import permissions, status
from rest_framework.exceptions import NotFound
from rest_framework.pagination import PageNumberPagination
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.pagination import PageNumberPagination
from rest_framework.exceptions import NotFound
from drf_spectacular.utils import extend_schema, OpenApiParameter
from drf_spectacular.types import OpenApiTypes
from rest_framework.views import APIView
# Import models
try:
@@ -32,8 +29,8 @@ except Exception:
# Import serializers
try:
from apps.api.v1.serializers.rides import RideListOutputSerializer, RideDetailOutputSerializer
from apps.api.v1.serializers.parks import ParkDetailOutputSerializer
from apps.api.v1.serializers.rides import RideDetailOutputSerializer, RideListOutputSerializer
SERIALIZERS_AVAILABLE = True
except Exception:
SERIALIZERS_AVAILABLE = False
@@ -47,7 +44,7 @@ class StandardResultsSetPagination(PageNumberPagination):
class ParkRidesListAPIView(APIView):
"""List rides at a specific park with pagination and filtering."""
permission_classes = [permissions.AllowAny]
@extend_schema(
@@ -59,7 +56,7 @@ class ParkRidesListAPIView(APIView):
type=OpenApiTypes.INT, description="Page number"),
OpenApiParameter(name="page_size", location=OpenApiParameter.QUERY,
type=OpenApiTypes.INT, description="Number of results per page (max 100)"),
# Filtering
OpenApiParameter(name="category", location=OpenApiParameter.QUERY,
type=OpenApiTypes.STR, description="Filter by ride category"),
@@ -67,7 +64,7 @@ class ParkRidesListAPIView(APIView):
type=OpenApiTypes.STR, description="Filter by operational status"),
OpenApiParameter(name="search", location=OpenApiParameter.QUERY,
type=OpenApiTypes.STR, description="Search rides by name"),
# Ordering
OpenApiParameter(name="ordering", location=OpenApiParameter.QUERY,
type=OpenApiTypes.STR, description="Order results by field"),
@@ -158,7 +155,7 @@ class ParkRidesListAPIView(APIView):
class ParkRideDetailAPIView(APIView):
"""Get specific ride details within park context."""
permission_classes = [permissions.AllowAny]
@extend_schema(
@@ -222,7 +219,7 @@ class ParkRideDetailAPIView(APIView):
class ParkComprehensiveDetailAPIView(APIView):
"""Get comprehensive park details including summary of rides."""
permission_classes = [permissions.AllowAny]
@extend_schema(
@@ -271,7 +268,7 @@ class ParkComprehensiveDetailAPIView(APIView):
rides_serializer = RideListOutputSerializer(
rides_sample, many=True, context={"request": request, "park": park}
)
# Enhance response with rides data
park_data["rides_summary"] = {
"total_count": park.ride_count or 0,

View File

@@ -11,23 +11,24 @@ This module implements comprehensive park endpoints with full filtering support:
Supports all 24 filtering parameters from frontend API documentation.
"""
import contextlib
from typing import Any
from django.db import models
from django.db.models import Q, Count, Avg
from django.db.models.query import QuerySet
from rest_framework import status, permissions
from rest_framework.views import APIView
from django.db import models
from django.db.models import Avg, Count, Q
from django.db.models.query import QuerySet
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema
from rest_framework import permissions, status
from rest_framework.exceptions import NotFound
from rest_framework.pagination import PageNumberPagination
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.pagination import PageNumberPagination
from rest_framework.exceptions import NotFound
from drf_spectacular.utils import extend_schema, OpenApiParameter
from drf_spectacular.types import OpenApiTypes
from rest_framework.views import APIView
# Import models
try:
from apps.parks.models import Park, Company
from apps.parks.models import Company, Park
MODELS_AVAILABLE = True
except Exception:
Park = None # type: ignore
@@ -45,11 +46,11 @@ except Exception:
# Import serializers
try:
from apps.api.v1.serializers.parks import (
ParkListOutputSerializer,
ParkDetailOutputSerializer,
ParkCreateInputSerializer,
ParkUpdateInputSerializer,
ParkDetailOutputSerializer,
ParkImageSettingsInputSerializer,
ParkListOutputSerializer,
ParkUpdateInputSerializer,
)
SERIALIZERS_AVAILABLE = True
except Exception:
@@ -247,12 +248,12 @@ class ParkListCreateAPIView(APIView):
'city': 'location__city__iexact',
'continent': 'location__continent__iexact'
}
for param_name, filter_field in location_filters.items():
value = params.get(param_name)
if value:
qs = qs.filter(**{filter_field: value})
return qs
def _apply_park_attribute_filters(self, qs: QuerySet, params: dict) -> QuerySet:
@@ -264,7 +265,7 @@ class ParkListCreateAPIView(APIView):
status_filter = params.get("status")
if status_filter:
qs = qs.filter(status=status_filter)
return qs
def _apply_company_filters(self, qs: QuerySet, params: dict) -> QuerySet:
@@ -275,73 +276,59 @@ class ParkListCreateAPIView(APIView):
'property_owner_id': 'property_owner_id',
'property_owner_slug': 'property_owner__slug'
}
for param_name, filter_field in company_filters.items():
value = params.get(param_name)
if value:
qs = qs.filter(**{filter_field: value})
return qs
def _apply_rating_filters(self, qs: QuerySet, params: dict) -> QuerySet:
"""Apply rating-based filtering to the queryset."""
min_rating = params.get("min_rating")
if min_rating:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(average_rating__gte=float(min_rating))
except (ValueError, TypeError):
pass
max_rating = params.get("max_rating")
if max_rating:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(average_rating__lte=float(max_rating))
except (ValueError, TypeError):
pass
return qs
def _apply_ride_count_filters(self, qs: QuerySet, params: dict) -> QuerySet:
"""Apply ride count filtering to the queryset."""
min_ride_count = params.get("min_ride_count")
if min_ride_count:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(ride_count__gte=int(min_ride_count))
except (ValueError, TypeError):
pass
max_ride_count = params.get("max_ride_count")
if max_ride_count:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(ride_count__lte=int(max_ride_count))
except (ValueError, TypeError):
pass
return qs
def _apply_opening_year_filters(self, qs: QuerySet, params: dict) -> QuerySet:
"""Apply opening year filtering to the queryset."""
opening_year = params.get("opening_year")
if opening_year:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(opening_date__year=int(opening_year))
except (ValueError, TypeError):
pass
min_opening_year = params.get("min_opening_year")
if min_opening_year:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(opening_date__year__gte=int(min_opening_year))
except (ValueError, TypeError):
pass
max_opening_year = params.get("max_opening_year")
if max_opening_year:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(opening_date__year__lte=int(max_opening_year))
except (ValueError, TypeError):
pass
return qs
def _apply_roller_coaster_filters(self, qs: QuerySet, params: dict) -> QuerySet:
@@ -355,18 +342,14 @@ class ParkListCreateAPIView(APIView):
min_roller_coaster_count = params.get("min_roller_coaster_count")
if min_roller_coaster_count:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(coaster_count__gte=int(min_roller_coaster_count))
except (ValueError, TypeError):
pass
max_roller_coaster_count = params.get("max_roller_coaster_count")
if max_roller_coaster_count:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(coaster_count__lte=int(max_roller_coaster_count))
except (ValueError, TypeError):
pass
return qs
@extend_schema(
@@ -440,13 +423,13 @@ class ParkDetailAPIView(APIView):
def _get_park_or_404(self, identifier: str) -> Any:
if not MODELS_AVAILABLE:
raise NotFound(
(
"Park detail is not available because domain models "
"are not imported. Implement apps.parks.models.Park "
"to enable detail endpoints."
)
)
# Try to parse as integer ID first
try:
pk = int(identifier)
@@ -475,36 +458,36 @@ class ParkDetailAPIView(APIView):
summary="Get park full details",
description="""
Retrieve comprehensive park details including:
**Core Information:**
- Basic park details (name, slug, description, status)
- Opening/closing dates and operating season
- Size in acres and website URL
- Statistics (average rating, ride count, coaster count)
**Location Data:**
- Full address with coordinates
- City, state, country information
- Formatted address string
**Company Information:**
- Operating company details
- Property owner information (if different)
**Media:**
- All approved photos with Cloudflare variants
- Primary photo designation
- Banner and card image settings
**Related Content:**
- Park areas/themed sections
- Associated rides (summary)
**Lookup Methods:**
- By ID: `/api/v1/parks/123/`
- By current slug: `/api/v1/parks/cedar-point/`
- By historical slug: `/api/v1/parks/old-cedar-point-name/`
**No Query Parameters Required** - This endpoint returns full details by default.
""",
responses={
@@ -598,11 +581,11 @@ class FilterOptionsAPIView(APIView):
"""Return comprehensive filter options with Rich Choice Objects metadata."""
# Import Rich Choice registry
from apps.core.choices.registry import get_choices
# Always get static choice definitions from Rich Choice Objects (primary source)
park_types = get_choices('types', 'parks')
statuses = get_choices('statuses', 'parks')
# Convert Rich Choice Objects to frontend format with metadata
park_types_data = [
{
@@ -616,7 +599,7 @@ class FilterOptionsAPIView(APIView):
}
for choice in park_types
]
statuses_data = [
{
"value": choice.value,
@@ -629,12 +612,12 @@ class FilterOptionsAPIView(APIView):
}
for choice in statuses
]
# Get dynamic data from database if models are available
if MODELS_AVAILABLE:
# Add any dynamic data queries here
pass
return Response({
"park_types": park_types_data,
"statuses": statuses_data,
@@ -707,7 +690,7 @@ class FilterOptionsAPIView(APIView):
# Get rich choice objects from registry
park_types = get_choices('types', 'parks')
statuses = get_choices('statuses', 'parks')
# Convert Rich Choice Objects to frontend format with metadata
park_types_data = [
{
@@ -721,7 +704,7 @@ class FilterOptionsAPIView(APIView):
}
for choice in park_types
]
statuses_data = [
{
"value": choice.value,
@@ -1118,7 +1101,7 @@ class OperatorListAPIView(APIView):
}
for op in operators
]
return Response({
"results": data,
"count": len(data)

View File

@@ -13,27 +13,27 @@ if TYPE_CHECKING:
from django.core.exceptions import PermissionDenied
from django.utils import timezone
from drf_spectacular.utils import extend_schema_view, extend_schema
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import status
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError, NotFound
from rest_framework.permissions import IsAuthenticated, AllowAny
from rest_framework.exceptions import NotFound, ValidationError
from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from apps.rides.models.media import RidePhoto
from apps.rides.models import Ride
from apps.parks.models import Park
from apps.rides.services.media_service import RideMediaService
from apps.api.v1.rides.serializers import (
RidePhotoOutputSerializer,
RidePhotoCreateInputSerializer,
RidePhotoUpdateInputSerializer,
RidePhotoListOutputSerializer,
RidePhotoApprovalInputSerializer,
RidePhotoCreateInputSerializer,
RidePhotoListOutputSerializer,
RidePhotoOutputSerializer,
RidePhotoStatsOutputSerializer,
RidePhotoUpdateInputSerializer,
)
from apps.parks.models import Park
from apps.rides.models import Ride
from apps.rides.models.media import RidePhoto
from apps.rides.services.media_service import RideMediaService
logger = logging.getLogger(__name__)
@@ -116,10 +116,7 @@ class RidePhotoViewSet(ModelViewSet):
def get_permissions(self):
"""Set permissions based on action."""
if self.action in ['list', 'retrieve', 'stats']:
permission_classes = [AllowAny]
else:
permission_classes = [IsAuthenticated]
permission_classes = [AllowAny] if self.action in ['list', 'retrieve', 'stats'] else [IsAuthenticated]
return [permission() for permission in permission_classes]
def get_queryset(self):
@@ -131,7 +128,7 @@ class RidePhotoViewSet(ModelViewSet):
# Filter by park and ride from URL kwargs
park_slug = self.kwargs.get("park_slug")
ride_slug = self.kwargs.get("ride_slug")
if park_slug and ride_slug:
try:
park, _ = Park.get_by_slug(park_slug)
@@ -158,7 +155,7 @@ class RidePhotoViewSet(ModelViewSet):
"""Create a new ride photo using RideMediaService."""
park_slug = self.kwargs.get("park_slug")
ride_slug = self.kwargs.get("ride_slug")
if not park_slug or not ride_slug:
raise ValidationError("Park and ride slugs are required")
@@ -185,7 +182,7 @@ class RidePhotoViewSet(ModelViewSet):
# Set the instance for the serializer response
serializer.instance = photo
logger.info(f"Created ride photo {photo.id} for ride {ride.name} by user {self.request.user.username}")
except Exception as e:
@@ -249,7 +246,7 @@ class RidePhotoViewSet(ModelViewSet):
RideMediaService.delete_photo(
instance, deleted_by=self.request.user
)
logger.info(f"Deleted ride photo {instance.id} by user {self.request.user.username}")
except Exception as e:
logger.error(f"Error deleting ride photo: {e}")
@@ -331,7 +328,7 @@ class RidePhotoViewSet(ModelViewSet):
validated_data = getattr(serializer, "validated_data", {})
photo_ids = validated_data.get("photo_ids")
approve = validated_data.get("approve")
park_slug = self.kwargs.get("park_slug")
ride_slug = self.kwargs.get("ride_slug")
@@ -381,7 +378,7 @@ class RidePhotoViewSet(ModelViewSet):
"""Get photo statistics for the ride."""
park_slug = self.kwargs.get("park_slug")
ride_slug = self.kwargs.get("ride_slug")
if not park_slug or not ride_slug:
return Response(
{"error": "Park and ride slugs are required"},
@@ -431,7 +428,7 @@ class RidePhotoViewSet(ModelViewSet):
"""Save a Cloudflare image as a ride photo after direct upload to Cloudflare."""
park_slug = self.kwargs.get("park_slug")
ride_slug = self.kwargs.get("ride_slug")
if not park_slug or not ride_slug:
return Response(
{"error": "Park and ride slugs are required"},

View File

@@ -12,28 +12,28 @@ if TYPE_CHECKING:
pass
from django.core.exceptions import PermissionDenied
from django.db.models import Avg, Count, Q
from django.db.models import Avg
from django.utils import timezone
from drf_spectacular.utils import extend_schema_view, extend_schema
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import status
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError, NotFound
from rest_framework.permissions import IsAuthenticated, AllowAny
from rest_framework.exceptions import NotFound, ValidationError
from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from apps.rides.models.reviews import RideReview
from apps.rides.models import Ride
from apps.parks.models import Park
from apps.api.v1.serializers.ride_reviews import (
RideReviewOutputSerializer,
RideReviewCreateInputSerializer,
RideReviewUpdateInputSerializer,
RideReviewListOutputSerializer,
RideReviewStatsOutputSerializer,
RideReviewModerationInputSerializer,
RideReviewOutputSerializer,
RideReviewStatsOutputSerializer,
RideReviewUpdateInputSerializer,
)
from apps.parks.models import Park
from apps.rides.models import Ride
from apps.rides.models.reviews import RideReview
logger = logging.getLogger(__name__)
@@ -115,10 +115,7 @@ class RideReviewViewSet(ModelViewSet):
def get_permissions(self):
"""Set permissions based on action."""
if self.action in ['list', 'retrieve', 'stats']:
permission_classes = [AllowAny]
else:
permission_classes = [IsAuthenticated]
permission_classes = [AllowAny] if self.action in ['list', 'retrieve', 'stats'] else [IsAuthenticated]
return [permission() for permission in permission_classes]
def get_queryset(self):
@@ -130,7 +127,7 @@ class RideReviewViewSet(ModelViewSet):
# Filter by park and ride from URL kwargs
park_slug = self.kwargs.get("park_slug")
ride_slug = self.kwargs.get("ride_slug")
if park_slug and ride_slug:
try:
park, _ = Park.get_by_slug(park_slug)
@@ -141,7 +138,7 @@ class RideReviewViewSet(ModelViewSet):
return queryset.none()
# Filter published reviews for non-staff users
if not (hasattr(self.request, 'user') and
if not (hasattr(self.request, 'user') and
getattr(self.request.user, 'is_staff', False)):
queryset = queryset.filter(is_published=True)
@@ -162,7 +159,7 @@ class RideReviewViewSet(ModelViewSet):
"""Create a new ride review."""
park_slug = self.kwargs.get("park_slug")
ride_slug = self.kwargs.get("ride_slug")
if not park_slug or not ride_slug:
raise ValidationError("Park and ride slugs are required")
@@ -185,7 +182,7 @@ class RideReviewViewSet(ModelViewSet):
user=self.request.user,
is_published=True # Auto-publish for now, can add moderation later
)
logger.info(f"Created ride review {review.id} for ride {ride.name} by user {self.request.user.username}")
except Exception as e:
@@ -241,7 +238,7 @@ class RideReviewViewSet(ModelViewSet):
"""Get review statistics for the ride."""
park_slug = self.kwargs.get("park_slug")
ride_slug = self.kwargs.get("ride_slug")
if not park_slug or not ride_slug:
return Response(
{"error": "Park and ride slugs are required"},
@@ -265,19 +262,19 @@ class RideReviewViewSet(ModelViewSet):
try:
# Get review statistics
reviews = RideReview.objects.filter(ride=ride, is_published=True)
total_reviews = reviews.count()
published_reviews = total_reviews # Since we're filtering published
pending_reviews = RideReview.objects.filter(ride=ride, is_published=False).count()
# Calculate average rating
avg_rating = reviews.aggregate(avg_rating=Avg('rating'))['avg_rating']
# Get rating distribution
rating_distribution = {}
for i in range(1, 11):
rating_distribution[str(i)] = reviews.filter(rating=i).count()
# Get recent reviews count (last 30 days)
from datetime import timedelta
thirty_days_ago = timezone.now() - timedelta(days=30)

View File

@@ -5,12 +5,13 @@ This module contains serializers for park-specific media functionality.
Enhanced from rogue implementation to maintain full feature parity.
"""
from rest_framework import serializers
from drf_spectacular.utils import (
OpenApiExample,
extend_schema_field,
extend_schema_serializer,
OpenApiExample,
)
from rest_framework import serializers
from apps.parks.models import Park, ParkPhoto
@@ -235,7 +236,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
Enhanced serializer for hybrid filtering strategy.
Includes all filterable fields for client-side filtering.
"""
# Location fields from related ParkLocation
city = serializers.SerializerMethodField()
state = serializers.SerializerMethodField()
@@ -243,19 +244,19 @@ class HybridParkSerializer(serializers.ModelSerializer):
continent = serializers.SerializerMethodField()
latitude = serializers.SerializerMethodField()
longitude = serializers.SerializerMethodField()
# Company fields
operator_name = serializers.CharField(source="operator.name", read_only=True)
property_owner_name = serializers.CharField(source="property_owner.name", read_only=True, allow_null=True)
# Image URLs for display
banner_image_url = serializers.SerializerMethodField()
card_image_url = serializers.SerializerMethodField()
# Computed fields for filtering
opening_year = serializers.IntegerField(read_only=True)
search_text = serializers.CharField(read_only=True)
@extend_schema_field(serializers.CharField(allow_null=True))
def get_city(self, obj):
"""Get city from related location."""
@@ -263,7 +264,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
return obj.location.city if hasattr(obj, 'location') and obj.location else None
except AttributeError:
return None
@extend_schema_field(serializers.CharField(allow_null=True))
def get_state(self, obj):
"""Get state from related location."""
@@ -271,7 +272,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
return obj.location.state if hasattr(obj, 'location') and obj.location else None
except AttributeError:
return None
@extend_schema_field(serializers.CharField(allow_null=True))
def get_country(self, obj):
"""Get country from related location."""
@@ -279,7 +280,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
return obj.location.country if hasattr(obj, 'location') and obj.location else None
except AttributeError:
return None
@extend_schema_field(serializers.CharField(allow_null=True))
def get_continent(self, obj):
"""Get continent from related location."""
@@ -287,7 +288,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
return obj.location.continent if hasattr(obj, 'location') and obj.location else None
except AttributeError:
return None
@extend_schema_field(serializers.FloatField(allow_null=True))
def get_latitude(self, obj):
"""Get latitude from related location."""
@@ -297,7 +298,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
return None
except (AttributeError, IndexError, TypeError):
return None
@extend_schema_field(serializers.FloatField(allow_null=True))
def get_longitude(self, obj):
"""Get longitude from related location."""
@@ -307,14 +308,14 @@ class HybridParkSerializer(serializers.ModelSerializer):
return None
except (AttributeError, IndexError, TypeError):
return None
@extend_schema_field(serializers.URLField(allow_null=True))
def get_banner_image_url(self, obj):
"""Get banner image URL."""
if obj.banner_image and obj.banner_image.image:
return obj.banner_image.image.url
return None
@extend_schema_field(serializers.URLField(allow_null=True))
def get_card_image_url(self, obj):
"""Get card image URL."""
@@ -332,42 +333,42 @@ class HybridParkSerializer(serializers.ModelSerializer):
"description",
"status",
"park_type",
# Dates and computed fields
"opening_date",
"closing_date",
"opening_year",
"operating_season",
# Location fields
"city",
"state",
"state",
"country",
"continent",
"latitude",
"longitude",
# Company relationships
"operator_name",
"property_owner_name",
# Statistics
"size_acres",
"average_rating",
"ride_count",
"coaster_count",
# Images
"banner_image_url",
"card_image_url",
# URLs
"website",
"url",
# Computed fields for filtering
"search_text",
# Metadata
"created_at",
"updated_at",

View File

@@ -6,28 +6,10 @@ intentionally expansive to match the rides API functionality and provide
complete feature parity for parks management.
"""
from django.urls import path, include
from django.urls import include, path
from rest_framework.routers import DefaultRouter
from .park_views import (
ParkListCreateAPIView,
ParkDetailAPIView,
FilterOptionsAPIView,
CompanySearchAPIView,
ParkSearchSuggestionsAPIView,
ParkImageSettingsAPIView,
OperatorListAPIView,
)
from .park_rides_views import (
ParkRidesListAPIView,
ParkRideDetailAPIView,
ParkComprehensiveDetailAPIView,
)
from apps.parks.views import location_search, reverse_geocode
from .views import ParkPhotoViewSet, HybridParkAPIView, ParkFilterMetadataAPIView
from .ride_photos_views import RidePhotoViewSet
from .ride_photos_views import RidePhotoViewSet
from .ride_reviews_views import RideReviewViewSet
from apps.parks.views_roadtrip import (
CreateTripView,
FindParksAlongRouteView,
@@ -35,6 +17,24 @@ from apps.parks.views_roadtrip import (
ParkDistanceCalculatorView,
)
from .park_rides_views import (
ParkComprehensiveDetailAPIView,
ParkRideDetailAPIView,
ParkRidesListAPIView,
)
from .park_views import (
CompanySearchAPIView,
FilterOptionsAPIView,
OperatorListAPIView,
ParkDetailAPIView,
ParkImageSettingsAPIView,
ParkListCreateAPIView,
ParkSearchSuggestionsAPIView,
)
from .ride_photos_views import RidePhotoViewSet
from .ride_reviews_views import RideReviewViewSet
from .views import HybridParkAPIView, ParkFilterMetadataAPIView, ParkPhotoViewSet
# Create router for nested photo endpoints
router = DefaultRouter()
router.register(r"", ParkPhotoViewSet, basename="park-photo")
@@ -42,13 +42,12 @@ router.register(r"", ParkPhotoViewSet, basename="park-photo")
# Create routers for nested ride endpoints
ride_photos_router = DefaultRouter()
ride_photos_router.register(r"", RidePhotoViewSet, basename="ride-photo")
from .ride_reviews_views import RideReviewViewSet
ride_reviews_router = DefaultRouter()
ride_reviews_router.register(r"", RideReviewViewSet, basename="ride-review")
from .park_reviews_views import ParkReviewViewSet
from .history_views import ParkHistoryViewSet, RideHistoryViewSet
from .park_reviews_views import ParkReviewViewSet
# Create routers for nested park endpoints
reviews_router = DefaultRouter()
@@ -60,11 +59,11 @@ app_name = "api_v1_parks"
urlpatterns = [
# Core list/create endpoints
path("", ParkListCreateAPIView.as_view(), name="park-list-create"),
# Hybrid filtering endpoints
path("hybrid/", HybridParkAPIView.as_view(), name="park-hybrid-list"),
path("hybrid/filter-metadata/", ParkFilterMetadataAPIView.as_view(), name="park-hybrid-filter-metadata"),
# Filter options
path("filter-options/", FilterOptionsAPIView.as_view(), name="park-filter-options"),
# Autocomplete / suggestion endpoints
@@ -80,14 +79,14 @@ urlpatterns = [
),
# Detail and action endpoints - supports both ID and slug
path("<str:pk>/", ParkDetailAPIView.as_view(), name="park-detail"),
# Park rides endpoints
path("<str:park_slug>/rides/", ParkRidesListAPIView.as_view(), name="park-rides-list"),
path("<str:park_slug>/rides/<str:ride_slug>/", ParkRideDetailAPIView.as_view(), name="park-ride-detail"),
# Comprehensive park detail endpoint with rides summary
path("<str:park_slug>/detail/", ParkComprehensiveDetailAPIView.as_view(), name="park-comprehensive-detail"),
# Park image settings endpoint
path(
"<int:pk>/image-settings/",
@@ -96,21 +95,21 @@ urlpatterns = [
),
# Park photo endpoints - domain-specific photo management
path("<str:park_pk>/photos/", include(router.urls)),
# Nested ride photo endpoints - photos for specific rides within parks
path("<str:park_slug>/rides/<str:ride_slug>/photos/", include(ride_photos_router.urls)),
# Nested ride review endpoints - reviews for specific rides within parks
path("<str:park_slug>/rides/<str:ride_slug>/reviews/", include(ride_reviews_router.urls)),
# Nested ride review endpoints - reviews for specific rides within parks
path("<str:park_slug>/rides/<str:ride_slug>/reviews/", include(ride_reviews_router.urls)),
# Ride History
path("<str:park_slug>/rides/<str:ride_slug>/history/", RideHistoryViewSet.as_view({'get': 'list'}), name="ride-history"),
# Park Reviews
path("<str:park_slug>/reviews/", include(reviews_router.urls)),
# Park History
path("<str:park_slug>/history/", ParkHistoryViewSet.as_view({'get': 'list'}), name="park-history"),

View File

@@ -26,14 +26,13 @@ from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.viewsets import ModelViewSet
from apps.core.decorators.cache_decorators import cache_api_response
from apps.core.exceptions import (
NotFoundError,
PermissionDeniedError,
ServiceError,
ValidationException,
)
from apps.core.utils.error_handling import ErrorHandler
from apps.core.decorators.cache_decorators import cache_api_response
from apps.parks.models import Park, ParkPhoto
from apps.parks.services import ParkMediaService
from apps.parks.services.hybrid_loader import smart_park_loader
@@ -130,10 +129,7 @@ class ParkPhotoViewSet(ModelViewSet):
def get_permissions(self):
"""Set permissions based on action."""
if self.action in ["list", "retrieve", "stats"]:
permission_classes = [AllowAny]
else:
permission_classes = [IsAuthenticated]
permission_classes = [AllowAny] if self.action in ["list", "retrieve", "stats"] else [IsAuthenticated]
return [permission() for permission in permission_classes]
def get_queryset(self): # type: ignore[override]
@@ -171,11 +167,8 @@ class ParkPhotoViewSet(ModelViewSet):
raise ValidationError("Park ID/Slug is required")
try:
if str(park_id).isdigit():
park = Park.objects.get(pk=park_id)
else:
park = Park.objects.get(slug=park_id)
park = Park.objects.get(pk=park_id) if str(park_id).isdigit() else Park.objects.get(slug=park_id)
# Use real park ID
park_id = park.id
except Park.DoesNotExist:
@@ -398,10 +391,7 @@ class ParkPhotoViewSet(ModelViewSet):
park = None
if park_pk:
try:
if str(park_pk).isdigit():
park = Park.objects.get(pk=park_pk)
else:
park = Park.objects.get(slug=park_pk)
park = Park.objects.get(pk=park_pk) if str(park_pk).isdigit() else Park.objects.get(slug=park_pk)
except Park.DoesNotExist:
return ErrorHandler.handle_api_error(
NotFoundError(f"Park with id/slug {park_pk} not found"),
@@ -490,10 +480,7 @@ class ParkPhotoViewSet(ModelViewSet):
)
try:
if str(park_pk).isdigit():
park = Park.objects.get(pk=park_pk)
else:
park = Park.objects.get(slug=park_pk)
park = Park.objects.get(pk=park_pk) if str(park_pk).isdigit() else Park.objects.get(slug=park_pk)
except Park.DoesNotExist:
return Response(
{"error": "Park not found"},
@@ -509,9 +496,9 @@ class ParkPhotoViewSet(ModelViewSet):
try:
# Import CloudflareImage model and service
from django.utils import timezone
from django_cloudflareimages_toolkit.models import CloudflareImage
from django_cloudflareimages_toolkit.services import CloudflareImagesService
from django.utils import timezone
# Always fetch the latest image data from Cloudflare API
# Get image details from Cloudflare API

View File

@@ -0,0 +1,12 @@
"""URL routes for Company CRUD API."""
from django.urls import path
from .company_views import CompanyDetailAPIView, CompanyListCreateAPIView
app_name = "api_v1_companies"
urlpatterns = [
path("", CompanyListCreateAPIView.as_view(), name="company-list-create"),
path("<int:pk>/", CompanyDetailAPIView.as_view(), name="company-detail"),
]

View File

@@ -0,0 +1,167 @@
"""
Company API views for ThrillWiki API v1.
This module implements CRUD endpoints for company management:
- List / Create: GET /companies/ POST /companies/
- Retrieve / Update / Delete: GET /companies/{id}/ PATCH/PUT/DELETE
"""
from django.db.models import Q
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema
from rest_framework import permissions, status
from rest_framework.exceptions import NotFound
from rest_framework.pagination import PageNumberPagination
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView
from apps.api.v1.serializers.companies import (
CompanyCreateInputSerializer,
CompanyDetailOutputSerializer,
CompanyUpdateInputSerializer,
)
try:
from apps.rides.models.company import Company
MODELS_AVAILABLE = True
except ImportError:
Company = None
MODELS_AVAILABLE = False
class StandardResultsSetPagination(PageNumberPagination):
page_size = 20
page_size_query_param = "page_size"
max_page_size = 100
class CompanyListCreateAPIView(APIView):
"""List and create companies."""
permission_classes = [permissions.AllowAny]
@extend_schema(
summary="List all companies",
description="List companies with optional search and role filtering.",
parameters=[
OpenApiParameter(name="search", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR),
OpenApiParameter(name="role", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR),
OpenApiParameter(name="page", location=OpenApiParameter.QUERY, type=OpenApiTypes.INT),
OpenApiParameter(name="page_size", location=OpenApiParameter.QUERY, type=OpenApiTypes.INT),
],
responses={200: CompanyDetailOutputSerializer(many=True)},
tags=["Companies"],
)
def get(self, request: Request) -> Response:
if not MODELS_AVAILABLE:
return Response(
{"detail": "Company models not available"},
status=status.HTTP_501_NOT_IMPLEMENTED,
)
qs = Company.objects.all().order_by("name")
# Search filter
search = request.query_params.get("search", "")
if search:
qs = qs.filter(
Q(name__icontains=search) | Q(description__icontains=search)
)
# Role filter
role = request.query_params.get("role", "")
if role:
qs = qs.filter(roles__contains=[role])
paginator = StandardResultsSetPagination()
page = paginator.paginate_queryset(qs, request)
serializer = CompanyDetailOutputSerializer(page, many=True)
return paginator.get_paginated_response(serializer.data)
@extend_schema(
summary="Create a new company",
description="Create a new company with the given details.",
request=CompanyCreateInputSerializer,
responses={201: CompanyDetailOutputSerializer()},
tags=["Companies"],
)
def post(self, request: Request) -> Response:
if not MODELS_AVAILABLE:
return Response(
{"detail": "Company models not available"},
status=status.HTTP_501_NOT_IMPLEMENTED,
)
serializer_in = CompanyCreateInputSerializer(data=request.data)
serializer_in.is_valid(raise_exception=True)
validated = serializer_in.validated_data
company = Company.objects.create(
name=validated["name"],
roles=validated["roles"],
description=validated.get("description", ""),
website=validated.get("website", ""),
founded_date=validated.get("founded_date"),
)
serializer = CompanyDetailOutputSerializer(company)
return Response(serializer.data, status=status.HTTP_201_CREATED)
class CompanyDetailAPIView(APIView):
"""Retrieve, update, and delete a company."""
permission_classes = [permissions.AllowAny]
def _get_company_or_404(self, pk: int) -> "Company":
if not MODELS_AVAILABLE:
raise NotFound("Company models not available")
try:
return Company.objects.get(pk=pk)
except Company.DoesNotExist:
raise NotFound("Company not found")
@extend_schema(
summary="Retrieve a company",
description="Get detailed information about a specific company.",
responses={200: CompanyDetailOutputSerializer()},
tags=["Companies"],
)
def get(self, request: Request, pk: int) -> Response:
company = self._get_company_or_404(pk)
serializer = CompanyDetailOutputSerializer(company)
return Response(serializer.data)
@extend_schema(
summary="Update a company",
description="Update a company (partial update supported).",
request=CompanyUpdateInputSerializer,
responses={200: CompanyDetailOutputSerializer()},
tags=["Companies"],
)
def patch(self, request: Request, pk: int) -> Response:
company = self._get_company_or_404(pk)
serializer_in = CompanyUpdateInputSerializer(data=request.data, partial=True)
serializer_in.is_valid(raise_exception=True)
for field, value in serializer_in.validated_data.items():
setattr(company, field, value)
company.save()
serializer = CompanyDetailOutputSerializer(company)
return Response(serializer.data)
def put(self, request: Request, pk: int) -> Response:
return self.patch(request, pk)
@extend_schema(
summary="Delete a company",
description="Delete a company.",
responses={204: None},
tags=["Companies"],
)
def delete(self, request: Request, pk: int) -> Response:
company = self._get_company_or_404(pk)
company.delete()
return Response(status=status.HTTP_204_NO_CONTENT)

View File

@@ -11,17 +11,17 @@ This file exposes comprehensive endpoints for ride model management:
from django.urls import path
from .views import (
RideModelListCreateAPIView,
RideModelDetailAPIView,
RideModelSearchAPIView,
RideModelFilterOptionsAPIView,
RideModelStatsAPIView,
RideModelVariantListCreateAPIView,
RideModelVariantDetailAPIView,
RideModelTechnicalSpecListCreateAPIView,
RideModelTechnicalSpecDetailAPIView,
RideModelPhotoListCreateAPIView,
RideModelListCreateAPIView,
RideModelPhotoDetailAPIView,
RideModelPhotoListCreateAPIView,
RideModelSearchAPIView,
RideModelStatsAPIView,
RideModelTechnicalSpecDetailAPIView,
RideModelTechnicalSpecListCreateAPIView,
RideModelVariantDetailAPIView,
RideModelVariantListCreateAPIView,
)
app_name = "api_v1_ride_models"

View File

@@ -12,40 +12,40 @@ This module implements comprehensive endpoints for ride model management:
- Photos: CRUD operations for ride model photos
"""
from typing import Any
from datetime import timedelta
from typing import Any
from rest_framework import status, permissions
from rest_framework.views import APIView
from django.db.models import Count, Q
from django.utils import timezone
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema
from rest_framework import permissions, status
from rest_framework.exceptions import NotFound, ValidationError
from rest_framework.pagination import PageNumberPagination
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.pagination import PageNumberPagination
from rest_framework.exceptions import NotFound, ValidationError
from drf_spectacular.utils import extend_schema, OpenApiParameter
from drf_spectacular.types import OpenApiTypes
from django.db.models import Q, Count
from django.utils import timezone
from rest_framework.views import APIView
# Import serializers
from apps.api.v1.serializers.ride_models import (
RideModelListOutputSerializer,
RideModelDetailOutputSerializer,
RideModelCreateInputSerializer,
RideModelUpdateInputSerializer,
RideModelDetailOutputSerializer,
RideModelFilterInputSerializer,
RideModelVariantOutputSerializer,
RideModelVariantCreateInputSerializer,
RideModelVariantUpdateInputSerializer,
RideModelListOutputSerializer,
RideModelStatsOutputSerializer,
RideModelUpdateInputSerializer,
RideModelVariantCreateInputSerializer,
RideModelVariantOutputSerializer,
RideModelVariantUpdateInputSerializer,
)
# Attempt to import models; fall back gracefully if not present
try:
from apps.rides.models import (
RideModel,
RideModelVariant,
RideModelPhoto,
RideModelTechnicalSpec,
RideModelVariant,
)
from apps.rides.models.company import Company
@@ -54,12 +54,12 @@ except ImportError:
try:
# Try alternative import path
from apps.rides.models.rides import (
Company,
RideModel,
RideModelVariant,
RideModelPhoto,
RideModelTechnicalSpec,
RideModelVariant,
)
from apps.rides.models.rides import Company
MODELS_AVAILABLE = True
except ImportError:
@@ -486,14 +486,14 @@ class RideModelFilterOptionsAPIView(APIView):
"""Return filter options for ride models with Rich Choice Objects metadata."""
# Import Rich Choice registry
from apps.core.choices.registry import get_choices
if not MODELS_AVAILABLE:
# Use Rich Choice Objects for fallback options
try:
# Get rich choice objects from registry
categories = get_choices('categories', 'rides')
target_markets = get_choices('target_markets', 'rides')
# Convert Rich Choice Objects to frontend format with metadata
categories_data = [
{
@@ -507,7 +507,7 @@ class RideModelFilterOptionsAPIView(APIView):
}
for choice in categories
]
target_markets_data = [
{
"value": choice.value,
@@ -520,7 +520,7 @@ class RideModelFilterOptionsAPIView(APIView):
}
for choice in target_markets
]
except Exception:
# Ultimate fallback with basic structure
categories_data = [
@@ -538,7 +538,7 @@ class RideModelFilterOptionsAPIView(APIView):
{"value": "KIDDIE", "label": "Kiddie", "description": "Designed for young children", "color": "pink", "icon": "kiddie", "css_class": "bg-pink-100 text-pink-800", "sort_order": 4},
{"value": "ALL_AGES", "label": "All Ages", "description": "Enjoyable for all age groups", "color": "blue", "icon": "all-ages", "css_class": "bg-blue-100 text-blue-800", "sort_order": 5},
]
return Response({
"categories": categories_data,
"target_markets": target_markets_data,
@@ -557,11 +557,11 @@ class RideModelFilterOptionsAPIView(APIView):
# Get static choice definitions from Rich Choice Objects (primary source)
# Get dynamic data from database queries
# Get rich choice objects from registry
categories = get_choices('categories', 'rides')
target_markets = get_choices('target_markets', 'rides')
# Convert Rich Choice Objects to frontend format with metadata
categories_data = [
{
@@ -575,7 +575,7 @@ class RideModelFilterOptionsAPIView(APIView):
}
for choice in categories
]
target_markets_data = [
{
"value": choice.value,

View File

@@ -5,23 +5,25 @@ This module contains ride photo ViewSet following the parks pattern for domain c
Enhanced from centralized media API to provide domain-specific ride photo management.
"""
from .serializers import (
RidePhotoOutputSerializer,
RidePhotoCreateInputSerializer,
RidePhotoUpdateInputSerializer,
RidePhotoListOutputSerializer,
RidePhotoApprovalInputSerializer,
RidePhotoStatsOutputSerializer,
)
from typing import TYPE_CHECKING
from .serializers import (
RidePhotoApprovalInputSerializer,
RidePhotoCreateInputSerializer,
RidePhotoListOutputSerializer,
RidePhotoOutputSerializer,
RidePhotoStatsOutputSerializer,
RidePhotoUpdateInputSerializer,
)
if TYPE_CHECKING:
pass
import logging
from django.contrib.auth import get_user_model
from django.core.exceptions import PermissionDenied
from drf_spectacular.utils import extend_schema_view, extend_schema
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import status
from rest_framework.decorators import action
from rest_framework.exceptions import ValidationError
@@ -29,9 +31,8 @@ from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from apps.rides.models import RidePhoto, Ride
from apps.rides.models import Ride, RidePhoto
from apps.rides.services.media_service import RideMediaService
from django.contrib.auth import get_user_model
UserModel = get_user_model()
@@ -460,9 +461,9 @@ class RidePhotoViewSet(ModelViewSet):
try:
# Import CloudflareImage model and service
from django.utils import timezone
from django_cloudflareimages_toolkit.models import CloudflareImage
from django_cloudflareimages_toolkit.services import CloudflareImagesService
from django.utils import timezone
# Always fetch the latest image data from Cloudflare API
try:

View File

@@ -4,12 +4,13 @@ Ride media serializers for ThrillWiki API v1.
This module contains serializers for ride-specific media functionality.
"""
from rest_framework import serializers
from drf_spectacular.utils import (
OpenApiExample,
extend_schema_field,
extend_schema_serializer,
OpenApiExample,
)
from rest_framework import serializers
from apps.rides.models import Ride, RidePhoto
@@ -267,33 +268,33 @@ class HybridRideSerializer(serializers.ModelSerializer):
Enhanced serializer for hybrid filtering strategy.
Includes all filterable fields for client-side filtering.
"""
# Park fields
park_name = serializers.CharField(source="park.name", read_only=True)
park_slug = serializers.CharField(source="park.slug", read_only=True)
# Park location fields
park_city = serializers.SerializerMethodField()
park_state = serializers.SerializerMethodField()
park_country = serializers.SerializerMethodField()
# Park area fields
park_area_name = serializers.CharField(source="park_area.name", read_only=True, allow_null=True)
park_area_slug = serializers.CharField(source="park_area.slug", read_only=True, allow_null=True)
# Company fields
manufacturer_name = serializers.CharField(source="manufacturer.name", read_only=True, allow_null=True)
manufacturer_slug = serializers.CharField(source="manufacturer.slug", read_only=True, allow_null=True)
designer_name = serializers.CharField(source="designer.name", read_only=True, allow_null=True)
designer_slug = serializers.CharField(source="designer.slug", read_only=True, allow_null=True)
# Ride model fields
ride_model_name = serializers.CharField(source="ride_model.name", read_only=True, allow_null=True)
ride_model_slug = serializers.CharField(source="ride_model.slug", read_only=True, allow_null=True)
ride_model_category = serializers.CharField(source="ride_model.category", read_only=True, allow_null=True)
ride_model_manufacturer_name = serializers.CharField(source="ride_model.manufacturer.name", read_only=True, allow_null=True)
ride_model_manufacturer_slug = serializers.CharField(source="ride_model.manufacturer.slug", read_only=True, allow_null=True)
# Roller coaster stats fields
coaster_height_ft = serializers.SerializerMethodField()
coaster_length_ft = serializers.SerializerMethodField()
@@ -309,15 +310,15 @@ class HybridRideSerializer(serializers.ModelSerializer):
coaster_trains_count = serializers.SerializerMethodField()
coaster_cars_per_train = serializers.SerializerMethodField()
coaster_seats_per_car = serializers.SerializerMethodField()
# Image URLs for display
banner_image_url = serializers.SerializerMethodField()
card_image_url = serializers.SerializerMethodField()
# Computed fields for filtering
opening_year = serializers.IntegerField(read_only=True)
search_text = serializers.CharField(read_only=True)
@extend_schema_field(serializers.CharField(allow_null=True))
def get_park_city(self, obj):
"""Get city from park location."""
@@ -327,7 +328,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
return None
except AttributeError:
return None
@extend_schema_field(serializers.CharField(allow_null=True))
def get_park_state(self, obj):
"""Get state from park location."""
@@ -337,7 +338,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
return None
except AttributeError:
return None
@extend_schema_field(serializers.CharField(allow_null=True))
def get_park_country(self, obj):
"""Get country from park location."""
@@ -347,7 +348,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
return None
except AttributeError:
return None
@extend_schema_field(serializers.FloatField(allow_null=True))
def get_coaster_height_ft(self, obj):
"""Get roller coaster height."""
@@ -357,7 +358,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
return None
except (AttributeError, TypeError):
return None
@extend_schema_field(serializers.FloatField(allow_null=True))
def get_coaster_length_ft(self, obj):
"""Get roller coaster length."""
@@ -367,7 +368,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
return None
except (AttributeError, TypeError):
return None
@extend_schema_field(serializers.FloatField(allow_null=True))
def get_coaster_speed_mph(self, obj):
"""Get roller coaster speed."""
@@ -377,7 +378,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
return None
except (AttributeError, TypeError):
return None
@extend_schema_field(serializers.IntegerField(allow_null=True))
def get_coaster_inversions(self, obj):
"""Get roller coaster inversions."""
@@ -387,7 +388,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
return None
except AttributeError:
return None
@extend_schema_field(serializers.IntegerField(allow_null=True))
def get_coaster_ride_time_seconds(self, obj):
"""Get roller coaster ride time."""
@@ -397,7 +398,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
return None
except AttributeError:
return None
@extend_schema_field(serializers.CharField(allow_null=True))
def get_coaster_track_type(self, obj):
"""Get roller coaster track type."""
@@ -407,7 +408,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
return None
except AttributeError:
return None
@extend_schema_field(serializers.CharField(allow_null=True))
def get_coaster_track_material(self, obj):
"""Get roller coaster track material."""
@@ -417,7 +418,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
return None
except AttributeError:
return None
@extend_schema_field(serializers.CharField(allow_null=True))
def get_coaster_roller_coaster_type(self, obj):
"""Get roller coaster type."""
@@ -427,7 +428,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
return None
except AttributeError:
return None
@extend_schema_field(serializers.FloatField(allow_null=True))
def get_coaster_max_drop_height_ft(self, obj):
"""Get roller coaster max drop height."""
@@ -437,7 +438,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
return None
except (AttributeError, TypeError):
return None
@extend_schema_field(serializers.CharField(allow_null=True))
def get_coaster_propulsion_system(self, obj):
"""Get roller coaster propulsion system."""
@@ -447,7 +448,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
return None
except AttributeError:
return None
@extend_schema_field(serializers.CharField(allow_null=True))
def get_coaster_train_style(self, obj):
"""Get roller coaster train style."""
@@ -457,7 +458,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
return None
except AttributeError:
return None
@extend_schema_field(serializers.IntegerField(allow_null=True))
def get_coaster_trains_count(self, obj):
"""Get roller coaster trains count."""
@@ -467,7 +468,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
return None
except AttributeError:
return None
@extend_schema_field(serializers.IntegerField(allow_null=True))
def get_coaster_cars_per_train(self, obj):
"""Get roller coaster cars per train."""
@@ -477,7 +478,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
return None
except AttributeError:
return None
@extend_schema_field(serializers.IntegerField(allow_null=True))
def get_coaster_seats_per_car(self, obj):
"""Get roller coaster seats per car."""
@@ -487,14 +488,14 @@ class HybridRideSerializer(serializers.ModelSerializer):
return None
except AttributeError:
return None
@extend_schema_field(serializers.URLField(allow_null=True))
def get_banner_image_url(self, obj):
"""Get banner image URL."""
if obj.banner_image and obj.banner_image.image:
return obj.banner_image.image.url
return None
@extend_schema_field(serializers.URLField(allow_null=True))
def get_card_image_url(self, obj):
"""Get card image URL."""
@@ -513,44 +514,44 @@ class HybridRideSerializer(serializers.ModelSerializer):
"category",
"status",
"post_closing_status",
# Dates and computed fields
"opening_date",
"closing_date",
"status_since",
"opening_year",
# Park fields
"park_name",
"park_slug",
"park_city",
"park_state",
"park_country",
# Park area fields
"park_area_name",
"park_area_slug",
# Company fields
"manufacturer_name",
"manufacturer_slug",
"designer_name",
"designer_slug",
# Ride model fields
"ride_model_name",
"ride_model_slug",
"ride_model_category",
"ride_model_manufacturer_name",
"ride_model_manufacturer_slug",
# Ride specifications
"min_height_in",
"max_height_in",
"capacity_per_hour",
"ride_duration_seconds",
"average_rating",
# Roller coaster stats
"coaster_height_ft",
"coaster_length_ft",
@@ -566,18 +567,18 @@ class HybridRideSerializer(serializers.ModelSerializer):
"coaster_trains_count",
"coaster_cars_per_train",
"coaster_seats_per_car",
# Images
"banner_image_url",
"card_image_url",
# URLs
"url",
"park_url",
# Computed fields for filtering
"search_text",
# Metadata
"created_at",
"updated_at",

View File

@@ -8,23 +8,23 @@ actions (bulk, publish, export, import, recommendations) should be added
to the views module when business logic is available.
"""
from django.urls import path, include
from django.urls import include, path
from rest_framework.routers import DefaultRouter
from .photo_views import RidePhotoViewSet
from .views import (
RideListCreateAPIView,
RideDetailAPIView,
FilterOptionsAPIView,
CompanySearchAPIView,
DesignerListAPIView,
FilterOptionsAPIView,
HybridRideAPIView,
ManufacturerListAPIView,
RideDetailAPIView,
RideFilterMetadataAPIView,
RideImageSettingsAPIView,
RideListCreateAPIView,
RideModelSearchAPIView,
RideSearchSuggestionsAPIView,
RideImageSettingsAPIView,
HybridRideAPIView,
RideFilterMetadataAPIView,
ManufacturerListAPIView,
DesignerListAPIView,
)
from .photo_views import RidePhotoViewSet
# Create router for nested photo endpoints
router = DefaultRouter()
@@ -35,11 +35,11 @@ app_name = "api_v1_rides"
urlpatterns = [
# Core list/create endpoints
path("", RideListCreateAPIView.as_view(), name="ride-list-create"),
# Hybrid filtering endpoints
path("hybrid/", HybridRideAPIView.as_view(), name="ride-hybrid-filtering"),
path("hybrid/filter-metadata/", RideFilterMetadataAPIView.as_view(), name="ride-hybrid-filter-metadata"),
# Filter options
path("filter-options/", FilterOptionsAPIView.as_view(), name="ride-filter-options"),
# Autocomplete / suggestion endpoints

View File

@@ -23,12 +23,13 @@ Caching Strategy:
- RideSearchSuggestionsAPIView.get: 5 minutes (300s) - suggestions should be fresh
"""
import contextlib
import logging
from typing import Any
from django.db import models
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema, extend_schema_view, OpenApiParameter
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view
from rest_framework import permissions, status
from rest_framework.exceptions import NotFound
from rest_framework.pagination import PageNumberPagination
@@ -53,9 +54,9 @@ smart_ride_loader = SmartRideLoader()
# Attempt to import model-level helpers; fall back gracefully if not present.
try:
from apps.parks.models import Company, Park
from apps.rides.models import Ride, RideModel
from apps.rides.models.rides import RollerCoasterStats
from apps.parks.models import Park, Company
MODELS_AVAILABLE = True
except Exception:
@@ -370,10 +371,8 @@ class RideListCreateAPIView(APIView):
park_id = params.get("park_id")
if park_id:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(park_id=int(park_id))
except (ValueError, TypeError):
pass
return qs
@@ -393,10 +392,8 @@ class RideListCreateAPIView(APIView):
"""Apply manufacturer and designer filtering."""
manufacturer_id = params.get("manufacturer_id")
if manufacturer_id:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(manufacturer_id=int(manufacturer_id))
except (ValueError, TypeError):
pass
manufacturer_slug = params.get("manufacturer_slug")
if manufacturer_slug:
@@ -404,10 +401,8 @@ class RideListCreateAPIView(APIView):
designer_id = params.get("designer_id")
if designer_id:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(designer_id=int(designer_id))
except (ValueError, TypeError):
pass
designer_slug = params.get("designer_slug")
if designer_slug:
@@ -419,10 +414,8 @@ class RideListCreateAPIView(APIView):
"""Apply ride model filtering."""
ride_model_id = params.get("ride_model_id")
if ride_model_id:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(ride_model_id=int(ride_model_id))
except (ValueError, TypeError):
pass
ride_model_slug = params.get("ride_model_slug")
manufacturer_slug_for_model = params.get("manufacturer_slug")
@@ -438,17 +431,13 @@ class RideListCreateAPIView(APIView):
"""Apply rating-based filtering."""
min_rating = params.get("min_rating")
if min_rating:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(average_rating__gte=float(min_rating))
except (ValueError, TypeError):
pass
max_rating = params.get("max_rating")
if max_rating:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(average_rating__lte=float(max_rating))
except (ValueError, TypeError):
pass
return qs
@@ -456,17 +445,13 @@ class RideListCreateAPIView(APIView):
"""Apply height requirement filtering."""
min_height_req = params.get("min_height_requirement")
if min_height_req:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(min_height_in__gte=int(min_height_req))
except (ValueError, TypeError):
pass
max_height_req = params.get("max_height_requirement")
if max_height_req:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(max_height_in__lte=int(max_height_req))
except (ValueError, TypeError):
pass
return qs
@@ -474,17 +459,13 @@ class RideListCreateAPIView(APIView):
"""Apply capacity filtering."""
min_capacity = params.get("min_capacity")
if min_capacity:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(capacity_per_hour__gte=int(min_capacity))
except (ValueError, TypeError):
pass
max_capacity = params.get("max_capacity")
if max_capacity:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(capacity_per_hour__lte=int(max_capacity))
except (ValueError, TypeError):
pass
return qs
@@ -492,24 +473,18 @@ class RideListCreateAPIView(APIView):
"""Apply opening year filtering."""
opening_year = params.get("opening_year")
if opening_year:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(opening_date__year=int(opening_year))
except (ValueError, TypeError):
pass
min_opening_year = params.get("min_opening_year")
if min_opening_year:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(opening_date__year__gte=int(min_opening_year))
except (ValueError, TypeError):
pass
max_opening_year = params.get("max_opening_year")
if max_opening_year:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(opening_date__year__lte=int(max_opening_year))
except (ValueError, TypeError):
pass
return qs
@@ -530,47 +505,35 @@ class RideListCreateAPIView(APIView):
# Height filters
min_height_ft = params.get("min_height_ft")
if min_height_ft:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(coaster_stats__height_ft__gte=float(min_height_ft))
except (ValueError, TypeError):
pass
max_height_ft = params.get("max_height_ft")
if max_height_ft:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(coaster_stats__height_ft__lte=float(max_height_ft))
except (ValueError, TypeError):
pass
# Speed filters
min_speed_mph = params.get("min_speed_mph")
if min_speed_mph:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(coaster_stats__speed_mph__gte=float(min_speed_mph))
except (ValueError, TypeError):
pass
max_speed_mph = params.get("max_speed_mph")
if max_speed_mph:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(coaster_stats__speed_mph__lte=float(max_speed_mph))
except (ValueError, TypeError):
pass
# Inversion filters
min_inversions = params.get("min_inversions")
if min_inversions:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(coaster_stats__inversions__gte=int(min_inversions))
except (ValueError, TypeError):
pass
max_inversions = params.get("max_inversions")
if max_inversions:
try:
with contextlib.suppress(ValueError, TypeError):
qs = qs.filter(coaster_stats__inversions__lte=int(max_inversions))
except (ValueError, TypeError):
pass
has_inversions = params.get("has_inversions")
if has_inversions is not None:
@@ -2176,10 +2139,8 @@ class HybridRideAPIView(APIView):
value = query_params.get(param)
if value:
if param == "park_id":
try:
with contextlib.suppress(ValueError):
filters[param] = int(value)
except ValueError:
pass
else:
filters[param] = value
@@ -2461,14 +2422,14 @@ class RideFilterMetadataAPIView(APIView):
class BaseCompanyListAPIView(APIView):
permission_classes = [permissions.AllowAny]
role = None
def get(self, request: Request) -> Response:
if not MODELS_AVAILABLE:
return Response(
{"detail": "Models not available"},
status=status.HTTP_501_NOT_IMPLEMENTED
)
companies = (
Company.objects.filter(roles__contains=[self.role])
.annotate(ride_count=Count("manufactured_rides" if self.role == "MANUFACTURER" else "designed_rides"))
@@ -2486,7 +2447,7 @@ class BaseCompanyListAPIView(APIView):
}
for c in companies
]
return Response({
"results": data,
"count": len(data)

View File

@@ -5,88 +5,88 @@ This module provides a unified interface to all serializers across different dom
while maintaining the modular structure for better organization and maintainability.
"""
import importlib
from typing import Any
# --- Companies and ride models domain ---
from .companies import (
CompanyCreateInputSerializer,
CompanyDetailOutputSerializer,
CompanyUpdateInputSerializer,
RideModelCreateInputSerializer,
RideModelDetailOutputSerializer,
RideModelUpdateInputSerializer,
) # noqa: F401
# --- Parks domain ---
from .parks import (
ParkAreaCreateInputSerializer,
ParkAreaDetailOutputSerializer,
ParkAreaUpdateInputSerializer,
ParkCreateInputSerializer,
ParkDetailOutputSerializer,
ParkFilterInputSerializer,
ParkListOutputSerializer,
ParkLocationCreateInputSerializer,
ParkLocationOutputSerializer,
ParkLocationUpdateInputSerializer,
ParkSuggestionOutputSerializer,
ParkSuggestionSerializer,
ParkUpdateInputSerializer,
) # noqa: F401
# --- Rides domain ---
from .rides import (
RideCreateInputSerializer,
RideDetailOutputSerializer,
RideFilterInputSerializer,
RideListOutputSerializer,
RideLocationCreateInputSerializer,
RideLocationOutputSerializer,
RideLocationUpdateInputSerializer,
RideModelOutputSerializer,
RideParkOutputSerializer,
RideReviewCreateInputSerializer,
RideReviewOutputSerializer,
RideReviewUpdateInputSerializer,
RideUpdateInputSerializer,
RollerCoasterStatsCreateInputSerializer,
RollerCoasterStatsOutputSerializer,
RollerCoasterStatsUpdateInputSerializer,
) # noqa: F401
from .services import (
HealthCheckOutputSerializer,
PerformanceMetricsOutputSerializer,
SimpleHealthOutputSerializer,
EmailSendInputSerializer,
EmailTemplateOutputSerializer,
MapDataOutputSerializer,
CoordinateInputSerializer,
HistoryEventSerializer,
HistoryEntryOutputSerializer,
HistoryCreateInputSerializer,
ModerationSubmissionSerializer,
ModerationSubmissionOutputSerializer,
RoadtripParkSerializer,
RoadtripCreateInputSerializer,
RoadtripOutputSerializer,
GeocodeInputSerializer,
GeocodeOutputSerializer,
DistanceCalculationInputSerializer,
DistanceCalculationOutputSerializer,
EmailSendInputSerializer,
EmailTemplateOutputSerializer,
GeocodeInputSerializer,
GeocodeOutputSerializer,
HealthCheckOutputSerializer,
HistoryCreateInputSerializer,
HistoryEntryOutputSerializer,
HistoryEventSerializer,
MapDataOutputSerializer,
ModerationSubmissionOutputSerializer,
ModerationSubmissionSerializer,
PerformanceMetricsOutputSerializer,
RoadtripCreateInputSerializer,
RoadtripOutputSerializer,
RoadtripParkSerializer,
SimpleHealthOutputSerializer,
) # noqa: F401
from typing import Any, Dict, List
import importlib
# --- Shared utilities and base classes ---
from .shared import (
FilterOptionSerializer,
FilterRangeSerializer,
StandardizedFilterMetadataSerializer,
validate_filter_metadata_contract,
ensure_filter_option_format,
) # noqa: F401
# --- Parks domain ---
from .parks import (
ParkListOutputSerializer,
ParkDetailOutputSerializer,
ParkCreateInputSerializer,
ParkUpdateInputSerializer,
ParkFilterInputSerializer,
ParkAreaDetailOutputSerializer,
ParkAreaCreateInputSerializer,
ParkAreaUpdateInputSerializer,
ParkLocationOutputSerializer,
ParkLocationCreateInputSerializer,
ParkLocationUpdateInputSerializer,
ParkSuggestionSerializer,
ParkSuggestionOutputSerializer,
) # noqa: F401
# --- Companies and ride models domain ---
from .companies import (
CompanyDetailOutputSerializer,
CompanyCreateInputSerializer,
CompanyUpdateInputSerializer,
RideModelDetailOutputSerializer,
RideModelCreateInputSerializer,
RideModelUpdateInputSerializer,
) # noqa: F401
# --- Rides domain ---
from .rides import (
RideParkOutputSerializer,
RideModelOutputSerializer,
RideListOutputSerializer,
RideDetailOutputSerializer,
RideCreateInputSerializer,
RideUpdateInputSerializer,
RideFilterInputSerializer,
RollerCoasterStatsOutputSerializer,
RollerCoasterStatsCreateInputSerializer,
RollerCoasterStatsUpdateInputSerializer,
RideLocationOutputSerializer,
RideLocationCreateInputSerializer,
RideLocationUpdateInputSerializer,
RideReviewOutputSerializer,
RideReviewCreateInputSerializer,
RideReviewUpdateInputSerializer,
validate_filter_metadata_contract,
) # noqa: F401
# --- Accounts domain: try multiple likely locations, fall back to placeholders ---
_ACCOUNTS_SYMBOLS: List[str] = [
_ACCOUNTS_SYMBOLS: list[str] = [
"UserProfileOutputSerializer",
"UserProfileCreateInputSerializer",
"UserProfileUpdateInputSerializer",
@@ -106,7 +106,7 @@ _ACCOUNTS_SYMBOLS: List[str] = [
]
def _import_accounts_symbols() -> Dict[str, Any]:
def _import_accounts_symbols() -> dict[str, Any]:
"""
Try a list of candidate module paths and return a dict mapping expected symbol
names to the objects found. If no candidate provides a symbol, the symbol maps to None.
@@ -119,7 +119,7 @@ def _import_accounts_symbols() -> Dict[str, Any]:
]
# Prepare default placeholders
result: Dict[str, Any] = {name: None for name in _ACCOUNTS_SYMBOLS}
result: dict[str, Any] = dict.fromkeys(_ACCOUNTS_SYMBOLS)
for modname in candidates:
try:

View File

@@ -5,21 +5,22 @@ This module contains all serializers related to user account management,
profile settings, preferences, privacy, notifications, and security.
"""
from rest_framework import serializers
from django.contrib.auth import get_user_model
from drf_spectacular.utils import (
extend_schema_serializer,
OpenApiExample,
extend_schema_serializer,
)
from rest_framework import serializers
from apps.accounts.models import (
User,
UserProfile,
UserNotification,
NotificationPreference,
User,
UserNotification,
UserProfile,
)
from apps.core.choices.serializers import RichChoiceFieldSerializer
from apps.lists.models import UserList
from apps.rides.models.credits import RideCredit
from apps.core.choices.serializers import RichChoiceFieldSerializer
UserModel = get_user_model()
@@ -187,7 +188,7 @@ class PublicUserSerializer(serializers.ModelSerializer):
Only exposes public information.
"""
profile = UserProfileSerializer(read_only=True)
class Meta:
model = User
fields = [
@@ -906,9 +907,10 @@ class AvatarUploadSerializer(serializers.Serializer):
# Try to validate with PIL
try:
from PIL import Image
import io
from PIL import Image
value.seek(0)
image_data = value.read()
value.seek(0) # Reset for later use

View File

@@ -5,14 +5,14 @@ This module contains all serializers related to user authentication,
registration, password management, and social authentication.
"""
from rest_framework import serializers
from django.contrib.auth import get_user_model, authenticate
from django.contrib.auth import authenticate, get_user_model
from django.contrib.auth.password_validation import validate_password
from django.core.exceptions import ValidationError as DjangoValidationError
from drf_spectacular.utils import (
extend_schema_serializer,
OpenApiExample,
extend_schema_serializer,
)
from rest_framework import serializers
UserModel = get_user_model()

View File

@@ -5,16 +5,16 @@ This module contains all serializers related to companies that operate parks
or manufacture rides, as well as ride model serializers.
"""
from rest_framework import serializers
from drf_spectacular.utils import (
extend_schema_serializer,
extend_schema_field,
OpenApiExample,
extend_schema_field,
extend_schema_serializer,
)
from rest_framework import serializers
from .shared import ModelChoices
from apps.core.choices.serializers import RichChoiceFieldSerializer
from .shared import ModelChoices
# === COMPANY SERIALIZERS ===

View File

@@ -5,8 +5,8 @@ This module contains serializers for history tracking and timeline functionality
using django-pghistory.
"""
from rest_framework import serializers
from drf_spectacular.utils import extend_schema_field
from rest_framework import serializers
class ParkHistoryEventSerializer(serializers.Serializer):

View File

@@ -5,13 +5,12 @@ This module contains all serializers related to map functionality,
including location data, search results, and clustering.
"""
from rest_framework import serializers
from drf_spectacular.utils import (
extend_schema_serializer,
extend_schema_field,
OpenApiExample,
extend_schema_field,
extend_schema_serializer,
)
from rest_framework import serializers
# === MAP LOCATION SERIALIZERS ===

View File

@@ -5,13 +5,12 @@ This module contains serializers for photo uploads, media management,
and related media functionality.
"""
from rest_framework import serializers
from drf_spectacular.utils import (
extend_schema_serializer,
extend_schema_field,
OpenApiExample,
extend_schema_field,
extend_schema_serializer,
)
from rest_framework import serializers
# === MEDIA SERIALIZERS ===

View File

@@ -5,13 +5,12 @@ This module contains serializers for statistics, health checks, and other
miscellaneous functionality.
"""
from rest_framework import serializers
from drf_spectacular.utils import (
extend_schema_field,
)
from .shared import ModelChoices
from apps.core.choices.serializers import RichChoiceFieldSerializer
from rest_framework import serializers
from apps.core.choices.serializers import RichChoiceFieldSerializer
# === STATISTICS SERIALIZERS ===

View File

@@ -4,10 +4,12 @@ Serializers for park review API endpoints.
This module contains serializers for park review CRUD operations.
"""
from drf_spectacular.utils import OpenApiExample, extend_schema_field, extend_schema_serializer
from rest_framework import serializers
from drf_spectacular.utils import extend_schema_field, extend_schema_serializer, OpenApiExample
from apps.parks.models.reviews import ParkReview
from apps.api.v1.serializers.reviews import ReviewUserSerializer
from apps.parks.models.reviews import ParkReview
@extend_schema_serializer(
examples=[

View File

@@ -5,18 +5,18 @@ This module contains all serializers related to parks, park areas, park location
and park search functionality.
"""
from rest_framework import serializers
from drf_spectacular.utils import (
extend_schema_serializer,
extend_schema_field,
OpenApiExample,
extend_schema_field,
extend_schema_serializer,
)
from rest_framework import serializers
from apps.core.choices.serializers import RichChoiceFieldSerializer
from apps.core.services.media_url_service import MediaURLService
from config.django import base as settings
from .shared import LocationOutputSerializer, CompanyOutputSerializer, ModelChoices
from apps.core.services.media_url_service import MediaURLService
from apps.core.choices.serializers import RichChoiceFieldSerializer
from .shared import CompanyOutputSerializer, LocationOutputSerializer, ModelChoices
# === PARK SERIALIZERS ===

View File

@@ -5,6 +5,7 @@ This module contains serializers for park-specific media functionality.
"""
from rest_framework import serializers
from apps.parks.models import ParkPhoto

View File

@@ -3,9 +3,10 @@ Serializers for review-related API endpoints.
"""
from rest_framework import serializers
from apps.accounts.models import User
from apps.parks.models.reviews import ParkReview
from apps.rides.models.reviews import RideReview
from apps.accounts.models import User
class ReviewUserSerializer(serializers.ModelSerializer):

View File

@@ -1,17 +1,18 @@
from rest_framework import serializers
from drf_spectacular.utils import extend_schema_field
from apps.rides.models.credits import RideCredit
from apps.rides.models import Ride
from apps.api.v1.serializers.rides import RideListOutputSerializer
from apps.rides.models import Ride
from apps.rides.models.credits import RideCredit
class RideCreditSerializer(serializers.ModelSerializer):
"""Serializer for user ride credits."""
ride_id = serializers.PrimaryKeyRelatedField(
queryset=Ride.objects.all(), source='ride', write_only=True
)
ride = RideListOutputSerializer(read_only=True)
class Meta:
model = RideCredit
fields = [
@@ -23,6 +24,7 @@ class RideCreditSerializer(serializers.ModelSerializer):
'first_ridden_at',
'last_ridden_at',
'notes',
'display_order',
'created_at',
'updated_at',
]
@@ -37,7 +39,7 @@ class RideCreditSerializer(serializers.ModelSerializer):
last = attrs.get('last_ridden_at')
if first and last and last < first:
raise serializers.ValidationError("Last ridden date cannot be before first ridden date.")
return attrs
def create(self, validated_data):

View File

@@ -5,16 +5,17 @@ This module contains all serializers related to ride models, variants,
technical specifications, and related functionality.
"""
from rest_framework import serializers
from drf_spectacular.utils import (
extend_schema_serializer,
extend_schema_field,
OpenApiExample,
extend_schema_field,
extend_schema_serializer,
)
from rest_framework import serializers
from apps.core.choices.serializers import RichChoiceFieldSerializer
from config.django import base as settings
from .shared import ModelChoices
from apps.core.choices.serializers import RichChoiceFieldSerializer
# Use dynamic imports to avoid circular import issues
@@ -23,9 +24,9 @@ def get_ride_model_classes():
"""Get ride model classes dynamically to avoid import issues."""
from apps.rides.models import (
RideModel,
RideModelVariant,
RideModelPhoto,
RideModelTechnicalSpec,
RideModelVariant,
)
return RideModel, RideModelVariant, RideModelPhoto, RideModelTechnicalSpec

View File

@@ -4,11 +4,11 @@ Serializers for ride review API endpoints.
This module contains serializers for ride review CRUD operations with Rich Choice Objects compliance.
"""
from drf_spectacular.utils import OpenApiExample, extend_schema_field, extend_schema_serializer
from rest_framework import serializers
from drf_spectacular.utils import extend_schema_field, extend_schema_serializer, OpenApiExample
from apps.rides.models.reviews import RideReview
from apps.accounts.models import User
from apps.core.choices.serializers import RichChoiceSerializer
from apps.rides.models.reviews import RideReview
class ReviewUserSerializer(serializers.ModelSerializer):
@@ -74,7 +74,7 @@ class RideReviewOutputSerializer(serializers.ModelSerializer):
"""Output serializer for ride reviews."""
user = ReviewUserSerializer(read_only=True)
# Ride information
ride = serializers.SerializerMethodField()
park = serializers.SerializerMethodField()

View File

@@ -5,16 +5,17 @@ This module contains all serializers related to rides, roller coaster statistics
ride locations, and ride reviews.
"""
from rest_framework import serializers
from drf_spectacular.utils import (
extend_schema_serializer,
extend_schema_field,
OpenApiExample,
extend_schema_field,
extend_schema_serializer,
)
from config.django import base as settings
from .shared import ModelChoices
from apps.core.choices.serializers import RichChoiceFieldSerializer
from rest_framework import serializers
from apps.core.choices.serializers import RichChoiceFieldSerializer
from config.django import base as settings
from .shared import ModelChoices
# === RIDE SERIALIZERS ===

View File

@@ -5,6 +5,7 @@ This module contains serializers for ride-specific media functionality.
"""
from rest_framework import serializers
from apps.rides.models import RidePhoto

View File

@@ -6,9 +6,10 @@ and other search functionality.
"""
from rest_framework import serializers
from ..shared import ModelChoices
from apps.core.choices.serializers import RichChoiceFieldSerializer
from ..shared import ModelChoices
# === CORE ENTITY SEARCH SERIALIZERS ===

View File

@@ -5,11 +5,10 @@ This module contains serializers for various services like email, maps,
history tracking, moderation, and roadtrip planning.
"""
from rest_framework import serializers
from drf_spectacular.utils import (
extend_schema_field,
)
from rest_framework import serializers
# === HEALTH CHECK SERIALIZERS ===

View File

@@ -8,14 +8,15 @@ These serializers prevent contract violations by providing a single source of tr
for common data structures used throughout the API.
"""
from typing import Any
from rest_framework import serializers
from typing import Dict, Any, List
class FilterOptionSerializer(serializers.Serializer):
"""
Standard filter option format - matches frontend TypeScript exactly.
Frontend TypeScript interface:
interface FilterOption {
value: string;
@@ -31,7 +32,7 @@ class FilterOptionSerializer(serializers.Serializer):
help_text="Human-readable display label"
)
count = serializers.IntegerField(
required=False,
required=False,
allow_null=True,
help_text="Number of items matching this filter option"
)
@@ -44,7 +45,7 @@ class FilterOptionSerializer(serializers.Serializer):
class FilterRangeSerializer(serializers.Serializer):
"""
Standard range filter format.
Frontend TypeScript interface:
interface FilterRange {
min: number;
@@ -66,7 +67,7 @@ class FilterRangeSerializer(serializers.Serializer):
help_text="Step size for range inputs"
)
unit = serializers.CharField(
required=False,
required=False,
allow_null=True,
help_text="Unit of measurement (e.g., 'feet', 'mph', 'stars')"
)
@@ -75,7 +76,7 @@ class FilterRangeSerializer(serializers.Serializer):
class BooleanFilterSerializer(serializers.Serializer):
"""
Standard boolean filter format.
Frontend TypeScript interface:
interface BooleanFilter {
key: string;
@@ -97,7 +98,7 @@ class BooleanFilterSerializer(serializers.Serializer):
class OrderingOptionSerializer(serializers.Serializer):
"""
Standard ordering option format.
Frontend TypeScript interface:
interface OrderingOption {
value: string;
@@ -115,7 +116,7 @@ class OrderingOptionSerializer(serializers.Serializer):
class StandardizedFilterMetadataSerializer(serializers.Serializer):
"""
Matches frontend TypeScript interface exactly.
This serializer ensures all filter metadata responses follow the same structure
that the frontend expects, preventing runtime type errors.
"""
@@ -131,7 +132,7 @@ class StandardizedFilterMetadataSerializer(serializers.Serializer):
help_text="Total number of items in the filtered dataset"
)
ordering_options = FilterOptionSerializer(
many=True,
many=True,
required=False,
help_text="Available ordering options"
)
@@ -145,7 +146,7 @@ class StandardizedFilterMetadataSerializer(serializers.Serializer):
class PaginationMetadataSerializer(serializers.Serializer):
"""
Standard pagination metadata format.
Frontend TypeScript interface:
interface PaginationMetadata {
count: number;
@@ -183,7 +184,7 @@ class PaginationMetadataSerializer(serializers.Serializer):
class ApiResponseSerializer(serializers.Serializer):
"""
Standard API response wrapper.
Frontend TypeScript interface:
interface ApiResponse<T> {
success: boolean;
@@ -214,7 +215,7 @@ class ApiResponseSerializer(serializers.Serializer):
class ErrorResponseSerializer(serializers.Serializer):
"""
Standard error response format.
Frontend TypeScript interface:
interface ApiError {
status: "error";
@@ -245,7 +246,7 @@ class ErrorResponseSerializer(serializers.Serializer):
class LocationSerializer(serializers.Serializer):
"""
Standard location format.
Frontend TypeScript interface:
interface Location {
city: string;
@@ -291,7 +292,7 @@ LocationOutputSerializer = LocationSerializer
class CompanyOutputSerializer(serializers.Serializer):
"""
Standard company output format.
Frontend TypeScript interface:
interface Company {
id: number;
@@ -322,24 +323,24 @@ class ModelChoices:
"""
Utility class to provide model choices for serializers using Rich Choice Objects.
This prevents circular imports while providing access to model choices from the registry.
NO FALLBACKS - All choices must be properly defined in Rich Choice Objects.
"""
@staticmethod
def get_park_status_choices():
"""Get park status choices from Rich Choice registry."""
from apps.core.choices.registry import get_choices
choices = get_choices("statuses", "parks")
return [(choice.value, choice.label) for choice in choices]
@staticmethod
def get_ride_status_choices():
"""Get ride status choices from Rich Choice registry."""
from apps.core.choices.registry import get_choices
choices = get_choices("statuses", "rides")
return [(choice.value, choice.label) for choice in choices]
@staticmethod
def get_company_role_choices():
"""Get company role choices from Rich Choice registry."""
@@ -350,91 +351,91 @@ class ModelChoices:
parks_choices = get_choices("company_roles", "parks")
all_choices = list(rides_choices) + list(parks_choices)
return [(choice.value, choice.label) for choice in all_choices]
@staticmethod
def get_ride_category_choices():
"""Get ride category choices from Rich Choice registry."""
from apps.core.choices.registry import get_choices
choices = get_choices("categories", "rides")
return [(choice.value, choice.label) for choice in choices]
@staticmethod
def get_ride_post_closing_choices():
"""Get ride post-closing status choices from Rich Choice registry."""
from apps.core.choices.registry import get_choices
choices = get_choices("post_closing_statuses", "rides")
return [(choice.value, choice.label) for choice in choices]
@staticmethod
def get_coaster_track_choices():
"""Get coaster track material choices from Rich Choice registry."""
from apps.core.choices.registry import get_choices
choices = get_choices("track_materials", "rides")
return [(choice.value, choice.label) for choice in choices]
@staticmethod
def get_coaster_type_choices():
"""Get coaster type choices from Rich Choice registry."""
from apps.core.choices.registry import get_choices
choices = get_choices("coaster_types", "rides")
return [(choice.value, choice.label) for choice in choices]
@staticmethod
def get_launch_choices():
"""Get launch system choices from Rich Choice registry (legacy method)."""
from apps.core.choices.registry import get_choices
choices = get_choices("propulsion_systems", "rides")
return [(choice.value, choice.label) for choice in choices]
@staticmethod
def get_propulsion_system_choices():
"""Get propulsion system choices from Rich Choice registry."""
from apps.core.choices.registry import get_choices
choices = get_choices("propulsion_systems", "rides")
return [(choice.value, choice.label) for choice in choices]
@staticmethod
def get_photo_type_choices():
"""Get photo type choices from Rich Choice registry."""
from apps.core.choices.registry import get_choices
choices = get_choices("photo_types", "rides")
return [(choice.value, choice.label) for choice in choices]
@staticmethod
def get_spec_category_choices():
"""Get technical specification category choices from Rich Choice registry."""
from apps.core.choices.registry import get_choices
choices = get_choices("spec_categories", "rides")
return [(choice.value, choice.label) for choice in choices]
@staticmethod
def get_technical_spec_category_choices():
"""Get technical specification category choices from Rich Choice registry."""
from apps.core.choices.registry import get_choices
choices = get_choices("spec_categories", "rides")
return [(choice.value, choice.label) for choice in choices]
@staticmethod
def get_target_market_choices():
"""Get target market choices from Rich Choice registry."""
from apps.core.choices.registry import get_choices
choices = get_choices("target_markets", "rides")
return [(choice.value, choice.label) for choice in choices]
@staticmethod
def get_entity_type_choices():
"""Get entity type choices for search functionality."""
from apps.core.choices.registry import get_choices
choices = get_choices("entity_types", "core")
return [(choice.value, choice.label) for choice in choices]
@staticmethod
def get_health_status_choices():
"""Get health check status choices from Rich Choice registry."""
from apps.core.choices.registry import get_choices
choices = get_choices("health_statuses", "core")
return [(choice.value, choice.label) for choice in choices]
@staticmethod
def get_simple_health_status_choices():
"""Get simple health check status choices from Rich Choice registry."""
@@ -446,7 +447,7 @@ class ModelChoices:
class EntityReferenceSerializer(serializers.Serializer):
"""
Standard entity reference format.
Frontend TypeScript interface:
interface Entity {
id: number;
@@ -468,7 +469,7 @@ class EntityReferenceSerializer(serializers.Serializer):
class ImageVariantsSerializer(serializers.Serializer):
"""
Standard image variants format.
Frontend TypeScript interface:
interface ImageVariants {
thumbnail: string;
@@ -495,7 +496,7 @@ class ImageVariantsSerializer(serializers.Serializer):
class PhotoSerializer(serializers.Serializer):
"""
Standard photo format.
Frontend TypeScript interface:
interface Photo {
id: number;
@@ -546,7 +547,7 @@ class PhotoSerializer(serializers.Serializer):
class UserInfoSerializer(serializers.Serializer):
"""
Standard user info format.
Frontend TypeScript interface:
interface UserInfo {
id: number;
@@ -571,19 +572,19 @@ class UserInfoSerializer(serializers.Serializer):
)
def validate_filter_metadata_contract(data: Dict[str, Any]) -> Dict[str, Any]:
def validate_filter_metadata_contract(data: dict[str, Any]) -> dict[str, Any]:
"""
Validate that filter metadata follows the expected contract.
This function can be used in views to ensure filter metadata
matches the frontend TypeScript interface before returning it.
Args:
data: Filter metadata dictionary
Returns:
Validated and potentially transformed data
Raises:
serializers.ValidationError: If data doesn't match contract
"""
@@ -593,21 +594,21 @@ def validate_filter_metadata_contract(data: Dict[str, Any]) -> Dict[str, Any]:
return serializer.validated_data
def ensure_filter_option_format(options: List[Any]) -> List[Dict[str, Any]]:
def ensure_filter_option_format(options: list[Any]) -> list[dict[str, Any]]:
"""
Ensure a list of filter options follows the expected format.
This utility function converts various input formats to the standard
FilterOption format expected by the frontend.
Args:
options: List of options in various formats
Returns:
List of options in standard format
"""
standardized = []
for option in options:
if isinstance(option, dict):
# Already in correct format or close to it
@@ -633,19 +634,19 @@ def ensure_filter_option_format(options: List[Any]) -> List[Dict[str, Any]]:
'count': None,
'selected': False
}
standardized.append(standardized_option)
return standardized
def ensure_range_format(range_data: Dict[str, Any]) -> Dict[str, Any]:
def ensure_range_format(range_data: dict[str, Any]) -> dict[str, Any]:
"""
Ensure range data follows the expected format.
Args:
range_data: Range data dictionary
Returns:
Range data in standard format
"""

View File

@@ -2,13 +2,14 @@
API serializers for the ride ranking system.
"""
from rest_framework import serializers
from drf_spectacular.utils import (
extend_schema_serializer,
extend_schema_field,
OpenApiExample,
extend_schema_field,
extend_schema_serializer,
)
from apps.rides.models import RideRanking, RankingSnapshot
from rest_framework import serializers
from apps.rides.models import RankingSnapshot, RideRanking
@extend_schema_serializer(
@@ -179,6 +180,7 @@ class RideRankingDetailSerializer(serializers.ModelSerializer):
def get_head_to_head_comparisons(self, obj):
"""Get top head-to-head comparisons."""
from django.db.models import Q
from apps.rides.models import RidePairComparison
comparisons = (

View File

@@ -5,17 +5,20 @@ This module contains signal handlers that invalidate the stats cache
whenever relevant entities are created, updated, or deleted.
"""
from django.db.models.signals import post_save, post_delete
from django.dispatch import receiver
from django.core.cache import cache
from django.db.models.signals import post_delete, post_save
from django.dispatch import receiver
from apps.parks.models import Park, ParkReview, ParkPhoto, Company as ParkCompany
from apps.parks.models import Company as ParkCompany
from apps.parks.models import Park, ParkPhoto, ParkReview
from apps.rides.models import (
Company as RideCompany,
)
from apps.rides.models import (
Ride,
RollerCoasterStats,
RideReview,
RidePhoto,
Company as RideCompany,
RideReview,
RollerCoasterStats,
)

View File

@@ -5,120 +5,120 @@ These tests verify that API responses match frontend TypeScript interfaces exact
preventing runtime errors and ensuring type safety.
"""
from django.test import TestCase, Client
from django.test import Client, TestCase
from rest_framework.test import APITestCase
from apps.api.v1.serializers.shared import (
ensure_filter_option_format,
ensure_range_format,
validate_filter_metadata_contract,
)
from apps.parks.services.hybrid_loader import smart_park_loader
from apps.rides.services.hybrid_loader import SmartRideLoader
from apps.api.v1.serializers.shared import (
validate_filter_metadata_contract,
ensure_filter_option_format,
ensure_range_format
)
class FilterMetadataContractTests(TestCase):
"""Test that filter metadata follows the expected contract."""
def setUp(self):
self.client = Client()
def test_parks_filter_metadata_structure(self):
"""Test that parks filter metadata has correct structure."""
# Get filter metadata from the service
metadata = smart_park_loader.get_filter_metadata()
# Should have required top-level keys
self.assertIn('categorical', metadata)
self.assertIn('ranges', metadata)
self.assertIn('total_count', metadata)
# Categorical filters should be objects with value/label/count
categorical = metadata['categorical']
self.assertIsInstance(categorical, dict)
for filter_name, filter_options in categorical.items():
with self.subTest(filter_name=filter_name):
self.assertIsInstance(filter_options, list,
self.assertIsInstance(filter_options, list,
f"Filter '{filter_name}' should be a list")
for i, option in enumerate(filter_options):
with self.subTest(filter_name=filter_name, option_index=i):
self.assertIsInstance(option, dict,
f"Filter '{filter_name}' option {i} should be an object, not {type(option).__name__}")
# Check required properties
self.assertIn('value', option,
f"Filter '{filter_name}' option {i} missing 'value' property")
self.assertIn('label', option,
f"Filter '{filter_name}' option {i} missing 'label' property")
# Check types
self.assertIsInstance(option['value'], str,
f"Filter '{filter_name}' option {i} 'value' should be string")
self.assertIsInstance(option['label'], str,
f"Filter '{filter_name}' option {i} 'label' should be string")
# Count is optional but should be int if present
if 'count' in option and option['count'] is not None:
self.assertIsInstance(option['count'], int,
f"Filter '{filter_name}' option {i} 'count' should be int")
def test_rides_filter_metadata_structure(self):
"""Test that rides filter metadata has correct structure."""
loader = SmartRideLoader()
metadata = loader.get_filter_metadata()
# Should have required top-level keys
self.assertIn('categorical', metadata)
self.assertIn('ranges', metadata)
self.assertIn('total_count', metadata)
# Categorical filters should be objects with value/label/count
categorical = metadata['categorical']
self.assertIsInstance(categorical, dict)
# Test specific categorical filters that were problematic
critical_filters = ['categories', 'statuses', 'roller_coaster_types', 'track_materials']
for filter_name in critical_filters:
if filter_name in categorical:
with self.subTest(filter_name=filter_name):
filter_options = categorical[filter_name]
self.assertIsInstance(filter_options, list)
for i, option in enumerate(filter_options):
with self.subTest(filter_name=filter_name, option_index=i):
self.assertIsInstance(option, dict,
f"CRITICAL: Filter '{filter_name}' option {i} is {type(option).__name__} but should be dict")
self.assertIn('value', option)
self.assertIn('label', option)
self.assertIn('count', option)
def test_range_metadata_structure(self):
"""Test that range metadata has correct structure."""
# Test parks ranges
parks_metadata = smart_park_loader.get_filter_metadata()
ranges = parks_metadata['ranges']
for range_name, range_data in ranges.items():
with self.subTest(range_name=range_name):
self.assertIsInstance(range_data, dict,
f"Range '{range_name}' should be an object")
# Check required properties
self.assertIn('min', range_data)
self.assertIn('max', range_data)
self.assertIn('step', range_data)
self.assertIn('unit', range_data)
# Check types (min/max can be None)
if range_data['min'] is not None:
self.assertIsInstance(range_data['min'], (int, float))
if range_data['max'] is not None:
self.assertIsInstance(range_data['max'], (int, float))
self.assertIsInstance(range_data['step'], (int, float))
# Unit can be None or string
if range_data['unit'] is not None:
@@ -127,7 +127,7 @@ class FilterMetadataContractTests(TestCase):
class ContractValidationUtilityTests(TestCase):
"""Test contract validation utility functions."""
def test_validate_filter_metadata_contract_valid(self):
"""Test validation passes for valid filter metadata."""
valid_metadata = {
@@ -147,16 +147,16 @@ class ContractValidationUtilityTests(TestCase):
},
'total_count': 100
}
# Should not raise an exception
validated = validate_filter_metadata_contract(valid_metadata)
self.assertIsInstance(validated, dict)
self.assertEqual(validated['total_count'], 100)
def test_validate_filter_metadata_contract_invalid(self):
"""Test validation fails for invalid filter metadata."""
from rest_framework import serializers
invalid_metadata = {
'categorical': {
'statuses': ['OPERATING', 'CLOSED_TEMP'] # Should be objects, not strings
@@ -164,17 +164,17 @@ class ContractValidationUtilityTests(TestCase):
'ranges': {},
'total_count': 100
}
# Should raise ValidationError
with self.assertRaises(serializers.ValidationError):
validate_filter_metadata_contract(invalid_metadata)
def test_ensure_filter_option_format_strings(self):
"""Test converting string arrays to proper format."""
string_options = ['OPERATING', 'CLOSED_TEMP', 'UNDER_CONSTRUCTION']
formatted = ensure_filter_option_format(string_options)
self.assertEqual(len(formatted), 3)
for i, option in enumerate(formatted):
self.assertIsInstance(option, dict)
@@ -182,44 +182,44 @@ class ContractValidationUtilityTests(TestCase):
self.assertIn('label', option)
self.assertIn('count', option)
self.assertIn('selected', option)
self.assertEqual(option['value'], string_options[i])
self.assertEqual(option['label'], string_options[i])
self.assertIsNone(option['count'])
self.assertFalse(option['selected'])
def test_ensure_filter_option_format_tuples(self):
"""Test converting tuple arrays to proper format."""
tuple_options = [
('OPERATING', 'Operating', 5),
('CLOSED_TEMP', 'Temporarily Closed', 2)
]
formatted = ensure_filter_option_format(tuple_options)
self.assertEqual(len(formatted), 2)
self.assertEqual(formatted[0]['value'], 'OPERATING')
self.assertEqual(formatted[0]['label'], 'Operating')
self.assertEqual(formatted[0]['count'], 5)
self.assertEqual(formatted[1]['value'], 'CLOSED_TEMP')
self.assertEqual(formatted[1]['label'], 'Temporarily Closed')
self.assertEqual(formatted[1]['count'], 2)
def test_ensure_filter_option_format_dicts(self):
"""Test that properly formatted dicts pass through correctly."""
dict_options = [
{'value': 'OPERATING', 'label': 'Operating', 'count': 5},
{'value': 'CLOSED_TEMP', 'label': 'Temporarily Closed', 'count': 2}
]
formatted = ensure_filter_option_format(dict_options)
self.assertEqual(len(formatted), 2)
self.assertEqual(formatted[0]['value'], 'OPERATING')
self.assertEqual(formatted[0]['label'], 'Operating')
self.assertEqual(formatted[0]['count'], 5)
def test_ensure_range_format(self):
"""Test range format utility."""
range_data = {
@@ -228,36 +228,36 @@ class ContractValidationUtilityTests(TestCase):
'step': 0.5,
'unit': 'stars'
}
formatted = ensure_range_format(range_data)
self.assertEqual(formatted['min'], 1.0)
self.assertEqual(formatted['max'], 10.0)
self.assertEqual(formatted['step'], 0.5)
self.assertEqual(formatted['unit'], 'stars')
def test_ensure_range_format_missing_step(self):
"""Test range format with missing step defaults to 1.0."""
range_data = {
'min': 1,
'max': 10
}
formatted = ensure_range_format(range_data)
self.assertEqual(formatted['step'], 1.0)
self.assertIsNone(formatted['unit'])
class APIEndpointContractTests(APITestCase):
"""Test actual API endpoints for contract compliance."""
def test_parks_hybrid_endpoint_contract(self):
"""Test parks hybrid endpoint returns proper contract."""
# This would require actual data in the database
# For now, we'll test the structure
pass
def test_rides_hybrid_endpoint_contract(self):
"""Test rides hybrid endpoint returns proper contract."""
# This would require actual data in the database
@@ -267,7 +267,7 @@ class APIEndpointContractTests(APITestCase):
class TypeScriptInterfaceComplianceTests(TestCase):
"""Test that responses match TypeScript interfaces exactly."""
def test_filter_option_interface_compliance(self):
"""Test FilterOption interface compliance."""
# TypeScript interface:
@@ -277,28 +277,28 @@ class TypeScriptInterfaceComplianceTests(TestCase):
# count?: number;
# selected?: boolean;
# }
option = {
'value': 'OPERATING',
'label': 'Operating',
'count': 5,
'selected': False
}
# All required fields present
self.assertIn('value', option)
self.assertIn('label', option)
# Correct types
self.assertIsInstance(option['value'], str)
self.assertIsInstance(option['label'], str)
# Optional fields have correct types if present
if 'count' in option and option['count'] is not None:
self.assertIsInstance(option['count'], int)
if 'selected' in option:
self.assertIsInstance(option['selected'], bool)
def test_filter_range_interface_compliance(self):
"""Test FilterRange interface compliance."""
# TypeScript interface:
@@ -308,27 +308,27 @@ class TypeScriptInterfaceComplianceTests(TestCase):
# step: number;
# unit?: string;
# }
range_data = {
'min': 1.0,
'max': 10.0,
'step': 0.1,
'unit': 'stars'
}
# All required fields present
self.assertIn('min', range_data)
self.assertIn('max', range_data)
self.assertIn('step', range_data)
# Correct types (min/max can be null)
if range_data['min'] is not None:
self.assertIsInstance(range_data['min'], (int, float))
if range_data['max'] is not None:
self.assertIsInstance(range_data['max'], (int, float))
self.assertIsInstance(range_data['step'], (int, float))
# Optional unit field
if 'unit' in range_data and range_data['unit'] is not None:
self.assertIsInstance(range_data['unit'], str)
@@ -336,72 +336,72 @@ class TypeScriptInterfaceComplianceTests(TestCase):
class RegressionTests(TestCase):
"""Regression tests for specific contract violations that were fixed."""
def test_categorical_filters_not_strings(self):
"""Regression test: Ensure categorical filters are never returned as strings."""
# This was the main issue - categorical filters were returned as:
# ['OPERATING', 'CLOSED_TEMP'] instead of
# ['OPERATING', 'CLOSED_TEMP'] instead of
# [{'value': 'OPERATING', 'label': 'Operating', 'count': 5}, ...]
# Test parks
parks_metadata = smart_park_loader.get_filter_metadata()
categorical = parks_metadata.get('categorical', {})
for filter_name, filter_options in categorical.items():
with self.subTest(filter_name=filter_name):
self.assertIsInstance(filter_options, list)
for i, option in enumerate(filter_options):
with self.subTest(filter_name=filter_name, option_index=i):
self.assertIsInstance(option, dict,
f"REGRESSION: Filter '{filter_name}' option {i} is a {type(option).__name__} "
f"but should be a dict. This causes frontend crashes!")
# Must not be a string
self.assertNotIsInstance(option, str,
f"CRITICAL REGRESSION: Filter '{filter_name}' option {i} is a string '{option}' "
f"but frontend expects object with value/label/count properties!")
# Test rides
rides_loader = SmartRideLoader()
rides_metadata = rides_loader.get_filter_metadata()
categorical = rides_metadata.get('categorical', {})
for filter_name, filter_options in categorical.items():
with self.subTest(filter_name=f"rides_{filter_name}"):
self.assertIsInstance(filter_options, list)
for i, option in enumerate(filter_options):
with self.subTest(filter_name=f"rides_{filter_name}", option_index=i):
self.assertIsInstance(option, dict,
f"REGRESSION: Rides filter '{filter_name}' option {i} is a {type(option).__name__} "
f"but should be a dict. This causes frontend crashes!")
def test_ranges_have_step_and_unit(self):
"""Regression test: Ensure ranges have step and unit properties."""
# Frontend expects: { min: number, max: number, step: number, unit?: string }
# Backend was sometimes missing step and unit
parks_metadata = smart_park_loader.get_filter_metadata()
ranges = parks_metadata.get('ranges', {})
for range_name, range_data in ranges.items():
with self.subTest(range_name=range_name):
self.assertIn('step', range_data,
f"Range '{range_name}' missing 'step' property required by frontend")
self.assertIn('unit', range_data,
f"Range '{range_name}' missing 'unit' property required by frontend")
# Step should be a number
self.assertIsInstance(range_data['step'], (int, float),
f"Range '{range_name}' step should be a number")
def test_no_undefined_values(self):
"""Regression test: Ensure no undefined values (should be null)."""
# JavaScript undefined !== null, and TypeScript interfaces expect null
parks_metadata = smart_park_loader.get_filter_metadata()
def check_no_undefined(obj, path=""):
if isinstance(obj, dict):
for key, value in obj.items():
@@ -413,6 +413,6 @@ class RegressionTests(TestCase):
for i, item in enumerate(obj):
current_path = f"{path}[{i}]"
check_no_undefined(item, current_path)
# This will recursively check the entire metadata structure
check_no_undefined(parks_metadata)

View File

@@ -5,23 +5,24 @@ This module provides unified API routing following RESTful conventions
and DRF Router patterns for automatic URL generation.
"""
from .viewsets_rankings import RideRankingViewSet, TriggerRankingCalculationView
from django.urls import include, path
from rest_framework.routers import DefaultRouter
# Import other views from the views directory
from .views import (
HealthCheckAPIView,
NewContentAPIView,
PerformanceMetricsAPIView,
SimpleHealthAPIView,
# Trending system views
TrendingAPIView,
NewContentAPIView,
TriggerTrendingCalculationAPIView,
)
from .views.discovery import DiscoveryAPIView
from .views.stats import StatsAPIView, StatsRecalculateAPIView
from .views.reviews import LatestReviewsAPIView
from .views.leaderboard import leaderboard
from django.urls import path, include
from rest_framework.routers import DefaultRouter
from .views.reviews import LatestReviewsAPIView
from .views.stats import StatsAPIView, StatsRecalculateAPIView
from .viewsets_rankings import RideRankingViewSet, TriggerRankingCalculationView
# Create the main API router
router = DefaultRouter()
@@ -79,6 +80,7 @@ urlpatterns = [
path("core/", include("apps.api.v1.core.urls")),
path("maps/", include("apps.api.v1.maps.urls")),
path("lists/", include("apps.lists.urls")),
path("companies/", include("apps.api.v1.rides.company_urls")),
path("moderation/", include("apps.moderation.urls")),
path("reviews/", include("apps.reviews.urls")),
path("media/", include("apps.media.urls")),

View File

@@ -9,25 +9,23 @@ This package contains all API view classes organized by functionality:
# Import all view classes for easy access
from .auth import (
LoginAPIView,
SignupAPIView,
LogoutAPIView,
CurrentUserAPIView,
PasswordResetAPIView,
PasswordChangeAPIView,
SocialProvidersAPIView,
AuthStatusAPIView,
CurrentUserAPIView,
LoginAPIView,
LogoutAPIView,
PasswordChangeAPIView,
PasswordResetAPIView,
SignupAPIView,
SocialProvidersAPIView,
)
from .health import (
HealthCheckAPIView,
PerformanceMetricsAPIView,
SimpleHealthAPIView,
)
from .trending import (
TrendingAPIView,
NewContentAPIView,
TrendingAPIView,
TriggerTrendingCalculationAPIView,
)

View File

@@ -7,34 +7,34 @@ login, signup, logout, password management, and social authentication.
# type: ignore[misc,attr-defined,arg-type,call-arg,index,assignment]
from typing import TYPE_CHECKING, Type, Any
from django.contrib.auth import login, logout, get_user_model
from typing import TYPE_CHECKING, Any
from django.contrib.auth import get_user_model, login, logout
from django.contrib.sites.shortcuts import get_current_site
from django.core.exceptions import ValidationError
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework import status
from rest_framework.views import APIView
from rest_framework.permissions import AllowAny, IsAuthenticated
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.permissions import AllowAny, IsAuthenticated
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework.views import APIView
# Import serializers from the auth serializers module
from ..serializers.auth import (
AuthStatusOutputSerializer,
LoginInputSerializer,
LoginOutputSerializer,
SignupInputSerializer,
SignupOutputSerializer,
LogoutOutputSerializer,
UserOutputSerializer,
PasswordResetInputSerializer,
PasswordResetOutputSerializer,
PasswordChangeInputSerializer,
PasswordChangeOutputSerializer,
PasswordResetInputSerializer,
PasswordResetOutputSerializer,
SignupInputSerializer,
SignupOutputSerializer,
SocialProviderOutputSerializer,
AuthStatusOutputSerializer,
UserOutputSerializer,
)
# Handle optional dependencies with fallback classes
@@ -56,7 +56,7 @@ except ImportError:
if TYPE_CHECKING:
from typing import Union
TurnstileMixinType = Union[Type[FallbackTurnstileMixin], Any]
TurnstileMixinType = Union[type[FallbackTurnstileMixin], Any]
else:
TurnstileMixinType = TurnstileMixin

View File

@@ -6,16 +6,15 @@ consistent formats that match frontend TypeScript interfaces exactly.
"""
import logging
from typing import Dict, Any, Optional, Type
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
from rest_framework.serializers import Serializer
from django.conf import settings
from typing import Any
from apps.api.v1.serializers.shared import (
validate_filter_metadata_contract
)
from django.conf import settings
from rest_framework import status
from rest_framework.response import Response
from rest_framework.serializers import Serializer
from rest_framework.views import APIView
from apps.api.v1.serializers.shared import validate_filter_metadata_contract
logger = logging.getLogger(__name__)
@@ -23,28 +22,28 @@ logger = logging.getLogger(__name__)
class ContractCompliantAPIView(APIView):
"""
Base API view that ensures all responses are contract-compliant.
This view provides:
- Standardized success response format
- Consistent error response format
- Automatic contract validation in DEBUG mode
- Proper error logging with context
"""
# Override in subclasses to specify response serializer
response_serializer_class: Optional[Type[Serializer]] = None
response_serializer_class: type[Serializer] | None = None
def dispatch(self, request, *args, **kwargs):
"""Override dispatch to add contract validation."""
try:
response = super().dispatch(request, *args, **kwargs)
# Validate contract in DEBUG mode
if settings.DEBUG and hasattr(response, 'data'):
self._validate_response_contract(response.data)
return response
except Exception as e:
# Log the error with context
logger.error(
@@ -58,66 +57,66 @@ class ContractCompliantAPIView(APIView):
},
exc_info=True
)
# Return standardized error response
return self.error_response(
message="An internal error occurred",
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR
)
def success_response(
self,
data: Any = None,
message: str = None,
self,
data: Any = None,
message: str = None,
status_code: int = status.HTTP_200_OK,
headers: Dict[str, str] = None
headers: dict[str, str] = None
) -> Response:
"""
Create a standardized success response.
Args:
data: Response data
message: Optional success message
status_code: HTTP status code
headers: Optional response headers
Returns:
Response with standardized format
"""
response_data = {
'success': True
}
if data is not None:
response_data['data'] = data
if message:
response_data['message'] = message
return Response(
response_data,
response_data,
status=status_code,
headers=headers
)
def error_response(
self,
message: str,
status_code: int = status.HTTP_400_BAD_REQUEST,
error_code: str = None,
details: Any = None,
headers: Dict[str, str] = None
headers: dict[str, str] = None
) -> Response:
"""
Create a standardized error response.
Args:
message: Error message
status_code: HTTP status code
error_code: Optional error code
details: Optional error details
headers: Optional response headers
Returns:
Response with standardized error format
"""
@@ -125,40 +124,40 @@ class ContractCompliantAPIView(APIView):
'code': error_code or 'API_ERROR',
'message': message
}
if details:
error_data['details'] = details
# Add user context if available
if hasattr(self, 'request') and hasattr(self.request, 'user'):
user = self.request.user
if user and user.is_authenticated:
error_data['request_user'] = user.username
response_data = {
'status': 'error',
'error': error_data,
'data': None
}
return Response(
response_data,
status=status_code,
headers=headers
)
def validation_error_response(
self,
errors: Dict[str, Any],
errors: dict[str, Any],
message: str = "Validation failed"
) -> Response:
"""
Create a standardized validation error response.
Args:
errors: Validation errors dictionary
message: Error message
Returns:
Response with validation errors
"""
@@ -170,11 +169,11 @@ class ContractCompliantAPIView(APIView):
},
status=status.HTTP_400_BAD_REQUEST
)
def _validate_response_contract(self, data: Any) -> None:
"""
Validate response data against expected contracts.
This method is called automatically in DEBUG mode to catch
contract violations during development.
"""
@@ -182,9 +181,9 @@ class ContractCompliantAPIView(APIView):
# Check if this looks like filter metadata
if isinstance(data, dict) and 'categorical' in data and 'ranges' in data:
validate_filter_metadata_contract(data)
# Add more contract validations as needed
except Exception as e:
logger.warning(
f"Contract validation failed in {self.__class__.__name__}: {str(e)}",
@@ -199,30 +198,30 @@ class ContractCompliantAPIView(APIView):
class FilterMetadataAPIView(ContractCompliantAPIView):
"""
Base view for filter metadata endpoints.
This view ensures filter metadata responses always follow the correct
contract that matches frontend TypeScript interfaces.
"""
def get_filter_metadata(self) -> Dict[str, Any]:
def get_filter_metadata(self) -> dict[str, Any]:
"""
Override this method in subclasses to provide filter metadata.
Returns:
Filter metadata dictionary
"""
raise NotImplementedError("Subclasses must implement get_filter_metadata()")
def get(self, request, *args, **kwargs):
"""Handle GET requests for filter metadata."""
try:
metadata = self.get_filter_metadata()
# Validate the metadata contract
validated_metadata = validate_filter_metadata_contract(metadata)
return self.success_response(validated_metadata)
except Exception as e:
logger.error(
f"Error getting filter metadata in {self.__class__.__name__}: {str(e)}",
@@ -232,7 +231,7 @@ class FilterMetadataAPIView(ContractCompliantAPIView):
},
exc_info=True
)
return self.error_response(
message="Failed to retrieve filter metadata",
error_code="FILTER_METADATA_ERROR"
@@ -242,37 +241,37 @@ class FilterMetadataAPIView(ContractCompliantAPIView):
class HybridFilteringAPIView(ContractCompliantAPIView):
"""
Base view for hybrid filtering endpoints.
This view provides common functionality for hybrid filtering responses
and ensures they follow the correct contract.
"""
def get_hybrid_data(self, filters: Dict[str, Any] = None) -> Dict[str, Any]:
def get_hybrid_data(self, filters: dict[str, Any] = None) -> dict[str, Any]:
"""
Override this method in subclasses to provide hybrid data.
Args:
filters: Filter parameters
Returns:
Hybrid response dictionary
"""
raise NotImplementedError("Subclasses must implement get_hybrid_data()")
def get(self, request, *args, **kwargs):
"""Handle GET requests for hybrid filtering."""
try:
# Extract filters from request parameters
filters = self.extract_filters(request)
# Get hybrid data
hybrid_data = self.get_hybrid_data(filters)
# Validate hybrid response structure
self._validate_hybrid_response(hybrid_data)
return self.success_response(hybrid_data)
except Exception as e:
logger.error(
f"Error in hybrid filtering for {self.__class__.__name__}: {str(e)}",
@@ -283,21 +282,21 @@ class HybridFilteringAPIView(ContractCompliantAPIView):
},
exc_info=True
)
return self.error_response(
message="Failed to retrieve filtered data",
error_code="HYBRID_FILTERING_ERROR"
)
def extract_filters(self, request) -> Dict[str, Any]:
def extract_filters(self, request) -> dict[str, Any]:
"""
Extract filter parameters from request.
Override this method in subclasses to customize filter extraction.
Args:
request: HTTP request object
Returns:
Dictionary of filter parameters
"""
@@ -306,24 +305,24 @@ class HybridFilteringAPIView(ContractCompliantAPIView):
for key, value in request.query_params.items():
if value: # Only include non-empty values
filters[key] = value
# Store for error logging
self._extracted_filters = filters
return filters
def _validate_hybrid_response(self, data: Dict[str, Any]) -> None:
def _validate_hybrid_response(self, data: dict[str, Any]) -> None:
"""Validate hybrid response structure."""
required_fields = ['strategy', 'total_count']
for field in required_fields:
if field not in data:
raise ValueError(f"Hybrid response missing required field: {field}")
# Validate strategy value
if data['strategy'] not in ['client_side', 'server_side']:
raise ValueError(f"Invalid strategy value: {data['strategy']}")
# Validate filter metadata if present
if 'filter_metadata' in data:
validate_filter_metadata_contract(data['filter_metadata'])
@@ -332,77 +331,77 @@ class HybridFilteringAPIView(ContractCompliantAPIView):
class PaginatedAPIView(ContractCompliantAPIView):
"""
Base view for paginated responses.
This view ensures paginated responses follow the correct contract
with consistent pagination metadata.
"""
default_page_size = 20
max_page_size = 100
def get_paginated_response(
self,
queryset,
serializer_class: Type[Serializer],
serializer_class: type[Serializer],
request,
page_size: int = None
) -> Response:
"""
Create a paginated response.
Args:
queryset: Django queryset to paginate
serializer_class: Serializer class for items
request: HTTP request object
page_size: Optional page size override
Returns:
Paginated response
"""
from django.core.paginator import Paginator, EmptyPage, PageNotAnInteger
from django.core.paginator import EmptyPage, PageNotAnInteger, Paginator
# Determine page size
if page_size is None:
page_size = min(
int(request.query_params.get('page_size', self.default_page_size)),
self.max_page_size
)
# Get page number
page_number = request.query_params.get('page', 1)
try:
page_number = int(page_number)
except (ValueError, TypeError):
page_number = 1
# Create paginator
paginator = Paginator(queryset, page_size)
try:
page = paginator.page(page_number)
except PageNotAnInteger:
page = paginator.page(1)
except EmptyPage:
page = paginator.page(paginator.num_pages)
# Serialize data
serializer = serializer_class(page.object_list, many=True)
# Build pagination URLs
request_url = request.build_absolute_uri().split('?')[0]
query_params = request.query_params.copy()
next_url = None
if page.has_next():
query_params['page'] = page.next_page_number()
next_url = f"{request_url}?{query_params.urlencode()}"
previous_url = None
if page.has_previous():
query_params['page'] = page.previous_page_number()
previous_url = f"{request_url}?{query_params.urlencode()}"
# Create response data
response_data = {
'count': paginator.count,
@@ -413,36 +412,36 @@ class PaginatedAPIView(ContractCompliantAPIView):
'current_page': page.number,
'total_pages': paginator.num_pages
}
return self.success_response(response_data)
def contract_compliant_view(view_class):
"""
Decorator to make any view contract-compliant.
This decorator can be applied to existing views to add contract
validation without changing the base class.
"""
original_dispatch = view_class.dispatch
def new_dispatch(self, request, *args, **kwargs):
try:
response = original_dispatch(self, request, *args, **kwargs)
# Add contract validation in DEBUG mode
if settings.DEBUG and hasattr(response, 'data'):
# Basic validation - can be extended
pass
return response
except Exception as e:
logger.error(
f"Error in decorated view {view_class.__name__}: {str(e)}",
exc_info=True
)
# Return basic error response
return Response(
{
@@ -455,6 +454,6 @@ def contract_compliant_view(view_class):
},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
view_class.dispatch = new_dispatch
return view_class

View File

@@ -1,14 +1,14 @@
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
from rest_framework.permissions import AllowAny
from django.db.models import F
from django.utils import timezone
from drf_spectacular.utils import extend_schema
from datetime import timedelta
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from rest_framework.views import APIView
from apps.parks.models import Park
from apps.rides.models import Ride
class DiscoveryAPIView(APIView):
"""
API endpoint for discovery content (Top Lists, Opening/Closing Soon).
@@ -28,7 +28,7 @@ class DiscoveryAPIView(APIView):
# --- TOP LISTS ---
# Top Parks by average rating
top_parks = Park.objects.filter(average_rating__isnull=False).order_by("-average_rating")[:limit]
# Top Rides by average rating (fallback to RideRanking in future)
top_rides = Ride.objects.filter(average_rating__isnull=False).order_by("-average_rating")[:limit]
@@ -70,7 +70,7 @@ class DiscoveryAPIView(APIView):
"rides": self._serialize(recently_closed_rides, "ride"),
}
}
return Response(data)
def _serialize(self, queryset, type_):

View File

@@ -6,14 +6,15 @@ performance metrics, and database analysis.
"""
import time
from django.utils import timezone
from django.conf import settings
from rest_framework.views import APIView
from django.utils import timezone
from drf_spectacular.utils import extend_schema, extend_schema_view
from health_check.views import MainView
from rest_framework.permissions import AllowAny
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.permissions import AllowAny
from health_check.views import MainView
from drf_spectacular.utils import extend_schema, extend_schema_view
from rest_framework.views import APIView
# Import serializers
from ..serializers import (
@@ -150,9 +151,10 @@ class HealthCheckAPIView(APIView):
def _get_database_metrics(self) -> dict:
"""Get database performance metrics."""
try:
from django.db import connection
from typing import Any
from django.db import connection
# Get basic connection info
metrics: dict[str, Any] = {
"vendor": connection.vendor,

View File

@@ -1,18 +1,18 @@
"""
Leaderboard views for user rankings
"""
from rest_framework.decorators import api_view, permission_classes
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from datetime import timedelta
from django.db.models import Count, Sum
from django.db.models.functions import Coalesce
from django.utils import timezone
from datetime import timedelta
from rest_framework.decorators import api_view, permission_classes
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from apps.accounts.models import User
from apps.rides.models import RideCredit
from apps.reviews.models import Review
from apps.moderation.models import EditSubmission
from apps.reviews.models import Review
from apps.rides.models import RideCredit
@api_view(['GET'])
@@ -20,7 +20,7 @@ from apps.moderation.models import EditSubmission
def leaderboard(request):
"""
Get user leaderboard data.
Query params:
- category: 'credits' | 'reviews' | 'contributions' (default: credits)
- period: 'all' | 'monthly' | 'weekly' (default: all)
@@ -29,14 +29,14 @@ def leaderboard(request):
category = request.query_params.get('category', 'credits')
period = request.query_params.get('period', 'all')
limit = min(int(request.query_params.get('limit', 25)), 100)
# Calculate date filter based on period
date_filter = None
if period == 'weekly':
date_filter = timezone.now() - timedelta(days=7)
elif period == 'monthly':
date_filter = timezone.now() - timedelta(days=30)
if category == 'credits':
return _get_credits_leaderboard(date_filter, limit)
elif category == 'reviews':
@@ -50,16 +50,16 @@ def leaderboard(request):
def _get_credits_leaderboard(date_filter, limit):
"""Top users by total ride credits."""
queryset = RideCredit.objects.all()
if date_filter:
queryset = queryset.filter(created_at__gte=date_filter)
# Aggregate credits per user
users_data = queryset.values('user_id', 'user__username', 'user__display_name').annotate(
total_credits=Coalesce(Sum('count'), 0),
unique_rides=Count('ride', distinct=True),
).order_by('-total_credits')[:limit]
results = []
for rank, entry in enumerate(users_data, 1):
results.append({
@@ -70,7 +70,7 @@ def _get_credits_leaderboard(date_filter, limit):
'total_credits': entry['total_credits'],
'unique_rides': entry['unique_rides'],
})
return Response({
'category': 'credits',
'results': results,
@@ -80,15 +80,15 @@ def _get_credits_leaderboard(date_filter, limit):
def _get_reviews_leaderboard(date_filter, limit):
"""Top users by review count."""
queryset = Review.objects.all()
if date_filter:
queryset = queryset.filter(created_at__gte=date_filter)
# Count reviews per user
users_data = queryset.values('user_id', 'user__username', 'user__display_name').annotate(
review_count=Count('id'),
).order_by('-review_count')[:limit]
results = []
for rank, entry in enumerate(users_data, 1):
results.append({
@@ -98,7 +98,7 @@ def _get_reviews_leaderboard(date_filter, limit):
'display_name': entry['user__display_name'] or entry['user__username'],
'review_count': entry['review_count'],
})
return Response({
'category': 'reviews',
'results': results,
@@ -108,15 +108,15 @@ def _get_reviews_leaderboard(date_filter, limit):
def _get_contributions_leaderboard(date_filter, limit):
"""Top users by approved contributions."""
queryset = EditSubmission.objects.filter(status='approved')
if date_filter:
queryset = queryset.filter(created_at__gte=date_filter)
# Count contributions per user
users_data = queryset.values('submitted_by_id', 'submitted_by__username', 'submitted_by__display_name').annotate(
contribution_count=Count('id'),
).order_by('-contribution_count')[:limit]
results = []
for rank, entry in enumerate(users_data, 1):
results.append({
@@ -126,7 +126,7 @@ def _get_contributions_leaderboard(date_filter, limit):
'display_name': entry['submitted_by__display_name'] or entry['submitted_by__username'],
'contribution_count': entry['contribution_count'],
})
return Response({
'category': 'contributions',
'results': results,

View File

@@ -2,17 +2,19 @@
Views for review-related API endpoints.
"""
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework.permissions import AllowAny
from rest_framework import status
from drf_spectacular.utils import extend_schema, OpenApiParameter
from drf_spectacular.types import OpenApiTypes
from itertools import chain
from operator import attrgetter
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema
from rest_framework import status
from rest_framework.permissions import AllowAny
from rest_framework.response import Response
from rest_framework.views import APIView
from apps.parks.models.reviews import ParkReview
from apps.rides.models.reviews import RideReview
from ..serializers.reviews import LatestReviewSerializer

View File

@@ -5,24 +5,29 @@ Provides aggregate statistics about the platform's content including
counts of parks, rides, manufacturers, and other entities.
"""
from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework import status
from rest_framework.permissions import AllowAny, IsAdminUser
from django.db.models import Count
from django.core.cache import cache
from django.utils import timezone
from drf_spectacular.utils import extend_schema, OpenApiExample
from datetime import datetime
from apps.parks.models import Park, ParkReview, ParkPhoto, Company as ParkCompany
from django.core.cache import cache
from django.db.models import Count
from django.utils import timezone
from drf_spectacular.utils import OpenApiExample, extend_schema
from rest_framework import status
from rest_framework.permissions import AllowAny, IsAdminUser
from rest_framework.response import Response
from rest_framework.views import APIView
from apps.parks.models import Company as ParkCompany
from apps.parks.models import Park, ParkPhoto, ParkReview
from apps.rides.models import (
Ride,
RollerCoasterStats,
RideReview,
RidePhoto,
Company as RideCompany,
)
from apps.rides.models import (
Ride,
RidePhoto,
RideReview,
RollerCoasterStats,
)
from ..serializers.stats import StatsSerializer
@@ -103,17 +108,17 @@ class StatsAPIView(APIView):
summary="Get platform statistics",
description="""
Returns comprehensive aggregate statistics about the ThrillWiki platform.
This endpoint provides detailed counts and breakdowns of all major entities including:
- Parks, rides, and roller coasters
- Companies (manufacturers, operators, designers, property owners)
- Photos and reviews
- Ride categories (roller coasters, dark rides, flat rides, etc.)
- Status breakdowns (operating, closed, under construction, etc.)
Results are cached for 5 minutes for optimal performance and automatically
Results are cached for 5 minutes for optimal performance and automatically
invalidated when relevant data changes.
**No authentication required** - this is a public endpoint.
""".strip(),
responses={

View File

@@ -5,14 +5,15 @@ This module contains endpoints for trending and new content discovery
including trending parks, rides, and recently added content.
"""
from datetime import datetime, date
from rest_framework.views import APIView
from datetime import date, datetime
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view
from rest_framework import status
from rest_framework.permissions import AllowAny, IsAdminUser
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.permissions import AllowAny, IsAdminUser
from rest_framework import status
from drf_spectacular.utils import extend_schema, extend_schema_view, OpenApiParameter
from drf_spectacular.types import OpenApiTypes
from rest_framework.views import APIView
@extend_schema_view(
@@ -111,9 +112,10 @@ class TriggerTrendingCalculationAPIView(APIView):
def post(self, request: Request) -> Response:
"""Trigger trending content calculation using management commands."""
try:
from django.core.management import call_command
import io
from contextlib import redirect_stdout, redirect_stderr
from contextlib import redirect_stderr, redirect_stdout
from django.core.management import call_command
# Capture command output
trending_output = io.StringIO()
@@ -227,10 +229,7 @@ class NewContentAPIView(APIView):
if date_added:
try:
# Parse the date string
if isinstance(date_added, str):
item_date = datetime.fromisoformat(date_added).date()
else:
item_date = date_added
item_date = datetime.fromisoformat(date_added).date() if isinstance(date_added, str) else date_added
# Calculate days difference
days_diff = (today - item_date).days

View File

@@ -2,32 +2,34 @@
API viewsets for the ride ranking system.
"""
from typing import TYPE_CHECKING, Any, Type, cast
from typing import TYPE_CHECKING, Any, cast
from django.db.models import Q, QuerySet
from django.utils import timezone
from django_filters.rest_framework import DjangoFilterBackend
from drf_spectacular.utils import extend_schema, extend_schema_view, OpenApiParameter
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view
from rest_framework import status
from rest_framework.decorators import action
from rest_framework.filters import OrderingFilter
from rest_framework.permissions import IsAuthenticatedOrReadOnly, AllowAny
from rest_framework.permissions import AllowAny, IsAuthenticatedOrReadOnly
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.serializers import BaseSerializer
from rest_framework.viewsets import ReadOnlyModelViewSet
from rest_framework.views import APIView
from rest_framework.viewsets import ReadOnlyModelViewSet
if TYPE_CHECKING:
pass
# Import models inside methods to avoid Django initialization issues
import contextlib
from .serializers_rankings import (
RideRankingSerializer,
RideRankingDetailSerializer,
RankingSnapshotSerializer,
RankingStatsSerializer,
RideRankingDetailSerializer,
RideRankingSerializer,
)
@@ -127,10 +129,8 @@ class RideRankingViewSet(ReadOnlyModelViewSet):
# Filter by minimum mutual riders
min_riders = request.query_params.get("min_riders")
if min_riders:
try:
with contextlib.suppress(ValueError):
queryset = queryset.filter(mutual_riders_count__gte=int(min_riders))
except ValueError:
pass
# Filter by park
park_slug = request.query_params.get("park")
@@ -142,12 +142,12 @@ class RideRankingViewSet(ReadOnlyModelViewSet):
def get_serializer_class(self) -> Any: # type: ignore[override]
"""Use different serializers for list vs detail."""
if self.action == "retrieve":
return cast(Type[BaseSerializer], RideRankingDetailSerializer)
return cast(type[BaseSerializer], RideRankingDetailSerializer)
elif self.action == "history":
return cast(Type[BaseSerializer], RankingSnapshotSerializer)
return cast(type[BaseSerializer], RankingSnapshotSerializer)
elif self.action == "statistics":
return cast(Type[BaseSerializer], RankingStatsSerializer)
return cast(Type[BaseSerializer], RideRankingSerializer)
return cast(type[BaseSerializer], RankingStatsSerializer)
return cast(type[BaseSerializer], RideRankingSerializer)
@action(detail=True, methods=["get"])
def history(self, request, ride_slug=None):
@@ -167,7 +167,7 @@ class RideRankingViewSet(ReadOnlyModelViewSet):
@action(detail=False, methods=["get"])
def statistics(self, request):
"""Get overall ranking system statistics."""
from apps.rides.models import RideRanking, RidePairComparison, RankingSnapshot
from apps.rides.models import RankingSnapshot, RidePairComparison, RideRanking
total_rankings = RideRanking.objects.count()
total_comparisons = RidePairComparison.objects.count()