feat: Implement initial schema and add various API, service, and management command enhancements across the application.

This commit is contained in:
pacnpal
2026-01-01 15:13:01 -05:00
parent c95f99ca10
commit b243b17af7
413 changed files with 11164 additions and 17433 deletions

37
backend/.flake8 Normal file
View File

@@ -0,0 +1,37 @@
[flake8]
# Match Black and Ruff line length
max-line-length = 120
# Ignore rules that conflict with Black formatting or are handled by other tools
ignore =
# E203: whitespace before ':' - Black intentionally does this
E203,
# E501: line too long - handled by Black/Ruff
E501,
# W503: line break before binary operator - conflicts with Black
W503,
# E226: missing whitespace around arithmetic operator - Black style
E226,
# W391: blank line at end of file - not critical
W391,
# C901: function is too complex - these are intentional for complex business logic
C901,
# F401: imported but unused - star imports for choice registration are intentional
F401
# Exclude common directories
exclude =
.git,
__pycache__,
migrations,
.venv,
venv,
build,
dist,
*.egg-info,
node_modules,
htmlcov,
.pytest_cache
# Complexity threshold - set high since we have intentional complex functions
max-complexity = 50

View File

@@ -1,2 +1,2 @@
# Import choices to trigger registration
from .choices import *
from .choices import * # noqa: F403

View File

@@ -77,8 +77,6 @@ class UserProfileInline(admin.StackedInline):
)
@admin.register(User)
class CustomUserAdmin(QueryOptimizationMixin, ExportActionMixin, UserAdmin):
"""
@@ -332,8 +330,9 @@ class CustomUserAdmin(QueryOptimizationMixin, ExportActionMixin, UserAdmin):
try:
profile = user.profile
# Credits would be recalculated from ride history here
profile.save(update_fields=["coaster_credits", "dark_ride_credits",
"flat_ride_credits", "water_ride_credits"])
profile.save(
update_fields=["coaster_credits", "dark_ride_credits", "flat_ride_credits", "water_ride_credits"]
)
count += 1
except UserProfile.DoesNotExist:
pass
@@ -442,12 +441,14 @@ class UserProfileAdmin(QueryOptimizationMixin, ExportActionMixin, BaseModelAdmin
@admin.display(description="Completeness")
def profile_completeness(self, obj):
"""Display profile completeness indicator."""
fields_filled = sum([
bool(obj.display_name),
bool(obj.avatar),
bool(obj.bio),
bool(obj.twitter or obj.instagram or obj.youtube or obj.discord),
])
fields_filled = sum(
[
bool(obj.display_name),
bool(obj.avatar),
bool(obj.bio),
bool(obj.twitter or obj.instagram or obj.youtube or obj.discord),
]
)
percentage = (fields_filled / 4) * 100
color = "green" if percentage >= 75 else "orange" if percentage >= 50 else "red"
return format_html(
@@ -529,12 +530,8 @@ class EmailVerificationAdmin(QueryOptimizationMixin, BaseModelAdmin):
def expiration_status(self, obj):
"""Display expiration status with color coding."""
if timezone.now() - obj.last_sent > timedelta(days=1):
return format_html(
'<span style="color: red; font-weight: bold;">Expired</span>'
)
return format_html(
'<span style="color: green; font-weight: bold;">Valid</span>'
)
return format_html('<span style="color: red; font-weight: bold;">Expired</span>')
return format_html('<span style="color: green; font-weight: bold;">Valid</span>')
@admin.display(description="Can Resend", boolean=True)
def can_resend(self, obj):
@@ -665,6 +662,3 @@ class PasswordResetAdmin(ReadOnlyAdminMixin, BaseModelAdmin):
"Cleanup old tokens",
)
return actions

View File

@@ -26,7 +26,7 @@ user_roles = ChoiceGroup(
"css_class": "text-blue-600 bg-blue-50",
"permissions": ["create_content", "create_reviews", "create_lists"],
"sort_order": 1,
}
},
),
RichChoice(
value="MODERATOR",
@@ -38,7 +38,7 @@ user_roles = ChoiceGroup(
"css_class": "text-green-600 bg-green-50",
"permissions": ["moderate_content", "review_submissions", "manage_reports"],
"sort_order": 2,
}
},
),
RichChoice(
value="ADMIN",
@@ -50,7 +50,7 @@ user_roles = ChoiceGroup(
"css_class": "text-purple-600 bg-purple-50",
"permissions": ["manage_users", "site_configuration", "advanced_moderation"],
"sort_order": 3,
}
},
),
RichChoice(
value="SUPERUSER",
@@ -62,9 +62,9 @@ user_roles = ChoiceGroup(
"css_class": "text-red-600 bg-red-50",
"permissions": ["full_access", "system_administration", "database_access"],
"sort_order": 4,
}
},
),
]
],
)
@@ -83,13 +83,9 @@ theme_preferences = ChoiceGroup(
"color": "yellow",
"icon": "sun",
"css_class": "text-yellow-600 bg-yellow-50",
"preview_colors": {
"background": "#ffffff",
"text": "#1f2937",
"accent": "#3b82f6"
},
"preview_colors": {"background": "#ffffff", "text": "#1f2937", "accent": "#3b82f6"},
"sort_order": 1,
}
},
),
RichChoice(
value="dark",
@@ -99,15 +95,11 @@ theme_preferences = ChoiceGroup(
"color": "gray",
"icon": "moon",
"css_class": "text-gray-600 bg-gray-50",
"preview_colors": {
"background": "#1f2937",
"text": "#f9fafb",
"accent": "#60a5fa"
},
"preview_colors": {"background": "#1f2937", "text": "#f9fafb", "accent": "#60a5fa"},
"sort_order": 2,
}
},
),
]
],
)
@@ -133,7 +125,7 @@ unit_systems = ChoiceGroup(
"large_distance": "km",
},
"sort_order": 1,
}
},
),
RichChoice(
value="imperial",
@@ -150,9 +142,9 @@ unit_systems = ChoiceGroup(
"large_distance": "mi",
},
"sort_order": 2,
}
},
),
]
],
)
@@ -177,10 +169,10 @@ privacy_levels = ChoiceGroup(
"Profile visible to all users",
"Activity appears in public feeds",
"Searchable by search engines",
"Can be found by username search"
"Can be found by username search",
],
"sort_order": 1,
}
},
),
RichChoice(
value="friends",
@@ -196,10 +188,10 @@ privacy_levels = ChoiceGroup(
"Profile visible only to friends",
"Activity hidden from public feeds",
"Not searchable by search engines",
"Requires friend request approval"
"Requires friend request approval",
],
"sort_order": 2,
}
},
),
RichChoice(
value="private",
@@ -215,12 +207,12 @@ privacy_levels = ChoiceGroup(
"Profile completely hidden",
"No activity in any feeds",
"Not discoverable by other users",
"Maximum privacy protection"
"Maximum privacy protection",
],
"sort_order": 3,
}
},
),
]
],
)
@@ -242,7 +234,7 @@ top_list_categories = ChoiceGroup(
"ride_category": "roller_coaster",
"typical_list_size": 10,
"sort_order": 1,
}
},
),
RichChoice(
value="DR",
@@ -255,7 +247,7 @@ top_list_categories = ChoiceGroup(
"ride_category": "dark_ride",
"typical_list_size": 10,
"sort_order": 2,
}
},
),
RichChoice(
value="FR",
@@ -268,7 +260,7 @@ top_list_categories = ChoiceGroup(
"ride_category": "flat_ride",
"typical_list_size": 10,
"sort_order": 3,
}
},
),
RichChoice(
value="WR",
@@ -281,7 +273,7 @@ top_list_categories = ChoiceGroup(
"ride_category": "water_ride",
"typical_list_size": 10,
"sort_order": 4,
}
},
),
RichChoice(
value="PK",
@@ -294,9 +286,9 @@ top_list_categories = ChoiceGroup(
"entity_type": "park",
"typical_list_size": 10,
"sort_order": 5,
}
},
),
]
],
)
@@ -320,7 +312,7 @@ notification_types = ChoiceGroup(
"default_channels": ["email", "push", "inapp"],
"priority": "normal",
"sort_order": 1,
}
},
),
RichChoice(
value="submission_rejected",
@@ -334,7 +326,7 @@ notification_types = ChoiceGroup(
"default_channels": ["email", "push", "inapp"],
"priority": "normal",
"sort_order": 2,
}
},
),
RichChoice(
value="submission_pending",
@@ -348,7 +340,7 @@ notification_types = ChoiceGroup(
"default_channels": ["inapp"],
"priority": "low",
"sort_order": 3,
}
},
),
# Review related
RichChoice(
@@ -363,7 +355,7 @@ notification_types = ChoiceGroup(
"default_channels": ["email", "push", "inapp"],
"priority": "normal",
"sort_order": 4,
}
},
),
RichChoice(
value="review_helpful",
@@ -377,7 +369,7 @@ notification_types = ChoiceGroup(
"default_channels": ["push", "inapp"],
"priority": "low",
"sort_order": 5,
}
},
),
# Social related
RichChoice(
@@ -392,7 +384,7 @@ notification_types = ChoiceGroup(
"default_channels": ["email", "push", "inapp"],
"priority": "normal",
"sort_order": 6,
}
},
),
RichChoice(
value="friend_accepted",
@@ -406,7 +398,7 @@ notification_types = ChoiceGroup(
"default_channels": ["push", "inapp"],
"priority": "low",
"sort_order": 7,
}
},
),
RichChoice(
value="message_received",
@@ -420,7 +412,7 @@ notification_types = ChoiceGroup(
"default_channels": ["email", "push", "inapp"],
"priority": "normal",
"sort_order": 8,
}
},
),
RichChoice(
value="profile_comment",
@@ -434,7 +426,7 @@ notification_types = ChoiceGroup(
"default_channels": ["email", "push", "inapp"],
"priority": "normal",
"sort_order": 9,
}
},
),
# System related
RichChoice(
@@ -449,7 +441,7 @@ notification_types = ChoiceGroup(
"default_channels": ["email", "inapp"],
"priority": "normal",
"sort_order": 10,
}
},
),
RichChoice(
value="account_security",
@@ -463,7 +455,7 @@ notification_types = ChoiceGroup(
"default_channels": ["email", "push", "inapp"],
"priority": "high",
"sort_order": 11,
}
},
),
RichChoice(
value="feature_update",
@@ -477,7 +469,7 @@ notification_types = ChoiceGroup(
"default_channels": ["email", "inapp"],
"priority": "low",
"sort_order": 12,
}
},
),
RichChoice(
value="maintenance",
@@ -491,7 +483,7 @@ notification_types = ChoiceGroup(
"default_channels": ["email", "inapp"],
"priority": "normal",
"sort_order": 13,
}
},
),
# Achievement related
RichChoice(
@@ -506,7 +498,7 @@ notification_types = ChoiceGroup(
"default_channels": ["push", "inapp"],
"priority": "low",
"sort_order": 14,
}
},
),
RichChoice(
value="milestone_reached",
@@ -520,9 +512,9 @@ notification_types = ChoiceGroup(
"default_channels": ["push", "inapp"],
"priority": "low",
"sort_order": 15,
}
},
),
]
],
)
@@ -545,7 +537,7 @@ notification_priorities = ChoiceGroup(
"batch_eligible": True,
"delay_minutes": 60,
"sort_order": 1,
}
},
),
RichChoice(
value="normal",
@@ -559,7 +551,7 @@ notification_priorities = ChoiceGroup(
"batch_eligible": True,
"delay_minutes": 15,
"sort_order": 2,
}
},
),
RichChoice(
value="high",
@@ -573,7 +565,7 @@ notification_priorities = ChoiceGroup(
"batch_eligible": False,
"delay_minutes": 0,
"sort_order": 3,
}
},
),
RichChoice(
value="urgent",
@@ -588,9 +580,9 @@ notification_priorities = ChoiceGroup(
"delay_minutes": 0,
"bypass_preferences": True,
"sort_order": 4,
}
},
),
]
],
)

View File

@@ -53,28 +53,34 @@ class UserExportService:
"dark_ride": profile.dark_ride_credits,
"flat_ride": profile.flat_ride_credits,
"water_ride": profile.water_ride_credits,
}
},
}
# Reviews
park_reviews = list(ParkReview.objects.filter(user=user).values(
"park__name", "rating", "review", "created_at", "updated_at", "is_published"
))
park_reviews = list(
ParkReview.objects.filter(user=user).values(
"park__name", "rating", "review", "created_at", "updated_at", "is_published"
)
)
ride_reviews = list(RideReview.objects.filter(user=user).values(
"ride__name", "rating", "review", "created_at", "updated_at", "is_published"
))
ride_reviews = list(
RideReview.objects.filter(user=user).values(
"ride__name", "rating", "review", "created_at", "updated_at", "is_published"
)
)
# Lists
user_lists = []
for user_list in UserList.objects.filter(user=user):
items = list(user_list.items.values("order", "content_type__model", "object_id", "comment"))
user_lists.append({
"title": user_list.title,
"description": user_list.description,
"created_at": user_list.created_at,
"items": items
})
user_lists.append(
{
"title": user_list.title,
"description": user_list.description,
"created_at": user_list.created_at,
"items": items,
}
)
export_data = {
"account": user_data,
@@ -85,10 +91,7 @@ class UserExportService:
"ride_reviews": ride_reviews,
"lists": user_lists,
},
"export_info": {
"generated_at": timezone.now(),
"version": "1.0"
}
"export_info": {"generated_at": timezone.now(), "version": "1.0"},
}
return export_data

View File

@@ -99,8 +99,6 @@ class LoginHistory(models.Model):
# Default cleanup for entries older than the specified days
cutoff = timezone.now() - timedelta(days=days)
deleted_count, _ = cls.objects.filter(
login_timestamp__lt=cutoff
).delete()
deleted_count, _ = cls.objects.filter(login_timestamp__lt=cutoff).delete()
return deleted_count

View File

@@ -22,20 +22,14 @@ class Command(BaseCommand):
# Check SocialAccount
self.stdout.write("\nChecking SocialAccount table:")
for account in SocialAccount.objects.all():
self.stdout.write(
f"ID: {account.pk}, Provider: {account.provider}, UID: {account.uid}"
)
self.stdout.write(f"ID: {account.pk}, Provider: {account.provider}, UID: {account.uid}")
# Check SocialToken
self.stdout.write("\nChecking SocialToken table:")
for token in SocialToken.objects.all():
self.stdout.write(
f"ID: {token.pk}, Account: {token.account}, App: {token.app}"
)
self.stdout.write(f"ID: {token.pk}, Account: {token.account}, App: {token.app}")
# Check Site
self.stdout.write("\nChecking Site table:")
for site in Site.objects.all():
self.stdout.write(
f"ID: {site.pk}, Domain: {site.domain}, Name: {site.name}"
)
self.stdout.write(f"ID: {site.pk}, Domain: {site.domain}, Name: {site.name}")

View File

@@ -17,6 +17,4 @@ class Command(BaseCommand):
self.stdout.write(f"Name: {app.name}")
self.stdout.write(f"Client ID: {app.client_id}")
self.stdout.write(f"Secret: {app.secret}")
self.stdout.write(
f"Sites: {', '.join(str(site.domain) for site in app.sites.all())}"
)
self.stdout.write(f"Sites: {', '.join(str(site.domain) for site in app.sites.all())}")

View File

@@ -15,14 +15,9 @@ class Command(BaseCommand):
# Remove migration records
cursor.execute("DELETE FROM django_migrations WHERE app='socialaccount'")
cursor.execute(
"DELETE FROM django_migrations WHERE app='accounts' "
"AND name LIKE '%social%'"
)
cursor.execute("DELETE FROM django_migrations WHERE app='accounts' " "AND name LIKE '%social%'")
# Reset sequences
cursor.execute("DELETE FROM sqlite_sequence WHERE name LIKE '%social%'")
self.stdout.write(
self.style.SUCCESS("Successfully cleaned up social auth configuration")
)
self.stdout.write(self.style.SUCCESS("Successfully cleaned up social auth configuration"))

View File

@@ -18,24 +18,18 @@ class Command(BaseCommand):
self.stdout.write(self.style.SUCCESS(f"Deleted {count} test users"))
# Delete test reviews
reviews = ParkReview.objects.filter(
user__username__in=["testuser", "moderator"]
)
reviews = ParkReview.objects.filter(user__username__in=["testuser", "moderator"])
count = reviews.count()
reviews.delete()
self.stdout.write(self.style.SUCCESS(f"Deleted {count} test reviews"))
# Delete test photos - both park and ride photos
park_photos = ParkPhoto.objects.filter(
uploader__username__in=["testuser", "moderator"]
)
park_photos = ParkPhoto.objects.filter(uploader__username__in=["testuser", "moderator"])
park_count = park_photos.count()
park_photos.delete()
self.stdout.write(self.style.SUCCESS(f"Deleted {park_count} test park photos"))
ride_photos = RidePhoto.objects.filter(
uploader__username__in=["testuser", "moderator"]
)
ride_photos = RidePhoto.objects.filter(uploader__username__in=["testuser", "moderator"])
ride_count = ride_photos.count()
ride_photos.delete()
self.stdout.write(self.style.SUCCESS(f"Deleted {ride_count} test ride photos"))

View File

@@ -37,18 +37,12 @@ class Command(BaseCommand):
provider="google",
defaults={
"name": "Google",
"client_id": (
"135166769591-nopcgmo0fkqfqfs9qe783a137mtmcrt2."
"apps.googleusercontent.com"
),
"client_id": ("135166769591-nopcgmo0fkqfqfs9qe783a137mtmcrt2." "apps.googleusercontent.com"),
"secret": "GOCSPX-Wd_0Ue0Ue0Ue0Ue0Ue0Ue0Ue0Ue",
},
)
if not created:
google_app.client_id = (
"135166769591-nopcgmo0fkqfqfs9qe783a137mtmcrt2."
"apps.googleusercontent.com"
)
google_app.client_id = "135166769591-nopcgmo0fkqfqfs9qe783a137mtmcrt2." "apps.googleusercontent.com"
google_app.secret = "GOCSPX-Wd_0Ue0Ue0Ue0Ue0Ue0Ue0Ue0Ue"
google_app.save()
google_app.sites.add(site)

View File

@@ -14,9 +14,7 @@ class Command(BaseCommand):
)
user.set_password("testpass123")
user.save()
self.stdout.write(
self.style.SUCCESS(f"Created test user: {user.get_username()}")
)
self.stdout.write(self.style.SUCCESS(f"Created test user: {user.get_username()}"))
else:
self.stdout.write(self.style.WARNING("Test user already exists"))
@@ -47,11 +45,7 @@ class Command(BaseCommand):
# Add user to moderator group
moderator.groups.add(moderator_group)
self.stdout.write(
self.style.SUCCESS(
f"Created moderator user: {moderator.get_username()}"
)
)
self.stdout.write(self.style.SUCCESS(f"Created moderator user: {moderator.get_username()}"))
else:
self.stdout.write(self.style.WARNING("Moderator user already exists"))

View File

@@ -17,9 +17,7 @@ class Command(BaseCommand):
help = "Delete a user while preserving all their submissions"
def add_arguments(self, parser):
parser.add_argument(
"username", nargs="?", type=str, help="Username of the user to delete"
)
parser.add_argument("username", nargs="?", type=str, help="Username of the user to delete")
parser.add_argument(
"--user-id",
type=str,
@@ -30,9 +28,7 @@ class Command(BaseCommand):
action="store_true",
help="Show what would be deleted without actually deleting",
)
parser.add_argument(
"--force", action="store_true", help="Skip confirmation prompt"
)
parser.add_argument("--force", action="store_true", help="Skip confirmation prompt")
def handle(self, *args, **options):
username = options.get("username")
@@ -52,7 +48,7 @@ class Command(BaseCommand):
user = User.objects.get(username=username) if username else User.objects.get(user_id=user_id)
except User.DoesNotExist:
identifier = username or user_id
raise CommandError(f'User "{identifier}" does not exist')
raise CommandError(f'User "{identifier}" does not exist') from None
# Check if user can be deleted
can_delete, reason = UserDeletionService.can_delete_user(user)
@@ -61,27 +57,13 @@ class Command(BaseCommand):
# Count submissions
submission_counts = {
"park_reviews": getattr(
user, "park_reviews", user.__class__.objects.none()
).count(),
"ride_reviews": getattr(
user, "ride_reviews", user.__class__.objects.none()
).count(),
"uploaded_park_photos": getattr(
user, "uploaded_park_photos", user.__class__.objects.none()
).count(),
"uploaded_ride_photos": getattr(
user, "uploaded_ride_photos", user.__class__.objects.none()
).count(),
"top_lists": getattr(
user, "top_lists", user.__class__.objects.none()
).count(),
"edit_submissions": getattr(
user, "edit_submissions", user.__class__.objects.none()
).count(),
"photo_submissions": getattr(
user, "photo_submissions", user.__class__.objects.none()
).count(),
"park_reviews": getattr(user, "park_reviews", user.__class__.objects.none()).count(),
"ride_reviews": getattr(user, "ride_reviews", user.__class__.objects.none()).count(),
"uploaded_park_photos": getattr(user, "uploaded_park_photos", user.__class__.objects.none()).count(),
"uploaded_ride_photos": getattr(user, "uploaded_ride_photos", user.__class__.objects.none()).count(),
"top_lists": getattr(user, "top_lists", user.__class__.objects.none()).count(),
"edit_submissions": getattr(user, "edit_submissions", user.__class__.objects.none()).count(),
"photo_submissions": getattr(user, "photo_submissions", user.__class__.objects.none()).count(),
}
total_submissions = sum(submission_counts.values())
@@ -98,9 +80,7 @@ class Command(BaseCommand):
self.stdout.write(self.style.WARNING("\nSubmissions to preserve:"))
for submission_type, count in submission_counts.items():
if count > 0:
self.stdout.write(
f' {submission_type.replace("_", " ").title()}: {count}'
)
self.stdout.write(f' {submission_type.replace("_", " ").title()}: {count}')
self.stdout.write(f"\nTotal submissions: {total_submissions}")
@@ -111,9 +91,7 @@ class Command(BaseCommand):
)
)
else:
self.stdout.write(
self.style.WARNING("\nNo submissions found for this user.")
)
self.stdout.write(self.style.WARNING("\nNo submissions found for this user."))
if dry_run:
self.stdout.write(self.style.SUCCESS("\n[DRY RUN] No changes were made."))
@@ -136,11 +114,7 @@ class Command(BaseCommand):
try:
result = UserDeletionService.delete_user_preserve_submissions(user)
self.stdout.write(
self.style.SUCCESS(
f'\nSuccessfully deleted user "{result["deleted_user"]["username"]}"'
)
)
self.stdout.write(self.style.SUCCESS(f'\nSuccessfully deleted user "{result["deleted_user"]["username"]}"'))
preserved_count = sum(result["preserved_submissions"].values())
if preserved_count > 0:
@@ -154,9 +128,7 @@ class Command(BaseCommand):
self.stdout.write(self.style.WARNING("\nPreservation Summary:"))
for submission_type, count in result["preserved_submissions"].items():
if count > 0:
self.stdout.write(
f' {submission_type.replace("_", " ").title()}: {count}'
)
self.stdout.write(f' {submission_type.replace("_", " ").title()}: {count}')
except Exception as e:
raise CommandError(f"Error deleting user: {str(e)}")
raise CommandError(f"Error deleting user: {str(e)}") from None

View File

@@ -7,12 +7,5 @@ class Command(BaseCommand):
def handle(self, *args, **kwargs):
with connection.cursor() as cursor:
cursor.execute(
"DELETE FROM django_migrations WHERE app='rides' "
"AND name='0001_initial';"
)
self.stdout.write(
self.style.SUCCESS(
"Successfully removed rides.0001_initial from migration history"
)
)
cursor.execute("DELETE FROM django_migrations WHERE app='rides' " "AND name='0001_initial';")
self.stdout.write(self.style.SUCCESS("Successfully removed rides.0001_initial from migration history"))

View File

@@ -34,6 +34,4 @@ class Command(BaseCommand):
secret=os.getenv("DISCORD_CLIENT_SECRET"),
)
discord_app.sites.add(site)
self.stdout.write(
f"Created Discord app with client_id: {discord_app.client_id}"
)
self.stdout.write(f"Created Discord app with client_id: {discord_app.client_id}")

View File

@@ -47,9 +47,7 @@ class Command(BaseCommand):
help = "Generate avatars for letters A-Z and numbers 0-9"
def handle(self, *args, **kwargs):
characters = [chr(i) for i in range(65, 91)] + [
str(i) for i in range(10)
] # A-Z and 0-9
characters = [chr(i) for i in range(65, 91)] + [str(i) for i in range(10)] # A-Z and 0-9
for char in characters:
generate_avatar(char)
self.stdout.write(self.style.SUCCESS(f"Generated avatar for {char}"))

View File

@@ -11,6 +11,4 @@ class Command(BaseCommand):
for profile in profiles:
# This will trigger the avatar generation logic in the save method
profile.save()
self.stdout.write(
self.style.SUCCESS(f"Regenerated avatar for {profile.user.username}")
)
self.stdout.write(self.style.SUCCESS(f"Regenerated avatar for {profile.user.username}"))

View File

@@ -69,18 +69,18 @@ class Command(BaseCommand):
# Security: Using Django ORM instead of raw SQL for user creation
user = User.objects.create_superuser(
username='admin',
email='admin@thrillwiki.com',
password='admin',
role='SUPERUSER',
username="admin",
email="admin@thrillwiki.com",
password="admin",
role="SUPERUSER",
)
# Create profile using ORM
UserProfile.objects.create(
user=user,
display_name='Admin',
pronouns='they/them',
bio='ThrillWiki Administrator',
display_name="Admin",
pronouns="they/them",
bio="ThrillWiki Administrator",
)
self.stdout.write("Superuser created.")

View File

@@ -30,9 +30,7 @@ class Command(BaseCommand):
google_app = SocialApp.objects.create(
provider="google",
name="Google",
client_id=(
"135166769591-nopcgmo0fkqfqfs9qe783a137mtmcrt2.apps.googleusercontent.com"
),
client_id=("135166769591-nopcgmo0fkqfqfs9qe783a137mtmcrt2.apps.googleusercontent.com"),
secret="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",
)
google_app.sites.add(site)

View File

@@ -12,13 +12,7 @@ class Command(BaseCommand):
cursor.execute("DELETE FROM socialaccount_socialapp_sites")
# Reset sequences
cursor.execute(
"DELETE FROM sqlite_sequence WHERE name='socialaccount_socialapp'"
)
cursor.execute(
"DELETE FROM sqlite_sequence WHERE name='socialaccount_socialapp_sites'"
)
cursor.execute("DELETE FROM sqlite_sequence WHERE name='socialaccount_socialapp'")
cursor.execute("DELETE FROM sqlite_sequence WHERE name='socialaccount_socialapp_sites'")
self.stdout.write(
self.style.SUCCESS("Successfully reset social auth configuration")
)
self.stdout.write(self.style.SUCCESS("Successfully reset social auth configuration"))

View File

@@ -30,9 +30,7 @@ class Command(BaseCommand):
user.is_staff = True
user.save()
self.stdout.write(
self.style.SUCCESS("Successfully set up groups and permissions")
)
self.stdout.write(self.style.SUCCESS("Successfully set up groups and permissions"))
# Print summary
for group in Group.objects.all():

View File

@@ -10,7 +10,5 @@ class Command(BaseCommand):
Site.objects.all().delete()
# Create default site
site = Site.objects.create(
id=1, domain="localhost:8000", name="ThrillWiki Development"
)
site = Site.objects.create(id=1, domain="localhost:8000", name="ThrillWiki Development")
self.stdout.write(self.style.SUCCESS(f"Created site: {site.domain}"))

View File

@@ -49,27 +49,15 @@ class Command(BaseCommand):
discord_client_secret,
]
):
self.stdout.write(
self.style.ERROR("Missing required environment variables")
)
self.stdout.write(
f"DEBUG: google_client_id is None: {google_client_id is None}"
)
self.stdout.write(
f"DEBUG: google_client_secret is None: {google_client_secret is None}"
)
self.stdout.write(
f"DEBUG: discord_client_id is None: {discord_client_id is None}"
)
self.stdout.write(
f"DEBUG: discord_client_secret is None: {discord_client_secret is None}"
)
self.stdout.write(self.style.ERROR("Missing required environment variables"))
self.stdout.write(f"DEBUG: google_client_id is None: {google_client_id is None}")
self.stdout.write(f"DEBUG: google_client_secret is None: {google_client_secret is None}")
self.stdout.write(f"DEBUG: discord_client_id is None: {discord_client_id is None}")
self.stdout.write(f"DEBUG: discord_client_secret is None: {discord_client_secret is None}")
return
# Get or create the default site
site, _ = Site.objects.get_or_create(
id=1, defaults={"domain": "localhost:8000", "name": "localhost"}
)
site, _ = Site.objects.get_or_create(id=1, defaults={"domain": "localhost:8000", "name": "localhost"})
# Set up Google
google_app, created = SocialApp.objects.get_or_create(
@@ -92,11 +80,7 @@ class Command(BaseCommand):
google_app.save()
self.stdout.write("DEBUG: Successfully updated Google app")
else:
self.stdout.write(
self.style.ERROR(
"Google client_id or secret is None, skipping update."
)
)
self.stdout.write(self.style.ERROR("Google client_id or secret is None, skipping update."))
google_app.sites.add(site)
# Set up Discord
@@ -120,11 +104,7 @@ class Command(BaseCommand):
discord_app.save()
self.stdout.write("DEBUG: Successfully updated Discord app")
else:
self.stdout.write(
self.style.ERROR(
"Discord client_id or secret is None, skipping update."
)
)
self.stdout.write(self.style.ERROR("Discord client_id or secret is None, skipping update."))
discord_app.sites.add(site)
self.stdout.write(self.style.SUCCESS("Successfully set up social auth apps"))

View File

@@ -42,6 +42,4 @@ class Command(BaseCommand):
for app in SocialApp.objects.all():
self.stdout.write(f"- {app.name} ({app.provider}): {app.client_id}")
self.stdout.write(
self.style.SUCCESS(f"\nTotal social apps: {SocialApp.objects.count()}")
)
self.stdout.write(self.style.SUCCESS(f"\nTotal social apps: {SocialApp.objects.count()}"))

View File

@@ -40,9 +40,7 @@ class Command(BaseCommand):
# Show callback URL
callback_url = "http://localhost:8000/accounts/discord/login/callback/"
self.stdout.write(
"\nCallback URL to configure in Discord Developer Portal:"
)
self.stdout.write("\nCallback URL to configure in Discord Developer Portal:")
self.stdout.write(callback_url)
# Show frontend login URL

View File

@@ -18,6 +18,4 @@ class Command(BaseCommand):
# Add all sites
for site in sites:
app.sites.add(site)
self.stdout.write(
f"Added sites: {', '.join(site.domain for site in sites)}"
)
self.stdout.write(f"Added sites: {', '.join(site.domain for site in sites)}")

View File

@@ -22,17 +22,13 @@ class Command(BaseCommand):
# Show callback URL
callback_url = "http://localhost:8000/accounts/discord/login/callback/"
self.stdout.write(
"\nCallback URL to configure in Discord Developer Portal:"
)
self.stdout.write("\nCallback URL to configure in Discord Developer Portal:")
self.stdout.write(callback_url)
# Show OAuth2 settings
self.stdout.write("\nOAuth2 settings in settings.py:")
discord_settings = settings.SOCIALACCOUNT_PROVIDERS.get("discord", {})
self.stdout.write(
f"PKCE Enabled: {discord_settings.get('OAUTH_PKCE_ENABLED', False)}"
)
self.stdout.write(f"PKCE Enabled: {discord_settings.get('OAUTH_PKCE_ENABLED', False)}")
self.stdout.write(f"Scopes: {discord_settings.get('SCOPE', [])}")
except SocialApp.DoesNotExist:

View File

@@ -38,9 +38,7 @@ class Migration(migrations.Migration):
),
(
"last_login",
models.DateTimeField(
blank=True, null=True, verbose_name="last login"
),
models.DateTimeField(blank=True, null=True, verbose_name="last login"),
),
(
"is_superuser",
@@ -53,29 +51,21 @@ class Migration(migrations.Migration):
(
"username",
models.CharField(
error_messages={
"unique": "A user with that username already exists."
},
error_messages={"unique": "A user with that username already exists."},
help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.",
max_length=150,
unique=True,
validators=[
django.contrib.auth.validators.UnicodeUsernameValidator()
],
validators=[django.contrib.auth.validators.UnicodeUsernameValidator()],
verbose_name="username",
),
),
(
"first_name",
models.CharField(
blank=True, max_length=150, verbose_name="first name"
),
models.CharField(blank=True, max_length=150, verbose_name="first name"),
),
(
"last_name",
models.CharField(
blank=True, max_length=150, verbose_name="last name"
),
models.CharField(blank=True, max_length=150, verbose_name="last name"),
),
(
"email",

View File

@@ -57,9 +57,7 @@ class Migration(migrations.Migration):
("password", models.CharField(max_length=128, verbose_name="password")),
(
"last_login",
models.DateTimeField(
blank=True, null=True, verbose_name="last login"
),
models.DateTimeField(blank=True, null=True, verbose_name="last login"),
),
(
"is_superuser",
@@ -72,34 +70,24 @@ class Migration(migrations.Migration):
(
"username",
models.CharField(
error_messages={
"unique": "A user with that username already exists."
},
error_messages={"unique": "A user with that username already exists."},
help_text="Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only.",
max_length=150,
validators=[
django.contrib.auth.validators.UnicodeUsernameValidator()
],
validators=[django.contrib.auth.validators.UnicodeUsernameValidator()],
verbose_name="username",
),
),
(
"first_name",
models.CharField(
blank=True, max_length=150, verbose_name="first name"
),
models.CharField(blank=True, max_length=150, verbose_name="first name"),
),
(
"last_name",
models.CharField(
blank=True, max_length=150, verbose_name="last name"
),
models.CharField(blank=True, max_length=150, verbose_name="last name"),
),
(
"email",
models.EmailField(
blank=True, max_length=254, verbose_name="email address"
),
models.EmailField(blank=True, max_length=254, verbose_name="email address"),
),
(
"is_staff",
@@ -119,9 +107,7 @@ class Migration(migrations.Migration):
),
(
"date_joined",
models.DateTimeField(
default=django.utils.timezone.now, verbose_name="date joined"
),
models.DateTimeField(default=django.utils.timezone.now, verbose_name="date joined"),
),
(
"user_id",

View File

@@ -41,9 +41,7 @@ class Migration(migrations.Migration):
("created_at", models.DateTimeField(auto_now_add=True)),
(
"expires_at",
models.DateTimeField(
help_text="When this deletion request expires"
),
models.DateTimeField(help_text="When this deletion request expires"),
),
(
"email_sent_at",
@@ -55,9 +53,7 @@ class Migration(migrations.Migration):
),
(
"attempts",
models.PositiveIntegerField(
default=0, help_text="Number of verification attempts made"
),
models.PositiveIntegerField(default=0, help_text="Number of verification attempts made"),
),
(
"max_attempts",
@@ -103,9 +99,7 @@ class Migration(migrations.Migration):
("created_at", models.DateTimeField(auto_now_add=True)),
(
"expires_at",
models.DateTimeField(
help_text="When this deletion request expires"
),
models.DateTimeField(help_text="When this deletion request expires"),
),
(
"email_sent_at",
@@ -117,9 +111,7 @@ class Migration(migrations.Migration):
),
(
"attempts",
models.PositiveIntegerField(
default=0, help_text="Number of verification attempts made"
),
models.PositiveIntegerField(default=0, help_text="Number of verification attempts made"),
),
(
"max_attempts",
@@ -171,21 +163,15 @@ class Migration(migrations.Migration):
),
migrations.AddIndex(
model_name="userdeletionrequest",
index=models.Index(
fields=["verification_code"], name="accounts_us_verific_94460d_idx"
),
index=models.Index(fields=["verification_code"], name="accounts_us_verific_94460d_idx"),
),
migrations.AddIndex(
model_name="userdeletionrequest",
index=models.Index(
fields=["expires_at"], name="accounts_us_expires_1d1dca_idx"
),
index=models.Index(fields=["expires_at"], name="accounts_us_expires_1d1dca_idx"),
),
migrations.AddIndex(
model_name="userdeletionrequest",
index=models.Index(
fields=["user", "is_used"], name="accounts_us_user_id_1ce18a_idx"
),
index=models.Index(fields=["user", "is_used"], name="accounts_us_user_id_1ce18a_idx"),
),
pgtrigger.migrations.AddTrigger(
model_name="userdeletionrequest",

View File

@@ -57,9 +57,7 @@ class Migration(migrations.Migration):
migrations.AddField(
model_name="user",
name="last_password_change",
field=models.DateTimeField(
auto_now_add=True, default=django.utils.timezone.now
),
field=models.DateTimeField(auto_now_add=True, default=django.utils.timezone.now),
preserve_default=False,
),
migrations.AddField(
@@ -185,9 +183,7 @@ class Migration(migrations.Migration):
migrations.AddField(
model_name="userevent",
name="last_password_change",
field=models.DateTimeField(
auto_now_add=True, default=django.utils.timezone.now
),
field=models.DateTimeField(auto_now_add=True, default=django.utils.timezone.now),
preserve_default=False,
),
migrations.AddField(

View File

@@ -454,9 +454,7 @@ class Migration(migrations.Migration):
),
migrations.AddIndex(
model_name="usernotification",
index=models.Index(
fields=["user", "is_read"], name="accounts_us_user_id_785929_idx"
),
index=models.Index(fields=["user", "is_read"], name="accounts_us_user_id_785929_idx"),
),
migrations.AddIndex(
model_name="usernotification",
@@ -467,15 +465,11 @@ class Migration(migrations.Migration):
),
migrations.AddIndex(
model_name="usernotification",
index=models.Index(
fields=["created_at"], name="accounts_us_created_a62f54_idx"
),
index=models.Index(fields=["created_at"], name="accounts_us_created_a62f54_idx"),
),
migrations.AddIndex(
model_name="usernotification",
index=models.Index(
fields=["expires_at"], name="accounts_us_expires_f267b1_idx"
),
index=models.Index(fields=["expires_at"], name="accounts_us_expires_f267b1_idx"),
),
pgtrigger.migrations.AddTrigger(
model_name="usernotification",

View File

@@ -26,25 +26,24 @@ def safe_add_avatar_field(apps, schema_editor):
"""
# Check if the column already exists
with schema_editor.connection.cursor() as cursor:
cursor.execute("""
cursor.execute(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_name='accounts_userprofile'
AND column_name='avatar_id'
""")
"""
)
column_exists = cursor.fetchone() is not None
if not column_exists:
# Column doesn't exist, add it
UserProfile = apps.get_model('accounts', 'UserProfile')
UserProfile = apps.get_model("accounts", "UserProfile")
field = models.ForeignKey(
'django_cloudflareimages_toolkit.CloudflareImage',
on_delete=models.SET_NULL,
null=True,
blank=True
"django_cloudflareimages_toolkit.CloudflareImage", on_delete=models.SET_NULL, null=True, blank=True
)
field.set_attributes_from_name('avatar')
field.set_attributes_from_name("avatar")
schema_editor.add_field(UserProfile, field)
@@ -54,24 +53,23 @@ def reverse_safe_add_avatar_field(apps, schema_editor):
"""
# Check if the column exists and remove it
with schema_editor.connection.cursor() as cursor:
cursor.execute("""
cursor.execute(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_name='accounts_userprofile'
AND column_name='avatar_id'
""")
"""
)
column_exists = cursor.fetchone() is not None
if column_exists:
UserProfile = apps.get_model('accounts', 'UserProfile')
UserProfile = apps.get_model("accounts", "UserProfile")
field = models.ForeignKey(
'django_cloudflareimages_toolkit.CloudflareImage',
on_delete=models.SET_NULL,
null=True,
blank=True
"django_cloudflareimages_toolkit.CloudflareImage", on_delete=models.SET_NULL, null=True, blank=True
)
field.set_attributes_from_name('avatar')
field.set_attributes_from_name("avatar")
schema_editor.remove_field(UserProfile, field)
@@ -89,15 +87,13 @@ class Migration(migrations.Migration):
# First, remove the old avatar column (CloudflareImageField)
migrations.RunSQL(
"ALTER TABLE accounts_userprofile DROP COLUMN IF EXISTS avatar;",
reverse_sql="-- Cannot reverse this operation"
reverse_sql="-- Cannot reverse this operation",
),
# Safely add the new avatar_id column for ForeignKey
migrations.RunPython(
safe_add_avatar_field,
reverse_safe_add_avatar_field,
),
# Run the data migration
migrations.RunPython(
migrate_avatar_data,

View File

@@ -6,17 +6,16 @@ from django.db import migrations
class Migration(migrations.Migration):
dependencies = [
('accounts', '0010_auto_20250830_1657'),
('django_cloudflareimages_toolkit', '0001_initial'),
("accounts", "0010_auto_20250830_1657"),
("django_cloudflareimages_toolkit", "0001_initial"),
]
operations = [
# Remove the old avatar field from the event table
migrations.RunSQL(
"ALTER TABLE accounts_userprofileevent DROP COLUMN IF EXISTS avatar;",
reverse_sql="-- Cannot reverse this operation"
reverse_sql="-- Cannot reverse this operation",
),
# Add the new avatar_id field to match the main table (only if it doesn't exist)
migrations.RunSQL(
"""
@@ -32,6 +31,6 @@ class Migration(migrations.Migration):
END IF;
END $$;
""",
reverse_sql="ALTER TABLE accounts_userprofileevent DROP COLUMN IF EXISTS avatar_id;"
reverse_sql="ALTER TABLE accounts_userprofileevent DROP COLUMN IF EXISTS avatar_id;",
),
]

View File

@@ -13,28 +13,28 @@ from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('accounts', '0012_alter_toplist_category_and_more'),
("accounts", "0012_alter_toplist_category_and_more"),
]
operations = [
# Add db_index to is_banned field
migrations.AlterField(
model_name='user',
name='is_banned',
model_name="user",
name="is_banned",
field=models.BooleanField(default=False, db_index=True),
),
# Add composite index for common query patterns
migrations.AddIndex(
model_name='user',
index=models.Index(fields=['is_banned', 'role'], name='accounts_user_banned_role_idx'),
model_name="user",
index=models.Index(fields=["is_banned", "role"], name="accounts_user_banned_role_idx"),
),
# Add CheckConstraint for ban consistency
migrations.AddConstraint(
model_name='user',
model_name="user",
constraint=models.CheckConstraint(
name='user_ban_consistency',
name="user_ban_consistency",
check=models.Q(is_banned=False) | models.Q(ban_date__isnull=False),
violation_error_message='Banned users must have a ban_date set'
violation_error_message="Banned users must have a ban_date set",
),
),
]

View File

@@ -18,7 +18,6 @@ class Migration(migrations.Migration):
]
operations = [
migrations.AlterModelOptions(
name="user",
options={"verbose_name": "User", "verbose_name_plural": "Users"},
@@ -58,9 +57,7 @@ class Migration(migrations.Migration):
migrations.AddField(
model_name="userprofile",
name="location",
field=models.CharField(
blank=True, help_text="User's location (City, Country)", max_length=100
),
field=models.CharField(blank=True, help_text="User's location (City, Country)", max_length=100),
),
migrations.AddField(
model_name="userprofile",
@@ -78,9 +75,7 @@ class Migration(migrations.Migration):
migrations.AddField(
model_name="userprofileevent",
name="location",
field=models.CharField(
blank=True, help_text="User's location (City, Country)", max_length=100
),
field=models.CharField(blank=True, help_text="User's location (City, Country)", max_length=100),
),
migrations.AddField(
model_name="userprofileevent",
@@ -98,23 +93,17 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="emailverification",
name="created_at",
field=models.DateTimeField(
auto_now_add=True, help_text="When this verification was created"
),
field=models.DateTimeField(auto_now_add=True, help_text="When this verification was created"),
),
migrations.AlterField(
model_name="emailverification",
name="last_sent",
field=models.DateTimeField(
auto_now_add=True, help_text="When the verification email was last sent"
),
field=models.DateTimeField(auto_now_add=True, help_text="When the verification email was last sent"),
),
migrations.AlterField(
model_name="emailverification",
name="token",
field=models.CharField(
help_text="Verification token", max_length=64, unique=True
),
field=models.CharField(help_text="Verification token", max_length=64, unique=True),
),
migrations.AlterField(
model_name="emailverification",
@@ -128,16 +117,12 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="emailverificationevent",
name="created_at",
field=models.DateTimeField(
auto_now_add=True, help_text="When this verification was created"
),
field=models.DateTimeField(auto_now_add=True, help_text="When this verification was created"),
),
migrations.AlterField(
model_name="emailverificationevent",
name="last_sent",
field=models.DateTimeField(
auto_now_add=True, help_text="When the verification email was last sent"
),
field=models.DateTimeField(auto_now_add=True, help_text="When the verification email was last sent"),
),
migrations.AlterField(
model_name="emailverificationevent",
@@ -181,9 +166,7 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="passwordreset",
name="created_at",
field=models.DateTimeField(
auto_now_add=True, help_text="When this reset was requested"
),
field=models.DateTimeField(auto_now_add=True, help_text="When this reset was requested"),
),
migrations.AlterField(
model_name="passwordreset",
@@ -198,9 +181,7 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="passwordreset",
name="used",
field=models.BooleanField(
default=False, help_text="Whether this token has been used"
),
field=models.BooleanField(default=False, help_text="Whether this token has been used"),
),
migrations.AlterField(
model_name="passwordreset",
@@ -214,9 +195,7 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="passwordresetevent",
name="created_at",
field=models.DateTimeField(
auto_now_add=True, help_text="When this reset was requested"
),
field=models.DateTimeField(auto_now_add=True, help_text="When this reset was requested"),
),
migrations.AlterField(
model_name="passwordresetevent",
@@ -231,9 +210,7 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="passwordresetevent",
name="used",
field=models.BooleanField(
default=False, help_text="Whether this token has been used"
),
field=models.BooleanField(default=False, help_text="Whether this token has been used"),
),
migrations.AlterField(
model_name="passwordresetevent",
@@ -267,30 +244,22 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="user",
name="allow_friend_requests",
field=models.BooleanField(
default=True, help_text="Whether to allow friend requests"
),
field=models.BooleanField(default=True, help_text="Whether to allow friend requests"),
),
migrations.AlterField(
model_name="user",
name="allow_messages",
field=models.BooleanField(
default=True, help_text="Whether to allow direct messages"
),
field=models.BooleanField(default=True, help_text="Whether to allow direct messages"),
),
migrations.AlterField(
model_name="user",
name="allow_profile_comments",
field=models.BooleanField(
default=False, help_text="Whether to allow profile comments"
),
field=models.BooleanField(default=False, help_text="Whether to allow profile comments"),
),
migrations.AlterField(
model_name="user",
name="ban_date",
field=models.DateTimeField(
blank=True, help_text="Date the user was banned", null=True
),
field=models.DateTimeField(blank=True, help_text="Date the user was banned", null=True),
),
migrations.AlterField(
model_name="user",
@@ -300,37 +269,27 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="user",
name="email_notifications",
field=models.BooleanField(
default=True, help_text="Whether to send email notifications"
),
field=models.BooleanField(default=True, help_text="Whether to send email notifications"),
),
migrations.AlterField(
model_name="user",
name="is_banned",
field=models.BooleanField(
db_index=True, default=False, help_text="Whether this user is banned"
),
field=models.BooleanField(db_index=True, default=False, help_text="Whether this user is banned"),
),
migrations.AlterField(
model_name="user",
name="last_password_change",
field=models.DateTimeField(
auto_now_add=True, help_text="When the password was last changed"
),
field=models.DateTimeField(auto_now_add=True, help_text="When the password was last changed"),
),
migrations.AlterField(
model_name="user",
name="login_history_retention",
field=models.IntegerField(
default=90, help_text="How long to retain login history (days)"
),
field=models.IntegerField(default=90, help_text="How long to retain login history (days)"),
),
migrations.AlterField(
model_name="user",
name="login_notifications",
field=models.BooleanField(
default=True, help_text="Whether to send login notifications"
),
field=models.BooleanField(default=True, help_text="Whether to send login notifications"),
),
migrations.AlterField(
model_name="user",
@@ -352,9 +311,7 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="user",
name="push_notifications",
field=models.BooleanField(
default=False, help_text="Whether to send push notifications"
),
field=models.BooleanField(default=False, help_text="Whether to send push notifications"),
),
migrations.AlterField(
model_name="user",
@@ -378,9 +335,7 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="user",
name="search_visibility",
field=models.BooleanField(
default=True, help_text="Whether profile appears in search results"
),
field=models.BooleanField(default=True, help_text="Whether profile appears in search results"),
),
migrations.AlterField(
model_name="user",
@@ -390,51 +345,37 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="user",
name="show_email",
field=models.BooleanField(
default=False, help_text="Whether to show email on profile"
),
field=models.BooleanField(default=False, help_text="Whether to show email on profile"),
),
migrations.AlterField(
model_name="user",
name="show_join_date",
field=models.BooleanField(
default=True, help_text="Whether to show join date on profile"
),
field=models.BooleanField(default=True, help_text="Whether to show join date on profile"),
),
migrations.AlterField(
model_name="user",
name="show_photos",
field=models.BooleanField(
default=True, help_text="Whether to show photos on profile"
),
field=models.BooleanField(default=True, help_text="Whether to show photos on profile"),
),
migrations.AlterField(
model_name="user",
name="show_real_name",
field=models.BooleanField(
default=True, help_text="Whether to show real name on profile"
),
field=models.BooleanField(default=True, help_text="Whether to show real name on profile"),
),
migrations.AlterField(
model_name="user",
name="show_reviews",
field=models.BooleanField(
default=True, help_text="Whether to show reviews on profile"
),
field=models.BooleanField(default=True, help_text="Whether to show reviews on profile"),
),
migrations.AlterField(
model_name="user",
name="show_statistics",
field=models.BooleanField(
default=True, help_text="Whether to show statistics on profile"
),
field=models.BooleanField(default=True, help_text="Whether to show statistics on profile"),
),
migrations.AlterField(
model_name="user",
name="show_top_lists",
field=models.BooleanField(
default=True, help_text="Whether to show top lists on profile"
),
field=models.BooleanField(default=True, help_text="Whether to show top lists on profile"),
),
migrations.AlterField(
model_name="user",
@@ -452,9 +393,7 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="user",
name="two_factor_enabled",
field=models.BooleanField(
default=False, help_text="Whether two-factor authentication is enabled"
),
field=models.BooleanField(default=False, help_text="Whether two-factor authentication is enabled"),
),
migrations.AlterField(
model_name="userevent",
@@ -476,30 +415,22 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="userevent",
name="allow_friend_requests",
field=models.BooleanField(
default=True, help_text="Whether to allow friend requests"
),
field=models.BooleanField(default=True, help_text="Whether to allow friend requests"),
),
migrations.AlterField(
model_name="userevent",
name="allow_messages",
field=models.BooleanField(
default=True, help_text="Whether to allow direct messages"
),
field=models.BooleanField(default=True, help_text="Whether to allow direct messages"),
),
migrations.AlterField(
model_name="userevent",
name="allow_profile_comments",
field=models.BooleanField(
default=False, help_text="Whether to allow profile comments"
),
field=models.BooleanField(default=False, help_text="Whether to allow profile comments"),
),
migrations.AlterField(
model_name="userevent",
name="ban_date",
field=models.DateTimeField(
blank=True, help_text="Date the user was banned", null=True
),
field=models.DateTimeField(blank=True, help_text="Date the user was banned", null=True),
),
migrations.AlterField(
model_name="userevent",
@@ -509,37 +440,27 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="userevent",
name="email_notifications",
field=models.BooleanField(
default=True, help_text="Whether to send email notifications"
),
field=models.BooleanField(default=True, help_text="Whether to send email notifications"),
),
migrations.AlterField(
model_name="userevent",
name="is_banned",
field=models.BooleanField(
default=False, help_text="Whether this user is banned"
),
field=models.BooleanField(default=False, help_text="Whether this user is banned"),
),
migrations.AlterField(
model_name="userevent",
name="last_password_change",
field=models.DateTimeField(
auto_now_add=True, help_text="When the password was last changed"
),
field=models.DateTimeField(auto_now_add=True, help_text="When the password was last changed"),
),
migrations.AlterField(
model_name="userevent",
name="login_history_retention",
field=models.IntegerField(
default=90, help_text="How long to retain login history (days)"
),
field=models.IntegerField(default=90, help_text="How long to retain login history (days)"),
),
migrations.AlterField(
model_name="userevent",
name="login_notifications",
field=models.BooleanField(
default=True, help_text="Whether to send login notifications"
),
field=models.BooleanField(default=True, help_text="Whether to send login notifications"),
),
migrations.AlterField(
model_name="userevent",
@@ -561,9 +482,7 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="userevent",
name="push_notifications",
field=models.BooleanField(
default=False, help_text="Whether to send push notifications"
),
field=models.BooleanField(default=False, help_text="Whether to send push notifications"),
),
migrations.AlterField(
model_name="userevent",
@@ -586,9 +505,7 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="userevent",
name="search_visibility",
field=models.BooleanField(
default=True, help_text="Whether profile appears in search results"
),
field=models.BooleanField(default=True, help_text="Whether profile appears in search results"),
),
migrations.AlterField(
model_name="userevent",
@@ -598,51 +515,37 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="userevent",
name="show_email",
field=models.BooleanField(
default=False, help_text="Whether to show email on profile"
),
field=models.BooleanField(default=False, help_text="Whether to show email on profile"),
),
migrations.AlterField(
model_name="userevent",
name="show_join_date",
field=models.BooleanField(
default=True, help_text="Whether to show join date on profile"
),
field=models.BooleanField(default=True, help_text="Whether to show join date on profile"),
),
migrations.AlterField(
model_name="userevent",
name="show_photos",
field=models.BooleanField(
default=True, help_text="Whether to show photos on profile"
),
field=models.BooleanField(default=True, help_text="Whether to show photos on profile"),
),
migrations.AlterField(
model_name="userevent",
name="show_real_name",
field=models.BooleanField(
default=True, help_text="Whether to show real name on profile"
),
field=models.BooleanField(default=True, help_text="Whether to show real name on profile"),
),
migrations.AlterField(
model_name="userevent",
name="show_reviews",
field=models.BooleanField(
default=True, help_text="Whether to show reviews on profile"
),
field=models.BooleanField(default=True, help_text="Whether to show reviews on profile"),
),
migrations.AlterField(
model_name="userevent",
name="show_statistics",
field=models.BooleanField(
default=True, help_text="Whether to show statistics on profile"
),
field=models.BooleanField(default=True, help_text="Whether to show statistics on profile"),
),
migrations.AlterField(
model_name="userevent",
name="show_top_lists",
field=models.BooleanField(
default=True, help_text="Whether to show top lists on profile"
),
field=models.BooleanField(default=True, help_text="Whether to show top lists on profile"),
),
migrations.AlterField(
model_name="userevent",
@@ -660,9 +563,7 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="userevent",
name="two_factor_enabled",
field=models.BooleanField(
default=False, help_text="Whether two-factor authentication is enabled"
),
field=models.BooleanField(default=False, help_text="Whether two-factor authentication is enabled"),
),
migrations.AlterField(
model_name="usernotification",
@@ -678,23 +579,17 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="usernotification",
name="email_sent",
field=models.BooleanField(
default=False, help_text="Whether email was sent"
),
field=models.BooleanField(default=False, help_text="Whether email was sent"),
),
migrations.AlterField(
model_name="usernotification",
name="email_sent_at",
field=models.DateTimeField(
blank=True, help_text="When email was sent", null=True
),
field=models.DateTimeField(blank=True, help_text="When email was sent", null=True),
),
migrations.AlterField(
model_name="usernotification",
name="is_read",
field=models.BooleanField(
default=False, help_text="Whether this notification has been read"
),
field=models.BooleanField(default=False, help_text="Whether this notification has been read"),
),
migrations.AlterField(
model_name="usernotification",
@@ -704,30 +599,22 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="usernotification",
name="object_id",
field=models.PositiveIntegerField(
blank=True, help_text="ID of related object", null=True
),
field=models.PositiveIntegerField(blank=True, help_text="ID of related object", null=True),
),
migrations.AlterField(
model_name="usernotification",
name="push_sent",
field=models.BooleanField(
default=False, help_text="Whether push notification was sent"
),
field=models.BooleanField(default=False, help_text="Whether push notification was sent"),
),
migrations.AlterField(
model_name="usernotification",
name="push_sent_at",
field=models.DateTimeField(
blank=True, help_text="When push notification was sent", null=True
),
field=models.DateTimeField(blank=True, help_text="When push notification was sent", null=True),
),
migrations.AlterField(
model_name="usernotification",
name="read_at",
field=models.DateTimeField(
blank=True, help_text="When this notification was read", null=True
),
field=models.DateTimeField(blank=True, help_text="When this notification was read", null=True),
),
migrations.AlterField(
model_name="usernotification",
@@ -761,23 +648,17 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="usernotificationevent",
name="email_sent",
field=models.BooleanField(
default=False, help_text="Whether email was sent"
),
field=models.BooleanField(default=False, help_text="Whether email was sent"),
),
migrations.AlterField(
model_name="usernotificationevent",
name="email_sent_at",
field=models.DateTimeField(
blank=True, help_text="When email was sent", null=True
),
field=models.DateTimeField(blank=True, help_text="When email was sent", null=True),
),
migrations.AlterField(
model_name="usernotificationevent",
name="is_read",
field=models.BooleanField(
default=False, help_text="Whether this notification has been read"
),
field=models.BooleanField(default=False, help_text="Whether this notification has been read"),
),
migrations.AlterField(
model_name="usernotificationevent",
@@ -787,30 +668,22 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="usernotificationevent",
name="object_id",
field=models.PositiveIntegerField(
blank=True, help_text="ID of related object", null=True
),
field=models.PositiveIntegerField(blank=True, help_text="ID of related object", null=True),
),
migrations.AlterField(
model_name="usernotificationevent",
name="push_sent",
field=models.BooleanField(
default=False, help_text="Whether push notification was sent"
),
field=models.BooleanField(default=False, help_text="Whether push notification was sent"),
),
migrations.AlterField(
model_name="usernotificationevent",
name="push_sent_at",
field=models.DateTimeField(
blank=True, help_text="When push notification was sent", null=True
),
field=models.DateTimeField(blank=True, help_text="When push notification was sent", null=True),
),
migrations.AlterField(
model_name="usernotificationevent",
name="read_at",
field=models.DateTimeField(
blank=True, help_text="When this notification was read", null=True
),
field=models.DateTimeField(blank=True, help_text="When this notification was read", null=True),
),
migrations.AlterField(
model_name="usernotificationevent",
@@ -844,37 +717,27 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="userprofile",
name="bio",
field=models.TextField(
blank=True, help_text="User biography", max_length=500
),
field=models.TextField(blank=True, help_text="User biography", max_length=500),
),
migrations.AlterField(
model_name="userprofile",
name="coaster_credits",
field=models.IntegerField(
default=0, help_text="Number of roller coasters ridden"
),
field=models.IntegerField(default=0, help_text="Number of roller coasters ridden"),
),
migrations.AlterField(
model_name="userprofile",
name="dark_ride_credits",
field=models.IntegerField(
default=0, help_text="Number of dark rides ridden"
),
field=models.IntegerField(default=0, help_text="Number of dark rides ridden"),
),
migrations.AlterField(
model_name="userprofile",
name="discord",
field=models.CharField(
blank=True, help_text="Discord username", max_length=100
),
field=models.CharField(blank=True, help_text="Discord username", max_length=100),
),
migrations.AlterField(
model_name="userprofile",
name="flat_ride_credits",
field=models.IntegerField(
default=0, help_text="Number of flat rides ridden"
),
field=models.IntegerField(default=0, help_text="Number of flat rides ridden"),
),
migrations.AlterField(
model_name="userprofile",
@@ -884,9 +747,7 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="userprofile",
name="pronouns",
field=models.CharField(
blank=True, help_text="User's preferred pronouns", max_length=50
),
field=models.CharField(blank=True, help_text="User's preferred pronouns", max_length=50),
),
migrations.AlterField(
model_name="userprofile",
@@ -906,9 +767,7 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="userprofile",
name="water_ride_credits",
field=models.IntegerField(
default=0, help_text="Number of water rides ridden"
),
field=models.IntegerField(default=0, help_text="Number of water rides ridden"),
),
migrations.AlterField(
model_name="userprofile",
@@ -932,37 +791,27 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="userprofileevent",
name="bio",
field=models.TextField(
blank=True, help_text="User biography", max_length=500
),
field=models.TextField(blank=True, help_text="User biography", max_length=500),
),
migrations.AlterField(
model_name="userprofileevent",
name="coaster_credits",
field=models.IntegerField(
default=0, help_text="Number of roller coasters ridden"
),
field=models.IntegerField(default=0, help_text="Number of roller coasters ridden"),
),
migrations.AlterField(
model_name="userprofileevent",
name="dark_ride_credits",
field=models.IntegerField(
default=0, help_text="Number of dark rides ridden"
),
field=models.IntegerField(default=0, help_text="Number of dark rides ridden"),
),
migrations.AlterField(
model_name="userprofileevent",
name="discord",
field=models.CharField(
blank=True, help_text="Discord username", max_length=100
),
field=models.CharField(blank=True, help_text="Discord username", max_length=100),
),
migrations.AlterField(
model_name="userprofileevent",
name="flat_ride_credits",
field=models.IntegerField(
default=0, help_text="Number of flat rides ridden"
),
field=models.IntegerField(default=0, help_text="Number of flat rides ridden"),
),
migrations.AlterField(
model_name="userprofileevent",
@@ -972,9 +821,7 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="userprofileevent",
name="pronouns",
field=models.CharField(
blank=True, help_text="User's preferred pronouns", max_length=50
),
field=models.CharField(blank=True, help_text="User's preferred pronouns", max_length=50),
),
migrations.AlterField(
model_name="userprofileevent",
@@ -996,9 +843,7 @@ class Migration(migrations.Migration):
migrations.AlterField(
model_name="userprofileevent",
name="water_ride_credits",
field=models.IntegerField(
default=0, help_text="Number of water rides ridden"
),
field=models.IntegerField(default=0, help_text="Number of water rides ridden"),
),
migrations.AlterField(
model_name="userprofileevent",

View File

@@ -1,6 +1,7 @@
"""
Mixins for authentication views.
"""
from django.core.exceptions import ValidationError
from apps.core.utils.turnstile import get_client_ip, validate_turnstile_token
@@ -24,14 +25,14 @@ class TurnstileMixin:
token = None
# Check POST data (form submissions)
if hasattr(request, 'POST'):
if hasattr(request, "POST"):
token = request.POST.get("cf-turnstile-response")
# Check JSON body (API requests)
if not token and hasattr(request, 'data'):
data = getattr(request, 'data', {})
if hasattr(data, 'get'):
token = data.get('turnstile_token') or data.get('cf-turnstile-response')
if not token and hasattr(request, "data"):
data = getattr(request, "data", {})
if hasattr(data, "get"):
token = data.get("turnstile_token") or data.get("cf-turnstile-response")
# Get client IP
ip = get_client_ip(request)
@@ -39,6 +40,6 @@ class TurnstileMixin:
# Validate the token
result = validate_turnstile_token(token, ip)
if not result.get('success'):
error_msg = result.get('error', 'Captcha verification failed. Please try again.')
if not result.get("success"):
error_msg = result.get("error", "Captcha verification failed. Please try again.")
raise ValidationError(error_msg)

View File

@@ -41,10 +41,7 @@ class User(AbstractUser):
max_length=10,
unique=True,
editable=False,
help_text=(
"Unique identifier for this user that remains constant even if the "
"username changes"
),
help_text=("Unique identifier for this user that remains constant even if the " "username changes"),
)
role = RichChoiceField(
@@ -55,13 +52,9 @@ class User(AbstractUser):
db_index=True,
help_text="User role (user, moderator, admin)",
)
is_banned = models.BooleanField(
default=False, db_index=True, help_text="Whether this user is banned"
)
is_banned = models.BooleanField(default=False, db_index=True, help_text="Whether this user is banned")
ban_reason = models.TextField(blank=True, help_text="Reason for ban")
ban_date = models.DateTimeField(
null=True, blank=True, help_text="Date the user was banned"
)
ban_date = models.DateTimeField(null=True, blank=True, help_text="Date the user was banned")
pending_email = models.EmailField(blank=True, null=True)
theme_preference = RichChoiceField(
choice_group="theme_preferences",
@@ -72,12 +65,8 @@ class User(AbstractUser):
)
# Notification preferences
email_notifications = models.BooleanField(
default=True, help_text="Whether to send email notifications"
)
push_notifications = models.BooleanField(
default=False, help_text="Whether to send push notifications"
)
email_notifications = models.BooleanField(default=True, help_text="Whether to send email notifications")
push_notifications = models.BooleanField(default=False, help_text="Whether to send push notifications")
# Privacy settings
privacy_level = RichChoiceField(
@@ -87,39 +76,17 @@ class User(AbstractUser):
default="public",
help_text="Overall privacy level",
)
show_email = models.BooleanField(
default=False, help_text="Whether to show email on profile"
)
show_real_name = models.BooleanField(
default=True, help_text="Whether to show real name on profile"
)
show_join_date = models.BooleanField(
default=True, help_text="Whether to show join date on profile"
)
show_statistics = models.BooleanField(
default=True, help_text="Whether to show statistics on profile"
)
show_reviews = models.BooleanField(
default=True, help_text="Whether to show reviews on profile"
)
show_photos = models.BooleanField(
default=True, help_text="Whether to show photos on profile"
)
show_top_lists = models.BooleanField(
default=True, help_text="Whether to show top lists on profile"
)
allow_friend_requests = models.BooleanField(
default=True, help_text="Whether to allow friend requests"
)
allow_messages = models.BooleanField(
default=True, help_text="Whether to allow direct messages"
)
allow_profile_comments = models.BooleanField(
default=False, help_text="Whether to allow profile comments"
)
search_visibility = models.BooleanField(
default=True, help_text="Whether profile appears in search results"
)
show_email = models.BooleanField(default=False, help_text="Whether to show email on profile")
show_real_name = models.BooleanField(default=True, help_text="Whether to show real name on profile")
show_join_date = models.BooleanField(default=True, help_text="Whether to show join date on profile")
show_statistics = models.BooleanField(default=True, help_text="Whether to show statistics on profile")
show_reviews = models.BooleanField(default=True, help_text="Whether to show reviews on profile")
show_photos = models.BooleanField(default=True, help_text="Whether to show photos on profile")
show_top_lists = models.BooleanField(default=True, help_text="Whether to show top lists on profile")
allow_friend_requests = models.BooleanField(default=True, help_text="Whether to allow friend requests")
allow_messages = models.BooleanField(default=True, help_text="Whether to allow direct messages")
allow_profile_comments = models.BooleanField(default=False, help_text="Whether to allow profile comments")
search_visibility = models.BooleanField(default=True, help_text="Whether profile appears in search results")
activity_visibility = RichChoiceField(
choice_group="privacy_levels",
domain="accounts",
@@ -129,21 +96,11 @@ class User(AbstractUser):
)
# Security settings
two_factor_enabled = models.BooleanField(
default=False, help_text="Whether two-factor authentication is enabled"
)
login_notifications = models.BooleanField(
default=True, help_text="Whether to send login notifications"
)
session_timeout = models.IntegerField(
default=30, help_text="Session timeout in days"
)
login_history_retention = models.IntegerField(
default=90, help_text="How long to retain login history (days)"
)
last_password_change = models.DateTimeField(
auto_now_add=True, help_text="When the password was last changed"
)
two_factor_enabled = models.BooleanField(default=False, help_text="Whether two-factor authentication is enabled")
login_notifications = models.BooleanField(default=True, help_text="Whether to send login notifications")
session_timeout = models.IntegerField(default=30, help_text="Session timeout in days")
login_history_retention = models.IntegerField(default=90, help_text="How long to retain login history (days)")
last_password_change = models.DateTimeField(auto_now_add=True, help_text="When the password was last changed")
# Display name - core user data for better performance
display_name = models.CharField(
@@ -179,13 +136,13 @@ class User(AbstractUser):
verbose_name = "User"
verbose_name_plural = "Users"
indexes = [
models.Index(fields=['is_banned', 'role'], name='accounts_user_banned_role_idx'),
models.Index(fields=["is_banned", "role"], name="accounts_user_banned_role_idx"),
]
constraints = [
models.CheckConstraint(
name='user_ban_consistency',
name="user_ban_consistency",
check=models.Q(is_banned=False) | models.Q(ban_date__isnull=False),
violation_error_message='Banned users must have a ban_date set'
violation_error_message="Banned users must have a ban_date set",
),
]
@@ -224,14 +181,10 @@ class UserProfile(models.Model):
related_name="user_profiles",
help_text="User's avatar image",
)
pronouns = models.CharField(
max_length=50, blank=True, help_text="User's preferred pronouns"
)
pronouns = models.CharField(max_length=50, blank=True, help_text="User's preferred pronouns")
bio = models.TextField(max_length=500, blank=True, help_text="User biography")
location = models.CharField(
max_length=100, blank=True, help_text="User's location (City, Country)"
)
location = models.CharField(max_length=100, blank=True, help_text="User's location (City, Country)")
unit_system = RichChoiceField(
choice_group="unit_systems",
domain="accounts",
@@ -247,18 +200,10 @@ class UserProfile(models.Model):
discord = models.CharField(max_length=100, blank=True, help_text="Discord username")
# Ride statistics
coaster_credits = models.IntegerField(
default=0, help_text="Number of roller coasters ridden"
)
dark_ride_credits = models.IntegerField(
default=0, help_text="Number of dark rides ridden"
)
flat_ride_credits = models.IntegerField(
default=0, help_text="Number of flat rides ridden"
)
water_ride_credits = models.IntegerField(
default=0, help_text="Number of water rides ridden"
)
coaster_credits = models.IntegerField(default=0, help_text="Number of roller coasters ridden")
dark_ride_credits = models.IntegerField(default=0, help_text="Number of dark rides ridden")
flat_ride_credits = models.IntegerField(default=0, help_text="Number of flat rides ridden")
water_ride_credits = models.IntegerField(default=0, help_text="Number of water rides ridden")
def get_avatar_url(self):
"""
@@ -266,12 +211,12 @@ class UserProfile(models.Model):
"""
if self.avatar and self.avatar.is_uploaded:
# Try to get avatar variant first, fallback to public
avatar_url = self.avatar.get_url('avatar')
avatar_url = self.avatar.get_url("avatar")
if avatar_url:
return avatar_url
# Fallback to public variant
public_url = self.avatar.get_url('public')
public_url = self.avatar.get_url("public")
if public_url:
return public_url
@@ -298,10 +243,10 @@ class UserProfile(models.Model):
variants = {}
# Try to get specific variants
thumbnail_url = self.avatar.get_url('thumbnail')
avatar_url = self.avatar.get_url('avatar')
large_url = self.avatar.get_url('large')
public_url = self.avatar.get_url('public')
thumbnail_url = self.avatar.get_url("thumbnail")
avatar_url = self.avatar.get_url("avatar")
large_url = self.avatar.get_url("large")
public_url = self.avatar.get_url("public")
# Use specific variants if available, otherwise fallback to public or first available
fallback_url = public_url
@@ -354,18 +299,10 @@ class EmailVerification(models.Model):
on_delete=models.CASCADE,
help_text="User this verification belongs to",
)
token = models.CharField(
max_length=64, unique=True, help_text="Verification token"
)
created_at = models.DateTimeField(
auto_now_add=True, help_text="When this verification was created"
)
updated_at = models.DateTimeField(
auto_now=True, help_text="When this verification was last updated"
)
last_sent = models.DateTimeField(
auto_now_add=True, help_text="When the verification email was last sent"
)
token = models.CharField(max_length=64, unique=True, help_text="Verification token")
created_at = models.DateTimeField(auto_now_add=True, help_text="When this verification was created")
updated_at = models.DateTimeField(auto_now=True, help_text="When this verification was last updated")
last_sent = models.DateTimeField(auto_now_add=True, help_text="When the verification email was last sent")
def __str__(self):
return f"Email verification for {self.user.username}"
@@ -383,9 +320,7 @@ class PasswordReset(models.Model):
help_text="User requesting password reset",
)
token = models.CharField(max_length=64, help_text="Reset token")
created_at = models.DateTimeField(
auto_now_add=True, help_text="When this reset was requested"
)
created_at = models.DateTimeField(auto_now_add=True, help_text="When this reset was requested")
expires_at = models.DateTimeField(help_text="When this reset token expires")
used = models.BooleanField(default=False, help_text="Whether this token has been used")
@@ -397,8 +332,6 @@ class PasswordReset(models.Model):
verbose_name_plural = "Password Resets"
@pghistory.track()
class UserDeletionRequest(models.Model):
"""
@@ -409,9 +342,7 @@ class UserDeletionRequest(models.Model):
provide the correct code.
"""
user = models.OneToOneField(
User, on_delete=models.CASCADE, related_name="deletion_request"
)
user = models.OneToOneField(User, on_delete=models.CASCADE, related_name="deletion_request")
verification_code = models.CharField(
max_length=32,
@@ -422,21 +353,13 @@ class UserDeletionRequest(models.Model):
created_at = models.DateTimeField(auto_now_add=True)
expires_at = models.DateTimeField(help_text="When this deletion request expires")
email_sent_at = models.DateTimeField(
null=True, blank=True, help_text="When the verification email was sent"
)
email_sent_at = models.DateTimeField(null=True, blank=True, help_text="When the verification email was sent")
attempts = models.PositiveIntegerField(
default=0, help_text="Number of verification attempts made"
)
attempts = models.PositiveIntegerField(default=0, help_text="Number of verification attempts made")
max_attempts = models.PositiveIntegerField(
default=5, help_text="Maximum number of verification attempts allowed"
)
max_attempts = models.PositiveIntegerField(default=5, help_text="Maximum number of verification attempts allowed")
is_used = models.BooleanField(
default=False, help_text="Whether this deletion request has been used"
)
is_used = models.BooleanField(default=False, help_text="Whether this deletion request has been used")
class Meta:
verbose_name = "User Deletion Request"
@@ -466,9 +389,7 @@ class UserDeletionRequest(models.Model):
"""Generate a unique 8-character verification code."""
while True:
# Generate a random 8-character alphanumeric code
code = "".join(
secrets.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") for _ in range(8)
)
code = "".join(secrets.choice("ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789") for _ in range(8))
# Ensure it's unique
if not UserDeletionRequest.objects.filter(verification_code=code).exists():
@@ -480,11 +401,7 @@ class UserDeletionRequest(models.Model):
def is_valid(self):
"""Check if this deletion request is still valid."""
return (
not self.is_used
and not self.is_expired()
and self.attempts < self.max_attempts
)
return not self.is_used and not self.is_expired() and self.attempts < self.max_attempts
def increment_attempts(self):
"""Increment the number of verification attempts."""
@@ -499,9 +416,7 @@ class UserDeletionRequest(models.Model):
@classmethod
def cleanup_expired(cls):
"""Remove expired deletion requests."""
expired_requests = cls.objects.filter(
expires_at__lt=timezone.now(), is_used=False
)
expired_requests = cls.objects.filter(expires_at__lt=timezone.now(), is_used=False)
count = expired_requests.count()
expired_requests.delete()
return count
@@ -541,9 +456,7 @@ class UserNotification(TrackedModel):
blank=True,
help_text="Type of related object",
)
object_id = models.PositiveIntegerField(
null=True, blank=True, help_text="ID of related object"
)
object_id = models.PositiveIntegerField(null=True, blank=True, help_text="ID of related object")
related_object = GenericForeignKey("content_type", "object_id")
# Metadata
@@ -555,24 +468,14 @@ class UserNotification(TrackedModel):
)
# Status tracking
is_read = models.BooleanField(
default=False, help_text="Whether this notification has been read"
)
read_at = models.DateTimeField(
null=True, blank=True, help_text="When this notification was read"
)
is_read = models.BooleanField(default=False, help_text="Whether this notification has been read")
read_at = models.DateTimeField(null=True, blank=True, help_text="When this notification was read")
# Delivery tracking
email_sent = models.BooleanField(default=False, help_text="Whether email was sent")
email_sent_at = models.DateTimeField(
null=True, blank=True, help_text="When email was sent"
)
push_sent = models.BooleanField(
default=False, help_text="Whether push notification was sent"
)
push_sent_at = models.DateTimeField(
null=True, blank=True, help_text="When push notification was sent"
)
email_sent_at = models.DateTimeField(null=True, blank=True, help_text="When email was sent")
push_sent = models.BooleanField(default=False, help_text="Whether push notification was sent")
push_sent_at = models.DateTimeField(null=True, blank=True, help_text="When push notification was sent")
# Additional data (JSON field for flexibility)
extra_data = models.JSONField(default=dict, blank=True)
@@ -619,9 +522,7 @@ class UserNotification(TrackedModel):
@classmethod
def mark_all_read_for_user(cls, user):
"""Mark all notifications as read for a specific user."""
return cls.objects.filter(user=user, is_read=False).update(
is_read=True, read_at=timezone.now()
)
return cls.objects.filter(user=user, is_read=False).update(is_read=True, read_at=timezone.now())
@pghistory.track()

View File

@@ -27,16 +27,10 @@ def user_profile_optimized(*, user_id: int) -> Any:
User.DoesNotExist: If user doesn't exist
"""
return (
User.objects.prefetch_related(
"park_reviews", "ride_reviews", "socialaccount_set"
)
User.objects.prefetch_related("park_reviews", "ride_reviews", "socialaccount_set")
.annotate(
park_review_count=Count(
"park_reviews", filter=Q(park_reviews__is_published=True)
),
ride_review_count=Count(
"ride_reviews", filter=Q(ride_reviews__is_published=True)
),
park_review_count=Count("park_reviews", filter=Q(park_reviews__is_published=True)),
ride_review_count=Count("ride_reviews", filter=Q(ride_reviews__is_published=True)),
total_review_count=F("park_review_count") + F("ride_review_count"),
)
.get(id=user_id)
@@ -53,12 +47,8 @@ def active_users_with_stats() -> QuerySet:
return (
User.objects.filter(is_active=True)
.annotate(
park_review_count=Count(
"park_reviews", filter=Q(park_reviews__is_published=True)
),
ride_review_count=Count(
"ride_reviews", filter=Q(ride_reviews__is_published=True)
),
park_review_count=Count("park_reviews", filter=Q(park_reviews__is_published=True)),
ride_review_count=Count("ride_reviews", filter=Q(ride_reviews__is_published=True)),
total_review_count=F("park_review_count") + F("ride_review_count"),
)
.order_by("-total_review_count")
@@ -112,12 +102,8 @@ def top_reviewers(*, limit: int = 10) -> QuerySet:
return (
User.objects.filter(is_active=True)
.annotate(
park_review_count=Count(
"park_reviews", filter=Q(park_reviews__is_published=True)
),
ride_review_count=Count(
"ride_reviews", filter=Q(ride_reviews__is_published=True)
),
park_review_count=Count("park_reviews", filter=Q(park_reviews__is_published=True)),
ride_review_count=Count("ride_reviews", filter=Q(ride_reviews__is_published=True)),
total_review_count=F("park_review_count") + F("ride_review_count"),
)
.filter(total_review_count__gt=0)
@@ -159,9 +145,9 @@ def users_by_registration_date(*, start_date, end_date) -> QuerySet:
Returns:
QuerySet of users registered in the date range
"""
return User.objects.filter(
date_joined__date__gte=start_date, date_joined__date__lte=end_date
).order_by("-date_joined")
return User.objects.filter(date_joined__date__gte=start_date, date_joined__date__lte=end_date).order_by(
"-date_joined"
)
def user_search_autocomplete(*, query: str, limit: int = 10) -> QuerySet:
@@ -176,8 +162,7 @@ def user_search_autocomplete(*, query: str, limit: int = 10) -> QuerySet:
QuerySet of matching users for autocomplete
"""
return User.objects.filter(
Q(username__icontains=query)
| Q(display_name__icontains=query),
Q(username__icontains=query) | Q(display_name__icontains=query),
is_active=True,
).order_by("username")[:limit]
@@ -210,11 +195,7 @@ def user_statistics_summary() -> dict[str, Any]:
# Users with reviews
users_with_reviews = (
User.objects.filter(
Q(park_reviews__isnull=False) | Q(ride_reviews__isnull=False)
)
.distinct()
.count()
User.objects.filter(Q(park_reviews__isnull=False) | Q(ride_reviews__isnull=False)).distinct().count()
)
# Recent registrations (last 30 days)
@@ -228,9 +209,7 @@ def user_statistics_summary() -> dict[str, Any]:
"staff_users": staff_users,
"users_with_reviews": users_with_reviews,
"recent_registrations": recent_registrations,
"review_participation_rate": (
(users_with_reviews / total_users * 100) if total_users > 0 else 0
),
"review_participation_rate": ((users_with_reviews / total_users * 100) if total_users > 0 else 0),
}
@@ -241,11 +220,7 @@ def users_needing_email_verification() -> QuerySet:
Returns:
QuerySet of users with unverified emails
"""
return (
User.objects.filter(is_active=True, emailaddress__verified=False)
.distinct()
.order_by("date_joined")
)
return User.objects.filter(is_active=True, emailaddress__verified=False).distinct().order_by("date_joined")
def users_by_review_activity(*, min_reviews: int = 1) -> QuerySet:
@@ -260,12 +235,8 @@ def users_by_review_activity(*, min_reviews: int = 1) -> QuerySet:
"""
return (
User.objects.annotate(
park_review_count=Count(
"park_reviews", filter=Q(park_reviews__is_published=True)
),
ride_review_count=Count(
"ride_reviews", filter=Q(ride_reviews__is_published=True)
),
park_review_count=Count("park_reviews", filter=Q(park_reviews__is_published=True)),
ride_review_count=Count("ride_reviews", filter=Q(ride_reviews__is_published=True)),
total_review_count=F("park_review_count") + F("ride_review_count"),
)
.filter(total_review_count__gte=min_reviews)

View File

@@ -62,12 +62,8 @@ class LoginSerializer(serializers.Serializer):
Serializer for user login
"""
username = serializers.CharField(
max_length=254, help_text="Username or email address"
)
password = serializers.CharField(
max_length=128, style={"input_type": "password"}, trim_whitespace=False
)
username = serializers.CharField(max_length=254, help_text="Username or email address")
password = serializers.CharField(max_length=128, style={"input_type": "password"}, trim_whitespace=False)
def validate(self, attrs):
username = attrs.get("username")
@@ -89,9 +85,7 @@ class SignupSerializer(serializers.ModelSerializer):
validators=[validate_password],
style={"input_type": "password"},
)
password_confirm = serializers.CharField(
write_only=True, style={"input_type": "password"}
)
password_confirm = serializers.CharField(write_only=True, style={"input_type": "password"})
class Meta:
model = User
@@ -118,9 +112,7 @@ class SignupSerializer(serializers.ModelSerializer):
def validate_username(self, value):
"""Validate username is unique"""
if UserModel.objects.filter(username=value).exists():
raise serializers.ValidationError(
"A user with this username already exists."
)
raise serializers.ValidationError("A user with this username already exists.")
return value
def validate(self, attrs):
@@ -129,9 +121,7 @@ class SignupSerializer(serializers.ModelSerializer):
password_confirm = attrs.get("password_confirm")
if password != password_confirm:
raise serializers.ValidationError(
{"password_confirm": "Passwords do not match."}
)
raise serializers.ValidationError({"password_confirm": "Passwords do not match."})
return attrs
@@ -194,9 +184,7 @@ class PasswordResetSerializer(serializers.Serializer):
"site_name": site.name,
}
email_html = render_to_string(
"accounts/email/password_reset.html", context
)
email_html = render_to_string("accounts/email/password_reset.html", context)
# Narrow and validate email type for the static checker
email = getattr(self.user, "email", None)
@@ -218,15 +206,11 @@ class PasswordChangeSerializer(serializers.Serializer):
Serializer for password change
"""
old_password = serializers.CharField(
max_length=128, style={"input_type": "password"}
)
old_password = serializers.CharField(max_length=128, style={"input_type": "password"})
new_password = serializers.CharField(
max_length=128, validators=[validate_password], style={"input_type": "password"}
)
new_password_confirm = serializers.CharField(
max_length=128, style={"input_type": "password"}
)
new_password_confirm = serializers.CharField(max_length=128, style={"input_type": "password"})
def validate_old_password(self, value):
"""Validate old password is correct"""
@@ -241,9 +225,7 @@ class PasswordChangeSerializer(serializers.Serializer):
new_password_confirm = attrs.get("new_password_confirm")
if new_password != new_password_confirm:
raise serializers.ValidationError(
{"new_password_confirm": "New passwords do not match."}
)
raise serializers.ValidationError({"new_password_confirm": "New passwords do not match."})
return attrs

View File

@@ -81,21 +81,15 @@ class AccountService:
"""
# Verify old password
if not user.check_password(old_password):
logger.warning(
f"Password change failed: incorrect current password for user {user.id}"
)
return {
'success': False,
'message': "Current password is incorrect",
'redirect_url': None
}
logger.warning(f"Password change failed: incorrect current password for user {user.id}")
return {"success": False, "message": "Current password is incorrect", "redirect_url": None}
# Validate new password
if not AccountService.validate_password(new_password):
return {
'success': False,
'message': "Password must be at least 8 characters and contain uppercase, lowercase, and numbers",
'redirect_url': None
"success": False,
"message": "Password must be at least 8 characters and contain uppercase, lowercase, and numbers",
"redirect_url": None,
}
# Update password
@@ -111,9 +105,9 @@ class AccountService:
logger.info(f"Password changed successfully for user {user.id}")
return {
'success': True,
'message': "Password changed successfully. Please check your email for confirmation.",
'redirect_url': None
"success": True,
"message": "Password changed successfully. Please check your email for confirmation.",
"redirect_url": None,
}
@staticmethod
@@ -125,9 +119,7 @@ class AccountService:
"site_name": site.name,
}
email_html = render_to_string(
"accounts/email/password_change_confirmation.html", context
)
email_html = render_to_string("accounts/email/password_change_confirmation.html", context)
try:
EmailService.send_email(
@@ -166,26 +158,17 @@ class AccountService:
}
"""
if not new_email:
return {
'success': False,
'message': "New email is required"
}
return {"success": False, "message": "New email is required"}
# Check if email is already in use
if User.objects.filter(email=new_email).exclude(id=user.id).exists():
return {
'success': False,
'message': "This email address is already in use"
}
return {"success": False, "message": "This email address is already in use"}
# Generate verification token
token = get_random_string(64)
# Create or update email verification record
EmailVerification.objects.update_or_create(
user=user,
defaults={"token": token}
)
EmailVerification.objects.update_or_create(user=user, defaults={"token": token})
# Store pending email
user.pending_email = new_email
@@ -196,18 +179,10 @@ class AccountService:
logger.info(f"Email change initiated for user {user.id} to {new_email}")
return {
'success': True,
'message': "Verification email sent to your new email address"
}
return {"success": True, "message": "Verification email sent to your new email address"}
@staticmethod
def _send_email_verification(
request: HttpRequest,
user: User,
new_email: str,
token: str
) -> None:
def _send_email_verification(request: HttpRequest, user: User, new_email: str, token: str) -> None:
"""Send email verification for email change."""
from django.urls import reverse
@@ -245,22 +220,14 @@ class AccountService:
Dictionary with success status and message
"""
try:
verification = EmailVerification.objects.select_related("user").get(
token=token
)
verification = EmailVerification.objects.select_related("user").get(token=token)
except EmailVerification.DoesNotExist:
return {
'success': False,
'message': "Invalid or expired verification token"
}
return {"success": False, "message": "Invalid or expired verification token"}
user = verification.user
if not user.pending_email:
return {
'success': False,
'message': "No pending email change found"
}
return {"success": False, "message": "No pending email change found"}
# Update email
old_email = user.email
@@ -273,10 +240,7 @@ class AccountService:
logger.info(f"Email changed for user {user.id} from {old_email} to {user.email}")
return {
'success': True,
'message': "Email address updated successfully"
}
return {"success": True, "message": "Email address updated successfully"}
class UserDeletionService:
@@ -337,39 +301,17 @@ class UserDeletionService:
# Count submissions before transfer
submission_counts = {
"park_reviews": getattr(
user, "park_reviews", user.__class__.objects.none()
).count(),
"ride_reviews": getattr(
user, "ride_reviews", user.__class__.objects.none()
).count(),
"uploaded_park_photos": getattr(
user, "uploaded_park_photos", user.__class__.objects.none()
).count(),
"uploaded_ride_photos": getattr(
user, "uploaded_ride_photos", user.__class__.objects.none()
).count(),
"top_lists": getattr(
user, "top_lists", user.__class__.objects.none()
).count(),
"edit_submissions": getattr(
user, "edit_submissions", user.__class__.objects.none()
).count(),
"photo_submissions": getattr(
user, "photo_submissions", user.__class__.objects.none()
).count(),
"moderated_park_reviews": getattr(
user, "moderated_park_reviews", user.__class__.objects.none()
).count(),
"moderated_ride_reviews": getattr(
user, "moderated_ride_reviews", user.__class__.objects.none()
).count(),
"handled_submissions": getattr(
user, "handled_submissions", user.__class__.objects.none()
).count(),
"handled_photos": getattr(
user, "handled_photos", user.__class__.objects.none()
).count(),
"park_reviews": getattr(user, "park_reviews", user.__class__.objects.none()).count(),
"ride_reviews": getattr(user, "ride_reviews", user.__class__.objects.none()).count(),
"uploaded_park_photos": getattr(user, "uploaded_park_photos", user.__class__.objects.none()).count(),
"uploaded_ride_photos": getattr(user, "uploaded_ride_photos", user.__class__.objects.none()).count(),
"top_lists": getattr(user, "top_lists", user.__class__.objects.none()).count(),
"edit_submissions": getattr(user, "edit_submissions", user.__class__.objects.none()).count(),
"photo_submissions": getattr(user, "photo_submissions", user.__class__.objects.none()).count(),
"moderated_park_reviews": getattr(user, "moderated_park_reviews", user.__class__.objects.none()).count(),
"moderated_ride_reviews": getattr(user, "moderated_ride_reviews", user.__class__.objects.none()).count(),
"handled_submissions": getattr(user, "handled_submissions", user.__class__.objects.none()).count(),
"handled_photos": getattr(user, "handled_photos", user.__class__.objects.none()).count(),
}
# Transfer all submissions to deleted user
@@ -440,11 +382,17 @@ class UserDeletionService:
return False, "Cannot delete the system deleted user placeholder"
if user.is_superuser:
return False, "Superuser accounts cannot be deleted for security reasons. Please contact system administrator or remove superuser privileges first."
return (
False,
"Superuser accounts cannot be deleted for security reasons. Please contact system administrator or remove superuser privileges first.",
)
# Check if user has critical admin role
if user.role == User.Roles.ADMIN and user.is_staff:
return False, "Admin accounts with staff privileges cannot be deleted. Please remove admin privileges first or contact system administrator."
return (
False,
"Admin accounts with staff privileges cannot be deleted. Please remove admin privileges first or contact system administrator.",
)
# Add any other business rules here
@@ -492,9 +440,7 @@ class UserDeletionService:
site = Site.objects.get_current()
except Site.DoesNotExist:
# Fallback to default site
site = Site.objects.get_or_create(
id=1, defaults={"domain": "localhost:8000", "name": "localhost:8000"}
)[0]
site = Site.objects.get_or_create(id=1, defaults={"domain": "localhost:8000", "name": "localhost:8000"})[0]
# Prepare email context
context = {
@@ -502,9 +448,7 @@ class UserDeletionService:
"verification_code": deletion_request.verification_code,
"expires_at": deletion_request.expires_at,
"site_name": getattr(settings, "SITE_NAME", "ThrillWiki"),
"frontend_domain": getattr(
settings, "FRONTEND_DOMAIN", "http://localhost:3000"
),
"frontend_domain": getattr(settings, "FRONTEND_DOMAIN", "http://localhost:3000"),
}
# Render email content
@@ -564,11 +508,9 @@ The ThrillWiki Team
ValueError: If verification fails
"""
try:
deletion_request = UserDeletionRequest.objects.get(
verification_code=verification_code
)
deletion_request = UserDeletionRequest.objects.get(verification_code=verification_code)
except UserDeletionRequest.DoesNotExist:
raise ValueError("Invalid verification code")
raise ValueError("Invalid verification code") from None
# Check if request is still valid
if not deletion_request.is_valid():

View File

@@ -8,4 +8,4 @@ including social provider management, user authentication, and profile services.
from .social_provider_service import SocialProviderService
from .user_deletion_service import UserDeletionService
__all__ = ['SocialProviderService', 'UserDeletionService']
__all__ = ["SocialProviderService", "UserDeletionService"]

View File

@@ -139,7 +139,9 @@ class NotificationService:
UserNotification: The created notification
"""
title = f"Your {submission_type} needs attention"
message = f"Your {submission_type} submission has been reviewed and needs some changes before it can be approved."
message = (
f"Your {submission_type} submission has been reviewed and needs some changes before it can be approved."
)
message += f"\n\nReason: {rejection_reason}"
if additional_message:
@@ -216,9 +218,7 @@ class NotificationService:
preferences = NotificationPreference.objects.create(user=user)
# Send email notification if enabled
if preferences.should_send_notification(
notification.notification_type, "email"
):
if preferences.should_send_notification(notification.notification_type, "email"):
NotificationService._send_email_notification(notification)
# Toast notifications are always created (the notification object itself)
@@ -261,14 +261,10 @@ class NotificationService:
notification.email_sent_at = timezone.now()
notification.save(update_fields=["email_sent", "email_sent_at"])
logger.info(
f"Email notification sent to {user.email} for notification {notification.id}"
)
logger.info(f"Email notification sent to {user.email} for notification {notification.id}")
except Exception as e:
logger.error(
f"Failed to send email notification {notification.id}: {str(e)}"
)
logger.error(f"Failed to send email notification {notification.id}: {str(e)}")
@staticmethod
def get_user_notifications(
@@ -298,9 +294,7 @@ class NotificationService:
queryset = queryset.filter(notification_type__in=notification_types)
# Exclude expired notifications
queryset = queryset.filter(
models.Q(expires_at__isnull=True) | models.Q(expires_at__gt=timezone.now())
)
queryset = queryset.filter(models.Q(expires_at__isnull=True) | models.Q(expires_at__gt=timezone.now()))
if limit:
queryset = queryset[:limit]
@@ -308,9 +302,7 @@ class NotificationService:
return list(queryset)
@staticmethod
def mark_notifications_read(
user: User, notification_ids: list[int] | None = None
) -> int:
def mark_notifications_read(user: User, notification_ids: list[int] | None = None) -> int:
"""
Mark notifications as read for a user.
@@ -341,9 +333,7 @@ class NotificationService:
"""
cutoff_date = timezone.now() - timedelta(days=days)
old_notifications = UserNotification.objects.filter(
is_read=True, read_at__lt=cutoff_date
)
old_notifications = UserNotification.objects.filter(is_read=True, read_at__lt=cutoff_date)
count = old_notifications.count()
old_notifications.delete()

View File

@@ -40,23 +40,20 @@ class SocialProviderService:
"""
try:
# Count remaining social accounts after disconnection
remaining_social_accounts = user.socialaccount_set.exclude(
provider=provider
).count()
remaining_social_accounts = user.socialaccount_set.exclude(provider=provider).count()
# Check if user has email/password auth
has_password_auth = (
user.email and
user.has_usable_password() and
bool(user.password) # Not empty/unusable
)
has_password_auth = user.email and user.has_usable_password() and bool(user.password) # Not empty/unusable
# Allow disconnection only if alternative auth exists
can_disconnect = remaining_social_accounts > 0 or has_password_auth
if not can_disconnect:
if remaining_social_accounts == 0 and not has_password_auth:
return False, "Cannot disconnect your only authentication method. Please set up a password or connect another social provider first."
return (
False,
"Cannot disconnect your only authentication method. Please set up a password or connect another social provider first.",
)
elif not has_password_auth:
return False, "Please set up email/password authentication before disconnecting this provider."
else:
@@ -65,8 +62,7 @@ class SocialProviderService:
return True, "Provider can be safely disconnected."
except Exception as e:
logger.error(
f"Error checking disconnect permission for user {user.id}, provider {provider}: {e}")
logger.error(f"Error checking disconnect permission for user {user.id}, provider {provider}: {e}")
return False, "Unable to verify disconnection safety. Please try again."
@staticmethod
@@ -84,18 +80,16 @@ class SocialProviderService:
connected_providers = []
for social_account in user.socialaccount_set.all():
can_disconnect, reason = SocialProviderService.can_disconnect_provider(
user, social_account.provider
)
can_disconnect, reason = SocialProviderService.can_disconnect_provider(user, social_account.provider)
provider_info = {
'provider': social_account.provider,
'provider_name': social_account.get_provider().name,
'uid': social_account.uid,
'date_joined': social_account.date_joined,
'can_disconnect': can_disconnect,
'disconnect_reason': reason if not can_disconnect else None,
'extra_data': social_account.extra_data
"provider": social_account.provider,
"provider_name": social_account.get_provider().name,
"uid": social_account.uid,
"date_joined": social_account.date_joined,
"can_disconnect": can_disconnect,
"disconnect_reason": reason if not can_disconnect else None,
"extra_data": social_account.extra_data,
}
connected_providers.append(provider_info)
@@ -122,28 +116,25 @@ class SocialProviderService:
available_providers = []
# Get all social apps configured for this site
social_apps = SocialApp.objects.filter(sites=site).order_by('provider')
social_apps = SocialApp.objects.filter(sites=site).order_by("provider")
for social_app in social_apps:
try:
provider = registry.by_id(social_app.provider)
provider_info = {
'id': social_app.provider,
'name': provider.name,
'auth_url': request.build_absolute_uri(
f'/accounts/{social_app.provider}/login/'
"id": social_app.provider,
"name": provider.name,
"auth_url": request.build_absolute_uri(f"/accounts/{social_app.provider}/login/"),
"connect_url": request.build_absolute_uri(
f"/api/v1/auth/social/connect/{social_app.provider}/"
),
'connect_url': request.build_absolute_uri(
f'/api/v1/auth/social/connect/{social_app.provider}/'
)
}
available_providers.append(provider_info)
except Exception as e:
logger.warning(
f"Error processing provider {social_app.provider}: {e}")
logger.warning(f"Error processing provider {social_app.provider}: {e}")
continue
return available_providers
@@ -166,8 +157,7 @@ class SocialProviderService:
"""
try:
# First check if disconnection is allowed
can_disconnect, reason = SocialProviderService.can_disconnect_provider(
user, provider)
can_disconnect, reason = SocialProviderService.can_disconnect_provider(user, provider)
if not can_disconnect:
return False, reason
@@ -182,8 +172,7 @@ class SocialProviderService:
deleted_count = social_accounts.count()
social_accounts.delete()
logger.info(
f"User {user.id} disconnected {deleted_count} {provider} account(s)")
logger.info(f"User {user.id} disconnected {deleted_count} {provider} account(s)")
return True, f"{provider.title()} account disconnected successfully."
@@ -205,31 +194,24 @@ class SocialProviderService:
try:
connected_providers = SocialProviderService.get_connected_providers(user)
has_password_auth = (
user.email and
user.has_usable_password() and
bool(user.password)
)
has_password_auth = user.email and user.has_usable_password() and bool(user.password)
auth_methods_count = len(connected_providers) + \
(1 if has_password_auth else 0)
auth_methods_count = len(connected_providers) + (1 if has_password_auth else 0)
return {
'user_id': user.id,
'username': user.username,
'email': user.email,
'has_password_auth': has_password_auth,
'connected_providers': connected_providers,
'total_auth_methods': auth_methods_count,
'can_disconnect_any': auth_methods_count > 1,
'requires_password_setup': not has_password_auth and len(connected_providers) == 1
"user_id": user.id,
"username": user.username,
"email": user.email,
"has_password_auth": has_password_auth,
"connected_providers": connected_providers,
"total_auth_methods": auth_methods_count,
"can_disconnect_any": auth_methods_count > 1,
"requires_password_setup": not has_password_auth and len(connected_providers) == 1,
}
except Exception as e:
logger.error(f"Error getting auth status for user {user.id}: {e}")
return {
'error': 'Unable to retrieve authentication status'
}
return {"error": "Unable to retrieve authentication status"}
@staticmethod
def validate_provider_exists(provider: str) -> tuple[bool, str]:

View File

@@ -59,7 +59,7 @@ class UserDeletionService:
return False, "Cannot delete staff accounts"
# Check for system users (if you have any special system accounts)
if hasattr(user, 'role') and user.role in ['ADMIN', 'MODERATOR']:
if hasattr(user, "role") and user.role in ["ADMIN", "MODERATOR"]:
return False, "Cannot delete admin or moderator accounts"
return True, None
@@ -84,8 +84,7 @@ class UserDeletionService:
raise ValueError(reason)
# Generate verification code
verification_code = ''.join(secrets.choice(
string.ascii_uppercase + string.digits) for _ in range(8))
verification_code = "".join(secrets.choice(string.ascii_uppercase + string.digits) for _ in range(8))
# Set expiration (24 hours from now)
expires_at = timezone.now() + timezone.timedelta(hours=24)
@@ -97,8 +96,7 @@ class UserDeletionService:
UserDeletionService._deletion_requests[verification_code] = deletion_request
# Send verification email
UserDeletionService._send_deletion_verification_email(
user, verification_code, expires_at)
UserDeletionService._send_deletion_verification_email(user, verification_code, expires_at)
return deletion_request
@@ -136,10 +134,10 @@ class UserDeletionService:
del UserDeletionService._deletion_requests[verification_code]
# Add verification info to result
result['deletion_request'] = {
'verification_code': verification_code,
'created_at': deletion_request.created_at,
'verified_at': timezone.now(),
result["deletion_request"] = {
"verification_code": verification_code,
"created_at": deletion_request.created_at,
"verified_at": timezone.now(),
}
return result
@@ -180,13 +178,13 @@ class UserDeletionService:
"""
# Get or create the "deleted_user" placeholder
deleted_user_placeholder, created = User.objects.get_or_create(
username='deleted_user',
username="deleted_user",
defaults={
'email': 'deleted@thrillwiki.com',
'first_name': 'Deleted',
'last_name': 'User',
'is_active': False,
}
"email": "deleted@thrillwiki.com",
"first_name": "Deleted",
"last_name": "User",
"is_active": False,
},
)
# Count submissions before transfer
@@ -197,22 +195,22 @@ class UserDeletionService:
# Store user info before deletion
deleted_user_info = {
'username': user.username,
'user_id': getattr(user, 'user_id', user.id),
'email': user.email,
'date_joined': user.date_joined,
"username": user.username,
"user_id": getattr(user, "user_id", user.id),
"email": user.email,
"date_joined": user.date_joined,
}
# Delete the user account
user.delete()
return {
'deleted_user': deleted_user_info,
'preserved_submissions': submission_counts,
'transferred_to': {
'username': deleted_user_placeholder.username,
'user_id': getattr(deleted_user_placeholder, 'user_id', deleted_user_placeholder.id),
}
"deleted_user": deleted_user_info,
"preserved_submissions": submission_counts,
"transferred_to": {
"username": deleted_user_placeholder.username,
"user_id": getattr(deleted_user_placeholder, "user_id", deleted_user_placeholder.id),
},
}
@staticmethod
@@ -222,20 +220,13 @@ class UserDeletionService:
# Count different types of submissions
# Note: These are placeholder counts - adjust based on your actual models
counts['park_reviews'] = getattr(
user, 'park_reviews', user.__class__.objects.none()).count()
counts['ride_reviews'] = getattr(
user, 'ride_reviews', user.__class__.objects.none()).count()
counts['uploaded_park_photos'] = getattr(
user, 'uploaded_park_photos', user.__class__.objects.none()).count()
counts['uploaded_ride_photos'] = getattr(
user, 'uploaded_ride_photos', user.__class__.objects.none()).count()
counts['top_lists'] = getattr(
user, 'top_lists', user.__class__.objects.none()).count()
counts['edit_submissions'] = getattr(
user, 'edit_submissions', user.__class__.objects.none()).count()
counts['photo_submissions'] = getattr(
user, 'photo_submissions', user.__class__.objects.none()).count()
counts["park_reviews"] = getattr(user, "park_reviews", user.__class__.objects.none()).count()
counts["ride_reviews"] = getattr(user, "ride_reviews", user.__class__.objects.none()).count()
counts["uploaded_park_photos"] = getattr(user, "uploaded_park_photos", user.__class__.objects.none()).count()
counts["uploaded_ride_photos"] = getattr(user, "uploaded_ride_photos", user.__class__.objects.none()).count()
counts["top_lists"] = getattr(user, "top_lists", user.__class__.objects.none()).count()
counts["edit_submissions"] = getattr(user, "edit_submissions", user.__class__.objects.none()).count()
counts["photo_submissions"] = getattr(user, "photo_submissions", user.__class__.objects.none()).count()
return counts
@@ -247,30 +238,30 @@ class UserDeletionService:
# Note: Adjust these based on your actual model relationships
# Park reviews
if hasattr(user, 'park_reviews'):
if hasattr(user, "park_reviews"):
user.park_reviews.all().update(user=placeholder_user)
# Ride reviews
if hasattr(user, 'ride_reviews'):
if hasattr(user, "ride_reviews"):
user.ride_reviews.all().update(user=placeholder_user)
# Uploaded photos
if hasattr(user, 'uploaded_park_photos'):
if hasattr(user, "uploaded_park_photos"):
user.uploaded_park_photos.all().update(user=placeholder_user)
if hasattr(user, 'uploaded_ride_photos'):
if hasattr(user, "uploaded_ride_photos"):
user.uploaded_ride_photos.all().update(user=placeholder_user)
# Top lists
if hasattr(user, 'top_lists'):
if hasattr(user, "top_lists"):
user.top_lists.all().update(user=placeholder_user)
# Edit submissions
if hasattr(user, 'edit_submissions'):
if hasattr(user, "edit_submissions"):
user.edit_submissions.all().update(user=placeholder_user)
# Photo submissions
if hasattr(user, 'photo_submissions'):
if hasattr(user, "photo_submissions"):
user.photo_submissions.all().update(user=placeholder_user)
@staticmethod
@@ -278,18 +269,16 @@ class UserDeletionService:
"""Send verification email for account deletion."""
try:
context = {
'user': user,
'verification_code': verification_code,
'expires_at': expires_at,
'site_name': 'ThrillWiki',
'site_url': getattr(settings, 'SITE_URL', 'https://thrillwiki.com'),
"user": user,
"verification_code": verification_code,
"expires_at": expires_at,
"site_name": "ThrillWiki",
"site_url": getattr(settings, "SITE_URL", "https://thrillwiki.com"),
}
subject = 'ThrillWiki: Confirm Account Deletion'
html_message = render_to_string(
'emails/account_deletion_verification.html', context)
plain_message = render_to_string(
'emails/account_deletion_verification.txt', context)
subject = "ThrillWiki: Confirm Account Deletion"
html_message = render_to_string("emails/account_deletion_verification.html", context)
plain_message = render_to_string("emails/account_deletion_verification.txt", context)
send_mail(
subject=subject,
@@ -303,6 +292,5 @@ class UserDeletionService:
logger.info(f"Deletion verification email sent to {user.email}")
except Exception as e:
logger.error(
f"Failed to send deletion verification email to {user.email}: {str(e)}")
logger.error(f"Failed to send deletion verification email to {user.email}: {str(e)}")
raise

View File

@@ -108,7 +108,7 @@ def sync_user_role_with_groups(sender, instance, **kwargs):
User.Roles.MODERATOR,
]:
instance.is_staff = True
elif old_instance.role in [
elif old_instance.role in [ # noqa: SIM102
User.Roles.ADMIN,
User.Roles.MODERATOR,
]:
@@ -119,9 +119,7 @@ def sync_user_role_with_groups(sender, instance, **kwargs):
except User.DoesNotExist:
pass
except Exception as e:
print(
f"Error syncing role with groups for user {instance.username}: {str(e)}"
)
print(f"Error syncing role with groups for user {instance.username}: {str(e)}")
def create_default_groups():
@@ -200,19 +198,19 @@ def log_successful_login(sender, user, request, **kwargs):
"""
try:
# Get IP address
x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
ip_address = x_forwarded_for.split(',')[0].strip() if x_forwarded_for else request.META.get('REMOTE_ADDR')
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
ip_address = x_forwarded_for.split(",")[0].strip() if x_forwarded_for else request.META.get("REMOTE_ADDR")
# Get user agent
user_agent = request.META.get('HTTP_USER_AGENT', '')[:500]
user_agent = request.META.get("HTTP_USER_AGENT", "")[:500]
# Determine login method from session or request
login_method = 'PASSWORD'
if hasattr(request, 'session'):
sociallogin = getattr(request, '_sociallogin', None)
login_method = "PASSWORD"
if hasattr(request, "session"):
sociallogin = getattr(request, "_sociallogin", None)
if sociallogin:
provider = sociallogin.account.provider.upper()
if provider in ['GOOGLE', 'DISCORD']:
if provider in ["GOOGLE", "DISCORD"]:
login_method = provider
# Create login history entry

View File

@@ -113,16 +113,10 @@ class SignalsTestCase(TestCase):
moderator_group = Group.objects.get(name=User.Roles.MODERATOR)
self.assertIsNotNone(moderator_group)
self.assertTrue(
moderator_group.permissions.filter(codename="change_review").exists()
)
self.assertFalse(
moderator_group.permissions.filter(codename="change_user").exists()
)
self.assertTrue(moderator_group.permissions.filter(codename="change_review").exists())
self.assertFalse(moderator_group.permissions.filter(codename="change_user").exists())
admin_group = Group.objects.get(name=User.Roles.ADMIN)
self.assertIsNotNone(admin_group)
self.assertTrue(
admin_group.permissions.filter(codename="change_review").exists()
)
self.assertTrue(admin_group.permissions.filter(codename="change_review").exists())
self.assertTrue(admin_group.permissions.filter(codename="change_user").exists())

View File

@@ -150,6 +150,3 @@ class TestPasswordResetAdmin(TestCase):
request.user = UserModel(is_superuser=True)
actions = self.admin.get_actions(request)
assert "cleanup_old_tokens" in actions

View File

@@ -85,16 +85,16 @@ class UserIndexTests(TestCase):
def test_is_banned_field_is_indexed(self):
"""Verify is_banned field has db_index=True."""
field = User._meta.get_field('is_banned')
field = User._meta.get_field("is_banned")
self.assertTrue(field.db_index)
def test_role_field_is_indexed(self):
"""Verify role field has db_index=True."""
field = User._meta.get_field('role')
field = User._meta.get_field("role")
self.assertTrue(field.db_index)
def test_composite_index_exists(self):
"""Verify composite index on (is_banned, role) exists."""
indexes = User._meta.indexes
index_names = [idx.name for idx in indexes]
self.assertIn('accounts_user_banned_role_idx', index_names)
self.assertIn("accounts_user_banned_role_idx", index_names)

View File

@@ -15,9 +15,7 @@ class UserDeletionServiceTest(TestCase):
def setUp(self):
"""Set up test data."""
# Create test users
self.user = User.objects.create_user(
username="testuser", email="test@example.com", password="testpass123"
)
self.user = User.objects.create_user(username="testuser", email="test@example.com", password="testpass123")
self.admin_user = User.objects.create_user(
username="admin",
@@ -27,13 +25,9 @@ class UserDeletionServiceTest(TestCase):
)
# Create user profiles
UserProfile.objects.create(
user=self.user, display_name="Test User", bio="Test bio"
)
UserProfile.objects.create(user=self.user, display_name="Test User", bio="Test bio")
UserProfile.objects.create(
user=self.admin_user, display_name="Admin User", bio="Admin bio"
)
UserProfile.objects.create(user=self.admin_user, display_name="Admin User", bio="Admin bio")
def test_get_or_create_deleted_user(self):
"""Test that deleted user placeholder is created correctly."""
@@ -108,9 +102,7 @@ class UserDeletionServiceTest(TestCase):
with self.assertRaises(ValueError) as context:
UserDeletionService.delete_user_preserve_submissions(deleted_user)
self.assertIn(
"Cannot delete the system deleted user placeholder", str(context.exception)
)
self.assertIn("Cannot delete the system deleted user placeholder", str(context.exception))
def test_delete_user_with_submissions_transfers_correctly(self):
"""Test that user submissions are transferred to deleted user placeholder."""
@@ -141,7 +133,7 @@ class UserDeletionServiceTest(TestCase):
original_user_count = User.objects.count()
# Mock a failure during the deletion process
with self.assertRaises(Exception), transaction.atomic():
with self.assertRaises(Exception), transaction.atomic(): # noqa: B017
# Start the deletion process
UserDeletionService.get_or_create_deleted_user()

View File

@@ -61,11 +61,7 @@ class CustomLoginView(TurnstileMixin, LoginView):
context={"user_id": user.id, "username": user.username},
request=self.request,
)
return (
HttpResponseClientRefresh()
if getattr(self.request, "htmx", False)
else response
)
return HttpResponseClientRefresh() if getattr(self.request, "htmx", False) else response
def form_invalid(self, form):
log_security_event(
@@ -116,11 +112,7 @@ class CustomSignupView(TurnstileMixin, SignupView):
},
request=self.request,
)
return (
HttpResponseClientRefresh()
if getattr(self.request, "htmx", False)
else response
)
return HttpResponseClientRefresh() if getattr(self.request, "htmx", False) else response
def form_invalid(self, form):
if getattr(self.request, "htmx", False):
@@ -260,9 +252,7 @@ class SettingsView(LoginRequiredMixin, TemplateView):
and bool(re.search(r"[0-9]", password))
)
def _send_password_change_confirmation(
self, request: HttpRequest, user: User
) -> None:
def _send_password_change_confirmation(self, request: HttpRequest, user: User) -> None:
"""Send password change confirmation email."""
site = get_current_site(request)
context = {
@@ -270,9 +260,7 @@ class SettingsView(LoginRequiredMixin, TemplateView):
"site_name": site.name,
}
email_html = render_to_string(
"accounts/email/password_change_confirmation.html", context
)
email_html = render_to_string("accounts/email/password_change_confirmation.html", context)
EmailService.send_email(
to=user.email,
@@ -282,9 +270,7 @@ class SettingsView(LoginRequiredMixin, TemplateView):
html=email_html,
)
def _handle_password_change(
self, request: HttpRequest
) -> HttpResponseRedirect | None:
def _handle_password_change(self, request: HttpRequest) -> HttpResponseRedirect | None:
user = cast(User, request.user)
old_password = request.POST.get("old_password", "")
new_password = request.POST.get("new_password", "")
@@ -327,9 +313,7 @@ class SettingsView(LoginRequiredMixin, TemplateView):
def _handle_email_change(self, request: HttpRequest) -> None:
if new_email := request.POST.get("new_email"):
self._send_email_verification(request, new_email)
messages.success(
request, "Verification email sent to your new email address"
)
messages.success(request, "Verification email sent to your new email address")
else:
messages.error(request, "New email is required")
@@ -385,9 +369,7 @@ def create_password_reset_token(user: User) -> str:
return token
def send_password_reset_email(
user: User, site: Site | RequestSite, token: str
) -> None:
def send_password_reset_email(user: User, site: Site | RequestSite, token: str) -> None:
reset_url = reverse("password_reset_confirm", kwargs={"token": token})
context = {
"user": user,
@@ -457,16 +439,12 @@ def handle_password_reset(
messages.success(request, "Password reset successfully")
def send_password_reset_confirmation(
user: User, site: Site | RequestSite
) -> None:
def send_password_reset_confirmation(user: User, site: Site | RequestSite) -> None:
context = {
"user": user,
"site_name": site.name,
}
email_html = render_to_string(
"accounts/email/password_reset_complete.html", context
)
email_html = render_to_string("accounts/email/password_reset_complete.html", context)
EmailService.send_email(
to=user.email,
@@ -479,9 +457,7 @@ def send_password_reset_confirmation(
def reset_password(request: HttpRequest, token: str) -> HttpResponse:
try:
reset = PasswordReset.objects.select_related("user").get(
token=token, expires_at__gt=timezone.now(), used=False
)
reset = PasswordReset.objects.select_related("user").get(token=token, expires_at__gt=timezone.now(), used=False)
if request.method == "POST":
if new_password := request.POST.get("new_password"):

File diff suppressed because it is too large Load Diff

View File

@@ -23,6 +23,7 @@ class UserProfileUpdateInputSerializer(serializers.ModelSerializer):
cloudflare_id = validated_data.pop("cloudflare_image_id", None)
if cloudflare_id:
from django_cloudflareimages_toolkit.models import CloudflareImage
image, _ = CloudflareImage.objects.get_or_create(cloudflare_id=cloudflare_id)
instance.avatar = image

View File

@@ -76,9 +76,7 @@ urlpatterns = [
name="update_privacy_settings",
),
# Security settings endpoints
path(
"settings/security/", views.get_security_settings, name="get_security_settings"
),
path("settings/security/", views.get_security_settings, name="get_security_settings"),
path(
"settings/security/update/",
views.update_security_settings,
@@ -90,9 +88,7 @@ urlpatterns = [
path("top-lists/", views.get_user_top_lists, name="get_user_top_lists"),
path("top-lists/create/", views.create_top_list, name="create_top_list"),
path("top-lists/<int:list_id>/", views.update_top_list, name="update_top_list"),
path(
"top-lists/<int:list_id>/delete/", views.delete_top_list, name="delete_top_list"
),
path("top-lists/<int:list_id>/delete/", views.delete_top_list, name="delete_top_list"),
# Notification endpoints
path("notifications/", views.get_user_notifications, name="get_user_notifications"),
path(
@@ -114,18 +110,13 @@ urlpatterns = [
path("profile/avatar/upload/", views.upload_avatar, name="upload_avatar"),
path("profile/avatar/save/", views.save_avatar_image, name="save_avatar_image"),
path("profile/avatar/delete/", views.delete_avatar, name="delete_avatar"),
# Login history endpoint
path("login-history/", views.get_login_history, name="get_login_history"),
# Magic Link (Login by Code) endpoints
path("magic-link/request/", views_magic_link.request_magic_link, name="request_magic_link"),
path("magic-link/verify/", views_magic_link.verify_magic_link, name="verify_magic_link"),
# Public Profile
path("profiles/<str:username>/", views.get_public_user_profile, name="get_public_user_profile"),
# ViewSet routes
path("", include(router.urls)),
]

View File

@@ -69,8 +69,7 @@ logger = logging.getLogger(__name__)
200: {
"description": "User successfully deleted with submissions preserved",
"example": {
"success": True,
"message": "User successfully deleted with submissions preserved",
"detail": "User successfully deleted with submissions preserved",
"deleted_user": {
"username": "john_doe",
"user_id": "1234",
@@ -92,17 +91,16 @@ logger = logging.getLogger(__name__)
400: {
"description": "Bad request - user cannot be deleted",
"example": {
"success": False,
"error": "Cannot delete user: Cannot delete superuser accounts",
"detail": "Cannot delete user: Cannot delete superuser accounts",
},
},
404: {
"description": "User not found",
"example": {"success": False, "error": "User not found"},
"example": {"detail": "User not found"},
},
403: {
"description": "Permission denied - admin access required",
"example": {"success": False, "error": "Admin access required"},
"example": {"detail": "Admin access required"},
},
},
tags=["User Management"],
@@ -137,7 +135,7 @@ def delete_user_preserve_submissions(request, user_id):
"is_superuser": user.is_superuser,
"user_role": user.role,
"rejection_reason": reason,
}
},
)
# Determine error code based on reason
@@ -151,8 +149,7 @@ def delete_user_preserve_submissions(request, user_id):
return Response(
{
"success": False,
"error": f"Cannot delete user: {reason}",
"detail": f"Cannot delete user: {reason}",
"error_code": error_code,
"user_info": {
"username": user.username,
@@ -174,7 +171,7 @@ def delete_user_preserve_submissions(request, user_id):
"target_user": user.username,
"target_user_id": user_id,
"action": "user_deletion",
}
},
)
# Perform the deletion
@@ -185,17 +182,16 @@ def delete_user_preserve_submissions(request, user_id):
f"Successfully deleted user {result['deleted_user']['username']} (ID: {user_id}) by admin {request.user.username}",
extra={
"admin_user": request.user.username,
"deleted_user": result['deleted_user']['username'],
"deleted_user": result["deleted_user"]["username"],
"deleted_user_id": user_id,
"preserved_submissions": result['preserved_submissions'],
"preserved_submissions": result["preserved_submissions"],
"action": "user_deletion_completed",
}
},
)
return Response(
{
"success": True,
"message": "User successfully deleted with submissions preserved",
"detail": "User successfully deleted with submissions preserved",
**result,
},
status=status.HTTP_200_OK,
@@ -208,16 +204,15 @@ def delete_user_preserve_submissions(request, user_id):
extra={
"admin_user": request.user.username,
"target_user_id": user_id,
"error": str(e),
"detail": str(e),
"action": "user_deletion_error",
},
exc_info=True
exc_info=True,
)
return Response(
{
"success": False,
"error": f"Error deleting user: {str(e)}",
"detail": f"Error deleting user: {str(e)}",
"error_code": "DELETION_ERROR",
"help_text": "Please try again or contact system administrator if the problem persists.",
},
@@ -259,8 +254,7 @@ def delete_user_preserve_submissions(request, user_id):
},
},
"example": {
"success": True,
"message": "Avatar saved successfully",
"detail": "Avatar saved successfully",
"avatar_url": "https://imagedelivery.net/account-hash/image-id/avatar",
"avatar_variants": {
"thumbnail": "https://imagedelivery.net/account-hash/image-id/thumbnail",
@@ -285,7 +279,7 @@ def save_avatar_image(request):
if not cloudflare_image_id:
return Response(
{"success": False, "error": "cloudflare_image_id is required"},
{"detail": "cloudflare_image_id is required"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -299,26 +293,25 @@ def save_avatar_image(request):
if not image_data:
return Response(
{"success": False, "error": "Image not found in Cloudflare"},
{"detail": "Image not found in Cloudflare"},
status=status.HTTP_400_BAD_REQUEST,
)
# Try to find existing CloudflareImage record by cloudflare_id
cloudflare_image = None
try:
cloudflare_image = CloudflareImage.objects.get(
cloudflare_id=cloudflare_image_id)
cloudflare_image = CloudflareImage.objects.get(cloudflare_id=cloudflare_image_id)
# Update existing record with latest data from Cloudflare
cloudflare_image.status = 'uploaded'
cloudflare_image.status = "uploaded"
cloudflare_image.uploaded_at = timezone.now()
cloudflare_image.metadata = image_data.get('meta', {})
cloudflare_image.metadata = image_data.get("meta", {})
# Extract variants from nested result structure
cloudflare_image.variants = image_data.get('result', {}).get('variants', [])
cloudflare_image.variants = image_data.get("result", {}).get("variants", [])
cloudflare_image.cloudflare_metadata = image_data
cloudflare_image.width = image_data.get('width')
cloudflare_image.height = image_data.get('height')
cloudflare_image.format = image_data.get('format', '')
cloudflare_image.width = image_data.get("width")
cloudflare_image.height = image_data.get("height")
cloudflare_image.format = image_data.get("format", "")
cloudflare_image.save()
except CloudflareImage.DoesNotExist:
@@ -326,25 +319,23 @@ def save_avatar_image(request):
cloudflare_image = CloudflareImage.objects.create(
cloudflare_id=cloudflare_image_id,
user=user,
status='uploaded',
upload_url='', # Not needed for uploaded images
status="uploaded",
upload_url="", # Not needed for uploaded images
expires_at=timezone.now() + timezone.timedelta(days=365), # Set far future expiry
uploaded_at=timezone.now(),
metadata=image_data.get('meta', {}),
metadata=image_data.get("meta", {}),
# Extract variants from nested result structure
variants=image_data.get('result', {}).get('variants', []),
variants=image_data.get("result", {}).get("variants", []),
cloudflare_metadata=image_data,
width=image_data.get('width'),
height=image_data.get('height'),
format=image_data.get('format', ''),
width=image_data.get("width"),
height=image_data.get("height"),
format=image_data.get("format", ""),
)
except Exception as api_error:
logger.error(
f"Error fetching image from Cloudflare API: {str(api_error)}", exc_info=True)
logger.error(f"Error fetching image from Cloudflare API: {str(api_error)}", exc_info=True)
return Response(
{"success": False,
"error": f"Failed to fetch image from Cloudflare: {str(api_error)}"},
{"detail": f"Failed to fetch image from Cloudflare: {str(api_error)}"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -391,8 +382,7 @@ def save_avatar_image(request):
return Response(
{
"success": True,
"message": "Avatar saved successfully",
"detail": "Avatar saved successfully",
"avatar_url": avatar_url,
"avatar_variants": avatar_variants,
},
@@ -402,7 +392,7 @@ def save_avatar_image(request):
except Exception as e:
logger.error(f"Error saving avatar image: {str(e)}", exc_info=True)
return Response(
{"success": False, "error": f"Failed to save avatar: {str(e)}"},
{"detail": f"Failed to save avatar: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -420,8 +410,7 @@ def save_avatar_image(request):
"avatar_url": {"type": "string"},
},
"example": {
"success": True,
"message": "Avatar deleted successfully",
"detail": "Avatar deleted successfully",
"avatar_url": "https://ui-avatars.com/api/?name=J&size=200&background=random&color=fff&bold=true",
},
},
@@ -447,6 +436,7 @@ def delete_avatar(request):
# Delete from Cloudflare first, then from database
try:
from django_cloudflareimages_toolkit.services import CloudflareImagesService
service = CloudflareImagesService()
service.delete_image(avatar_to_delete)
logger.info(f"Successfully deleted avatar from Cloudflare: {avatar_to_delete.cloudflare_id}")
@@ -461,8 +451,7 @@ def delete_avatar(request):
return Response(
{
"success": True,
"message": "Avatar deleted successfully",
"detail": "Avatar deleted successfully",
"avatar_url": avatar_url,
},
status=status.HTTP_200_OK,
@@ -471,8 +460,7 @@ def delete_avatar(request):
except UserProfile.DoesNotExist:
return Response(
{
"success": True,
"message": "No avatar to delete",
"detail": "No avatar to delete",
"avatar_url": f"https://ui-avatars.com/api/?name={user.username[0].upper()}&size=200&background=random&color=fff&bold=true",
},
status=status.HTTP_200_OK,
@@ -480,7 +468,7 @@ def delete_avatar(request):
except Exception as e:
return Response(
{"success": False, "error": f"Failed to delete avatar: {str(e)}"},
{"detail": f"Failed to delete avatar: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -506,7 +494,7 @@ def request_account_deletion(request):
can_delete, reason = UserDeletionService.can_delete_user(user)
if not can_delete:
return Response(
{"success": False, "error": reason},
{"detail": reason},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -515,8 +503,7 @@ def request_account_deletion(request):
return Response(
{
"success": True,
"message": "Verification code sent to your email",
"detail": "Verification code sent to your email",
"expires_at": deletion_request.expires_at,
"email": user.email,
},
@@ -534,7 +521,7 @@ def request_account_deletion(request):
"user_role": request.user.role,
"rejection_reason": str(e),
"action": "self_deletion_rejected",
}
},
)
# Determine error code based on reason
@@ -549,8 +536,7 @@ def request_account_deletion(request):
return Response(
{
"success": False,
"error": error_message,
"detail": error_message,
"error_code": error_code,
"user_info": {
"username": request.user.username,
@@ -570,16 +556,15 @@ def request_account_deletion(request):
extra={
"user": request.user.username,
"user_id": request.user.user_id,
"error": str(e),
"detail": str(e),
"action": "self_deletion_error",
},
exc_info=True
exc_info=True,
)
return Response(
{
"success": False,
"error": f"Error creating deletion request: {str(e)}",
"detail": f"Error creating deletion request: {str(e)}",
"error_code": "DELETION_REQUEST_ERROR",
"help_text": "Please try again or contact support if the problem persists.",
},
@@ -611,8 +596,7 @@ def request_account_deletion(request):
200: {
"description": "Account successfully deleted",
"example": {
"success": True,
"message": "Account successfully deleted with submissions preserved",
"detail": "Account successfully deleted with submissions preserved",
"deleted_user": {
"username": "john_doe",
"user_id": "1234",
@@ -637,7 +621,7 @@ def request_account_deletion(request):
},
400: {
"description": "Invalid or expired verification code",
"example": {"success": False, "error": "Verification code has expired"},
"example": {"detail": "Verification code has expired"},
},
},
tags=["Self-Service Account Management"],
@@ -663,7 +647,7 @@ def verify_account_deletion(request):
if not verification_code:
return Response(
{"success": False, "error": "Verification code is required"},
{"detail": "Verification code is required"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -672,20 +656,17 @@ def verify_account_deletion(request):
return Response(
{
"success": True,
"message": "Account successfully deleted with submissions preserved",
"detail": "Account successfully deleted with submissions preserved",
**result,
},
status=status.HTTP_200_OK,
)
except ValueError as e:
return Response(
{"success": False, "error": str(e)}, status=status.HTTP_400_BAD_REQUEST
)
return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST)
except Exception as e:
return Response(
{"success": False, "error": f"Error verifying deletion: {str(e)}"},
{"detail": f"Error verifying deletion: {str(e)}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@@ -701,14 +682,13 @@ def verify_account_deletion(request):
200: {
"description": "Deletion request cancelled or no request found",
"example": {
"success": True,
"message": "Deletion request cancelled",
"detail": "Deletion request cancelled",
"had_pending_request": True,
},
},
401: {
"description": "Authentication required",
"example": {"success": False, "error": "Authentication required"},
"example": {"detail": "Authentication required"},
},
},
tags=["Self-Service Account Management"],
@@ -732,12 +712,7 @@ def cancel_account_deletion(request):
return Response(
{
"success": True,
"message": (
"Deletion request cancelled"
if had_request
else "No pending deletion request found"
),
"detail": ("Deletion request cancelled" if had_request else "No pending deletion request found"),
"had_pending_request": had_request,
},
status=status.HTTP_200_OK,
@@ -745,7 +720,7 @@ def cancel_account_deletion(request):
except Exception as e:
return Response(
{"success": False, "error": f"Error cancelling deletion request: {str(e)}"},
{"detail": f"Error cancelling deletion request: {str(e)}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@@ -753,10 +728,7 @@ def cancel_account_deletion(request):
@extend_schema(
operation_id="check_user_deletion_eligibility",
summary="Check if user can be deleted",
description=(
"Check if a user can be safely deleted and get a preview of "
"what submissions would be preserved."
),
description=("Check if a user can be safely deleted and get a preview of " "what submissions would be preserved."),
parameters=[
OpenApiParameter(
name="user_id",
@@ -792,11 +764,11 @@ def cancel_account_deletion(request):
},
404: {
"description": "User not found",
"example": {"success": False, "error": "User not found"},
"example": {"detail": "User not found"},
},
403: {
"description": "Permission denied - admin access required",
"example": {"success": False, "error": "Admin access required"},
"example": {"detail": "Admin access required"},
},
},
tags=["User Management"],
@@ -821,27 +793,13 @@ def check_user_deletion_eligibility(request, user_id):
# Count submissions
submission_counts = {
"park_reviews": getattr(
user, "park_reviews", user.__class__.objects.none()
).count(),
"ride_reviews": getattr(
user, "ride_reviews", user.__class__.objects.none()
).count(),
"uploaded_park_photos": getattr(
user, "uploaded_park_photos", user.__class__.objects.none()
).count(),
"uploaded_ride_photos": getattr(
user, "uploaded_ride_photos", user.__class__.objects.none()
).count(),
"top_lists": getattr(
user, "user_lists", user.__class__.objects.none()
).count(),
"edit_submissions": getattr(
user, "edit_submissions", user.__class__.objects.none()
).count(),
"photo_submissions": getattr(
user, "photo_submissions", user.__class__.objects.none()
).count(),
"park_reviews": getattr(user, "park_reviews", user.__class__.objects.none()).count(),
"ride_reviews": getattr(user, "ride_reviews", user.__class__.objects.none()).count(),
"uploaded_park_photos": getattr(user, "uploaded_park_photos", user.__class__.objects.none()).count(),
"uploaded_ride_photos": getattr(user, "uploaded_ride_photos", user.__class__.objects.none()).count(),
"top_lists": getattr(user, "user_lists", user.__class__.objects.none()).count(),
"edit_submissions": getattr(user, "edit_submissions", user.__class__.objects.none()).count(),
"photo_submissions": getattr(user, "photo_submissions", user.__class__.objects.none()).count(),
}
total_submissions = sum(submission_counts.values())
@@ -865,7 +823,7 @@ def check_user_deletion_eligibility(request, user_id):
except Exception as e:
return Response(
{"success": False, "error": f"Error checking user: {str(e)}"},
{"detail": f"Error checking user: {str(e)}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@@ -912,9 +870,7 @@ def get_user_profile(request):
@permission_classes([IsAuthenticated])
def update_user_account(request):
"""Update basic account information."""
serializer = AccountUpdateSerializer(
request.user, data=request.data, partial=True, context={"request": request}
)
serializer = AccountUpdateSerializer(request.user, data=request.data, partial=True, context={"request": request})
if serializer.is_valid():
serializer.save()
@@ -944,9 +900,7 @@ def update_user_profile(request):
"""Update user profile information."""
profile, created = UserProfile.objects.get_or_create(user=request.user)
serializer = ProfileUpdateSerializer(
profile, data=request.data, partial=True, context={"request": request}
)
serializer = ProfileUpdateSerializer(profile, data=request.data, partial=True, context={"request": request})
if serializer.is_valid():
serializer.save()
@@ -1046,9 +1000,7 @@ def update_user_preferences(request):
@permission_classes([IsAuthenticated])
def update_theme_preference(request):
"""Update theme preference."""
serializer = ThemePreferenceSerializer(
request.user, data=request.data, partial=True
)
serializer = ThemePreferenceSerializer(request.user, data=request.data, partial=True)
if serializer.is_valid():
serializer.save()
@@ -1395,14 +1347,9 @@ def update_top_list(request, list_id):
try:
top_list = UserList.objects.get(id=list_id, user=request.user)
except UserList.DoesNotExist:
return Response(
{"error": "Top list not found"},
status=status.HTTP_404_NOT_FOUND
)
return Response({"detail": "Top list not found"}, status=status.HTTP_404_NOT_FOUND)
serializer = UserListSerializer(
top_list, data=request.data, partial=True, context={"request": request}
)
serializer = UserListSerializer(top_list, data=request.data, partial=True, context={"request": request})
if serializer.is_valid():
serializer.save()
@@ -1430,10 +1377,7 @@ def delete_top_list(request, list_id):
top_list.delete()
return Response(status=status.HTTP_204_NO_CONTENT)
except UserList.DoesNotExist:
return Response(
{"error": "Top list not found"},
status=status.HTTP_404_NOT_FOUND
)
return Response({"detail": "Top list not found"}, status=status.HTTP_404_NOT_FOUND)
# === NOTIFICATION ENDPOINTS ===
@@ -1453,9 +1397,9 @@ def delete_top_list(request, list_id):
@permission_classes([IsAuthenticated])
def get_user_notifications(request):
"""Get user notifications."""
notifications = UserNotification.objects.filter(
user=request.user
).order_by("-created_at")[:50] # Limit to 50 most recent
notifications = UserNotification.objects.filter(user=request.user).order_by("-created_at")[
:50
] # Limit to 50 most recent
serializer = UserNotificationSerializer(notifications, many=True)
return Response(serializer.data, status=status.HTTP_200_OK)
@@ -1483,19 +1427,16 @@ def mark_notifications_read(request):
mark_all = serializer.validated_data.get("mark_all", False)
if mark_all:
UserNotification.objects.filter(
user=request.user, is_read=False
).update(is_read=True, read_at=timezone.now())
UserNotification.objects.filter(user=request.user, is_read=False).update(
is_read=True, read_at=timezone.now()
)
count = UserNotification.objects.filter(user=request.user).count()
else:
count = UserNotification.objects.filter(
id__in=notification_ids, user=request.user, is_read=False
).update(is_read=True, read_at=timezone.now())
count = UserNotification.objects.filter(id__in=notification_ids, user=request.user, is_read=False).update(
is_read=True, read_at=timezone.now()
)
return Response(
{"message": f"Marked {count} notifications as read"},
status=status.HTTP_200_OK
)
return Response({"detail": f"Marked {count} notifications as read"}, status=status.HTTP_200_OK)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
@@ -1544,9 +1485,7 @@ def update_notification_preferences(request):
except NotificationPreference.DoesNotExist:
preferences = NotificationPreference.objects.create(user=request.user)
serializer = NotificationPreferenceSerializer(
preferences, data=request.data, partial=True
)
serializer = NotificationPreferenceSerializer(preferences, data=request.data, partial=True)
if serializer.is_valid():
serializer.save()
@@ -1578,10 +1517,7 @@ def upload_avatar(request):
if serializer.is_valid():
# Handle avatar upload logic here
# This would typically involve saving the file and updating the user profile
return Response(
{"message": "Avatar uploaded successfully"},
status=status.HTTP_200_OK
)
return Response({"detail": "Avatar uploaded successfully"}, status=status.HTTP_200_OK)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
@@ -1596,8 +1532,8 @@ def upload_avatar(request):
"example": {
"account": {"username": "user", "email": "user@example.com"},
"profile": {"display_name": "User"},
"content": {"park_reviews": [], "lists": []}
}
"content": {"park_reviews": [], "lists": []},
},
},
401: {"description": "Authentication required"},
},
@@ -1612,10 +1548,7 @@ def export_user_data(request):
return Response(export_data, status=status.HTTP_200_OK)
except Exception as e:
logger.error(f"Error exporting data for user {request.user.id}: {e}", exc_info=True)
return Response(
{"error": "Failed to generate data export"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
return Response({"detail": "Failed to generate data export"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@extend_schema(
@@ -1690,20 +1623,25 @@ def get_login_history(request):
# Serialize
results = []
for entry in entries:
results.append({
"id": entry.id,
"ip_address": entry.ip_address,
"user_agent": entry.user_agent[:100] if entry.user_agent else None, # Truncate long user agents
"login_method": entry.login_method,
"login_method_display": dict(LoginHistory._meta.get_field('login_method').choices).get(entry.login_method, entry.login_method),
"login_timestamp": entry.login_timestamp.isoformat(),
"country": entry.country,
"city": entry.city,
"success": entry.success,
})
return Response({
"results": results,
"count": len(results),
})
results.append(
{
"id": entry.id,
"ip_address": entry.ip_address,
"user_agent": entry.user_agent[:100] if entry.user_agent else None, # Truncate long user agents
"login_method": entry.login_method,
"login_method_display": dict(LoginHistory._meta.get_field("login_method").choices).get(
entry.login_method, entry.login_method
),
"login_timestamp": entry.login_timestamp.isoformat(),
"country": entry.country,
"city": entry.city,
"success": entry.success,
}
)
return Response(
{
"results": results,
"count": len(results),
}
)

View File

@@ -15,22 +15,23 @@ class RideCreditViewSet(viewsets.ModelViewSet):
ViewSet for managing Ride Credits.
Allows users to track rides they have ridden.
"""
serializer_class = RideCreditSerializer
permission_classes = [permissions.IsAuthenticatedOrReadOnly]
filter_backends = [DjangoFilterBackend, filters.OrderingFilter]
filterset_fields = ['user__username', 'ride__park__slug', 'ride__manufacturer__slug']
ordering_fields = ['first_ridden_at', 'last_ridden_at', 'created_at', 'count', 'rating', 'display_order']
ordering = ['display_order', '-last_ridden_at']
filterset_fields = ["user__username", "ride__park__slug", "ride__manufacturer__slug"]
ordering_fields = ["first_ridden_at", "last_ridden_at", "created_at", "count", "rating", "display_order"]
ordering = ["display_order", "-last_ridden_at"]
def get_queryset(self):
"""
Return ride credits.
Optionally filter by user via query param ?user=username
"""
queryset = RideCredit.objects.all().select_related('ride', 'ride__park', 'user')
queryset = RideCredit.objects.all().select_related("ride", "ride__park", "user")
# Filter by user if provided
username = self.request.query_params.get('user')
username = self.request.query_params.get("user")
if username:
queryset = queryset.filter(user__username=username)
@@ -40,64 +41,49 @@ class RideCreditViewSet(viewsets.ModelViewSet):
"""Associate the current user with the ride credit."""
serializer.save(user=self.request.user)
@action(detail=False, methods=['post'], permission_classes=[permissions.IsAuthenticated])
@action(detail=False, methods=["post"], permission_classes=[permissions.IsAuthenticated])
@extend_schema(
summary="Reorder ride credits",
description="Bulk update the display order of ride credits. Send a list of {id, order} objects.",
request={
'application/json': {
'type': 'object',
'properties': {
'order': {
'type': 'array',
'items': {
'type': 'object',
'properties': {
'id': {'type': 'integer'},
'order': {'type': 'integer'}
},
'required': ['id', 'order']
}
"application/json": {
"type": "object",
"properties": {
"order": {
"type": "array",
"items": {
"type": "object",
"properties": {"id": {"type": "integer"}, "order": {"type": "integer"}},
"required": ["id", "order"],
},
}
}
},
}
}
},
)
def reorder(self, request):
"""
Bulk update display_order for multiple credits.
Expects: {"order": [{"id": 1, "order": 0}, {"id": 2, "order": 1}, ...]}
"""
order_data = request.data.get('order', [])
order_data = request.data.get("order", [])
if not order_data:
return Response(
{'error': 'No order data provided'},
status=status.HTTP_400_BAD_REQUEST
)
return Response({"detail": "No order data provided"}, status=status.HTTP_400_BAD_REQUEST)
# Validate that all credits belong to the current user
credit_ids = [item['id'] for item in order_data]
user_credits = RideCredit.objects.filter(
id__in=credit_ids,
user=request.user
).values_list('id', flat=True)
credit_ids = [item["id"] for item in order_data]
user_credits = RideCredit.objects.filter(id__in=credit_ids, user=request.user).values_list("id", flat=True)
if set(credit_ids) != set(user_credits):
return Response(
{'error': 'You can only reorder your own credits'},
status=status.HTTP_403_FORBIDDEN
)
return Response({"detail": "You can only reorder your own credits"}, status=status.HTTP_403_FORBIDDEN)
# Bulk update in a transaction
with transaction.atomic():
for item in order_data:
RideCredit.objects.filter(
id=item['id'],
user=request.user
).update(display_order=item['order'])
RideCredit.objects.filter(id=item["id"], user=request.user).update(display_order=item["order"])
return Response({'status': 'reordered', 'count': len(order_data)})
return Response({"status": "reordered", "count": len(order_data)})
@extend_schema(
summary="List ride credits",
@@ -109,8 +95,7 @@ class RideCreditViewSet(viewsets.ModelViewSet):
type=OpenApiTypes.STR,
description="Filter by username",
),
]
],
)
def list(self, request, *args, **kwargs):
return super().list(request, *args, **kwargs)

View File

@@ -4,6 +4,7 @@ Magic Link (Login by Code) API views.
Provides API endpoints for passwordless login via email code.
Uses django-allauth's built-in login-by-code functionality.
"""
from django.conf import settings
from drf_spectacular.utils import OpenApiExample, extend_schema
from rest_framework import status
@@ -15,6 +16,7 @@ try:
from allauth.account.internal.flows.login_by_code import perform_login_by_code, request_login_code
from allauth.account.models import EmailAddress
from allauth.account.utils import user_email # noqa: F401 - imported to verify availability
HAS_LOGIN_BY_CODE = True
except ImportError:
HAS_LOGIN_BY_CODE = False
@@ -24,27 +26,19 @@ except ImportError:
summary="Request magic link login code",
description="Send a one-time login code to the user's email address.",
request={
'application/json': {
'type': 'object',
'properties': {
'email': {'type': 'string', 'format': 'email'}
},
'required': ['email']
"application/json": {
"type": "object",
"properties": {"email": {"type": "string", "format": "email"}},
"required": ["email"],
}
},
responses={
200: {'description': 'Login code sent successfully'},
400: {'description': 'Invalid email or feature disabled'},
200: {"description": "Login code sent successfully"},
400: {"description": "Invalid email or feature disabled"},
},
examples=[
OpenApiExample(
'Request login code',
value={'email': 'user@example.com'},
request_only=True
)
]
examples=[OpenApiExample("Request login code", value={"email": "user@example.com"}, request_only=True)],
)
@api_view(['POST'])
@api_view(["POST"])
@permission_classes([AllowAny])
def request_magic_link(request):
"""
@@ -55,25 +49,18 @@ def request_magic_link(request):
2. If the email exists, a code is sent
3. User enters the code to complete login
"""
if not getattr(settings, 'ACCOUNT_LOGIN_BY_CODE_ENABLED', False):
return Response(
{'error': 'Magic link login is not enabled'},
status=status.HTTP_400_BAD_REQUEST
)
if not getattr(settings, "ACCOUNT_LOGIN_BY_CODE_ENABLED", False):
return Response({"detail": "Magic link login is not enabled"}, status=status.HTTP_400_BAD_REQUEST)
if not HAS_LOGIN_BY_CODE:
return Response(
{'error': 'Login by code is not available in this version of allauth'},
status=status.HTTP_400_BAD_REQUEST
{"detail": "Login by code is not available in this version of allauth"}, status=status.HTTP_400_BAD_REQUEST
)
email = request.data.get('email', '').lower().strip()
email = request.data.get("email", "").lower().strip()
if not email:
return Response(
{'error': 'Email is required'},
status=status.HTTP_400_BAD_REQUEST
)
return Response({"detail": "Email is required"}, status=status.HTTP_400_BAD_REQUEST)
# Check if email exists (don't reveal if it doesn't for security)
try:
@@ -83,40 +70,39 @@ def request_magic_link(request):
# Request the login code
request_login_code(request._request, user)
return Response({
'success': True,
'message': 'If an account exists with this email, a login code has been sent.',
'timeout': getattr(settings, 'ACCOUNT_LOGIN_BY_CODE_TIMEOUT', 300)
})
return Response(
{
"detail": "If an account exists with this email, a login code has been sent.",
"timeout": getattr(settings, "ACCOUNT_LOGIN_BY_CODE_TIMEOUT", 300),
}
)
except EmailAddress.DoesNotExist:
# Don't reveal that the email doesn't exist
return Response({
'success': True,
'message': 'If an account exists with this email, a login code has been sent.',
'timeout': getattr(settings, 'ACCOUNT_LOGIN_BY_CODE_TIMEOUT', 300)
})
return Response(
{
"detail": "If an account exists with this email, a login code has been sent.",
"timeout": getattr(settings, "ACCOUNT_LOGIN_BY_CODE_TIMEOUT", 300),
}
)
@extend_schema(
summary="Verify magic link code",
description="Verify the login code and complete the login process.",
request={
'application/json': {
'type': 'object',
'properties': {
'email': {'type': 'string', 'format': 'email'},
'code': {'type': 'string'}
},
'required': ['email', 'code']
"application/json": {
"type": "object",
"properties": {"email": {"type": "string", "format": "email"}, "code": {"type": "string"}},
"required": ["email", "code"],
}
},
responses={
200: {'description': 'Login successful'},
400: {'description': 'Invalid or expired code'},
}
200: {"description": "Login successful"},
400: {"description": "Invalid or expired code"},
},
)
@api_view(['POST'])
@api_view(["POST"])
@permission_classes([AllowAny])
def verify_magic_link(request):
"""
@@ -124,26 +110,17 @@ def verify_magic_link(request):
This is the second step of the magic link flow.
"""
if not getattr(settings, 'ACCOUNT_LOGIN_BY_CODE_ENABLED', False):
return Response(
{'error': 'Magic link login is not enabled'},
status=status.HTTP_400_BAD_REQUEST
)
if not getattr(settings, "ACCOUNT_LOGIN_BY_CODE_ENABLED", False):
return Response({"detail": "Magic link login is not enabled"}, status=status.HTTP_400_BAD_REQUEST)
if not HAS_LOGIN_BY_CODE:
return Response(
{'error': 'Login by code is not available'},
status=status.HTTP_400_BAD_REQUEST
)
return Response({"detail": "Login by code is not available"}, status=status.HTTP_400_BAD_REQUEST)
email = request.data.get('email', '').lower().strip()
code = request.data.get('code', '').strip()
email = request.data.get("email", "").lower().strip()
code = request.data.get("code", "").strip()
if not email or not code:
return Response(
{'error': 'Email and code are required'},
status=status.HTTP_400_BAD_REQUEST
)
return Response({"detail": "Email and code are required"}, status=status.HTTP_400_BAD_REQUEST)
try:
email_address = EmailAddress.objects.get(email__iexact=email, verified=True)
@@ -153,28 +130,20 @@ def verify_magic_link(request):
success = perform_login_by_code(request._request, user, code)
if success:
return Response({
'success': True,
'message': 'Login successful',
'user': {
'id': user.id,
'username': user.username,
'email': user.email
return Response(
{
"detail": "Login successful",
"user": {"id": user.id, "username": user.username, "email": user.email},
}
})
)
else:
return Response(
{'error': 'Invalid or expired code. Please request a new one.'},
status=status.HTTP_400_BAD_REQUEST
{"detail": "Invalid or expired code. Please request a new one."}, status=status.HTTP_400_BAD_REQUEST
)
except EmailAddress.DoesNotExist:
return Response(
{'error': 'Invalid email or code'},
status=status.HTTP_400_BAD_REQUEST
)
return Response({"detail": "Invalid email or code"}, status=status.HTTP_400_BAD_REQUEST)
except Exception:
return Response(
{'error': 'Invalid or expired code. Please request a new one.'},
status=status.HTTP_400_BAD_REQUEST
{"detail": "Invalid or expired code. Please request a new one."}, status=status.HTTP_400_BAD_REQUEST
)

View File

@@ -17,6 +17,7 @@ from rest_framework.response import Response
try:
import qrcode
HAS_QRCODE = True
except ImportError:
HAS_QRCODE = False
@@ -59,12 +60,14 @@ def get_mfa_status(request):
except Authenticator.DoesNotExist:
pass
return Response({
"mfa_enabled": totp_enabled,
"totp_enabled": totp_enabled,
"recovery_codes_enabled": recovery_enabled,
"recovery_codes_count": recovery_count,
})
return Response(
{
"mfa_enabled": totp_enabled,
"totp_enabled": totp_enabled,
"recovery_codes_enabled": recovery_enabled,
"recovery_codes_count": recovery_count,
}
)
@extend_schema(
@@ -110,11 +113,13 @@ def setup_totp(request):
# Store secret in session for later verification
request.session["pending_totp_secret"] = secret
return Response({
"secret": secret,
"provisioning_uri": uri,
"qr_code_base64": qr_code_base64,
})
return Response(
{
"secret": secret,
"provisioning_uri": uri,
"qr_code_base64": qr_code_base64,
}
)
@extend_schema(
@@ -138,8 +143,7 @@ def setup_totp(request):
200: {
"description": "TOTP activated successfully",
"example": {
"success": True,
"message": "Two-factor authentication enabled",
"detail": "Two-factor authentication enabled",
"recovery_codes": ["ABCD1234", "EFGH5678"],
},
},
@@ -160,7 +164,7 @@ def activate_totp(request):
if not code:
return Response(
{"success": False, "error": "Verification code is required"},
{"detail": "Verification code is required"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -168,21 +172,21 @@ def activate_totp(request):
secret = request.session.get("pending_totp_secret")
if not secret:
return Response(
{"success": False, "error": "No pending TOTP setup. Please start setup again."},
{"detail": "No pending TOTP setup. Please start setup again."},
status=status.HTTP_400_BAD_REQUEST,
)
# Verify the code
if not totp_auth.validate_totp_code(secret, code):
return Response(
{"success": False, "error": "Invalid verification code"},
{"detail": "Invalid verification code"},
status=status.HTTP_400_BAD_REQUEST,
)
# Check if already has TOTP
if Authenticator.objects.filter(user=user, type=Authenticator.Type.TOTP).exists():
return Response(
{"success": False, "error": "TOTP is already enabled"},
{"detail": "TOTP is already enabled"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -204,11 +208,12 @@ def activate_totp(request):
# Clear session
del request.session["pending_totp_secret"]
return Response({
"success": True,
"message": "Two-factor authentication enabled",
"recovery_codes": codes,
})
return Response(
{
"detail": "Two-factor authentication enabled",
"recovery_codes": codes,
}
)
@extend_schema(
@@ -230,7 +235,7 @@ def activate_totp(request):
responses={
200: {
"description": "TOTP disabled",
"example": {"success": True, "message": "Two-factor authentication disabled"},
"example": {"detail": "Two-factor authentication disabled"},
},
400: {"description": "Invalid password or MFA not enabled"},
},
@@ -248,26 +253,26 @@ def deactivate_totp(request):
# Verify password
if not user.check_password(password):
return Response(
{"success": False, "error": "Invalid password"},
{"detail": "Invalid password"},
status=status.HTTP_400_BAD_REQUEST,
)
# Remove TOTP and recovery codes
deleted_count, _ = Authenticator.objects.filter(
user=user,
type__in=[Authenticator.Type.TOTP, Authenticator.Type.RECOVERY_CODES]
user=user, type__in=[Authenticator.Type.TOTP, Authenticator.Type.RECOVERY_CODES]
).delete()
if deleted_count == 0:
return Response(
{"success": False, "error": "Two-factor authentication is not enabled"},
{"detail": "Two-factor authentication is not enabled"},
status=status.HTTP_400_BAD_REQUEST,
)
return Response({
"success": True,
"message": "Two-factor authentication disabled",
})
return Response(
{
"detail": "Two-factor authentication disabled",
}
)
@extend_schema(
@@ -277,9 +282,7 @@ def deactivate_totp(request):
request={
"application/json": {
"type": "object",
"properties": {
"code": {"type": "string", "description": "6-digit TOTP code"}
},
"properties": {"code": {"type": "string", "description": "6-digit TOTP code"}},
"required": ["code"],
}
},
@@ -301,7 +304,7 @@ def verify_totp(request):
if not code:
return Response(
{"success": False, "error": "Verification code is required"},
{"detail": "Verification code is required"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -313,12 +316,12 @@ def verify_totp(request):
return Response({"success": True})
else:
return Response(
{"success": False, "error": "Invalid verification code"},
{"detail": "Invalid verification code"},
status=status.HTTP_400_BAD_REQUEST,
)
except Authenticator.DoesNotExist:
return Response(
{"success": False, "error": "TOTP is not enabled"},
{"detail": "TOTP is not enabled"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -330,9 +333,7 @@ def verify_totp(request):
request={
"application/json": {
"type": "object",
"properties": {
"password": {"type": "string", "description": "Current password"}
},
"properties": {"password": {"type": "string", "description": "Current password"}},
"required": ["password"],
}
},
@@ -358,14 +359,14 @@ def regenerate_recovery_codes(request):
# Verify password
if not user.check_password(password):
return Response(
{"success": False, "error": "Invalid password"},
{"detail": "Invalid password"},
status=status.HTTP_400_BAD_REQUEST,
)
# Check if TOTP is enabled
if not Authenticator.objects.filter(user=user, type=Authenticator.Type.TOTP).exists():
return Response(
{"success": False, "error": "Two-factor authentication is not enabled"},
{"detail": "Two-factor authentication is not enabled"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -379,7 +380,9 @@ def regenerate_recovery_codes(request):
defaults={"data": {"codes": codes}},
)
return Response({
"success": True,
"recovery_codes": codes,
})
return Response(
{
"success": True,
"recovery_codes": codes,
}
)

View File

@@ -38,8 +38,6 @@ class ModelChoices:
"""Model choices utility class."""
# === AUTHENTICATION SERIALIZERS ===
@@ -95,12 +93,8 @@ class UserOutputSerializer(serializers.ModelSerializer):
class LoginInputSerializer(serializers.Serializer):
"""Input serializer for user login."""
username = serializers.CharField(
max_length=254, help_text="Username or email address"
)
password = serializers.CharField(
max_length=128, style={"input_type": "password"}, trim_whitespace=False
)
username = serializers.CharField(max_length=254, help_text="Username or email address")
password = serializers.CharField(max_length=128, style={"input_type": "password"}, trim_whitespace=False)
def validate(self, attrs):
username = attrs.get("username")
@@ -129,9 +123,7 @@ class SignupInputSerializer(serializers.ModelSerializer):
validators=[validate_password],
style={"input_type": "password"},
)
password_confirm = serializers.CharField(
write_only=True, style={"input_type": "password"}
)
password_confirm = serializers.CharField(write_only=True, style={"input_type": "password"})
class Meta:
model = UserModel
@@ -158,9 +150,7 @@ class SignupInputSerializer(serializers.ModelSerializer):
def validate_username(self, value):
"""Validate username is unique."""
if UserModel.objects.filter(username=value).exists():
raise serializers.ValidationError(
"A user with this username already exists."
)
raise serializers.ValidationError("A user with this username already exists.")
return value
def validate(self, attrs):
@@ -169,9 +159,7 @@ class SignupInputSerializer(serializers.ModelSerializer):
password_confirm = attrs.get("password_confirm")
if password != password_confirm:
raise serializers.ValidationError(
{"password_confirm": "Passwords do not match."}
)
raise serializers.ValidationError({"password_confirm": "Passwords do not match."})
return attrs
@@ -204,8 +192,7 @@ class SignupInputSerializer(serializers.ModelSerializer):
# Create or update email verification record
verification, created = EmailVerification.objects.get_or_create(
user=user,
defaults={'token': get_random_string(64)}
user=user, defaults={"token": get_random_string(64)}
)
if not created:
@@ -214,14 +201,12 @@ class SignupInputSerializer(serializers.ModelSerializer):
verification.save()
# Get current site from request context
request = self.context.get('request')
request = self.context.get("request")
if request:
site = get_current_site(request._request)
# Build verification URL
verification_url = request.build_absolute_uri(
f"/api/v1/auth/verify-email/{verification.token}/"
)
verification_url = request.build_absolute_uri(f"/api/v1/auth/verify-email/{verification.token}/")
# Send verification email
try:
@@ -243,13 +228,11 @@ The ThrillWiki Team
)
# Log the ForwardEmail email ID from the response
email_id = response.get('id') if response else None
email_id = response.get("id") if response else None
if email_id:
logger.info(
f"Verification email sent successfully to {user.email}. ForwardEmail ID: {email_id}")
logger.info(f"Verification email sent successfully to {user.email}. ForwardEmail ID: {email_id}")
else:
logger.info(
f"Verification email sent successfully to {user.email}. No email ID in response.")
logger.info(f"Verification email sent successfully to {user.email}. No email ID in response.")
except Exception as e:
# Log the error but don't fail registration
@@ -312,17 +295,13 @@ class PasswordResetOutputSerializer(serializers.Serializer):
class PasswordChangeInputSerializer(serializers.Serializer):
"""Input serializer for password change."""
old_password = serializers.CharField(
max_length=128, style={"input_type": "password"}
)
old_password = serializers.CharField(max_length=128, style={"input_type": "password"})
new_password = serializers.CharField(
max_length=128,
validators=[validate_password],
style={"input_type": "password"},
)
new_password_confirm = serializers.CharField(
max_length=128, style={"input_type": "password"}
)
new_password_confirm = serializers.CharField(max_length=128, style={"input_type": "password"})
def validate_old_password(self, value):
"""Validate old password is correct."""
@@ -337,9 +316,7 @@ class PasswordChangeInputSerializer(serializers.Serializer):
new_password_confirm = attrs.get("new_password_confirm")
if new_password != new_password_confirm:
raise serializers.ValidationError(
{"new_password_confirm": "New passwords do not match."}
)
raise serializers.ValidationError({"new_password_confirm": "New passwords do not match."})
return attrs
@@ -471,6 +448,3 @@ class UserProfileUpdateInputSerializer(serializers.Serializer):
dark_ride_credits = serializers.IntegerField(required=False)
flat_ride_credits = serializers.IntegerField(required=False)
water_ride_credits = serializers.IntegerField(required=False)

View File

@@ -19,13 +19,13 @@ from .social import (
__all__ = [
# Social authentication serializers
'ConnectedProviderSerializer',
'AvailableProviderSerializer',
'SocialAuthStatusSerializer',
'ConnectProviderInputSerializer',
'ConnectProviderOutputSerializer',
'DisconnectProviderOutputSerializer',
'SocialProviderListOutputSerializer',
'ConnectedProvidersListOutputSerializer',
'SocialProviderErrorSerializer',
"ConnectedProviderSerializer",
"AvailableProviderSerializer",
"SocialAuthStatusSerializer",
"ConnectProviderInputSerializer",
"ConnectProviderOutputSerializer",
"DisconnectProviderOutputSerializer",
"SocialProviderListOutputSerializer",
"ConnectedProvidersListOutputSerializer",
"SocialProviderErrorSerializer",
]

View File

@@ -14,74 +14,36 @@ User = get_user_model()
class ConnectedProviderSerializer(serializers.Serializer):
"""Serializer for connected social provider information."""
provider = serializers.CharField(
help_text="Provider ID (e.g., 'google', 'discord')"
)
provider_name = serializers.CharField(
help_text="Human-readable provider name"
)
uid = serializers.CharField(
help_text="User ID on the social provider"
)
date_joined = serializers.DateTimeField(
help_text="When this provider was connected"
)
can_disconnect = serializers.BooleanField(
help_text="Whether this provider can be safely disconnected"
)
provider = serializers.CharField(help_text="Provider ID (e.g., 'google', 'discord')")
provider_name = serializers.CharField(help_text="Human-readable provider name")
uid = serializers.CharField(help_text="User ID on the social provider")
date_joined = serializers.DateTimeField(help_text="When this provider was connected")
can_disconnect = serializers.BooleanField(help_text="Whether this provider can be safely disconnected")
disconnect_reason = serializers.CharField(
allow_null=True,
required=False,
help_text="Reason why provider cannot be disconnected (if applicable)"
)
extra_data = serializers.JSONField(
required=False,
help_text="Additional data from the social provider"
allow_null=True, required=False, help_text="Reason why provider cannot be disconnected (if applicable)"
)
extra_data = serializers.JSONField(required=False, help_text="Additional data from the social provider")
class AvailableProviderSerializer(serializers.Serializer):
"""Serializer for available social provider information."""
id = serializers.CharField(
help_text="Provider ID (e.g., 'google', 'discord')"
)
name = serializers.CharField(
help_text="Human-readable provider name"
)
auth_url = serializers.URLField(
help_text="URL to initiate authentication with this provider"
)
connect_url = serializers.URLField(
help_text="API URL to connect this provider"
)
id = serializers.CharField(help_text="Provider ID (e.g., 'google', 'discord')")
name = serializers.CharField(help_text="Human-readable provider name")
auth_url = serializers.URLField(help_text="URL to initiate authentication with this provider")
connect_url = serializers.URLField(help_text="API URL to connect this provider")
class SocialAuthStatusSerializer(serializers.Serializer):
"""Serializer for comprehensive social authentication status."""
user_id = serializers.IntegerField(
help_text="User's ID"
)
username = serializers.CharField(
help_text="User's username"
)
email = serializers.EmailField(
help_text="User's email address"
)
has_password_auth = serializers.BooleanField(
help_text="Whether user has email/password authentication set up"
)
connected_providers = ConnectedProviderSerializer(
many=True,
help_text="List of connected social providers"
)
total_auth_methods = serializers.IntegerField(
help_text="Total number of authentication methods available"
)
can_disconnect_any = serializers.BooleanField(
help_text="Whether user can safely disconnect any provider"
)
user_id = serializers.IntegerField(help_text="User's ID")
username = serializers.CharField(help_text="User's username")
email = serializers.EmailField(help_text="User's email address")
has_password_auth = serializers.BooleanField(help_text="Whether user has email/password authentication set up")
connected_providers = ConnectedProviderSerializer(many=True, help_text="List of connected social providers")
total_auth_methods = serializers.IntegerField(help_text="Total number of authentication methods available")
can_disconnect_any = serializers.BooleanField(help_text="Whether user can safely disconnect any provider")
requires_password_setup = serializers.BooleanField(
help_text="Whether user needs to set up password before disconnecting"
)
@@ -90,9 +52,7 @@ class SocialAuthStatusSerializer(serializers.Serializer):
class ConnectProviderInputSerializer(serializers.Serializer):
"""Serializer for social provider connection requests."""
provider = serializers.CharField(
help_text="Provider ID to connect (e.g., 'google', 'discord')"
)
provider = serializers.CharField(help_text="Provider ID to connect (e.g., 'google', 'discord')")
def validate_provider(self, value):
"""Validate that the provider is supported and configured."""
@@ -108,93 +68,51 @@ class ConnectProviderInputSerializer(serializers.Serializer):
class ConnectProviderOutputSerializer(serializers.Serializer):
"""Serializer for social provider connection responses."""
success = serializers.BooleanField(
help_text="Whether the connection was successful"
)
message = serializers.CharField(
help_text="Success or error message"
)
provider = serializers.CharField(
help_text="Provider that was connected"
)
auth_url = serializers.URLField(
required=False,
help_text="URL to complete the connection process"
)
success = serializers.BooleanField(help_text="Whether the connection was successful")
message = serializers.CharField(help_text="Success or error message")
provider = serializers.CharField(help_text="Provider that was connected")
auth_url = serializers.URLField(required=False, help_text="URL to complete the connection process")
class DisconnectProviderOutputSerializer(serializers.Serializer):
"""Serializer for social provider disconnection responses."""
success = serializers.BooleanField(
help_text="Whether the disconnection was successful"
)
message = serializers.CharField(
help_text="Success or error message"
)
provider = serializers.CharField(
help_text="Provider that was disconnected"
)
success = serializers.BooleanField(help_text="Whether the disconnection was successful")
message = serializers.CharField(help_text="Success or error message")
provider = serializers.CharField(help_text="Provider that was disconnected")
remaining_providers = serializers.ListField(
child=serializers.CharField(),
help_text="List of remaining connected providers"
)
has_password_auth = serializers.BooleanField(
help_text="Whether user still has password authentication"
child=serializers.CharField(), help_text="List of remaining connected providers"
)
has_password_auth = serializers.BooleanField(help_text="Whether user still has password authentication")
suggestions = serializers.ListField(
child=serializers.CharField(),
required=False,
help_text="Suggestions for maintaining account access (if applicable)"
help_text="Suggestions for maintaining account access (if applicable)",
)
class SocialProviderListOutputSerializer(serializers.Serializer):
"""Serializer for listing available social providers."""
available_providers = AvailableProviderSerializer(
many=True,
help_text="List of available social providers"
)
count = serializers.IntegerField(
help_text="Number of available providers"
)
available_providers = AvailableProviderSerializer(many=True, help_text="List of available social providers")
count = serializers.IntegerField(help_text="Number of available providers")
class ConnectedProvidersListOutputSerializer(serializers.Serializer):
"""Serializer for listing connected social providers."""
connected_providers = ConnectedProviderSerializer(
many=True,
help_text="List of connected social providers"
)
count = serializers.IntegerField(
help_text="Number of connected providers"
)
has_password_auth = serializers.BooleanField(
help_text="Whether user has password authentication"
)
can_disconnect_any = serializers.BooleanField(
help_text="Whether user can safely disconnect any provider"
)
connected_providers = ConnectedProviderSerializer(many=True, help_text="List of connected social providers")
count = serializers.IntegerField(help_text="Number of connected providers")
has_password_auth = serializers.BooleanField(help_text="Whether user has password authentication")
can_disconnect_any = serializers.BooleanField(help_text="Whether user can safely disconnect any provider")
class SocialProviderErrorSerializer(serializers.Serializer):
"""Serializer for social provider error responses."""
error = serializers.CharField(
help_text="Error message"
)
code = serializers.CharField(
required=False,
help_text="Error code for programmatic handling"
)
error = serializers.CharField(help_text="Error message")
code = serializers.CharField(required=False, help_text="Error code for programmatic handling")
suggestions = serializers.ListField(
child=serializers.CharField(),
required=False,
help_text="Suggestions for resolving the error"
)
provider = serializers.CharField(
required=False,
help_text="Provider related to the error (if applicable)"
child=serializers.CharField(), required=False, help_text="Suggestions for resolving the error"
)
provider = serializers.CharField(required=False, help_text="Provider related to the error (if applicable)")

View File

@@ -36,13 +36,10 @@ urlpatterns = [
path("signup/", SignupAPIView.as_view(), name="auth-signup"),
path("logout/", LogoutAPIView.as_view(), name="auth-logout"),
path("user/", CurrentUserAPIView.as_view(), name="auth-current-user"),
# JWT token management
path("token/refresh/", TokenRefreshView.as_view(), name="auth-token-refresh"),
# Social authentication endpoints (dj-rest-auth)
path("social/", include("dj_rest_auth.registration.urls")),
path(
"password/reset/",
PasswordResetAPIView.as_view(),
@@ -58,7 +55,6 @@ urlpatterns = [
SocialProvidersAPIView.as_view(),
name="auth-social-providers",
),
# Social provider management endpoints
path(
"social/providers/available/",
@@ -85,9 +81,7 @@ urlpatterns = [
SocialAuthStatusAPIView.as_view(),
name="auth-social-status",
),
path("status/", AuthStatusAPIView.as_view(), name="auth-status"),
# Email verification endpoints
path(
"verify-email/<str:token>/",
@@ -99,7 +93,6 @@ urlpatterns = [
ResendVerificationAPIView.as_view(),
name="auth-resend-verification",
),
# MFA (Multi-Factor Authentication) endpoints
path("mfa/status/", mfa_views.get_mfa_status, name="auth-mfa-status"),
path("mfa/totp/setup/", mfa_views.setup_totp, name="auth-mfa-totp-setup"),

View File

@@ -85,9 +85,7 @@ def _get_underlying_request(request: Request) -> HttpRequest:
# Helper: encapsulate user lookup + authenticate to reduce complexity in view
def _authenticate_user_by_lookup(
email_or_username: str, password: str, request: Request
) -> UserModel | None:
def _authenticate_user_by_lookup(email_or_username: str, password: str, request: Request) -> UserModel | None:
"""
Try a single optimized query to find a user by email OR username then authenticate.
Returns authenticated user or None.
@@ -154,7 +152,7 @@ class LoginAPIView(APIView):
# instantiate mixin before calling to avoid type-mismatch in static analysis
TurnstileMixin().validate_turnstile(request)
except ValidationError as e:
return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST)
return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST)
except Exception:
# If mixin doesn't do anything, continue
pass
@@ -168,7 +166,7 @@ class LoginAPIView(APIView):
if not email_or_username or not password:
return Response(
{"error": "username and password are required"},
{"detail": "username and password are required"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -177,8 +175,7 @@ class LoginAPIView(APIView):
if user:
if getattr(user, "is_active", False):
# pass a real HttpRequest to Django login with backend specified
login(_get_underlying_request(request), user,
backend='django.contrib.auth.backends.ModelBackend')
login(_get_underlying_request(request), user, backend="django.contrib.auth.backends.ModelBackend")
# Generate JWT tokens
from rest_framework_simplejwt.tokens import RefreshToken
@@ -191,22 +188,22 @@ class LoginAPIView(APIView):
"access": str(access_token),
"refresh": str(refresh),
"user": user,
"message": "Login successful",
"detail": "Login successful",
}
)
return Response(response_serializer.data)
else:
return Response(
{
"error": "Email verification required",
"message": "Please verify your email address before logging in. Check your email for a verification link.",
"email_verification_required": True
"detail": "Please verify your email address before logging in. Check your email for a verification link.",
"code": "EMAIL_VERIFICATION_REQUIRED",
"email_verification_required": True,
},
status=status.HTTP_400_BAD_REQUEST,
)
else:
return Response(
{"error": "Invalid credentials"},
{"detail": "Invalid credentials"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -237,7 +234,7 @@ class SignupAPIView(APIView):
# instantiate mixin before calling to avoid type-mismatch in static analysis
TurnstileMixin().validate_turnstile(request)
except ValidationError as e:
return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST)
return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST)
except Exception:
# If mixin doesn't do anything, continue
pass
@@ -252,7 +249,7 @@ class SignupAPIView(APIView):
"access": None,
"refresh": None,
"user": user,
"message": "Registration successful. Please check your email to verify your account.",
"detail": "Registration successful. Please check your email to verify your account.",
"email_verification_required": True,
}
)
@@ -282,18 +279,18 @@ class LogoutAPIView(APIView):
try:
# Get refresh token from request data with proper type handling
refresh_token = None
if hasattr(request, 'data') and request.data is not None:
data = getattr(request, 'data', {})
if hasattr(data, 'get'):
if hasattr(request, "data") and request.data is not None:
data = getattr(request, "data", {})
if hasattr(data, "get"):
refresh_token = data.get("refresh")
if refresh_token and isinstance(refresh_token, str):
# Blacklist the refresh token
from rest_framework_simplejwt.tokens import RefreshToken
try:
# Create RefreshToken from string and blacklist it
refresh_token_obj = RefreshToken(
refresh_token) # type: ignore[arg-type]
refresh_token_obj = RefreshToken(refresh_token) # type: ignore[arg-type]
refresh_token_obj.blacklist()
except Exception:
# Token might be invalid or already blacklisted
@@ -306,14 +303,10 @@ class LogoutAPIView(APIView):
# Logout from session using the underlying HttpRequest
logout(_get_underlying_request(request))
response_serializer = LogoutOutputSerializer(
{"message": "Logout successful"}
)
response_serializer = LogoutOutputSerializer({"detail": "Logout successful"})
return Response(response_serializer.data)
except Exception:
return Response(
{"error": "Logout failed"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
return Response({"detail": "Logout failed"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@extend_schema_view(
@@ -357,15 +350,11 @@ class PasswordResetAPIView(APIView):
serializer_class = PasswordResetInputSerializer
def post(self, request: Request) -> Response:
serializer = PasswordResetInputSerializer(
data=request.data, context={"request": request}
)
serializer = PasswordResetInputSerializer(data=request.data, context={"request": request})
if serializer.is_valid():
serializer.save()
response_serializer = PasswordResetOutputSerializer(
{"detail": "Password reset email sent"}
)
response_serializer = PasswordResetOutputSerializer({"detail": "Password reset email sent"})
return Response(response_serializer.data)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
@@ -391,15 +380,11 @@ class PasswordChangeAPIView(APIView):
serializer_class = PasswordChangeInputSerializer
def post(self, request: Request) -> Response:
serializer = PasswordChangeInputSerializer(
data=request.data, context={"request": request}
)
serializer = PasswordChangeInputSerializer(data=request.data, context={"request": request})
if serializer.is_valid():
serializer.save()
response_serializer = PasswordChangeOutputSerializer(
{"detail": "Password changed successfully"}
)
response_serializer = PasswordChangeOutputSerializer({"detail": "Password changed successfully"})
return Response(response_serializer.data)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
@@ -443,13 +428,9 @@ class SocialProvidersAPIView(APIView):
for social_app in social_apps:
try:
provider_name = (
social_app.name or getattr(social_app, "provider", "").title()
)
provider_name = social_app.name or getattr(social_app, "provider", "").title()
auth_url = request.build_absolute_uri(
f"/accounts/{social_app.provider}/login/"
)
auth_url = request.build_absolute_uri(f"/accounts/{social_app.provider}/login/")
providers_list.append(
{
@@ -532,7 +513,7 @@ class AvailableProvidersAPIView(APIView):
"name": "Discord",
"login_url": "/auth/social/discord/",
"connect_url": "/auth/social/connect/discord/",
}
},
]
serializer = AvailableProviderSerializer(providers, many=True)
@@ -585,31 +566,29 @@ class ConnectProviderAPIView(APIView):
def post(self, request: Request, provider: str) -> Response:
# Validate provider
if provider not in ['google', 'discord']:
if provider not in ["google", "discord"]:
return Response(
{
"success": False,
"error": "INVALID_PROVIDER",
"message": f"Provider '{provider}' is not supported",
"suggestions": ["Use 'google' or 'discord'"]
"detail": f"Provider '{provider}' is not supported",
"code": "INVALID_PROVIDER",
"suggestions": ["Use 'google' or 'discord'"],
},
status=status.HTTP_400_BAD_REQUEST
status=status.HTTP_400_BAD_REQUEST,
)
serializer = ConnectProviderInputSerializer(data=request.data)
if not serializer.is_valid():
return Response(
{
"success": False,
"error": "VALIDATION_ERROR",
"message": "Invalid request data",
"detail": "Invalid request data",
"code": "VALIDATION_ERROR",
"details": serializer.errors,
"suggestions": ["Provide a valid access_token"]
"suggestions": ["Provide a valid access_token"],
},
status=status.HTTP_400_BAD_REQUEST
status=status.HTTP_400_BAD_REQUEST,
)
access_token = serializer.validated_data['access_token']
access_token = serializer.validated_data["access_token"]
try:
service = SocialProviderService()
@@ -622,14 +601,14 @@ class ConnectProviderAPIView(APIView):
return Response(
{
"success": False,
"error": "CONNECTION_FAILED",
"detail": "CONNECTION_FAILED",
"message": str(e),
"suggestions": [
"Verify the access token is valid",
"Ensure the provider account is not already connected to another user"
]
"Ensure the provider account is not already connected to another user",
],
},
status=status.HTTP_400_BAD_REQUEST
status=status.HTTP_400_BAD_REQUEST,
)
@@ -653,35 +632,33 @@ class DisconnectProviderAPIView(APIView):
def post(self, request: Request, provider: str) -> Response:
# Validate provider
if provider not in ['google', 'discord']:
if provider not in ["google", "discord"]:
return Response(
{
"success": False,
"error": "INVALID_PROVIDER",
"message": f"Provider '{provider}' is not supported",
"suggestions": ["Use 'google' or 'discord'"]
"detail": f"Provider '{provider}' is not supported",
"code": "INVALID_PROVIDER",
"suggestions": ["Use 'google' or 'discord'"],
},
status=status.HTTP_400_BAD_REQUEST
status=status.HTTP_400_BAD_REQUEST,
)
try:
service = SocialProviderService()
# Check if disconnection is safe
can_disconnect, reason = service.can_disconnect_provider(
request.user, provider)
can_disconnect, reason = service.can_disconnect_provider(request.user, provider)
if not can_disconnect:
return Response(
{
"success": False,
"error": "UNSAFE_DISCONNECTION",
"detail": "UNSAFE_DISCONNECTION",
"message": reason,
"suggestions": [
"Set up email/password authentication before disconnecting",
"Connect another social provider before disconnecting this one"
]
"Connect another social provider before disconnecting this one",
],
},
status=status.HTTP_400_BAD_REQUEST
status=status.HTTP_400_BAD_REQUEST,
)
# Perform disconnection
@@ -694,14 +671,14 @@ class DisconnectProviderAPIView(APIView):
return Response(
{
"success": False,
"error": "DISCONNECTION_FAILED",
"detail": "DISCONNECTION_FAILED",
"message": str(e),
"suggestions": [
"Verify the provider is currently connected",
"Ensure you have alternative authentication methods"
]
"Ensure you have alternative authentication methods",
],
},
status=status.HTTP_400_BAD_REQUEST
status=status.HTTP_400_BAD_REQUEST,
)
@@ -755,7 +732,7 @@ class EmailVerificationAPIView(APIView):
from apps.accounts.models import EmailVerification
try:
verification = EmailVerification.objects.select_related('user').get(token=token)
verification = EmailVerification.objects.select_related("user").get(token=token)
user = verification.user
# Activate the user
@@ -765,16 +742,10 @@ class EmailVerificationAPIView(APIView):
# Delete the verification record
verification.delete()
return Response({
"message": "Email verified successfully. You can now log in.",
"success": True
})
return Response({"detail": "Email verified successfully. You can now log in.", "success": True})
except EmailVerification.DoesNotExist:
return Response(
{"error": "Invalid or expired verification token"},
status=status.HTTP_404_NOT_FOUND
)
return Response({"detail": "Invalid or expired verification token"}, status=status.HTTP_404_NOT_FOUND)
@extend_schema_view(
@@ -803,27 +774,20 @@ class ResendVerificationAPIView(APIView):
from apps.accounts.models import EmailVerification
email = request.data.get('email')
email = request.data.get("email")
if not email:
return Response(
{"error": "Email address is required"},
status=status.HTTP_400_BAD_REQUEST
)
return Response({"detail": "Email address is required"}, status=status.HTTP_400_BAD_REQUEST)
try:
user = UserModel.objects.get(email__iexact=email.strip().lower())
# Don't resend if user is already active
if user.is_active:
return Response(
{"error": "Email is already verified"},
status=status.HTTP_400_BAD_REQUEST
)
return Response({"detail": "Email is already verified"}, status=status.HTTP_400_BAD_REQUEST)
# Create or update verification record
verification, created = EmailVerification.objects.get_or_create(
user=user,
defaults={'token': get_random_string(64)}
user=user, defaults={"token": get_random_string(64)}
)
if not created:
@@ -833,9 +797,7 @@ class ResendVerificationAPIView(APIView):
# Send verification email
site = get_current_site(_get_underlying_request(request))
verification_url = request.build_absolute_uri(
f"/api/v1/auth/verify-email/{verification.token}/"
)
verification_url = request.build_absolute_uri(f"/api/v1/auth/verify-email/{verification.token}/")
try:
EmailService.send_email(
@@ -855,27 +817,21 @@ The ThrillWiki Team
site=site,
)
return Response({
"message": "Verification email sent successfully",
"success": True
})
return Response({"detail": "Verification email sent successfully", "success": True})
except Exception as e:
import logging
logger = logging.getLogger(__name__)
logger.error(f"Failed to send verification email to {user.email}: {e}")
return Response(
{"error": "Failed to send verification email"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
{"detail": "Failed to send verification email"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
except UserModel.DoesNotExist:
# Don't reveal whether email exists
return Response({
"message": "If the email exists, a verification email has been sent",
"success": True
})
return Response({"detail": "If the email exists, a verification email has been sent", "success": True})
# Note: User Profile, Top List, and Top List Item ViewSets are now handled

View File

@@ -8,7 +8,6 @@ Caching Strategy:
- EntityNotFoundView: No caching - POST requests with context-specific data
"""
import contextlib
from drf_spectacular.utils import extend_schema
@@ -82,9 +81,7 @@ class EntityFuzzySearchView(APIView):
try:
# Parse request data
query = request.data.get("query", "").strip()
entity_types_raw = request.data.get(
"entity_types", ["park", "ride", "company"]
)
entity_types_raw = request.data.get("entity_types", ["park", "ride", "company"])
include_suggestions = request.data.get("include_suggestions", True)
# Validate query
@@ -92,7 +89,7 @@ class EntityFuzzySearchView(APIView):
return Response(
{
"success": False,
"error": "Query must be at least 2 characters long",
"detail": "Query must be at least 2 characters long",
"code": "INVALID_QUERY",
},
status=status.HTTP_400_BAD_REQUEST,
@@ -120,9 +117,7 @@ class EntityFuzzySearchView(APIView):
"query": query,
"matches": [match.to_dict() for match in matches],
"user_authenticated": (
request.user.is_authenticated
if hasattr(request.user, "is_authenticated")
else False
request.user.is_authenticated if hasattr(request.user, "is_authenticated") else False
),
}
@@ -143,7 +138,7 @@ class EntityFuzzySearchView(APIView):
return Response(
{
"success": False,
"error": f"Internal server error: {str(e)}",
"detail": f"Internal server error: {str(e)}",
"code": "INTERNAL_ERROR",
},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -192,7 +187,7 @@ class EntityNotFoundView(APIView):
return Response(
{
"success": False,
"error": "original_query is required",
"detail": "original_query is required",
"code": "MISSING_QUERY",
},
status=status.HTTP_400_BAD_REQUEST,
@@ -233,9 +228,7 @@ class EntityNotFoundView(APIView):
"context": context,
"matches": [match.to_dict() for match in matches],
"user_authenticated": (
request.user.is_authenticated
if hasattr(request.user, "is_authenticated")
else False
request.user.is_authenticated if hasattr(request.user, "is_authenticated") else False
),
"has_matches": len(matches) > 0,
}
@@ -257,7 +250,7 @@ class EntityNotFoundView(APIView):
return Response(
{
"success": False,
"error": f"Internal server error: {str(e)}",
"detail": f"Internal server error: {str(e)}",
"code": "INTERNAL_ERROR",
},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
@@ -297,9 +290,7 @@ class QuickEntitySuggestionView(APIView):
limit = min(int(request.GET.get("limit", 5)), 10) # Cap at 10
if not query or len(query) < 2:
return Response(
{"suggestions": [], "query": query}, status=status.HTTP_200_OK
)
return Response({"suggestions": [], "query": query}, status=status.HTTP_200_OK)
# Parse entity types
entity_types = []
@@ -312,9 +303,7 @@ class QuickEntitySuggestionView(APIView):
entity_types = [EntityType.PARK, EntityType.RIDE, EntityType.COMPANY]
# Get fuzzy matches
matches, _ = entity_fuzzy_matcher.find_entity(
query=query, entity_types=entity_types, user=request.user
)
matches, _ = entity_fuzzy_matcher.find_entity(query=query, entity_types=entity_types, user=request.user)
# Format as simple suggestions
suggestions = []
@@ -337,15 +326,13 @@ class QuickEntitySuggestionView(APIView):
except Exception as e:
return Response(
{"suggestions": [], "query": request.GET.get("q", ""), "error": str(e)},
{"suggestions": [], "query": request.GET.get("q", ""), "detail": str(e)},
status=status.HTTP_200_OK,
) # Return 200 even on errors for autocomplete
# Utility function for other views to use
def get_entity_suggestions(
query: str, entity_types: list[str] | None = None, user=None
):
def get_entity_suggestions(query: str, entity_types: list[str] | None = None, user=None):
"""
Utility function for other Django views to get entity suggestions.
@@ -370,8 +357,6 @@ def get_entity_suggestions(
if not parsed_types:
parsed_types = [EntityType.PARK, EntityType.RIDE, EntityType.COMPANY]
return entity_fuzzy_matcher.find_entity(
query=query, entity_types=parsed_types, user=user
)
return entity_fuzzy_matcher.find_entity(query=query, entity_types=parsed_types, user=user)
except Exception:
return [], None

View File

@@ -76,7 +76,7 @@ class SendEmailView(APIView):
if not all([to, subject, text]):
return Response(
{
"error": "Missing required fields",
"detail": "Missing required fields",
"required_fields": ["to", "subject", "text"],
},
status=status.HTTP_400_BAD_REQUEST,
@@ -96,11 +96,9 @@ class SendEmailView(APIView):
)
return Response(
{"message": "Email sent successfully", "response": response},
{"detail": "Email sent successfully", "response": response},
status=status.HTTP_200_OK,
)
except Exception as e:
return Response(
{"error": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
return Response({"detail": str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)

View File

@@ -37,21 +37,11 @@ class _FallbackSerializer(drf_serializers.Serializer):
return {}
ParkHistoryEventSerializer = getattr(
history_serializers, "ParkHistoryEventSerializer", _FallbackSerializer
)
RideHistoryEventSerializer = getattr(
history_serializers, "RideHistoryEventSerializer", _FallbackSerializer
)
ParkHistoryOutputSerializer = getattr(
history_serializers, "ParkHistoryOutputSerializer", _FallbackSerializer
)
RideHistoryOutputSerializer = getattr(
history_serializers, "RideHistoryOutputSerializer", _FallbackSerializer
)
UnifiedHistoryTimelineSerializer = getattr(
history_serializers, "UnifiedHistoryTimelineSerializer", _FallbackSerializer
)
ParkHistoryEventSerializer = getattr(history_serializers, "ParkHistoryEventSerializer", _FallbackSerializer)
RideHistoryEventSerializer = getattr(history_serializers, "RideHistoryEventSerializer", _FallbackSerializer)
ParkHistoryOutputSerializer = getattr(history_serializers, "ParkHistoryOutputSerializer", _FallbackSerializer)
RideHistoryOutputSerializer = getattr(history_serializers, "RideHistoryOutputSerializer", _FallbackSerializer)
UnifiedHistoryTimelineSerializer = getattr(history_serializers, "UnifiedHistoryTimelineSerializer", _FallbackSerializer)
# --- Constants for model strings to avoid duplication ---
PARK_MODEL = "parks.park"
@@ -201,18 +191,14 @@ class ParkHistoryViewSet(ReadOnlyModelViewSet):
# Base queryset for park events
queryset = (
pghistory.models.Events.objects.filter(
pgh_model__in=[PARK_MODEL], pgh_obj_id=getattr(park, "id", None)
)
pghistory.models.Events.objects.filter(pgh_model__in=[PARK_MODEL], pgh_obj_id=getattr(park, "id", None))
.select_related()
.order_by("-pgh_created_at")
)
# Apply list filters via helper to reduce complexity
if self.action == "list":
queryset = _apply_list_filters(
queryset, cast(Request, self.request), default_limit=50, max_limit=500
)
queryset = _apply_list_filters(queryset, cast(Request, self.request), default_limit=50, max_limit=500)
return queryset
@@ -322,18 +308,14 @@ class RideHistoryViewSet(ReadOnlyModelViewSet):
# Base queryset for ride events
queryset = (
pghistory.models.Events.objects.filter(
pgh_model__in=RIDE_MODELS, pgh_obj_id=getattr(ride, "id", None)
)
pghistory.models.Events.objects.filter(pgh_model__in=RIDE_MODELS, pgh_obj_id=getattr(ride, "id", None))
.select_related()
.order_by("-pgh_created_at")
)
# Apply list filters via helper
if self.action == "list":
queryset = _apply_list_filters(
queryset, cast(Request, self.request), default_limit=50, max_limit=500
)
queryset = _apply_list_filters(queryset, cast(Request, self.request), default_limit=50, max_limit=500)
return queryset
@@ -462,9 +444,7 @@ class UnifiedHistoryViewSet(ReadOnlyModelViewSet):
# Apply shared list filters when serving the list action
if self.action == "list":
queryset = _apply_list_filters(
queryset, cast(Request, self.request), default_limit=100, max_limit=1000
)
queryset = _apply_list_filters(queryset, cast(Request, self.request), default_limit=100, max_limit=1000)
return queryset
@@ -477,9 +457,7 @@ class UnifiedHistoryViewSet(ReadOnlyModelViewSet):
events = list(self.get_queryset()) # evaluate for counts / earliest/latest use
# Summary statistics across all tracked models
total_events = pghistory.models.Events.objects.filter(
pgh_model__in=ALL_TRACKED_MODELS
).count()
total_events = pghistory.models.Events.objects.filter(pgh_model__in=ALL_TRACKED_MODELS).count()
event_type_counts = (
pghistory.models.Events.objects.filter(pgh_model__in=ALL_TRACKED_MODELS)
@@ -497,12 +475,8 @@ class UnifiedHistoryViewSet(ReadOnlyModelViewSet):
"summary": {
"total_events": total_events,
"events_returned": len(events),
"event_type_breakdown": {
item["pgh_label"]: item["count"] for item in event_type_counts
},
"model_type_breakdown": {
item["pgh_model"]: item["count"] for item in model_type_counts
},
"event_type_breakdown": {item["pgh_label"]: item["count"] for item in event_type_counts},
"model_type_breakdown": {item["pgh_model"]: item["count"] for item in model_type_counts},
"time_range": {
"earliest": events[-1].pgh_created_at if events else None,
"latest": events[0].pgh_created_at if events else None,

View File

@@ -11,6 +11,7 @@ from apps.core.utils.cloudflare import get_direct_upload_url
logger = logging.getLogger(__name__)
class GenerateUploadURLView(APIView):
permission_classes = [IsAuthenticated]
@@ -21,19 +22,10 @@ class GenerateUploadURLView(APIView):
return Response(result, status=status.HTTP_200_OK)
except ImproperlyConfigured as e:
logger.error(f"Configuration Error: {e}")
return Response(
{"detail": "Server configuration error."},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
return Response({"detail": "Server configuration error."}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
except requests.RequestException as e:
logger.error(f"Cloudflare API Error: {e}")
return Response(
{"detail": "Failed to generate upload URL."},
status=status.HTTP_502_BAD_GATEWAY
)
return Response({"detail": "Failed to generate upload URL."}, status=status.HTTP_502_BAD_GATEWAY)
except Exception:
logger.exception("Unexpected error generating upload URL")
return Response(
{"detail": "An unexpected error occurred."},
status=status.HTTP_500_INTERNAL_SERVER_ERROR
)
return Response({"detail": "An unexpected error occurred."}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)

View File

@@ -162,16 +162,13 @@ class MapLocationsAPIView(APIView):
if not all([north, south, east, west]):
return None
try:
return Polygon.from_bbox(
(float(west), float(south), float(east), float(north))
)
return Polygon.from_bbox((float(west), float(south), float(east), float(north)))
except (ValueError, TypeError):
return None
def _serialize_park_location(self, park) -> dict:
"""Serialize park location data."""
location = park.location if hasattr(
park, "location") and park.location else None
location = park.location if hasattr(park, "location") and park.location else None
return {
"city": location.city if location else "",
"state": location.state if location else "",
@@ -181,8 +178,7 @@ class MapLocationsAPIView(APIView):
def _serialize_park_data(self, park) -> dict:
"""Serialize park data for map response."""
location = park.location if hasattr(
park, "location") and park.location else None
location = park.location if hasattr(park, "location") and park.location else None
return {
"id": park.id,
"type": "park",
@@ -195,9 +191,7 @@ class MapLocationsAPIView(APIView):
"stats": {
"coaster_count": park.coaster_count or 0,
"ride_count": park.ride_count or 0,
"average_rating": (
float(park.average_rating) if park.average_rating else None
),
"average_rating": (float(park.average_rating) if park.average_rating else None),
},
}
@@ -206,14 +200,10 @@ class MapLocationsAPIView(APIView):
if "park" not in params["types"]:
return []
parks_query = Park.objects.select_related(
"location", "operator"
).filter(location__point__isnull=False)
parks_query = Park.objects.select_related("location", "operator").filter(location__point__isnull=False)
# Apply bounds filtering
bounds_polygon = self._create_bounds_polygon(
params["north"], params["south"], params["east"], params["west"]
)
bounds_polygon = self._create_bounds_polygon(params["north"], params["south"], params["east"], params["west"])
if bounds_polygon:
parks_query = parks_query.filter(location__point__within=bounds_polygon)
@@ -229,11 +219,7 @@ class MapLocationsAPIView(APIView):
def _serialize_ride_location(self, ride) -> dict:
"""Serialize ride location data."""
location = (
ride.park.location
if hasattr(ride.park, "location") and ride.park.location
else None
)
location = ride.park.location if hasattr(ride.park, "location") and ride.park.location else None
return {
"city": location.city if location else "",
"state": location.state if location else "",
@@ -243,11 +229,7 @@ class MapLocationsAPIView(APIView):
def _serialize_ride_data(self, ride) -> dict:
"""Serialize ride data for map response."""
location = (
ride.park.location
if hasattr(ride.park, "location") and ride.park.location
else None
)
location = ride.park.location if hasattr(ride.park, "location") and ride.park.location else None
return {
"id": ride.id,
"type": "ride",
@@ -259,9 +241,7 @@ class MapLocationsAPIView(APIView):
"location": self._serialize_ride_location(ride),
"stats": {
"category": ride.get_category_display() if ride.category else None,
"average_rating": (
float(ride.average_rating) if ride.average_rating else None
),
"average_rating": (float(ride.average_rating) if ride.average_rating else None),
"park_name": ride.park.name,
},
}
@@ -271,17 +251,14 @@ class MapLocationsAPIView(APIView):
if "ride" not in params["types"]:
return []
rides_query = Ride.objects.select_related(
"park__location", "manufacturer"
).filter(park__location__point__isnull=False)
rides_query = Ride.objects.select_related("park__location", "manufacturer").filter(
park__location__point__isnull=False
)
# Apply bounds filtering
bounds_polygon = self._create_bounds_polygon(
params["north"], params["south"], params["east"], params["west"]
)
bounds_polygon = self._create_bounds_polygon(params["north"], params["south"], params["east"], params["west"])
if bounds_polygon:
rides_query = rides_query.filter(
park__location__point__within=bounds_polygon)
rides_query = rides_query.filter(park__location__point__within=bounds_polygon)
# Apply text search
if params["query"]:
@@ -335,7 +312,7 @@ class MapLocationsAPIView(APIView):
# Use EnhancedCacheService for improved caching with monitoring
cache_service = EnhancedCacheService()
cached_result = cache_service.get_cached_api_response('map_locations', params)
cached_result = cache_service.get_cached_api_response("map_locations", params)
if cached_result:
logger.debug(f"Cache hit for map_locations with key: {cache_key}")
return Response(cached_result)
@@ -349,7 +326,7 @@ class MapLocationsAPIView(APIView):
result = self._build_response(locations, params)
# Cache result for 5 minutes using EnhancedCacheService
cache_service.cache_api_response('map_locations', params, result, timeout=300)
cache_service.cache_api_response("map_locations", params, result, timeout=300)
logger.debug(f"Cached map_locations result for key: {cache_key}")
return Response(result)
@@ -357,7 +334,7 @@ class MapLocationsAPIView(APIView):
except Exception as e:
logger.error(f"Error in MapLocationsAPIView: {str(e)}", exc_info=True)
return Response(
{"status": "error", "message": "Failed to retrieve map locations"},
{"status": "error", "detail": "Failed to retrieve map locations"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@@ -401,34 +378,28 @@ class MapLocationDetailAPIView(APIView):
permission_classes = [AllowAny]
@cache_api_response(timeout=1800, key_prefix="map_detail")
def get(
self, request: HttpRequest, location_type: str, location_id: int
) -> Response:
def get(self, request: HttpRequest, location_type: str, location_id: int) -> Response:
"""Get detailed information for a specific location."""
try:
if location_type == "park":
try:
obj = Park.objects.select_related("location", "operator").get(
id=location_id
)
obj = Park.objects.select_related("location", "operator").get(id=location_id)
except Park.DoesNotExist:
return Response(
{"status": "error", "message": "Park not found"},
{"status": "error", "detail": "Park not found"},
status=status.HTTP_404_NOT_FOUND,
)
elif location_type == "ride":
try:
obj = Ride.objects.select_related(
"park__location", "manufacturer"
).get(id=location_id)
obj = Ride.objects.select_related("park__location", "manufacturer").get(id=location_id)
except Ride.DoesNotExist:
return Response(
{"status": "error", "message": "Ride not found"},
{"status": "error", "detail": "Ride not found"},
status=status.HTTP_404_NOT_FOUND,
)
else:
return Response(
{"status": "error", "message": "Invalid location type"},
{"status": "error", "detail": "Invalid location type"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -440,59 +411,27 @@ class MapLocationDetailAPIView(APIView):
"name": obj.name,
"slug": obj.slug,
"description": obj.description,
"latitude": (
obj.location.latitude
if hasattr(obj, "location") and obj.location
else None
),
"longitude": (
obj.location.longitude
if hasattr(obj, "location") and obj.location
else None
),
"latitude": (obj.location.latitude if hasattr(obj, "location") and obj.location else None),
"longitude": (obj.location.longitude if hasattr(obj, "location") and obj.location else None),
"status": obj.status,
"location": {
"street_address": (
obj.location.street_address
if hasattr(obj, "location") and obj.location
else ""
),
"city": (
obj.location.city
if hasattr(obj, "location") and obj.location
else ""
),
"state": (
obj.location.state
if hasattr(obj, "location") and obj.location
else ""
),
"country": (
obj.location.country
if hasattr(obj, "location") and obj.location
else ""
),
"postal_code": (
obj.location.postal_code
if hasattr(obj, "location") and obj.location
else ""
obj.location.street_address if hasattr(obj, "location") and obj.location else ""
),
"city": (obj.location.city if hasattr(obj, "location") and obj.location else ""),
"state": (obj.location.state if hasattr(obj, "location") and obj.location else ""),
"country": (obj.location.country if hasattr(obj, "location") and obj.location else ""),
"postal_code": (obj.location.postal_code if hasattr(obj, "location") and obj.location else ""),
"formatted_address": (
obj.location.formatted_address
if hasattr(obj, "location") and obj.location
else ""
obj.location.formatted_address if hasattr(obj, "location") and obj.location else ""
),
},
"stats": {
"coaster_count": obj.coaster_count or 0,
"ride_count": obj.ride_count or 0,
"average_rating": (
float(obj.average_rating) if obj.average_rating else None
),
"average_rating": (float(obj.average_rating) if obj.average_rating else None),
"size_acres": float(obj.size_acres) if obj.size_acres else None,
"opening_date": (
obj.opening_date.isoformat() if obj.opening_date else None
),
"opening_date": (obj.opening_date.isoformat() if obj.opening_date else None),
},
"nearby_locations": [], # See FUTURE_WORK.md - THRILLWIKI-107
}
@@ -504,14 +443,10 @@ class MapLocationDetailAPIView(APIView):
"slug": obj.slug,
"description": obj.description,
"latitude": (
obj.park.location.latitude
if hasattr(obj.park, "location") and obj.park.location
else None
obj.park.location.latitude if hasattr(obj.park, "location") and obj.park.location else None
),
"longitude": (
obj.park.location.longitude
if hasattr(obj.park, "location") and obj.park.location
else None
obj.park.location.longitude if hasattr(obj.park, "location") and obj.park.location else None
),
"status": obj.status,
"location": {
@@ -520,25 +455,15 @@ class MapLocationDetailAPIView(APIView):
if hasattr(obj.park, "location") and obj.park.location
else ""
),
"city": (
obj.park.location.city
if hasattr(obj.park, "location") and obj.park.location
else ""
),
"city": (obj.park.location.city if hasattr(obj.park, "location") and obj.park.location else ""),
"state": (
obj.park.location.state
if hasattr(obj.park, "location") and obj.park.location
else ""
obj.park.location.state if hasattr(obj.park, "location") and obj.park.location else ""
),
"country": (
obj.park.location.country
if hasattr(obj.park, "location") and obj.park.location
else ""
obj.park.location.country if hasattr(obj.park, "location") and obj.park.location else ""
),
"postal_code": (
obj.park.location.postal_code
if hasattr(obj.park, "location") and obj.park.location
else ""
obj.park.location.postal_code if hasattr(obj.park, "location") and obj.park.location else ""
),
"formatted_address": (
obj.park.location.formatted_address
@@ -547,19 +472,11 @@ class MapLocationDetailAPIView(APIView):
),
},
"stats": {
"category": (
obj.get_category_display() if obj.category else None
),
"average_rating": (
float(obj.average_rating) if obj.average_rating else None
),
"category": (obj.get_category_display() if obj.category else None),
"average_rating": (float(obj.average_rating) if obj.average_rating else None),
"park_name": obj.park.name,
"opening_date": (
obj.opening_date.isoformat() if obj.opening_date else None
),
"manufacturer": (
obj.manufacturer.name if obj.manufacturer else None
),
"opening_date": (obj.opening_date.isoformat() if obj.opening_date else None),
"manufacturer": (obj.manufacturer.name if obj.manufacturer else None),
},
"nearby_locations": [], # See FUTURE_WORK.md - THRILLWIKI-107
}
@@ -574,7 +491,7 @@ class MapLocationDetailAPIView(APIView):
except Exception as e:
logger.error(f"Error in MapLocationDetailAPIView: {str(e)}", exc_info=True)
return Response(
{"status": "error", "message": "Failed to retrieve location details"},
{"status": "error", "detail": "Failed to retrieve location details"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@@ -640,7 +557,7 @@ class MapSearchAPIView(APIView):
return Response(
{
"status": "error",
"message": "Search query 'q' parameter is required",
"detail": "Search query 'q' parameter is required",
},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -672,30 +589,16 @@ class MapSearchAPIView(APIView):
"name": park.name,
"slug": park.slug,
"latitude": (
park.location.latitude
if hasattr(park, "location") and park.location
else None
park.location.latitude if hasattr(park, "location") and park.location else None
),
"longitude": (
park.location.longitude
if hasattr(park, "location") and park.location
else None
park.location.longitude if hasattr(park, "location") and park.location else None
),
"location": {
"city": (
park.location.city
if hasattr(park, "location") and park.location
else ""
),
"state": (
park.location.state
if hasattr(park, "location") and park.location
else ""
),
"city": (park.location.city if hasattr(park, "location") and park.location else ""),
"state": (park.location.state if hasattr(park, "location") and park.location else ""),
"country": (
park.location.country
if hasattr(park, "location") and park.location
else ""
park.location.country if hasattr(park, "location") and park.location else ""
),
},
"relevance_score": 1.0, # See FUTURE_WORK.md - THRILLWIKI-108
@@ -734,20 +637,17 @@ class MapSearchAPIView(APIView):
"location": {
"city": (
ride.park.location.city
if hasattr(ride.park, "location")
and ride.park.location
if hasattr(ride.park, "location") and ride.park.location
else ""
),
"state": (
ride.park.location.state
if hasattr(ride.park, "location")
and ride.park.location
if hasattr(ride.park, "location") and ride.park.location
else ""
),
"country": (
ride.park.location.country
if hasattr(ride.park, "location")
and ride.park.location
if hasattr(ride.park, "location") and ride.park.location
else ""
),
},
@@ -776,7 +676,7 @@ class MapSearchAPIView(APIView):
except Exception as e:
logger.error(f"Error in MapSearchAPIView: {str(e)}", exc_info=True)
return Response(
{"status": "error", "message": "Search failed due to internal error"},
{"status": "error", "detail": "Search failed due to internal error"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@@ -848,8 +748,7 @@ class MapBoundsAPIView(APIView):
if not all([north_str, south_str, east_str, west_str]):
return Response(
{"status": "error",
"message": "All bounds parameters (north, south, east, west) are required"},
{"status": "error", "detail": "All bounds parameters (north, south, east, west) are required"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -860,7 +759,7 @@ class MapBoundsAPIView(APIView):
west = float(west_str) if west_str else 0.0
except (TypeError, ValueError):
return Response(
{"status": "error", "message": "Invalid bounds parameters"},
{"status": "error", "detail": "Invalid bounds parameters"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -869,7 +768,7 @@ class MapBoundsAPIView(APIView):
return Response(
{
"status": "error",
"message": "North bound must be greater than south bound",
"detail": "North bound must be greater than south bound",
},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -878,7 +777,7 @@ class MapBoundsAPIView(APIView):
return Response(
{
"status": "error",
"message": "West bound must be less than east bound",
"detail": "West bound must be less than east bound",
},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -891,9 +790,7 @@ class MapBoundsAPIView(APIView):
# Get parks within bounds
if "park" in types:
parks_query = Park.objects.select_related("location").filter(
location__point__within=bounds_polygon
)
parks_query = Park.objects.select_related("location").filter(location__point__within=bounds_polygon)
for park in parks_query[:100]: # Limit results
locations.append(
@@ -903,14 +800,10 @@ class MapBoundsAPIView(APIView):
"name": park.name,
"slug": park.slug,
"latitude": (
park.location.latitude
if hasattr(park, "location") and park.location
else None
park.location.latitude if hasattr(park, "location") and park.location else None
),
"longitude": (
park.location.longitude
if hasattr(park, "location") and park.location
else None
park.location.longitude if hasattr(park, "location") and park.location else None
),
"status": park.status,
}
@@ -960,7 +853,7 @@ class MapBoundsAPIView(APIView):
except Exception as e:
logger.error(f"Error in MapBoundsAPIView: {str(e)}", exc_info=True)
return Response(
{"status": "error", "message": "Failed to retrieve locations within bounds"},
{"status": "error", "detail": "Failed to retrieve locations within bounds"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@@ -987,18 +880,15 @@ class MapStatsAPIView(APIView):
"""Get map service statistics and performance metrics."""
try:
# Count locations with coordinates
parks_with_location = Park.objects.filter(
location__point__isnull=False
).count()
rides_with_location = Ride.objects.filter(
park__location__point__isnull=False
).count()
parks_with_location = Park.objects.filter(location__point__isnull=False).count()
rides_with_location = Ride.objects.filter(park__location__point__isnull=False).count()
total_locations = parks_with_location + rides_with_location
# Get cache statistics
from apps.core.services.enhanced_cache_service import CacheMonitor
cache_monitor = CacheMonitor()
cache_stats = cache_monitor.get_cache_statistics('map_locations')
cache_stats = cache_monitor.get_cache_statistics("map_locations")
return Response(
{
@@ -1006,17 +896,17 @@ class MapStatsAPIView(APIView):
"total_locations": total_locations,
"parks_with_location": parks_with_location,
"rides_with_location": rides_with_location,
"cache_hits": cache_stats.get('hits', 0),
"cache_misses": cache_stats.get('misses', 0),
"cache_hit_rate": cache_stats.get('hit_rate', 0.0),
"cache_size": cache_stats.get('size', 0),
"cache_hits": cache_stats.get("hits", 0),
"cache_misses": cache_stats.get("misses", 0),
"cache_hit_rate": cache_stats.get("hit_rate", 0.0),
"cache_size": cache_stats.get("size", 0),
}
)
except Exception as e:
logger.error(f"Error in MapStatsAPIView: {str(e)}", exc_info=True)
return Response(
{"status": "error", "message": "Failed to retrieve map statistics"},
{"status": "error", "detail": "Failed to retrieve map statistics"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@@ -1060,7 +950,7 @@ class MapCacheAPIView(APIView):
return Response(
{
"status": "success",
"message": f"Map cache cleared successfully. Cleared {cleared_count} entries.",
"detail": f"Map cache cleared successfully. Cleared {cleared_count} entries.",
"cleared_count": cleared_count,
}
)
@@ -1068,7 +958,7 @@ class MapCacheAPIView(APIView):
except Exception as e:
logger.error(f"Error in MapCacheAPIView.delete: {str(e)}", exc_info=True)
return Response(
{"status": "error", "message": "Failed to clear map cache"},
{"status": "error", "detail": "Failed to clear map cache"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@@ -1076,7 +966,7 @@ class MapCacheAPIView(APIView):
"""Invalidate specific cache entries."""
try:
# Get cache keys to invalidate from request data
request_data = getattr(request, 'data', {})
request_data = getattr(request, "data", {})
cache_keys = request_data.get("cache_keys", []) if request_data else []
if cache_keys:
@@ -1088,7 +978,7 @@ class MapCacheAPIView(APIView):
return Response(
{
"status": "success",
"message": f"Cache invalidated successfully. Invalidated {invalidated_count} entries.",
"detail": f"Cache invalidated successfully. Invalidated {invalidated_count} entries.",
"invalidated_count": invalidated_count,
}
)
@@ -1096,7 +986,7 @@ class MapCacheAPIView(APIView):
except Exception as e:
logger.error(f"Error in MapCacheAPIView.post: {str(e)}", exc_info=True)
return Response(
{"status": "error", "message": "Failed to invalidate cache"},
{"status": "error", "detail": "Failed to invalidate cache"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)

View File

@@ -33,7 +33,7 @@ class ContractValidationMiddleware(MiddlewareMixin):
def __init__(self, get_response):
super().__init__(get_response)
self.get_response = get_response
self.enabled = getattr(settings, 'DEBUG', False)
self.enabled = getattr(settings, "DEBUG", False)
if self.enabled:
logger.info("Contract validation middleware enabled (DEBUG mode)")
@@ -45,11 +45,11 @@ class ContractValidationMiddleware(MiddlewareMixin):
return response
# Only validate API endpoints
if not request.path.startswith('/api/'):
if not request.path.startswith("/api/"):
return response
# Only validate JSON responses
if not isinstance(response, (JsonResponse, Response)):
if not isinstance(response, JsonResponse | Response):
return response
# Only validate successful responses (2xx status codes)
@@ -58,7 +58,7 @@ class ContractValidationMiddleware(MiddlewareMixin):
try:
# Get response data
data = response.data if isinstance(response, Response) else json.loads(response.content.decode('utf-8'))
data = response.data if isinstance(response, Response) else json.loads(response.content.decode("utf-8"))
# Validate the response
self._validate_response_contract(request.path, data)
@@ -68,11 +68,11 @@ class ContractValidationMiddleware(MiddlewareMixin):
logger.warning(
f"Contract validation error for {request.path}: {str(e)}",
extra={
'path': request.path,
'method': request.method,
'status_code': response.status_code,
'validation_error': str(e)
}
"path": request.path,
"method": request.method,
"status_code": response.status_code,
"validation_error": str(e),
},
)
return response
@@ -81,15 +81,15 @@ class ContractValidationMiddleware(MiddlewareMixin):
"""Validate response data against expected contracts."""
# Check for filter metadata endpoints
if 'filter-options' in path or 'filter_options' in path:
if "filter-options" in path or "filter_options" in path:
self._validate_filter_metadata(path, data)
# Check for hybrid filtering endpoints
if 'hybrid' in path:
if "hybrid" in path:
self._validate_hybrid_response(path, data)
# Check for pagination responses
if isinstance(data, dict) and 'results' in data:
if isinstance(data, dict) and "results" in data:
self._validate_pagination_response(path, data)
# Check for common contract violations
@@ -100,22 +100,20 @@ class ContractValidationMiddleware(MiddlewareMixin):
if not isinstance(data, dict):
self._log_contract_violation(
path,
"FILTER_METADATA_NOT_DICT",
f"Filter metadata should be a dictionary, got {type(data).__name__}"
path, "FILTER_METADATA_NOT_DICT", f"Filter metadata should be a dictionary, got {type(data).__name__}"
)
return
# Check for categorical filters
if 'categorical' in data:
categorical = data['categorical']
if "categorical" in data:
categorical = data["categorical"]
if isinstance(categorical, dict):
for filter_name, filter_options in categorical.items():
self._validate_categorical_filter(path, filter_name, filter_options)
# Check for ranges
if 'ranges' in data:
ranges = data['ranges']
if "ranges" in data:
ranges = data["ranges"]
if isinstance(ranges, dict):
for range_name, range_data in ranges.items():
self._validate_range_filter(path, range_name, range_data)
@@ -127,7 +125,7 @@ class ContractValidationMiddleware(MiddlewareMixin):
self._log_contract_violation(
path,
"CATEGORICAL_FILTER_NOT_ARRAY",
f"Categorical filter '{filter_name}' should be an array, got {type(filter_options).__name__}"
f"Categorical filter '{filter_name}' should be an array, got {type(filter_options).__name__}",
)
return
@@ -138,28 +136,28 @@ class ContractValidationMiddleware(MiddlewareMixin):
path,
"CATEGORICAL_OPTION_IS_STRING",
f"Categorical filter '{filter_name}' option {i} is a string '{option}' but should be an object with value/label/count properties",
severity="ERROR"
severity="ERROR",
)
elif isinstance(option, dict):
# Validate object structure
if 'value' not in option:
if "value" not in option:
self._log_contract_violation(
path,
"MISSING_VALUE_PROPERTY",
f"Categorical filter '{filter_name}' option {i} missing 'value' property"
f"Categorical filter '{filter_name}' option {i} missing 'value' property",
)
if 'label' not in option:
if "label" not in option:
self._log_contract_violation(
path,
"MISSING_LABEL_PROPERTY",
f"Categorical filter '{filter_name}' option {i} missing 'label' property"
f"Categorical filter '{filter_name}' option {i} missing 'label' property",
)
# Count is optional but should be number if present
if 'count' in option and option['count'] is not None and not isinstance(option['count'], (int, float)):
if "count" in option and option["count"] is not None and not isinstance(option["count"], int | float):
self._log_contract_violation(
path,
"INVALID_COUNT_TYPE",
f"Categorical filter '{filter_name}' option {i} 'count' should be a number, got {type(option['count']).__name__}"
f"Categorical filter '{filter_name}' option {i} 'count' should be a number, got {type(option['count']).__name__}",
)
def _validate_range_filter(self, path: str, range_name: str, range_data: Any) -> None:
@@ -169,26 +167,24 @@ class ContractValidationMiddleware(MiddlewareMixin):
self._log_contract_violation(
path,
"RANGE_FILTER_NOT_OBJECT",
f"Range filter '{range_name}' should be an object, got {type(range_data).__name__}"
f"Range filter '{range_name}' should be an object, got {type(range_data).__name__}",
)
return
# Check required properties
required_props = ['min', 'max']
required_props = ["min", "max"]
for prop in required_props:
if prop not in range_data:
self._log_contract_violation(
path,
"MISSING_RANGE_PROPERTY",
f"Range filter '{range_name}' missing required property '{prop}'"
path, "MISSING_RANGE_PROPERTY", f"Range filter '{range_name}' missing required property '{prop}'"
)
# Check step property
if 'step' in range_data and not isinstance(range_data['step'], (int, float)):
if "step" in range_data and not isinstance(range_data["step"], int | float):
self._log_contract_violation(
path,
"INVALID_STEP_TYPE",
f"Range filter '{range_name}' 'step' should be a number, got {type(range_data['step']).__name__}"
f"Range filter '{range_name}' 'step' should be a number, got {type(range_data['step']).__name__}",
)
def _validate_hybrid_response(self, path: str, data: Any) -> None:
@@ -198,38 +194,36 @@ class ContractValidationMiddleware(MiddlewareMixin):
return
# Check for strategy field
if 'strategy' in data:
strategy = data['strategy']
if strategy not in ['client_side', 'server_side']:
if "strategy" in data:
strategy = data["strategy"]
if strategy not in ["client_side", "server_side"]:
self._log_contract_violation(
path,
"INVALID_STRATEGY_VALUE",
f"Hybrid response strategy should be 'client_side' or 'server_side', got '{strategy}'"
f"Hybrid response strategy should be 'client_side' or 'server_side', got '{strategy}'",
)
# Check filter_metadata structure
if 'filter_metadata' in data:
self._validate_filter_metadata(path, data['filter_metadata'])
if "filter_metadata" in data:
self._validate_filter_metadata(path, data["filter_metadata"])
def _validate_pagination_response(self, path: str, data: dict[str, Any]) -> None:
"""Validate pagination response structure."""
# Check for required pagination fields
required_fields = ['count', 'results']
required_fields = ["count", "results"]
for field in required_fields:
if field not in data:
self._log_contract_violation(
path,
"MISSING_PAGINATION_FIELD",
f"Pagination response missing required field '{field}'"
path, "MISSING_PAGINATION_FIELD", f"Pagination response missing required field '{field}'"
)
# Check results is array
if 'results' in data and not isinstance(data['results'], list):
if "results" in data and not isinstance(data["results"], list):
self._log_contract_violation(
path,
"RESULTS_NOT_ARRAY",
f"Pagination 'results' should be an array, got {type(data['results']).__name__}"
f"Pagination 'results' should be an array, got {type(data['results']).__name__}",
)
def _validate_common_patterns(self, path: str, data: Any) -> None:
@@ -238,38 +232,32 @@ class ContractValidationMiddleware(MiddlewareMixin):
if isinstance(data, dict):
# Check for null vs undefined issues
for key, value in data.items():
if value is None and key.endswith('_id'):
if value is None and key.endswith("_id"):
# ID fields should probably be null, not undefined
continue
# Check for numeric fields that might be strings
if key.endswith('_count') and isinstance(value, str):
if key.endswith("_count") and isinstance(value, str):
try:
int(value)
self._log_contract_violation(
path,
"NUMERIC_FIELD_AS_STRING",
f"Field '{key}' appears to be numeric but is a string: '{value}'"
f"Field '{key}' appears to be numeric but is a string: '{value}'",
)
except ValueError:
pass
def _log_contract_violation(
self,
path: str,
violation_type: str,
message: str,
severity: str = "WARNING"
) -> None:
def _log_contract_violation(self, path: str, violation_type: str, message: str, severity: str = "WARNING") -> None:
"""Log a contract violation with structured data."""
log_data = {
'contract_violation': True,
'violation_type': violation_type,
'api_path': path,
'severity': severity,
'message': message,
'suggestion': self._get_violation_suggestion(violation_type)
"contract_violation": True,
"violation_type": violation_type,
"api_path": path,
"severity": severity,
"message": message,
"suggestion": self._get_violation_suggestion(violation_type),
}
if severity == "ERROR":
@@ -302,9 +290,8 @@ class ContractValidationMiddleware(MiddlewareMixin):
"Check serializer field types and database field types."
),
"RESULTS_NOT_ARRAY": (
"Ensure pagination 'results' field is always an array. "
"Check serializer implementation."
)
"Ensure pagination 'results' field is always an array. " "Check serializer implementation."
),
}
return suggestions.get(violation_type, "Check the API response format against frontend TypeScript interfaces.")
@@ -326,9 +313,9 @@ class ContractValidationSettings:
# Paths to exclude from validation
EXCLUDED_PATHS = [
'/api/docs/',
'/api/schema/',
'/api/v1/auth/', # Auth endpoints might have different structures
"/api/docs/",
"/api/schema/",
"/api/v1/auth/", # Auth endpoints might have different structures
]
@classmethod

View File

@@ -17,6 +17,7 @@ class ParkHistoryViewSet(viewsets.GenericViewSet):
"""
ViewSet for retrieving park history.
"""
permission_classes = [AllowAny]
lookup_field = "slug"
lookup_url_kwarg = "park_slug"
@@ -40,12 +41,7 @@ class ParkHistoryViewSet(viewsets.GenericViewSet):
"last_modified": events.first().pgh_created_at if len(events) else None,
}
data = {
"park": park,
"current_state": park,
"summary": summary,
"events": events
}
data = {"park": park, "current_state": park, "summary": summary, "events": events}
serializer = ParkHistoryOutputSerializer(data)
return Response(serializer.data)
@@ -55,6 +51,7 @@ class RideHistoryViewSet(viewsets.GenericViewSet):
"""
ViewSet for retrieving ride history.
"""
permission_classes = [AllowAny]
lookup_field = "slug"
lookup_url_kwarg = "ride_slug"
@@ -79,12 +76,7 @@ class RideHistoryViewSet(viewsets.GenericViewSet):
"last_modified": events.first().pgh_created_at if len(events) else None,
}
data = {
"ride": ride,
"current_state": ride,
"summary": summary,
"events": events
}
data = {"ride": ride, "current_state": ride, "summary": summary, "events": events}
serializer = RideHistoryOutputSerializer(data)
return Response(serializer.data)

View File

@@ -65,14 +65,12 @@ class ParkReviewViewSet(ModelViewSet):
def get_permissions(self):
"""Set permissions based on action."""
permission_classes = [AllowAny] if self.action in ['list', 'retrieve', 'stats'] else [IsAuthenticated]
permission_classes = [AllowAny] if self.action in ["list", "retrieve", "stats"] else [IsAuthenticated]
return [permission() for permission in permission_classes]
def get_queryset(self):
"""Get reviews for the current park."""
queryset = ParkReview.objects.select_related(
"park", "user", "user__profile"
)
queryset = ParkReview.objects.select_related("park", "user", "user__profile")
park_slug = self.kwargs.get("park_slug")
if park_slug:
@@ -82,7 +80,7 @@ class ParkReviewViewSet(ModelViewSet):
except Park.DoesNotExist:
return queryset.none()
if not (hasattr(self.request, 'user') and getattr(self.request.user, 'is_staff', False)):
if not (hasattr(self.request, "user") and getattr(self.request.user, "is_staff", False)):
queryset = queryset.filter(is_published=True)
return queryset.order_by("-created_at")
@@ -102,16 +100,12 @@ class ParkReviewViewSet(ModelViewSet):
try:
park, _ = Park.get_by_slug(park_slug)
except Park.DoesNotExist:
raise NotFound("Park not found")
raise NotFound("Park not found") from None
if ParkReview.objects.filter(park=park, user=self.request.user).exists():
raise ValidationError("You have already reviewed this park")
serializer.save(
park=park,
user=self.request.user,
is_published=True
)
serializer.save(park=park, user=self.request.user, is_published=True)
def perform_update(self, serializer):
instance = self.get_object()
@@ -134,17 +128,18 @@ class ParkReviewViewSet(ModelViewSet):
try:
park, _ = Park.get_by_slug(park_slug)
except Park.DoesNotExist:
return Response({"error": "Park not found"}, status=status.HTTP_404_NOT_FOUND)
return Response({"detail": "Park not found"}, status=status.HTTP_404_NOT_FOUND)
reviews = ParkReview.objects.filter(park=park, is_published=True)
total_reviews = reviews.count()
avg_rating = reviews.aggregate(avg=Avg('rating'))['avg']
avg_rating = reviews.aggregate(avg=Avg("rating"))["avg"]
rating_distribution = {}
for i in range(1, 11):
rating_distribution[str(i)] = reviews.filter(rating=i).count()
from datetime import timedelta
recent_reviews = reviews.filter(created_at__gte=timezone.now() - timedelta(days=30)).count()
stats = {

View File

@@ -21,6 +21,7 @@ from rest_framework.views import APIView
try:
from apps.parks.models import Park
from apps.rides.models import Ride
MODELS_AVAILABLE = True
except Exception:
Park = None # type: ignore
@@ -31,6 +32,7 @@ except Exception:
try:
from apps.api.v1.serializers.parks import ParkDetailOutputSerializer
from apps.api.v1.serializers.rides import RideDetailOutputSerializer, RideListOutputSerializer
SERIALIZERS_AVAILABLE = True
except Exception:
SERIALIZERS_AVAILABLE = False
@@ -52,22 +54,41 @@ class ParkRidesListAPIView(APIView):
description="Get paginated list of rides at a specific park with filtering options",
parameters=[
# Pagination
OpenApiParameter(name="page", location=OpenApiParameter.QUERY,
type=OpenApiTypes.INT, description="Page number"),
OpenApiParameter(name="page_size", location=OpenApiParameter.QUERY,
type=OpenApiTypes.INT, description="Number of results per page (max 100)"),
OpenApiParameter(
name="page", location=OpenApiParameter.QUERY, type=OpenApiTypes.INT, description="Page number"
),
OpenApiParameter(
name="page_size",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.INT,
description="Number of results per page (max 100)",
),
# Filtering
OpenApiParameter(name="category", location=OpenApiParameter.QUERY,
type=OpenApiTypes.STR, description="Filter by ride category"),
OpenApiParameter(name="status", location=OpenApiParameter.QUERY,
type=OpenApiTypes.STR, description="Filter by operational status"),
OpenApiParameter(name="search", location=OpenApiParameter.QUERY,
type=OpenApiTypes.STR, description="Search rides by name"),
OpenApiParameter(
name="category",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.STR,
description="Filter by ride category",
),
OpenApiParameter(
name="status",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.STR,
description="Filter by operational status",
),
OpenApiParameter(
name="search",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.STR,
description="Search rides by name",
),
# Ordering
OpenApiParameter(name="ordering", location=OpenApiParameter.QUERY,
type=OpenApiTypes.STR, description="Order results by field"),
OpenApiParameter(
name="ordering",
location=OpenApiParameter.QUERY,
type=OpenApiTypes.STR,
description="Order results by field",
),
],
responses={
200: OpenApiTypes.OBJECT,
@@ -87,12 +108,14 @@ class ParkRidesListAPIView(APIView):
try:
park, is_historical = Park.get_by_slug(park_slug)
except Park.DoesNotExist:
raise NotFound("Park not found")
raise NotFound("Park not found") from None
# Get rides for this park
qs = Ride.objects.filter(park=park).select_related(
"manufacturer", "designer", "ride_model", "park_area"
).prefetch_related("photos")
qs = (
Ride.objects.filter(park=park)
.select_related("manufacturer", "designer", "ride_model", "park_area")
.prefetch_related("photos")
)
# Apply filtering
qs = self._apply_filters(qs, request.query_params)
@@ -107,9 +130,7 @@ class ParkRidesListAPIView(APIView):
page = paginator.paginate_queryset(qs, request)
if SERIALIZERS_AVAILABLE:
serializer = RideListOutputSerializer(
page, many=True, context={"request": request, "park": park}
)
serializer = RideListOutputSerializer(page, many=True, context={"request": request, "park": park})
return paginator.get_paginated_response(serializer.data)
else:
# Fallback serialization
@@ -145,9 +166,7 @@ class ParkRidesListAPIView(APIView):
search = params.get("search")
if search:
qs = qs.filter(
Q(name__icontains=search) |
Q(description__icontains=search) |
Q(manufacturer__name__icontains=search)
Q(name__icontains=search) | Q(description__icontains=search) | Q(manufacturer__name__icontains=search)
)
return qs
@@ -179,42 +198,46 @@ class ParkRideDetailAPIView(APIView):
try:
park, is_historical = Park.get_by_slug(park_slug)
except Park.DoesNotExist:
raise NotFound("Park not found")
raise NotFound("Park not found") from None
# Get the ride
try:
ride, is_historical = Ride.get_by_slug(ride_slug, park=park)
except Ride.DoesNotExist:
raise NotFound("Ride not found at this park")
raise NotFound("Ride not found at this park") from None
# Ensure ride belongs to this park
if ride.park_id != park.id:
raise NotFound("Ride not found at this park")
if SERIALIZERS_AVAILABLE:
serializer = RideDetailOutputSerializer(
ride, context={"request": request, "park": park}
)
serializer = RideDetailOutputSerializer(ride, context={"request": request, "park": park})
return Response(serializer.data)
else:
# Fallback serialization
return Response({
"id": ride.id,
"name": ride.name,
"slug": ride.slug,
"description": getattr(ride, "description", ""),
"category": getattr(ride, "category", ""),
"status": getattr(ride, "status", ""),
"park": {
"id": park.id,
"name": park.name,
"slug": park.slug,
},
"manufacturer": {
"name": ride.manufacturer.name if ride.manufacturer else "",
"slug": getattr(ride.manufacturer, "slug", "") if ride.manufacturer else "",
} if ride.manufacturer else None,
})
return Response(
{
"id": ride.id,
"name": ride.name,
"slug": ride.slug,
"description": getattr(ride, "description", ""),
"category": getattr(ride, "category", ""),
"status": getattr(ride, "status", ""),
"park": {
"id": park.id,
"name": park.name,
"slug": park.slug,
},
"manufacturer": (
{
"name": ride.manufacturer.name if ride.manufacturer else "",
"slug": getattr(ride.manufacturer, "slug", "") if ride.manufacturer else "",
}
if ride.manufacturer
else None
),
}
)
class ParkComprehensiveDetailAPIView(APIView):
@@ -243,25 +266,21 @@ class ParkComprehensiveDetailAPIView(APIView):
try:
park, is_historical = Park.get_by_slug(park_slug)
except Park.DoesNotExist:
raise NotFound("Park not found")
raise NotFound("Park not found") from None
# Get park with full related data
park = Park.objects.select_related(
"operator", "property_owner", "location"
).prefetch_related(
"areas", "rides", "photos"
).get(pk=park.pk)
park = (
Park.objects.select_related("operator", "property_owner", "location")
.prefetch_related("areas", "rides", "photos")
.get(pk=park.pk)
)
# Get a sample of rides (first 10) for preview
rides_sample = Ride.objects.filter(park=park).select_related(
"manufacturer", "designer", "ride_model"
)[:10]
rides_sample = Ride.objects.filter(park=park).select_related("manufacturer", "designer", "ride_model")[:10]
if SERIALIZERS_AVAILABLE:
# Get full park details
park_serializer = ParkDetailOutputSerializer(
park, context={"request": request}
)
park_serializer = ParkDetailOutputSerializer(park, context={"request": request})
park_data = park_serializer.data
# Add rides summary
@@ -279,25 +298,27 @@ class ParkComprehensiveDetailAPIView(APIView):
return Response(park_data)
else:
# Fallback serialization
return Response({
"id": park.id,
"name": park.name,
"slug": park.slug,
"description": getattr(park, "description", ""),
"location": str(getattr(park, "location", "")),
"operator": getattr(park.operator, "name", "") if hasattr(park, "operator") else "",
"ride_count": getattr(park, "ride_count", 0),
"rides_summary": {
"total_count": getattr(park, "ride_count", 0),
"sample": [
{
"id": ride.id,
"name": ride.name,
"slug": ride.slug,
"category": getattr(ride, "category", ""),
}
for ride in rides_sample
],
"full_list_url": f"/api/v1/parks/{park_slug}/rides/",
},
})
return Response(
{
"id": park.id,
"name": park.name,
"slug": park.slug,
"description": getattr(park, "description", ""),
"location": str(getattr(park, "location", "")),
"operator": getattr(park.operator, "name", "") if hasattr(park, "operator") else "",
"ride_count": getattr(park, "ride_count", 0),
"rides_summary": {
"total_count": getattr(park, "ride_count", 0),
"sample": [
{
"id": ride.id,
"name": ride.name,
"slug": ride.slug,
"category": getattr(ride, "category", ""),
}
for ride in rides_sample
],
"full_list_url": f"/api/v1/parks/{park_slug}/rides/",
},
}
)

File diff suppressed because it is too large Load Diff

View File

@@ -116,14 +116,12 @@ class RidePhotoViewSet(ModelViewSet):
def get_permissions(self):
"""Set permissions based on action."""
permission_classes = [AllowAny] if self.action in ['list', 'retrieve', 'stats'] else [IsAuthenticated]
permission_classes = [AllowAny] if self.action in ["list", "retrieve", "stats"] else [IsAuthenticated]
return [permission() for permission in permission_classes]
def get_queryset(self):
"""Get photos for the current ride with optimized queries."""
queryset = RidePhoto.objects.select_related(
"ride", "ride__park", "ride__park__operator", "uploaded_by"
)
queryset = RidePhoto.objects.select_related("ride", "ride__park", "ride__park__operator", "uploaded_by")
# Filter by park and ride from URL kwargs
park_slug = self.kwargs.get("park_slug")
@@ -163,9 +161,9 @@ class RidePhotoViewSet(ModelViewSet):
park, _ = Park.get_by_slug(park_slug)
ride, _ = Ride.get_by_slug(ride_slug, park=park)
except Park.DoesNotExist:
raise NotFound("Park not found")
raise NotFound("Park not found") from None
except Ride.DoesNotExist:
raise NotFound("Ride not found at this park")
raise NotFound("Ride not found at this park") from None
try:
# Use the service to create the photo with proper business logic
@@ -187,17 +185,14 @@ class RidePhotoViewSet(ModelViewSet):
except Exception as e:
logger.error(f"Error creating ride photo: {e}")
raise ValidationError(f"Failed to create photo: {str(e)}")
raise ValidationError(f"Failed to create photo: {str(e)}") from None
def perform_update(self, serializer):
"""Update ride photo with permission checking."""
instance = self.get_object()
# Check permissions - allow owner or staff
if not (
self.request.user == instance.uploaded_by
or getattr(self.request.user, "is_staff", False)
):
if not (self.request.user == instance.uploaded_by or getattr(self.request.user, "is_staff", False)):
raise PermissionDenied("You can only edit your own photos or be an admin.")
# Handle primary photo logic using service
@@ -209,48 +204,40 @@ class RidePhotoViewSet(ModelViewSet):
del serializer.validated_data["is_primary"]
except Exception as e:
logger.error(f"Error setting primary photo: {e}")
raise ValidationError(f"Failed to set primary photo: {str(e)}")
raise ValidationError(f"Failed to set primary photo: {str(e)}") from None
try:
serializer.save()
logger.info(f"Updated ride photo {instance.id} by user {self.request.user.username}")
except Exception as e:
logger.error(f"Error updating ride photo: {e}")
raise ValidationError(f"Failed to update photo: {str(e)}")
raise ValidationError(f"Failed to update photo: {str(e)}") from None
def perform_destroy(self, instance):
"""Delete ride photo with permission checking."""
# Check permissions - allow owner or staff
if not (
self.request.user == instance.uploaded_by
or getattr(self.request.user, "is_staff", False)
):
raise PermissionDenied(
"You can only delete your own photos or be an admin."
)
if not (self.request.user == instance.uploaded_by or getattr(self.request.user, "is_staff", False)):
raise PermissionDenied("You can only delete your own photos or be an admin.")
try:
# Delete from Cloudflare first if image exists
if instance.image:
try:
from django_cloudflareimages_toolkit.services import CloudflareImagesService
service = CloudflareImagesService()
service.delete_image(instance.image)
logger.info(
f"Successfully deleted ride photo from Cloudflare: {instance.image.cloudflare_id}")
logger.info(f"Successfully deleted ride photo from Cloudflare: {instance.image.cloudflare_id}")
except Exception as e:
logger.error(
f"Failed to delete ride photo from Cloudflare: {str(e)}")
logger.error(f"Failed to delete ride photo from Cloudflare: {str(e)}")
# Continue with database deletion even if Cloudflare deletion fails
RideMediaService.delete_photo(
instance, deleted_by=self.request.user
)
RideMediaService.delete_photo(instance, deleted_by=self.request.user)
logger.info(f"Deleted ride photo {instance.id} by user {self.request.user.username}")
except Exception as e:
logger.error(f"Error deleting ride photo: {e}")
raise ValidationError(f"Failed to delete photo: {str(e)}")
raise ValidationError(f"Failed to delete photo: {str(e)}") from None
@extend_schema(
summary="Set photo as primary",
@@ -269,13 +256,8 @@ class RidePhotoViewSet(ModelViewSet):
photo = self.get_object()
# Check permissions - allow owner or staff
if not (
request.user == photo.uploaded_by
or getattr(request.user, "is_staff", False)
):
raise PermissionDenied(
"You can only modify your own photos or be an admin."
)
if not (request.user == photo.uploaded_by or getattr(request.user, "is_staff", False)):
raise PermissionDenied("You can only modify your own photos or be an admin.")
try:
success = RideMediaService.set_primary_photo(ride=photo.ride, photo=photo)
@@ -287,21 +269,21 @@ class RidePhotoViewSet(ModelViewSet):
return Response(
{
"message": "Photo set as primary successfully",
"detail": "Photo set as primary successfully",
"photo": serializer.data,
},
status=status.HTTP_200_OK,
)
else:
return Response(
{"error": "Failed to set primary photo"},
{"detail": "Failed to set primary photo"},
status=status.HTTP_400_BAD_REQUEST,
)
except Exception as e:
logger.error(f"Error setting primary photo: {e}")
return Response(
{"error": f"Failed to set primary photo: {str(e)}"},
{"detail": f"Failed to set primary photo: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -334,7 +316,7 @@ class RidePhotoViewSet(ModelViewSet):
if photo_ids is None or approve is None:
return Response(
{"error": "Missing required fields: photo_ids and/or approve."},
{"detail": "Missing required fields: photo_ids and/or approve."},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -350,7 +332,7 @@ class RidePhotoViewSet(ModelViewSet):
return Response(
{
"message": f"Successfully {'approved' if approve else 'rejected'} {updated_count} photos",
"detail": f"Successfully {'approved' if approve else 'rejected'} {updated_count} photos",
"updated_count": updated_count,
},
status=status.HTTP_200_OK,
@@ -359,7 +341,7 @@ class RidePhotoViewSet(ModelViewSet):
except Exception as e:
logger.error(f"Error in bulk photo approval: {e}")
return Response(
{"error": f"Failed to update photos: {str(e)}"},
{"detail": f"Failed to update photos: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -381,7 +363,7 @@ class RidePhotoViewSet(ModelViewSet):
if not park_slug or not ride_slug:
return Response(
{"error": "Park and ride slugs are required"},
{"detail": "Park and ride slugs are required"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -390,12 +372,12 @@ class RidePhotoViewSet(ModelViewSet):
ride, _ = Ride.get_by_slug(ride_slug, park=park)
except Park.DoesNotExist:
return Response(
{"error": "Park not found"},
{"detail": "Park not found"},
status=status.HTTP_404_NOT_FOUND,
)
except Ride.DoesNotExist:
return Response(
{"error": "Ride not found at this park"},
{"detail": "Ride not found at this park"},
status=status.HTTP_404_NOT_FOUND,
)
@@ -407,7 +389,7 @@ class RidePhotoViewSet(ModelViewSet):
except Exception as e:
logger.error(f"Error getting ride photo stats: {e}")
return Response(
{"error": f"Failed to get photo statistics: {str(e)}"},
{"detail": f"Failed to get photo statistics: {str(e)}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@@ -431,7 +413,7 @@ class RidePhotoViewSet(ModelViewSet):
if not park_slug or not ride_slug:
return Response(
{"error": "Park and ride slugs are required"},
{"detail": "Park and ride slugs are required"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -440,19 +422,19 @@ class RidePhotoViewSet(ModelViewSet):
ride, _ = Ride.get_by_slug(ride_slug, park=park)
except Park.DoesNotExist:
return Response(
{"error": "Park not found"},
{"detail": "Park not found"},
status=status.HTTP_404_NOT_FOUND,
)
except Ride.DoesNotExist:
return Response(
{"error": "Ride not found at this park"},
{"detail": "Ride not found at this park"},
status=status.HTTP_404_NOT_FOUND,
)
cloudflare_image_id = request.data.get("cloudflare_image_id")
if not cloudflare_image_id:
return Response(
{"error": "cloudflare_image_id is required"},
{"detail": "cloudflare_image_id is required"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -469,27 +451,25 @@ class RidePhotoViewSet(ModelViewSet):
if not image_data:
return Response(
{"error": "Image not found in Cloudflare"},
{"detail": "Image not found in Cloudflare"},
status=status.HTTP_400_BAD_REQUEST,
)
# Try to find existing CloudflareImage record by cloudflare_id
cloudflare_image = None
try:
cloudflare_image = CloudflareImage.objects.get(
cloudflare_id=cloudflare_image_id)
cloudflare_image = CloudflareImage.objects.get(cloudflare_id=cloudflare_image_id)
# Update existing record with latest data from Cloudflare
cloudflare_image.status = 'uploaded'
cloudflare_image.status = "uploaded"
cloudflare_image.uploaded_at = timezone.now()
cloudflare_image.metadata = image_data.get('meta', {})
cloudflare_image.metadata = image_data.get("meta", {})
# Extract variants from nested result structure
cloudflare_image.variants = image_data.get(
'result', {}).get('variants', [])
cloudflare_image.variants = image_data.get("result", {}).get("variants", [])
cloudflare_image.cloudflare_metadata = image_data
cloudflare_image.width = image_data.get('width')
cloudflare_image.height = image_data.get('height')
cloudflare_image.format = image_data.get('format', '')
cloudflare_image.width = image_data.get("width")
cloudflare_image.height = image_data.get("height")
cloudflare_image.format = image_data.get("format", "")
cloudflare_image.save()
except CloudflareImage.DoesNotExist:
@@ -497,24 +477,23 @@ class RidePhotoViewSet(ModelViewSet):
cloudflare_image = CloudflareImage.objects.create(
cloudflare_id=cloudflare_image_id,
user=request.user,
status='uploaded',
upload_url='', # Not needed for uploaded images
status="uploaded",
upload_url="", # Not needed for uploaded images
expires_at=timezone.now() + timezone.timedelta(days=365), # Set far future expiry
uploaded_at=timezone.now(),
metadata=image_data.get('meta', {}),
metadata=image_data.get("meta", {}),
# Extract variants from nested result structure
variants=image_data.get('result', {}).get('variants', []),
variants=image_data.get("result", {}).get("variants", []),
cloudflare_metadata=image_data,
width=image_data.get('width'),
height=image_data.get('height'),
format=image_data.get('format', ''),
width=image_data.get("width"),
height=image_data.get("height"),
format=image_data.get("format", ""),
)
except Exception as api_error:
logger.error(
f"Error fetching image from Cloudflare API: {str(api_error)}", exc_info=True)
logger.error(f"Error fetching image from Cloudflare API: {str(api_error)}", exc_info=True)
return Response(
{"error": f"Failed to fetch image from Cloudflare: {str(api_error)}"},
{"detail": f"Failed to fetch image from Cloudflare: {str(api_error)}"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -544,6 +523,6 @@ class RidePhotoViewSet(ModelViewSet):
except Exception as e:
logger.error(f"Error saving ride photo: {e}")
return Response(
{"error": f"Failed to save photo: {str(e)}"},
{"detail": f"Failed to save photo: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST,
)

View File

@@ -115,14 +115,12 @@ class RideReviewViewSet(ModelViewSet):
def get_permissions(self):
"""Set permissions based on action."""
permission_classes = [AllowAny] if self.action in ['list', 'retrieve', 'stats'] else [IsAuthenticated]
permission_classes = [AllowAny] if self.action in ["list", "retrieve", "stats"] else [IsAuthenticated]
return [permission() for permission in permission_classes]
def get_queryset(self):
"""Get reviews for the current ride with optimized queries."""
queryset = RideReview.objects.select_related(
"ride", "ride__park", "user", "user__profile"
)
queryset = RideReview.objects.select_related("ride", "ride__park", "user", "user__profile")
# Filter by park and ride from URL kwargs
park_slug = self.kwargs.get("park_slug")
@@ -138,8 +136,7 @@ class RideReviewViewSet(ModelViewSet):
return queryset.none()
# Filter published reviews for non-staff users
if not (hasattr(self.request, 'user') and
getattr(self.request.user, 'is_staff', False)):
if not (hasattr(self.request, "user") and getattr(self.request.user, "is_staff", False)):
queryset = queryset.filter(is_published=True)
return queryset.order_by("-created_at")
@@ -167,9 +164,9 @@ class RideReviewViewSet(ModelViewSet):
park, _ = Park.get_by_slug(park_slug)
ride, _ = Ride.get_by_slug(ride_slug, park=park)
except Park.DoesNotExist:
raise NotFound("Park not found")
raise NotFound("Park not found") from None
except Ride.DoesNotExist:
raise NotFound("Ride not found at this park")
raise NotFound("Ride not found at this park") from None
# Check if user already has a review for this ride
if RideReview.objects.filter(ride=ride, user=self.request.user).exists():
@@ -178,26 +175,21 @@ class RideReviewViewSet(ModelViewSet):
try:
# Save the review
review = serializer.save(
ride=ride,
user=self.request.user,
is_published=True # Auto-publish for now, can add moderation later
ride=ride, user=self.request.user, is_published=True # Auto-publish for now, can add moderation later
)
logger.info(f"Created ride review {review.id} for ride {ride.name} by user {self.request.user.username}")
except Exception as e:
logger.error(f"Error creating ride review: {e}")
raise ValidationError(f"Failed to create review: {str(e)}")
raise ValidationError(f"Failed to create review: {str(e)}") from None
def perform_update(self, serializer):
"""Update ride review with permission checking."""
instance = self.get_object()
# Check permissions - allow owner or staff
if not (
self.request.user == instance.user
or getattr(self.request.user, "is_staff", False)
):
if not (self.request.user == instance.user or getattr(self.request.user, "is_staff", False)):
raise PermissionDenied("You can only edit your own reviews or be an admin.")
try:
@@ -205,15 +197,12 @@ class RideReviewViewSet(ModelViewSet):
logger.info(f"Updated ride review {instance.id} by user {self.request.user.username}")
except Exception as e:
logger.error(f"Error updating ride review: {e}")
raise ValidationError(f"Failed to update review: {str(e)}")
raise ValidationError(f"Failed to update review: {str(e)}") from None
def perform_destroy(self, instance):
"""Delete ride review with permission checking."""
# Check permissions - allow owner or staff
if not (
self.request.user == instance.user
or getattr(self.request.user, "is_staff", False)
):
if not (self.request.user == instance.user or getattr(self.request.user, "is_staff", False)):
raise PermissionDenied("You can only delete your own reviews or be an admin.")
try:
@@ -221,7 +210,7 @@ class RideReviewViewSet(ModelViewSet):
instance.delete()
except Exception as e:
logger.error(f"Error deleting ride review: {e}")
raise ValidationError(f"Failed to delete review: {str(e)}")
raise ValidationError(f"Failed to delete review: {str(e)}") from None
@extend_schema(
summary="Get ride review statistics",
@@ -241,7 +230,7 @@ class RideReviewViewSet(ModelViewSet):
if not park_slug or not ride_slug:
return Response(
{"error": "Park and ride slugs are required"},
{"detail": "Park and ride slugs are required"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -250,12 +239,12 @@ class RideReviewViewSet(ModelViewSet):
ride, _ = Ride.get_by_slug(ride_slug, park=park)
except Park.DoesNotExist:
return Response(
{"error": "Park not found"},
{"detail": "Park not found"},
status=status.HTTP_404_NOT_FOUND,
)
except Ride.DoesNotExist:
return Response(
{"error": "Ride not found at this park"},
{"detail": "Ride not found at this park"},
status=status.HTTP_404_NOT_FOUND,
)
@@ -268,7 +257,7 @@ class RideReviewViewSet(ModelViewSet):
pending_reviews = RideReview.objects.filter(ride=ride, is_published=False).count()
# Calculate average rating
avg_rating = reviews.aggregate(avg_rating=Avg('rating'))['avg_rating']
avg_rating = reviews.aggregate(avg_rating=Avg("rating"))["avg_rating"]
# Get rating distribution
rating_distribution = {}
@@ -277,6 +266,7 @@ class RideReviewViewSet(ModelViewSet):
# Get recent reviews count (last 30 days)
from datetime import timedelta
thirty_days_ago = timezone.now() - timedelta(days=30)
recent_reviews = reviews.filter(created_at__gte=thirty_days_ago).count()
@@ -295,7 +285,7 @@ class RideReviewViewSet(ModelViewSet):
except Exception as e:
logger.error(f"Error getting ride review stats: {e}")
return Response(
{"error": f"Failed to get review statistics: {str(e)}"},
{"detail": f"Failed to get review statistics: {str(e)}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@@ -340,7 +330,7 @@ class RideReviewViewSet(ModelViewSet):
is_published=True,
moderated_by=request.user,
moderated_at=timezone.now(),
moderation_notes=moderation_notes
moderation_notes=moderation_notes,
)
message = f"Successfully published {updated_count} reviews"
elif action_type == "unpublish":
@@ -348,7 +338,7 @@ class RideReviewViewSet(ModelViewSet):
is_published=False,
moderated_by=request.user,
moderated_at=timezone.now(),
moderation_notes=moderation_notes
moderation_notes=moderation_notes,
)
message = f"Successfully unpublished {updated_count} reviews"
elif action_type == "delete":
@@ -357,13 +347,13 @@ class RideReviewViewSet(ModelViewSet):
message = f"Successfully deleted {updated_count} reviews"
else:
return Response(
{"error": "Invalid action type"},
{"detail": "Invalid action type"},
status=status.HTTP_400_BAD_REQUEST,
)
return Response(
{
"message": message,
"detail": message,
"updated_count": updated_count,
},
status=status.HTTP_200_OK,
@@ -372,6 +362,6 @@ class RideReviewViewSet(ModelViewSet):
except Exception as e:
logger.error(f"Error in bulk review moderation: {e}")
return Response(
{"error": f"Failed to moderate reviews: {str(e)}"},
{"detail": f"Failed to moderate reviews: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST,
)

View File

@@ -50,18 +50,14 @@ from apps.parks.models import Park, ParkPhoto
class ParkPhotoOutputSerializer(serializers.ModelSerializer):
"""Enhanced output serializer for park photos with Cloudflare Images support."""
uploaded_by_username = serializers.CharField(
source="uploaded_by.username", read_only=True
)
uploaded_by_username = serializers.CharField(source="uploaded_by.username", read_only=True)
file_size = serializers.SerializerMethodField()
dimensions = serializers.SerializerMethodField()
image_url = serializers.SerializerMethodField()
image_variants = serializers.SerializerMethodField()
@extend_schema_field(
serializers.IntegerField(allow_null=True, help_text="File size in bytes")
)
@extend_schema_field(serializers.IntegerField(allow_null=True, help_text="File size in bytes"))
def get_file_size(self, obj):
"""Get file size in bytes."""
return obj.file_size
@@ -79,11 +75,7 @@ class ParkPhotoOutputSerializer(serializers.ModelSerializer):
"""Get image dimensions as [width, height]."""
return obj.dimensions
@extend_schema_field(
serializers.URLField(
help_text="Full URL to the Cloudflare Images asset", allow_null=True
)
)
@extend_schema_field(serializers.URLField(help_text="Full URL to the Cloudflare Images asset", allow_null=True))
def get_image_url(self, obj):
"""Get the full Cloudflare Images URL."""
if obj.image:
@@ -175,9 +167,7 @@ class ParkPhotoUpdateInputSerializer(serializers.ModelSerializer):
class ParkPhotoListOutputSerializer(serializers.ModelSerializer):
"""Optimized output serializer for park photo lists."""
uploaded_by_username = serializers.CharField(
source="uploaded_by.username", read_only=True
)
uploaded_by_username = serializers.CharField(source="uploaded_by.username", read_only=True)
class Meta:
model = ParkPhoto
@@ -196,12 +186,8 @@ class ParkPhotoListOutputSerializer(serializers.ModelSerializer):
class ParkPhotoApprovalInputSerializer(serializers.Serializer):
"""Input serializer for bulk photo approval operations."""
photo_ids = serializers.ListField(
child=serializers.IntegerField(), help_text="List of photo IDs to approve"
)
approve = serializers.BooleanField(
default=True, help_text="Whether to approve (True) or reject (False) the photos"
)
photo_ids = serializers.ListField(child=serializers.IntegerField(), help_text="List of photo IDs to approve")
approve = serializers.BooleanField(default=True, help_text="Whether to approve (True) or reject (False) the photos")
class ParkPhotoStatsOutputSerializer(serializers.Serializer):
@@ -261,7 +247,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
def get_city(self, obj):
"""Get city from related location."""
try:
return obj.location.city if hasattr(obj, 'location') and obj.location else None
return obj.location.city if hasattr(obj, "location") and obj.location else None
except AttributeError:
return None
@@ -269,7 +255,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
def get_state(self, obj):
"""Get state from related location."""
try:
return obj.location.state if hasattr(obj, 'location') and obj.location else None
return obj.location.state if hasattr(obj, "location") and obj.location else None
except AttributeError:
return None
@@ -277,7 +263,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
def get_country(self, obj):
"""Get country from related location."""
try:
return obj.location.country if hasattr(obj, 'location') and obj.location else None
return obj.location.country if hasattr(obj, "location") and obj.location else None
except AttributeError:
return None
@@ -285,7 +271,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
def get_continent(self, obj):
"""Get continent from related location."""
try:
return obj.location.continent if hasattr(obj, 'location') and obj.location else None
return obj.location.continent if hasattr(obj, "location") and obj.location else None
except AttributeError:
return None
@@ -293,7 +279,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
def get_latitude(self, obj):
"""Get latitude from related location."""
try:
if hasattr(obj, 'location') and obj.location and obj.location.coordinates:
if hasattr(obj, "location") and obj.location and obj.location.coordinates:
return obj.location.coordinates[1] # PostGIS returns [lon, lat]
return None
except (AttributeError, IndexError, TypeError):
@@ -303,7 +289,7 @@ class HybridParkSerializer(serializers.ModelSerializer):
def get_longitude(self, obj):
"""Get longitude from related location."""
try:
if hasattr(obj, 'location') and obj.location and obj.location.coordinates:
if hasattr(obj, "location") and obj.location and obj.location.coordinates:
return obj.location.coordinates[0] # PostGIS returns [lon, lat]
return None
except (AttributeError, IndexError, TypeError):
@@ -333,13 +319,11 @@ class HybridParkSerializer(serializers.ModelSerializer):
"description",
"status",
"park_type",
# Dates and computed fields
"opening_date",
"closing_date",
"opening_year",
"operating_season",
# Location fields
"city",
"state",
@@ -347,28 +331,22 @@ class HybridParkSerializer(serializers.ModelSerializer):
"continent",
"latitude",
"longitude",
# Company relationships
"operator_name",
"property_owner_name",
# Statistics
"size_acres",
"average_rating",
"ride_count",
"coaster_count",
# Images
"banner_image_url",
"card_image_url",
# URLs
"website",
"url",
# Computed fields for filtering
"search_text",
# Metadata
"created_at",
"updated_at",

View File

@@ -46,8 +46,8 @@ ride_photos_router.register(r"", RidePhotoViewSet, basename="ride-photo")
ride_reviews_router = DefaultRouter()
ride_reviews_router.register(r"", RideReviewViewSet, basename="ride-review")
from .history_views import ParkHistoryViewSet, RideHistoryViewSet
from .park_reviews_views import ParkReviewViewSet
from .history_views import ParkHistoryViewSet, RideHistoryViewSet # noqa: E402
from .park_reviews_views import ParkReviewViewSet # noqa: E402
# Create routers for nested park endpoints
reviews_router = DefaultRouter()
@@ -59,11 +59,9 @@ app_name = "api_v1_parks"
urlpatterns = [
# Core list/create endpoints
path("", ParkListCreateAPIView.as_view(), name="park-list-create"),
# Hybrid filtering endpoints
path("hybrid/", HybridParkAPIView.as_view(), name="park-hybrid-list"),
path("hybrid/filter-metadata/", ParkFilterMetadataAPIView.as_view(), name="park-hybrid-filter-metadata"),
# Filter options
path("filter-options/", FilterOptionsAPIView.as_view(), name="park-filter-options"),
# Autocomplete / suggestion endpoints
@@ -79,14 +77,11 @@ urlpatterns = [
),
# Detail and action endpoints - supports both ID and slug
path("<str:pk>/", ParkDetailAPIView.as_view(), name="park-detail"),
# Park rides endpoints
path("<str:park_slug>/rides/", ParkRidesListAPIView.as_view(), name="park-rides-list"),
path("<str:park_slug>/rides/<str:ride_slug>/", ParkRideDetailAPIView.as_view(), name="park-ride-detail"),
# Comprehensive park detail endpoint with rides summary
path("<str:park_slug>/detail/", ParkComprehensiveDetailAPIView.as_view(), name="park-comprehensive-detail"),
# Park image settings endpoint
path(
"<int:pk>/image-settings/",
@@ -95,33 +90,29 @@ urlpatterns = [
),
# Park photo endpoints - domain-specific photo management
path("<str:park_pk>/photos/", include(router.urls)),
# Nested ride photo endpoints - photos for specific rides within parks
path("<str:park_slug>/rides/<str:ride_slug>/photos/", include(ride_photos_router.urls)),
# Nested ride review endpoints - reviews for specific rides within parks
path("<str:park_slug>/rides/<str:ride_slug>/reviews/", include(ride_reviews_router.urls)),
# Nested ride review endpoints - reviews for specific rides within parks
path("<str:park_slug>/rides/<str:ride_slug>/reviews/", include(ride_reviews_router.urls)),
# Ride History
path("<str:park_slug>/rides/<str:ride_slug>/history/", RideHistoryViewSet.as_view({'get': 'list'}), name="ride-history"),
path(
"<str:park_slug>/rides/<str:ride_slug>/history/",
RideHistoryViewSet.as_view({"get": "list"}),
name="ride-history",
),
# Park Reviews
path("<str:park_slug>/reviews/", include(reviews_router.urls)),
# Park History
path("<str:park_slug>/history/", ParkHistoryViewSet.as_view({'get': 'list'}), name="park-history"),
path("<str:park_slug>/history/", ParkHistoryViewSet.as_view({"get": "list"}), name="park-history"),
# Roadtrip API endpoints
path("roadtrip/create/", CreateTripView.as_view(), name="roadtrip-create"),
path("roadtrip/find-along-route/", FindParksAlongRouteView.as_view(), name="roadtrip-find"),
path("roadtrip/geocode/", GeocodeAddressView.as_view(), name="roadtrip-geocode"),
path("roadtrip/distance/", ParkDistanceCalculatorView.as_view(), name="roadtrip-distance"),
# Operator endpoints
path("operators/", OperatorListAPIView.as_view(), name="operator-list"),
# Location search endpoints
path("search/location/", location_search, name="location-search"),
path("search/reverse-geocode/", reverse_geocode, name="reverse-geocode"),

View File

@@ -134,9 +134,7 @@ class ParkPhotoViewSet(ModelViewSet):
def get_queryset(self): # type: ignore[override]
"""Get photos for the current park with optimized queries."""
queryset = ParkPhoto.objects.select_related(
"park", "park__operator", "uploaded_by"
)
queryset = ParkPhoto.objects.select_related("park", "park__operator", "uploaded_by")
# If park_pk is provided in URL kwargs, filter by park
# If park_pk is provided in URL kwargs, filter by park
@@ -172,7 +170,7 @@ class ParkPhotoViewSet(ModelViewSet):
# Use real park ID
park_id = park.id
except Park.DoesNotExist:
raise ValidationError("Park not found")
raise ValidationError("Park not found") from None
try:
# Use the service to create the photo with proper business logic
@@ -188,48 +186,38 @@ class ParkPhotoViewSet(ModelViewSet):
except (ValidationException, ValidationError) as e:
logger.warning(f"Validation error creating park photo: {e}")
raise ValidationError(str(e))
raise ValidationError(str(e)) from None
except ServiceError as e:
logger.error(f"Service error creating park photo: {e}")
raise ValidationError(f"Failed to create photo: {str(e)}")
raise ValidationError(f"Failed to create photo: {str(e)}") from None
def perform_update(self, serializer):
"""Update park photo with permission checking."""
instance = self.get_object()
# Check permissions - allow owner or staff
if not (
self.request.user == instance.uploaded_by
or cast(Any, self.request.user).is_staff
):
if not (self.request.user == instance.uploaded_by or cast(Any, self.request.user).is_staff):
raise PermissionDenied("You can only edit your own photos or be an admin.")
# Handle primary photo logic using service
if serializer.validated_data.get("is_primary", False):
try:
ParkMediaService().set_primary_photo(
park_id=instance.park_id, photo_id=instance.id
)
ParkMediaService().set_primary_photo(park_id=instance.park_id, photo_id=instance.id)
# Remove is_primary from validated_data since service handles it
if "is_primary" in serializer.validated_data:
del serializer.validated_data["is_primary"]
except (ValidationException, ValidationError) as e:
logger.warning(f"Validation error setting primary photo: {e}")
raise ValidationError(str(e))
raise ValidationError(str(e)) from None
except ServiceError as e:
logger.error(f"Service error setting primary photo: {e}")
raise ValidationError(f"Failed to set primary photo: {str(e)}")
raise ValidationError(f"Failed to set primary photo: {str(e)}") from None
def perform_destroy(self, instance):
"""Delete park photo with permission checking."""
# Check permissions - allow owner or staff
if not (
self.request.user == instance.uploaded_by
or cast(Any, self.request.user).is_staff
):
raise PermissionDenied(
"You can only delete your own photos or be an admin."
)
if not (self.request.user == instance.uploaded_by or cast(Any, self.request.user).is_staff):
raise PermissionDenied("You can only delete your own photos or be an admin.")
# Delete from Cloudflare first if image exists
if instance.image:
@@ -240,9 +228,7 @@ class ParkPhotoViewSet(ModelViewSet):
service = CloudflareImagesService()
service.delete_image(instance.image)
logger.info(
f"Successfully deleted park photo from Cloudflare: {instance.image.cloudflare_id}"
)
logger.info(f"Successfully deleted park photo from Cloudflare: {instance.image.cloudflare_id}")
except ImportError:
logger.warning("CloudflareImagesService not available")
except ServiceError as e:
@@ -250,12 +236,10 @@ class ParkPhotoViewSet(ModelViewSet):
# Continue with database deletion even if Cloudflare deletion fails
try:
ParkMediaService().delete_photo(
instance.id, deleted_by=cast(UserModel, self.request.user)
)
ParkMediaService().delete_photo(instance.id, deleted_by=cast(UserModel, self.request.user))
except ServiceError as e:
logger.error(f"Service error deleting park photo: {e}")
raise ValidationError(f"Failed to delete photo: {str(e)}")
raise ValidationError(f"Failed to delete photo: {str(e)}") from None
@extend_schema(
summary="Set photo as primary",
@@ -275,14 +259,10 @@ class ParkPhotoViewSet(ModelViewSet):
# Check permissions - allow owner or staff
if not (request.user == photo.uploaded_by or cast(Any, request.user).is_staff):
raise PermissionDenied(
"You can only modify your own photos or be an admin."
)
raise PermissionDenied("You can only modify your own photos or be an admin.")
try:
ParkMediaService().set_primary_photo(
park_id=photo.park_id, photo_id=photo.id
)
ParkMediaService().set_primary_photo(park_id=photo.park_id, photo_id=photo.id)
# Refresh the photo instance
photo.refresh_from_db()
@@ -290,7 +270,7 @@ class ParkPhotoViewSet(ModelViewSet):
return Response(
{
"message": "Photo set as primary successfully",
"detail": "Photo set as primary successfully",
"photo": serializer.data,
},
status=status.HTTP_200_OK,
@@ -337,7 +317,7 @@ class ParkPhotoViewSet(ModelViewSet):
if photo_ids is None or approve is None:
return Response(
{"error": "Missing required fields: photo_ids and/or approve."},
{"detail": "Missing required fields: photo_ids and/or approve."},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -354,7 +334,7 @@ class ParkPhotoViewSet(ModelViewSet):
return Response(
{
"message": f"Successfully {'approved' if approve else 'rejected'} {updated_count} photos",
"detail": f"Successfully {'approved' if approve else 'rejected'} {updated_count} photos",
"updated_count": updated_count,
},
status=status.HTTP_200_OK,
@@ -430,19 +410,14 @@ class ParkPhotoViewSet(ModelViewSet):
def set_primary_legacy(self, request, id=None):
"""Legacy set primary action for backwards compatibility."""
photo = self.get_object()
if not (
request.user == photo.uploaded_by
or request.user.has_perm("parks.change_parkphoto")
):
if not (request.user == photo.uploaded_by or request.user.has_perm("parks.change_parkphoto")):
return Response(
{"error": "You do not have permission to edit photos for this park."},
{"detail": "You do not have permission to edit photos for this park."},
status=status.HTTP_403_FORBIDDEN,
)
try:
ParkMediaService().set_primary_photo(
park_id=photo.park_id, photo_id=photo.id
)
return Response({"message": "Photo set as primary successfully."})
ParkMediaService().set_primary_photo(park_id=photo.park_id, photo_id=photo.id)
return Response({"detail": "Photo set as primary successfully."})
except (ValidationException, ValidationError) as e:
logger.warning(f"Validation error in set_primary_photo: {str(e)}")
return ErrorHandler.handle_api_error(
@@ -475,7 +450,7 @@ class ParkPhotoViewSet(ModelViewSet):
park_pk = self.kwargs.get("park_pk")
if not park_pk:
return Response(
{"error": "Park ID is required"},
{"detail": "Park ID is required"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -483,14 +458,14 @@ class ParkPhotoViewSet(ModelViewSet):
park = Park.objects.get(pk=park_pk) if str(park_pk).isdigit() else Park.objects.get(slug=park_pk)
except Park.DoesNotExist:
return Response(
{"error": "Park not found"},
{"detail": "Park not found"},
status=status.HTTP_404_NOT_FOUND,
)
cloudflare_image_id = request.data.get("cloudflare_image_id")
if not cloudflare_image_id:
return Response(
{"error": "cloudflare_image_id is required"},
{"detail": "cloudflare_image_id is required"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -515,18 +490,14 @@ class ParkPhotoViewSet(ModelViewSet):
# Try to find existing CloudflareImage record by cloudflare_id
cloudflare_image = None
try:
cloudflare_image = CloudflareImage.objects.get(
cloudflare_id=cloudflare_image_id
)
cloudflare_image = CloudflareImage.objects.get(cloudflare_id=cloudflare_image_id)
# Update existing record with latest data from Cloudflare
cloudflare_image.status = "uploaded"
cloudflare_image.uploaded_at = timezone.now()
cloudflare_image.metadata = image_data.get("meta", {})
# Extract variants from nested result structure
cloudflare_image.variants = image_data.get("result", {}).get(
"variants", []
)
cloudflare_image.variants = image_data.get("result", {}).get("variants", [])
cloudflare_image.cloudflare_metadata = image_data
cloudflare_image.width = image_data.get("width")
cloudflare_image.height = image_data.get("height")
@@ -540,8 +511,7 @@ class ParkPhotoViewSet(ModelViewSet):
user=request.user,
status="uploaded",
upload_url="", # Not needed for uploaded images
expires_at=timezone.now()
+ timezone.timedelta(days=365), # Set far future expiry
expires_at=timezone.now() + timezone.timedelta(days=365), # Set far future expiry
uploaded_at=timezone.now(),
metadata=image_data.get("meta", {}),
# Extract variants from nested result structure
@@ -567,9 +537,7 @@ class ParkPhotoViewSet(ModelViewSet):
# Handle primary photo logic if requested
if request.data.get("is_primary", False):
try:
ParkMediaService().set_primary_photo(
park_id=park.id, photo_id=photo.id
)
ParkMediaService().set_primary_photo(park_id=park.id, photo_id=photo.id)
except ServiceError as e:
logger.error(f"Error setting primary photo: {e}")
# Don't fail the entire operation, just log the error
@@ -624,12 +592,8 @@ class ParkPhotoViewSet(ModelViewSet):
OpenApiTypes.STR,
description="Filter by state (comma-separated for multiple)",
),
OpenApiParameter(
"opening_year_min", OpenApiTypes.INT, description="Minimum opening year"
),
OpenApiParameter(
"opening_year_max", OpenApiTypes.INT, description="Maximum opening year"
),
OpenApiParameter("opening_year_min", OpenApiTypes.INT, description="Minimum opening year"),
OpenApiParameter("opening_year_max", OpenApiTypes.INT, description="Maximum opening year"),
OpenApiParameter(
"size_min",
OpenApiTypes.NUMBER,
@@ -640,18 +604,10 @@ class ParkPhotoViewSet(ModelViewSet):
OpenApiTypes.NUMBER,
description="Maximum park size in acres",
),
OpenApiParameter(
"rating_min", OpenApiTypes.NUMBER, description="Minimum average rating"
),
OpenApiParameter(
"rating_max", OpenApiTypes.NUMBER, description="Maximum average rating"
),
OpenApiParameter(
"ride_count_min", OpenApiTypes.INT, description="Minimum ride count"
),
OpenApiParameter(
"ride_count_max", OpenApiTypes.INT, description="Maximum ride count"
),
OpenApiParameter("rating_min", OpenApiTypes.NUMBER, description="Minimum average rating"),
OpenApiParameter("rating_max", OpenApiTypes.NUMBER, description="Maximum average rating"),
OpenApiParameter("ride_count_min", OpenApiTypes.INT, description="Minimum ride count"),
OpenApiParameter("ride_count_max", OpenApiTypes.INT, description="Maximum ride count"),
OpenApiParameter(
"coaster_count_min",
OpenApiTypes.INT,
@@ -688,9 +644,7 @@ class ParkPhotoViewSet(ModelViewSet):
"properties": {
"parks": {
"type": "array",
"items": {
"$ref": "#/components/schemas/HybridParkSerializer"
},
"items": {"$ref": "#/components/schemas/HybridParkSerializer"},
},
"total_count": {"type": "integer"},
"strategy": {
@@ -808,7 +762,7 @@ class HybridParkAPIView(APIView):
for param in int_params:
value = query_params.get(param)
if value:
try:
try: # noqa: SIM105
filters[param] = int(value)
except ValueError:
pass # Skip invalid integer values
@@ -818,7 +772,7 @@ class HybridParkAPIView(APIView):
for param in float_params:
value = query_params.get(param)
if value:
try:
try: # noqa: SIM105
filters[param] = float(value)
except ValueError:
pass # Skip invalid float values

View File

@@ -0,0 +1,167 @@
"""
Standardized API response helpers for ThrillWiki.
This module provides consistent response formatting across all API endpoints:
Success responses:
- Action completed: {"detail": "Success message"}
- With data: {"detail": "...", "data": {...}}
Error responses:
- Validation: {"field": ["error"]} (DRF default)
- Application: {"detail": "Error message", "code": "ERROR_CODE"}
Usage:
from apps.api.v1.responses import success_response, error_response
# Success
return success_response("Avatar saved successfully")
# Error
return error_response("User not found", code="NOT_FOUND", status_code=404)
"""
from rest_framework import status
from rest_framework.response import Response
# Standard error codes for machine-readable error handling
class ErrorCodes:
"""Standard error codes for API responses."""
# Authentication / Authorization
UNAUTHORIZED = "UNAUTHORIZED"
FORBIDDEN = "FORBIDDEN"
INVALID_CREDENTIALS = "INVALID_CREDENTIALS"
TOKEN_EXPIRED = "TOKEN_EXPIRED"
TOKEN_INVALID = "TOKEN_INVALID"
# Resource errors
NOT_FOUND = "NOT_FOUND"
ALREADY_EXISTS = "ALREADY_EXISTS"
CONFLICT = "CONFLICT"
# Validation errors
VALIDATION_ERROR = "VALIDATION_ERROR"
INVALID_INPUT = "INVALID_INPUT"
MISSING_FIELD = "MISSING_FIELD"
# Operation errors
OPERATION_FAILED = "OPERATION_FAILED"
PERMISSION_DENIED = "PERMISSION_DENIED"
RATE_LIMITED = "RATE_LIMITED"
# User-specific errors
USER_NOT_FOUND = "USER_NOT_FOUND"
USER_INACTIVE = "USER_INACTIVE"
USER_BANNED = "USER_BANNED"
CANNOT_DELETE_SUPERUSER = "CANNOT_DELETE_SUPERUSER"
CANNOT_DELETE_SELF = "CANNOT_DELETE_SELF"
# Verification errors
VERIFICATION_EXPIRED = "VERIFICATION_EXPIRED"
VERIFICATION_INVALID = "VERIFICATION_INVALID"
ALREADY_VERIFIED = "ALREADY_VERIFIED"
# External service errors
EXTERNAL_SERVICE_ERROR = "EXTERNAL_SERVICE_ERROR"
CLOUDFLARE_ERROR = "CLOUDFLARE_ERROR"
def success_response(
detail: str,
data: dict | None = None,
status_code: int = status.HTTP_200_OK,
) -> Response:
"""
Create a standardized success response.
Args:
detail: Human-readable success message
data: Optional additional data to include
status_code: HTTP status code (default 200)
Returns:
DRF Response object
Example:
return success_response("Avatar saved successfully")
return success_response("User created", data={"id": user.id}, status_code=201)
"""
response_data = {"detail": detail}
if data:
response_data.update(data)
return Response(response_data, status=status_code)
def error_response(
detail: str,
code: str | None = None,
status_code: int = status.HTTP_400_BAD_REQUEST,
extra: dict | None = None,
) -> Response:
"""
Create a standardized error response.
Args:
detail: Human-readable error message
code: Machine-readable error code from ErrorCodes
status_code: HTTP status code (default 400)
extra: Optional additional data to include
Returns:
DRF Response object
Example:
return error_response("User not found", code=ErrorCodes.NOT_FOUND, status_code=404)
return error_response("Invalid input", code=ErrorCodes.VALIDATION_ERROR)
"""
response_data = {"detail": detail}
if code:
response_data["code"] = code
if extra:
response_data.update(extra)
return Response(response_data, status=status_code)
def created_response(detail: str, data: dict | None = None) -> Response:
"""Convenience wrapper for 201 Created responses."""
return success_response(detail, data=data, status_code=status.HTTP_201_CREATED)
def not_found_response(detail: str = "Resource not found") -> Response:
"""Convenience wrapper for 404 Not Found responses."""
return error_response(
detail,
code=ErrorCodes.NOT_FOUND,
status_code=status.HTTP_404_NOT_FOUND,
)
def forbidden_response(detail: str = "Permission denied") -> Response:
"""Convenience wrapper for 403 Forbidden responses."""
return error_response(
detail,
code=ErrorCodes.FORBIDDEN,
status_code=status.HTTP_403_FORBIDDEN,
)
def unauthorized_response(detail: str = "Authentication required") -> Response:
"""Convenience wrapper for 401 Unauthorized responses."""
return error_response(
detail,
code=ErrorCodes.UNAUTHORIZED,
status_code=status.HTTP_401_UNAUTHORIZED,
)
__all__ = [
"ErrorCodes",
"success_response",
"error_response",
"created_response",
"not_found_response",
"forbidden_response",
"unauthorized_response",
]

View File

@@ -24,6 +24,7 @@ from apps.api.v1.serializers.companies import (
try:
from apps.rides.models.company import Company
MODELS_AVAILABLE = True
except ImportError:
Company = None
@@ -65,9 +66,7 @@ class CompanyListCreateAPIView(APIView):
# Search filter
search = request.query_params.get("search", "")
if search:
qs = qs.filter(
Q(name__icontains=search) | Q(description__icontains=search)
)
qs = qs.filter(Q(name__icontains=search) | Q(description__icontains=search))
# Role filter
role = request.query_params.get("role", "")
@@ -120,7 +119,7 @@ class CompanyDetailAPIView(APIView):
try:
return Company.objects.get(pk=pk)
except Company.DoesNotExist:
raise NotFound("Company not found")
raise NotFound("Company not found") from None
@extend_schema(
summary="Retrieve a company",

View File

@@ -93,18 +93,10 @@ class RideModelListCreateAPIView(APIView):
type=OpenApiTypes.STR,
required=True,
),
OpenApiParameter(
name="page", location=OpenApiParameter.QUERY, type=OpenApiTypes.INT
),
OpenApiParameter(
name="page_size", location=OpenApiParameter.QUERY, type=OpenApiTypes.INT
),
OpenApiParameter(
name="search", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR
),
OpenApiParameter(
name="category", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR
),
OpenApiParameter(name="page", location=OpenApiParameter.QUERY, type=OpenApiTypes.INT),
OpenApiParameter(name="page_size", location=OpenApiParameter.QUERY, type=OpenApiTypes.INT),
OpenApiParameter(name="search", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR),
OpenApiParameter(name="category", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR),
OpenApiParameter(
name="target_market",
location=OpenApiParameter.QUERY,
@@ -134,7 +126,7 @@ class RideModelListCreateAPIView(APIView):
try:
manufacturer = Company.objects.get(slug=manufacturer_slug)
except Company.DoesNotExist:
raise NotFound("Manufacturer not found")
raise NotFound("Manufacturer not found") from None
qs = (
RideModel.objects.filter(manufacturer=manufacturer)
@@ -176,13 +168,9 @@ class RideModelListCreateAPIView(APIView):
# Year filters
if filters.get("first_installation_year_min"):
qs = qs.filter(
first_installation_year__gte=filters["first_installation_year_min"]
)
qs = qs.filter(first_installation_year__gte=filters["first_installation_year_min"])
if filters.get("first_installation_year_max"):
qs = qs.filter(
first_installation_year__lte=filters["first_installation_year_max"]
)
qs = qs.filter(first_installation_year__lte=filters["first_installation_year_max"])
# Installation count filter
if filters.get("min_installations"):
@@ -190,23 +178,15 @@ class RideModelListCreateAPIView(APIView):
# Height filters
if filters.get("min_height_ft"):
qs = qs.filter(
typical_height_range_max_ft__gte=filters["min_height_ft"]
)
qs = qs.filter(typical_height_range_max_ft__gte=filters["min_height_ft"])
if filters.get("max_height_ft"):
qs = qs.filter(
typical_height_range_min_ft__lte=filters["max_height_ft"]
)
qs = qs.filter(typical_height_range_min_ft__lte=filters["max_height_ft"])
# Speed filters
if filters.get("min_speed_mph"):
qs = qs.filter(
typical_speed_range_max_mph__gte=filters["min_speed_mph"]
)
qs = qs.filter(typical_speed_range_max_mph__gte=filters["min_speed_mph"])
if filters.get("max_speed_mph"):
qs = qs.filter(
typical_speed_range_min_mph__lte=filters["max_speed_mph"]
)
qs = qs.filter(typical_speed_range_min_mph__lte=filters["max_speed_mph"])
# Ordering
ordering = filters.get("ordering", "manufacturer__name,name")
@@ -216,9 +196,7 @@ class RideModelListCreateAPIView(APIView):
paginator = StandardResultsSetPagination()
page = paginator.paginate_queryset(qs, request)
serializer = RideModelListOutputSerializer(
page, many=True, context={"request": request}
)
serializer = RideModelListOutputSerializer(page, many=True, context={"request": request})
return paginator.get_paginated_response(serializer.data)
@extend_schema(
@@ -240,9 +218,7 @@ class RideModelListCreateAPIView(APIView):
"""Create a new ride model for a specific manufacturer."""
if not MODELS_AVAILABLE:
return Response(
{
"detail": "Ride model creation is not available because domain models are not imported."
},
{"detail": "Ride model creation is not available because domain models are not imported."},
status=status.HTTP_501_NOT_IMPLEMENTED,
)
@@ -250,7 +226,7 @@ class RideModelListCreateAPIView(APIView):
try:
manufacturer = Company.objects.get(slug=manufacturer_slug)
except Company.DoesNotExist:
raise NotFound("Manufacturer not found")
raise NotFound("Manufacturer not found") from None
serializer_in = RideModelCreateInputSerializer(data=request.data)
serializer_in.is_valid(raise_exception=True)
@@ -279,18 +255,14 @@ class RideModelListCreateAPIView(APIView):
target_market=validated.get("target_market", ""),
)
out_serializer = RideModelDetailOutputSerializer(
ride_model, context={"request": request}
)
out_serializer = RideModelDetailOutputSerializer(ride_model, context={"request": request})
return Response(out_serializer.data, status=status.HTTP_201_CREATED)
class RideModelDetailAPIView(APIView):
permission_classes = [permissions.AllowAny]
def _get_ride_model_or_404(
self, manufacturer_slug: str, ride_model_slug: str
) -> Any:
def _get_ride_model_or_404(self, manufacturer_slug: str, ride_model_slug: str) -> Any:
if not MODELS_AVAILABLE:
raise NotFound("Ride model models not available")
try:
@@ -300,7 +272,7 @@ class RideModelDetailAPIView(APIView):
.get(manufacturer__slug=manufacturer_slug, slug=ride_model_slug)
)
except RideModel.DoesNotExist:
raise NotFound("Ride model not found")
raise NotFound("Ride model not found") from None
@extend_schema(
summary="Retrieve a ride model",
@@ -322,13 +294,9 @@ class RideModelDetailAPIView(APIView):
responses={200: RideModelDetailOutputSerializer()},
tags=["Ride Models"],
)
def get(
self, request: Request, manufacturer_slug: str, ride_model_slug: str
) -> Response:
def get(self, request: Request, manufacturer_slug: str, ride_model_slug: str) -> Response:
ride_model = self._get_ride_model_or_404(manufacturer_slug, ride_model_slug)
serializer = RideModelDetailOutputSerializer(
ride_model, context={"request": request}
)
serializer = RideModelDetailOutputSerializer(ride_model, context={"request": request})
return Response(serializer.data)
@extend_schema(
@@ -352,9 +320,7 @@ class RideModelDetailAPIView(APIView):
responses={200: RideModelDetailOutputSerializer()},
tags=["Ride Models"],
)
def patch(
self, request: Request, manufacturer_slug: str, ride_model_slug: str
) -> Response:
def patch(self, request: Request, manufacturer_slug: str, ride_model_slug: str) -> Response:
ride_model = self._get_ride_model_or_404(manufacturer_slug, ride_model_slug)
serializer_in = RideModelUpdateInputSerializer(data=request.data, partial=True)
serializer_in.is_valid(raise_exception=True)
@@ -366,20 +332,16 @@ class RideModelDetailAPIView(APIView):
manufacturer = Company.objects.get(id=value)
ride_model.manufacturer = manufacturer
except Company.DoesNotExist:
raise ValidationError({"manufacturer_id": "Manufacturer not found"})
raise ValidationError({"manufacturer_id": "Manufacturer not found"}) from None
else:
setattr(ride_model, field, value)
ride_model.save()
serializer = RideModelDetailOutputSerializer(
ride_model, context={"request": request}
)
serializer = RideModelDetailOutputSerializer(ride_model, context={"request": request})
return Response(serializer.data)
def put(
self, request: Request, manufacturer_slug: str, ride_model_slug: str
) -> Response:
def put(self, request: Request, manufacturer_slug: str, ride_model_slug: str) -> Response:
# Full replace - reuse patch behavior for simplicity
return self.patch(request, manufacturer_slug, ride_model_slug)
@@ -403,9 +365,7 @@ class RideModelDetailAPIView(APIView):
responses={204: None},
tags=["Ride Models"],
)
def delete(
self, request: Request, manufacturer_slug: str, ride_model_slug: str
) -> Response:
def delete(self, request: Request, manufacturer_slug: str, ride_model_slug: str) -> Response:
ride_model = self._get_ride_model_or_404(manufacturer_slug, ride_model_slug)
ride_model.delete()
return Response(status=status.HTTP_204_NO_CONTENT)
@@ -449,9 +409,7 @@ class RideModelSearchAPIView(APIView):
)
qs = RideModel.objects.filter(
Q(name__icontains=q)
| Q(description__icontains=q)
| Q(manufacturer__name__icontains=q)
Q(name__icontains=q) | Q(description__icontains=q) | Q(manufacturer__name__icontains=q)
).select_related("manufacturer")[:20]
results = [
@@ -491,8 +449,8 @@ class RideModelFilterOptionsAPIView(APIView):
# Use Rich Choice Objects for fallback options
try:
# Get rich choice objects from registry
categories = get_choices('categories', 'rides')
target_markets = get_choices('target_markets', 'rides')
categories = get_choices("categories", "rides")
target_markets = get_choices("target_markets", "rides")
# Convert Rich Choice Objects to frontend format with metadata
categories_data = [
@@ -500,10 +458,10 @@ class RideModelFilterOptionsAPIView(APIView):
"value": choice.value,
"label": choice.label,
"description": choice.description,
"color": choice.metadata.get('color'),
"icon": choice.metadata.get('icon'),
"css_class": choice.metadata.get('css_class'),
"sort_order": choice.metadata.get('sort_order', 0)
"color": choice.metadata.get("color"),
"icon": choice.metadata.get("icon"),
"css_class": choice.metadata.get("css_class"),
"sort_order": choice.metadata.get("sort_order", 0),
}
for choice in categories
]
@@ -513,10 +471,10 @@ class RideModelFilterOptionsAPIView(APIView):
"value": choice.value,
"label": choice.label,
"description": choice.description,
"color": choice.metadata.get('color'),
"icon": choice.metadata.get('icon'),
"css_class": choice.metadata.get('css_class'),
"sort_order": choice.metadata.get('sort_order', 0)
"color": choice.metadata.get("color"),
"icon": choice.metadata.get("icon"),
"css_class": choice.metadata.get("css_class"),
"sort_order": choice.metadata.get("sort_order", 0),
}
for choice in target_markets
]
@@ -524,25 +482,173 @@ class RideModelFilterOptionsAPIView(APIView):
except Exception:
# Ultimate fallback with basic structure
categories_data = [
{"value": "RC", "label": "Roller Coaster", "description": "High-speed thrill rides with tracks", "color": "red", "icon": "roller-coaster", "css_class": "bg-red-100 text-red-800", "sort_order": 1},
{"value": "DR", "label": "Dark Ride", "description": "Indoor themed experiences", "color": "purple", "icon": "dark-ride", "css_class": "bg-purple-100 text-purple-800", "sort_order": 2},
{"value": "FR", "label": "Flat Ride", "description": "Spinning and rotating attractions", "color": "blue", "icon": "flat-ride", "css_class": "bg-blue-100 text-blue-800", "sort_order": 3},
{"value": "WR", "label": "Water Ride", "description": "Water-based attractions and slides", "color": "cyan", "icon": "water-ride", "css_class": "bg-cyan-100 text-cyan-800", "sort_order": 4},
{"value": "TR", "label": "Transport", "description": "Transportation systems within parks", "color": "green", "icon": "transport", "css_class": "bg-green-100 text-green-800", "sort_order": 5},
{"value": "OT", "label": "Other", "description": "Miscellaneous attractions", "color": "gray", "icon": "other", "css_class": "bg-gray-100 text-gray-800", "sort_order": 6},
{
"value": "RC",
"label": "Roller Coaster",
"description": "High-speed thrill rides with tracks",
"color": "red",
"icon": "roller-coaster",
"css_class": "bg-red-100 text-red-800",
"sort_order": 1,
},
{
"value": "DR",
"label": "Dark Ride",
"description": "Indoor themed experiences",
"color": "purple",
"icon": "dark-ride",
"css_class": "bg-purple-100 text-purple-800",
"sort_order": 2,
},
{
"value": "FR",
"label": "Flat Ride",
"description": "Spinning and rotating attractions",
"color": "blue",
"icon": "flat-ride",
"css_class": "bg-blue-100 text-blue-800",
"sort_order": 3,
},
{
"value": "WR",
"label": "Water Ride",
"description": "Water-based attractions and slides",
"color": "cyan",
"icon": "water-ride",
"css_class": "bg-cyan-100 text-cyan-800",
"sort_order": 4,
},
{
"value": "TR",
"label": "Transport",
"description": "Transportation systems within parks",
"color": "green",
"icon": "transport",
"css_class": "bg-green-100 text-green-800",
"sort_order": 5,
},
{
"value": "OT",
"label": "Other",
"description": "Miscellaneous attractions",
"color": "gray",
"icon": "other",
"css_class": "bg-gray-100 text-gray-800",
"sort_order": 6,
},
]
target_markets_data = [
{"value": "FAMILY", "label": "Family", "description": "Suitable for all family members", "color": "green", "icon": "family", "css_class": "bg-green-100 text-green-800", "sort_order": 1},
{"value": "THRILL", "label": "Thrill", "description": "High-intensity thrill experience", "color": "orange", "icon": "thrill", "css_class": "bg-orange-100 text-orange-800", "sort_order": 2},
{"value": "EXTREME", "label": "Extreme", "description": "Maximum intensity experience", "color": "red", "icon": "extreme", "css_class": "bg-red-100 text-red-800", "sort_order": 3},
{"value": "KIDDIE", "label": "Kiddie", "description": "Designed for young children", "color": "pink", "icon": "kiddie", "css_class": "bg-pink-100 text-pink-800", "sort_order": 4},
{"value": "ALL_AGES", "label": "All Ages", "description": "Enjoyable for all age groups", "color": "blue", "icon": "all-ages", "css_class": "bg-blue-100 text-blue-800", "sort_order": 5},
{
"value": "FAMILY",
"label": "Family",
"description": "Suitable for all family members",
"color": "green",
"icon": "family",
"css_class": "bg-green-100 text-green-800",
"sort_order": 1,
},
{
"value": "THRILL",
"label": "Thrill",
"description": "High-intensity thrill experience",
"color": "orange",
"icon": "thrill",
"css_class": "bg-orange-100 text-orange-800",
"sort_order": 2,
},
{
"value": "EXTREME",
"label": "Extreme",
"description": "Maximum intensity experience",
"color": "red",
"icon": "extreme",
"css_class": "bg-red-100 text-red-800",
"sort_order": 3,
},
{
"value": "KIDDIE",
"label": "Kiddie",
"description": "Designed for young children",
"color": "pink",
"icon": "kiddie",
"css_class": "bg-pink-100 text-pink-800",
"sort_order": 4,
},
{
"value": "ALL_AGES",
"label": "All Ages",
"description": "Enjoyable for all age groups",
"color": "blue",
"icon": "all-ages",
"css_class": "bg-blue-100 text-blue-800",
"sort_order": 5,
},
]
return Response({
return Response(
{
"categories": categories_data,
"target_markets": target_markets_data,
"manufacturers": [{"id": 1, "name": "Bolliger & Mabillard", "slug": "bolliger-mabillard"}],
"ordering_options": [
{"value": "name", "label": "Name A-Z"},
{"value": "-name", "label": "Name Z-A"},
{"value": "manufacturer__name", "label": "Manufacturer A-Z"},
{"value": "-manufacturer__name", "label": "Manufacturer Z-A"},
{"value": "first_installation_year", "label": "Oldest First"},
{"value": "-first_installation_year", "label": "Newest First"},
{"value": "total_installations", "label": "Fewest Installations"},
{"value": "-total_installations", "label": "Most Installations"},
],
}
)
# Get static choice definitions from Rich Choice Objects (primary source)
# Get dynamic data from database queries
# Get rich choice objects from registry
categories = get_choices("categories", "rides")
target_markets = get_choices("target_markets", "rides")
# Convert Rich Choice Objects to frontend format with metadata
categories_data = [
{
"value": choice.value,
"label": choice.label,
"description": choice.description,
"color": choice.metadata.get("color"),
"icon": choice.metadata.get("icon"),
"css_class": choice.metadata.get("css_class"),
"sort_order": choice.metadata.get("sort_order", 0),
}
for choice in categories
]
target_markets_data = [
{
"value": choice.value,
"label": choice.label,
"description": choice.description,
"color": choice.metadata.get("color"),
"icon": choice.metadata.get("icon"),
"css_class": choice.metadata.get("css_class"),
"sort_order": choice.metadata.get("sort_order", 0),
}
for choice in target_markets
]
# Get actual data from database
manufacturers = (
Company.objects.filter(roles__contains=["MANUFACTURER"], ride_models__isnull=False)
.distinct()
.values("id", "name", "slug")
)
return Response(
{
"categories": categories_data,
"target_markets": target_markets_data,
"manufacturers": [{"id": 1, "name": "Bolliger & Mabillard", "slug": "bolliger-mabillard"}],
"manufacturers": list(manufacturers),
"ordering_options": [
{"value": "name", "label": "Name A-Z"},
{"value": "-name", "label": "Name Z-A"},
@@ -553,68 +659,9 @@ class RideModelFilterOptionsAPIView(APIView):
{"value": "total_installations", "label": "Fewest Installations"},
{"value": "-total_installations", "label": "Most Installations"},
],
})
# Get static choice definitions from Rich Choice Objects (primary source)
# Get dynamic data from database queries
# Get rich choice objects from registry
categories = get_choices('categories', 'rides')
target_markets = get_choices('target_markets', 'rides')
# Convert Rich Choice Objects to frontend format with metadata
categories_data = [
{
"value": choice.value,
"label": choice.label,
"description": choice.description,
"color": choice.metadata.get('color'),
"icon": choice.metadata.get('icon'),
"css_class": choice.metadata.get('css_class'),
"sort_order": choice.metadata.get('sort_order', 0)
}
for choice in categories
]
target_markets_data = [
{
"value": choice.value,
"label": choice.label,
"description": choice.description,
"color": choice.metadata.get('color'),
"icon": choice.metadata.get('icon'),
"css_class": choice.metadata.get('css_class'),
"sort_order": choice.metadata.get('sort_order', 0)
}
for choice in target_markets
]
# Get actual data from database
manufacturers = (
Company.objects.filter(
roles__contains=["MANUFACTURER"], ride_models__isnull=False
)
.distinct()
.values("id", "name", "slug")
)
return Response({
"categories": categories_data,
"target_markets": target_markets_data,
"manufacturers": list(manufacturers),
"ordering_options": [
{"value": "name", "label": "Name A-Z"},
{"value": "-name", "label": "Name Z-A"},
{"value": "manufacturer__name", "label": "Manufacturer A-Z"},
{"value": "-manufacturer__name", "label": "Manufacturer Z-A"},
{"value": "first_installation_year", "label": "Oldest First"},
{"value": "-first_installation_year", "label": "Newest First"},
{"value": "total_installations", "label": "Fewest Installations"},
{"value": "-total_installations", "label": "Most Installations"},
],
})
# === RIDE MODEL STATISTICS ===
@@ -646,37 +693,23 @@ class RideModelStatsAPIView(APIView):
# Calculate statistics
total_models = RideModel.objects.count()
total_installations = (
RideModel.objects.aggregate(total=Count("rides"))["total"] or 0
)
total_installations = RideModel.objects.aggregate(total=Count("rides"))["total"] or 0
active_manufacturers = (
Company.objects.filter(
roles__contains=["MANUFACTURER"], ride_models__isnull=False
)
.distinct()
.count()
Company.objects.filter(roles__contains=["MANUFACTURER"], ride_models__isnull=False).distinct().count()
)
discontinued_models = RideModel.objects.filter(is_discontinued=True).count()
# Category breakdown
by_category = {}
category_counts = (
RideModel.objects.exclude(category="")
.values("category")
.annotate(count=Count("id"))
)
category_counts = RideModel.objects.exclude(category="").values("category").annotate(count=Count("id"))
for item in category_counts:
by_category[item["category"]] = item["count"]
# Target market breakdown
by_target_market = {}
market_counts = (
RideModel.objects.exclude(target_market="")
.values("target_market")
.annotate(count=Count("id"))
)
market_counts = RideModel.objects.exclude(target_market="").values("target_market").annotate(count=Count("id"))
for item in market_counts:
by_target_market[item["target_market"]] = item["count"]
@@ -693,9 +726,7 @@ class RideModelStatsAPIView(APIView):
# Recent models (last 30 days)
thirty_days_ago = timezone.now() - timedelta(days=30)
recent_models = RideModel.objects.filter(
created_at__gte=thirty_days_ago
).count()
recent_models = RideModel.objects.filter(created_at__gte=thirty_days_ago).count()
return Response(
{
@@ -730,7 +761,7 @@ class RideModelVariantListCreateAPIView(APIView):
try:
ride_model = RideModel.objects.get(pk=ride_model_pk)
except RideModel.DoesNotExist:
raise NotFound("Ride model not found")
raise NotFound("Ride model not found") from None
variants = RideModelVariant.objects.filter(ride_model=ride_model)
serializer = RideModelVariantOutputSerializer(variants, many=True)
@@ -753,7 +784,7 @@ class RideModelVariantListCreateAPIView(APIView):
try:
ride_model = RideModel.objects.get(pk=ride_model_pk)
except RideModel.DoesNotExist:
raise NotFound("Ride model not found")
raise NotFound("Ride model not found") from None
# Override ride_model_id in the data
data = request.data.copy()
@@ -787,7 +818,7 @@ class RideModelVariantDetailAPIView(APIView):
try:
return RideModelVariant.objects.get(ride_model_id=ride_model_pk, pk=pk)
except RideModelVariant.DoesNotExist:
raise NotFound("Variant not found")
raise NotFound("Variant not found") from None
@extend_schema(
summary="Get a ride model variant",
@@ -807,9 +838,7 @@ class RideModelVariantDetailAPIView(APIView):
)
def patch(self, request: Request, ride_model_pk: int, pk: int) -> Response:
variant = self._get_variant_or_404(ride_model_pk, pk)
serializer_in = RideModelVariantUpdateInputSerializer(
data=request.data, partial=True
)
serializer_in = RideModelVariantUpdateInputSerializer(data=request.data, partial=True)
serializer_in.is_valid(raise_exception=True)
for field, value in serializer_in.validated_data.items():

View File

@@ -118,9 +118,7 @@ class RidePhotoViewSet(ModelViewSet):
def get_queryset(self): # type: ignore[override]
"""Get photos for the current ride with optimized queries."""
queryset = RidePhoto.objects.select_related(
"ride", "ride__park", "ride__park__operator", "uploaded_by"
)
queryset = RidePhoto.objects.select_related("ride", "ride__park", "ride__park__operator", "uploaded_by")
# If ride_pk is provided in URL kwargs, filter by ride
ride_pk = self.kwargs.get("ride_pk")
@@ -149,7 +147,7 @@ class RidePhotoViewSet(ModelViewSet):
try:
ride = Ride.objects.get(pk=ride_id)
except Ride.DoesNotExist:
raise ValidationError("Ride not found")
raise ValidationError("Ride not found") from None
try:
# Use the service to create the photo with proper business logic
@@ -169,17 +167,14 @@ class RidePhotoViewSet(ModelViewSet):
except Exception as e:
logger.error(f"Error creating ride photo: {e}")
raise ValidationError(f"Failed to create photo: {str(e)}")
raise ValidationError(f"Failed to create photo: {str(e)}") from None
def perform_update(self, serializer):
"""Update ride photo with permission checking."""
instance = self.get_object()
# Check permissions - allow owner or staff
if not (
self.request.user == instance.uploaded_by
or getattr(self.request.user, "is_staff", False)
):
if not (self.request.user == instance.uploaded_by or getattr(self.request.user, "is_staff", False)):
raise PermissionDenied("You can only edit your own photos or be an admin.")
# Handle primary photo logic using service
@@ -191,39 +186,31 @@ class RidePhotoViewSet(ModelViewSet):
del serializer.validated_data["is_primary"]
except Exception as e:
logger.error(f"Error setting primary photo: {e}")
raise ValidationError(f"Failed to set primary photo: {str(e)}")
raise ValidationError(f"Failed to set primary photo: {str(e)}") from None
def perform_destroy(self, instance):
"""Delete ride photo with permission checking."""
# Check permissions - allow owner or staff
if not (
self.request.user == instance.uploaded_by
or getattr(self.request.user, "is_staff", False)
):
raise PermissionDenied(
"You can only delete your own photos or be an admin."
)
if not (self.request.user == instance.uploaded_by or getattr(self.request.user, "is_staff", False)):
raise PermissionDenied("You can only delete your own photos or be an admin.")
try:
# Delete from Cloudflare first if image exists
if instance.image:
try:
from django_cloudflareimages_toolkit.services import CloudflareImagesService
service = CloudflareImagesService()
service.delete_image(instance.image)
logger.info(
f"Successfully deleted ride photo from Cloudflare: {instance.image.cloudflare_id}")
logger.info(f"Successfully deleted ride photo from Cloudflare: {instance.image.cloudflare_id}")
except Exception as e:
logger.error(
f"Failed to delete ride photo from Cloudflare: {str(e)}")
logger.error(f"Failed to delete ride photo from Cloudflare: {str(e)}")
# Continue with database deletion even if Cloudflare deletion fails
RideMediaService.delete_photo(
instance, deleted_by=self.request.user # type: ignore
)
RideMediaService.delete_photo(instance, deleted_by=self.request.user) # type: ignore
except Exception as e:
logger.error(f"Error deleting ride photo: {e}")
raise ValidationError(f"Failed to delete photo: {str(e)}")
raise ValidationError(f"Failed to delete photo: {str(e)}") from None
@extend_schema(
summary="Set photo as primary",
@@ -242,13 +229,8 @@ class RidePhotoViewSet(ModelViewSet):
photo = self.get_object()
# Check permissions - allow owner or staff
if not (
request.user == photo.uploaded_by
or getattr(request.user, "is_staff", False)
):
raise PermissionDenied(
"You can only modify your own photos or be an admin."
)
if not (request.user == photo.uploaded_by or getattr(request.user, "is_staff", False)):
raise PermissionDenied("You can only modify your own photos or be an admin.")
try:
success = RideMediaService.set_primary_photo(ride=photo.ride, photo=photo)
@@ -260,21 +242,21 @@ class RidePhotoViewSet(ModelViewSet):
return Response(
{
"message": "Photo set as primary successfully",
"detail": "Photo set as primary successfully",
"photo": serializer.data,
},
status=status.HTTP_200_OK,
)
else:
return Response(
{"error": "Failed to set primary photo"},
{"detail": "Failed to set primary photo"},
status=status.HTTP_400_BAD_REQUEST,
)
except Exception as e:
logger.error(f"Error setting primary photo: {e}")
return Response(
{"error": f"Failed to set primary photo: {str(e)}"},
{"detail": f"Failed to set primary photo: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -305,7 +287,7 @@ class RidePhotoViewSet(ModelViewSet):
if photo_ids is None or approve is None:
return Response(
{"error": "Missing required fields: photo_ids and/or approve."},
{"detail": "Missing required fields: photo_ids and/or approve."},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -319,7 +301,7 @@ class RidePhotoViewSet(ModelViewSet):
return Response(
{
"message": f"Successfully {'approved' if approve else 'rejected'} {updated_count} photos",
"detail": f"Successfully {'approved' if approve else 'rejected'} {updated_count} photos",
"updated_count": updated_count,
},
status=status.HTTP_200_OK,
@@ -328,7 +310,7 @@ class RidePhotoViewSet(ModelViewSet):
except Exception as e:
logger.error(f"Error in bulk photo approval: {e}")
return Response(
{"error": f"Failed to update photos: {str(e)}"},
{"detail": f"Failed to update photos: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -352,7 +334,7 @@ class RidePhotoViewSet(ModelViewSet):
ride = Ride.objects.get(pk=ride_pk)
except Ride.DoesNotExist:
return Response(
{"error": "Ride not found."},
{"detail": "Ride not found."},
status=status.HTTP_404_NOT_FOUND,
)
@@ -363,16 +345,10 @@ class RidePhotoViewSet(ModelViewSet):
# Global stats across all rides
stats = {
"total_photos": RidePhoto.objects.count(),
"approved_photos": RidePhoto.objects.filter(
is_approved=True
).count(),
"pending_photos": RidePhoto.objects.filter(
is_approved=False
).count(),
"approved_photos": RidePhoto.objects.filter(is_approved=True).count(),
"pending_photos": RidePhoto.objects.filter(is_approved=False).count(),
"has_primary": False, # Not applicable for global stats
"recent_uploads": RidePhoto.objects.order_by("-created_at")[
:5
].count(),
"recent_uploads": RidePhoto.objects.order_by("-created_at")[:5].count(),
"by_type": {},
}
@@ -382,7 +358,7 @@ class RidePhotoViewSet(ModelViewSet):
except Exception as e:
logger.error(f"Error getting ride photo stats: {e}")
return Response(
{"error": f"Failed to get photo statistics: {str(e)}"},
{"detail": f"Failed to get photo statistics: {str(e)}"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@@ -401,26 +377,23 @@ class RidePhotoViewSet(ModelViewSet):
def set_primary_legacy(self, request, id=None):
"""Legacy set primary action for backwards compatibility."""
photo = self.get_object()
if not (
request.user == photo.uploaded_by
or request.user.has_perm("rides.change_ridephoto")
):
if not (request.user == photo.uploaded_by or request.user.has_perm("rides.change_ridephoto")):
return Response(
{"error": "You do not have permission to edit photos for this ride."},
{"detail": "You do not have permission to edit photos for this ride."},
status=status.HTTP_403_FORBIDDEN,
)
try:
success = RideMediaService.set_primary_photo(ride=photo.ride, photo=photo)
if success:
return Response({"message": "Photo set as primary successfully."})
return Response({"detail": "Photo set as primary successfully."})
else:
return Response(
{"error": "Failed to set primary photo"},
{"detail": "Failed to set primary photo"},
status=status.HTTP_400_BAD_REQUEST,
)
except Exception as e:
logger.error(f"Error in set_primary_photo: {str(e)}", exc_info=True)
return Response({"error": str(e)}, status=status.HTTP_400_BAD_REQUEST)
return Response({"detail": str(e)}, status=status.HTTP_400_BAD_REQUEST)
@extend_schema(
summary="Save Cloudflare image as ride photo",
@@ -440,7 +413,7 @@ class RidePhotoViewSet(ModelViewSet):
ride_pk = self.kwargs.get("ride_pk")
if not ride_pk:
return Response(
{"error": "Ride ID is required"},
{"detail": "Ride ID is required"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -448,14 +421,14 @@ class RidePhotoViewSet(ModelViewSet):
ride = Ride.objects.get(pk=ride_pk)
except Ride.DoesNotExist:
return Response(
{"error": "Ride not found"},
{"detail": "Ride not found"},
status=status.HTTP_404_NOT_FOUND,
)
cloudflare_image_id = request.data.get("cloudflare_image_id")
if not cloudflare_image_id:
return Response(
{"error": "cloudflare_image_id is required"},
{"detail": "cloudflare_image_id is required"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -473,27 +446,25 @@ class RidePhotoViewSet(ModelViewSet):
if not image_data:
return Response(
{"error": "Image not found in Cloudflare"},
{"detail": "Image not found in Cloudflare"},
status=status.HTTP_400_BAD_REQUEST,
)
# Try to find existing CloudflareImage record by cloudflare_id
cloudflare_image = None
try:
cloudflare_image = CloudflareImage.objects.get(
cloudflare_id=cloudflare_image_id)
cloudflare_image = CloudflareImage.objects.get(cloudflare_id=cloudflare_image_id)
# Update existing record with latest data from Cloudflare
cloudflare_image.status = 'uploaded'
cloudflare_image.status = "uploaded"
cloudflare_image.uploaded_at = timezone.now()
cloudflare_image.metadata = image_data.get('meta', {})
cloudflare_image.metadata = image_data.get("meta", {})
# Extract variants from nested result structure
cloudflare_image.variants = image_data.get(
'result', {}).get('variants', [])
cloudflare_image.variants = image_data.get("result", {}).get("variants", [])
cloudflare_image.cloudflare_metadata = image_data
cloudflare_image.width = image_data.get('width')
cloudflare_image.height = image_data.get('height')
cloudflare_image.format = image_data.get('format', '')
cloudflare_image.width = image_data.get("width")
cloudflare_image.height = image_data.get("height")
cloudflare_image.format = image_data.get("format", "")
cloudflare_image.save()
except CloudflareImage.DoesNotExist:
@@ -501,24 +472,23 @@ class RidePhotoViewSet(ModelViewSet):
cloudflare_image = CloudflareImage.objects.create(
cloudflare_id=cloudflare_image_id,
user=request.user,
status='uploaded',
upload_url='', # Not needed for uploaded images
status="uploaded",
upload_url="", # Not needed for uploaded images
expires_at=timezone.now() + timezone.timedelta(days=365), # Set far future expiry
uploaded_at=timezone.now(),
metadata=image_data.get('meta', {}),
metadata=image_data.get("meta", {}),
# Extract variants from nested result structure
variants=image_data.get('result', {}).get('variants', []),
variants=image_data.get("result", {}).get("variants", []),
cloudflare_metadata=image_data,
width=image_data.get('width'),
height=image_data.get('height'),
format=image_data.get('format', ''),
width=image_data.get("width"),
height=image_data.get("height"),
format=image_data.get("format", ""),
)
except Exception as api_error:
logger.error(
f"Error fetching image from Cloudflare API: {str(api_error)}", exc_info=True)
logger.error(f"Error fetching image from Cloudflare API: {str(api_error)}", exc_info=True)
return Response(
{"error": f"Failed to fetch image from Cloudflare: {str(api_error)}"},
{"detail": f"Failed to fetch image from Cloudflare: {str(api_error)}"},
status=status.HTTP_400_BAD_REQUEST,
)
@@ -548,6 +518,6 @@ class RidePhotoViewSet(ModelViewSet):
except Exception as e:
logger.error(f"Error saving ride photo: {e}")
return Response(
{"error": f"Failed to save photo: {str(e)}"},
{"detail": f"Failed to save photo: {str(e)}"},
status=status.HTTP_400_BAD_REQUEST,
)

View File

@@ -52,18 +52,14 @@ from apps.rides.models import Ride, RidePhoto
class RidePhotoOutputSerializer(serializers.ModelSerializer):
"""Output serializer for ride photos with Cloudflare Images support."""
uploaded_by_username = serializers.CharField(
source="uploaded_by.username", read_only=True
)
uploaded_by_username = serializers.CharField(source="uploaded_by.username", read_only=True)
file_size = serializers.SerializerMethodField()
dimensions = serializers.SerializerMethodField()
image_url = serializers.SerializerMethodField()
image_variants = serializers.SerializerMethodField()
@extend_schema_field(
serializers.IntegerField(allow_null=True, help_text="File size in bytes")
)
@extend_schema_field(serializers.IntegerField(allow_null=True, help_text="File size in bytes"))
def get_file_size(self, obj):
"""Get file size in bytes."""
return obj.file_size
@@ -81,11 +77,7 @@ class RidePhotoOutputSerializer(serializers.ModelSerializer):
"""Get image dimensions as [width, height]."""
return obj.dimensions
@extend_schema_field(
serializers.URLField(
help_text="Full URL to the Cloudflare Images asset", allow_null=True
)
)
@extend_schema_field(serializers.URLField(help_text="Full URL to the Cloudflare Images asset", allow_null=True))
def get_image_url(self, obj):
"""Get the full Cloudflare Images URL."""
if obj.image:
@@ -186,9 +178,7 @@ class RidePhotoUpdateInputSerializer(serializers.ModelSerializer):
class RidePhotoListOutputSerializer(serializers.ModelSerializer):
"""Simplified output serializer for ride photo lists."""
uploaded_by_username = serializers.CharField(
source="uploaded_by.username", read_only=True
)
uploaded_by_username = serializers.CharField(source="uploaded_by.username", read_only=True)
class Meta:
model = RidePhoto
@@ -208,12 +198,8 @@ class RidePhotoListOutputSerializer(serializers.ModelSerializer):
class RidePhotoApprovalInputSerializer(serializers.Serializer):
"""Input serializer for photo approval operations."""
photo_ids = serializers.ListField(
child=serializers.IntegerField(), help_text="List of photo IDs to approve"
)
approve = serializers.BooleanField(
default=True, help_text="Whether to approve (True) or reject (False) the photos"
)
photo_ids = serializers.ListField(child=serializers.IntegerField(), help_text="List of photo IDs to approve")
approve = serializers.BooleanField(default=True, help_text="Whether to approve (True) or reject (False) the photos")
class RidePhotoStatsOutputSerializer(serializers.Serializer):
@@ -224,9 +210,7 @@ class RidePhotoStatsOutputSerializer(serializers.Serializer):
pending_photos = serializers.IntegerField()
has_primary = serializers.BooleanField()
recent_uploads = serializers.IntegerField()
by_type = serializers.DictField(
child=serializers.IntegerField(), help_text="Photo counts by type"
)
by_type = serializers.DictField(child=serializers.IntegerField(), help_text="Photo counts by type")
class RidePhotoTypeFilterSerializer(serializers.Serializer):
@@ -292,8 +276,12 @@ class HybridRideSerializer(serializers.ModelSerializer):
ride_model_name = serializers.CharField(source="ride_model.name", read_only=True, allow_null=True)
ride_model_slug = serializers.CharField(source="ride_model.slug", read_only=True, allow_null=True)
ride_model_category = serializers.CharField(source="ride_model.category", read_only=True, allow_null=True)
ride_model_manufacturer_name = serializers.CharField(source="ride_model.manufacturer.name", read_only=True, allow_null=True)
ride_model_manufacturer_slug = serializers.CharField(source="ride_model.manufacturer.slug", read_only=True, allow_null=True)
ride_model_manufacturer_name = serializers.CharField(
source="ride_model.manufacturer.name", read_only=True, allow_null=True
)
ride_model_manufacturer_slug = serializers.CharField(
source="ride_model.manufacturer.slug", read_only=True, allow_null=True
)
# Roller coaster stats fields
coaster_height_ft = serializers.SerializerMethodField()
@@ -323,7 +311,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_park_city(self, obj):
"""Get city from park location."""
try:
if obj.park and hasattr(obj.park, 'location') and obj.park.location:
if obj.park and hasattr(obj.park, "location") and obj.park.location:
return obj.park.location.city
return None
except AttributeError:
@@ -333,7 +321,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_park_state(self, obj):
"""Get state from park location."""
try:
if obj.park and hasattr(obj.park, 'location') and obj.park.location:
if obj.park and hasattr(obj.park, "location") and obj.park.location:
return obj.park.location.state
return None
except AttributeError:
@@ -343,7 +331,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_park_country(self, obj):
"""Get country from park location."""
try:
if obj.park and hasattr(obj.park, 'location') and obj.park.location:
if obj.park and hasattr(obj.park, "location") and obj.park.location:
return obj.park.location.country
return None
except AttributeError:
@@ -353,7 +341,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_height_ft(self, obj):
"""Get roller coaster height."""
try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats:
if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return float(obj.coaster_stats.height_ft) if obj.coaster_stats.height_ft else None
return None
except (AttributeError, TypeError):
@@ -363,7 +351,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_length_ft(self, obj):
"""Get roller coaster length."""
try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats:
if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return float(obj.coaster_stats.length_ft) if obj.coaster_stats.length_ft else None
return None
except (AttributeError, TypeError):
@@ -373,7 +361,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_speed_mph(self, obj):
"""Get roller coaster speed."""
try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats:
if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return float(obj.coaster_stats.speed_mph) if obj.coaster_stats.speed_mph else None
return None
except (AttributeError, TypeError):
@@ -383,7 +371,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_inversions(self, obj):
"""Get roller coaster inversions."""
try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats:
if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return obj.coaster_stats.inversions
return None
except AttributeError:
@@ -393,7 +381,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_ride_time_seconds(self, obj):
"""Get roller coaster ride time."""
try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats:
if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return obj.coaster_stats.ride_time_seconds
return None
except AttributeError:
@@ -403,7 +391,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_track_type(self, obj):
"""Get roller coaster track type."""
try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats:
if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return obj.coaster_stats.track_type
return None
except AttributeError:
@@ -413,7 +401,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_track_material(self, obj):
"""Get roller coaster track material."""
try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats:
if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return obj.coaster_stats.track_material
return None
except AttributeError:
@@ -423,7 +411,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_roller_coaster_type(self, obj):
"""Get roller coaster type."""
try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats:
if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return obj.coaster_stats.roller_coaster_type
return None
except AttributeError:
@@ -433,7 +421,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_max_drop_height_ft(self, obj):
"""Get roller coaster max drop height."""
try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats:
if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return float(obj.coaster_stats.max_drop_height_ft) if obj.coaster_stats.max_drop_height_ft else None
return None
except (AttributeError, TypeError):
@@ -443,7 +431,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_propulsion_system(self, obj):
"""Get roller coaster propulsion system."""
try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats:
if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return obj.coaster_stats.propulsion_system
return None
except AttributeError:
@@ -453,7 +441,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_train_style(self, obj):
"""Get roller coaster train style."""
try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats:
if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return obj.coaster_stats.train_style
return None
except AttributeError:
@@ -463,7 +451,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_trains_count(self, obj):
"""Get roller coaster trains count."""
try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats:
if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return obj.coaster_stats.trains_count
return None
except AttributeError:
@@ -473,7 +461,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_cars_per_train(self, obj):
"""Get roller coaster cars per train."""
try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats:
if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return obj.coaster_stats.cars_per_train
return None
except AttributeError:
@@ -483,7 +471,7 @@ class HybridRideSerializer(serializers.ModelSerializer):
def get_coaster_seats_per_car(self, obj):
"""Get roller coaster seats per car."""
try:
if hasattr(obj, 'coaster_stats') and obj.coaster_stats:
if hasattr(obj, "coaster_stats") and obj.coaster_stats:
return obj.coaster_stats.seats_per_car
return None
except AttributeError:
@@ -514,44 +502,37 @@ class HybridRideSerializer(serializers.ModelSerializer):
"category",
"status",
"post_closing_status",
# Dates and computed fields
"opening_date",
"closing_date",
"status_since",
"opening_year",
# Park fields
"park_name",
"park_slug",
"park_city",
"park_state",
"park_country",
# Park area fields
"park_area_name",
"park_area_slug",
# Company fields
"manufacturer_name",
"manufacturer_slug",
"designer_name",
"designer_slug",
# Ride model fields
"ride_model_name",
"ride_model_slug",
"ride_model_category",
"ride_model_manufacturer_name",
"ride_model_manufacturer_slug",
# Ride specifications
"min_height_in",
"max_height_in",
"capacity_per_hour",
"ride_duration_seconds",
"average_rating",
# Roller coaster stats
"coaster_height_ft",
"coaster_length_ft",
@@ -567,18 +548,14 @@ class HybridRideSerializer(serializers.ModelSerializer):
"coaster_trains_count",
"coaster_cars_per_train",
"coaster_seats_per_car",
# Images
"banner_image_url",
"card_image_url",
# URLs
"url",
"park_url",
# Computed fields for filtering
"search_text",
# Metadata
"created_at",
"updated_at",

View File

@@ -35,11 +35,9 @@ app_name = "api_v1_rides"
urlpatterns = [
# Core list/create endpoints
path("", RideListCreateAPIView.as_view(), name="ride-list-create"),
# Hybrid filtering endpoints
path("hybrid/", HybridRideAPIView.as_view(), name="ride-hybrid-filtering"),
path("hybrid/filter-metadata/", RideFilterMetadataAPIView.as_view(), name="ride-hybrid-filter-metadata"),
# Filter options
path("filter-options/", FilterOptionsAPIView.as_view(), name="ride-filter-options"),
# Autocomplete / suggestion endpoints
@@ -61,7 +59,6 @@ urlpatterns = [
# Manufacturer and Designer endpoints
path("manufacturers/", ManufacturerListAPIView.as_view(), name="manufacturer-list"),
path("designers/", DesignerListAPIView.as_view(), name="designer-list"),
# Ride model management endpoints - nested under rides/manufacturers
path(
"manufacturers/<slug:manufacturer_slug>/",

View File

@@ -28,6 +28,7 @@ import logging
from typing import Any
from django.db import models
from django.db.models import Count
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import OpenApiParameter, extend_schema, extend_schema_view
from rest_framework import permissions, status
@@ -333,9 +334,7 @@ class RideListCreateAPIView(APIView):
paginator = StandardResultsSetPagination()
page = paginator.paginate_queryset(qs, request)
serializer = RideListOutputSerializer(
page, many=True, context={"request": request}
)
serializer = RideListOutputSerializer(page, many=True, context={"request": request})
return paginator.get_paginated_response(serializer.data)
def _apply_filters(self, qs, params):
@@ -567,9 +566,9 @@ class RideListCreateAPIView(APIView):
if ordering in valid_orderings:
if ordering in ["height_ft", "-height_ft", "speed_mph", "-speed_mph"]:
# For coaster stats ordering, we need to join and order by the stats
ordering_field = ordering.replace(
"height_ft", "coaster_stats__height_ft"
).replace("speed_mph", "coaster_stats__speed_mph")
ordering_field = ordering.replace("height_ft", "coaster_stats__height_ft").replace(
"speed_mph", "coaster_stats__speed_mph"
)
qs = qs.order_by(ordering_field)
else:
qs = qs.order_by(ordering)
@@ -602,7 +601,7 @@ class RideListCreateAPIView(APIView):
try:
park = Park.objects.get(id=validated["park_id"]) # type: ignore
except Park.DoesNotExist: # type: ignore
raise NotFound("Park not found")
raise NotFound("Park not found") from None
ride = Ride.objects.create( # type: ignore
name=validated["name"],
@@ -658,7 +657,7 @@ class RideDetailAPIView(APIView):
try:
return Ride.objects.select_related("park").get(pk=pk) # type: ignore
except Ride.DoesNotExist: # type: ignore
raise NotFound("Ride not found")
raise NotFound("Ride not found") from None
@cache_api_response(timeout=1800, key_prefix="ride_detail")
def get(self, request: Request, pk: int) -> Response:
@@ -672,9 +671,7 @@ class RideDetailAPIView(APIView):
serializer_in.is_valid(raise_exception=True)
if not MODELS_AVAILABLE:
return Response(
{
"detail": "Ride update is not available because domain models are not imported."
},
{"detail": "Ride update is not available because domain models are not imported."},
status=status.HTTP_501_NOT_IMPLEMENTED,
)
@@ -690,7 +687,7 @@ class RideDetailAPIView(APIView):
# Use the move_to_park method for proper handling
park_change_info = ride.move_to_park(new_park)
except Park.DoesNotExist: # type: ignore
raise NotFound("Target park not found")
raise NotFound("Target park not found") from None
# Apply other field updates
for key, value in validated_data.items():
@@ -715,9 +712,7 @@ class RideDetailAPIView(APIView):
def delete(self, request: Request, pk: int) -> Response:
if not MODELS_AVAILABLE:
return Response(
{
"detail": "Ride delete is not available because domain models are not imported."
},
{"detail": "Ride delete is not available because domain models are not imported."},
status=status.HTTP_501_NOT_IMPLEMENTED,
)
ride = self._get_ride_or_404(pk)
@@ -1491,16 +1486,12 @@ class FilterOptionsAPIView(APIView):
# Get manufacturers (companies with MANUFACTURER role)
manufacturers = list(
Company.objects.filter(roles__contains=["MANUFACTURER"])
.values("id", "name", "slug")
.order_by("name")
Company.objects.filter(roles__contains=["MANUFACTURER"]).values("id", "name", "slug").order_by("name")
)
# Get designers (companies with DESIGNER role)
designers = list(
Company.objects.filter(roles__contains=["DESIGNER"])
.values("id", "name", "slug")
.order_by("name")
Company.objects.filter(roles__contains=["DESIGNER"]).values("id", "name", "slug").order_by("name")
)
# Get ride models data from database
@@ -1722,11 +1713,7 @@ class FilterOptionsAPIView(APIView):
# --- Company search (autocomplete) -----------------------------------------
@extend_schema(
summary="Search companies (manufacturers/designers) for autocomplete",
parameters=[
OpenApiParameter(
name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR
)
],
parameters=[OpenApiParameter(name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR)],
responses={200: OpenApiTypes.OBJECT},
tags=["Rides"],
)
@@ -1753,20 +1740,14 @@ class CompanySearchAPIView(APIView):
)
qs = Company.objects.filter(name__icontains=q)[:20] # type: ignore
results = [
{"id": c.id, "name": c.name, "slug": getattr(c, "slug", "")} for c in qs
]
results = [{"id": c.id, "name": c.name, "slug": getattr(c, "slug", "")} for c in qs]
return Response(results)
# --- Ride model search (autocomplete) --------------------------------------
@extend_schema(
summary="Search ride models for autocomplete",
parameters=[
OpenApiParameter(
name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR
)
],
parameters=[OpenApiParameter(name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR)],
tags=["Rides"],
)
class RideModelSearchAPIView(APIView):
@@ -1795,21 +1776,14 @@ class RideModelSearchAPIView(APIView):
)
qs = RideModel.objects.filter(name__icontains=q)[:20] # type: ignore
results = [
{"id": m.id, "name": m.name, "category": getattr(m, "category", "")}
for m in qs
]
results = [{"id": m.id, "name": m.name, "category": getattr(m, "category", "")} for m in qs]
return Response(results)
# --- Search suggestions -----------------------------------------------------
@extend_schema(
summary="Search suggestions for ride search box",
parameters=[
OpenApiParameter(
name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR
)
],
parameters=[OpenApiParameter(name="q", location=OpenApiParameter.QUERY, type=OpenApiTypes.STR)],
tags=["Rides"],
)
class RideSearchSuggestionsAPIView(APIView):
@@ -1827,9 +1801,7 @@ class RideSearchSuggestionsAPIView(APIView):
# Very small suggestion implementation: look in ride names if available
if MODELS_AVAILABLE and Ride is not None:
qs = Ride.objects.filter(name__icontains=q).values_list("name", flat=True)[
:10
] # type: ignore
qs = Ride.objects.filter(name__icontains=q).values_list("name", flat=True)[:10] # type: ignore
return Response([{"suggestion": name} for name in qs])
# Fallback suggestions
@@ -1862,7 +1834,7 @@ class RideImageSettingsAPIView(APIView):
try:
return Ride.objects.get(pk=pk) # type: ignore
except Ride.DoesNotExist: # type: ignore
raise NotFound("Ride not found")
raise NotFound("Ride not found") from None
def patch(self, request: Request, pk: int) -> Response:
"""Set banner and card images for the ride."""
@@ -1878,9 +1850,7 @@ class RideImageSettingsAPIView(APIView):
ride.save()
# Return updated ride data
output_serializer = RideDetailOutputSerializer(
ride, context={"request": request}
)
output_serializer = RideDetailOutputSerializer(ride, context={"request": request})
return Response(output_serializer.data)
@@ -1902,12 +1872,8 @@ class RideImageSettingsAPIView(APIView):
OpenApiTypes.STR,
description="Filter by ride status (comma-separated for multiple)",
),
OpenApiParameter(
"park_slug", OpenApiTypes.STR, description="Filter by park slug"
),
OpenApiParameter(
"park_id", OpenApiTypes.INT, description="Filter by park ID"
),
OpenApiParameter("park_slug", OpenApiTypes.STR, description="Filter by park slug"),
OpenApiParameter("park_id", OpenApiTypes.INT, description="Filter by park ID"),
OpenApiParameter(
"manufacturer",
OpenApiTypes.STR,
@@ -1923,18 +1889,10 @@ class RideImageSettingsAPIView(APIView):
OpenApiTypes.STR,
description="Filter by ride model slug (comma-separated for multiple)",
),
OpenApiParameter(
"opening_year_min", OpenApiTypes.INT, description="Minimum opening year"
),
OpenApiParameter(
"opening_year_max", OpenApiTypes.INT, description="Maximum opening year"
),
OpenApiParameter(
"rating_min", OpenApiTypes.NUMBER, description="Minimum average rating"
),
OpenApiParameter(
"rating_max", OpenApiTypes.NUMBER, description="Maximum average rating"
),
OpenApiParameter("opening_year_min", OpenApiTypes.INT, description="Minimum opening year"),
OpenApiParameter("opening_year_max", OpenApiTypes.INT, description="Maximum opening year"),
OpenApiParameter("rating_min", OpenApiTypes.NUMBER, description="Minimum average rating"),
OpenApiParameter("rating_max", OpenApiTypes.NUMBER, description="Maximum average rating"),
OpenApiParameter(
"height_requirement_min",
OpenApiTypes.INT,
@@ -1945,12 +1903,8 @@ class RideImageSettingsAPIView(APIView):
OpenApiTypes.INT,
description="Maximum height requirement in inches",
),
OpenApiParameter(
"capacity_min", OpenApiTypes.INT, description="Minimum hourly capacity"
),
OpenApiParameter(
"capacity_max", OpenApiTypes.INT, description="Maximum hourly capacity"
),
OpenApiParameter("capacity_min", OpenApiTypes.INT, description="Minimum hourly capacity"),
OpenApiParameter("capacity_max", OpenApiTypes.INT, description="Maximum hourly capacity"),
OpenApiParameter(
"roller_coaster_type",
OpenApiTypes.STR,
@@ -2022,9 +1976,7 @@ class RideImageSettingsAPIView(APIView):
"properties": {
"rides": {
"type": "array",
"items": {
"$ref": "#/components/schemas/HybridRideSerializer"
},
"items": {"$ref": "#/components/schemas/HybridRideSerializer"},
},
"total_count": {"type": "integer"},
"strategy": {
@@ -2084,7 +2036,7 @@ class HybridRideAPIView(APIView):
data = smart_ride_loader.get_progressive_load(offset, filters)
except ValueError:
return Response(
{"error": "Invalid offset parameter"},
{"detail": "Invalid offset parameter"},
status=status.HTTP_400_BAD_REQUEST,
)
else:
@@ -2109,7 +2061,7 @@ class HybridRideAPIView(APIView):
except Exception as e:
logger.error(f"Error in HybridRideAPIView: {e}")
return Response(
{"error": "Internal server error"},
{"detail": "Internal server error"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@@ -2158,7 +2110,7 @@ class HybridRideAPIView(APIView):
for param in int_params:
value = query_params.get(param)
if value:
try:
try: # noqa: SIM105
filters[param] = int(value)
except ValueError:
pass # Skip invalid integer values
@@ -2175,7 +2127,7 @@ class HybridRideAPIView(APIView):
for param in float_params:
value = query_params.get(param)
if value:
try:
try: # noqa: SIM105
filters[param] = float(value)
except ValueError:
pass # Skip invalid float values
@@ -2408,7 +2360,7 @@ class RideFilterMetadataAPIView(APIView):
except Exception as e:
logger.error(f"Error in RideFilterMetadataAPIView: {e}")
return Response(
{"error": "Internal server error"},
{"detail": "Internal server error"},
status=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@@ -2417,18 +2369,18 @@ class RideFilterMetadataAPIView(APIView):
# Reuse the same filter extraction logic
view = HybridRideAPIView()
return view._extract_filters(query_params)
# === MANUFACTURER & DESIGNER LISTS ===
class BaseCompanyListAPIView(APIView):
permission_classes = [permissions.AllowAny]
role = None
def get(self, request: Request) -> Response:
if not MODELS_AVAILABLE:
return Response(
{"detail": "Models not available"},
status=status.HTTP_501_NOT_IMPLEMENTED
)
return Response({"detail": "Models not available"}, status=status.HTTP_501_NOT_IMPLEMENTED)
companies = (
Company.objects.filter(roles__contains=[self.role])
@@ -2448,10 +2400,8 @@ class BaseCompanyListAPIView(APIView):
for c in companies
]
return Response({
"results": data,
"count": len(data)
})
return Response({"results": data, "count": len(data)})
@extend_schema(
summary="List manufacturers",
@@ -2462,6 +2412,7 @@ class BaseCompanyListAPIView(APIView):
class ManufacturerListAPIView(BaseCompanyListAPIView):
role = "MANUFACTURER"
@extend_schema(
summary="List designers",
description="List all companies with DESIGNER role.",

View File

@@ -49,5 +49,4 @@ __all__ = (
"UserProfileCreateInputSerializer",
"UserProfileUpdateInputSerializer",
"UserProfileOutputSerializer",
)

View File

@@ -90,7 +90,6 @@ _ACCOUNTS_SYMBOLS: list[str] = [
"UserProfileOutputSerializer",
"UserProfileCreateInputSerializer",
"UserProfileUpdateInputSerializer",
"UserOutputSerializer",
"LoginInputSerializer",
"LoginOutputSerializer",

View File

@@ -187,6 +187,7 @@ class PublicUserSerializer(serializers.ModelSerializer):
Public user serializer for viewing other users' profiles.
Only exposes public information.
"""
profile = UserProfileSerializer(read_only=True)
class Meta:
@@ -228,37 +229,21 @@ class UserPreferencesSerializer(serializers.Serializer):
"""Serializer for user preferences and settings."""
theme_preference = RichChoiceFieldSerializer(
choice_group="theme_preferences",
domain="accounts",
help_text="User's theme preference"
)
email_notifications = serializers.BooleanField(
default=True, help_text="Whether to receive email notifications"
)
push_notifications = serializers.BooleanField(
default=False, help_text="Whether to receive push notifications"
choice_group="theme_preferences", domain="accounts", help_text="User's theme preference"
)
email_notifications = serializers.BooleanField(default=True, help_text="Whether to receive email notifications")
push_notifications = serializers.BooleanField(default=False, help_text="Whether to receive push notifications")
privacy_level = RichChoiceFieldSerializer(
choice_group="privacy_levels",
domain="accounts",
default="public",
help_text="Profile visibility level",
)
show_email = serializers.BooleanField(
default=False, help_text="Whether to show email on profile"
)
show_real_name = serializers.BooleanField(
default=True, help_text="Whether to show real name on profile"
)
show_statistics = serializers.BooleanField(
default=True, help_text="Whether to show ride statistics on profile"
)
allow_friend_requests = serializers.BooleanField(
default=True, help_text="Whether to allow friend requests"
)
allow_messages = serializers.BooleanField(
default=True, help_text="Whether to allow direct messages"
)
show_email = serializers.BooleanField(default=False, help_text="Whether to show email on profile")
show_real_name = serializers.BooleanField(default=True, help_text="Whether to show real name on profile")
show_statistics = serializers.BooleanField(default=True, help_text="Whether to show ride statistics on profile")
allow_friend_requests = serializers.BooleanField(default=True, help_text="Whether to allow friend requests")
allow_messages = serializers.BooleanField(default=True, help_text="Whether to allow direct messages")
# === NOTIFICATION SETTINGS SERIALIZERS ===
@@ -363,39 +348,17 @@ class PrivacySettingsSerializer(serializers.Serializer):
default="public",
help_text="Overall profile visibility",
)
show_email = serializers.BooleanField(
default=False, help_text="Show email address on profile"
)
show_real_name = serializers.BooleanField(
default=True, help_text="Show real name on profile"
)
show_join_date = serializers.BooleanField(
default=True, help_text="Show join date on profile"
)
show_statistics = serializers.BooleanField(
default=True, help_text="Show ride statistics on profile"
)
show_reviews = serializers.BooleanField(
default=True, help_text="Show reviews on profile"
)
show_photos = serializers.BooleanField(
default=True, help_text="Show uploaded photos on profile"
)
show_top_lists = serializers.BooleanField(
default=True, help_text="Show top lists on profile"
)
allow_friend_requests = serializers.BooleanField(
default=True, help_text="Allow others to send friend requests"
)
allow_messages = serializers.BooleanField(
default=True, help_text="Allow others to send direct messages"
)
allow_profile_comments = serializers.BooleanField(
default=False, help_text="Allow others to comment on profile"
)
search_visibility = serializers.BooleanField(
default=True, help_text="Allow profile to appear in search results"
)
show_email = serializers.BooleanField(default=False, help_text="Show email address on profile")
show_real_name = serializers.BooleanField(default=True, help_text="Show real name on profile")
show_join_date = serializers.BooleanField(default=True, help_text="Show join date on profile")
show_statistics = serializers.BooleanField(default=True, help_text="Show ride statistics on profile")
show_reviews = serializers.BooleanField(default=True, help_text="Show reviews on profile")
show_photos = serializers.BooleanField(default=True, help_text="Show uploaded photos on profile")
show_top_lists = serializers.BooleanField(default=True, help_text="Show top lists on profile")
allow_friend_requests = serializers.BooleanField(default=True, help_text="Allow others to send friend requests")
allow_messages = serializers.BooleanField(default=True, help_text="Allow others to send direct messages")
allow_profile_comments = serializers.BooleanField(default=False, help_text="Allow others to comment on profile")
search_visibility = serializers.BooleanField(default=True, help_text="Allow profile to appear in search results")
activity_visibility = RichChoiceFieldSerializer(
choice_group="privacy_levels",
domain="accounts",
@@ -431,21 +394,13 @@ class SecuritySettingsSerializer(serializers.Serializer):
two_factor_enabled = serializers.BooleanField(
default=False, help_text="Whether two-factor authentication is enabled"
)
login_notifications = serializers.BooleanField(
default=True, help_text="Send notifications for new logins"
)
login_notifications = serializers.BooleanField(default=True, help_text="Send notifications for new logins")
session_timeout = serializers.IntegerField(
default=30, min_value=5, max_value=180, help_text="Session timeout in days"
)
require_password_change = serializers.BooleanField(
default=False, help_text="Whether password change is required"
)
last_password_change = serializers.DateTimeField(
read_only=True, help_text="When password was last changed"
)
active_sessions = serializers.IntegerField(
read_only=True, help_text="Number of active sessions"
)
require_password_change = serializers.BooleanField(default=False, help_text="Whether password change is required")
last_password_change = serializers.DateTimeField(read_only=True, help_text="When password was last changed")
active_sessions = serializers.IntegerField(read_only=True, help_text="Number of active sessions")
login_history_retention = serializers.IntegerField(
default=90,
min_value=30,
@@ -699,7 +654,7 @@ class ThemePreferenceSerializer(serializers.ModelSerializer):
"id": 1,
"notification_type": "submission_approved",
"title": "Your submission has been approved!",
"message": "Your photo submission for Cedar Point has been approved and is now live on the site.",
"detail": "Your photo submission for Cedar Point has been approved and is now live on the site.",
"priority": "normal",
"is_read": False,
"read_at": None,
@@ -866,15 +821,11 @@ class MarkNotificationsReadSerializer(serializers.Serializer):
def validate_notification_ids(self, value):
"""Validate that all notification IDs belong to the requesting user."""
user = self.context["request"].user
valid_ids = UserNotification.objects.filter(
id__in=value, user=user
).values_list("id", flat=True)
valid_ids = UserNotification.objects.filter(id__in=value, user=user).values_list("id", flat=True)
invalid_ids = set(value) - set(valid_ids)
if invalid_ids:
raise serializers.ValidationError(
f"Invalid notification IDs: {list(invalid_ids)}"
)
raise serializers.ValidationError(f"Invalid notification IDs: {list(invalid_ids)}")
return value
@@ -901,9 +852,8 @@ class AvatarUploadSerializer(serializers.Serializer):
raise serializers.ValidationError("No file provided")
# Check file size constraints (max 10MB for Cloudflare Images)
if hasattr(value, 'size') and value.size > 10 * 1024 * 1024:
raise serializers.ValidationError(
"Image file too large. Maximum size is 10MB.")
if hasattr(value, "size") and value.size > 10 * 1024 * 1024:
raise serializers.ValidationError("Image file too large. Maximum size is 10MB.")
# Try to validate with PIL
try:
@@ -926,13 +876,13 @@ class AvatarUploadSerializer(serializers.Serializer):
# Check image dimensions (max 12,000x12,000 for Cloudflare Images)
if image.size[0] > 12000 or image.size[1] > 12000:
raise serializers.ValidationError(
"Image dimensions too large. Maximum is 12,000x12,000 pixels.")
raise serializers.ValidationError("Image dimensions too large. Maximum is 12,000x12,000 pixels.")
# Check if it's a supported format
if image.format not in ['JPEG', 'PNG', 'GIF', 'WEBP']:
if image.format not in ["JPEG", "PNG", "GIF", "WEBP"]:
raise serializers.ValidationError(
f"Unsupported image format: {image.format}. Supported formats: JPEG, PNG, GIF, WebP.")
f"Unsupported image format: {image.format}. Supported formats: JPEG, PNG, GIF, WebP."
)
except serializers.ValidationError:
raise # Re-raise validation errors

View File

@@ -97,7 +97,7 @@ class LoginInputSerializer(serializers.Serializer):
password=password,
)
if not user:
if not user: # noqa: SIM102
# Try email-based authentication if username failed
if "@" in username:
try:
@@ -138,7 +138,7 @@ class LoginInputSerializer(serializers.Serializer):
"first_name": "John",
"last_name": "Doe",
},
"message": "Login successful",
"detail": "Login successful",
},
)
]
@@ -213,7 +213,7 @@ class SignupInputSerializer(serializers.ModelSerializer):
try:
validate_password(value)
except DjangoValidationError as e:
raise serializers.ValidationError(list(e.messages))
raise serializers.ValidationError(list(e.messages)) from None
return value
def validate(self, attrs):
@@ -253,7 +253,7 @@ class SignupInputSerializer(serializers.ModelSerializer):
"first_name": "Jane",
"last_name": "Smith",
},
"message": "Registration successful",
"detail": "Registration successful",
},
)
]
@@ -276,7 +276,7 @@ class SignupOutputSerializer(serializers.Serializer):
summary="Example logout response",
description="Successful logout response",
value={
"message": "Logout successful",
"detail": "Logout successful",
},
)
]
@@ -318,9 +318,9 @@ class PasswordResetInputSerializer(serializers.Serializer):
"""Send password reset email."""
email = self.validated_data["email"] # type: ignore[index]
try:
_user = UserModel.objects.get(email=email)
# Check if email exists (but don't reveal the result for security)
UserModel.objects.get(email=email)
# Here you would typically send a password reset email
# For now, we'll just pass
pass
except UserModel.DoesNotExist:
# Don't reveal if email exists for security
@@ -393,7 +393,7 @@ class PasswordChangeInputSerializer(serializers.Serializer):
try:
validate_password(value, user=self.context["request"].user)
except DjangoValidationError as e:
raise serializers.ValidationError(list(e.messages))
raise serializers.ValidationError(list(e.messages)) from None
return value
def validate(self, attrs):
@@ -492,6 +492,4 @@ class AuthStatusOutputSerializer(serializers.Serializer):
"""Output serializer for authentication status."""
authenticated = serializers.BooleanField(help_text="Whether user is authenticated")
user = UserOutputSerializer(
allow_null=True, help_text="User information if authenticated"
)
user = UserOutputSerializer(allow_null=True, help_text="User information if authenticated")

View File

@@ -112,10 +112,7 @@ class RideModelDetailOutputSerializer(serializers.Serializer):
id = serializers.IntegerField()
name = serializers.CharField()
description = serializers.CharField()
category = RichChoiceFieldSerializer(
choice_group="categories",
domain="rides"
)
category = RichChoiceFieldSerializer(choice_group="categories", domain="rides")
# Manufacturer info
manufacturer = serializers.SerializerMethodField()

View File

@@ -99,9 +99,7 @@ class ParkHistoryOutputSerializer(serializers.Serializer):
"slug": park.slug,
"status": park.status,
"opening_date": (
park.opening_date.isoformat()
if hasattr(park, "opening_date") and park.opening_date
else None
park.opening_date.isoformat() if hasattr(park, "opening_date") and park.opening_date else None
),
"coaster_count": getattr(park, "coaster_count", 0),
"ride_count": getattr(park, "ride_count", 0),
@@ -143,9 +141,7 @@ class RideHistoryOutputSerializer(serializers.Serializer):
"park_name": ride.park.name if hasattr(ride, "park") else None,
"status": getattr(ride, "status", "UNKNOWN"),
"opening_date": (
ride.opening_date.isoformat()
if hasattr(ride, "opening_date") and ride.opening_date
else None
ride.opening_date.isoformat() if hasattr(ride, "opening_date") and ride.opening_date else None
),
"ride_type": getattr(ride, "ride_type", "Unknown"),
}

View File

@@ -79,16 +79,12 @@ class MapLocationSerializer(serializers.Serializer):
return {
"coaster_count": obj.coaster_count or 0,
"ride_count": obj.ride_count or 0,
"average_rating": (
float(obj.average_rating) if obj.average_rating else None
),
"average_rating": (float(obj.average_rating) if obj.average_rating else None),
}
elif obj._meta.model_name == "ride":
return {
"category": obj.get_category_display() if obj.category else None,
"average_rating": (
float(obj.average_rating) if obj.average_rating else None
),
"average_rating": (float(obj.average_rating) if obj.average_rating else None),
"park_name": obj.park.name if obj.park else None,
}
return {}
@@ -339,24 +335,16 @@ class MapLocationDetailSerializer(serializers.Serializer):
return {
"coaster_count": obj.coaster_count or 0,
"ride_count": obj.ride_count or 0,
"average_rating": (
float(obj.average_rating) if obj.average_rating else None
),
"average_rating": (float(obj.average_rating) if obj.average_rating else None),
"size_acres": float(obj.size_acres) if obj.size_acres else None,
"opening_date": (
obj.opening_date.isoformat() if obj.opening_date else None
),
"opening_date": (obj.opening_date.isoformat() if obj.opening_date else None),
}
elif obj._meta.model_name == "ride":
return {
"category": obj.get_category_display() if obj.category else None,
"average_rating": (
float(obj.average_rating) if obj.average_rating else None
),
"average_rating": (float(obj.average_rating) if obj.average_rating else None),
"park_name": obj.park.name if obj.park else None,
"opening_date": (
obj.opening_date.isoformat() if obj.opening_date else None
),
"opening_date": (obj.opening_date.isoformat() if obj.opening_date else None),
"manufacturer": obj.manufacturer.name if obj.manufacturer else None,
}
return {}
@@ -382,9 +370,7 @@ class MapBoundsInputSerializer(serializers.Serializer):
def validate(self, attrs):
"""Validate that bounds make geographic sense."""
if attrs["north"] <= attrs["south"]:
raise serializers.ValidationError(
"North bound must be greater than south bound"
)
raise serializers.ValidationError("North bound must be greater than south bound")
# Handle longitude wraparound (e.g., crossing the international date line)
# For now, we'll require west < east for simplicity

View File

@@ -31,9 +31,7 @@ class PhotoUploadInputSerializer(serializers.Serializer):
allow_blank=True,
help_text="Alt text for accessibility",
)
is_primary = serializers.BooleanField(
default=False, help_text="Whether this should be the primary photo"
)
is_primary = serializers.BooleanField(default=False, help_text="Whether this should be the primary photo")
@extend_schema_serializer(
@@ -89,9 +87,7 @@ class PhotoDetailOutputSerializer(serializers.Serializer):
return {
"id": obj.uploaded_by.id,
"username": obj.uploaded_by.username,
"display_name": getattr(
obj.uploaded_by, "get_display_name", lambda: obj.uploaded_by.username
)(),
"display_name": getattr(obj.uploaded_by, "get_display_name", lambda: obj.uploaded_by.username)(),
}

View File

@@ -24,12 +24,8 @@ class ParkStatsOutputSerializer(serializers.Serializer):
under_construction = serializers.IntegerField()
# Averages
average_rating = serializers.DecimalField(
max_digits=3, decimal_places=2, allow_null=True
)
average_coaster_count = serializers.DecimalField(
max_digits=5, decimal_places=2, allow_null=True
)
average_rating = serializers.DecimalField(max_digits=3, decimal_places=2, allow_null=True)
average_coaster_count = serializers.DecimalField(max_digits=5, decimal_places=2, allow_null=True)
# Top countries
top_countries = serializers.ListField(child=serializers.DictField())
@@ -50,12 +46,8 @@ class RideStatsOutputSerializer(serializers.Serializer):
rides_by_category = serializers.DictField()
# Averages
average_rating = serializers.DecimalField(
max_digits=3, decimal_places=2, allow_null=True
)
average_capacity = serializers.DecimalField(
max_digits=8, decimal_places=2, allow_null=True
)
average_rating = serializers.DecimalField(max_digits=3, decimal_places=2, allow_null=True)
average_capacity = serializers.DecimalField(max_digits=8, decimal_places=2, allow_null=True)
# Top manufacturers
top_manufacturers = serializers.ListField(child=serializers.DictField())
@@ -91,10 +83,7 @@ class ParkReviewOutputSerializer(serializers.Serializer):
class HealthCheckOutputSerializer(serializers.Serializer):
"""Output serializer for health check responses."""
status = RichChoiceFieldSerializer(
choice_group="health_statuses",
domain="core"
)
status = RichChoiceFieldSerializer(choice_group="health_statuses", domain="core")
timestamp = serializers.DateTimeField()
version = serializers.CharField()
environment = serializers.CharField()
@@ -115,9 +104,6 @@ class PerformanceMetricsOutputSerializer(serializers.Serializer):
class SimpleHealthOutputSerializer(serializers.Serializer):
"""Output serializer for simple health check."""
status = RichChoiceFieldSerializer(
choice_group="simple_health_statuses",
domain="core"
)
status = RichChoiceFieldSerializer(choice_group="simple_health_statuses", domain="core")
timestamp = serializers.DateTimeField()
error = serializers.CharField(required=False)

View File

@@ -29,14 +29,10 @@ from apps.parks.models.reviews import ParkReview
"user": {
"username": "park_fan",
"display_name": "Park Fan",
"avatar_url": "https://example.com/avatar.jpg"
"avatar_url": "https://example.com/avatar.jpg",
},
"park": {
"id": 101,
"name": "Cedar Point",
"slug": "cedar-point"
}
}
"park": {"id": 101, "name": "Cedar Point", "slug": "cedar-point"},
},
)
]
)
@@ -145,8 +141,7 @@ class ParkReviewStatsOutputSerializer(serializers.Serializer):
pending_reviews = serializers.IntegerField()
average_rating = serializers.FloatField(allow_null=True)
rating_distribution = serializers.DictField(
child=serializers.IntegerField(),
help_text="Count of reviews by rating (1-10)"
child=serializers.IntegerField(), help_text="Count of reviews by rating (1-10)"
)
recent_reviews = serializers.IntegerField()
@@ -154,20 +149,15 @@ class ParkReviewStatsOutputSerializer(serializers.Serializer):
class ParkReviewModerationInputSerializer(serializers.Serializer):
"""Input serializer for review moderation operations."""
review_ids = serializers.ListField(
child=serializers.IntegerField(),
help_text="List of review IDs to moderate"
)
review_ids = serializers.ListField(child=serializers.IntegerField(), help_text="List of review IDs to moderate")
action = serializers.ChoiceField(
choices=[
("publish", "Publish"),
("unpublish", "Unpublish"),
("delete", "Delete"),
],
help_text="Moderation action to perform"
help_text="Moderation action to perform",
)
moderation_notes = serializers.CharField(
required=False,
allow_blank=True,
help_text="Optional notes about the moderation action"
required=False, allow_blank=True, help_text="Optional notes about the moderation action"
)

View File

@@ -52,16 +52,11 @@ class ParkListOutputSerializer(serializers.Serializer):
id = serializers.IntegerField()
name = serializers.CharField()
slug = serializers.CharField()
status = RichChoiceFieldSerializer(
choice_group="statuses",
domain="parks"
)
status = RichChoiceFieldSerializer(choice_group="statuses", domain="parks")
description = serializers.CharField()
# Statistics
average_rating = serializers.DecimalField(
max_digits=3, decimal_places=2, allow_null=True
)
average_rating = serializers.DecimalField(max_digits=3, decimal_places=2, allow_null=True)
coaster_count = serializers.IntegerField(allow_null=True)
ride_count = serializers.IntegerField(allow_null=True)
@@ -145,25 +140,18 @@ class ParkDetailOutputSerializer(serializers.Serializer):
id = serializers.IntegerField()
name = serializers.CharField()
slug = serializers.CharField()
status = RichChoiceFieldSerializer(
choice_group="statuses",
domain="parks"
)
status = RichChoiceFieldSerializer(choice_group="statuses", domain="parks")
description = serializers.CharField()
# Details
opening_date = serializers.DateField(allow_null=True)
closing_date = serializers.DateField(allow_null=True)
operating_season = serializers.CharField()
size_acres = serializers.DecimalField(
max_digits=10, decimal_places=2, allow_null=True
)
size_acres = serializers.DecimalField(max_digits=10, decimal_places=2, allow_null=True)
website = serializers.URLField()
# Statistics
average_rating = serializers.DecimalField(
max_digits=3, decimal_places=2, allow_null=True
)
average_rating = serializers.DecimalField(max_digits=3, decimal_places=2, allow_null=True)
coaster_count = serializers.IntegerField(allow_null=True)
ride_count = serializers.IntegerField(allow_null=True)
@@ -211,9 +199,7 @@ class ParkDetailOutputSerializer(serializers.Serializer):
"""Get all approved photos for this park."""
from apps.parks.models import ParkPhoto
photos = ParkPhoto.objects.filter(park=obj, is_approved=True).order_by(
"-is_primary", "-created_at"
)[
photos = ParkPhoto.objects.filter(park=obj, is_approved=True).order_by("-is_primary", "-created_at")[
:10
] # Limit to 10 photos
@@ -228,7 +214,9 @@ class ParkDetailOutputSerializer(serializers.Serializer):
"public": MediaURLService.get_cloudflare_url_with_fallback(photo.image, "public"),
},
"friendly_urls": {
"thumbnail": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "thumbnail"),
"thumbnail": MediaURLService.generate_park_photo_url(
obj.slug, photo.caption, photo.pk, "thumbnail"
),
"medium": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "medium"),
"large": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "large"),
"public": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "public"),
@@ -246,9 +234,7 @@ class ParkDetailOutputSerializer(serializers.Serializer):
from apps.parks.models import ParkPhoto
try:
photo = ParkPhoto.objects.filter(
park=obj, is_primary=True, is_approved=True
).first()
photo = ParkPhoto.objects.filter(park=obj, is_primary=True, is_approved=True).first()
if photo and photo.image:
return {
@@ -261,7 +247,9 @@ class ParkDetailOutputSerializer(serializers.Serializer):
"public": MediaURLService.get_cloudflare_url_with_fallback(photo.image, "public"),
},
"friendly_urls": {
"thumbnail": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "thumbnail"),
"thumbnail": MediaURLService.generate_park_photo_url(
obj.slug, photo.caption, photo.pk, "thumbnail"
),
"medium": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "medium"),
"large": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "large"),
"public": MediaURLService.generate_park_photo_url(obj.slug, photo.caption, photo.pk, "public"),
@@ -289,10 +277,18 @@ class ParkDetailOutputSerializer(serializers.Serializer):
"public": MediaURLService.get_cloudflare_url_with_fallback(obj.banner_image.image, "public"),
},
"friendly_urls": {
"thumbnail": MediaURLService.generate_park_photo_url(obj.slug, obj.banner_image.caption, obj.banner_image.pk, "thumbnail"),
"medium": MediaURLService.generate_park_photo_url(obj.slug, obj.banner_image.caption, obj.banner_image.pk, "medium"),
"large": MediaURLService.generate_park_photo_url(obj.slug, obj.banner_image.caption, obj.banner_image.pk, "large"),
"public": MediaURLService.generate_park_photo_url(obj.slug, obj.banner_image.caption, obj.banner_image.pk, "public"),
"thumbnail": MediaURLService.generate_park_photo_url(
obj.slug, obj.banner_image.caption, obj.banner_image.pk, "thumbnail"
),
"medium": MediaURLService.generate_park_photo_url(
obj.slug, obj.banner_image.caption, obj.banner_image.pk, "medium"
),
"large": MediaURLService.generate_park_photo_url(
obj.slug, obj.banner_image.caption, obj.banner_image.pk, "large"
),
"public": MediaURLService.generate_park_photo_url(
obj.slug, obj.banner_image.caption, obj.banner_image.pk, "public"
),
},
"caption": obj.banner_image.caption,
"alt_text": obj.banner_image.alt_text,
@@ -303,9 +299,7 @@ class ParkDetailOutputSerializer(serializers.Serializer):
try:
latest_photo = (
ParkPhoto.objects.filter(
park=obj, is_approved=True, image__isnull=False
)
ParkPhoto.objects.filter(park=obj, is_approved=True, image__isnull=False)
.order_by("-created_at")
.first()
)
@@ -321,10 +315,18 @@ class ParkDetailOutputSerializer(serializers.Serializer):
"public": MediaURLService.get_cloudflare_url_with_fallback(latest_photo.image, "public"),
},
"friendly_urls": {
"thumbnail": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "thumbnail"),
"medium": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "medium"),
"large": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "large"),
"public": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "public"),
"thumbnail": MediaURLService.generate_park_photo_url(
obj.slug, latest_photo.caption, latest_photo.pk, "thumbnail"
),
"medium": MediaURLService.generate_park_photo_url(
obj.slug, latest_photo.caption, latest_photo.pk, "medium"
),
"large": MediaURLService.generate_park_photo_url(
obj.slug, latest_photo.caption, latest_photo.pk, "large"
),
"public": MediaURLService.generate_park_photo_url(
obj.slug, latest_photo.caption, latest_photo.pk, "public"
),
},
"caption": latest_photo.caption,
"alt_text": latest_photo.alt_text,
@@ -350,10 +352,18 @@ class ParkDetailOutputSerializer(serializers.Serializer):
"public": MediaURLService.get_cloudflare_url_with_fallback(obj.card_image.image, "public"),
},
"friendly_urls": {
"thumbnail": MediaURLService.generate_park_photo_url(obj.slug, obj.card_image.caption, obj.card_image.pk, "thumbnail"),
"medium": MediaURLService.generate_park_photo_url(obj.slug, obj.card_image.caption, obj.card_image.pk, "medium"),
"large": MediaURLService.generate_park_photo_url(obj.slug, obj.card_image.caption, obj.card_image.pk, "large"),
"public": MediaURLService.generate_park_photo_url(obj.slug, obj.card_image.caption, obj.card_image.pk, "public"),
"thumbnail": MediaURLService.generate_park_photo_url(
obj.slug, obj.card_image.caption, obj.card_image.pk, "thumbnail"
),
"medium": MediaURLService.generate_park_photo_url(
obj.slug, obj.card_image.caption, obj.card_image.pk, "medium"
),
"large": MediaURLService.generate_park_photo_url(
obj.slug, obj.card_image.caption, obj.card_image.pk, "large"
),
"public": MediaURLService.generate_park_photo_url(
obj.slug, obj.card_image.caption, obj.card_image.pk, "public"
),
},
"caption": obj.card_image.caption,
"alt_text": obj.card_image.alt_text,
@@ -364,9 +374,7 @@ class ParkDetailOutputSerializer(serializers.Serializer):
try:
latest_photo = (
ParkPhoto.objects.filter(
park=obj, is_approved=True, image__isnull=False
)
ParkPhoto.objects.filter(park=obj, is_approved=True, image__isnull=False)
.order_by("-created_at")
.first()
)
@@ -382,10 +390,18 @@ class ParkDetailOutputSerializer(serializers.Serializer):
"public": MediaURLService.get_cloudflare_url_with_fallback(latest_photo.image, "public"),
},
"friendly_urls": {
"thumbnail": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "thumbnail"),
"medium": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "medium"),
"large": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "large"),
"public": MediaURLService.generate_park_photo_url(obj.slug, latest_photo.caption, latest_photo.pk, "public"),
"thumbnail": MediaURLService.generate_park_photo_url(
obj.slug, latest_photo.caption, latest_photo.pk, "thumbnail"
),
"medium": MediaURLService.generate_park_photo_url(
obj.slug, latest_photo.caption, latest_photo.pk, "medium"
),
"large": MediaURLService.generate_park_photo_url(
obj.slug, latest_photo.caption, latest_photo.pk, "large"
),
"public": MediaURLService.generate_park_photo_url(
obj.slug, latest_photo.caption, latest_photo.pk, "public"
),
},
"caption": latest_photo.caption,
"alt_text": latest_photo.alt_text,
@@ -417,7 +433,7 @@ class ParkImageSettingsInputSerializer(serializers.Serializer):
# The park will be validated in the view
return value
except ParkPhoto.DoesNotExist:
raise serializers.ValidationError("Photo not found")
raise serializers.ValidationError("Photo not found") from None
return value
def validate_card_image_id(self, value):
@@ -430,7 +446,7 @@ class ParkImageSettingsInputSerializer(serializers.Serializer):
# The park will be validated in the view
return value
except ParkPhoto.DoesNotExist:
raise serializers.ValidationError("Photo not found")
raise serializers.ValidationError("Photo not found") from None
return value
@@ -439,19 +455,13 @@ class ParkCreateInputSerializer(serializers.Serializer):
name = serializers.CharField(max_length=255)
description = serializers.CharField(allow_blank=True, default="")
status = serializers.ChoiceField(
choices=ModelChoices.get_park_status_choices(), default="OPERATING"
)
status = serializers.ChoiceField(choices=ModelChoices.get_park_status_choices(), default="OPERATING")
# Optional details
opening_date = serializers.DateField(required=False, allow_null=True)
closing_date = serializers.DateField(required=False, allow_null=True)
operating_season = serializers.CharField(
max_length=255, required=False, allow_blank=True
)
size_acres = serializers.DecimalField(
max_digits=10, decimal_places=2, required=False, allow_null=True
)
operating_season = serializers.CharField(max_length=255, required=False, allow_blank=True)
size_acres = serializers.DecimalField(max_digits=10, decimal_places=2, required=False, allow_null=True)
website = serializers.URLField(required=False, allow_blank=True)
# Required operator
@@ -466,9 +476,7 @@ class ParkCreateInputSerializer(serializers.Serializer):
closing_date = attrs.get("closing_date")
if opening_date and closing_date and closing_date < opening_date:
raise serializers.ValidationError(
"Closing date cannot be before opening date"
)
raise serializers.ValidationError("Closing date cannot be before opening date")
return attrs
@@ -478,19 +486,13 @@ class ParkUpdateInputSerializer(serializers.Serializer):
name = serializers.CharField(max_length=255, required=False)
description = serializers.CharField(allow_blank=True, required=False)
status = serializers.ChoiceField(
choices=ModelChoices.get_park_status_choices(), required=False
)
status = serializers.ChoiceField(choices=ModelChoices.get_park_status_choices(), required=False)
# Optional details
opening_date = serializers.DateField(required=False, allow_null=True)
closing_date = serializers.DateField(required=False, allow_null=True)
operating_season = serializers.CharField(
max_length=255, required=False, allow_blank=True
)
size_acres = serializers.DecimalField(
max_digits=10, decimal_places=2, required=False, allow_null=True
)
operating_season = serializers.CharField(max_length=255, required=False, allow_blank=True)
size_acres = serializers.DecimalField(max_digits=10, decimal_places=2, required=False, allow_null=True)
website = serializers.URLField(required=False, allow_blank=True)
# Companies
@@ -503,9 +505,7 @@ class ParkUpdateInputSerializer(serializers.Serializer):
closing_date = attrs.get("closing_date")
if opening_date and closing_date and closing_date < opening_date:
raise serializers.ValidationError(
"Closing date cannot be before opening date"
)
raise serializers.ValidationError("Closing date cannot be before opening date")
return attrs
@@ -537,12 +537,8 @@ class ParkFilterInputSerializer(serializers.Serializer):
)
# Size filter
min_size_acres = serializers.DecimalField(
max_digits=10, decimal_places=2, required=False, min_value=0
)
max_size_acres = serializers.DecimalField(
max_digits=10, decimal_places=2, required=False, min_value=0
)
min_size_acres = serializers.DecimalField(max_digits=10, decimal_places=2, required=False, min_value=0)
max_size_acres = serializers.DecimalField(max_digits=10, decimal_places=2, required=False, min_value=0)
# Company filters
operator_id = serializers.IntegerField(required=False)
@@ -625,9 +621,7 @@ class ParkAreaCreateInputSerializer(serializers.Serializer):
closing_date = attrs.get("closing_date")
if opening_date and closing_date and closing_date < opening_date:
raise serializers.ValidationError(
"Closing date cannot be before opening date"
)
raise serializers.ValidationError("Closing date cannot be before opening date")
return attrs
@@ -646,9 +640,7 @@ class ParkAreaUpdateInputSerializer(serializers.Serializer):
closing_date = attrs.get("closing_date")
if opening_date and closing_date and closing_date < opening_date:
raise serializers.ValidationError(
"Closing date cannot be before opening date"
)
raise serializers.ValidationError("Closing date cannot be before opening date")
return attrs

View File

@@ -12,9 +12,7 @@ from apps.parks.models import ParkPhoto
class ParkPhotoOutputSerializer(serializers.ModelSerializer):
"""Output serializer for park photos."""
uploaded_by_username = serializers.CharField(
source="uploaded_by.username", read_only=True
)
uploaded_by_username = serializers.CharField(source="uploaded_by.username", read_only=True)
file_size = serializers.ReadOnlyField()
dimensions = serializers.ReadOnlyField()
park_slug = serializers.CharField(source="park.slug", read_only=True)
@@ -78,9 +76,7 @@ class ParkPhotoUpdateInputSerializer(serializers.ModelSerializer):
class ParkPhotoListOutputSerializer(serializers.ModelSerializer):
"""Simplified output serializer for park photo lists."""
uploaded_by_username = serializers.CharField(
source="uploaded_by.username", read_only=True
)
uploaded_by_username = serializers.CharField(source="uploaded_by.username", read_only=True)
class Meta:
model = ParkPhoto
@@ -99,12 +95,8 @@ class ParkPhotoListOutputSerializer(serializers.ModelSerializer):
class ParkPhotoApprovalInputSerializer(serializers.Serializer):
"""Input serializer for photo approval operations."""
photo_ids = serializers.ListField(
child=serializers.IntegerField(), help_text="List of photo IDs to approve"
)
approve = serializers.BooleanField(
default=True, help_text="Whether to approve (True) or reject (False) the photos"
)
photo_ids = serializers.ListField(child=serializers.IntegerField(), help_text="List of photo IDs to approve")
approve = serializers.BooleanField(default=True, help_text="Whether to approve (True) or reject (False) the photos")
class ParkPhotoStatsOutputSerializer(serializers.Serializer):

View File

@@ -8,35 +8,33 @@ from apps.rides.models.credits import RideCredit
class RideCreditSerializer(serializers.ModelSerializer):
"""Serializer for user ride credits."""
ride_id = serializers.PrimaryKeyRelatedField(
queryset=Ride.objects.all(), source='ride', write_only=True
)
ride_id = serializers.PrimaryKeyRelatedField(queryset=Ride.objects.all(), source="ride", write_only=True)
ride = RideListOutputSerializer(read_only=True)
class Meta:
model = RideCredit
fields = [
'id',
'ride',
'ride_id',
'count',
'rating',
'first_ridden_at',
'last_ridden_at',
'notes',
'display_order',
'created_at',
'updated_at',
"id",
"ride",
"ride_id",
"count",
"rating",
"first_ridden_at",
"last_ridden_at",
"notes",
"display_order",
"created_at",
"updated_at",
]
read_only_fields = ['id', 'created_at', 'updated_at']
read_only_fields = ["id", "created_at", "updated_at"]
def validate(self, attrs):
"""
Validate data.
"""
# Ensure dates make sense
first = attrs.get('first_ridden_at')
last = attrs.get('last_ridden_at')
first = attrs.get("first_ridden_at")
last = attrs.get("last_ridden_at")
if first and last and last < first:
raise serializers.ValidationError("Last ridden date cannot be before first ridden date.")
@@ -44,6 +42,6 @@ class RideCreditSerializer(serializers.ModelSerializer):
def create(self, validated_data):
"""Create a new ride credit."""
user = self.context['request'].user
validated_data['user'] = user
user = self.context["request"].user
validated_data["user"] = user
return super().create(validated_data)

View File

@@ -80,18 +80,10 @@ class RideModelVariantOutputSerializer(serializers.Serializer):
id = serializers.IntegerField()
name = serializers.CharField()
description = serializers.CharField()
min_height_ft = serializers.DecimalField(
max_digits=6, decimal_places=2, allow_null=True
)
max_height_ft = serializers.DecimalField(
max_digits=6, decimal_places=2, allow_null=True
)
min_speed_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, allow_null=True
)
max_speed_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, allow_null=True
)
min_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, allow_null=True)
max_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, allow_null=True)
min_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, allow_null=True)
max_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, allow_null=True)
distinguishing_features = serializers.CharField()
@@ -134,20 +126,14 @@ class RideModelListOutputSerializer(serializers.Serializer):
id = serializers.IntegerField()
name = serializers.CharField()
slug = serializers.CharField()
category = RichChoiceFieldSerializer(
choice_group="categories",
domain="rides"
)
category = RichChoiceFieldSerializer(choice_group="categories", domain="rides")
description = serializers.CharField()
# Manufacturer info
manufacturer = RideModelManufacturerOutputSerializer(allow_null=True)
# Market info
target_market = RichChoiceFieldSerializer(
choice_group="target_markets",
domain="rides"
)
target_market = RichChoiceFieldSerializer(choice_group="target_markets", domain="rides")
is_discontinued = serializers.BooleanField()
total_installations = serializers.IntegerField()
first_installation_year = serializers.IntegerField(allow_null=True)
@@ -258,18 +244,10 @@ class RideModelDetailOutputSerializer(serializers.Serializer):
manufacturer = RideModelManufacturerOutputSerializer(allow_null=True)
# Technical specifications
typical_height_range_min_ft = serializers.DecimalField(
max_digits=6, decimal_places=2, allow_null=True
)
typical_height_range_max_ft = serializers.DecimalField(
max_digits=6, decimal_places=2, allow_null=True
)
typical_speed_range_min_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, allow_null=True
)
typical_speed_range_max_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, allow_null=True
)
typical_height_range_min_ft = serializers.DecimalField(max_digits=6, decimal_places=2, allow_null=True)
typical_height_range_max_ft = serializers.DecimalField(max_digits=6, decimal_places=2, allow_null=True)
typical_speed_range_min_mph = serializers.DecimalField(max_digits=5, decimal_places=2, allow_null=True)
typical_speed_range_max_mph = serializers.DecimalField(max_digits=5, decimal_places=2, allow_null=True)
typical_capacity_range_min = serializers.IntegerField(allow_null=True)
typical_capacity_range_max = serializers.IntegerField(allow_null=True)
@@ -343,9 +321,7 @@ class RideModelCreateInputSerializer(serializers.Serializer):
name = serializers.CharField(max_length=255)
description = serializers.CharField(allow_blank=True, default="")
category = serializers.ChoiceField(
choices=ModelChoices.get_ride_category_choices(), allow_blank=True, default=""
)
category = serializers.ChoiceField(choices=ModelChoices.get_ride_category_choices(), allow_blank=True, default="")
# Required manufacturer
manufacturer_id = serializers.IntegerField()
@@ -363,32 +339,18 @@ class RideModelCreateInputSerializer(serializers.Serializer):
typical_speed_range_max_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, required=False, allow_null=True
)
typical_capacity_range_min = serializers.IntegerField(
required=False, allow_null=True, min_value=1
)
typical_capacity_range_max = serializers.IntegerField(
required=False, allow_null=True, min_value=1
)
typical_capacity_range_min = serializers.IntegerField(required=False, allow_null=True, min_value=1)
typical_capacity_range_max = serializers.IntegerField(required=False, allow_null=True, min_value=1)
# Design characteristics
track_type = serializers.CharField(max_length=100, allow_blank=True, default="")
support_structure = serializers.CharField(
max_length=100, allow_blank=True, default=""
)
train_configuration = serializers.CharField(
max_length=200, allow_blank=True, default=""
)
restraint_system = serializers.CharField(
max_length=100, allow_blank=True, default=""
)
support_structure = serializers.CharField(max_length=100, allow_blank=True, default="")
train_configuration = serializers.CharField(max_length=200, allow_blank=True, default="")
restraint_system = serializers.CharField(max_length=100, allow_blank=True, default="")
# Market information
first_installation_year = serializers.IntegerField(
required=False, allow_null=True, min_value=1800, max_value=2100
)
last_installation_year = serializers.IntegerField(
required=False, allow_null=True, min_value=1800, max_value=2100
)
first_installation_year = serializers.IntegerField(required=False, allow_null=True, min_value=1800, max_value=2100)
last_installation_year = serializers.IntegerField(required=False, allow_null=True, min_value=1800, max_value=2100)
is_discontinued = serializers.BooleanField(default=False)
# Design features
@@ -406,36 +368,28 @@ class RideModelCreateInputSerializer(serializers.Serializer):
max_height = attrs.get("typical_height_range_max_ft")
if min_height and max_height and min_height > max_height:
raise serializers.ValidationError(
"Minimum height cannot be greater than maximum height"
)
raise serializers.ValidationError("Minimum height cannot be greater than maximum height")
# Speed range validation
min_speed = attrs.get("typical_speed_range_min_mph")
max_speed = attrs.get("typical_speed_range_max_mph")
if min_speed and max_speed and min_speed > max_speed:
raise serializers.ValidationError(
"Minimum speed cannot be greater than maximum speed"
)
raise serializers.ValidationError("Minimum speed cannot be greater than maximum speed")
# Capacity range validation
min_capacity = attrs.get("typical_capacity_range_min")
max_capacity = attrs.get("typical_capacity_range_max")
if min_capacity and max_capacity and min_capacity > max_capacity:
raise serializers.ValidationError(
"Minimum capacity cannot be greater than maximum capacity"
)
raise serializers.ValidationError("Minimum capacity cannot be greater than maximum capacity")
# Installation years validation
first_year = attrs.get("first_installation_year")
last_year = attrs.get("last_installation_year")
if first_year and last_year and first_year > last_year:
raise serializers.ValidationError(
"First installation year cannot be after last installation year"
)
raise serializers.ValidationError("First installation year cannot be after last installation year")
return attrs
@@ -467,32 +421,18 @@ class RideModelUpdateInputSerializer(serializers.Serializer):
typical_speed_range_max_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, required=False, allow_null=True
)
typical_capacity_range_min = serializers.IntegerField(
required=False, allow_null=True, min_value=1
)
typical_capacity_range_max = serializers.IntegerField(
required=False, allow_null=True, min_value=1
)
typical_capacity_range_min = serializers.IntegerField(required=False, allow_null=True, min_value=1)
typical_capacity_range_max = serializers.IntegerField(required=False, allow_null=True, min_value=1)
# Design characteristics
track_type = serializers.CharField(max_length=100, allow_blank=True, required=False)
support_structure = serializers.CharField(
max_length=100, allow_blank=True, required=False
)
train_configuration = serializers.CharField(
max_length=200, allow_blank=True, required=False
)
restraint_system = serializers.CharField(
max_length=100, allow_blank=True, required=False
)
support_structure = serializers.CharField(max_length=100, allow_blank=True, required=False)
train_configuration = serializers.CharField(max_length=200, allow_blank=True, required=False)
restraint_system = serializers.CharField(max_length=100, allow_blank=True, required=False)
# Market information
first_installation_year = serializers.IntegerField(
required=False, allow_null=True, min_value=1800, max_value=2100
)
last_installation_year = serializers.IntegerField(
required=False, allow_null=True, min_value=1800, max_value=2100
)
first_installation_year = serializers.IntegerField(required=False, allow_null=True, min_value=1800, max_value=2100)
last_installation_year = serializers.IntegerField(required=False, allow_null=True, min_value=1800, max_value=2100)
is_discontinued = serializers.BooleanField(required=False)
# Design features
@@ -510,36 +450,28 @@ class RideModelUpdateInputSerializer(serializers.Serializer):
max_height = attrs.get("typical_height_range_max_ft")
if min_height and max_height and min_height > max_height:
raise serializers.ValidationError(
"Minimum height cannot be greater than maximum height"
)
raise serializers.ValidationError("Minimum height cannot be greater than maximum height")
# Speed range validation
min_speed = attrs.get("typical_speed_range_min_mph")
max_speed = attrs.get("typical_speed_range_max_mph")
if min_speed and max_speed and min_speed > max_speed:
raise serializers.ValidationError(
"Minimum speed cannot be greater than maximum speed"
)
raise serializers.ValidationError("Minimum speed cannot be greater than maximum speed")
# Capacity range validation
min_capacity = attrs.get("typical_capacity_range_min")
max_capacity = attrs.get("typical_capacity_range_max")
if min_capacity and max_capacity and min_capacity > max_capacity:
raise serializers.ValidationError(
"Minimum capacity cannot be greater than maximum capacity"
)
raise serializers.ValidationError("Minimum capacity cannot be greater than maximum capacity")
# Installation years validation
first_year = attrs.get("first_installation_year")
last_year = attrs.get("last_installation_year")
if first_year and last_year and first_year > last_year:
raise serializers.ValidationError(
"First installation year cannot be after last installation year"
)
raise serializers.ValidationError("First installation year cannot be after last installation year")
return attrs
@@ -551,9 +483,7 @@ class RideModelFilterInputSerializer(serializers.Serializer):
search = serializers.CharField(required=False, allow_blank=True)
# Category filter
category = serializers.MultipleChoiceField(
choices=ModelChoices.get_ride_category_choices(), required=False
)
category = serializers.MultipleChoiceField(choices=ModelChoices.get_ride_category_choices(), required=False)
# Manufacturer filter
manufacturer_id = serializers.IntegerField(required=False)
@@ -576,20 +506,12 @@ class RideModelFilterInputSerializer(serializers.Serializer):
min_installations = serializers.IntegerField(required=False, min_value=0)
# Height filters
min_height_ft = serializers.DecimalField(
max_digits=6, decimal_places=2, required=False
)
max_height_ft = serializers.DecimalField(
max_digits=6, decimal_places=2, required=False
)
min_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False)
max_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False)
# Speed filters
min_speed_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, required=False
)
max_speed_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, required=False
)
min_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, required=False)
max_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, required=False)
# Ordering
ordering = serializers.ChoiceField(
@@ -621,18 +543,10 @@ class RideModelVariantCreateInputSerializer(serializers.Serializer):
description = serializers.CharField(allow_blank=True, default="")
# Variant-specific specifications
min_height_ft = serializers.DecimalField(
max_digits=6, decimal_places=2, required=False, allow_null=True
)
max_height_ft = serializers.DecimalField(
max_digits=6, decimal_places=2, required=False, allow_null=True
)
min_speed_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, required=False, allow_null=True
)
max_speed_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, required=False, allow_null=True
)
min_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False, allow_null=True)
max_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False, allow_null=True)
min_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, required=False, allow_null=True)
max_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, required=False, allow_null=True)
# Distinguishing features
distinguishing_features = serializers.CharField(allow_blank=True, default="")
@@ -644,18 +558,14 @@ class RideModelVariantCreateInputSerializer(serializers.Serializer):
max_height = attrs.get("max_height_ft")
if min_height and max_height and min_height > max_height:
raise serializers.ValidationError(
"Minimum height cannot be greater than maximum height"
)
raise serializers.ValidationError("Minimum height cannot be greater than maximum height")
# Speed range validation
min_speed = attrs.get("min_speed_mph")
max_speed = attrs.get("max_speed_mph")
if min_speed and max_speed and min_speed > max_speed:
raise serializers.ValidationError(
"Minimum speed cannot be greater than maximum speed"
)
raise serializers.ValidationError("Minimum speed cannot be greater than maximum speed")
return attrs
@@ -667,18 +577,10 @@ class RideModelVariantUpdateInputSerializer(serializers.Serializer):
description = serializers.CharField(allow_blank=True, required=False)
# Variant-specific specifications
min_height_ft = serializers.DecimalField(
max_digits=6, decimal_places=2, required=False, allow_null=True
)
max_height_ft = serializers.DecimalField(
max_digits=6, decimal_places=2, required=False, allow_null=True
)
min_speed_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, required=False, allow_null=True
)
max_speed_mph = serializers.DecimalField(
max_digits=5, decimal_places=2, required=False, allow_null=True
)
min_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False, allow_null=True)
max_height_ft = serializers.DecimalField(max_digits=6, decimal_places=2, required=False, allow_null=True)
min_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, required=False, allow_null=True)
max_speed_mph = serializers.DecimalField(max_digits=5, decimal_places=2, required=False, allow_null=True)
# Distinguishing features
distinguishing_features = serializers.CharField(allow_blank=True, required=False)
@@ -690,18 +592,14 @@ class RideModelVariantUpdateInputSerializer(serializers.Serializer):
max_height = attrs.get("max_height_ft")
if min_height and max_height and min_height > max_height:
raise serializers.ValidationError(
"Minimum height cannot be greater than maximum height"
)
raise serializers.ValidationError("Minimum height cannot be greater than maximum height")
# Speed range validation
min_speed = attrs.get("min_speed_mph")
max_speed = attrs.get("max_speed_mph")
if min_speed and max_speed and min_speed > max_speed:
raise serializers.ValidationError(
"Minimum speed cannot be greater than maximum speed"
)
raise serializers.ValidationError("Minimum speed cannot be greater than maximum speed")
return attrs
@@ -713,9 +611,7 @@ class RideModelTechnicalSpecCreateInputSerializer(serializers.Serializer):
"""Input serializer for creating ride model technical specifications."""
ride_model_id = serializers.IntegerField()
spec_category = serializers.ChoiceField(
choices=ModelChoices.get_technical_spec_category_choices()
)
spec_category = serializers.ChoiceField(choices=ModelChoices.get_technical_spec_category_choices())
spec_name = serializers.CharField(max_length=100)
spec_value = serializers.CharField(max_length=255)
spec_unit = serializers.CharField(max_length=20, allow_blank=True, default="")
@@ -765,13 +661,9 @@ class RideModelPhotoUpdateInputSerializer(serializers.Serializer):
required=False,
)
is_primary = serializers.BooleanField(required=False)
photographer = serializers.CharField(
max_length=255, allow_blank=True, required=False
)
photographer = serializers.CharField(max_length=255, allow_blank=True, required=False)
source = serializers.CharField(max_length=255, allow_blank=True, required=False)
copyright_info = serializers.CharField(
max_length=255, allow_blank=True, required=False
)
copyright_info = serializers.CharField(max_length=255, allow_blank=True, required=False)
# === RIDE MODEL STATS SERIALIZERS ===
@@ -784,15 +676,9 @@ class RideModelStatsOutputSerializer(serializers.Serializer):
total_installations = serializers.IntegerField()
active_manufacturers = serializers.IntegerField()
discontinued_models = serializers.IntegerField()
by_category = serializers.DictField(
child=serializers.IntegerField(), help_text="Model counts by category"
)
by_category = serializers.DictField(child=serializers.IntegerField(), help_text="Model counts by category")
by_target_market = serializers.DictField(
child=serializers.IntegerField(), help_text="Model counts by target market"
)
by_manufacturer = serializers.DictField(
child=serializers.IntegerField(), help_text="Model counts by manufacturer"
)
recent_models = serializers.IntegerField(
help_text="Models created in the last 30 days"
)
by_manufacturer = serializers.DictField(child=serializers.IntegerField(), help_text="Model counts by manufacturer")
recent_models = serializers.IntegerField(help_text="Models created in the last 30 days")

View File

@@ -54,19 +54,11 @@ class ReviewUserSerializer(serializers.ModelSerializer):
"id": 456,
"username": "coaster_fan",
"display_name": "Coaster Fan",
"avatar_url": "https://example.com/avatar.jpg"
"avatar_url": "https://example.com/avatar.jpg",
},
"ride": {
"id": 789,
"name": "Steel Vengeance",
"slug": "steel-vengeance"
},
"park": {
"id": 101,
"name": "Cedar Point",
"slug": "cedar-point"
}
}
"ride": {"id": 789, "name": "Steel Vengeance", "slug": "steel-vengeance"},
"park": {"id": 101, "name": "Cedar Point", "slug": "cedar-point"},
},
)
]
)
@@ -191,8 +183,7 @@ class RideReviewStatsOutputSerializer(serializers.Serializer):
pending_reviews = serializers.IntegerField()
average_rating = serializers.FloatField(allow_null=True)
rating_distribution = serializers.DictField(
child=serializers.IntegerField(),
help_text="Count of reviews by rating (1-10)"
child=serializers.IntegerField(), help_text="Count of reviews by rating (1-10)"
)
recent_reviews = serializers.IntegerField()
@@ -200,20 +191,15 @@ class RideReviewStatsOutputSerializer(serializers.Serializer):
class RideReviewModerationInputSerializer(serializers.Serializer):
"""Input serializer for review moderation operations."""
review_ids = serializers.ListField(
child=serializers.IntegerField(),
help_text="List of review IDs to moderate"
)
review_ids = serializers.ListField(child=serializers.IntegerField(), help_text="List of review IDs to moderate")
action = serializers.ChoiceField(
choices=[
("publish", "Publish"),
("unpublish", "Unpublish"),
("delete", "Delete"),
],
help_text="Moderation action to perform"
help_text="Moderation action to perform",
)
moderation_notes = serializers.CharField(
required=False,
allow_blank=True,
help_text="Optional notes about the moderation action"
required=False, allow_blank=True, help_text="Optional notes about the moderation action"
)

Some files were not shown because too many files have changed in this diff Show More